summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMax Bernstein <ruby@bernsteinbear.com>2025-11-19 16:37:45 -0500
committerMax Bernstein <tekknolagi@gmail.com>2025-11-21 09:21:57 -0800
commitff89e470e21e9d021c6739d83eddda4bd8c071fe (patch)
treec3c89868d1ca88363c5e555e630b6ce15fc43155
parente0bb3fb1cda2238d0c98afcdec2fe282c29994aa (diff)
ZJIT: Specialize Module#=== and Kernel#is_a? into IsA
-rw-r--r--zjit/src/codegen.rs5
-rw-r--r--zjit/src/cruby_methods.rs21
-rw-r--r--zjit/src/hir.rs15
-rw-r--r--zjit/src/hir/opt_tests.rs185
4 files changed, 225 insertions, 1 deletions
diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs
index 9f74838c11..4c865dcd8a 100644
--- a/zjit/src/codegen.rs
+++ b/zjit/src/codegen.rs
@@ -472,6 +472,7 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio
Insn::ArrayInclude { elements, target, state } => gen_array_include(jit, asm, opnds!(elements), opnd!(target), &function.frame_state(*state)),
&Insn::DupArrayInclude { ary, target, state } => gen_dup_array_include(jit, asm, ary, opnd!(target), &function.frame_state(state)),
Insn::ArrayHash { elements, state } => gen_opt_newarray_hash(jit, asm, opnds!(elements), &function.frame_state(*state)),
+ &Insn::IsA { val, class } => gen_is_a(asm, opnd!(val), opnd!(class)),
&Insn::ArrayMax { state, .. }
| &Insn::FixnumDiv { state, .. }
| &Insn::Throw { state, .. }
@@ -1520,6 +1521,10 @@ fn gen_dup_array_include(
)
}
+fn gen_is_a(asm: &mut Assembler, obj: Opnd, class: Opnd) -> lir::Opnd {
+ asm_ccall!(asm, rb_obj_is_kind_of, obj, class)
+}
+
/// Compile a new hash instruction
fn gen_new_hash(
jit: &mut JITState,
diff --git a/zjit/src/cruby_methods.rs b/zjit/src/cruby_methods.rs
index 3999ef0a10..6d09b5e5a7 100644
--- a/zjit/src/cruby_methods.rs
+++ b/zjit/src/cruby_methods.rs
@@ -197,6 +197,7 @@ pub fn init() -> Annotations {
annotate!(rb_mKernel, "itself", inline_kernel_itself);
annotate!(rb_mKernel, "block_given?", inline_kernel_block_given_p);
annotate!(rb_mKernel, "===", inline_eqq);
+ annotate!(rb_mKernel, "is_a?", inline_kernel_is_a_p);
annotate!(rb_cString, "bytesize", inline_string_bytesize);
annotate!(rb_cString, "size", types::Fixnum, no_gc, leaf, elidable);
annotate!(rb_cString, "length", types::Fixnum, no_gc, leaf, elidable);
@@ -206,7 +207,7 @@ pub fn init() -> Annotations {
annotate!(rb_cString, "<<", inline_string_append);
annotate!(rb_cString, "==", inline_string_eq);
annotate!(rb_cModule, "name", types::StringExact.union(types::NilClass), no_gc, leaf, elidable);
- annotate!(rb_cModule, "===", types::BoolExact, no_gc, leaf);
+ annotate!(rb_cModule, "===", inline_module_eqq, types::BoolExact, no_gc, leaf);
annotate!(rb_cArray, "length", types::Fixnum, no_gc, leaf, elidable);
annotate!(rb_cArray, "size", types::Fixnum, no_gc, leaf, elidable);
annotate!(rb_cArray, "empty?", types::BoolExact, no_gc, leaf, elidable);
@@ -447,6 +448,15 @@ fn inline_string_eq(fun: &mut hir::Function, block: hir::BlockId, recv: hir::Ins
None
}
+fn inline_module_eqq(fun: &mut hir::Function, block: hir::BlockId, recv: hir::InsnId, args: &[hir::InsnId], _state: hir::InsnId) -> Option<hir::InsnId> {
+ let &[other] = args else { return None; };
+ if fun.is_a(recv, types::Class) {
+ let result = fun.push_insn(block, hir::Insn::IsA { val: other, class: recv });
+ return Some(result);
+ }
+ None
+}
+
fn inline_integer_succ(fun: &mut hir::Function, block: hir::BlockId, recv: hir::InsnId, args: &[hir::InsnId], state: hir::InsnId) -> Option<hir::InsnId> {
if !args.is_empty() { return None; }
if fun.likely_a(recv, types::Fixnum, state) {
@@ -613,6 +623,15 @@ fn inline_eqq(fun: &mut hir::Function, block: hir::BlockId, recv: hir::InsnId, a
Some(result)
}
+fn inline_kernel_is_a_p(fun: &mut hir::Function, block: hir::BlockId, recv: hir::InsnId, args: &[hir::InsnId], _state: hir::InsnId) -> Option<hir::InsnId> {
+ let &[other] = args else { return None; };
+ if fun.is_a(other, types::Class) {
+ let result = fun.push_insn(block, hir::Insn::IsA { val: recv, class: other });
+ return Some(result);
+ }
+ None
+}
+
fn inline_kernel_nil_p(fun: &mut hir::Function, block: hir::BlockId, _recv: hir::InsnId, args: &[hir::InsnId], _state: hir::InsnId) -> Option<hir::InsnId> {
if !args.is_empty() { return None; }
Some(fun.push_insn(block, hir::Insn::Const { val: hir::Const::Value(Qfalse) }))
diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs
index 00014c5758..3a69dd6610 100644
--- a/zjit/src/hir.rs
+++ b/zjit/src/hir.rs
@@ -721,6 +721,9 @@ pub enum Insn {
/// Test the bit at index of val, a Fixnum.
/// Return Qtrue if the bit is set, else Qfalse.
FixnumBitCheck { val: InsnId, index: u8 },
+ /// Return Qtrue if `val` is an instance of `class`, else Qfalse.
+ /// Equivalent to `class_search_ancestor(CLASS_OF(val), class)`.
+ IsA { val: InsnId, class: InsnId },
/// Get a global variable named `id`
GetGlobal { id: ID, state: InsnId },
@@ -1000,6 +1003,7 @@ impl Insn {
Insn::BoxFixnum { .. } => false,
Insn::BoxBool { .. } => false,
Insn::IsBitEqual { .. } => false,
+ Insn::IsA { .. } => false,
_ => true,
}
}
@@ -1324,6 +1328,7 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> {
}
Insn::IncrCounter(counter) => write!(f, "IncrCounter {counter:?}"),
Insn::CheckInterrupts { .. } => write!(f, "CheckInterrupts"),
+ Insn::IsA { val, class } => write!(f, "IsA {val}, {class}"),
}
}
}
@@ -1946,6 +1951,7 @@ impl Function {
&ArrayExtend { left, right, state } => ArrayExtend { left: find!(left), right: find!(right), state },
&ArrayPush { array, val, state } => ArrayPush { array: find!(array), val: find!(val), state },
&CheckInterrupts { state } => CheckInterrupts { state },
+ &IsA { val, class } => IsA { val: find!(val), class: find!(class) },
}
}
@@ -2095,6 +2101,7 @@ impl Function {
// The type of Snapshot doesn't really matter; it's never materialized. It's used only
// as a reference for FrameState, which we use to generate side-exit code.
Insn::Snapshot { .. } => types::Any,
+ Insn::IsA { .. } => types::BoolExact,
}
}
@@ -3622,6 +3629,10 @@ impl Function {
&Insn::ObjectAllocClass { state, .. } |
&Insn::SideExit { state, .. } => worklist.push_back(state),
&Insn::UnboxFixnum { val } => worklist.push_back(val),
+ &Insn::IsA { val, class } => {
+ worklist.push_back(val);
+ worklist.push_back(class);
+ }
}
}
@@ -4314,6 +4325,10 @@ impl Function {
self.assert_subtype(insn_id, index, types::Fixnum)?;
self.assert_subtype(insn_id, value, types::Fixnum)
}
+ Insn::IsA { val, class } => {
+ self.assert_subtype(insn_id, val, types::BasicObject)?;
+ self.assert_subtype(insn_id, class, types::Class)
+ }
}
}
diff --git a/zjit/src/hir/opt_tests.rs b/zjit/src/hir/opt_tests.rs
index 19c0ce66e3..9704afcf6e 100644
--- a/zjit/src/hir/opt_tests.rs
+++ b/zjit/src/hir/opt_tests.rs
@@ -7984,6 +7984,191 @@ mod hir_opt_tests {
}
#[test]
+ fn test_specialize_class_eqq() {
+ eval(r#"
+ def test(o) = String === o
+ test("asdf")
+ "#);
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:2:
+ bb0():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ v2:BasicObject = GetLocal l0, SP@4
+ Jump bb2(v1, v2)
+ bb1(v5:BasicObject, v6:BasicObject):
+ EntryPoint JIT(0)
+ Jump bb2(v5, v6)
+ bb2(v8:BasicObject, v9:BasicObject):
+ PatchPoint SingleRactorMode
+ PatchPoint StableConstantNames(0x1000, String)
+ v26:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008))
+ PatchPoint NoEPEscape(test)
+ PatchPoint MethodRedefined(Class@0x1010, ===@0x1018, cme:0x1020)
+ PatchPoint NoSingletonClass(Class@0x1010)
+ v30:BoolExact = IsA v9, v26
+ IncrCounter inline_cfunc_optimized_send_count
+ CheckInterrupts
+ Return v30
+ ");
+ }
+
+ #[test]
+ fn test_dont_specialize_module_eqq() {
+ eval(r#"
+ def test(o) = Kernel === o
+ test("asdf")
+ "#);
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:2:
+ bb0():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ v2:BasicObject = GetLocal l0, SP@4
+ Jump bb2(v1, v2)
+ bb1(v5:BasicObject, v6:BasicObject):
+ EntryPoint JIT(0)
+ Jump bb2(v5, v6)
+ bb2(v8:BasicObject, v9:BasicObject):
+ PatchPoint SingleRactorMode
+ PatchPoint StableConstantNames(0x1000, Kernel)
+ v26:ModuleExact[VALUE(0x1008)] = Const Value(VALUE(0x1008))
+ PatchPoint NoEPEscape(test)
+ PatchPoint MethodRedefined(Module@0x1010, ===@0x1018, cme:0x1020)
+ PatchPoint NoSingletonClass(Module@0x1010)
+ IncrCounter inline_cfunc_optimized_send_count
+ v31:BoolExact = CCall Module#===@0x1048, v26, v9
+ CheckInterrupts
+ Return v31
+ ");
+ }
+
+ #[test]
+ fn test_specialize_is_a_class() {
+ eval(r#"
+ def test(o) = o.is_a?(String)
+ test("asdf")
+ "#);
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:2:
+ bb0():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ v2:BasicObject = GetLocal l0, SP@4
+ Jump bb2(v1, v2)
+ bb1(v5:BasicObject, v6:BasicObject):
+ EntryPoint JIT(0)
+ Jump bb2(v5, v6)
+ bb2(v8:BasicObject, v9:BasicObject):
+ PatchPoint SingleRactorMode
+ PatchPoint StableConstantNames(0x1000, String)
+ v24:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008))
+ PatchPoint MethodRedefined(String@0x1008, is_a?@0x1010, cme:0x1018)
+ PatchPoint NoSingletonClass(String@0x1008)
+ v28:StringExact = GuardType v9, StringExact
+ v29:BoolExact = IsA v28, v24
+ IncrCounter inline_cfunc_optimized_send_count
+ CheckInterrupts
+ Return v29
+ ");
+ }
+
+ #[test]
+ fn test_dont_specialize_is_a_module() {
+ eval(r#"
+ def test(o) = o.is_a?(Kernel)
+ test("asdf")
+ "#);
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:2:
+ bb0():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ v2:BasicObject = GetLocal l0, SP@4
+ Jump bb2(v1, v2)
+ bb1(v5:BasicObject, v6:BasicObject):
+ EntryPoint JIT(0)
+ Jump bb2(v5, v6)
+ bb2(v8:BasicObject, v9:BasicObject):
+ PatchPoint SingleRactorMode
+ PatchPoint StableConstantNames(0x1000, Kernel)
+ v24:ModuleExact[VALUE(0x1008)] = Const Value(VALUE(0x1008))
+ PatchPoint MethodRedefined(String@0x1010, is_a?@0x1018, cme:0x1020)
+ PatchPoint NoSingletonClass(String@0x1010)
+ v28:StringExact = GuardType v9, StringExact
+ v29:BasicObject = CCallWithFrame Kernel#is_a?@0x1048, v28, v24
+ CheckInterrupts
+ Return v29
+ ");
+ }
+
+ #[test]
+ fn test_elide_is_a() {
+ eval(r#"
+ def test(o)
+ o.is_a?(Integer)
+ 5
+ end
+ test("asdf")
+ "#);
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:3:
+ bb0():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ v2:BasicObject = GetLocal l0, SP@4
+ Jump bb2(v1, v2)
+ bb1(v5:BasicObject, v6:BasicObject):
+ EntryPoint JIT(0)
+ Jump bb2(v5, v6)
+ bb2(v8:BasicObject, v9:BasicObject):
+ PatchPoint SingleRactorMode
+ PatchPoint StableConstantNames(0x1000, Integer)
+ v28:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008))
+ PatchPoint MethodRedefined(String@0x1010, is_a?@0x1018, cme:0x1020)
+ PatchPoint NoSingletonClass(String@0x1010)
+ v32:StringExact = GuardType v9, StringExact
+ IncrCounter inline_cfunc_optimized_send_count
+ v21:Fixnum[5] = Const Value(5)
+ CheckInterrupts
+ Return v21
+ ");
+ }
+
+ #[test]
+ fn test_elide_class_eqq() {
+ eval(r#"
+ def test(o)
+ Integer === o
+ 5
+ end
+ test("asdf")
+ "#);
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:3:
+ bb0():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ v2:BasicObject = GetLocal l0, SP@4
+ Jump bb2(v1, v2)
+ bb1(v5:BasicObject, v6:BasicObject):
+ EntryPoint JIT(0)
+ Jump bb2(v5, v6)
+ bb2(v8:BasicObject, v9:BasicObject):
+ PatchPoint SingleRactorMode
+ PatchPoint StableConstantNames(0x1000, Integer)
+ v30:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008))
+ PatchPoint NoEPEscape(test)
+ PatchPoint MethodRedefined(Class@0x1010, ===@0x1018, cme:0x1020)
+ PatchPoint NoSingletonClass(Class@0x1010)
+ IncrCounter inline_cfunc_optimized_send_count
+ v23:Fixnum[5] = Const Value(5)
+ CheckInterrupts
+ Return v23
+ ");
+ }
+
+ #[test]
fn counting_complex_feature_use_for_fallback() {
eval("
define_method(:fancy) { |_a, *_b, kw: 100, **kw_rest, &block| }