diff options
| -rw-r--r-- | test/ruby/test_zjit.rb | 36 | ||||
| -rw-r--r-- | zjit/src/codegen.rs | 42 | ||||
| -rw-r--r-- | zjit/src/hir.rs | 130 | ||||
| -rw-r--r-- | zjit/src/hir/opt_tests.rs | 61 | ||||
| -rw-r--r-- | zjit/src/hir/tests.rs | 65 | ||||
| -rw-r--r-- | zjit/src/stats.rs | 2 |
6 files changed, 336 insertions, 0 deletions
diff --git a/test/ruby/test_zjit.rb b/test/ruby/test_zjit.rb index e347986abc..2066610cb2 100644 --- a/test/ruby/test_zjit.rb +++ b/test/ruby/test_zjit.rb @@ -470,6 +470,42 @@ class TestZJIT < Test::Unit::TestCase }, insns: [:getblockparamproxy] end + def test_getblockparam + assert_compiles '2', %q{ + def test(&blk) + blk + end + test { 2 }.call + test { 2 }.call + }, insns: [:getblockparam] + end + + def test_getblockparam_proxy_side_exit_restores_block_local + assert_compiles '2', %q{ + def test(&block) + b = block + # sideexits here + raise "test" unless block + b ? 2 : 3 + end + test {} + test {} + }, insns: [:getblockparam, :getblockparamproxy] + end + + def test_getblockparam_used_twice_in_args + assert_compiles '1', %q{ + def f(*args) = args + def test(&blk) + b = blk + f(*[1], blk) + blk + end + test {1}.call + test {1}.call + }, insns: [:getblockparam] + end + def test_optimized_method_call_proc_call assert_compiles '2', %q{ p = proc { |x| x * 2 } diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs index 0030493ddf..870fe7584a 100644 --- a/zjit/src/codegen.rs +++ b/zjit/src/codegen.rs @@ -550,6 +550,8 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio Insn::SetGlobal { id, val, state } => no_output!(gen_setglobal(jit, asm, *id, opnd!(val), &function.frame_state(*state))), Insn::GetGlobal { id, state } => gen_getglobal(jit, asm, *id, &function.frame_state(*state)), &Insn::GetLocal { ep_offset, level, use_sp, .. } => gen_getlocal(asm, ep_offset, level, use_sp), + &Insn::IsBlockParamModified { level } => gen_is_block_param_modified(asm, level), + &Insn::GetBlockParam { ep_offset, level, state } => gen_getblockparam(jit, asm, ep_offset, level, &function.frame_state(state)), &Insn::SetLocal { val, ep_offset, level } => no_output!(gen_setlocal(asm, opnd!(val), function.type_of(val), ep_offset, level)), Insn::GetConstantPath { ic, state } => gen_get_constant_path(jit, asm, *ic, &function.frame_state(*state)), Insn::GetClassVar { id, ic, state } => gen_getclassvar(jit, asm, *id, *ic, &function.frame_state(*state)), @@ -743,6 +745,46 @@ fn gen_setlocal(asm: &mut Assembler, val: Opnd, val_type: Type, local_ep_offset: } } +/// Returns 1 (as CBool) when VM_FRAME_FLAG_MODIFIED_BLOCK_PARAM is set; returns 0 otherwise. +fn gen_is_block_param_modified(asm: &mut Assembler, level: u32) -> Opnd { + let ep = gen_get_ep(asm, level); + let flags = asm.load(Opnd::mem(VALUE_BITS, ep, SIZEOF_VALUE_I32 * (VM_ENV_DATA_INDEX_FLAGS as i32))); + asm.test(flags, VM_FRAME_FLAG_MODIFIED_BLOCK_PARAM.into()); + asm.csel_nz(Opnd::Imm(1), Opnd::Imm(0)) +} + +/// Get the block parameter as a Proc, write it to the environment, +/// and mark the flag as modified. +fn gen_getblockparam(jit: &mut JITState, asm: &mut Assembler, ep_offset: u32, level: u32, state: &FrameState) -> Opnd { + gen_prepare_leaf_call_with_gc(asm, state); + // Bail out if write barrier is required. + let ep = gen_get_ep(asm, level); + let flags = Opnd::mem(VALUE_BITS, ep, SIZEOF_VALUE_I32 * (VM_ENV_DATA_INDEX_FLAGS as i32)); + asm.test(flags, VM_ENV_FLAG_WB_REQUIRED.into()); + asm.jnz(side_exit(jit, state, SideExitReason::BlockParamWbRequired)); + + // Convert block handler to Proc. + let block_handler = asm.load(Opnd::mem(VALUE_BITS, ep, SIZEOF_VALUE_I32 * VM_ENV_DATA_INDEX_SPECVAL)); + let proc = asm_ccall!(asm, rb_vm_bh_to_procval, EC, block_handler); + + // Write Proc to EP and mark modified. + let ep = gen_get_ep(asm, level); + let local_ep_offset = c_int::try_from(ep_offset).unwrap_or_else(|_| { + panic!("Could not convert local_ep_offset {ep_offset} to i32") + }); + let offset = -(SIZEOF_VALUE_I32 * local_ep_offset); + asm.mov(Opnd::mem(VALUE_BITS, ep, offset), proc); + + let flags = Opnd::mem(VALUE_BITS, ep, SIZEOF_VALUE_I32 * (VM_ENV_DATA_INDEX_FLAGS as i32)); + let flags_val = asm.load(flags); + let modified = asm.or(flags_val, VM_FRAME_FLAG_MODIFIED_BLOCK_PARAM.into()); + asm.store(flags, modified); + + // Read the Proc from EP. + let ep = gen_get_ep(asm, level); + asm.load(Opnd::mem(VALUE_BITS, ep, offset)) +} + fn gen_guard_block_param_proxy(jit: &JITState, asm: &mut Assembler, level: u32, state: &FrameState) { // Bail out if the `&block` local variable has been modified let ep = gen_get_ep(asm, level); diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index 4326d37b34..b4f78c025d 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -506,6 +506,7 @@ pub enum SideExitReason { Interrupt, BlockParamProxyModified, BlockParamProxyNotIseqOrIfunc, + BlockParamWbRequired, StackOverflow, FixnumModByZero, FixnumDivByZero, @@ -839,6 +840,11 @@ pub enum Insn { /// If `use_sp` is true, it uses the SP register to optimize the read. /// `rest_param` is used by infer_types to infer the ArrayExact type. GetLocal { level: u32, ep_offset: u32, use_sp: bool, rest_param: bool }, + /// Check whether VM_FRAME_FLAG_MODIFIED_BLOCK_PARAM is set in the environment flags. + /// Returns CBool (0/1). + IsBlockParamModified { level: u32 }, + /// Get the block parameter as a Proc. + GetBlockParam { level: u32, ep_offset: u32, state: InsnId }, /// Set a local variable in a higher scope or the heap SetLocal { level: u32, ep_offset: u32, val: InsnId }, GetSpecialSymbol { symbol_type: SpecialBackrefSymbol, state: InsnId }, @@ -1150,6 +1156,8 @@ impl Insn { Insn::GetSpecialNumber { .. } => effects::Any, Insn::GetClassVar { .. } => effects::Any, Insn::SetClassVar { .. } => effects::Any, + Insn::IsBlockParamModified { .. } => effects::Any, + Insn::GetBlockParam { .. } => effects::Any, Insn::Snapshot { .. } => effects::Empty, Insn::Jump(_) => effects::Any, Insn::IfTrue { .. } => effects::Any, @@ -1523,6 +1531,11 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> { Insn::GuardGreaterEq { left, right, .. } => write!(f, "GuardGreaterEq {left}, {right}"), Insn::GuardSuperMethodEntry { lep, cme, .. } => write!(f, "GuardSuperMethodEntry {lep}, {:p}", self.ptr_map.map_ptr(cme)), Insn::GetBlockHandler { lep } => write!(f, "GetBlockHandler {lep}"), + &Insn::GetBlockParam { level, ep_offset, .. } => { + let name = get_local_var_name_for_printer(self.iseq, level, ep_offset) + .map_or(String::new(), |x| format!("{x}, ")); + write!(f, "GetBlockParam {name}l{level}, EP@{ep_offset}") + }, Insn::PatchPoint { invariant, .. } => { write!(f, "PatchPoint {}", invariant.print(self.ptr_map)) }, Insn::GetConstantPath { ic, .. } => { write!(f, "GetConstantPath {:p}", self.ptr_map.map_ptr(ic)) }, Insn::IsBlockGiven { lep } => { write!(f, "IsBlockGiven {lep}") }, @@ -1589,6 +1602,9 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> { let name = get_local_var_name_for_printer(self.iseq, level, ep_offset).map_or(String::new(), |x| format!("{x}, ")); write!(f, "GetLocal {name}l{level}, EP@{ep_offset}{}", if rest_param { ", *" } else { "" }) }, + &Insn::IsBlockParamModified { level } => { + write!(f, "IsBlockParamModified l{level}") + }, &Insn::SetLocal { val, level, ep_offset } => { let name = get_local_var_name_for_printer(self.iseq, level, ep_offset).map_or(String::new(), |x| format!("{x}, ")); write!(f, "SetLocal {name}l{level}, EP@{ep_offset}, {val}") @@ -2139,6 +2155,7 @@ impl Function { | PutSpecialObject {..} | GetGlobal {..} | GetLocal {..} + | IsBlockParamModified {..} | SideExit {..} | EntryPoint {..} | LoadPC @@ -2193,6 +2210,7 @@ impl Function { &GuardSuperMethodEntry { lep, cme, state } => GuardSuperMethodEntry { lep: find!(lep), cme, state }, &GetBlockHandler { lep } => GetBlockHandler { lep: find!(lep) }, &IsBlockGiven { lep } => IsBlockGiven { lep: find!(lep) }, + &GetBlockParam { level, ep_offset, state } => GetBlockParam { level, ep_offset, state: find!(state) }, &FixnumAdd { left, right, state } => FixnumAdd { left: find!(left), right: find!(right), state }, &FixnumSub { left, right, state } => FixnumSub { left: find!(left), right: find!(right), state }, &FixnumMult { left, right, state } => FixnumMult { left: find!(left), right: find!(right), state }, @@ -2488,6 +2506,8 @@ impl Function { Insn::AnyToString { .. } => types::String, Insn::GetLocal { rest_param: true, .. } => types::ArrayExact, Insn::GetLocal { .. } => types::BasicObject, + Insn::IsBlockParamModified { .. } => types::CBool, + Insn::GetBlockParam { .. } => types::BasicObject, Insn::GetBlockHandler { .. } => types::RubyValue, // The type of Snapshot doesn't really matter; it's never materialized. It's used only // as a reference for FrameState, which we use to generate side-exit code. @@ -4386,6 +4406,7 @@ impl Function { | &Insn::GetLEP | &Insn::LoadSelf | &Insn::GetLocal { .. } + | &Insn::IsBlockParamModified { .. } | &Insn::PutSpecialObject { .. } | &Insn::IncrCounter(_) | &Insn::IncrCounterPtr { .. } => @@ -4396,6 +4417,7 @@ impl Function { } &Insn::PatchPoint { state, .. } | &Insn::CheckInterrupts { state } + | &Insn::GetBlockParam { state, .. } | &Insn::GetConstantPath { ic: _, state } => { worklist.push_back(state); } @@ -5153,6 +5175,8 @@ impl Function { | Insn::GetSpecialNumber { .. } | Insn::GetSpecialSymbol { .. } | Insn::GetLocal { .. } + | Insn::GetBlockParam { .. } + | Insn::IsBlockParamModified { .. } | Insn::StoreField { .. } => { Ok(()) } @@ -6428,6 +6452,112 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> { // TODO(Shopify/ruby#753): GC root, so we should be able to avoid unnecessary GC tracing state.stack_push(fun.push_insn(block, Insn::Const { val: Const::Value(unsafe { rb_block_param_proxy }) })); } + YARVINSN_getblockparam => { + fn new_branch_block( + fun: &mut Function, + insn_idx: u32, + exit_state: &FrameState, + locals_count: usize, + stack_count: usize, + ) -> (BlockId, InsnId, FrameState, InsnId) { + let block = fun.new_block(insn_idx); + let self_param = fun.push_insn(block, Insn::Param); + let mut state = exit_state.clone(); + state.locals.clear(); + state.stack.clear(); + state.locals.extend((0..locals_count).map(|_| fun.push_insn(block, Insn::Param))); + state.stack.extend((0..stack_count).map(|_| fun.push_insn(block, Insn::Param))); + let snapshot = fun.push_insn(block, Insn::Snapshot { state: state.clone() }); + (block, self_param, state, snapshot) + } + + fn finish_getblockparam_branch( + fun: &mut Function, + block: BlockId, + self_param: InsnId, + state: &mut FrameState, + join_block: BlockId, + ep_offset: u32, + level: u32, + val: InsnId, + ) { + if level == 0 { + state.setlocal(ep_offset, val); + } + state.stack_push(val); + fun.push_insn(block, Insn::Jump(BranchEdge { + target: join_block, + args: state.as_args(self_param), + })); + } + + let ep_offset = get_arg(pc, 0).as_u32(); + let level = get_arg(pc, 1).as_u32(); + let branch_insn_idx = exit_state.insn_idx as u32; + + // If the block param is already a Proc (modified), read it from EP. + // Otherwise, convert it to a Proc and store it to EP. + let is_modified = fun.push_insn(block, Insn::IsBlockParamModified { level }); + + let locals_count = state.locals.len(); + let stack_count = state.stack.len(); + let entry_args = state.as_args(self_param); + + // Set up branch and join blocks. + let (modified_block, modified_self_param, mut modified_state, ..) = + new_branch_block(&mut fun, branch_insn_idx, &exit_state, locals_count, stack_count); + let (unmodified_block, unmodified_self_param, mut unmodified_state, unmodified_exit_id) = + new_branch_block(&mut fun, branch_insn_idx, &exit_state, locals_count, stack_count); + let join_block = insn_idx_to_block.get(&insn_idx).copied().unwrap_or_else(|| fun.new_block(insn_idx)); + + fun.push_insn(block, Insn::IfTrue { + val: is_modified, + target: BranchEdge { target: modified_block, args: entry_args.clone() }, + }); + fun.push_insn(block, Insn::Jump(BranchEdge { + target: unmodified_block, + args: entry_args, + })); + + // Push modified block: read Proc from EP. + let modified_val = fun.push_insn(modified_block, Insn::GetLocal { + ep_offset, + level, + use_sp: false, + rest_param: false, + }); + finish_getblockparam_branch( + &mut fun, + modified_block, + modified_self_param, + &mut modified_state, + join_block, + ep_offset, + level, + modified_val, + ); + + // Push unmodified block: convert block handler to Proc. + let unmodified_val = fun.push_insn(unmodified_block, Insn::GetBlockParam { + ep_offset, + level, + state: unmodified_exit_id, + }); + finish_getblockparam_branch( + &mut fun, + unmodified_block, + unmodified_self_param, + &mut unmodified_state, + join_block, + ep_offset, + level, + unmodified_val, + ); + + // Continue compilation from the join block at the next instruction. + queue.push_back((unmodified_state, join_block, insn_idx, local_inval)); + break; + } YARVINSN_pop => { state.stack_pop()?; } YARVINSN_dup => { state.stack_push(state.stack_top()?); } YARVINSN_dupn => { diff --git a/zjit/src/hir/opt_tests.rs b/zjit/src/hir/opt_tests.rs index da24025010..0a42652993 100644 --- a/zjit/src/hir/opt_tests.rs +++ b/zjit/src/hir/opt_tests.rs @@ -3853,6 +3853,67 @@ mod hir_opt_tests { } #[test] + fn test_getblockparam() { + eval(" + def test(&block) = block + "); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:2: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal :block, l0, SP@4 + Jump bb2(v1, v2) + bb1(v5:BasicObject, v6:BasicObject): + EntryPoint JIT(0) + Jump bb2(v5, v6) + bb2(v8:BasicObject, v9:BasicObject): + v13:CBool = IsBlockParamModified l0 + IfTrue v13, bb3(v8, v9) + v24:BasicObject = GetBlockParam :block, l0, EP@3 + Jump bb5(v8, v24, v24) + bb3(v14:BasicObject, v15:BasicObject): + v22:BasicObject = GetLocal :block, l0, EP@3 + Jump bb5(v14, v22, v22) + bb5(v26:BasicObject, v27:BasicObject, v28:BasicObject): + CheckInterrupts + Return v28 + "); + } + + #[test] + fn test_getblockparam_nested_block() { + eval(" + def test(&block) + proc do + block + end + end + "); + assert_snapshot!(hir_string_proc("test"), @r" + fn block in test@<compiled>:4: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + Jump bb2(v1) + bb1(v4:BasicObject): + EntryPoint JIT(0) + Jump bb2(v4) + bb2(v6:BasicObject): + v10:CBool = IsBlockParamModified l1 + IfTrue v10, bb3(v6) + v19:BasicObject = GetBlockParam :block, l1, EP@3 + Jump bb5(v6, v19) + bb3(v11:BasicObject): + v17:BasicObject = GetLocal :block, l1, EP@3 + Jump bb5(v11, v17) + bb5(v21:BasicObject, v22:BasicObject): + CheckInterrupts + Return v22 + "); + } + + #[test] fn test_getinstancevariable() { eval(" def test = @foo diff --git a/zjit/src/hir/tests.rs b/zjit/src/hir/tests.rs index 3e28178273..44082ce908 100644 --- a/zjit/src/hir/tests.rs +++ b/zjit/src/hir/tests.rs @@ -2684,6 +2684,71 @@ pub mod hir_build_tests { } #[test] + fn test_getblockparam() { + eval(" + def test(&block) = block + "); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:2: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal :block, l0, SP@4 + Jump bb2(v1, v2) + bb1(v5:BasicObject, v6:BasicObject): + EntryPoint JIT(0) + Jump bb2(v5, v6) + bb2(v8:BasicObject, v9:BasicObject): + v13:CBool = IsBlockParamModified l0 + IfTrue v13, bb3(v8, v9) + Jump bb4(v8, v9) + bb3(v14:BasicObject, v15:BasicObject): + v22:BasicObject = GetLocal :block, l0, EP@3 + Jump bb5(v14, v22, v22) + bb4(v17:BasicObject, v18:BasicObject): + v24:BasicObject = GetBlockParam :block, l0, EP@3 + Jump bb5(v17, v24, v24) + bb5(v26:BasicObject, v27:BasicObject, v28:BasicObject): + CheckInterrupts + Return v28 + "); + } + + #[test] + fn test_getblockparam_nested_block() { + eval(" + def test(&block) + proc do + block + end + end + "); + assert_snapshot!(hir_string_proc("test"), @r" + fn block in test@<compiled>:4: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + Jump bb2(v1) + bb1(v4:BasicObject): + EntryPoint JIT(0) + Jump bb2(v4) + bb2(v6:BasicObject): + v10:CBool = IsBlockParamModified l1 + IfTrue v10, bb3(v6) + Jump bb4(v6) + bb3(v11:BasicObject): + v17:BasicObject = GetLocal :block, l1, EP@3 + Jump bb5(v11, v17) + bb4(v13:BasicObject): + v19:BasicObject = GetBlockParam :block, l1, EP@3 + Jump bb5(v13, v19) + bb5(v21:BasicObject, v22:BasicObject): + CheckInterrupts + Return v22 + "); + } + + #[test] fn test_splatarray_mut() { eval(" def test(a) = [*a] diff --git a/zjit/src/stats.rs b/zjit/src/stats.rs index cf100dcda2..556a1417a4 100644 --- a/zjit/src/stats.rs +++ b/zjit/src/stats.rs @@ -210,6 +210,7 @@ make_counters! { exit_stackoverflow, exit_block_param_proxy_modified, exit_block_param_proxy_not_iseq_or_ifunc, + exit_block_param_wb_required, exit_too_many_keyword_parameters, } @@ -557,6 +558,7 @@ pub fn side_exit_counter(reason: crate::hir::SideExitReason) -> Counter { StackOverflow => exit_stackoverflow, BlockParamProxyModified => exit_block_param_proxy_modified, BlockParamProxyNotIseqOrIfunc => exit_block_param_proxy_not_iseq_or_ifunc, + BlockParamWbRequired => exit_block_param_wb_required, TooManyKeywordParameters => exit_too_many_keyword_parameters, PatchPoint(Invariant::BOPRedefined { .. }) => exit_patchpoint_bop_redefined, |
