summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTakashi Kokubun <takashi.kokubun@shopify.com>2025-09-15 17:43:41 -0700
committerGitHub <noreply@github.com>2025-09-15 17:43:41 -0700
commit6c5960ae1955b66b4a13e03012d53853ee1fd1de (patch)
tree5c968cf42f044c15f4fa4cc04bf770d6ac13fa8a
parente4f09a8c94e6e6d21a6dfa43f71d52e4096234d6 (diff)
ZJIT: Support compiling block args (#14537)
-rw-r--r--test/ruby/test_zjit.rb15
-rw-r--r--zjit/src/backend/lir.rs7
-rw-r--r--zjit/src/hir.rs80
-rw-r--r--zjit/src/stats.rs4
4 files changed, 63 insertions, 43 deletions
diff --git a/test/ruby/test_zjit.rb b/test/ruby/test_zjit.rb
index 87d8b06ece..fc79bdda6e 100644
--- a/test/ruby/test_zjit.rb
+++ b/test/ruby/test_zjit.rb
@@ -439,6 +439,21 @@ class TestZJIT < Test::Unit::TestCase
}, call_threshold: 2
end
+ def test_send_nil_block_arg
+ assert_compiles 'false', %q{
+ def test = block_given?
+ def entry = test(&nil)
+ test
+ }
+ end
+
+ def test_send_symbol_block_arg
+ assert_compiles '["1", "2"]', %q{
+ def test = [1, 2].map(&:to_s)
+ test
+ }
+ end
+
def test_forwardable_iseq
assert_compiles '1', %q{
def test(...) = 1
diff --git a/zjit/src/backend/lir.rs b/zjit/src/backend/lir.rs
index c67bf8c9ba..5f96d0718a 100644
--- a/zjit/src/backend/lir.rs
+++ b/zjit/src/backend/lir.rs
@@ -6,7 +6,7 @@ use crate::cruby::{Qundef, RUBY_OFFSET_CFP_PC, RUBY_OFFSET_CFP_SP, SIZEOF_VALUE_
use crate::hir::SideExitReason;
use crate::options::{debug, get_option};
use crate::cruby::VALUE;
-use crate::stats::{exit_counter_ptr, exit_counter_ptr_for_call_type, exit_counter_ptr_for_opcode, CompileError};
+use crate::stats::{exit_counter_ptr, exit_counter_ptr_for_opcode, CompileError};
use crate::virtualmem::CodePtr;
use crate::asm::{CodeBlock, Label};
@@ -1601,11 +1601,6 @@ impl Assembler
self.load_into(SCRATCH_OPND, Opnd::const_ptr(exit_counter_ptr_for_opcode(opcode)));
self.incr_counter_with_reg(Opnd::mem(64, SCRATCH_OPND, 0), 1.into(), C_RET_OPND);
}
- if let SideExitReason::UnhandledCallType(call_type) = reason {
- asm_comment!(self, "increment an unknown call type counter");
- self.load_into(SCRATCH_OPND, Opnd::const_ptr(exit_counter_ptr_for_call_type(call_type)));
- self.incr_counter_with_reg(Opnd::mem(64, SCRATCH_OPND, 0), 1.into(), C_RET_OPND);
- }
}
asm_comment!(self, "exit to the interpreter");
diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs
index 3726d8ec0e..46bf38dcc6 100644
--- a/zjit/src/hir.rs
+++ b/zjit/src/hir.rs
@@ -447,10 +447,10 @@ impl PtrPrintMap {
#[derive(Debug, Clone, Copy)]
pub enum SideExitReason {
UnknownNewarraySend(vm_opt_newarray_send_type),
- UnhandledCallType(CallType),
UnknownSpecialVariable(u64),
UnhandledHIRInsn(InsnId),
UnhandledYARVInsn(u32),
+ UnhandledTailCall,
FixnumAddOverflow,
FixnumSubOverflow,
FixnumMultOverflow,
@@ -2986,10 +2986,8 @@ fn num_locals(iseq: *const rb_iseq_t) -> usize {
}
/// If we can't handle the type of send (yet), bail out.
-fn unknown_call_type(flag: u32) -> Result<(), CallType> {
- if (flag & VM_CALL_ARGS_BLOCKARG) != 0 { return Err(CallType::BlockArg); }
- if (flag & VM_CALL_TAILCALL) != 0 { return Err(CallType::Tailcall); }
- Ok(())
+fn is_tailcall(flags: u32) -> bool {
+ (flags & VM_CALL_TAILCALL) != 0
}
/// We have IseqPayload, which keeps track of HIR Types in the interpreter, but this is not useful
@@ -3501,10 +3499,11 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
// NB: opt_neq has two cd; get_arg(0) is for eq and get_arg(1) is for neq
let cd: *const rb_call_data = get_arg(pc, 1).as_ptr();
let call_info = unsafe { rb_get_call_data_ci(cd) };
- if let Err(call_type) = unknown_call_type(unsafe { rb_vm_ci_flag(call_info) }) {
- // Unknown call type; side-exit into the interpreter
+ let flags = unsafe { rb_vm_ci_flag(call_info) };
+ if is_tailcall(flags) {
+ // Can't handle tailcall; side-exit into the interpreter
let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
- fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::UnhandledCallType(call_type) });
+ fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::UnhandledTailCall });
break; // End the block
}
let argc = unsafe { vm_ci_argc((*cd).ci) };
@@ -3522,10 +3521,11 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
// NB: these instructions have the recv for the call at get_arg(0)
let cd: *const rb_call_data = get_arg(pc, 1).as_ptr();
let call_info = unsafe { rb_get_call_data_ci(cd) };
- if let Err(call_type) = unknown_call_type(unsafe { rb_vm_ci_flag(call_info) }) {
- // Unknown call type; side-exit into the interpreter
+ let flags = unsafe { rb_vm_ci_flag(call_info) };
+ if is_tailcall(flags) {
+ // Can't handle tailcall; side-exit into the interpreter
let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
- fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::UnhandledCallType(call_type) });
+ fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::UnhandledTailCall });
break; // End the block
}
let argc = unsafe { vm_ci_argc((*cd).ci) };
@@ -3580,10 +3580,11 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
YARVINSN_opt_send_without_block => {
let cd: *const rb_call_data = get_arg(pc, 0).as_ptr();
let call_info = unsafe { rb_get_call_data_ci(cd) };
- if let Err(call_type) = unknown_call_type(unsafe { rb_vm_ci_flag(call_info) }) {
- // Unknown call type; side-exit into the interpreter
+ let flags = unsafe { rb_vm_ci_flag(call_info) };
+ if is_tailcall(flags) {
+ // Can't handle tailcall; side-exit into the interpreter
let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
- fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::UnhandledCallType(call_type) });
+ fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::UnhandledTailCall });
break; // End the block
}
let argc = unsafe { vm_ci_argc((*cd).ci) };
@@ -3598,15 +3599,17 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
let cd: *const rb_call_data = get_arg(pc, 0).as_ptr();
let blockiseq: IseqPtr = get_arg(pc, 1).as_iseq();
let call_info = unsafe { rb_get_call_data_ci(cd) };
- if let Err(call_type) = unknown_call_type(unsafe { rb_vm_ci_flag(call_info) }) {
- // Unknown call type; side-exit into the interpreter
+ let flags = unsafe { rb_vm_ci_flag(call_info) };
+ if is_tailcall(flags) {
+ // Can't handle tailcall; side-exit into the interpreter
let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
- fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::UnhandledCallType(call_type) });
+ fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::UnhandledTailCall });
break; // End the block
}
let argc = unsafe { vm_ci_argc((*cd).ci) };
+ let block_arg = (flags & VM_CALL_ARGS_BLOCKARG) != 0;
- let args = state.stack_pop_n(argc as usize)?;
+ let args = state.stack_pop_n(argc as usize + usize::from(block_arg))?;
let recv = state.stack_pop()?;
let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
let send = fun.push_insn(block, Insn::Send { recv, cd, blockiseq, args, state: exit_id });
@@ -3624,14 +3627,16 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
YARVINSN_invokesuper => {
let cd: *const rb_call_data = get_arg(pc, 0).as_ptr();
let call_info = unsafe { rb_get_call_data_ci(cd) };
- if let Err(call_type) = unknown_call_type(unsafe { rb_vm_ci_flag(call_info) } & !VM_CALL_SUPER & !VM_CALL_ZSUPER) {
- // Unknown call type; side-exit into the interpreter
+ let flags = unsafe { rb_vm_ci_flag(call_info) };
+ if is_tailcall(flags) {
+ // Can't handle tailcall; side-exit into the interpreter
let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
- fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::UnhandledCallType(call_type) });
+ fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::UnhandledTailCall });
break; // End the block
}
let argc = unsafe { vm_ci_argc((*cd).ci) };
- let args = state.stack_pop_n(argc as usize)?;
+ let block_arg = (flags & VM_CALL_ARGS_BLOCKARG) != 0;
+ let args = state.stack_pop_n(argc as usize + usize::from(block_arg))?;
let recv = state.stack_pop()?;
let blockiseq: IseqPtr = get_arg(pc, 1).as_ptr();
let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
@@ -3652,14 +3657,16 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
YARVINSN_invokeblock => {
let cd: *const rb_call_data = get_arg(pc, 0).as_ptr();
let call_info = unsafe { rb_get_call_data_ci(cd) };
- if let Err(call_type) = unknown_call_type(unsafe { rb_vm_ci_flag(call_info) }) {
- // Unknown call type; side-exit into the interpreter
+ let flags = unsafe { rb_vm_ci_flag(call_info) };
+ if is_tailcall(flags) {
+ // Can't handle tailcall; side-exit into the interpreter
let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
- fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::UnhandledCallType(call_type) });
+ fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::UnhandledTailCall });
break; // End the block
}
let argc = unsafe { vm_ci_argc((*cd).ci) };
- let args = state.stack_pop_n(argc as usize)?;
+ let block_arg = (flags & VM_CALL_ARGS_BLOCKARG) != 0;
+ let args = state.stack_pop_n(argc as usize + usize::from(block_arg))?;
let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
let result = fun.push_insn(block, Insn::InvokeBlock { cd, args, state: exit_id });
state.stack_push(result);
@@ -3767,11 +3774,6 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
}
YARVINSN_objtostring => {
let cd: *const rb_call_data = get_arg(pc, 0).as_ptr();
- let call_info = unsafe { rb_get_call_data_ci(cd) };
-
- if let Err(call_type) = unknown_call_type(unsafe { rb_vm_ci_flag(call_info) }) {
- panic!("objtostring should not have unknown call type {call_type:?}");
- }
let argc = unsafe { vm_ci_argc((*cd).ci) };
assert_eq!(0, argc, "objtostring should not have args");
@@ -5115,14 +5117,17 @@ mod tests {
}
#[test]
- fn test_cant_compile_block_arg() {
+ fn test_compile_block_arg() {
eval("
def test(a) = foo(&a)
");
assert_snapshot!(hir_string("test"), @r"
fn test@<compiled>:2:
bb0(v0:BasicObject, v1:BasicObject):
- SideExit UnhandledCallType(BlockArg)
+ v6:BasicObject = Send v0, 0x1000, :foo, v1
+ v7:BasicObject = GetLocal l0, EP@3
+ CheckInterrupts
+ Return v6
");
}
@@ -5194,7 +5199,9 @@ mod tests {
fn test@<compiled>:2:
bb0(v0:BasicObject):
v4:NilClass = Const Value(nil)
- SideExit UnhandledCallType(BlockArg)
+ v6:BasicObject = InvokeSuper v0, 0x1000, v4
+ CheckInterrupts
+ Return v6
");
}
@@ -8074,7 +8081,10 @@ mod opt_tests {
fn test@<compiled>:2:
bb0(v0:BasicObject, v1:BasicObject):
v6:BasicObject = GetBlockParamProxy l0
- SideExit UnhandledCallType(BlockArg)
+ v8:BasicObject = Send v0, 0x1000, :tap, v6
+ v9:BasicObject = GetLocal l0, EP@3
+ CheckInterrupts
+ Return v8
");
}
diff --git a/zjit/src/stats.rs b/zjit/src/stats.rs
index 8a7073dd97..69748f4ebc 100644
--- a/zjit/src/stats.rs
+++ b/zjit/src/stats.rs
@@ -92,7 +92,7 @@ make_counters! {
// exit_: Side exits reasons
exit_compile_error,
exit_unknown_newarray_send,
- exit_unhandled_call_type,
+ exit_unhandled_tailcall,
exit_unknown_special_variable,
exit_unhandled_hir_insn,
exit_unhandled_yarv_insn,
@@ -209,7 +209,7 @@ pub fn exit_counter_ptr(reason: crate::hir::SideExitReason) -> *mut u64 {
use crate::stats::Counter::*;
let counter = match reason {
UnknownNewarraySend(_) => exit_unknown_newarray_send,
- UnhandledCallType(_) => exit_unhandled_call_type,
+ UnhandledTailCall => exit_unhandled_tailcall,
UnknownSpecialVariable(_) => exit_unknown_special_variable,
UnhandledHIRInsn(_) => exit_unhandled_hir_insn,
UnhandledYARVInsn(_) => exit_unhandled_yarv_insn,