diff options
| author | Kevin Newton <kddnewton@gmail.com> | 2024-04-29 14:33:54 -0400 |
|---|---|---|
| committer | Kevin Newton <kddnewton@gmail.com> | 2024-05-01 12:34:29 -0400 |
| commit | b6fa18fbe90c63b2979a4f1f8aecab1de4373664 (patch) | |
| tree | 22490e92a0151da5eb1bc50179eed12389333d6a | |
| parent | 1b8650964bb4c69e23a1e9e5f6b2d14bfe0b698a (diff) | |
[PRISM] Properly precheck regexp for encoding issues
| -rw-r--r-- | prism_compile.c | 100 | ||||
| -rw-r--r-- | test/.excludes-prism/TestM17N.rb | 3 | ||||
| -rw-r--r-- | test/.excludes-prism/TestRegexp.rb | 1 |
3 files changed, 57 insertions, 47 deletions
diff --git a/prism_compile.c b/prism_compile.c index 4a4ba5acaa..915e80d65e 100644 --- a/prism_compile.c +++ b/prism_compile.c @@ -351,8 +351,34 @@ pm_optimizable_range_item_p(const pm_node_t *node) return (!node || PM_NODE_TYPE_P(node, PM_INTEGER_NODE) || PM_NODE_TYPE_P(node, PM_NIL_NODE)); } +/** Raise an error corresponding to the invalid regular expression. */ +static VALUE +parse_regexp_error(rb_iseq_t *iseq, int32_t line_number, const char *fmt, ...) +{ + va_list args; + va_start(args, fmt); + VALUE error = rb_syntax_error_append(Qnil, rb_iseq_path(iseq), line_number, -1, NULL, "%" PRIsVALUE, args); + va_end(args); + rb_exc_raise(error); +} + static VALUE -pm_static_literal_concat(const pm_node_list_t *nodes, const pm_scope_node_t *scope_node, bool top) +parse_regexp_string_part(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const pm_node_t *node, const pm_string_t *unescaped, rb_encoding *regexp_encoding) +{ + // If we were passed an explicit regexp encoding, then we need to double + // check that it's okay here for this fragment of the string. + VALUE string = rb_enc_str_new((const char *) pm_string_source(unescaped), pm_string_length(unescaped), regexp_encoding); + VALUE error = rb_reg_check_preprocess(string); + + if (error != Qnil) { + parse_regexp_error(iseq, pm_node_line_number(scope_node->parser, node), "%" PRIsVALUE, rb_obj_as_string(error)); + } + + return string; +} + +static VALUE +pm_static_literal_concat(rb_iseq_t *iseq, const pm_node_list_t *nodes, const pm_scope_node_t *scope_node, rb_encoding *regexp_encoding, bool top) { VALUE current = Qnil; @@ -362,10 +388,16 @@ pm_static_literal_concat(const pm_node_list_t *nodes, const pm_scope_node_t *sco switch (PM_NODE_TYPE(part)) { case PM_STRING_NODE: - string = parse_string_encoded(part, &((const pm_string_node_t *) part)->unescaped, scope_node->encoding); + if (regexp_encoding == NULL) { + string = parse_string_encoded(part, &((const pm_string_node_t *) part)->unescaped, scope_node->encoding); + } + else { + string = parse_regexp_string_part(iseq, scope_node, part, &((const pm_string_node_t *) part)->unescaped, regexp_encoding); + } + break; case PM_INTERPOLATED_STRING_NODE: - string = pm_static_literal_concat(&((const pm_interpolated_string_node_t *) part)->parts, scope_node, false); + string = pm_static_literal_concat(iseq, &((const pm_interpolated_string_node_t *) part)->parts, scope_node, regexp_encoding, false); break; default: RUBY_ASSERT(false && "unexpected node type in pm_static_literal_concat"); @@ -445,7 +477,7 @@ parse_regexp_flags(const pm_node_t *node) #undef ENC_UTF8 static rb_encoding * -parse_regexp_encoding(const pm_node_t *node) +parse_regexp_encoding(const pm_scope_node_t *scope_node, const pm_node_t *node) { if (PM_NODE_FLAG_P(node, PM_REGULAR_EXPRESSION_FLAGS_ASCII_8BIT)) { return rb_ascii8bit_encoding(); @@ -460,21 +492,10 @@ parse_regexp_encoding(const pm_node_t *node) return rb_enc_get_from_index(ENCINDEX_Windows_31J); } else { - return NULL; + return scope_node->encoding; } } -/** Raise an error corresponding to the invalid regular expression. */ -static VALUE -parse_regexp_error(rb_iseq_t *iseq, int32_t line_number, const char *fmt, ...) -{ - va_list args; - va_start(args, fmt); - VALUE error = rb_syntax_error_append(Qnil, rb_iseq_path(iseq), line_number, -1, NULL, "%" PRIsVALUE, args); - va_end(args); - rb_exc_raise(error); -} - static VALUE parse_regexp(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const pm_node_t *node, VALUE string) { @@ -498,22 +519,16 @@ parse_regexp(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const pm_node_t static inline VALUE parse_regexp_literal(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const pm_node_t *node, const pm_string_t *unescaped) { - rb_encoding *encoding = parse_regexp_encoding(node); - if (encoding == NULL) encoding = scope_node->encoding; - - VALUE string = rb_enc_str_new((const char *) pm_string_source(unescaped), pm_string_length(unescaped), encoding); + rb_encoding *regexp_encoding = parse_regexp_encoding(scope_node, node); + VALUE string = rb_enc_str_new((const char *) pm_string_source(unescaped), pm_string_length(unescaped), regexp_encoding); return parse_regexp(iseq, scope_node, node, string); } static inline VALUE parse_regexp_concat(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const pm_node_t *node, const pm_node_list_t *parts) { - rb_encoding *encoding = parse_regexp_encoding(node); - if (encoding == NULL) encoding = scope_node->encoding; - - VALUE string = pm_static_literal_concat(parts, scope_node, false); - rb_enc_associate(string, encoding); - + rb_encoding *regexp_encoding = parse_regexp_encoding(scope_node, node); + VALUE string = pm_static_literal_concat(iseq, parts, scope_node, regexp_encoding, false); return parse_regexp(iseq, scope_node, node, string); } @@ -528,20 +543,19 @@ pm_interpolated_node_compile(rb_iseq_t *iseq, const pm_node_list_t *parts, const if (parts_size > 0) { VALUE current_string = Qnil; - rb_encoding *default_encoding = regexp_encoding != NULL ? regexp_encoding : scope_node->encoding; for (size_t index = 0; index < parts_size; index++) { const pm_node_t *part = parts->nodes[index]; if (PM_NODE_TYPE_P(part, PM_STRING_NODE)) { const pm_string_node_t *string_node = (const pm_string_node_t *) part; - VALUE string_value = parse_string_encoded((const pm_node_t *) string_node, &string_node->unescaped, default_encoding); + VALUE string_value; - // If we were passed an explicit regexp encoding, then we need - // to double check that it's okay here. - if (regexp_encoding != NULL) { - VALUE error = rb_reg_check_preprocess(string_value); - if (error != Qnil) parse_regexp_error(iseq, pm_node_line_number(scope_node->parser, (const pm_node_t *) string_node), "%" PRIsVALUE, rb_obj_as_string(error)); + if (regexp_encoding == NULL) { + string_value = parse_string_encoded(part, &string_node->unescaped, scope_node->encoding); + } + else { + string_value = parse_regexp_string_part(iseq, scope_node, (const pm_node_t *) string_node, &string_node->unescaped, regexp_encoding); } if (RTEST(current_string)) { @@ -561,13 +575,13 @@ pm_interpolated_node_compile(rb_iseq_t *iseq, const pm_node_list_t *parts, const PM_NODE_TYPE_P(((const pm_embedded_statements_node_t *) part)->statements->body.nodes[0], PM_STRING_NODE) ) { const pm_string_node_t *string_node = (const pm_string_node_t *) ((const pm_embedded_statements_node_t *) part)->statements->body.nodes[0]; - VALUE string_value = parse_string_encoded((const pm_node_t *) string_node, &string_node->unescaped, default_encoding); + VALUE string_value; - // If we were passed an explicit regexp encoding, then we - // need to double check that it's okay here. - if (regexp_encoding != NULL) { - VALUE error = rb_reg_check_preprocess(string_value); - if (error != Qnil) parse_regexp_error(iseq, pm_node_line_number(scope_node->parser, (const pm_node_t *) string_node), "%" PRIsVALUE, rb_obj_as_string(error)); + if (regexp_encoding == NULL) { + string_value = parse_string_encoded(part, &string_node->unescaped, scope_node->encoding); + } + else { + string_value = parse_regexp_string_part(iseq, scope_node, (const pm_node_t *) string_node, &string_node->unescaped, regexp_encoding); } if (RTEST(current_string)) { @@ -579,7 +593,7 @@ pm_interpolated_node_compile(rb_iseq_t *iseq, const pm_node_list_t *parts, const } else { if (!RTEST(current_string)) { - current_string = rb_enc_str_new(NULL, 0, default_encoding); + current_string = rb_enc_str_new(NULL, 0, regexp_encoding != NULL ? regexp_encoding : scope_node->encoding); } PUSH_INSN1(ret, *node_location, putobject, rb_fstring(current_string)); @@ -618,7 +632,7 @@ pm_interpolated_node_compile(rb_iseq_t *iseq, const pm_node_list_t *parts, const static void pm_compile_regexp_dynamic(rb_iseq_t *iseq, const pm_node_t *node, const pm_node_list_t *parts, const pm_line_column_t *node_location, LINK_ANCHOR *const ret, bool popped, pm_scope_node_t *scope_node) { - rb_encoding *regexp_encoding = parse_regexp_encoding(node); + rb_encoding *regexp_encoding = parse_regexp_encoding(scope_node, node); int length = pm_interpolated_node_compile(iseq, parts, node_location, ret, popped, scope_node, regexp_encoding); PUSH_INSN2(ret, *node_location, toregexp, INT2FIX(parse_regexp_flags(node) & 0xFF), INT2FIX(length)); @@ -717,13 +731,13 @@ pm_static_literal_value(rb_iseq_t *iseq, const pm_node_t *node, const pm_scope_n return parse_regexp_concat(iseq, scope_node, (const pm_node_t *) cast, &cast->parts); } case PM_INTERPOLATED_STRING_NODE: { - VALUE string = pm_static_literal_concat(&((const pm_interpolated_string_node_t *) node)->parts, scope_node, false); + VALUE string = pm_static_literal_concat(iseq, &((const pm_interpolated_string_node_t *) node)->parts, scope_node, NULL, false); int line_number = pm_node_line_number(scope_node->parser, node); return pm_static_literal_string(iseq, string, line_number); } case PM_INTERPOLATED_SYMBOL_NODE: { const pm_interpolated_symbol_node_t *cast = (const pm_interpolated_symbol_node_t *) node; - VALUE string = pm_static_literal_concat(&cast->parts, scope_node, true); + VALUE string = pm_static_literal_concat(iseq, &cast->parts, scope_node, NULL, true); return ID2SYM(rb_intern_str(string)); } diff --git a/test/.excludes-prism/TestM17N.rb b/test/.excludes-prism/TestM17N.rb index bf11829020..25c4394634 100644 --- a/test/.excludes-prism/TestM17N.rb +++ b/test/.excludes-prism/TestM17N.rb @@ -1,6 +1,3 @@ -exclude(:test_dynamic_eucjp_regexp, "https://github.com/ruby/prism/issues/2664") -exclude(:test_dynamic_sjis_regexp, "https://github.com/ruby/prism/issues/2664") -exclude(:test_dynamic_utf8_regexp, "https://github.com/ruby/prism/issues/2664") exclude(:test_regexp_ascii, "https://github.com/ruby/prism/issues/2664") exclude(:test_regexp_usascii, "unknown") exclude(:test_string_mixed_unicode, "unknown") diff --git a/test/.excludes-prism/TestRegexp.rb b/test/.excludes-prism/TestRegexp.rb index 52d4a69025..1d41a4dc57 100644 --- a/test/.excludes-prism/TestRegexp.rb +++ b/test/.excludes-prism/TestRegexp.rb @@ -1,3 +1,2 @@ exclude(:test_invalid_escape_error, "unknown") -exclude(:test_invalid_fragment, "https://github.com/ruby/prism/issues/2664") exclude(:test_unescape, "unknown") |
