summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJemma Issroff <jemmaissroff@gmail.com>2023-12-14 16:44:38 -0500
committerJemma Issroff <jemmaissroff@gmail.com>2023-12-14 17:11:54 -0500
commit5587bd4b37cc01227ce5546b3350033bb1ef9775 (patch)
treeecc7135e9e9cc4e2d47f2c83ceba887fa2261570
parent39c072d6f7846a99f64e52090a320576715325c5 (diff)
[PRISM] Implement safe navigation in CallNodes
This commit implements safe navigation for CallNodes, CallAndWriteNodes and CallOperatorWriteNodes
-rw-r--r--prism_compile.c36
-rw-r--r--test/ruby/test_compile_prism.rb46
2 files changed, 76 insertions, 6 deletions
diff --git a/prism_compile.c b/prism_compile.c
index 28f69b2b9e..e99218ed60 100644
--- a/prism_compile.c
+++ b/prism_compile.c
@@ -855,9 +855,10 @@ pm_compile_class_path(LINK_ANCHOR *const ret, rb_iseq_t *iseq, const pm_node_t *
}
static void
-pm_compile_call_and_or_write_node(bool and_node, pm_node_t *receiver, pm_node_t *value, pm_constant_id_t write_name, pm_constant_id_t read_name, LINK_ANCHOR *const ret, rb_iseq_t *iseq, int lineno, const uint8_t * src, bool popped, pm_scope_node_t *scope_node)
+pm_compile_call_and_or_write_node(bool and_node, pm_node_t *receiver, pm_node_t *value, pm_constant_id_t write_name, pm_constant_id_t read_name, bool safe_nav, LINK_ANCHOR *const ret, rb_iseq_t *iseq, int lineno, const uint8_t * src, bool popped, pm_scope_node_t *scope_node)
{
LABEL *call_end_label = NEW_LABEL(lineno);
+ LABEL *else_label = NEW_LABEL(lineno);
LABEL *end_label = NEW_LABEL(lineno);
NODE dummy_line_node = generate_dummy_line_node(lineno, lineno);
@@ -869,6 +870,11 @@ pm_compile_call_and_or_write_node(bool and_node, pm_node_t *receiver, pm_node_t
PM_COMPILE_NOT_POPPED(receiver);
+ if (safe_nav) {
+ PM_DUP;
+ ADD_INSNL(ret, &dummy_line_node, branchnil, else_label);
+ }
+
ID write_name_id = pm_constant_id_lookup(scope_node, write_name);
ID read_name_id = pm_constant_id_lookup(scope_node, read_name);
PM_DUP;
@@ -901,6 +907,10 @@ pm_compile_call_and_or_write_node(bool and_node, pm_node_t *receiver, pm_node_t
PM_SWAP;
}
+ if (safe_nav) {
+ ADD_LABEL(ret, else_label);
+ }
+
ADD_LABEL(ret, end_label);
PM_POP;
return;
@@ -2565,10 +2575,16 @@ pm_compile_call(rb_iseq_t *iseq, const pm_call_node_t *call_node, LINK_ANCHOR *c
pm_newline_list_t newline_list = parser->newline_list;
int lineno = (int)pm_newline_list_line_column(&newline_list, ((pm_node_t *)call_node)->location.start).line;
NODE dummy_line_node = generate_dummy_line_node(lineno, lineno);
- LABEL *end = NEW_LABEL(lineno);
+ LABEL *else_label = NEW_LABEL(lineno);
+ LABEL *end_label = NEW_LABEL(lineno);
pm_node_t *pm_node = (pm_node_t *)call_node;
+ if (call_node->base.flags & PM_CALL_NODE_FLAGS_SAFE_NAVIGATION) {
+ PM_DUP;
+ ADD_INSNL(ret, &dummy_line_node, branchnil, else_label);
+ }
+
int flags = 0;
struct rb_callinfo_kwarg *kw_arg = NULL;
@@ -2582,7 +2598,7 @@ pm_compile_call(rb_iseq_t *iseq, const pm_call_node_t *call_node, LINK_ANCHOR *c
block_iseq = NEW_CHILD_ISEQ(next_scope_node, make_name_for_block(iseq), ISEQ_TYPE_BLOCK, lineno);
if (ISEQ_BODY(block_iseq)->catch_table) {
- ADD_CATCH_ENTRY(CATCH_TYPE_BREAK, start, end, block_iseq, end);
+ ADD_CATCH_ENTRY(CATCH_TYPE_BREAK, start, end_label, block_iseq, end_label);
}
ISEQ_COMPILE_DATA(iseq)->current_block = block_iseq;
}
@@ -2613,7 +2629,12 @@ pm_compile_call(rb_iseq_t *iseq, const pm_call_node_t *call_node, LINK_ANCHOR *c
else {
ADD_SEND_R(ret, &dummy_line_node, method_id, INT2FIX(orig_argc), block_iseq, INT2FIX(flags), kw_arg);
}
- ADD_LABEL(ret, end);
+
+ if (call_node->base.flags & PM_CALL_NODE_FLAGS_SAFE_NAVIGATION) {
+ ADD_INSNL(ret, &dummy_line_node, jump, end_label);
+ ADD_LABEL(ret, else_label);
+ }
+ ADD_LABEL(ret, end_label);
PM_POP_IF_POPPED;
}
@@ -3131,14 +3152,17 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret,
case PM_CALL_AND_WRITE_NODE: {
pm_call_and_write_node_t *call_and_write_node = (pm_call_and_write_node_t*) node;
- pm_compile_call_and_or_write_node(true, call_and_write_node->receiver, call_and_write_node->value, call_and_write_node->write_name, call_and_write_node->read_name, ret, iseq, lineno, src, popped, scope_node);
+ bool safe_nav = node->flags & PM_CALL_NODE_FLAGS_SAFE_NAVIGATION;
+
+ pm_compile_call_and_or_write_node(true, call_and_write_node->receiver, call_and_write_node->value, call_and_write_node->write_name, call_and_write_node->read_name, safe_nav, ret, iseq, lineno, src, popped, scope_node);
return;
}
case PM_CALL_OR_WRITE_NODE: {
pm_call_or_write_node_t *call_or_write_node = (pm_call_or_write_node_t*) node;
+ bool safe_nav = node->flags & PM_CALL_NODE_FLAGS_SAFE_NAVIGATION;
- pm_compile_call_and_or_write_node(false, call_or_write_node->receiver, call_or_write_node->value, call_or_write_node->write_name, call_or_write_node->read_name, ret, iseq, lineno, src, popped, scope_node);
+ pm_compile_call_and_or_write_node(false, call_or_write_node->receiver, call_or_write_node->value, call_or_write_node->write_name, call_or_write_node->read_name, safe_nav, ret, iseq, lineno, src, popped, scope_node);
return;
}
diff --git a/test/ruby/test_compile_prism.rb b/test/ruby/test_compile_prism.rb
index 6b16205869..fc563951f6 100644
--- a/test/ruby/test_compile_prism.rb
+++ b/test/ruby/test_compile_prism.rb
@@ -1470,6 +1470,16 @@ module Prism
pm = PrivateMethod.new
pm.send(:instance_var)
CODE
+
+ # Testing safe navigation operator
+ assert_prism_eval(<<-CODE)
+ def self.test_prism_call_node
+ if [][0]&.first
+ 1
+ end
+ end
+ test_prism_call_node
+ CODE
end
def test_CallAndWriteNode
@@ -1507,6 +1517,24 @@ module Prism
self.test_call_and_write_node &&= 1
CODE
)
+
+ assert_prism_eval(<<-CODE)
+ def self.test_prism_call_node; end
+ def self.test_prism_call_node=(val)
+ val
+ end
+ self&.test_prism_call_node &&= 1
+ CODE
+
+ assert_prism_eval(<<-CODE)
+ def self.test_prism_call_node
+ 2
+ end
+ def self.test_prism_call_node=(val)
+ val
+ end
+ self&.test_prism_call_node &&= 1
+ CODE
end
def test_CallOrWriteNode
@@ -1544,6 +1572,24 @@ module Prism
self.test_call_or_write_node ||= 1
CODE
)
+
+ assert_prism_eval(<<-CODE)
+ def self.test_prism_call_node
+ 2
+ end
+ def self.test_prism_call_node=(val)
+ val
+ end
+ self&.test_prism_call_node ||= 1
+ CODE
+
+ assert_prism_eval(<<-CODE)
+ def self.test_prism_call_node; end
+ def self.test_prism_call_node=(val)
+ val
+ end
+ self&.test_prism_call_node ||= 1
+ CODE
end
def test_CallOperatorWriteNode