diff options
| author | Max Bernstein <rubybugs@bernsteinbear.com> | 2026-01-30 12:29:15 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-01-30 12:29:15 -0500 |
| commit | 1298f9ac1ad390594815bc4d3739eb312bca8887 (patch) | |
| tree | f14579f22262b032ef58a334a11856b5ef03414d | |
| parent | 9be01bc70dca0e727fe1f518ebae1f6f72405b84 (diff) | |
ZJIT: Support CFunc inlining in InvokeSuper (#16004)
Also generally make the CFunc process look more like `optimize_c_calls`.
| -rw-r--r-- | zjit/src/codegen.rs | 20 | ||||
| -rw-r--r-- | zjit/src/hir.rs | 179 | ||||
| -rw-r--r-- | zjit/src/hir/opt_tests.rs | 51 |
3 files changed, 182 insertions, 68 deletions
diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs index 2038be808d..a3068ff23d 100644 --- a/zjit/src/codegen.rs +++ b/zjit/src/codegen.rs @@ -2330,6 +2330,26 @@ fn gen_guard_type(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, guard let tag = asm.and(flags, Opnd::UImm(RUBY_T_MASK as u64)); asm.cmp(tag, Opnd::UImm(RUBY_T_STRING as u64)); asm.jne(side); + } else if guard_type.is_subtype(types::Array) { + let side = side_exit(jit, state, GuardType(guard_type)); + + // Check special constant + asm.test(val, Opnd::UImm(RUBY_IMMEDIATE_MASK as u64)); + asm.jnz(side.clone()); + + // Check false + asm.cmp(val, Qfalse.into()); + asm.je(side.clone()); + + let val = match val { + Opnd::Reg(_) | Opnd::VReg { .. } => val, + _ => asm.load(val), + }; + + let flags = asm.load(Opnd::mem(VALUE_BITS, val, RUBY_OFFSET_RBASIC_FLAGS)); + let tag = asm.and(flags, Opnd::UImm(RUBY_T_MASK as u64)); + asm.cmp(tag, Opnd::UImm(RUBY_T_ARRAY as u64)); + asm.jne(side); } else if guard_type.bit_equal(types::HeapBasicObject) { let side_exit = side_exit(jit, state, GuardType(guard_type)); asm.cmp(val, Opnd::Value(Qfalse)); diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index 901beffea0..9aa70b5d34 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -3564,13 +3564,7 @@ impl Function { assert!(flags & VM_CALL_FCALL != 0); // Reject calls with complex argument handling. - let complex_arg_types = VM_CALL_ARGS_SPLAT - | VM_CALL_KW_SPLAT - | VM_CALL_KWARG - | VM_CALL_ARGS_BLOCKARG - | VM_CALL_FORWARDING; - - if (flags & complex_arg_types) != 0 { + if unspecializable_c_call_type(flags) { self.push_insn_id(block, insn_id); self.set_dynamic_send_reason(insn_id, SuperComplexArgsPass); continue; @@ -3608,14 +3602,18 @@ impl Function { } // Look up the super method. - let super_cme = unsafe { rb_callable_method_entry(superclass, mid) }; + let mut super_cme = unsafe { rb_callable_method_entry(superclass, mid) }; if super_cme.is_null() { self.push_insn_id(block, insn_id); self.set_dynamic_send_reason(insn_id, SuperTargetNotFound); continue; } - let def_type = unsafe { get_cme_def_type(super_cme) }; + let mut def_type = unsafe { get_cme_def_type(super_cme) }; + while def_type == VM_METHOD_TYPE_ALIAS { + super_cme = unsafe { rb_aliased_callable_method_entry(super_cme) }; + def_type = unsafe { get_cme_def_type(super_cme) }; + } if def_type == VM_METHOD_TYPE_ISEQ { // Check if the super method's parameters support direct send. @@ -3653,6 +3651,12 @@ impl Function { let cfunc_argc = unsafe { get_mct_argc(cfunc) }; let cfunc_ptr = unsafe { get_mct_func(cfunc) }.cast(); + let props = ZJITState::get_method_annotations().get_cfunc_properties(super_cme); + if props.is_none() && get_option!(stats) { + self.count_not_annotated_cfunc(block, super_cme); + } + let props = props.unwrap_or_default(); + match cfunc_argc { // C function with fixed argument count. 0.. => { @@ -3665,20 +3669,48 @@ impl Function { emit_super_call_guards(self, block, super_cme, current_cme, mid, state); + // Try inlining the cfunc into HIR + let tmp_block = self.new_block(u32::MAX); + if let Some(replacement) = (props.inline)(self, tmp_block, recv, &args, state) { + // Copy contents of tmp_block to block + assert_ne!(block, tmp_block); + let insns = std::mem::take(&mut self.blocks[tmp_block.0].insns); + self.blocks[block.0].insns.extend(insns); + self.push_insn(block, Insn::IncrCounter(Counter::inline_cfunc_optimized_send_count)); + self.make_equal_to(insn_id, replacement); + if self.type_of(replacement).bit_equal(types::Any) { + // Not set yet; infer type + self.insn_types[replacement.0] = self.infer_type(replacement); + } + self.remove_block(tmp_block); + continue; + } + // Use CCallWithFrame for the C function. let name = rust_str_to_id(&qualified_method_name(unsafe { (*super_cme).owner }, unsafe { (*super_cme).called_id })); - let ccall = self.push_insn(block, Insn::CCallWithFrame { - cd, - cfunc: cfunc_ptr, - recv, - args: args.clone(), - cme: super_cme, - name, - state, - return_type: types::BasicObject, - elidable: false, - blockiseq: None, - }); + let return_type = props.return_type; + let elidable = props.elidable; + // Filter for a leaf and GC free function + let ccall = if props.leaf && props.no_gc { + self.push_insn(block, Insn::IncrCounter(Counter::inline_cfunc_optimized_send_count)); + self.push_insn(block, Insn::CCall { cfunc: cfunc_ptr, recv, args, name, return_type, elidable }) + } else { + if get_option!(stats) { + self.count_not_inlined_cfunc(block, super_cme); + } + self.push_insn(block, Insn::CCallWithFrame { + cd, + cfunc: cfunc_ptr, + recv, + args: args.clone(), + cme: super_cme, + name, + state, + return_type: types::BasicObject, + elidable: false, + blockiseq: None, + }) + }; self.make_equal_to(insn_id, ccall); } @@ -3686,19 +3718,48 @@ impl Function { -1 => { emit_super_call_guards(self, block, super_cme, current_cme, mid, state); + // Try inlining the cfunc into HIR + let tmp_block = self.new_block(u32::MAX); + if let Some(replacement) = (props.inline)(self, tmp_block, recv, &args, state) { + // Copy contents of tmp_block to block + assert_ne!(block, tmp_block); + emit_super_call_guards(self, block, super_cme, current_cme, mid, state); + let insns = std::mem::take(&mut self.blocks[tmp_block.0].insns); + self.blocks[block.0].insns.extend(insns); + self.push_insn(block, Insn::IncrCounter(Counter::inline_cfunc_optimized_send_count)); + self.make_equal_to(insn_id, replacement); + if self.type_of(replacement).bit_equal(types::Any) { + // Not set yet; infer type + self.insn_types[replacement.0] = self.infer_type(replacement); + } + self.remove_block(tmp_block); + continue; + } + // Use CCallVariadic for the variadic C function. let name = rust_str_to_id(&qualified_method_name(unsafe { (*super_cme).owner }, unsafe { (*super_cme).called_id })); - let ccall = self.push_insn(block, Insn::CCallVariadic { - cfunc: cfunc_ptr, - recv, - args: args.clone(), - cme: super_cme, - name, - state, - return_type: types::BasicObject, - elidable: false, - blockiseq: None, - }); + let return_type = props.return_type; + let elidable = props.elidable; + // Filter for a leaf and GC free function + let ccall = if props.leaf && props.no_gc { + self.push_insn(block, Insn::IncrCounter(Counter::inline_cfunc_optimized_send_count)); + self.push_insn(block, Insn::CCall { cfunc: cfunc_ptr, recv, args, name, return_type, elidable }) + } else { + if get_option!(stats) { + self.count_not_inlined_cfunc(block, super_cme); + } + self.push_insn(block, Insn::CCallVariadic { + cfunc: cfunc_ptr, + recv, + args: args.clone(), + cme: super_cme, + name, + state, + return_type: types::BasicObject, + elidable: false, + blockiseq: None, + }) + }; self.make_equal_to(insn_id, ccall); } @@ -3981,6 +4042,28 @@ impl Function { self.push_insn(block, Insn::PatchPoint { invariant: Invariant::MethodRedefined { klass: recv_class, method: method_id, cme }, state }); } + fn count_not_inlined_cfunc(&mut self, block: BlockId, cme: *const rb_callable_method_entry_t) { + let owner = unsafe { (*cme).owner }; + let called_id = unsafe { (*cme).called_id }; + let qualified_method_name = qualified_method_name(owner, called_id); + let not_inlined_cfunc_counter_pointers = ZJITState::get_not_inlined_cfunc_counter_pointers(); + let counter_ptr = not_inlined_cfunc_counter_pointers.entry(qualified_method_name.clone()).or_insert_with(|| Box::new(0)); + let counter_ptr = &mut **counter_ptr as *mut u64; + + self.push_insn(block, Insn::IncrCounterPtr { counter_ptr }); + } + + fn count_not_annotated_cfunc(&mut self, block: BlockId, cme: *const rb_callable_method_entry_t) { + let owner = unsafe { (*cme).owner }; + let called_id = unsafe { (*cme).called_id }; + let qualified_method_name = qualified_method_name(owner, called_id); + let not_annotated_cfunc_counter_pointers = ZJITState::get_not_annotated_cfunc_counter_pointers(); + let counter_ptr = not_annotated_cfunc_counter_pointers.entry(qualified_method_name.clone()).or_insert_with(|| Box::new(0)); + let counter_ptr = &mut **counter_ptr as *mut u64; + + self.push_insn(block, Insn::IncrCounterPtr { counter_ptr }); + } + /// Optimize Send/SendWithoutBlock that land in a C method to a direct CCall without /// runtime lookup. fn optimize_c_calls(&mut self) { @@ -4124,7 +4207,7 @@ impl Function { } if get_option!(stats) { - count_not_inlined_cfunc(fun, block, cme); + fun.count_not_inlined_cfunc(block, cme); } let ccall = fun.push_insn(block, Insn::CCallVariadic { @@ -4238,7 +4321,7 @@ impl Function { let props = ZJITState::get_method_annotations().get_cfunc_properties(cme); if props.is_none() && get_option!(stats) { - count_not_annotated_cfunc(fun, block, cme); + fun.count_not_annotated_cfunc(block, cme); } let props = props.unwrap_or_default(); @@ -4277,7 +4360,7 @@ impl Function { fun.make_equal_to(send_insn_id, ccall); } else { if get_option!(stats) { - count_not_inlined_cfunc(fun, block, cme); + fun.count_not_inlined_cfunc(block, cme); } let ccall = fun.push_insn(block, Insn::CCallWithFrame { cd, @@ -4326,7 +4409,7 @@ impl Function { let cfunc = unsafe { get_mct_func(cfunc) }.cast(); let props = ZJITState::get_method_annotations().get_cfunc_properties(cme); if props.is_none() && get_option!(stats) { - count_not_annotated_cfunc(fun, block, cme); + fun.count_not_annotated_cfunc(block, cme); } let props = props.unwrap_or_default(); @@ -4349,7 +4432,7 @@ impl Function { // No inlining; emit a call if get_option!(stats) { - count_not_inlined_cfunc(fun, block, cme); + fun.count_not_inlined_cfunc(block, cme); } let return_type = props.return_type; let elidable = props.elidable; @@ -4383,28 +4466,6 @@ impl Function { Err(()) } - fn count_not_inlined_cfunc(fun: &mut Function, block: BlockId, cme: *const rb_callable_method_entry_t) { - let owner = unsafe { (*cme).owner }; - let called_id = unsafe { (*cme).called_id }; - let qualified_method_name = qualified_method_name(owner, called_id); - let not_inlined_cfunc_counter_pointers = ZJITState::get_not_inlined_cfunc_counter_pointers(); - let counter_ptr = not_inlined_cfunc_counter_pointers.entry(qualified_method_name.clone()).or_insert_with(|| Box::new(0)); - let counter_ptr = &mut **counter_ptr as *mut u64; - - fun.push_insn(block, Insn::IncrCounterPtr { counter_ptr }); - } - - fn count_not_annotated_cfunc(fun: &mut Function, block: BlockId, cme: *const rb_callable_method_entry_t) { - let owner = unsafe { (*cme).owner }; - let called_id = unsafe { (*cme).called_id }; - let qualified_method_name = qualified_method_name(owner, called_id); - let not_annotated_cfunc_counter_pointers = ZJITState::get_not_annotated_cfunc_counter_pointers(); - let counter_ptr = not_annotated_cfunc_counter_pointers.entry(qualified_method_name.clone()).or_insert_with(|| Box::new(0)); - let counter_ptr = &mut **counter_ptr as *mut u64; - - fun.push_insn(block, Insn::IncrCounterPtr { counter_ptr }); - } - for block in self.rpo() { let old_insns = std::mem::take(&mut self.blocks[block.0].insns); assert!(self.blocks[block.0].insns.is_empty()); diff --git a/zjit/src/hir/opt_tests.rs b/zjit/src/hir/opt_tests.rs index de4e2ec39d..8dec65fed6 100644 --- a/zjit/src/hir/opt_tests.rs +++ b/zjit/src/hir/opt_tests.rs @@ -11408,21 +11408,20 @@ mod hir_opt_tests { #[test] fn test_invokesuper_to_cfunc_optimizes_to_ccall() { eval(" - class MyArray < Array - def length + class C < Hash + def size super end end - MyArray.new.length; MyArray.new.length + C.new.size "); - let hir = hir_string_proc("MyArray.new.method(:length)"); + let hir = hir_string_proc("C.new.method(:size)"); assert!(!hir.contains("InvokeSuper "), "Expected unoptimized InvokeSuper but got:\n{hir}"); - assert!(hir.contains("CCallWithFrame"), "Should optimize to CCallWithFrame for non-variadic cfunc:\n{hir}"); - assert_snapshot!(hir, @" - fn length@<compiled>:4: + assert_snapshot!(hir, @r" + fn size@<compiled>:4: bb0(): EntryPoint interpreter v1:BasicObject = LoadSelf @@ -11431,12 +11430,46 @@ mod hir_opt_tests { EntryPoint JIT(0) Jump bb2(v4) bb2(v6:BasicObject): - PatchPoint MethodRedefined(Array@0x1000, length@0x1008, cme:0x1010) + PatchPoint MethodRedefined(Hash@0x1000, size@0x1008, cme:0x1010) + v17:CPtr = GetLEP + GuardSuperMethodEntry v17, 0x1038 + v19:RubyValue = GetBlockHandler v17 + v20:FalseClass = GuardBitEquals v19, Value(false) + IncrCounter inline_cfunc_optimized_send_count + v22:Fixnum = CCall v6, :Hash#size@0x1040 + CheckInterrupts + Return v22 + "); + } + + #[test] + fn test_inline_invokesuper_to_basicobject_initialize() { + eval(" + class C + def initialize + super + end + end + + C.new + "); + assert_snapshot!(hir_string_proc("C.instance_method(:initialize)"), @r" + fn initialize@<compiled>:4: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + Jump bb2(v1) + bb1(v4:BasicObject): + EntryPoint JIT(0) + Jump bb2(v4) + bb2(v6:BasicObject): + PatchPoint MethodRedefined(BasicObject@0x1000, initialize@0x1008, cme:0x1010) v17:CPtr = GetLEP GuardSuperMethodEntry v17, 0x1038 v19:RubyValue = GetBlockHandler v17 v20:FalseClass = GuardBitEquals v19, Value(false) - v21:BasicObject = CCallWithFrame v6, :Array#length@0x1040 + v21:NilClass = Const Value(nil) + IncrCounter inline_cfunc_optimized_send_count CheckInterrupts Return v21 "); |
