summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMax Bernstein <ruby@bernsteinbear.com>2025-10-24 19:07:04 -0700
committerMax Bernstein <tekknolagi@gmail.com>2025-10-28 10:49:30 -0400
commitc2bef01b668174936c0a25358d9d50b38bcf341c (patch)
treeb6c3d53f0638c3265ea2fdf3f3446a57a829f239
parente973baa837a9cc17189ed4e32e43e047f622766b (diff)
ZJIT: Optimize Kernel#===
-rw-r--r--zjit/src/cruby_methods.rs12
-rw-r--r--zjit/src/hir.rs118
2 files changed, 122 insertions, 8 deletions
diff --git a/zjit/src/cruby_methods.rs b/zjit/src/cruby_methods.rs
index e7be5ab445..199b5a64ac 100644
--- a/zjit/src/cruby_methods.rs
+++ b/zjit/src/cruby_methods.rs
@@ -194,6 +194,7 @@ pub fn init() -> Annotations {
annotate!(rb_mKernel, "itself", inline_kernel_itself);
annotate!(rb_mKernel, "block_given?", inline_kernel_block_given_p);
+ annotate!(rb_mKernel, "===", inline_eqq);
annotate!(rb_cString, "bytesize", types::Fixnum, no_gc, leaf, elidable);
annotate!(rb_cString, "size", types::Fixnum, no_gc, leaf, elidable);
annotate!(rb_cString, "length", types::Fixnum, no_gc, leaf, elidable);
@@ -379,6 +380,17 @@ fn inline_nilclass_nil_p(fun: &mut hir::Function, block: hir::BlockId, _recv: hi
Some(fun.push_insn(block, hir::Insn::Const { val: hir::Const::Value(Qtrue) }))
}
+fn inline_eqq(fun: &mut hir::Function, block: hir::BlockId, recv: hir::InsnId, args: &[hir::InsnId], state: hir::InsnId) -> Option<hir::InsnId> {
+ let &[other] = args else { return None; };
+ let recv_class = fun.type_of(recv).runtime_exact_ruby_class()?;
+ if !fun.assume_expected_cfunc(block, recv_class, ID!(eq), rb_obj_equal as _, state) {
+ return None;
+ }
+ let c_result = fun.push_insn(block, hir::Insn::IsBitEqual { left: recv, right: other });
+ let result = fun.push_insn(block, hir::Insn::BoxBool { val: c_result });
+ Some(result)
+}
+
fn inline_kernel_nil_p(fun: &mut hir::Function, block: hir::BlockId, _recv: hir::InsnId, args: &[hir::InsnId], _state: hir::InsnId) -> Option<hir::InsnId> {
if !args.is_empty() { return None; }
Some(fun.push_insn(block, hir::Insn::Const { val: hir::Const::Value(Qfalse) }))
diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs
index 6d0b120a37..c184c2d4b2 100644
--- a/zjit/src/hir.rs
+++ b/zjit/src/hir.rs
@@ -2037,6 +2037,21 @@ impl Function {
None
}
+ pub fn assume_expected_cfunc(&mut self, block: BlockId, class: VALUE, method_id: ID, cfunc: *mut c_void, state: InsnId) -> bool {
+ let cme = unsafe { rb_callable_method_entry(class, method_id) };
+ if cme.is_null() { return false; }
+ let def_type = unsafe { get_cme_def_type(cme) };
+ if def_type != VM_METHOD_TYPE_CFUNC { return false; }
+ if unsafe { get_mct_func(get_cme_def_body_cfunc(cme)) } != cfunc {
+ return false;
+ }
+ self.gen_patch_points_for_optimized_ccall(block, class, method_id, cme, state);
+ if class.instance_can_have_singleton_class() {
+ self.push_insn(block, Insn::PatchPoint { invariant: Invariant::NoSingletonClass { klass: class }, state });
+ }
+ true
+ }
+
pub fn likely_a(&self, val: InsnId, ty: Type, state: InsnId) -> bool {
if self.type_of(val).is_subtype(ty) {
return true;
@@ -2545,6 +2560,11 @@ impl Function {
self.infer_types();
}
+ fn gen_patch_points_for_optimized_ccall(&mut self, block: BlockId, recv_class: VALUE, method_id: ID, method: *const rb_callable_method_entry_struct, state: InsnId) {
+ self.push_insn(block, Insn::PatchPoint { invariant: Invariant::NoTracePoint, state });
+ self.push_insn(block, Insn::PatchPoint { invariant: Invariant::MethodRedefined { klass: recv_class, method: method_id, cme: method }, state });
+ }
+
/// Optimize SendWithoutBlock that land in a C method to a direct CCall without
/// runtime lookup.
fn optimize_c_calls(&mut self) {
@@ -2552,11 +2572,6 @@ impl Function {
return;
}
- fn gen_patch_points_for_optimized_ccall(fun: &mut Function, block: BlockId, recv_class: VALUE, method_id: ID, method: *const rb_callable_method_entry_struct, state: InsnId) {
- fun.push_insn(block, Insn::PatchPoint { invariant: Invariant::NoTracePoint, state });
- fun.push_insn(block, Insn::PatchPoint { invariant: Invariant::MethodRedefined { klass: recv_class, method: method_id, cme: method }, state });
- }
-
// Try to reduce a Send insn to a CCallWithFrame
fn reduce_send_to_ccall(
fun: &mut Function,
@@ -2615,7 +2630,7 @@ impl Function {
}
// Commit to the replacement. Put PatchPoint.
- gen_patch_points_for_optimized_ccall(fun, block, recv_class, method_id, method, state);
+ fun.gen_patch_points_for_optimized_ccall(block, recv_class, method_id, method, state);
if recv_class.instance_can_have_singleton_class() {
fun.push_insn(block, Insn::PatchPoint { invariant: Invariant::NoSingletonClass { klass: recv_class }, state });
}
@@ -2717,7 +2732,7 @@ impl Function {
}
// Commit to the replacement. Put PatchPoint.
- gen_patch_points_for_optimized_ccall(fun, block, recv_class, method_id, method, state);
+ fun.gen_patch_points_for_optimized_ccall(block, recv_class, method_id, method, state);
if recv_class.instance_can_have_singleton_class() {
fun.push_insn(block, Insn::PatchPoint { invariant: Invariant::NoSingletonClass { klass: recv_class }, state });
}
@@ -2784,7 +2799,7 @@ impl Function {
// func(int argc, VALUE *argv, VALUE recv)
let ci_flags = unsafe { vm_ci_flag(call_info) };
if ci_flags & VM_CALL_ARGS_SIMPLE != 0 {
- gen_patch_points_for_optimized_ccall(fun, block, recv_class, method_id, method, state);
+ fun.gen_patch_points_for_optimized_ccall(block, recv_class, method_id, method, state);
if recv_class.instance_can_have_singleton_class() {
fun.push_insn(block, Insn::PatchPoint { invariant: Invariant::NoSingletonClass { klass: recv_class }, state });
@@ -13182,6 +13197,93 @@ mod opt_tests {
}
#[test]
+ fn test_specialize_nil_eqq() {
+ eval("
+ def test(a, b) = a === b
+ test(nil, 5)
+ ");
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:2:
+ bb0():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ v2:BasicObject = GetLocal l0, SP@5
+ v3:BasicObject = GetLocal l0, SP@4
+ Jump bb2(v1, v2, v3)
+ bb1(v6:BasicObject, v7:BasicObject, v8:BasicObject):
+ EntryPoint JIT(0)
+ Jump bb2(v6, v7, v8)
+ bb2(v10:BasicObject, v11:BasicObject, v12:BasicObject):
+ PatchPoint MethodRedefined(NilClass@0x1000, ===@0x1008, cme:0x1010)
+ v25:NilClass = GuardType v11, NilClass
+ PatchPoint MethodRedefined(NilClass@0x1000, ==@0x1038, cme:0x1040)
+ v28:CBool = IsBitEqual v25, v12
+ v29:BoolExact = BoxBool v28
+ IncrCounter inline_cfunc_optimized_send_count
+ CheckInterrupts
+ Return v29
+ ");
+ }
+
+ #[test]
+ fn test_specialize_true_eqq() {
+ eval("
+ def test(a, b) = a === b
+ test(true, 5)
+ ");
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:2:
+ bb0():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ v2:BasicObject = GetLocal l0, SP@5
+ v3:BasicObject = GetLocal l0, SP@4
+ Jump bb2(v1, v2, v3)
+ bb1(v6:BasicObject, v7:BasicObject, v8:BasicObject):
+ EntryPoint JIT(0)
+ Jump bb2(v6, v7, v8)
+ bb2(v10:BasicObject, v11:BasicObject, v12:BasicObject):
+ PatchPoint MethodRedefined(TrueClass@0x1000, ===@0x1008, cme:0x1010)
+ v25:TrueClass = GuardType v11, TrueClass
+ PatchPoint MethodRedefined(TrueClass@0x1000, ==@0x1038, cme:0x1040)
+ v28:CBool = IsBitEqual v25, v12
+ v29:BoolExact = BoxBool v28
+ IncrCounter inline_cfunc_optimized_send_count
+ CheckInterrupts
+ Return v29
+ ");
+ }
+
+ #[test]
+ fn test_specialize_false_eqq() {
+ eval("
+ def test(a, b) = a === b
+ test(true, 5)
+ ");
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:2:
+ bb0():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ v2:BasicObject = GetLocal l0, SP@5
+ v3:BasicObject = GetLocal l0, SP@4
+ Jump bb2(v1, v2, v3)
+ bb1(v6:BasicObject, v7:BasicObject, v8:BasicObject):
+ EntryPoint JIT(0)
+ Jump bb2(v6, v7, v8)
+ bb2(v10:BasicObject, v11:BasicObject, v12:BasicObject):
+ PatchPoint MethodRedefined(TrueClass@0x1000, ===@0x1008, cme:0x1010)
+ v25:TrueClass = GuardType v11, TrueClass
+ PatchPoint MethodRedefined(TrueClass@0x1000, ==@0x1038, cme:0x1040)
+ v28:CBool = IsBitEqual v25, v12
+ v29:BoolExact = BoxBool v28
+ IncrCounter inline_cfunc_optimized_send_count
+ CheckInterrupts
+ Return v29
+ ");
+ }
+
+ #[test]
fn test_guard_fixnum_and_fixnum() {
eval("
def test(x, y) = x & y