diff options
| author | Takashi Kokubun <takashi.kokubun@shopify.com> | 2025-09-15 17:43:41 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-09-15 17:43:41 -0700 |
| commit | 6c5960ae1955b66b4a13e03012d53853ee1fd1de (patch) | |
| tree | 5c968cf42f044c15f4fa4cc04bf770d6ac13fa8a | |
| parent | e4f09a8c94e6e6d21a6dfa43f71d52e4096234d6 (diff) | |
ZJIT: Support compiling block args (#14537)
| -rw-r--r-- | test/ruby/test_zjit.rb | 15 | ||||
| -rw-r--r-- | zjit/src/backend/lir.rs | 7 | ||||
| -rw-r--r-- | zjit/src/hir.rs | 80 | ||||
| -rw-r--r-- | zjit/src/stats.rs | 4 |
4 files changed, 63 insertions, 43 deletions
diff --git a/test/ruby/test_zjit.rb b/test/ruby/test_zjit.rb index 87d8b06ece..fc79bdda6e 100644 --- a/test/ruby/test_zjit.rb +++ b/test/ruby/test_zjit.rb @@ -439,6 +439,21 @@ class TestZJIT < Test::Unit::TestCase }, call_threshold: 2 end + def test_send_nil_block_arg + assert_compiles 'false', %q{ + def test = block_given? + def entry = test(&nil) + test + } + end + + def test_send_symbol_block_arg + assert_compiles '["1", "2"]', %q{ + def test = [1, 2].map(&:to_s) + test + } + end + def test_forwardable_iseq assert_compiles '1', %q{ def test(...) = 1 diff --git a/zjit/src/backend/lir.rs b/zjit/src/backend/lir.rs index c67bf8c9ba..5f96d0718a 100644 --- a/zjit/src/backend/lir.rs +++ b/zjit/src/backend/lir.rs @@ -6,7 +6,7 @@ use crate::cruby::{Qundef, RUBY_OFFSET_CFP_PC, RUBY_OFFSET_CFP_SP, SIZEOF_VALUE_ use crate::hir::SideExitReason; use crate::options::{debug, get_option}; use crate::cruby::VALUE; -use crate::stats::{exit_counter_ptr, exit_counter_ptr_for_call_type, exit_counter_ptr_for_opcode, CompileError}; +use crate::stats::{exit_counter_ptr, exit_counter_ptr_for_opcode, CompileError}; use crate::virtualmem::CodePtr; use crate::asm::{CodeBlock, Label}; @@ -1601,11 +1601,6 @@ impl Assembler self.load_into(SCRATCH_OPND, Opnd::const_ptr(exit_counter_ptr_for_opcode(opcode))); self.incr_counter_with_reg(Opnd::mem(64, SCRATCH_OPND, 0), 1.into(), C_RET_OPND); } - if let SideExitReason::UnhandledCallType(call_type) = reason { - asm_comment!(self, "increment an unknown call type counter"); - self.load_into(SCRATCH_OPND, Opnd::const_ptr(exit_counter_ptr_for_call_type(call_type))); - self.incr_counter_with_reg(Opnd::mem(64, SCRATCH_OPND, 0), 1.into(), C_RET_OPND); - } } asm_comment!(self, "exit to the interpreter"); diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index 3726d8ec0e..46bf38dcc6 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -447,10 +447,10 @@ impl PtrPrintMap { #[derive(Debug, Clone, Copy)] pub enum SideExitReason { UnknownNewarraySend(vm_opt_newarray_send_type), - UnhandledCallType(CallType), UnknownSpecialVariable(u64), UnhandledHIRInsn(InsnId), UnhandledYARVInsn(u32), + UnhandledTailCall, FixnumAddOverflow, FixnumSubOverflow, FixnumMultOverflow, @@ -2986,10 +2986,8 @@ fn num_locals(iseq: *const rb_iseq_t) -> usize { } /// If we can't handle the type of send (yet), bail out. -fn unknown_call_type(flag: u32) -> Result<(), CallType> { - if (flag & VM_CALL_ARGS_BLOCKARG) != 0 { return Err(CallType::BlockArg); } - if (flag & VM_CALL_TAILCALL) != 0 { return Err(CallType::Tailcall); } - Ok(()) +fn is_tailcall(flags: u32) -> bool { + (flags & VM_CALL_TAILCALL) != 0 } /// We have IseqPayload, which keeps track of HIR Types in the interpreter, but this is not useful @@ -3501,10 +3499,11 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> { // NB: opt_neq has two cd; get_arg(0) is for eq and get_arg(1) is for neq let cd: *const rb_call_data = get_arg(pc, 1).as_ptr(); let call_info = unsafe { rb_get_call_data_ci(cd) }; - if let Err(call_type) = unknown_call_type(unsafe { rb_vm_ci_flag(call_info) }) { - // Unknown call type; side-exit into the interpreter + let flags = unsafe { rb_vm_ci_flag(call_info) }; + if is_tailcall(flags) { + // Can't handle tailcall; side-exit into the interpreter let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state }); - fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::UnhandledCallType(call_type) }); + fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::UnhandledTailCall }); break; // End the block } let argc = unsafe { vm_ci_argc((*cd).ci) }; @@ -3522,10 +3521,11 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> { // NB: these instructions have the recv for the call at get_arg(0) let cd: *const rb_call_data = get_arg(pc, 1).as_ptr(); let call_info = unsafe { rb_get_call_data_ci(cd) }; - if let Err(call_type) = unknown_call_type(unsafe { rb_vm_ci_flag(call_info) }) { - // Unknown call type; side-exit into the interpreter + let flags = unsafe { rb_vm_ci_flag(call_info) }; + if is_tailcall(flags) { + // Can't handle tailcall; side-exit into the interpreter let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state }); - fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::UnhandledCallType(call_type) }); + fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::UnhandledTailCall }); break; // End the block } let argc = unsafe { vm_ci_argc((*cd).ci) }; @@ -3580,10 +3580,11 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> { YARVINSN_opt_send_without_block => { let cd: *const rb_call_data = get_arg(pc, 0).as_ptr(); let call_info = unsafe { rb_get_call_data_ci(cd) }; - if let Err(call_type) = unknown_call_type(unsafe { rb_vm_ci_flag(call_info) }) { - // Unknown call type; side-exit into the interpreter + let flags = unsafe { rb_vm_ci_flag(call_info) }; + if is_tailcall(flags) { + // Can't handle tailcall; side-exit into the interpreter let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state }); - fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::UnhandledCallType(call_type) }); + fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::UnhandledTailCall }); break; // End the block } let argc = unsafe { vm_ci_argc((*cd).ci) }; @@ -3598,15 +3599,17 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> { let cd: *const rb_call_data = get_arg(pc, 0).as_ptr(); let blockiseq: IseqPtr = get_arg(pc, 1).as_iseq(); let call_info = unsafe { rb_get_call_data_ci(cd) }; - if let Err(call_type) = unknown_call_type(unsafe { rb_vm_ci_flag(call_info) }) { - // Unknown call type; side-exit into the interpreter + let flags = unsafe { rb_vm_ci_flag(call_info) }; + if is_tailcall(flags) { + // Can't handle tailcall; side-exit into the interpreter let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state }); - fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::UnhandledCallType(call_type) }); + fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::UnhandledTailCall }); break; // End the block } let argc = unsafe { vm_ci_argc((*cd).ci) }; + let block_arg = (flags & VM_CALL_ARGS_BLOCKARG) != 0; - let args = state.stack_pop_n(argc as usize)?; + let args = state.stack_pop_n(argc as usize + usize::from(block_arg))?; let recv = state.stack_pop()?; let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state }); let send = fun.push_insn(block, Insn::Send { recv, cd, blockiseq, args, state: exit_id }); @@ -3624,14 +3627,16 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> { YARVINSN_invokesuper => { let cd: *const rb_call_data = get_arg(pc, 0).as_ptr(); let call_info = unsafe { rb_get_call_data_ci(cd) }; - if let Err(call_type) = unknown_call_type(unsafe { rb_vm_ci_flag(call_info) } & !VM_CALL_SUPER & !VM_CALL_ZSUPER) { - // Unknown call type; side-exit into the interpreter + let flags = unsafe { rb_vm_ci_flag(call_info) }; + if is_tailcall(flags) { + // Can't handle tailcall; side-exit into the interpreter let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state }); - fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::UnhandledCallType(call_type) }); + fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::UnhandledTailCall }); break; // End the block } let argc = unsafe { vm_ci_argc((*cd).ci) }; - let args = state.stack_pop_n(argc as usize)?; + let block_arg = (flags & VM_CALL_ARGS_BLOCKARG) != 0; + let args = state.stack_pop_n(argc as usize + usize::from(block_arg))?; let recv = state.stack_pop()?; let blockiseq: IseqPtr = get_arg(pc, 1).as_ptr(); let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state }); @@ -3652,14 +3657,16 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> { YARVINSN_invokeblock => { let cd: *const rb_call_data = get_arg(pc, 0).as_ptr(); let call_info = unsafe { rb_get_call_data_ci(cd) }; - if let Err(call_type) = unknown_call_type(unsafe { rb_vm_ci_flag(call_info) }) { - // Unknown call type; side-exit into the interpreter + let flags = unsafe { rb_vm_ci_flag(call_info) }; + if is_tailcall(flags) { + // Can't handle tailcall; side-exit into the interpreter let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state }); - fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::UnhandledCallType(call_type) }); + fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::UnhandledTailCall }); break; // End the block } let argc = unsafe { vm_ci_argc((*cd).ci) }; - let args = state.stack_pop_n(argc as usize)?; + let block_arg = (flags & VM_CALL_ARGS_BLOCKARG) != 0; + let args = state.stack_pop_n(argc as usize + usize::from(block_arg))?; let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state }); let result = fun.push_insn(block, Insn::InvokeBlock { cd, args, state: exit_id }); state.stack_push(result); @@ -3767,11 +3774,6 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> { } YARVINSN_objtostring => { let cd: *const rb_call_data = get_arg(pc, 0).as_ptr(); - let call_info = unsafe { rb_get_call_data_ci(cd) }; - - if let Err(call_type) = unknown_call_type(unsafe { rb_vm_ci_flag(call_info) }) { - panic!("objtostring should not have unknown call type {call_type:?}"); - } let argc = unsafe { vm_ci_argc((*cd).ci) }; assert_eq!(0, argc, "objtostring should not have args"); @@ -5115,14 +5117,17 @@ mod tests { } #[test] - fn test_cant_compile_block_arg() { + fn test_compile_block_arg() { eval(" def test(a) = foo(&a) "); assert_snapshot!(hir_string("test"), @r" fn test@<compiled>:2: bb0(v0:BasicObject, v1:BasicObject): - SideExit UnhandledCallType(BlockArg) + v6:BasicObject = Send v0, 0x1000, :foo, v1 + v7:BasicObject = GetLocal l0, EP@3 + CheckInterrupts + Return v6 "); } @@ -5194,7 +5199,9 @@ mod tests { fn test@<compiled>:2: bb0(v0:BasicObject): v4:NilClass = Const Value(nil) - SideExit UnhandledCallType(BlockArg) + v6:BasicObject = InvokeSuper v0, 0x1000, v4 + CheckInterrupts + Return v6 "); } @@ -8074,7 +8081,10 @@ mod opt_tests { fn test@<compiled>:2: bb0(v0:BasicObject, v1:BasicObject): v6:BasicObject = GetBlockParamProxy l0 - SideExit UnhandledCallType(BlockArg) + v8:BasicObject = Send v0, 0x1000, :tap, v6 + v9:BasicObject = GetLocal l0, EP@3 + CheckInterrupts + Return v8 "); } diff --git a/zjit/src/stats.rs b/zjit/src/stats.rs index 8a7073dd97..69748f4ebc 100644 --- a/zjit/src/stats.rs +++ b/zjit/src/stats.rs @@ -92,7 +92,7 @@ make_counters! { // exit_: Side exits reasons exit_compile_error, exit_unknown_newarray_send, - exit_unhandled_call_type, + exit_unhandled_tailcall, exit_unknown_special_variable, exit_unhandled_hir_insn, exit_unhandled_yarv_insn, @@ -209,7 +209,7 @@ pub fn exit_counter_ptr(reason: crate::hir::SideExitReason) -> *mut u64 { use crate::stats::Counter::*; let counter = match reason { UnknownNewarraySend(_) => exit_unknown_newarray_send, - UnhandledCallType(_) => exit_unhandled_call_type, + UnhandledTailCall => exit_unhandled_tailcall, UnknownSpecialVariable(_) => exit_unknown_special_variable, UnhandledHIRInsn(_) => exit_unhandled_hir_insn, UnhandledYARVInsn(_) => exit_unhandled_yarv_insn, |
