summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Newton <kddnewton@gmail.com>2024-04-29 14:33:54 -0400
committerKevin Newton <kddnewton@gmail.com>2024-05-01 12:34:29 -0400
commitb6fa18fbe90c63b2979a4f1f8aecab1de4373664 (patch)
tree22490e92a0151da5eb1bc50179eed12389333d6a
parent1b8650964bb4c69e23a1e9e5f6b2d14bfe0b698a (diff)
[PRISM] Properly precheck regexp for encoding issues
-rw-r--r--prism_compile.c100
-rw-r--r--test/.excludes-prism/TestM17N.rb3
-rw-r--r--test/.excludes-prism/TestRegexp.rb1
3 files changed, 57 insertions, 47 deletions
diff --git a/prism_compile.c b/prism_compile.c
index 4a4ba5acaa..915e80d65e 100644
--- a/prism_compile.c
+++ b/prism_compile.c
@@ -351,8 +351,34 @@ pm_optimizable_range_item_p(const pm_node_t *node)
return (!node || PM_NODE_TYPE_P(node, PM_INTEGER_NODE) || PM_NODE_TYPE_P(node, PM_NIL_NODE));
}
+/** Raise an error corresponding to the invalid regular expression. */
+static VALUE
+parse_regexp_error(rb_iseq_t *iseq, int32_t line_number, const char *fmt, ...)
+{
+ va_list args;
+ va_start(args, fmt);
+ VALUE error = rb_syntax_error_append(Qnil, rb_iseq_path(iseq), line_number, -1, NULL, "%" PRIsVALUE, args);
+ va_end(args);
+ rb_exc_raise(error);
+}
+
static VALUE
-pm_static_literal_concat(const pm_node_list_t *nodes, const pm_scope_node_t *scope_node, bool top)
+parse_regexp_string_part(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const pm_node_t *node, const pm_string_t *unescaped, rb_encoding *regexp_encoding)
+{
+ // If we were passed an explicit regexp encoding, then we need to double
+ // check that it's okay here for this fragment of the string.
+ VALUE string = rb_enc_str_new((const char *) pm_string_source(unescaped), pm_string_length(unescaped), regexp_encoding);
+ VALUE error = rb_reg_check_preprocess(string);
+
+ if (error != Qnil) {
+ parse_regexp_error(iseq, pm_node_line_number(scope_node->parser, node), "%" PRIsVALUE, rb_obj_as_string(error));
+ }
+
+ return string;
+}
+
+static VALUE
+pm_static_literal_concat(rb_iseq_t *iseq, const pm_node_list_t *nodes, const pm_scope_node_t *scope_node, rb_encoding *regexp_encoding, bool top)
{
VALUE current = Qnil;
@@ -362,10 +388,16 @@ pm_static_literal_concat(const pm_node_list_t *nodes, const pm_scope_node_t *sco
switch (PM_NODE_TYPE(part)) {
case PM_STRING_NODE:
- string = parse_string_encoded(part, &((const pm_string_node_t *) part)->unescaped, scope_node->encoding);
+ if (regexp_encoding == NULL) {
+ string = parse_string_encoded(part, &((const pm_string_node_t *) part)->unescaped, scope_node->encoding);
+ }
+ else {
+ string = parse_regexp_string_part(iseq, scope_node, part, &((const pm_string_node_t *) part)->unescaped, regexp_encoding);
+ }
+
break;
case PM_INTERPOLATED_STRING_NODE:
- string = pm_static_literal_concat(&((const pm_interpolated_string_node_t *) part)->parts, scope_node, false);
+ string = pm_static_literal_concat(iseq, &((const pm_interpolated_string_node_t *) part)->parts, scope_node, regexp_encoding, false);
break;
default:
RUBY_ASSERT(false && "unexpected node type in pm_static_literal_concat");
@@ -445,7 +477,7 @@ parse_regexp_flags(const pm_node_t *node)
#undef ENC_UTF8
static rb_encoding *
-parse_regexp_encoding(const pm_node_t *node)
+parse_regexp_encoding(const pm_scope_node_t *scope_node, const pm_node_t *node)
{
if (PM_NODE_FLAG_P(node, PM_REGULAR_EXPRESSION_FLAGS_ASCII_8BIT)) {
return rb_ascii8bit_encoding();
@@ -460,21 +492,10 @@ parse_regexp_encoding(const pm_node_t *node)
return rb_enc_get_from_index(ENCINDEX_Windows_31J);
}
else {
- return NULL;
+ return scope_node->encoding;
}
}
-/** Raise an error corresponding to the invalid regular expression. */
-static VALUE
-parse_regexp_error(rb_iseq_t *iseq, int32_t line_number, const char *fmt, ...)
-{
- va_list args;
- va_start(args, fmt);
- VALUE error = rb_syntax_error_append(Qnil, rb_iseq_path(iseq), line_number, -1, NULL, "%" PRIsVALUE, args);
- va_end(args);
- rb_exc_raise(error);
-}
-
static VALUE
parse_regexp(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const pm_node_t *node, VALUE string)
{
@@ -498,22 +519,16 @@ parse_regexp(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const pm_node_t
static inline VALUE
parse_regexp_literal(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const pm_node_t *node, const pm_string_t *unescaped)
{
- rb_encoding *encoding = parse_regexp_encoding(node);
- if (encoding == NULL) encoding = scope_node->encoding;
-
- VALUE string = rb_enc_str_new((const char *) pm_string_source(unescaped), pm_string_length(unescaped), encoding);
+ rb_encoding *regexp_encoding = parse_regexp_encoding(scope_node, node);
+ VALUE string = rb_enc_str_new((const char *) pm_string_source(unescaped), pm_string_length(unescaped), regexp_encoding);
return parse_regexp(iseq, scope_node, node, string);
}
static inline VALUE
parse_regexp_concat(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const pm_node_t *node, const pm_node_list_t *parts)
{
- rb_encoding *encoding = parse_regexp_encoding(node);
- if (encoding == NULL) encoding = scope_node->encoding;
-
- VALUE string = pm_static_literal_concat(parts, scope_node, false);
- rb_enc_associate(string, encoding);
-
+ rb_encoding *regexp_encoding = parse_regexp_encoding(scope_node, node);
+ VALUE string = pm_static_literal_concat(iseq, parts, scope_node, regexp_encoding, false);
return parse_regexp(iseq, scope_node, node, string);
}
@@ -528,20 +543,19 @@ pm_interpolated_node_compile(rb_iseq_t *iseq, const pm_node_list_t *parts, const
if (parts_size > 0) {
VALUE current_string = Qnil;
- rb_encoding *default_encoding = regexp_encoding != NULL ? regexp_encoding : scope_node->encoding;
for (size_t index = 0; index < parts_size; index++) {
const pm_node_t *part = parts->nodes[index];
if (PM_NODE_TYPE_P(part, PM_STRING_NODE)) {
const pm_string_node_t *string_node = (const pm_string_node_t *) part;
- VALUE string_value = parse_string_encoded((const pm_node_t *) string_node, &string_node->unescaped, default_encoding);
+ VALUE string_value;
- // If we were passed an explicit regexp encoding, then we need
- // to double check that it's okay here.
- if (regexp_encoding != NULL) {
- VALUE error = rb_reg_check_preprocess(string_value);
- if (error != Qnil) parse_regexp_error(iseq, pm_node_line_number(scope_node->parser, (const pm_node_t *) string_node), "%" PRIsVALUE, rb_obj_as_string(error));
+ if (regexp_encoding == NULL) {
+ string_value = parse_string_encoded(part, &string_node->unescaped, scope_node->encoding);
+ }
+ else {
+ string_value = parse_regexp_string_part(iseq, scope_node, (const pm_node_t *) string_node, &string_node->unescaped, regexp_encoding);
}
if (RTEST(current_string)) {
@@ -561,13 +575,13 @@ pm_interpolated_node_compile(rb_iseq_t *iseq, const pm_node_list_t *parts, const
PM_NODE_TYPE_P(((const pm_embedded_statements_node_t *) part)->statements->body.nodes[0], PM_STRING_NODE)
) {
const pm_string_node_t *string_node = (const pm_string_node_t *) ((const pm_embedded_statements_node_t *) part)->statements->body.nodes[0];
- VALUE string_value = parse_string_encoded((const pm_node_t *) string_node, &string_node->unescaped, default_encoding);
+ VALUE string_value;
- // If we were passed an explicit regexp encoding, then we
- // need to double check that it's okay here.
- if (regexp_encoding != NULL) {
- VALUE error = rb_reg_check_preprocess(string_value);
- if (error != Qnil) parse_regexp_error(iseq, pm_node_line_number(scope_node->parser, (const pm_node_t *) string_node), "%" PRIsVALUE, rb_obj_as_string(error));
+ if (regexp_encoding == NULL) {
+ string_value = parse_string_encoded(part, &string_node->unescaped, scope_node->encoding);
+ }
+ else {
+ string_value = parse_regexp_string_part(iseq, scope_node, (const pm_node_t *) string_node, &string_node->unescaped, regexp_encoding);
}
if (RTEST(current_string)) {
@@ -579,7 +593,7 @@ pm_interpolated_node_compile(rb_iseq_t *iseq, const pm_node_list_t *parts, const
}
else {
if (!RTEST(current_string)) {
- current_string = rb_enc_str_new(NULL, 0, default_encoding);
+ current_string = rb_enc_str_new(NULL, 0, regexp_encoding != NULL ? regexp_encoding : scope_node->encoding);
}
PUSH_INSN1(ret, *node_location, putobject, rb_fstring(current_string));
@@ -618,7 +632,7 @@ pm_interpolated_node_compile(rb_iseq_t *iseq, const pm_node_list_t *parts, const
static void
pm_compile_regexp_dynamic(rb_iseq_t *iseq, const pm_node_t *node, const pm_node_list_t *parts, const pm_line_column_t *node_location, LINK_ANCHOR *const ret, bool popped, pm_scope_node_t *scope_node)
{
- rb_encoding *regexp_encoding = parse_regexp_encoding(node);
+ rb_encoding *regexp_encoding = parse_regexp_encoding(scope_node, node);
int length = pm_interpolated_node_compile(iseq, parts, node_location, ret, popped, scope_node, regexp_encoding);
PUSH_INSN2(ret, *node_location, toregexp, INT2FIX(parse_regexp_flags(node) & 0xFF), INT2FIX(length));
@@ -717,13 +731,13 @@ pm_static_literal_value(rb_iseq_t *iseq, const pm_node_t *node, const pm_scope_n
return parse_regexp_concat(iseq, scope_node, (const pm_node_t *) cast, &cast->parts);
}
case PM_INTERPOLATED_STRING_NODE: {
- VALUE string = pm_static_literal_concat(&((const pm_interpolated_string_node_t *) node)->parts, scope_node, false);
+ VALUE string = pm_static_literal_concat(iseq, &((const pm_interpolated_string_node_t *) node)->parts, scope_node, NULL, false);
int line_number = pm_node_line_number(scope_node->parser, node);
return pm_static_literal_string(iseq, string, line_number);
}
case PM_INTERPOLATED_SYMBOL_NODE: {
const pm_interpolated_symbol_node_t *cast = (const pm_interpolated_symbol_node_t *) node;
- VALUE string = pm_static_literal_concat(&cast->parts, scope_node, true);
+ VALUE string = pm_static_literal_concat(iseq, &cast->parts, scope_node, NULL, true);
return ID2SYM(rb_intern_str(string));
}
diff --git a/test/.excludes-prism/TestM17N.rb b/test/.excludes-prism/TestM17N.rb
index bf11829020..25c4394634 100644
--- a/test/.excludes-prism/TestM17N.rb
+++ b/test/.excludes-prism/TestM17N.rb
@@ -1,6 +1,3 @@
-exclude(:test_dynamic_eucjp_regexp, "https://github.com/ruby/prism/issues/2664")
-exclude(:test_dynamic_sjis_regexp, "https://github.com/ruby/prism/issues/2664")
-exclude(:test_dynamic_utf8_regexp, "https://github.com/ruby/prism/issues/2664")
exclude(:test_regexp_ascii, "https://github.com/ruby/prism/issues/2664")
exclude(:test_regexp_usascii, "unknown")
exclude(:test_string_mixed_unicode, "unknown")
diff --git a/test/.excludes-prism/TestRegexp.rb b/test/.excludes-prism/TestRegexp.rb
index 52d4a69025..1d41a4dc57 100644
--- a/test/.excludes-prism/TestRegexp.rb
+++ b/test/.excludes-prism/TestRegexp.rb
@@ -1,3 +1,2 @@
exclude(:test_invalid_escape_error, "unknown")
-exclude(:test_invalid_fragment, "https://github.com/ruby/prism/issues/2664")
exclude(:test_unescape, "unknown")