diff options
| author | Max Bernstein <ruby@bernsteinbear.com> | 2025-10-24 19:07:04 -0700 |
|---|---|---|
| committer | Max Bernstein <tekknolagi@gmail.com> | 2025-10-28 10:49:30 -0400 |
| commit | c2bef01b668174936c0a25358d9d50b38bcf341c (patch) | |
| tree | b6c3d53f0638c3265ea2fdf3f3446a57a829f239 | |
| parent | e973baa837a9cc17189ed4e32e43e047f622766b (diff) | |
ZJIT: Optimize Kernel#===
| -rw-r--r-- | zjit/src/cruby_methods.rs | 12 | ||||
| -rw-r--r-- | zjit/src/hir.rs | 118 |
2 files changed, 122 insertions, 8 deletions
diff --git a/zjit/src/cruby_methods.rs b/zjit/src/cruby_methods.rs index e7be5ab445..199b5a64ac 100644 --- a/zjit/src/cruby_methods.rs +++ b/zjit/src/cruby_methods.rs @@ -194,6 +194,7 @@ pub fn init() -> Annotations { annotate!(rb_mKernel, "itself", inline_kernel_itself); annotate!(rb_mKernel, "block_given?", inline_kernel_block_given_p); + annotate!(rb_mKernel, "===", inline_eqq); annotate!(rb_cString, "bytesize", types::Fixnum, no_gc, leaf, elidable); annotate!(rb_cString, "size", types::Fixnum, no_gc, leaf, elidable); annotate!(rb_cString, "length", types::Fixnum, no_gc, leaf, elidable); @@ -379,6 +380,17 @@ fn inline_nilclass_nil_p(fun: &mut hir::Function, block: hir::BlockId, _recv: hi Some(fun.push_insn(block, hir::Insn::Const { val: hir::Const::Value(Qtrue) })) } +fn inline_eqq(fun: &mut hir::Function, block: hir::BlockId, recv: hir::InsnId, args: &[hir::InsnId], state: hir::InsnId) -> Option<hir::InsnId> { + let &[other] = args else { return None; }; + let recv_class = fun.type_of(recv).runtime_exact_ruby_class()?; + if !fun.assume_expected_cfunc(block, recv_class, ID!(eq), rb_obj_equal as _, state) { + return None; + } + let c_result = fun.push_insn(block, hir::Insn::IsBitEqual { left: recv, right: other }); + let result = fun.push_insn(block, hir::Insn::BoxBool { val: c_result }); + Some(result) +} + fn inline_kernel_nil_p(fun: &mut hir::Function, block: hir::BlockId, _recv: hir::InsnId, args: &[hir::InsnId], _state: hir::InsnId) -> Option<hir::InsnId> { if !args.is_empty() { return None; } Some(fun.push_insn(block, hir::Insn::Const { val: hir::Const::Value(Qfalse) })) diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index 6d0b120a37..c184c2d4b2 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -2037,6 +2037,21 @@ impl Function { None } + pub fn assume_expected_cfunc(&mut self, block: BlockId, class: VALUE, method_id: ID, cfunc: *mut c_void, state: InsnId) -> bool { + let cme = unsafe { rb_callable_method_entry(class, method_id) }; + if cme.is_null() { return false; } + let def_type = unsafe { get_cme_def_type(cme) }; + if def_type != VM_METHOD_TYPE_CFUNC { return false; } + if unsafe { get_mct_func(get_cme_def_body_cfunc(cme)) } != cfunc { + return false; + } + self.gen_patch_points_for_optimized_ccall(block, class, method_id, cme, state); + if class.instance_can_have_singleton_class() { + self.push_insn(block, Insn::PatchPoint { invariant: Invariant::NoSingletonClass { klass: class }, state }); + } + true + } + pub fn likely_a(&self, val: InsnId, ty: Type, state: InsnId) -> bool { if self.type_of(val).is_subtype(ty) { return true; @@ -2545,6 +2560,11 @@ impl Function { self.infer_types(); } + fn gen_patch_points_for_optimized_ccall(&mut self, block: BlockId, recv_class: VALUE, method_id: ID, method: *const rb_callable_method_entry_struct, state: InsnId) { + self.push_insn(block, Insn::PatchPoint { invariant: Invariant::NoTracePoint, state }); + self.push_insn(block, Insn::PatchPoint { invariant: Invariant::MethodRedefined { klass: recv_class, method: method_id, cme: method }, state }); + } + /// Optimize SendWithoutBlock that land in a C method to a direct CCall without /// runtime lookup. fn optimize_c_calls(&mut self) { @@ -2552,11 +2572,6 @@ impl Function { return; } - fn gen_patch_points_for_optimized_ccall(fun: &mut Function, block: BlockId, recv_class: VALUE, method_id: ID, method: *const rb_callable_method_entry_struct, state: InsnId) { - fun.push_insn(block, Insn::PatchPoint { invariant: Invariant::NoTracePoint, state }); - fun.push_insn(block, Insn::PatchPoint { invariant: Invariant::MethodRedefined { klass: recv_class, method: method_id, cme: method }, state }); - } - // Try to reduce a Send insn to a CCallWithFrame fn reduce_send_to_ccall( fun: &mut Function, @@ -2615,7 +2630,7 @@ impl Function { } // Commit to the replacement. Put PatchPoint. - gen_patch_points_for_optimized_ccall(fun, block, recv_class, method_id, method, state); + fun.gen_patch_points_for_optimized_ccall(block, recv_class, method_id, method, state); if recv_class.instance_can_have_singleton_class() { fun.push_insn(block, Insn::PatchPoint { invariant: Invariant::NoSingletonClass { klass: recv_class }, state }); } @@ -2717,7 +2732,7 @@ impl Function { } // Commit to the replacement. Put PatchPoint. - gen_patch_points_for_optimized_ccall(fun, block, recv_class, method_id, method, state); + fun.gen_patch_points_for_optimized_ccall(block, recv_class, method_id, method, state); if recv_class.instance_can_have_singleton_class() { fun.push_insn(block, Insn::PatchPoint { invariant: Invariant::NoSingletonClass { klass: recv_class }, state }); } @@ -2784,7 +2799,7 @@ impl Function { // func(int argc, VALUE *argv, VALUE recv) let ci_flags = unsafe { vm_ci_flag(call_info) }; if ci_flags & VM_CALL_ARGS_SIMPLE != 0 { - gen_patch_points_for_optimized_ccall(fun, block, recv_class, method_id, method, state); + fun.gen_patch_points_for_optimized_ccall(block, recv_class, method_id, method, state); if recv_class.instance_can_have_singleton_class() { fun.push_insn(block, Insn::PatchPoint { invariant: Invariant::NoSingletonClass { klass: recv_class }, state }); @@ -13182,6 +13197,93 @@ mod opt_tests { } #[test] + fn test_specialize_nil_eqq() { + eval(" + def test(a, b) = a === b + test(nil, 5) + "); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:2: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal l0, SP@5 + v3:BasicObject = GetLocal l0, SP@4 + Jump bb2(v1, v2, v3) + bb1(v6:BasicObject, v7:BasicObject, v8:BasicObject): + EntryPoint JIT(0) + Jump bb2(v6, v7, v8) + bb2(v10:BasicObject, v11:BasicObject, v12:BasicObject): + PatchPoint MethodRedefined(NilClass@0x1000, ===@0x1008, cme:0x1010) + v25:NilClass = GuardType v11, NilClass + PatchPoint MethodRedefined(NilClass@0x1000, ==@0x1038, cme:0x1040) + v28:CBool = IsBitEqual v25, v12 + v29:BoolExact = BoxBool v28 + IncrCounter inline_cfunc_optimized_send_count + CheckInterrupts + Return v29 + "); + } + + #[test] + fn test_specialize_true_eqq() { + eval(" + def test(a, b) = a === b + test(true, 5) + "); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:2: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal l0, SP@5 + v3:BasicObject = GetLocal l0, SP@4 + Jump bb2(v1, v2, v3) + bb1(v6:BasicObject, v7:BasicObject, v8:BasicObject): + EntryPoint JIT(0) + Jump bb2(v6, v7, v8) + bb2(v10:BasicObject, v11:BasicObject, v12:BasicObject): + PatchPoint MethodRedefined(TrueClass@0x1000, ===@0x1008, cme:0x1010) + v25:TrueClass = GuardType v11, TrueClass + PatchPoint MethodRedefined(TrueClass@0x1000, ==@0x1038, cme:0x1040) + v28:CBool = IsBitEqual v25, v12 + v29:BoolExact = BoxBool v28 + IncrCounter inline_cfunc_optimized_send_count + CheckInterrupts + Return v29 + "); + } + + #[test] + fn test_specialize_false_eqq() { + eval(" + def test(a, b) = a === b + test(true, 5) + "); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:2: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal l0, SP@5 + v3:BasicObject = GetLocal l0, SP@4 + Jump bb2(v1, v2, v3) + bb1(v6:BasicObject, v7:BasicObject, v8:BasicObject): + EntryPoint JIT(0) + Jump bb2(v6, v7, v8) + bb2(v10:BasicObject, v11:BasicObject, v12:BasicObject): + PatchPoint MethodRedefined(TrueClass@0x1000, ===@0x1008, cme:0x1010) + v25:TrueClass = GuardType v11, TrueClass + PatchPoint MethodRedefined(TrueClass@0x1000, ==@0x1038, cme:0x1040) + v28:CBool = IsBitEqual v25, v12 + v29:BoolExact = BoxBool v28 + IncrCounter inline_cfunc_optimized_send_count + CheckInterrupts + Return v29 + "); + } + + #[test] fn test_guard_fixnum_and_fixnum() { eval(" def test(x, y) = x & y |
