summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMax Bernstein <rubybugs@bernsteinbear.com>2026-01-30 12:29:15 -0500
committerGitHub <noreply@github.com>2026-01-30 12:29:15 -0500
commit1298f9ac1ad390594815bc4d3739eb312bca8887 (patch)
treef14579f22262b032ef58a334a11856b5ef03414d
parent9be01bc70dca0e727fe1f518ebae1f6f72405b84 (diff)
ZJIT: Support CFunc inlining in InvokeSuper (#16004)
Also generally make the CFunc process look more like `optimize_c_calls`.
-rw-r--r--zjit/src/codegen.rs20
-rw-r--r--zjit/src/hir.rs179
-rw-r--r--zjit/src/hir/opt_tests.rs51
3 files changed, 182 insertions, 68 deletions
diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs
index 2038be808d..a3068ff23d 100644
--- a/zjit/src/codegen.rs
+++ b/zjit/src/codegen.rs
@@ -2330,6 +2330,26 @@ fn gen_guard_type(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, guard
let tag = asm.and(flags, Opnd::UImm(RUBY_T_MASK as u64));
asm.cmp(tag, Opnd::UImm(RUBY_T_STRING as u64));
asm.jne(side);
+ } else if guard_type.is_subtype(types::Array) {
+ let side = side_exit(jit, state, GuardType(guard_type));
+
+ // Check special constant
+ asm.test(val, Opnd::UImm(RUBY_IMMEDIATE_MASK as u64));
+ asm.jnz(side.clone());
+
+ // Check false
+ asm.cmp(val, Qfalse.into());
+ asm.je(side.clone());
+
+ let val = match val {
+ Opnd::Reg(_) | Opnd::VReg { .. } => val,
+ _ => asm.load(val),
+ };
+
+ let flags = asm.load(Opnd::mem(VALUE_BITS, val, RUBY_OFFSET_RBASIC_FLAGS));
+ let tag = asm.and(flags, Opnd::UImm(RUBY_T_MASK as u64));
+ asm.cmp(tag, Opnd::UImm(RUBY_T_ARRAY as u64));
+ asm.jne(side);
} else if guard_type.bit_equal(types::HeapBasicObject) {
let side_exit = side_exit(jit, state, GuardType(guard_type));
asm.cmp(val, Opnd::Value(Qfalse));
diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs
index 901beffea0..9aa70b5d34 100644
--- a/zjit/src/hir.rs
+++ b/zjit/src/hir.rs
@@ -3564,13 +3564,7 @@ impl Function {
assert!(flags & VM_CALL_FCALL != 0);
// Reject calls with complex argument handling.
- let complex_arg_types = VM_CALL_ARGS_SPLAT
- | VM_CALL_KW_SPLAT
- | VM_CALL_KWARG
- | VM_CALL_ARGS_BLOCKARG
- | VM_CALL_FORWARDING;
-
- if (flags & complex_arg_types) != 0 {
+ if unspecializable_c_call_type(flags) {
self.push_insn_id(block, insn_id);
self.set_dynamic_send_reason(insn_id, SuperComplexArgsPass);
continue;
@@ -3608,14 +3602,18 @@ impl Function {
}
// Look up the super method.
- let super_cme = unsafe { rb_callable_method_entry(superclass, mid) };
+ let mut super_cme = unsafe { rb_callable_method_entry(superclass, mid) };
if super_cme.is_null() {
self.push_insn_id(block, insn_id);
self.set_dynamic_send_reason(insn_id, SuperTargetNotFound);
continue;
}
- let def_type = unsafe { get_cme_def_type(super_cme) };
+ let mut def_type = unsafe { get_cme_def_type(super_cme) };
+ while def_type == VM_METHOD_TYPE_ALIAS {
+ super_cme = unsafe { rb_aliased_callable_method_entry(super_cme) };
+ def_type = unsafe { get_cme_def_type(super_cme) };
+ }
if def_type == VM_METHOD_TYPE_ISEQ {
// Check if the super method's parameters support direct send.
@@ -3653,6 +3651,12 @@ impl Function {
let cfunc_argc = unsafe { get_mct_argc(cfunc) };
let cfunc_ptr = unsafe { get_mct_func(cfunc) }.cast();
+ let props = ZJITState::get_method_annotations().get_cfunc_properties(super_cme);
+ if props.is_none() && get_option!(stats) {
+ self.count_not_annotated_cfunc(block, super_cme);
+ }
+ let props = props.unwrap_or_default();
+
match cfunc_argc {
// C function with fixed argument count.
0.. => {
@@ -3665,20 +3669,48 @@ impl Function {
emit_super_call_guards(self, block, super_cme, current_cme, mid, state);
+ // Try inlining the cfunc into HIR
+ let tmp_block = self.new_block(u32::MAX);
+ if let Some(replacement) = (props.inline)(self, tmp_block, recv, &args, state) {
+ // Copy contents of tmp_block to block
+ assert_ne!(block, tmp_block);
+ let insns = std::mem::take(&mut self.blocks[tmp_block.0].insns);
+ self.blocks[block.0].insns.extend(insns);
+ self.push_insn(block, Insn::IncrCounter(Counter::inline_cfunc_optimized_send_count));
+ self.make_equal_to(insn_id, replacement);
+ if self.type_of(replacement).bit_equal(types::Any) {
+ // Not set yet; infer type
+ self.insn_types[replacement.0] = self.infer_type(replacement);
+ }
+ self.remove_block(tmp_block);
+ continue;
+ }
+
// Use CCallWithFrame for the C function.
let name = rust_str_to_id(&qualified_method_name(unsafe { (*super_cme).owner }, unsafe { (*super_cme).called_id }));
- let ccall = self.push_insn(block, Insn::CCallWithFrame {
- cd,
- cfunc: cfunc_ptr,
- recv,
- args: args.clone(),
- cme: super_cme,
- name,
- state,
- return_type: types::BasicObject,
- elidable: false,
- blockiseq: None,
- });
+ let return_type = props.return_type;
+ let elidable = props.elidable;
+ // Filter for a leaf and GC free function
+ let ccall = if props.leaf && props.no_gc {
+ self.push_insn(block, Insn::IncrCounter(Counter::inline_cfunc_optimized_send_count));
+ self.push_insn(block, Insn::CCall { cfunc: cfunc_ptr, recv, args, name, return_type, elidable })
+ } else {
+ if get_option!(stats) {
+ self.count_not_inlined_cfunc(block, super_cme);
+ }
+ self.push_insn(block, Insn::CCallWithFrame {
+ cd,
+ cfunc: cfunc_ptr,
+ recv,
+ args: args.clone(),
+ cme: super_cme,
+ name,
+ state,
+ return_type: types::BasicObject,
+ elidable: false,
+ blockiseq: None,
+ })
+ };
self.make_equal_to(insn_id, ccall);
}
@@ -3686,19 +3718,48 @@ impl Function {
-1 => {
emit_super_call_guards(self, block, super_cme, current_cme, mid, state);
+ // Try inlining the cfunc into HIR
+ let tmp_block = self.new_block(u32::MAX);
+ if let Some(replacement) = (props.inline)(self, tmp_block, recv, &args, state) {
+ // Copy contents of tmp_block to block
+ assert_ne!(block, tmp_block);
+ emit_super_call_guards(self, block, super_cme, current_cme, mid, state);
+ let insns = std::mem::take(&mut self.blocks[tmp_block.0].insns);
+ self.blocks[block.0].insns.extend(insns);
+ self.push_insn(block, Insn::IncrCounter(Counter::inline_cfunc_optimized_send_count));
+ self.make_equal_to(insn_id, replacement);
+ if self.type_of(replacement).bit_equal(types::Any) {
+ // Not set yet; infer type
+ self.insn_types[replacement.0] = self.infer_type(replacement);
+ }
+ self.remove_block(tmp_block);
+ continue;
+ }
+
// Use CCallVariadic for the variadic C function.
let name = rust_str_to_id(&qualified_method_name(unsafe { (*super_cme).owner }, unsafe { (*super_cme).called_id }));
- let ccall = self.push_insn(block, Insn::CCallVariadic {
- cfunc: cfunc_ptr,
- recv,
- args: args.clone(),
- cme: super_cme,
- name,
- state,
- return_type: types::BasicObject,
- elidable: false,
- blockiseq: None,
- });
+ let return_type = props.return_type;
+ let elidable = props.elidable;
+ // Filter for a leaf and GC free function
+ let ccall = if props.leaf && props.no_gc {
+ self.push_insn(block, Insn::IncrCounter(Counter::inline_cfunc_optimized_send_count));
+ self.push_insn(block, Insn::CCall { cfunc: cfunc_ptr, recv, args, name, return_type, elidable })
+ } else {
+ if get_option!(stats) {
+ self.count_not_inlined_cfunc(block, super_cme);
+ }
+ self.push_insn(block, Insn::CCallVariadic {
+ cfunc: cfunc_ptr,
+ recv,
+ args: args.clone(),
+ cme: super_cme,
+ name,
+ state,
+ return_type: types::BasicObject,
+ elidable: false,
+ blockiseq: None,
+ })
+ };
self.make_equal_to(insn_id, ccall);
}
@@ -3981,6 +4042,28 @@ impl Function {
self.push_insn(block, Insn::PatchPoint { invariant: Invariant::MethodRedefined { klass: recv_class, method: method_id, cme }, state });
}
+ fn count_not_inlined_cfunc(&mut self, block: BlockId, cme: *const rb_callable_method_entry_t) {
+ let owner = unsafe { (*cme).owner };
+ let called_id = unsafe { (*cme).called_id };
+ let qualified_method_name = qualified_method_name(owner, called_id);
+ let not_inlined_cfunc_counter_pointers = ZJITState::get_not_inlined_cfunc_counter_pointers();
+ let counter_ptr = not_inlined_cfunc_counter_pointers.entry(qualified_method_name.clone()).or_insert_with(|| Box::new(0));
+ let counter_ptr = &mut **counter_ptr as *mut u64;
+
+ self.push_insn(block, Insn::IncrCounterPtr { counter_ptr });
+ }
+
+ fn count_not_annotated_cfunc(&mut self, block: BlockId, cme: *const rb_callable_method_entry_t) {
+ let owner = unsafe { (*cme).owner };
+ let called_id = unsafe { (*cme).called_id };
+ let qualified_method_name = qualified_method_name(owner, called_id);
+ let not_annotated_cfunc_counter_pointers = ZJITState::get_not_annotated_cfunc_counter_pointers();
+ let counter_ptr = not_annotated_cfunc_counter_pointers.entry(qualified_method_name.clone()).or_insert_with(|| Box::new(0));
+ let counter_ptr = &mut **counter_ptr as *mut u64;
+
+ self.push_insn(block, Insn::IncrCounterPtr { counter_ptr });
+ }
+
/// Optimize Send/SendWithoutBlock that land in a C method to a direct CCall without
/// runtime lookup.
fn optimize_c_calls(&mut self) {
@@ -4124,7 +4207,7 @@ impl Function {
}
if get_option!(stats) {
- count_not_inlined_cfunc(fun, block, cme);
+ fun.count_not_inlined_cfunc(block, cme);
}
let ccall = fun.push_insn(block, Insn::CCallVariadic {
@@ -4238,7 +4321,7 @@ impl Function {
let props = ZJITState::get_method_annotations().get_cfunc_properties(cme);
if props.is_none() && get_option!(stats) {
- count_not_annotated_cfunc(fun, block, cme);
+ fun.count_not_annotated_cfunc(block, cme);
}
let props = props.unwrap_or_default();
@@ -4277,7 +4360,7 @@ impl Function {
fun.make_equal_to(send_insn_id, ccall);
} else {
if get_option!(stats) {
- count_not_inlined_cfunc(fun, block, cme);
+ fun.count_not_inlined_cfunc(block, cme);
}
let ccall = fun.push_insn(block, Insn::CCallWithFrame {
cd,
@@ -4326,7 +4409,7 @@ impl Function {
let cfunc = unsafe { get_mct_func(cfunc) }.cast();
let props = ZJITState::get_method_annotations().get_cfunc_properties(cme);
if props.is_none() && get_option!(stats) {
- count_not_annotated_cfunc(fun, block, cme);
+ fun.count_not_annotated_cfunc(block, cme);
}
let props = props.unwrap_or_default();
@@ -4349,7 +4432,7 @@ impl Function {
// No inlining; emit a call
if get_option!(stats) {
- count_not_inlined_cfunc(fun, block, cme);
+ fun.count_not_inlined_cfunc(block, cme);
}
let return_type = props.return_type;
let elidable = props.elidable;
@@ -4383,28 +4466,6 @@ impl Function {
Err(())
}
- fn count_not_inlined_cfunc(fun: &mut Function, block: BlockId, cme: *const rb_callable_method_entry_t) {
- let owner = unsafe { (*cme).owner };
- let called_id = unsafe { (*cme).called_id };
- let qualified_method_name = qualified_method_name(owner, called_id);
- let not_inlined_cfunc_counter_pointers = ZJITState::get_not_inlined_cfunc_counter_pointers();
- let counter_ptr = not_inlined_cfunc_counter_pointers.entry(qualified_method_name.clone()).or_insert_with(|| Box::new(0));
- let counter_ptr = &mut **counter_ptr as *mut u64;
-
- fun.push_insn(block, Insn::IncrCounterPtr { counter_ptr });
- }
-
- fn count_not_annotated_cfunc(fun: &mut Function, block: BlockId, cme: *const rb_callable_method_entry_t) {
- let owner = unsafe { (*cme).owner };
- let called_id = unsafe { (*cme).called_id };
- let qualified_method_name = qualified_method_name(owner, called_id);
- let not_annotated_cfunc_counter_pointers = ZJITState::get_not_annotated_cfunc_counter_pointers();
- let counter_ptr = not_annotated_cfunc_counter_pointers.entry(qualified_method_name.clone()).or_insert_with(|| Box::new(0));
- let counter_ptr = &mut **counter_ptr as *mut u64;
-
- fun.push_insn(block, Insn::IncrCounterPtr { counter_ptr });
- }
-
for block in self.rpo() {
let old_insns = std::mem::take(&mut self.blocks[block.0].insns);
assert!(self.blocks[block.0].insns.is_empty());
diff --git a/zjit/src/hir/opt_tests.rs b/zjit/src/hir/opt_tests.rs
index de4e2ec39d..8dec65fed6 100644
--- a/zjit/src/hir/opt_tests.rs
+++ b/zjit/src/hir/opt_tests.rs
@@ -11408,21 +11408,20 @@ mod hir_opt_tests {
#[test]
fn test_invokesuper_to_cfunc_optimizes_to_ccall() {
eval("
- class MyArray < Array
- def length
+ class C < Hash
+ def size
super
end
end
- MyArray.new.length; MyArray.new.length
+ C.new.size
");
- let hir = hir_string_proc("MyArray.new.method(:length)");
+ let hir = hir_string_proc("C.new.method(:size)");
assert!(!hir.contains("InvokeSuper "), "Expected unoptimized InvokeSuper but got:\n{hir}");
- assert!(hir.contains("CCallWithFrame"), "Should optimize to CCallWithFrame for non-variadic cfunc:\n{hir}");
- assert_snapshot!(hir, @"
- fn length@<compiled>:4:
+ assert_snapshot!(hir, @r"
+ fn size@<compiled>:4:
bb0():
EntryPoint interpreter
v1:BasicObject = LoadSelf
@@ -11431,12 +11430,46 @@ mod hir_opt_tests {
EntryPoint JIT(0)
Jump bb2(v4)
bb2(v6:BasicObject):
- PatchPoint MethodRedefined(Array@0x1000, length@0x1008, cme:0x1010)
+ PatchPoint MethodRedefined(Hash@0x1000, size@0x1008, cme:0x1010)
+ v17:CPtr = GetLEP
+ GuardSuperMethodEntry v17, 0x1038
+ v19:RubyValue = GetBlockHandler v17
+ v20:FalseClass = GuardBitEquals v19, Value(false)
+ IncrCounter inline_cfunc_optimized_send_count
+ v22:Fixnum = CCall v6, :Hash#size@0x1040
+ CheckInterrupts
+ Return v22
+ ");
+ }
+
+ #[test]
+ fn test_inline_invokesuper_to_basicobject_initialize() {
+ eval("
+ class C
+ def initialize
+ super
+ end
+ end
+
+ C.new
+ ");
+ assert_snapshot!(hir_string_proc("C.instance_method(:initialize)"), @r"
+ fn initialize@<compiled>:4:
+ bb0():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ Jump bb2(v1)
+ bb1(v4:BasicObject):
+ EntryPoint JIT(0)
+ Jump bb2(v4)
+ bb2(v6:BasicObject):
+ PatchPoint MethodRedefined(BasicObject@0x1000, initialize@0x1008, cme:0x1010)
v17:CPtr = GetLEP
GuardSuperMethodEntry v17, 0x1038
v19:RubyValue = GetBlockHandler v17
v20:FalseClass = GuardBitEquals v19, Value(false)
- v21:BasicObject = CCallWithFrame v6, :Array#length@0x1040
+ v21:NilClass = Const Value(nil)
+ IncrCounter inline_cfunc_optimized_send_count
CheckInterrupts
Return v21
");