summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--test/ruby/test_zjit.rb36
-rw-r--r--zjit/src/codegen.rs42
-rw-r--r--zjit/src/hir.rs130
-rw-r--r--zjit/src/hir/opt_tests.rs61
-rw-r--r--zjit/src/hir/tests.rs65
-rw-r--r--zjit/src/stats.rs2
6 files changed, 336 insertions, 0 deletions
diff --git a/test/ruby/test_zjit.rb b/test/ruby/test_zjit.rb
index e347986abc..2066610cb2 100644
--- a/test/ruby/test_zjit.rb
+++ b/test/ruby/test_zjit.rb
@@ -470,6 +470,42 @@ class TestZJIT < Test::Unit::TestCase
}, insns: [:getblockparamproxy]
end
+ def test_getblockparam
+ assert_compiles '2', %q{
+ def test(&blk)
+ blk
+ end
+ test { 2 }.call
+ test { 2 }.call
+ }, insns: [:getblockparam]
+ end
+
+ def test_getblockparam_proxy_side_exit_restores_block_local
+ assert_compiles '2', %q{
+ def test(&block)
+ b = block
+ # sideexits here
+ raise "test" unless block
+ b ? 2 : 3
+ end
+ test {}
+ test {}
+ }, insns: [:getblockparam, :getblockparamproxy]
+ end
+
+ def test_getblockparam_used_twice_in_args
+ assert_compiles '1', %q{
+ def f(*args) = args
+ def test(&blk)
+ b = blk
+ f(*[1], blk)
+ blk
+ end
+ test {1}.call
+ test {1}.call
+ }, insns: [:getblockparam]
+ end
+
def test_optimized_method_call_proc_call
assert_compiles '2', %q{
p = proc { |x| x * 2 }
diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs
index 0030493ddf..870fe7584a 100644
--- a/zjit/src/codegen.rs
+++ b/zjit/src/codegen.rs
@@ -550,6 +550,8 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio
Insn::SetGlobal { id, val, state } => no_output!(gen_setglobal(jit, asm, *id, opnd!(val), &function.frame_state(*state))),
Insn::GetGlobal { id, state } => gen_getglobal(jit, asm, *id, &function.frame_state(*state)),
&Insn::GetLocal { ep_offset, level, use_sp, .. } => gen_getlocal(asm, ep_offset, level, use_sp),
+ &Insn::IsBlockParamModified { level } => gen_is_block_param_modified(asm, level),
+ &Insn::GetBlockParam { ep_offset, level, state } => gen_getblockparam(jit, asm, ep_offset, level, &function.frame_state(state)),
&Insn::SetLocal { val, ep_offset, level } => no_output!(gen_setlocal(asm, opnd!(val), function.type_of(val), ep_offset, level)),
Insn::GetConstantPath { ic, state } => gen_get_constant_path(jit, asm, *ic, &function.frame_state(*state)),
Insn::GetClassVar { id, ic, state } => gen_getclassvar(jit, asm, *id, *ic, &function.frame_state(*state)),
@@ -743,6 +745,46 @@ fn gen_setlocal(asm: &mut Assembler, val: Opnd, val_type: Type, local_ep_offset:
}
}
+/// Returns 1 (as CBool) when VM_FRAME_FLAG_MODIFIED_BLOCK_PARAM is set; returns 0 otherwise.
+fn gen_is_block_param_modified(asm: &mut Assembler, level: u32) -> Opnd {
+ let ep = gen_get_ep(asm, level);
+ let flags = asm.load(Opnd::mem(VALUE_BITS, ep, SIZEOF_VALUE_I32 * (VM_ENV_DATA_INDEX_FLAGS as i32)));
+ asm.test(flags, VM_FRAME_FLAG_MODIFIED_BLOCK_PARAM.into());
+ asm.csel_nz(Opnd::Imm(1), Opnd::Imm(0))
+}
+
+/// Get the block parameter as a Proc, write it to the environment,
+/// and mark the flag as modified.
+fn gen_getblockparam(jit: &mut JITState, asm: &mut Assembler, ep_offset: u32, level: u32, state: &FrameState) -> Opnd {
+ gen_prepare_leaf_call_with_gc(asm, state);
+ // Bail out if write barrier is required.
+ let ep = gen_get_ep(asm, level);
+ let flags = Opnd::mem(VALUE_BITS, ep, SIZEOF_VALUE_I32 * (VM_ENV_DATA_INDEX_FLAGS as i32));
+ asm.test(flags, VM_ENV_FLAG_WB_REQUIRED.into());
+ asm.jnz(side_exit(jit, state, SideExitReason::BlockParamWbRequired));
+
+ // Convert block handler to Proc.
+ let block_handler = asm.load(Opnd::mem(VALUE_BITS, ep, SIZEOF_VALUE_I32 * VM_ENV_DATA_INDEX_SPECVAL));
+ let proc = asm_ccall!(asm, rb_vm_bh_to_procval, EC, block_handler);
+
+ // Write Proc to EP and mark modified.
+ let ep = gen_get_ep(asm, level);
+ let local_ep_offset = c_int::try_from(ep_offset).unwrap_or_else(|_| {
+ panic!("Could not convert local_ep_offset {ep_offset} to i32")
+ });
+ let offset = -(SIZEOF_VALUE_I32 * local_ep_offset);
+ asm.mov(Opnd::mem(VALUE_BITS, ep, offset), proc);
+
+ let flags = Opnd::mem(VALUE_BITS, ep, SIZEOF_VALUE_I32 * (VM_ENV_DATA_INDEX_FLAGS as i32));
+ let flags_val = asm.load(flags);
+ let modified = asm.or(flags_val, VM_FRAME_FLAG_MODIFIED_BLOCK_PARAM.into());
+ asm.store(flags, modified);
+
+ // Read the Proc from EP.
+ let ep = gen_get_ep(asm, level);
+ asm.load(Opnd::mem(VALUE_BITS, ep, offset))
+}
+
fn gen_guard_block_param_proxy(jit: &JITState, asm: &mut Assembler, level: u32, state: &FrameState) {
// Bail out if the `&block` local variable has been modified
let ep = gen_get_ep(asm, level);
diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs
index 4326d37b34..b4f78c025d 100644
--- a/zjit/src/hir.rs
+++ b/zjit/src/hir.rs
@@ -506,6 +506,7 @@ pub enum SideExitReason {
Interrupt,
BlockParamProxyModified,
BlockParamProxyNotIseqOrIfunc,
+ BlockParamWbRequired,
StackOverflow,
FixnumModByZero,
FixnumDivByZero,
@@ -839,6 +840,11 @@ pub enum Insn {
/// If `use_sp` is true, it uses the SP register to optimize the read.
/// `rest_param` is used by infer_types to infer the ArrayExact type.
GetLocal { level: u32, ep_offset: u32, use_sp: bool, rest_param: bool },
+ /// Check whether VM_FRAME_FLAG_MODIFIED_BLOCK_PARAM is set in the environment flags.
+ /// Returns CBool (0/1).
+ IsBlockParamModified { level: u32 },
+ /// Get the block parameter as a Proc.
+ GetBlockParam { level: u32, ep_offset: u32, state: InsnId },
/// Set a local variable in a higher scope or the heap
SetLocal { level: u32, ep_offset: u32, val: InsnId },
GetSpecialSymbol { symbol_type: SpecialBackrefSymbol, state: InsnId },
@@ -1150,6 +1156,8 @@ impl Insn {
Insn::GetSpecialNumber { .. } => effects::Any,
Insn::GetClassVar { .. } => effects::Any,
Insn::SetClassVar { .. } => effects::Any,
+ Insn::IsBlockParamModified { .. } => effects::Any,
+ Insn::GetBlockParam { .. } => effects::Any,
Insn::Snapshot { .. } => effects::Empty,
Insn::Jump(_) => effects::Any,
Insn::IfTrue { .. } => effects::Any,
@@ -1523,6 +1531,11 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> {
Insn::GuardGreaterEq { left, right, .. } => write!(f, "GuardGreaterEq {left}, {right}"),
Insn::GuardSuperMethodEntry { lep, cme, .. } => write!(f, "GuardSuperMethodEntry {lep}, {:p}", self.ptr_map.map_ptr(cme)),
Insn::GetBlockHandler { lep } => write!(f, "GetBlockHandler {lep}"),
+ &Insn::GetBlockParam { level, ep_offset, .. } => {
+ let name = get_local_var_name_for_printer(self.iseq, level, ep_offset)
+ .map_or(String::new(), |x| format!("{x}, "));
+ write!(f, "GetBlockParam {name}l{level}, EP@{ep_offset}")
+ },
Insn::PatchPoint { invariant, .. } => { write!(f, "PatchPoint {}", invariant.print(self.ptr_map)) },
Insn::GetConstantPath { ic, .. } => { write!(f, "GetConstantPath {:p}", self.ptr_map.map_ptr(ic)) },
Insn::IsBlockGiven { lep } => { write!(f, "IsBlockGiven {lep}") },
@@ -1589,6 +1602,9 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> {
let name = get_local_var_name_for_printer(self.iseq, level, ep_offset).map_or(String::new(), |x| format!("{x}, "));
write!(f, "GetLocal {name}l{level}, EP@{ep_offset}{}", if rest_param { ", *" } else { "" })
},
+ &Insn::IsBlockParamModified { level } => {
+ write!(f, "IsBlockParamModified l{level}")
+ },
&Insn::SetLocal { val, level, ep_offset } => {
let name = get_local_var_name_for_printer(self.iseq, level, ep_offset).map_or(String::new(), |x| format!("{x}, "));
write!(f, "SetLocal {name}l{level}, EP@{ep_offset}, {val}")
@@ -2139,6 +2155,7 @@ impl Function {
| PutSpecialObject {..}
| GetGlobal {..}
| GetLocal {..}
+ | IsBlockParamModified {..}
| SideExit {..}
| EntryPoint {..}
| LoadPC
@@ -2193,6 +2210,7 @@ impl Function {
&GuardSuperMethodEntry { lep, cme, state } => GuardSuperMethodEntry { lep: find!(lep), cme, state },
&GetBlockHandler { lep } => GetBlockHandler { lep: find!(lep) },
&IsBlockGiven { lep } => IsBlockGiven { lep: find!(lep) },
+ &GetBlockParam { level, ep_offset, state } => GetBlockParam { level, ep_offset, state: find!(state) },
&FixnumAdd { left, right, state } => FixnumAdd { left: find!(left), right: find!(right), state },
&FixnumSub { left, right, state } => FixnumSub { left: find!(left), right: find!(right), state },
&FixnumMult { left, right, state } => FixnumMult { left: find!(left), right: find!(right), state },
@@ -2488,6 +2506,8 @@ impl Function {
Insn::AnyToString { .. } => types::String,
Insn::GetLocal { rest_param: true, .. } => types::ArrayExact,
Insn::GetLocal { .. } => types::BasicObject,
+ Insn::IsBlockParamModified { .. } => types::CBool,
+ Insn::GetBlockParam { .. } => types::BasicObject,
Insn::GetBlockHandler { .. } => types::RubyValue,
// The type of Snapshot doesn't really matter; it's never materialized. It's used only
// as a reference for FrameState, which we use to generate side-exit code.
@@ -4386,6 +4406,7 @@ impl Function {
| &Insn::GetLEP
| &Insn::LoadSelf
| &Insn::GetLocal { .. }
+ | &Insn::IsBlockParamModified { .. }
| &Insn::PutSpecialObject { .. }
| &Insn::IncrCounter(_)
| &Insn::IncrCounterPtr { .. } =>
@@ -4396,6 +4417,7 @@ impl Function {
}
&Insn::PatchPoint { state, .. }
| &Insn::CheckInterrupts { state }
+ | &Insn::GetBlockParam { state, .. }
| &Insn::GetConstantPath { ic: _, state } => {
worklist.push_back(state);
}
@@ -5153,6 +5175,8 @@ impl Function {
| Insn::GetSpecialNumber { .. }
| Insn::GetSpecialSymbol { .. }
| Insn::GetLocal { .. }
+ | Insn::GetBlockParam { .. }
+ | Insn::IsBlockParamModified { .. }
| Insn::StoreField { .. } => {
Ok(())
}
@@ -6428,6 +6452,112 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
// TODO(Shopify/ruby#753): GC root, so we should be able to avoid unnecessary GC tracing
state.stack_push(fun.push_insn(block, Insn::Const { val: Const::Value(unsafe { rb_block_param_proxy }) }));
}
+ YARVINSN_getblockparam => {
+ fn new_branch_block(
+ fun: &mut Function,
+ insn_idx: u32,
+ exit_state: &FrameState,
+ locals_count: usize,
+ stack_count: usize,
+ ) -> (BlockId, InsnId, FrameState, InsnId) {
+ let block = fun.new_block(insn_idx);
+ let self_param = fun.push_insn(block, Insn::Param);
+ let mut state = exit_state.clone();
+ state.locals.clear();
+ state.stack.clear();
+ state.locals.extend((0..locals_count).map(|_| fun.push_insn(block, Insn::Param)));
+ state.stack.extend((0..stack_count).map(|_| fun.push_insn(block, Insn::Param)));
+ let snapshot = fun.push_insn(block, Insn::Snapshot { state: state.clone() });
+ (block, self_param, state, snapshot)
+ }
+
+ fn finish_getblockparam_branch(
+ fun: &mut Function,
+ block: BlockId,
+ self_param: InsnId,
+ state: &mut FrameState,
+ join_block: BlockId,
+ ep_offset: u32,
+ level: u32,
+ val: InsnId,
+ ) {
+ if level == 0 {
+ state.setlocal(ep_offset, val);
+ }
+ state.stack_push(val);
+ fun.push_insn(block, Insn::Jump(BranchEdge {
+ target: join_block,
+ args: state.as_args(self_param),
+ }));
+ }
+
+ let ep_offset = get_arg(pc, 0).as_u32();
+ let level = get_arg(pc, 1).as_u32();
+ let branch_insn_idx = exit_state.insn_idx as u32;
+
+ // If the block param is already a Proc (modified), read it from EP.
+ // Otherwise, convert it to a Proc and store it to EP.
+ let is_modified = fun.push_insn(block, Insn::IsBlockParamModified { level });
+
+ let locals_count = state.locals.len();
+ let stack_count = state.stack.len();
+ let entry_args = state.as_args(self_param);
+
+ // Set up branch and join blocks.
+ let (modified_block, modified_self_param, mut modified_state, ..) =
+ new_branch_block(&mut fun, branch_insn_idx, &exit_state, locals_count, stack_count);
+ let (unmodified_block, unmodified_self_param, mut unmodified_state, unmodified_exit_id) =
+ new_branch_block(&mut fun, branch_insn_idx, &exit_state, locals_count, stack_count);
+ let join_block = insn_idx_to_block.get(&insn_idx).copied().unwrap_or_else(|| fun.new_block(insn_idx));
+
+ fun.push_insn(block, Insn::IfTrue {
+ val: is_modified,
+ target: BranchEdge { target: modified_block, args: entry_args.clone() },
+ });
+ fun.push_insn(block, Insn::Jump(BranchEdge {
+ target: unmodified_block,
+ args: entry_args,
+ }));
+
+ // Push modified block: read Proc from EP.
+ let modified_val = fun.push_insn(modified_block, Insn::GetLocal {
+ ep_offset,
+ level,
+ use_sp: false,
+ rest_param: false,
+ });
+ finish_getblockparam_branch(
+ &mut fun,
+ modified_block,
+ modified_self_param,
+ &mut modified_state,
+ join_block,
+ ep_offset,
+ level,
+ modified_val,
+ );
+
+ // Push unmodified block: convert block handler to Proc.
+ let unmodified_val = fun.push_insn(unmodified_block, Insn::GetBlockParam {
+ ep_offset,
+ level,
+ state: unmodified_exit_id,
+ });
+ finish_getblockparam_branch(
+ &mut fun,
+ unmodified_block,
+ unmodified_self_param,
+ &mut unmodified_state,
+ join_block,
+ ep_offset,
+ level,
+ unmodified_val,
+ );
+
+ // Continue compilation from the join block at the next instruction.
+ queue.push_back((unmodified_state, join_block, insn_idx, local_inval));
+ break;
+ }
YARVINSN_pop => { state.stack_pop()?; }
YARVINSN_dup => { state.stack_push(state.stack_top()?); }
YARVINSN_dupn => {
diff --git a/zjit/src/hir/opt_tests.rs b/zjit/src/hir/opt_tests.rs
index da24025010..0a42652993 100644
--- a/zjit/src/hir/opt_tests.rs
+++ b/zjit/src/hir/opt_tests.rs
@@ -3853,6 +3853,67 @@ mod hir_opt_tests {
}
#[test]
+ fn test_getblockparam() {
+ eval("
+ def test(&block) = block
+ ");
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:2:
+ bb0():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ v2:BasicObject = GetLocal :block, l0, SP@4
+ Jump bb2(v1, v2)
+ bb1(v5:BasicObject, v6:BasicObject):
+ EntryPoint JIT(0)
+ Jump bb2(v5, v6)
+ bb2(v8:BasicObject, v9:BasicObject):
+ v13:CBool = IsBlockParamModified l0
+ IfTrue v13, bb3(v8, v9)
+ v24:BasicObject = GetBlockParam :block, l0, EP@3
+ Jump bb5(v8, v24, v24)
+ bb3(v14:BasicObject, v15:BasicObject):
+ v22:BasicObject = GetLocal :block, l0, EP@3
+ Jump bb5(v14, v22, v22)
+ bb5(v26:BasicObject, v27:BasicObject, v28:BasicObject):
+ CheckInterrupts
+ Return v28
+ ");
+ }
+
+ #[test]
+ fn test_getblockparam_nested_block() {
+ eval("
+ def test(&block)
+ proc do
+ block
+ end
+ end
+ ");
+ assert_snapshot!(hir_string_proc("test"), @r"
+ fn block in test@<compiled>:4:
+ bb0():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ Jump bb2(v1)
+ bb1(v4:BasicObject):
+ EntryPoint JIT(0)
+ Jump bb2(v4)
+ bb2(v6:BasicObject):
+ v10:CBool = IsBlockParamModified l1
+ IfTrue v10, bb3(v6)
+ v19:BasicObject = GetBlockParam :block, l1, EP@3
+ Jump bb5(v6, v19)
+ bb3(v11:BasicObject):
+ v17:BasicObject = GetLocal :block, l1, EP@3
+ Jump bb5(v11, v17)
+ bb5(v21:BasicObject, v22:BasicObject):
+ CheckInterrupts
+ Return v22
+ ");
+ }
+
+ #[test]
fn test_getinstancevariable() {
eval("
def test = @foo
diff --git a/zjit/src/hir/tests.rs b/zjit/src/hir/tests.rs
index 3e28178273..44082ce908 100644
--- a/zjit/src/hir/tests.rs
+++ b/zjit/src/hir/tests.rs
@@ -2684,6 +2684,71 @@ pub mod hir_build_tests {
}
#[test]
+ fn test_getblockparam() {
+ eval("
+ def test(&block) = block
+ ");
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:2:
+ bb0():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ v2:BasicObject = GetLocal :block, l0, SP@4
+ Jump bb2(v1, v2)
+ bb1(v5:BasicObject, v6:BasicObject):
+ EntryPoint JIT(0)
+ Jump bb2(v5, v6)
+ bb2(v8:BasicObject, v9:BasicObject):
+ v13:CBool = IsBlockParamModified l0
+ IfTrue v13, bb3(v8, v9)
+ Jump bb4(v8, v9)
+ bb3(v14:BasicObject, v15:BasicObject):
+ v22:BasicObject = GetLocal :block, l0, EP@3
+ Jump bb5(v14, v22, v22)
+ bb4(v17:BasicObject, v18:BasicObject):
+ v24:BasicObject = GetBlockParam :block, l0, EP@3
+ Jump bb5(v17, v24, v24)
+ bb5(v26:BasicObject, v27:BasicObject, v28:BasicObject):
+ CheckInterrupts
+ Return v28
+ ");
+ }
+
+ #[test]
+ fn test_getblockparam_nested_block() {
+ eval("
+ def test(&block)
+ proc do
+ block
+ end
+ end
+ ");
+ assert_snapshot!(hir_string_proc("test"), @r"
+ fn block in test@<compiled>:4:
+ bb0():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ Jump bb2(v1)
+ bb1(v4:BasicObject):
+ EntryPoint JIT(0)
+ Jump bb2(v4)
+ bb2(v6:BasicObject):
+ v10:CBool = IsBlockParamModified l1
+ IfTrue v10, bb3(v6)
+ Jump bb4(v6)
+ bb3(v11:BasicObject):
+ v17:BasicObject = GetLocal :block, l1, EP@3
+ Jump bb5(v11, v17)
+ bb4(v13:BasicObject):
+ v19:BasicObject = GetBlockParam :block, l1, EP@3
+ Jump bb5(v13, v19)
+ bb5(v21:BasicObject, v22:BasicObject):
+ CheckInterrupts
+ Return v22
+ ");
+ }
+
+ #[test]
fn test_splatarray_mut() {
eval("
def test(a) = [*a]
diff --git a/zjit/src/stats.rs b/zjit/src/stats.rs
index cf100dcda2..556a1417a4 100644
--- a/zjit/src/stats.rs
+++ b/zjit/src/stats.rs
@@ -210,6 +210,7 @@ make_counters! {
exit_stackoverflow,
exit_block_param_proxy_modified,
exit_block_param_proxy_not_iseq_or_ifunc,
+ exit_block_param_wb_required,
exit_too_many_keyword_parameters,
}
@@ -557,6 +558,7 @@ pub fn side_exit_counter(reason: crate::hir::SideExitReason) -> Counter {
StackOverflow => exit_stackoverflow,
BlockParamProxyModified => exit_block_param_proxy_modified,
BlockParamProxyNotIseqOrIfunc => exit_block_param_proxy_not_iseq_or_ifunc,
+ BlockParamWbRequired => exit_block_param_wb_required,
TooManyKeywordParameters => exit_too_many_keyword_parameters,
PatchPoint(Invariant::BOPRedefined { .. })
=> exit_patchpoint_bop_redefined,