summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStan Lo <stan.lo@shopify.com>2025-07-04 21:41:32 +0100
committerMax Bernstein <tekknolagi@gmail.com>2025-07-08 12:28:03 -0400
commit6c20082852a2f69a11d950b98ff179b2b737ed67 (patch)
tree27a6e75dad545dbe51dd7cae890648c2011d937c
parentf5acefca44951dcaec53324826e4078a3c3ce6f9 (diff)
ZJIT: Support inference of ModuleExact type
-rw-r--r--zjit/src/cruby.rs1
-rw-r--r--zjit/src/hir.rs26
-rw-r--r--zjit/src/hir_type/gen_hir_type.rb1
-rw-r--r--zjit/src/hir_type/hir_type.inc.rs53
-rw-r--r--zjit/src/hir_type/mod.rs23
5 files changed, 73 insertions, 31 deletions
diff --git a/zjit/src/cruby.rs b/zjit/src/cruby.rs
index 82f0e39804..459c7d7d5d 100644
--- a/zjit/src/cruby.rs
+++ b/zjit/src/cruby.rs
@@ -158,6 +158,7 @@ unsafe extern "C" {
pub fn rb_vm_ic_hit_p(ic: IC, reg_ep: *const VALUE) -> bool;
pub fn rb_vm_stack_canary() -> VALUE;
pub fn rb_vm_push_cfunc_frame(cme: *const rb_callable_method_entry_t, recv_idx: c_int);
+ pub fn rb_obj_class(klass: VALUE) -> VALUE;
}
// Renames
diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs
index a62f4898d7..20ffe17683 100644
--- a/zjit/src/hir.rs
+++ b/zjit/src/hir.rs
@@ -5921,7 +5921,7 @@ mod opt_tests {
}
#[test]
- fn module_instances_not_class_exact() {
+ fn module_instances_are_module_exact() {
eval("
def test = [Enumerable, Kernel]
test # Warm the constant cache
@@ -5931,16 +5931,34 @@ mod opt_tests {
bb0(v0:BasicObject):
PatchPoint SingleRactorMode
PatchPoint StableConstantNames(0x1000, Enumerable)
- v11:BasicObject[VALUE(0x1008)] = Const Value(VALUE(0x1008))
+ v11:ModuleExact[VALUE(0x1008)] = Const Value(VALUE(0x1008))
PatchPoint SingleRactorMode
PatchPoint StableConstantNames(0x1010, Kernel)
- v14:BasicObject[VALUE(0x1018)] = Const Value(VALUE(0x1018))
+ v14:ModuleExact[VALUE(0x1018)] = Const Value(VALUE(0x1018))
v7:ArrayExact = NewArray v11, v14
Return v7
"#]]);
}
#[test]
+ fn module_subclasses_are_not_module_exact() {
+ eval("
+ class ModuleSubclass < Module; end
+ MY_MODULE = ModuleSubclass.new
+ def test = MY_MODULE
+ test # Warm the constant cache
+ ");
+ assert_optimized_method_hir("test", expect![[r#"
+ fn test:
+ bb0(v0:BasicObject):
+ PatchPoint SingleRactorMode
+ PatchPoint StableConstantNames(0x1000, MY_MODULE)
+ v7:BasicObject[VALUE(0x1008)] = Const Value(VALUE(0x1008))
+ Return v7
+ "#]]);
+ }
+
+ #[test]
fn eliminate_array_size() {
eval("
def test
@@ -6067,7 +6085,7 @@ mod opt_tests {
bb0(v0:BasicObject):
PatchPoint SingleRactorMode
PatchPoint StableConstantNames(0x1000, Kernel)
- v7:BasicObject[VALUE(0x1008)] = Const Value(VALUE(0x1008))
+ v7:ModuleExact[VALUE(0x1008)] = Const Value(VALUE(0x1008))
Return v7
"#]]);
}
diff --git a/zjit/src/hir_type/gen_hir_type.rb b/zjit/src/hir_type/gen_hir_type.rb
index 50259a9816..bbd0e1ed32 100644
--- a/zjit/src/hir_type/gen_hir_type.rb
+++ b/zjit/src/hir_type/gen_hir_type.rb
@@ -75,6 +75,7 @@ base_type "Range"
base_type "Set"
base_type "Regexp"
base_type "Class"
+base_type "Module"
(integer, integer_exact) = base_type "Integer"
# CRuby partitions Integer into immediate and non-immediate variants.
diff --git a/zjit/src/hir_type/hir_type.inc.rs b/zjit/src/hir_type/hir_type.inc.rs
index 57349aba14..2c6fb48ea5 100644
--- a/zjit/src/hir_type/hir_type.inc.rs
+++ b/zjit/src/hir_type/hir_type.inc.rs
@@ -9,7 +9,7 @@ mod bits {
pub const BasicObjectSubclass: u64 = 1u64 << 3;
pub const Bignum: u64 = 1u64 << 4;
pub const BoolExact: u64 = FalseClassExact | TrueClassExact;
- pub const BuiltinExact: u64 = ArrayExact | BasicObjectExact | ClassExact | FalseClassExact | FloatExact | HashExact | IntegerExact | NilClassExact | ObjectExact | RangeExact | RegexpExact | SetExact | StringExact | SymbolExact | TrueClassExact;
+ pub const BuiltinExact: u64 = ArrayExact | BasicObjectExact | ClassExact | FalseClassExact | FloatExact | HashExact | IntegerExact | ModuleExact | NilClassExact | ObjectExact | RangeExact | RegexpExact | SetExact | StringExact | SymbolExact | TrueClassExact;
pub const CBool: u64 = 1u64 << 5;
pub const CDouble: u64 = 1u64 << 6;
pub const CInt: u64 = CSigned | CUnsigned;
@@ -48,35 +48,38 @@ mod bits {
pub const Integer: u64 = IntegerExact | IntegerSubclass;
pub const IntegerExact: u64 = Bignum | Fixnum;
pub const IntegerSubclass: u64 = 1u64 << 29;
+ pub const Module: u64 = ModuleExact | ModuleSubclass;
+ pub const ModuleExact: u64 = 1u64 << 30;
+ pub const ModuleSubclass: u64 = 1u64 << 31;
pub const NilClass: u64 = NilClassExact | NilClassSubclass;
- pub const NilClassExact: u64 = 1u64 << 30;
- pub const NilClassSubclass: u64 = 1u64 << 31;
- pub const Object: u64 = Array | Class | FalseClass | Float | Hash | Integer | NilClass | ObjectExact | ObjectSubclass | Range | Regexp | Set | String | Symbol | TrueClass;
- pub const ObjectExact: u64 = 1u64 << 32;
- pub const ObjectSubclass: u64 = 1u64 << 33;
+ pub const NilClassExact: u64 = 1u64 << 32;
+ pub const NilClassSubclass: u64 = 1u64 << 33;
+ pub const Object: u64 = Array | Class | FalseClass | Float | Hash | Integer | Module | NilClass | ObjectExact | ObjectSubclass | Range | Regexp | Set | String | Symbol | TrueClass;
+ pub const ObjectExact: u64 = 1u64 << 34;
+ pub const ObjectSubclass: u64 = 1u64 << 35;
pub const Range: u64 = RangeExact | RangeSubclass;
- pub const RangeExact: u64 = 1u64 << 34;
- pub const RangeSubclass: u64 = 1u64 << 35;
+ pub const RangeExact: u64 = 1u64 << 36;
+ pub const RangeSubclass: u64 = 1u64 << 37;
pub const Regexp: u64 = RegexpExact | RegexpSubclass;
- pub const RegexpExact: u64 = 1u64 << 36;
- pub const RegexpSubclass: u64 = 1u64 << 37;
+ pub const RegexpExact: u64 = 1u64 << 38;
+ pub const RegexpSubclass: u64 = 1u64 << 39;
pub const RubyValue: u64 = BasicObject | CallableMethodEntry | Undef;
pub const Set: u64 = SetExact | SetSubclass;
- pub const SetExact: u64 = 1u64 << 38;
- pub const SetSubclass: u64 = 1u64 << 39;
- pub const StaticSymbol: u64 = 1u64 << 40;
+ pub const SetExact: u64 = 1u64 << 40;
+ pub const SetSubclass: u64 = 1u64 << 41;
+ pub const StaticSymbol: u64 = 1u64 << 42;
pub const String: u64 = StringExact | StringSubclass;
- pub const StringExact: u64 = 1u64 << 41;
- pub const StringSubclass: u64 = 1u64 << 42;
- pub const Subclass: u64 = ArraySubclass | BasicObjectSubclass | ClassSubclass | FalseClassSubclass | FloatSubclass | HashSubclass | IntegerSubclass | NilClassSubclass | ObjectSubclass | RangeSubclass | RegexpSubclass | SetSubclass | StringSubclass | SymbolSubclass | TrueClassSubclass;
+ pub const StringExact: u64 = 1u64 << 43;
+ pub const StringSubclass: u64 = 1u64 << 44;
+ pub const Subclass: u64 = ArraySubclass | BasicObjectSubclass | ClassSubclass | FalseClassSubclass | FloatSubclass | HashSubclass | IntegerSubclass | ModuleSubclass | NilClassSubclass | ObjectSubclass | RangeSubclass | RegexpSubclass | SetSubclass | StringSubclass | SymbolSubclass | TrueClassSubclass;
pub const Symbol: u64 = SymbolExact | SymbolSubclass;
pub const SymbolExact: u64 = DynamicSymbol | StaticSymbol;
- pub const SymbolSubclass: u64 = 1u64 << 43;
+ pub const SymbolSubclass: u64 = 1u64 << 45;
pub const TrueClass: u64 = TrueClassExact | TrueClassSubclass;
- pub const TrueClassExact: u64 = 1u64 << 44;
- pub const TrueClassSubclass: u64 = 1u64 << 45;
- pub const Undef: u64 = 1u64 << 46;
- pub const AllBitPatterns: [(&'static str, u64); 76] = [
+ pub const TrueClassExact: u64 = 1u64 << 46;
+ pub const TrueClassSubclass: u64 = 1u64 << 47;
+ pub const Undef: u64 = 1u64 << 48;
+ pub const AllBitPatterns: [(&'static str, u64); 79] = [
("Any", Any),
("RubyValue", RubyValue),
("Immediate", Immediate),
@@ -110,6 +113,9 @@ mod bits {
("NilClass", NilClass),
("NilClassSubclass", NilClassSubclass),
("NilClassExact", NilClassExact),
+ ("Module", Module),
+ ("ModuleSubclass", ModuleSubclass),
+ ("ModuleExact", ModuleExact),
("Integer", Integer),
("IntegerSubclass", IntegerSubclass),
("Float", Float),
@@ -154,7 +160,7 @@ mod bits {
("ArrayExact", ArrayExact),
("Empty", Empty),
];
- pub const NumTypeBits: u64 = 47;
+ pub const NumTypeBits: u64 = 49;
}
pub mod types {
use super::*;
@@ -206,6 +212,9 @@ pub mod types {
pub const Integer: Type = Type::from_bits(bits::Integer);
pub const IntegerExact: Type = Type::from_bits(bits::IntegerExact);
pub const IntegerSubclass: Type = Type::from_bits(bits::IntegerSubclass);
+ pub const Module: Type = Type::from_bits(bits::Module);
+ pub const ModuleExact: Type = Type::from_bits(bits::ModuleExact);
+ pub const ModuleSubclass: Type = Type::from_bits(bits::ModuleSubclass);
pub const NilClass: Type = Type::from_bits(bits::NilClass);
pub const NilClassExact: Type = Type::from_bits(bits::NilClassExact);
pub const NilClassSubclass: Type = Type::from_bits(bits::NilClassSubclass);
diff --git a/zjit/src/hir_type/mod.rs b/zjit/src/hir_type/mod.rs
index 422055e6d0..907582c251 100644
--- a/zjit/src/hir_type/mod.rs
+++ b/zjit/src/hir_type/mod.rs
@@ -1,10 +1,11 @@
#![allow(non_upper_case_globals)]
-use crate::cruby::{Qfalse, Qnil, Qtrue, VALUE, RUBY_T_ARRAY, RUBY_T_STRING, RUBY_T_HASH, RUBY_T_CLASS};
+use crate::cruby::{Qfalse, Qnil, Qtrue, VALUE, RUBY_T_ARRAY, RUBY_T_STRING, RUBY_T_HASH, RUBY_T_CLASS, RUBY_T_MODULE};
use crate::cruby::{rb_cInteger, rb_cFloat, rb_cArray, rb_cHash, rb_cString, rb_cSymbol, rb_cObject, rb_cTrueClass, rb_cFalseClass, rb_cNilClass, rb_cRange, rb_cSet, rb_cRegexp, rb_cClass, rb_cModule};
use crate::cruby::ClassRelationship;
use crate::cruby::get_class_name;
use crate::cruby::ruby_sym_to_rust_string;
use crate::cruby::rb_mRubyVMFrozenCore;
+use crate::cruby::rb_obj_class;
use crate::hir::PtrPrintMap;
#[derive(Copy, Clone, Debug, PartialEq)]
@@ -145,9 +146,16 @@ fn is_range_exact(val: VALUE) -> bool {
val.class_of() == unsafe { rb_cRange }
}
-fn is_class_exact(val: VALUE) -> bool {
- // Objects with RUBY_T_CLASS type and not instances of Module
- val.builtin_type() == RUBY_T_CLASS && val.class_of() != unsafe { rb_cModule }
+fn is_module_exact(val: VALUE) -> bool {
+ if val.builtin_type() != RUBY_T_MODULE {
+ return false;
+ }
+
+ // For Class and Module instances, `class_of` will return the singleton class of the object.
+ // Using `rb_obj_class` will give us the actual class of the module so we can check if the
+ // object is an instance of Module, or an instance of Module subclass.
+ let klass = unsafe { rb_obj_class(val) };
+ klass == unsafe { rb_cModule }
}
impl Type {
@@ -202,7 +210,10 @@ impl Type {
else if is_string_exact(val) {
Type { bits: bits::StringExact, spec: Specialization::Object(val) }
}
- else if is_class_exact(val) {
+ else if is_module_exact(val) {
+ Type { bits: bits::ModuleExact, spec: Specialization::Object(val) }
+ }
+ else if val.builtin_type() == RUBY_T_CLASS {
Type { bits: bits::ClassExact, spec: Specialization::Object(val) }
}
else if val.class_of() == unsafe { rb_cRegexp } {
@@ -301,6 +312,7 @@ impl Type {
if class == unsafe { rb_cFloat } { return true; }
if class == unsafe { rb_cHash } { return true; }
if class == unsafe { rb_cInteger } { return true; }
+ if class == unsafe { rb_cModule } { return true; }
if class == unsafe { rb_cNilClass } { return true; }
if class == unsafe { rb_cObject } { return true; }
if class == unsafe { rb_cRange } { return true; }
@@ -410,6 +422,7 @@ impl Type {
if self.is_subtype(types::FloatExact) { return Some(unsafe { rb_cFloat }); }
if self.is_subtype(types::HashExact) { return Some(unsafe { rb_cHash }); }
if self.is_subtype(types::IntegerExact) { return Some(unsafe { rb_cInteger }); }
+ if self.is_subtype(types::ModuleExact) { return Some(unsafe { rb_cModule }); }
if self.is_subtype(types::NilClassExact) { return Some(unsafe { rb_cNilClass }); }
if self.is_subtype(types::ObjectExact) { return Some(unsafe { rb_cObject }); }
if self.is_subtype(types::RangeExact) { return Some(unsafe { rb_cRange }); }