diff options
| author | Stan Lo <stan001212@gmail.com> | 2025-12-01 17:14:08 +0000 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-12-01 17:14:08 +0000 |
| commit | c58970b57a91a10eb75f933258643d0393ce0ba8 (patch) | |
| tree | 2e60bf50b418b2e07d121ad6a61f1f17fc331f1d | |
| parent | 8dc5822b007937f75eb0b156c9d9dcc7b16f9de8 (diff) | |
ZJIT: Optimize variadic cfunc `Send` calls into `CCallVariadic` (#14898)
ZJIT: Optimize variadic cfunc Send calls into CCallVariadic
| -rw-r--r-- | test/ruby/test_zjit.rb | 15 | ||||
| -rw-r--r-- | zjit/src/codegen.rs | 37 | ||||
| -rw-r--r-- | zjit/src/hir.rs | 69 | ||||
| -rw-r--r-- | zjit/src/hir/opt_tests.rs | 42 |
4 files changed, 125 insertions, 38 deletions
diff --git a/test/ruby/test_zjit.rb b/test/ruby/test_zjit.rb index 7472ff7715..d821d8ad5c 100644 --- a/test/ruby/test_zjit.rb +++ b/test/ruby/test_zjit.rb @@ -525,6 +525,21 @@ class TestZJIT < Test::Unit::TestCase } end + def test_send_variadic_with_block + assert_compiles '[[1, "a"], [2, "b"], [3, "c"]]', %q{ + A = [1, 2, 3] + B = ["a", "b", "c"] + + def test + result = [] + A.zip(B) { |x, y| result << [x, y] } + result + end + + test; test + }, call_threshold: 2 + end + def test_send_splat assert_runs '[1, 2]', %q{ def test(a, b) = [a, b] diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs index 66436b2374..5200894f87 100644 --- a/zjit/src/codegen.rs +++ b/zjit/src/codegen.rs @@ -431,8 +431,8 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio gen_send_without_block(jit, asm, *cd, &function.frame_state(*state), SendFallbackReason::CCallWithFrameTooManyArgs), Insn::CCallWithFrame { cfunc, name, args, cme, state, blockiseq, .. } => gen_ccall_with_frame(jit, asm, *cfunc, *name, opnds!(args), *cme, *blockiseq, &function.frame_state(*state)), - Insn::CCallVariadic { cfunc, recv, args, name, cme, state, return_type: _, elidable: _ } => { - gen_ccall_variadic(jit, asm, *cfunc, *name, opnd!(recv), opnds!(args), *cme, &function.frame_state(*state)) + Insn::CCallVariadic { cfunc, recv, args, name, cme, state, blockiseq, .. } => { + gen_ccall_variadic(jit, asm, *cfunc, *name, opnd!(recv), opnds!(args), *cme, *blockiseq, &function.frame_state(*state)) } Insn::GetIvar { self_val, id, ic, state: _ } => gen_getivar(jit, asm, opnd!(self_val), *id, *ic), Insn::SetGlobal { id, val, state } => no_output!(gen_setglobal(jit, asm, *id, opnd!(val), &function.frame_state(*state))), @@ -845,26 +845,47 @@ fn gen_ccall_variadic( recv: Opnd, args: Vec<Opnd>, cme: *const rb_callable_method_entry_t, + blockiseq: Option<IseqPtr>, state: &FrameState, ) -> lir::Opnd { gen_incr_counter(asm, Counter::variadic_cfunc_optimized_send_count); + gen_stack_overflow_check(jit, asm, state, state.stack_size()); - gen_prepare_non_leaf_call(jit, asm, state); + let args_with_recv_len = args.len() + 1; - let stack_growth = state.stack_size(); - gen_stack_overflow_check(jit, asm, state, stack_growth); + // Compute the caller's stack size after consuming recv and args. + // state.stack() includes recv + args, so subtract both. + let caller_stack_size = state.stack_size() - args_with_recv_len; - gen_push_frame(asm, args.len(), state, ControlFrame { + // Can't use gen_prepare_non_leaf_call() because we need to adjust the SP + // to account for the receiver and arguments (like gen_ccall_with_frame does) + gen_prepare_call_with_gc(asm, state, false); + gen_save_sp(asm, caller_stack_size); + gen_spill_stack(jit, asm, state); + gen_spill_locals(jit, asm, state); + + let block_handler_specval = if let Some(block_iseq) = blockiseq { + // Change cfp->block_code in the current frame. See vm_caller_setup_arg_block(). + // VM_CFP_TO_CAPTURED_BLOCK then turns &cfp->self into a block handler. + // rb_captured_block->code.iseq aliases with cfp->block_code. + asm.store(Opnd::mem(64, CFP, RUBY_OFFSET_CFP_BLOCK_CODE), VALUE::from(block_iseq).into()); + let cfp_self_addr = asm.lea(Opnd::mem(64, CFP, RUBY_OFFSET_CFP_SELF)); + asm.or(cfp_self_addr, Opnd::Imm(1)) + } else { + VM_BLOCK_HANDLER_NONE.into() + }; + + gen_push_frame(asm, args_with_recv_len, state, ControlFrame { recv, iseq: None, cme, frame_type: VM_FRAME_MAGIC_CFUNC | VM_FRAME_FLAG_CFRAME | VM_ENV_FLAG_LOCAL, - specval: VM_BLOCK_HANDLER_NONE.into(), + specval: block_handler_specval, pc: PC_POISON, }); asm_comment!(asm, "switch to new SP register"); - let sp_offset = (state.stack().len() - args.len() + VM_ENV_DATA_SIZE.to_usize()) * SIZEOF_VALUE; + let sp_offset = (caller_stack_size + VM_ENV_DATA_SIZE.to_usize()) * SIZEOF_VALUE; let new_sp = asm.add(SP, sp_offset.into()); asm.mov(SP, new_sp); diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index fbc9d80700..8c1ec664d0 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -804,6 +804,7 @@ pub enum Insn { state: InsnId, return_type: Type, elidable: bool, + blockiseq: Option<IseqPtr>, }, /// Un-optimized fallback implementation (dynamic dispatch) for send-ish instructions @@ -1920,8 +1921,8 @@ impl Function { elidable, blockiseq, }, - &CCallVariadic { cfunc, recv, ref args, cme, name, state, return_type, elidable } => CCallVariadic { - cfunc, recv: find!(recv), args: find_vec!(args), cme, name, state, return_type, elidable + &CCallVariadic { cfunc, recv, ref args, cme, name, state, return_type, elidable, blockiseq } => CCallVariadic { + cfunc, recv: find!(recv), args: find_vec!(args), cme, name, state, return_type, elidable, blockiseq }, &Defined { op_type, obj, pushval, v, state } => Defined { op_type, obj, pushval, v: find!(v), state: find!(state) }, &DefinedIvar { self_val, pushval, id, state } => DefinedIvar { self_val: find!(self_val), pushval, id, state }, @@ -2989,9 +2990,23 @@ impl Function { return Err(()); } - // Find the `argc` (arity) of the C method, which describes the parameters it expects + let ci_flags = unsafe { vm_ci_flag(call_info) }; + + // When seeing &block argument, fall back to dynamic dispatch for now + // TODO: Support block forwarding + if unspecializable_call_type(ci_flags) { + fun.count_complex_call_features(block, ci_flags); + fun.set_dynamic_send_reason(send_insn_id, ComplexArgPass); + return Err(()); + } + + let blockiseq = if blockiseq.is_null() { None } else { Some(blockiseq) }; + let cfunc = unsafe { get_cme_def_body_cfunc(cme) }; + // Find the `argc` (arity) of the C method, which describes the parameters it expects let cfunc_argc = unsafe { get_mct_argc(cfunc) }; + let cfunc_ptr = unsafe { get_mct_func(cfunc) }.cast(); + match cfunc_argc { 0.. => { // (self, arg0, arg1, ..., argc) form @@ -3001,16 +3016,6 @@ impl Function { return Err(()); } - let ci_flags = unsafe { vm_ci_flag(call_info) }; - - // When seeing &block argument, fall back to dynamic dispatch for now - // TODO: Support block forwarding - if unspecializable_call_type(ci_flags) { - fun.count_complex_call_features(block, ci_flags); - fun.set_dynamic_send_reason(send_insn_id, ComplexArgPass); - return Err(()); - } - // Commit to the replacement. Put PatchPoint. fun.gen_patch_points_for_optimized_ccall(block, recv_class, method_id, cme, state); if recv_class.instance_can_have_singleton_class() { @@ -3023,17 +3028,14 @@ impl Function { fun.insn_types[recv.0] = fun.infer_type(recv); } - let blockiseq = if blockiseq.is_null() { None } else { Some(blockiseq) }; - // Emit a call - let cfunc = unsafe { get_mct_func(cfunc) }.cast(); let mut cfunc_args = vec![recv]; cfunc_args.append(&mut args); let name = rust_str_to_id(&qualified_method_name(unsafe { (*cme).owner }, unsafe { (*cme).called_id })); let ccall = fun.push_insn(block, Insn::CCallWithFrame { cd, - cfunc, + cfunc: cfunc_ptr, args: cfunc_args, cme, name, @@ -3047,9 +3049,37 @@ impl Function { } // Variadic method -1 => { + // The method gets a pointer to the first argument // func(int argc, VALUE *argv, VALUE recv) - fun.set_dynamic_send_reason(send_insn_id, SendCfuncVariadic); - Err(()) + fun.gen_patch_points_for_optimized_ccall(block, recv_class, method_id, cme, state); + + if recv_class.instance_can_have_singleton_class() { + fun.push_insn(block, Insn::PatchPoint { invariant: Invariant::NoSingletonClass { klass: recv_class }, state }); + } + if let Some(profiled_type) = profiled_type { + // Guard receiver class + recv = fun.push_insn(block, Insn::GuardType { val: recv, guard_type: Type::from_profiled_type(profiled_type), state }); + fun.insn_types[recv.0] = fun.infer_type(recv); + } + + if get_option!(stats) { + count_not_inlined_cfunc(fun, block, cme); + } + + let ccall = fun.push_insn(block, Insn::CCallVariadic { + cfunc: cfunc_ptr, + recv, + args, + cme, + name: method_id, + state, + return_type: types::BasicObject, + elidable: false, + blockiseq + }); + + fun.make_equal_to(send_insn_id, ccall); + Ok(()) } -2 => { // (self, args_ruby_array) @@ -3252,6 +3282,7 @@ impl Function { state, return_type, elidable, + blockiseq: None, }); fun.make_equal_to(send_insn_id, ccall); diff --git a/zjit/src/hir/opt_tests.rs b/zjit/src/hir/opt_tests.rs index b32da5a9eb..26dad38b58 100644 --- a/zjit/src/hir/opt_tests.rs +++ b/zjit/src/hir/opt_tests.rs @@ -5302,26 +5302,46 @@ mod hir_opt_tests { } #[test] - fn test_do_not_optimize_send_variadic_with_block() { + fn test_optimize_send_variadic_with_block() { eval(r#" - def test = [1, 2, 3].index { |x| x == 2 } + A = [1, 2, 3] + B = ["a", "b", "c"] + + def test + result = [] + A.zip(B) { |x, y| result << [x, y] } + result + end + test; test "#); assert_snapshot!(hir_string("test"), @r" - fn test@<compiled>:2: + fn test@<compiled>:6: bb0(): EntryPoint interpreter v1:BasicObject = LoadSelf - Jump bb2(v1) - bb1(v4:BasicObject): + v2:NilClass = Const Value(nil) + Jump bb2(v1, v2) + bb1(v5:BasicObject): EntryPoint JIT(0) - Jump bb2(v4) - bb2(v6:BasicObject): - v10:ArrayExact[VALUE(0x1000)] = Const Value(VALUE(0x1000)) - v11:ArrayExact = ArrayDup v10 - v13:BasicObject = Send v11, 0x1008, :index + v6:NilClass = Const Value(nil) + Jump bb2(v5, v6) + bb2(v8:BasicObject, v9:NilClass): + v13:ArrayExact = NewArray + SetLocal l0, EP@3, v13 + PatchPoint SingleRactorMode + PatchPoint StableConstantNames(0x1000, A) + v36:ArrayExact[VALUE(0x1008)] = Const Value(VALUE(0x1008)) + PatchPoint SingleRactorMode + PatchPoint StableConstantNames(0x1010, B) + v39:ArrayExact[VALUE(0x1018)] = Const Value(VALUE(0x1018)) + PatchPoint MethodRedefined(Array@0x1020, zip@0x1028, cme:0x1030) + PatchPoint NoSingletonClass(Array@0x1020) + v43:BasicObject = CCallVariadic zip@0x1058, v36, v39 + v25:BasicObject = GetLocal l0, EP@3 + v29:BasicObject = GetLocal l0, EP@3 CheckInterrupts - Return v13 + Return v29 "); } |
