summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMax Bernstein <ruby@bernsteinbear.com>2025-09-17 12:46:18 -0400
committerMax Bernstein <tekknolagi@gmail.com>2025-09-17 17:27:35 -0400
commit7a82f1faa0f6157e3e3104d04531f62a8b1db90c (patch)
tree1c06bb1fdf070696587fb8fac46eb52f2b0c0135
parentc31a73d7ea410c74e1c6bc887619898eac3c8795 (diff)
ZJIT: Const-fold IsMethodCfunc
-rw-r--r--test/ruby/test_zjit.rb12
-rw-r--r--vm_insnhelper.c12
-rw-r--r--zjit.c5
-rw-r--r--zjit/bindgen/src/main.rs2
-rw-r--r--zjit/src/codegen.rs2
-rw-r--r--zjit/src/cruby_bindings.inc.rs9
-rw-r--r--zjit/src/hir.rs56
7 files changed, 68 insertions, 30 deletions
diff --git a/test/ruby/test_zjit.rb b/test/ruby/test_zjit.rb
index 42d10490c5..e530fb797f 100644
--- a/test/ruby/test_zjit.rb
+++ b/test/ruby/test_zjit.rb
@@ -851,6 +851,18 @@ class TestZJIT < Test::Unit::TestCase
}, insns: [:opt_new]
end
+ def test_opt_new_invalidate_new
+ assert_compiles '["Foo", "foo"]', %q{
+ class Foo; end
+ def test = Foo.new
+ test; test
+ result = [test.class.name]
+ def Foo.new = "foo"
+ result << test
+ result
+ }, insns: [:opt_new], call_threshold: 2
+ end
+
def test_new_hash_empty
assert_compiles '{}', %q{
def test = {}
diff --git a/vm_insnhelper.c b/vm_insnhelper.c
index 362af31188..8022a29a6e 100644
--- a/vm_insnhelper.c
+++ b/vm_insnhelper.c
@@ -2353,6 +2353,12 @@ vm_search_method(VALUE cd_owner, struct rb_call_data *cd, VALUE recv)
return vm_cc_cme(cc);
}
+const struct rb_callable_method_entry_struct *
+rb_zjit_vm_search_method(VALUE cd_owner, struct rb_call_data *cd, VALUE recv)
+{
+ return vm_search_method(cd_owner, cd, recv);
+}
+
#if __has_attribute(transparent_union)
typedef union {
VALUE (*anyargs)(ANYARGS);
@@ -2417,6 +2423,12 @@ vm_method_cfunc_is(const rb_iseq_t *iseq, CALL_DATA cd, VALUE recv, cfunc_type f
return check_cfunc(cme, func);
}
+bool
+rb_zjit_cme_is_cfunc(const rb_callable_method_entry_t *me, const cfunc_type func)
+{
+ return check_cfunc(me, func);
+}
+
int
rb_vm_method_cfunc_is(const rb_iseq_t *iseq, CALL_DATA cd, VALUE recv, cfunc_type func)
{
diff --git a/zjit.c b/zjit.c
index 6bbe508f24..4b29578b4a 100644
--- a/zjit.c
+++ b/zjit.c
@@ -170,6 +170,11 @@ rb_zjit_local_id(const rb_iseq_t *iseq, unsigned idx)
return ISEQ_BODY(iseq)->local_table[idx];
}
+bool rb_zjit_cme_is_cfunc(const rb_callable_method_entry_t *me, const void *func);
+
+const struct rb_callable_method_entry_struct *
+rb_zjit_vm_search_method(VALUE cd_owner, struct rb_call_data *cd, VALUE recv);
+
// Primitives used by zjit.rb. Don't put other functions below, which wouldn't use them.
VALUE rb_zjit_assert_compiles(rb_execution_context_t *ec, VALUE self);
VALUE rb_zjit_stats(rb_execution_context_t *ec, VALUE self, VALUE target_key);
diff --git a/zjit/bindgen/src/main.rs b/zjit/bindgen/src/main.rs
index c6f02be415..6e9a5a529f 100644
--- a/zjit/bindgen/src/main.rs
+++ b/zjit/bindgen/src/main.rs
@@ -342,6 +342,8 @@ fn main() {
.allowlist_function("rb_get_cfp_ep_level")
.allowlist_function("rb_get_cme_def_type")
.allowlist_function("rb_zjit_constcache_shareable")
+ .allowlist_function("rb_zjit_vm_search_method")
+ .allowlist_function("rb_zjit_cme_is_cfunc")
.allowlist_function("rb_get_cme_def_body_attr_id")
.allowlist_function("rb_get_symbol_id")
.allowlist_function("rb_get_cme_def_body_optimized_type")
diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs
index 0f167ceec3..52d1dd315b 100644
--- a/zjit/src/codegen.rs
+++ b/zjit/src/codegen.rs
@@ -389,7 +389,7 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio
Insn::FixnumAnd { left, right } => gen_fixnum_and(asm, opnd!(left), opnd!(right)),
Insn::FixnumOr { left, right } => gen_fixnum_or(asm, opnd!(left), opnd!(right)),
Insn::IsNil { val } => gen_isnil(asm, opnd!(val)),
- &Insn::IsMethodCfunc { val, cd, cfunc } => gen_is_method_cfunc(jit, asm, opnd!(val), cd, cfunc),
+ &Insn::IsMethodCfunc { val, cd, cfunc, state: _ } => gen_is_method_cfunc(jit, asm, opnd!(val), cd, cfunc),
Insn::Test { val } => gen_test(asm, opnd!(val)),
Insn::GuardType { val, guard_type, state } => gen_guard_type(jit, asm, opnd!(val), *guard_type, &function.frame_state(*state)),
Insn::GuardTypeNot { val, guard_type, state } => gen_guard_type_not(jit, asm, opnd!(val), *guard_type, &function.frame_state(*state)),
diff --git a/zjit/src/cruby_bindings.inc.rs b/zjit/src/cruby_bindings.inc.rs
index 4bb1c3dffd..dfa1be9b8f 100644
--- a/zjit/src/cruby_bindings.inc.rs
+++ b/zjit/src/cruby_bindings.inc.rs
@@ -937,6 +937,15 @@ unsafe extern "C" {
pub fn rb_zjit_defined_ivar(obj: VALUE, id: ID, pushval: VALUE) -> VALUE;
pub fn rb_zjit_insn_leaf(insn: ::std::os::raw::c_int, opes: *const VALUE) -> bool;
pub fn rb_zjit_local_id(iseq: *const rb_iseq_t, idx: ::std::os::raw::c_uint) -> ID;
+ pub fn rb_zjit_cme_is_cfunc(
+ me: *const rb_callable_method_entry_t,
+ func: *const ::std::os::raw::c_void,
+ ) -> bool;
+ pub fn rb_zjit_vm_search_method(
+ cd_owner: VALUE,
+ cd: *mut rb_call_data,
+ recv: VALUE,
+ ) -> *const rb_callable_method_entry_struct;
pub fn rb_iseq_encoded_size(iseq: *const rb_iseq_t) -> ::std::os::raw::c_uint;
pub fn rb_iseq_pc_at_idx(iseq: *const rb_iseq_t, insn_idx: u32) -> *mut VALUE;
pub fn rb_iseq_opcode_at_pc(iseq: *const rb_iseq_t, pc: *const VALUE) -> ::std::os::raw::c_int;
diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs
index 114cfda549..7dcf1c6ba8 100644
--- a/zjit/src/hir.rs
+++ b/zjit/src/hir.rs
@@ -576,7 +576,7 @@ pub enum Insn {
/// Return C `true` if `val` is `Qnil`, else `false`.
IsNil { val: InsnId },
/// Return C `true` if `val`'s method on cd resolves to the cfunc.
- IsMethodCfunc { val: InsnId, cd: *const rb_call_data, cfunc: *const u8 },
+ IsMethodCfunc { val: InsnId, cd: *const rb_call_data, cfunc: *const u8, state: InsnId },
Defined { op_type: usize, obj: VALUE, pushval: VALUE, v: InsnId, state: InsnId },
GetConstantPath { ic: *const iseq_inline_constant_cache, state: InsnId },
@@ -1350,7 +1350,7 @@ impl Function {
&ToRegexp { opt, ref values, state } => ToRegexp { opt, values: find_vec!(values), state },
&Test { val } => Test { val: find!(val) },
&IsNil { val } => IsNil { val: find!(val) },
- &IsMethodCfunc { val, cd, cfunc } => IsMethodCfunc { val: find!(val), cd, cfunc },
+ &IsMethodCfunc { val, cd, cfunc, state } => IsMethodCfunc { val: find!(val), cd, cfunc, state },
Jump(target) => Jump(find_branch_edge!(target)),
&IfTrue { val, ref target } => IfTrue { val: find!(val), target: find_branch_edge!(target) },
&IfFalse { val, ref target } => IfFalse { val: find!(val), target: find_branch_edge!(target) },
@@ -1906,6 +1906,16 @@ impl Function {
self.push_insn_id(block, insn_id);
}
}
+ Insn::IsMethodCfunc { val, cd, cfunc, state } if self.type_of(val).ruby_object_known() => {
+ let class = self.type_of(val).ruby_object().unwrap();
+ let cme = unsafe { rb_zjit_vm_search_method(self.iseq.into(), cd as *mut rb_call_data, class) };
+ let is_expected_cfunc = unsafe { rb_zjit_cme_is_cfunc(cme, cfunc as *const c_void) };
+ let method = unsafe { rb_vm_ci_mid((*cd).ci) };
+ self.push_insn(block, Insn::PatchPoint { invariant: Invariant::MethodRedefined { klass: class, method, cme }, state });
+ let replacement = self.push_insn(block, Insn::Const { val: Const::CBool(is_expected_cfunc) });
+ self.insn_types[replacement.0] = self.infer_type(replacement);
+ self.make_equal_to(insn_id, replacement);
+ }
Insn::ObjectAlloc { val, state } => {
let val_type = self.type_of(val);
if val_type.is_subtype(types::Class) && val_type.ruby_object_known() {
@@ -2295,8 +2305,7 @@ impl Function {
| &Insn::Return { val }
| &Insn::Test { val }
| &Insn::SetLocal { val, .. }
- | &Insn::IsNil { val }
- | &Insn::IsMethodCfunc { val, .. } =>
+ | &Insn::IsNil { val } =>
worklist.push_back(val),
&Insn::SetGlobal { val, state, .. }
| &Insn::Defined { v: val, state, .. }
@@ -2308,6 +2317,7 @@ impl Function {
| &Insn::GuardBitEquals { val, state, .. }
| &Insn::GuardShape { val, state, .. }
| &Insn::ToArray { val, state }
+ | &Insn::IsMethodCfunc { val, state, .. }
| &Insn::ToNewArray { val, state } => {
worklist.push_back(val);
worklist.push_back(state);
@@ -3450,7 +3460,8 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
// TODO: Guard on a profiled class and add a patch point for #new redefinition
let argc = unsafe { vm_ci_argc((*cd).ci) } as usize;
let val = state.stack_topn(argc)?;
- let test_id = fun.push_insn(block, Insn::IsMethodCfunc { val, cd, cfunc: rb_class_new_instance_pass_kw as *const u8 });
+ let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
+ let test_id = fun.push_insn(block, Insn::IsMethodCfunc { val, cd, cfunc: rb_class_new_instance_pass_kw as *const u8, state: exit_id });
// Jump to the fallback block if it's not the expected function.
// Skip CheckInterrupts since the #new call will do it very soon anyway.
@@ -3463,7 +3474,6 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
queue.push_back((state.clone(), target, target_idx, local_inval));
// Move on to the fast path
- let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
let insn_id = fun.push_insn(block, Insn::ObjectAlloc { val, state: exit_id });
state.stack_setn(argc, insn_id);
state.stack_setn(argc + 1, insn_id);
@@ -5400,8 +5410,8 @@ mod tests {
bb0(v0:BasicObject):
v5:BasicObject = GetConstantPath 0x1000
v6:NilClass = Const Value(nil)
- v7:CBool = IsMethodCFunc v5, :new
- IfFalse v7, bb1(v0, v6, v5)
+ v8:CBool = IsMethodCFunc v5, :new
+ IfFalse v8, bb1(v0, v6, v5)
v10:HeapObject = ObjectAlloc v5
v12:BasicObject = SendWithoutBlock v10, :initialize
CheckInterrupts
@@ -8092,18 +8102,12 @@ mod opt_tests {
PatchPoint StableConstantNames(0x1000, C)
v34:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008))
v6:NilClass = Const Value(nil)
- v7:CBool = IsMethodCFunc v34, :new
- IfFalse v7, bb1(v0, v6, v34)
- v35:HeapObject[class_exact:C] = ObjectAllocClass VALUE(0x1008)
- v12:BasicObject = SendWithoutBlock v35, :initialize
+ PatchPoint MethodRedefined(C@0x1008, new@0x1010, cme:0x1018)
+ v37:HeapObject[class_exact:C] = ObjectAllocClass VALUE(0x1008)
+ v12:BasicObject = SendWithoutBlock v37, :initialize
CheckInterrupts
- Jump bb2(v0, v35, v12)
- bb1(v16:BasicObject, v17:NilClass, v18:Class[VALUE(0x1008)]):
- v21:BasicObject = SendWithoutBlock v18, :new
- Jump bb2(v16, v21, v17)
- bb2(v23:BasicObject, v24:BasicObject, v25:BasicObject):
CheckInterrupts
- Return v24
+ Return v37
");
}
@@ -8126,19 +8130,13 @@ mod opt_tests {
v36:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008))
v6:NilClass = Const Value(nil)
v7:Fixnum[1] = Const Value(1)
- v8:CBool = IsMethodCFunc v36, :new
- IfFalse v8, bb1(v0, v6, v36, v7)
- v37:HeapObject[class_exact:C] = ObjectAllocClass VALUE(0x1008)
- PatchPoint MethodRedefined(C@0x1008, initialize@0x1010, cme:0x1018)
- v39:BasicObject = SendWithoutBlockDirect v37, :initialize (0x1040), v7
+ PatchPoint MethodRedefined(C@0x1008, new@0x1010, cme:0x1018)
+ v39:HeapObject[class_exact:C] = ObjectAllocClass VALUE(0x1008)
+ PatchPoint MethodRedefined(C@0x1008, initialize@0x1040, cme:0x1048)
+ v41:BasicObject = SendWithoutBlockDirect v39, :initialize (0x1070), v7
CheckInterrupts
- Jump bb2(v0, v37, v39)
- bb1(v17:BasicObject, v18:NilClass, v19:Class[VALUE(0x1008)], v20:Fixnum[1]):
- v23:BasicObject = SendWithoutBlock v19, :new, v20
- Jump bb2(v17, v23, v18)
- bb2(v25:BasicObject, v26:BasicObject, v27:BasicObject):
CheckInterrupts
- Return v26
+ Return v39
");
}