summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStan Lo <stan001212@gmail.com>2025-12-01 17:14:08 +0000
committerGitHub <noreply@github.com>2025-12-01 17:14:08 +0000
commitc58970b57a91a10eb75f933258643d0393ce0ba8 (patch)
tree2e60bf50b418b2e07d121ad6a61f1f17fc331f1d
parent8dc5822b007937f75eb0b156c9d9dcc7b16f9de8 (diff)
ZJIT: Optimize variadic cfunc `Send` calls into `CCallVariadic` (#14898)
ZJIT: Optimize variadic cfunc Send calls into CCallVariadic
-rw-r--r--test/ruby/test_zjit.rb15
-rw-r--r--zjit/src/codegen.rs37
-rw-r--r--zjit/src/hir.rs69
-rw-r--r--zjit/src/hir/opt_tests.rs42
4 files changed, 125 insertions, 38 deletions
diff --git a/test/ruby/test_zjit.rb b/test/ruby/test_zjit.rb
index 7472ff7715..d821d8ad5c 100644
--- a/test/ruby/test_zjit.rb
+++ b/test/ruby/test_zjit.rb
@@ -525,6 +525,21 @@ class TestZJIT < Test::Unit::TestCase
}
end
+ def test_send_variadic_with_block
+ assert_compiles '[[1, "a"], [2, "b"], [3, "c"]]', %q{
+ A = [1, 2, 3]
+ B = ["a", "b", "c"]
+
+ def test
+ result = []
+ A.zip(B) { |x, y| result << [x, y] }
+ result
+ end
+
+ test; test
+ }, call_threshold: 2
+ end
+
def test_send_splat
assert_runs '[1, 2]', %q{
def test(a, b) = [a, b]
diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs
index 66436b2374..5200894f87 100644
--- a/zjit/src/codegen.rs
+++ b/zjit/src/codegen.rs
@@ -431,8 +431,8 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio
gen_send_without_block(jit, asm, *cd, &function.frame_state(*state), SendFallbackReason::CCallWithFrameTooManyArgs),
Insn::CCallWithFrame { cfunc, name, args, cme, state, blockiseq, .. } =>
gen_ccall_with_frame(jit, asm, *cfunc, *name, opnds!(args), *cme, *blockiseq, &function.frame_state(*state)),
- Insn::CCallVariadic { cfunc, recv, args, name, cme, state, return_type: _, elidable: _ } => {
- gen_ccall_variadic(jit, asm, *cfunc, *name, opnd!(recv), opnds!(args), *cme, &function.frame_state(*state))
+ Insn::CCallVariadic { cfunc, recv, args, name, cme, state, blockiseq, .. } => {
+ gen_ccall_variadic(jit, asm, *cfunc, *name, opnd!(recv), opnds!(args), *cme, *blockiseq, &function.frame_state(*state))
}
Insn::GetIvar { self_val, id, ic, state: _ } => gen_getivar(jit, asm, opnd!(self_val), *id, *ic),
Insn::SetGlobal { id, val, state } => no_output!(gen_setglobal(jit, asm, *id, opnd!(val), &function.frame_state(*state))),
@@ -845,26 +845,47 @@ fn gen_ccall_variadic(
recv: Opnd,
args: Vec<Opnd>,
cme: *const rb_callable_method_entry_t,
+ blockiseq: Option<IseqPtr>,
state: &FrameState,
) -> lir::Opnd {
gen_incr_counter(asm, Counter::variadic_cfunc_optimized_send_count);
+ gen_stack_overflow_check(jit, asm, state, state.stack_size());
- gen_prepare_non_leaf_call(jit, asm, state);
+ let args_with_recv_len = args.len() + 1;
- let stack_growth = state.stack_size();
- gen_stack_overflow_check(jit, asm, state, stack_growth);
+ // Compute the caller's stack size after consuming recv and args.
+ // state.stack() includes recv + args, so subtract both.
+ let caller_stack_size = state.stack_size() - args_with_recv_len;
- gen_push_frame(asm, args.len(), state, ControlFrame {
+ // Can't use gen_prepare_non_leaf_call() because we need to adjust the SP
+ // to account for the receiver and arguments (like gen_ccall_with_frame does)
+ gen_prepare_call_with_gc(asm, state, false);
+ gen_save_sp(asm, caller_stack_size);
+ gen_spill_stack(jit, asm, state);
+ gen_spill_locals(jit, asm, state);
+
+ let block_handler_specval = if let Some(block_iseq) = blockiseq {
+ // Change cfp->block_code in the current frame. See vm_caller_setup_arg_block().
+ // VM_CFP_TO_CAPTURED_BLOCK then turns &cfp->self into a block handler.
+ // rb_captured_block->code.iseq aliases with cfp->block_code.
+ asm.store(Opnd::mem(64, CFP, RUBY_OFFSET_CFP_BLOCK_CODE), VALUE::from(block_iseq).into());
+ let cfp_self_addr = asm.lea(Opnd::mem(64, CFP, RUBY_OFFSET_CFP_SELF));
+ asm.or(cfp_self_addr, Opnd::Imm(1))
+ } else {
+ VM_BLOCK_HANDLER_NONE.into()
+ };
+
+ gen_push_frame(asm, args_with_recv_len, state, ControlFrame {
recv,
iseq: None,
cme,
frame_type: VM_FRAME_MAGIC_CFUNC | VM_FRAME_FLAG_CFRAME | VM_ENV_FLAG_LOCAL,
- specval: VM_BLOCK_HANDLER_NONE.into(),
+ specval: block_handler_specval,
pc: PC_POISON,
});
asm_comment!(asm, "switch to new SP register");
- let sp_offset = (state.stack().len() - args.len() + VM_ENV_DATA_SIZE.to_usize()) * SIZEOF_VALUE;
+ let sp_offset = (caller_stack_size + VM_ENV_DATA_SIZE.to_usize()) * SIZEOF_VALUE;
let new_sp = asm.add(SP, sp_offset.into());
asm.mov(SP, new_sp);
diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs
index fbc9d80700..8c1ec664d0 100644
--- a/zjit/src/hir.rs
+++ b/zjit/src/hir.rs
@@ -804,6 +804,7 @@ pub enum Insn {
state: InsnId,
return_type: Type,
elidable: bool,
+ blockiseq: Option<IseqPtr>,
},
/// Un-optimized fallback implementation (dynamic dispatch) for send-ish instructions
@@ -1920,8 +1921,8 @@ impl Function {
elidable,
blockiseq,
},
- &CCallVariadic { cfunc, recv, ref args, cme, name, state, return_type, elidable } => CCallVariadic {
- cfunc, recv: find!(recv), args: find_vec!(args), cme, name, state, return_type, elidable
+ &CCallVariadic { cfunc, recv, ref args, cme, name, state, return_type, elidable, blockiseq } => CCallVariadic {
+ cfunc, recv: find!(recv), args: find_vec!(args), cme, name, state, return_type, elidable, blockiseq
},
&Defined { op_type, obj, pushval, v, state } => Defined { op_type, obj, pushval, v: find!(v), state: find!(state) },
&DefinedIvar { self_val, pushval, id, state } => DefinedIvar { self_val: find!(self_val), pushval, id, state },
@@ -2989,9 +2990,23 @@ impl Function {
return Err(());
}
- // Find the `argc` (arity) of the C method, which describes the parameters it expects
+ let ci_flags = unsafe { vm_ci_flag(call_info) };
+
+ // When seeing &block argument, fall back to dynamic dispatch for now
+ // TODO: Support block forwarding
+ if unspecializable_call_type(ci_flags) {
+ fun.count_complex_call_features(block, ci_flags);
+ fun.set_dynamic_send_reason(send_insn_id, ComplexArgPass);
+ return Err(());
+ }
+
+ let blockiseq = if blockiseq.is_null() { None } else { Some(blockiseq) };
+
let cfunc = unsafe { get_cme_def_body_cfunc(cme) };
+ // Find the `argc` (arity) of the C method, which describes the parameters it expects
let cfunc_argc = unsafe { get_mct_argc(cfunc) };
+ let cfunc_ptr = unsafe { get_mct_func(cfunc) }.cast();
+
match cfunc_argc {
0.. => {
// (self, arg0, arg1, ..., argc) form
@@ -3001,16 +3016,6 @@ impl Function {
return Err(());
}
- let ci_flags = unsafe { vm_ci_flag(call_info) };
-
- // When seeing &block argument, fall back to dynamic dispatch for now
- // TODO: Support block forwarding
- if unspecializable_call_type(ci_flags) {
- fun.count_complex_call_features(block, ci_flags);
- fun.set_dynamic_send_reason(send_insn_id, ComplexArgPass);
- return Err(());
- }
-
// Commit to the replacement. Put PatchPoint.
fun.gen_patch_points_for_optimized_ccall(block, recv_class, method_id, cme, state);
if recv_class.instance_can_have_singleton_class() {
@@ -3023,17 +3028,14 @@ impl Function {
fun.insn_types[recv.0] = fun.infer_type(recv);
}
- let blockiseq = if blockiseq.is_null() { None } else { Some(blockiseq) };
-
// Emit a call
- let cfunc = unsafe { get_mct_func(cfunc) }.cast();
let mut cfunc_args = vec![recv];
cfunc_args.append(&mut args);
let name = rust_str_to_id(&qualified_method_name(unsafe { (*cme).owner }, unsafe { (*cme).called_id }));
let ccall = fun.push_insn(block, Insn::CCallWithFrame {
cd,
- cfunc,
+ cfunc: cfunc_ptr,
args: cfunc_args,
cme,
name,
@@ -3047,9 +3049,37 @@ impl Function {
}
// Variadic method
-1 => {
+ // The method gets a pointer to the first argument
// func(int argc, VALUE *argv, VALUE recv)
- fun.set_dynamic_send_reason(send_insn_id, SendCfuncVariadic);
- Err(())
+ fun.gen_patch_points_for_optimized_ccall(block, recv_class, method_id, cme, state);
+
+ if recv_class.instance_can_have_singleton_class() {
+ fun.push_insn(block, Insn::PatchPoint { invariant: Invariant::NoSingletonClass { klass: recv_class }, state });
+ }
+ if let Some(profiled_type) = profiled_type {
+ // Guard receiver class
+ recv = fun.push_insn(block, Insn::GuardType { val: recv, guard_type: Type::from_profiled_type(profiled_type), state });
+ fun.insn_types[recv.0] = fun.infer_type(recv);
+ }
+
+ if get_option!(stats) {
+ count_not_inlined_cfunc(fun, block, cme);
+ }
+
+ let ccall = fun.push_insn(block, Insn::CCallVariadic {
+ cfunc: cfunc_ptr,
+ recv,
+ args,
+ cme,
+ name: method_id,
+ state,
+ return_type: types::BasicObject,
+ elidable: false,
+ blockiseq
+ });
+
+ fun.make_equal_to(send_insn_id, ccall);
+ Ok(())
}
-2 => {
// (self, args_ruby_array)
@@ -3252,6 +3282,7 @@ impl Function {
state,
return_type,
elidable,
+ blockiseq: None,
});
fun.make_equal_to(send_insn_id, ccall);
diff --git a/zjit/src/hir/opt_tests.rs b/zjit/src/hir/opt_tests.rs
index b32da5a9eb..26dad38b58 100644
--- a/zjit/src/hir/opt_tests.rs
+++ b/zjit/src/hir/opt_tests.rs
@@ -5302,26 +5302,46 @@ mod hir_opt_tests {
}
#[test]
- fn test_do_not_optimize_send_variadic_with_block() {
+ fn test_optimize_send_variadic_with_block() {
eval(r#"
- def test = [1, 2, 3].index { |x| x == 2 }
+ A = [1, 2, 3]
+ B = ["a", "b", "c"]
+
+ def test
+ result = []
+ A.zip(B) { |x, y| result << [x, y] }
+ result
+ end
+
test; test
"#);
assert_snapshot!(hir_string("test"), @r"
- fn test@<compiled>:2:
+ fn test@<compiled>:6:
bb0():
EntryPoint interpreter
v1:BasicObject = LoadSelf
- Jump bb2(v1)
- bb1(v4:BasicObject):
+ v2:NilClass = Const Value(nil)
+ Jump bb2(v1, v2)
+ bb1(v5:BasicObject):
EntryPoint JIT(0)
- Jump bb2(v4)
- bb2(v6:BasicObject):
- v10:ArrayExact[VALUE(0x1000)] = Const Value(VALUE(0x1000))
- v11:ArrayExact = ArrayDup v10
- v13:BasicObject = Send v11, 0x1008, :index
+ v6:NilClass = Const Value(nil)
+ Jump bb2(v5, v6)
+ bb2(v8:BasicObject, v9:NilClass):
+ v13:ArrayExact = NewArray
+ SetLocal l0, EP@3, v13
+ PatchPoint SingleRactorMode
+ PatchPoint StableConstantNames(0x1000, A)
+ v36:ArrayExact[VALUE(0x1008)] = Const Value(VALUE(0x1008))
+ PatchPoint SingleRactorMode
+ PatchPoint StableConstantNames(0x1010, B)
+ v39:ArrayExact[VALUE(0x1018)] = Const Value(VALUE(0x1018))
+ PatchPoint MethodRedefined(Array@0x1020, zip@0x1028, cme:0x1030)
+ PatchPoint NoSingletonClass(Array@0x1020)
+ v43:BasicObject = CCallVariadic zip@0x1058, v36, v39
+ v25:BasicObject = GetLocal l0, EP@3
+ v29:BasicObject = GetLocal l0, EP@3
CheckInterrupts
- Return v13
+ Return v29
");
}