diff options
| author | Max Bernstein <ruby@bernsteinbear.com> | 2025-11-19 16:37:45 -0500 |
|---|---|---|
| committer | Max Bernstein <tekknolagi@gmail.com> | 2025-11-21 09:21:57 -0800 |
| commit | ff89e470e21e9d021c6739d83eddda4bd8c071fe (patch) | |
| tree | c3c89868d1ca88363c5e555e630b6ce15fc43155 | |
| parent | e0bb3fb1cda2238d0c98afcdec2fe282c29994aa (diff) | |
ZJIT: Specialize Module#=== and Kernel#is_a? into IsA
| -rw-r--r-- | zjit/src/codegen.rs | 5 | ||||
| -rw-r--r-- | zjit/src/cruby_methods.rs | 21 | ||||
| -rw-r--r-- | zjit/src/hir.rs | 15 | ||||
| -rw-r--r-- | zjit/src/hir/opt_tests.rs | 185 |
4 files changed, 225 insertions, 1 deletions
diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs index 9f74838c11..4c865dcd8a 100644 --- a/zjit/src/codegen.rs +++ b/zjit/src/codegen.rs @@ -472,6 +472,7 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio Insn::ArrayInclude { elements, target, state } => gen_array_include(jit, asm, opnds!(elements), opnd!(target), &function.frame_state(*state)), &Insn::DupArrayInclude { ary, target, state } => gen_dup_array_include(jit, asm, ary, opnd!(target), &function.frame_state(state)), Insn::ArrayHash { elements, state } => gen_opt_newarray_hash(jit, asm, opnds!(elements), &function.frame_state(*state)), + &Insn::IsA { val, class } => gen_is_a(asm, opnd!(val), opnd!(class)), &Insn::ArrayMax { state, .. } | &Insn::FixnumDiv { state, .. } | &Insn::Throw { state, .. } @@ -1520,6 +1521,10 @@ fn gen_dup_array_include( ) } +fn gen_is_a(asm: &mut Assembler, obj: Opnd, class: Opnd) -> lir::Opnd { + asm_ccall!(asm, rb_obj_is_kind_of, obj, class) +} + /// Compile a new hash instruction fn gen_new_hash( jit: &mut JITState, diff --git a/zjit/src/cruby_methods.rs b/zjit/src/cruby_methods.rs index 3999ef0a10..6d09b5e5a7 100644 --- a/zjit/src/cruby_methods.rs +++ b/zjit/src/cruby_methods.rs @@ -197,6 +197,7 @@ pub fn init() -> Annotations { annotate!(rb_mKernel, "itself", inline_kernel_itself); annotate!(rb_mKernel, "block_given?", inline_kernel_block_given_p); annotate!(rb_mKernel, "===", inline_eqq); + annotate!(rb_mKernel, "is_a?", inline_kernel_is_a_p); annotate!(rb_cString, "bytesize", inline_string_bytesize); annotate!(rb_cString, "size", types::Fixnum, no_gc, leaf, elidable); annotate!(rb_cString, "length", types::Fixnum, no_gc, leaf, elidable); @@ -206,7 +207,7 @@ pub fn init() -> Annotations { annotate!(rb_cString, "<<", inline_string_append); annotate!(rb_cString, "==", inline_string_eq); annotate!(rb_cModule, "name", types::StringExact.union(types::NilClass), no_gc, leaf, elidable); - annotate!(rb_cModule, "===", types::BoolExact, no_gc, leaf); + annotate!(rb_cModule, "===", inline_module_eqq, types::BoolExact, no_gc, leaf); annotate!(rb_cArray, "length", types::Fixnum, no_gc, leaf, elidable); annotate!(rb_cArray, "size", types::Fixnum, no_gc, leaf, elidable); annotate!(rb_cArray, "empty?", types::BoolExact, no_gc, leaf, elidable); @@ -447,6 +448,15 @@ fn inline_string_eq(fun: &mut hir::Function, block: hir::BlockId, recv: hir::Ins None } +fn inline_module_eqq(fun: &mut hir::Function, block: hir::BlockId, recv: hir::InsnId, args: &[hir::InsnId], _state: hir::InsnId) -> Option<hir::InsnId> { + let &[other] = args else { return None; }; + if fun.is_a(recv, types::Class) { + let result = fun.push_insn(block, hir::Insn::IsA { val: other, class: recv }); + return Some(result); + } + None +} + fn inline_integer_succ(fun: &mut hir::Function, block: hir::BlockId, recv: hir::InsnId, args: &[hir::InsnId], state: hir::InsnId) -> Option<hir::InsnId> { if !args.is_empty() { return None; } if fun.likely_a(recv, types::Fixnum, state) { @@ -613,6 +623,15 @@ fn inline_eqq(fun: &mut hir::Function, block: hir::BlockId, recv: hir::InsnId, a Some(result) } +fn inline_kernel_is_a_p(fun: &mut hir::Function, block: hir::BlockId, recv: hir::InsnId, args: &[hir::InsnId], _state: hir::InsnId) -> Option<hir::InsnId> { + let &[other] = args else { return None; }; + if fun.is_a(other, types::Class) { + let result = fun.push_insn(block, hir::Insn::IsA { val: recv, class: other }); + return Some(result); + } + None +} + fn inline_kernel_nil_p(fun: &mut hir::Function, block: hir::BlockId, _recv: hir::InsnId, args: &[hir::InsnId], _state: hir::InsnId) -> Option<hir::InsnId> { if !args.is_empty() { return None; } Some(fun.push_insn(block, hir::Insn::Const { val: hir::Const::Value(Qfalse) })) diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index 00014c5758..3a69dd6610 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -721,6 +721,9 @@ pub enum Insn { /// Test the bit at index of val, a Fixnum. /// Return Qtrue if the bit is set, else Qfalse. FixnumBitCheck { val: InsnId, index: u8 }, + /// Return Qtrue if `val` is an instance of `class`, else Qfalse. + /// Equivalent to `class_search_ancestor(CLASS_OF(val), class)`. + IsA { val: InsnId, class: InsnId }, /// Get a global variable named `id` GetGlobal { id: ID, state: InsnId }, @@ -1000,6 +1003,7 @@ impl Insn { Insn::BoxFixnum { .. } => false, Insn::BoxBool { .. } => false, Insn::IsBitEqual { .. } => false, + Insn::IsA { .. } => false, _ => true, } } @@ -1324,6 +1328,7 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> { } Insn::IncrCounter(counter) => write!(f, "IncrCounter {counter:?}"), Insn::CheckInterrupts { .. } => write!(f, "CheckInterrupts"), + Insn::IsA { val, class } => write!(f, "IsA {val}, {class}"), } } } @@ -1946,6 +1951,7 @@ impl Function { &ArrayExtend { left, right, state } => ArrayExtend { left: find!(left), right: find!(right), state }, &ArrayPush { array, val, state } => ArrayPush { array: find!(array), val: find!(val), state }, &CheckInterrupts { state } => CheckInterrupts { state }, + &IsA { val, class } => IsA { val: find!(val), class: find!(class) }, } } @@ -2095,6 +2101,7 @@ impl Function { // The type of Snapshot doesn't really matter; it's never materialized. It's used only // as a reference for FrameState, which we use to generate side-exit code. Insn::Snapshot { .. } => types::Any, + Insn::IsA { .. } => types::BoolExact, } } @@ -3622,6 +3629,10 @@ impl Function { &Insn::ObjectAllocClass { state, .. } | &Insn::SideExit { state, .. } => worklist.push_back(state), &Insn::UnboxFixnum { val } => worklist.push_back(val), + &Insn::IsA { val, class } => { + worklist.push_back(val); + worklist.push_back(class); + } } } @@ -4314,6 +4325,10 @@ impl Function { self.assert_subtype(insn_id, index, types::Fixnum)?; self.assert_subtype(insn_id, value, types::Fixnum) } + Insn::IsA { val, class } => { + self.assert_subtype(insn_id, val, types::BasicObject)?; + self.assert_subtype(insn_id, class, types::Class) + } } } diff --git a/zjit/src/hir/opt_tests.rs b/zjit/src/hir/opt_tests.rs index 19c0ce66e3..9704afcf6e 100644 --- a/zjit/src/hir/opt_tests.rs +++ b/zjit/src/hir/opt_tests.rs @@ -7984,6 +7984,191 @@ mod hir_opt_tests { } #[test] + fn test_specialize_class_eqq() { + eval(r#" + def test(o) = String === o + test("asdf") + "#); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:2: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal l0, SP@4 + Jump bb2(v1, v2) + bb1(v5:BasicObject, v6:BasicObject): + EntryPoint JIT(0) + Jump bb2(v5, v6) + bb2(v8:BasicObject, v9:BasicObject): + PatchPoint SingleRactorMode + PatchPoint StableConstantNames(0x1000, String) + v26:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008)) + PatchPoint NoEPEscape(test) + PatchPoint MethodRedefined(Class@0x1010, ===@0x1018, cme:0x1020) + PatchPoint NoSingletonClass(Class@0x1010) + v30:BoolExact = IsA v9, v26 + IncrCounter inline_cfunc_optimized_send_count + CheckInterrupts + Return v30 + "); + } + + #[test] + fn test_dont_specialize_module_eqq() { + eval(r#" + def test(o) = Kernel === o + test("asdf") + "#); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:2: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal l0, SP@4 + Jump bb2(v1, v2) + bb1(v5:BasicObject, v6:BasicObject): + EntryPoint JIT(0) + Jump bb2(v5, v6) + bb2(v8:BasicObject, v9:BasicObject): + PatchPoint SingleRactorMode + PatchPoint StableConstantNames(0x1000, Kernel) + v26:ModuleExact[VALUE(0x1008)] = Const Value(VALUE(0x1008)) + PatchPoint NoEPEscape(test) + PatchPoint MethodRedefined(Module@0x1010, ===@0x1018, cme:0x1020) + PatchPoint NoSingletonClass(Module@0x1010) + IncrCounter inline_cfunc_optimized_send_count + v31:BoolExact = CCall Module#===@0x1048, v26, v9 + CheckInterrupts + Return v31 + "); + } + + #[test] + fn test_specialize_is_a_class() { + eval(r#" + def test(o) = o.is_a?(String) + test("asdf") + "#); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:2: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal l0, SP@4 + Jump bb2(v1, v2) + bb1(v5:BasicObject, v6:BasicObject): + EntryPoint JIT(0) + Jump bb2(v5, v6) + bb2(v8:BasicObject, v9:BasicObject): + PatchPoint SingleRactorMode + PatchPoint StableConstantNames(0x1000, String) + v24:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008)) + PatchPoint MethodRedefined(String@0x1008, is_a?@0x1010, cme:0x1018) + PatchPoint NoSingletonClass(String@0x1008) + v28:StringExact = GuardType v9, StringExact + v29:BoolExact = IsA v28, v24 + IncrCounter inline_cfunc_optimized_send_count + CheckInterrupts + Return v29 + "); + } + + #[test] + fn test_dont_specialize_is_a_module() { + eval(r#" + def test(o) = o.is_a?(Kernel) + test("asdf") + "#); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:2: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal l0, SP@4 + Jump bb2(v1, v2) + bb1(v5:BasicObject, v6:BasicObject): + EntryPoint JIT(0) + Jump bb2(v5, v6) + bb2(v8:BasicObject, v9:BasicObject): + PatchPoint SingleRactorMode + PatchPoint StableConstantNames(0x1000, Kernel) + v24:ModuleExact[VALUE(0x1008)] = Const Value(VALUE(0x1008)) + PatchPoint MethodRedefined(String@0x1010, is_a?@0x1018, cme:0x1020) + PatchPoint NoSingletonClass(String@0x1010) + v28:StringExact = GuardType v9, StringExact + v29:BasicObject = CCallWithFrame Kernel#is_a?@0x1048, v28, v24 + CheckInterrupts + Return v29 + "); + } + + #[test] + fn test_elide_is_a() { + eval(r#" + def test(o) + o.is_a?(Integer) + 5 + end + test("asdf") + "#); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:3: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal l0, SP@4 + Jump bb2(v1, v2) + bb1(v5:BasicObject, v6:BasicObject): + EntryPoint JIT(0) + Jump bb2(v5, v6) + bb2(v8:BasicObject, v9:BasicObject): + PatchPoint SingleRactorMode + PatchPoint StableConstantNames(0x1000, Integer) + v28:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008)) + PatchPoint MethodRedefined(String@0x1010, is_a?@0x1018, cme:0x1020) + PatchPoint NoSingletonClass(String@0x1010) + v32:StringExact = GuardType v9, StringExact + IncrCounter inline_cfunc_optimized_send_count + v21:Fixnum[5] = Const Value(5) + CheckInterrupts + Return v21 + "); + } + + #[test] + fn test_elide_class_eqq() { + eval(r#" + def test(o) + Integer === o + 5 + end + test("asdf") + "#); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:3: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal l0, SP@4 + Jump bb2(v1, v2) + bb1(v5:BasicObject, v6:BasicObject): + EntryPoint JIT(0) + Jump bb2(v5, v6) + bb2(v8:BasicObject, v9:BasicObject): + PatchPoint SingleRactorMode + PatchPoint StableConstantNames(0x1000, Integer) + v30:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008)) + PatchPoint NoEPEscape(test) + PatchPoint MethodRedefined(Class@0x1010, ===@0x1018, cme:0x1020) + PatchPoint NoSingletonClass(Class@0x1010) + IncrCounter inline_cfunc_optimized_send_count + v23:Fixnum[5] = Const Value(5) + CheckInterrupts + Return v23 + "); + } + + #[test] fn counting_complex_feature_use_for_fallback() { eval(" define_method(:fancy) { |_a, *_b, kw: 100, **kw_rest, &block| } |
