diff options
| author | Stan Lo <stan.lo@shopify.com> | 2025-07-04 21:41:32 +0100 |
|---|---|---|
| committer | Max Bernstein <tekknolagi@gmail.com> | 2025-07-08 12:28:03 -0400 |
| commit | 6c20082852a2f69a11d950b98ff179b2b737ed67 (patch) | |
| tree | 27a6e75dad545dbe51dd7cae890648c2011d937c | |
| parent | f5acefca44951dcaec53324826e4078a3c3ce6f9 (diff) | |
ZJIT: Support inference of ModuleExact type
| -rw-r--r-- | zjit/src/cruby.rs | 1 | ||||
| -rw-r--r-- | zjit/src/hir.rs | 26 | ||||
| -rw-r--r-- | zjit/src/hir_type/gen_hir_type.rb | 1 | ||||
| -rw-r--r-- | zjit/src/hir_type/hir_type.inc.rs | 53 | ||||
| -rw-r--r-- | zjit/src/hir_type/mod.rs | 23 |
5 files changed, 73 insertions, 31 deletions
diff --git a/zjit/src/cruby.rs b/zjit/src/cruby.rs index 82f0e39804..459c7d7d5d 100644 --- a/zjit/src/cruby.rs +++ b/zjit/src/cruby.rs @@ -158,6 +158,7 @@ unsafe extern "C" { pub fn rb_vm_ic_hit_p(ic: IC, reg_ep: *const VALUE) -> bool; pub fn rb_vm_stack_canary() -> VALUE; pub fn rb_vm_push_cfunc_frame(cme: *const rb_callable_method_entry_t, recv_idx: c_int); + pub fn rb_obj_class(klass: VALUE) -> VALUE; } // Renames diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index a62f4898d7..20ffe17683 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -5921,7 +5921,7 @@ mod opt_tests { } #[test] - fn module_instances_not_class_exact() { + fn module_instances_are_module_exact() { eval(" def test = [Enumerable, Kernel] test # Warm the constant cache @@ -5931,16 +5931,34 @@ mod opt_tests { bb0(v0:BasicObject): PatchPoint SingleRactorMode PatchPoint StableConstantNames(0x1000, Enumerable) - v11:BasicObject[VALUE(0x1008)] = Const Value(VALUE(0x1008)) + v11:ModuleExact[VALUE(0x1008)] = Const Value(VALUE(0x1008)) PatchPoint SingleRactorMode PatchPoint StableConstantNames(0x1010, Kernel) - v14:BasicObject[VALUE(0x1018)] = Const Value(VALUE(0x1018)) + v14:ModuleExact[VALUE(0x1018)] = Const Value(VALUE(0x1018)) v7:ArrayExact = NewArray v11, v14 Return v7 "#]]); } #[test] + fn module_subclasses_are_not_module_exact() { + eval(" + class ModuleSubclass < Module; end + MY_MODULE = ModuleSubclass.new + def test = MY_MODULE + test # Warm the constant cache + "); + assert_optimized_method_hir("test", expect![[r#" + fn test: + bb0(v0:BasicObject): + PatchPoint SingleRactorMode + PatchPoint StableConstantNames(0x1000, MY_MODULE) + v7:BasicObject[VALUE(0x1008)] = Const Value(VALUE(0x1008)) + Return v7 + "#]]); + } + + #[test] fn eliminate_array_size() { eval(" def test @@ -6067,7 +6085,7 @@ mod opt_tests { bb0(v0:BasicObject): PatchPoint SingleRactorMode PatchPoint StableConstantNames(0x1000, Kernel) - v7:BasicObject[VALUE(0x1008)] = Const Value(VALUE(0x1008)) + v7:ModuleExact[VALUE(0x1008)] = Const Value(VALUE(0x1008)) Return v7 "#]]); } diff --git a/zjit/src/hir_type/gen_hir_type.rb b/zjit/src/hir_type/gen_hir_type.rb index 50259a9816..bbd0e1ed32 100644 --- a/zjit/src/hir_type/gen_hir_type.rb +++ b/zjit/src/hir_type/gen_hir_type.rb @@ -75,6 +75,7 @@ base_type "Range" base_type "Set" base_type "Regexp" base_type "Class" +base_type "Module" (integer, integer_exact) = base_type "Integer" # CRuby partitions Integer into immediate and non-immediate variants. diff --git a/zjit/src/hir_type/hir_type.inc.rs b/zjit/src/hir_type/hir_type.inc.rs index 57349aba14..2c6fb48ea5 100644 --- a/zjit/src/hir_type/hir_type.inc.rs +++ b/zjit/src/hir_type/hir_type.inc.rs @@ -9,7 +9,7 @@ mod bits { pub const BasicObjectSubclass: u64 = 1u64 << 3; pub const Bignum: u64 = 1u64 << 4; pub const BoolExact: u64 = FalseClassExact | TrueClassExact; - pub const BuiltinExact: u64 = ArrayExact | BasicObjectExact | ClassExact | FalseClassExact | FloatExact | HashExact | IntegerExact | NilClassExact | ObjectExact | RangeExact | RegexpExact | SetExact | StringExact | SymbolExact | TrueClassExact; + pub const BuiltinExact: u64 = ArrayExact | BasicObjectExact | ClassExact | FalseClassExact | FloatExact | HashExact | IntegerExact | ModuleExact | NilClassExact | ObjectExact | RangeExact | RegexpExact | SetExact | StringExact | SymbolExact | TrueClassExact; pub const CBool: u64 = 1u64 << 5; pub const CDouble: u64 = 1u64 << 6; pub const CInt: u64 = CSigned | CUnsigned; @@ -48,35 +48,38 @@ mod bits { pub const Integer: u64 = IntegerExact | IntegerSubclass; pub const IntegerExact: u64 = Bignum | Fixnum; pub const IntegerSubclass: u64 = 1u64 << 29; + pub const Module: u64 = ModuleExact | ModuleSubclass; + pub const ModuleExact: u64 = 1u64 << 30; + pub const ModuleSubclass: u64 = 1u64 << 31; pub const NilClass: u64 = NilClassExact | NilClassSubclass; - pub const NilClassExact: u64 = 1u64 << 30; - pub const NilClassSubclass: u64 = 1u64 << 31; - pub const Object: u64 = Array | Class | FalseClass | Float | Hash | Integer | NilClass | ObjectExact | ObjectSubclass | Range | Regexp | Set | String | Symbol | TrueClass; - pub const ObjectExact: u64 = 1u64 << 32; - pub const ObjectSubclass: u64 = 1u64 << 33; + pub const NilClassExact: u64 = 1u64 << 32; + pub const NilClassSubclass: u64 = 1u64 << 33; + pub const Object: u64 = Array | Class | FalseClass | Float | Hash | Integer | Module | NilClass | ObjectExact | ObjectSubclass | Range | Regexp | Set | String | Symbol | TrueClass; + pub const ObjectExact: u64 = 1u64 << 34; + pub const ObjectSubclass: u64 = 1u64 << 35; pub const Range: u64 = RangeExact | RangeSubclass; - pub const RangeExact: u64 = 1u64 << 34; - pub const RangeSubclass: u64 = 1u64 << 35; + pub const RangeExact: u64 = 1u64 << 36; + pub const RangeSubclass: u64 = 1u64 << 37; pub const Regexp: u64 = RegexpExact | RegexpSubclass; - pub const RegexpExact: u64 = 1u64 << 36; - pub const RegexpSubclass: u64 = 1u64 << 37; + pub const RegexpExact: u64 = 1u64 << 38; + pub const RegexpSubclass: u64 = 1u64 << 39; pub const RubyValue: u64 = BasicObject | CallableMethodEntry | Undef; pub const Set: u64 = SetExact | SetSubclass; - pub const SetExact: u64 = 1u64 << 38; - pub const SetSubclass: u64 = 1u64 << 39; - pub const StaticSymbol: u64 = 1u64 << 40; + pub const SetExact: u64 = 1u64 << 40; + pub const SetSubclass: u64 = 1u64 << 41; + pub const StaticSymbol: u64 = 1u64 << 42; pub const String: u64 = StringExact | StringSubclass; - pub const StringExact: u64 = 1u64 << 41; - pub const StringSubclass: u64 = 1u64 << 42; - pub const Subclass: u64 = ArraySubclass | BasicObjectSubclass | ClassSubclass | FalseClassSubclass | FloatSubclass | HashSubclass | IntegerSubclass | NilClassSubclass | ObjectSubclass | RangeSubclass | RegexpSubclass | SetSubclass | StringSubclass | SymbolSubclass | TrueClassSubclass; + pub const StringExact: u64 = 1u64 << 43; + pub const StringSubclass: u64 = 1u64 << 44; + pub const Subclass: u64 = ArraySubclass | BasicObjectSubclass | ClassSubclass | FalseClassSubclass | FloatSubclass | HashSubclass | IntegerSubclass | ModuleSubclass | NilClassSubclass | ObjectSubclass | RangeSubclass | RegexpSubclass | SetSubclass | StringSubclass | SymbolSubclass | TrueClassSubclass; pub const Symbol: u64 = SymbolExact | SymbolSubclass; pub const SymbolExact: u64 = DynamicSymbol | StaticSymbol; - pub const SymbolSubclass: u64 = 1u64 << 43; + pub const SymbolSubclass: u64 = 1u64 << 45; pub const TrueClass: u64 = TrueClassExact | TrueClassSubclass; - pub const TrueClassExact: u64 = 1u64 << 44; - pub const TrueClassSubclass: u64 = 1u64 << 45; - pub const Undef: u64 = 1u64 << 46; - pub const AllBitPatterns: [(&'static str, u64); 76] = [ + pub const TrueClassExact: u64 = 1u64 << 46; + pub const TrueClassSubclass: u64 = 1u64 << 47; + pub const Undef: u64 = 1u64 << 48; + pub const AllBitPatterns: [(&'static str, u64); 79] = [ ("Any", Any), ("RubyValue", RubyValue), ("Immediate", Immediate), @@ -110,6 +113,9 @@ mod bits { ("NilClass", NilClass), ("NilClassSubclass", NilClassSubclass), ("NilClassExact", NilClassExact), + ("Module", Module), + ("ModuleSubclass", ModuleSubclass), + ("ModuleExact", ModuleExact), ("Integer", Integer), ("IntegerSubclass", IntegerSubclass), ("Float", Float), @@ -154,7 +160,7 @@ mod bits { ("ArrayExact", ArrayExact), ("Empty", Empty), ]; - pub const NumTypeBits: u64 = 47; + pub const NumTypeBits: u64 = 49; } pub mod types { use super::*; @@ -206,6 +212,9 @@ pub mod types { pub const Integer: Type = Type::from_bits(bits::Integer); pub const IntegerExact: Type = Type::from_bits(bits::IntegerExact); pub const IntegerSubclass: Type = Type::from_bits(bits::IntegerSubclass); + pub const Module: Type = Type::from_bits(bits::Module); + pub const ModuleExact: Type = Type::from_bits(bits::ModuleExact); + pub const ModuleSubclass: Type = Type::from_bits(bits::ModuleSubclass); pub const NilClass: Type = Type::from_bits(bits::NilClass); pub const NilClassExact: Type = Type::from_bits(bits::NilClassExact); pub const NilClassSubclass: Type = Type::from_bits(bits::NilClassSubclass); diff --git a/zjit/src/hir_type/mod.rs b/zjit/src/hir_type/mod.rs index 422055e6d0..907582c251 100644 --- a/zjit/src/hir_type/mod.rs +++ b/zjit/src/hir_type/mod.rs @@ -1,10 +1,11 @@ #![allow(non_upper_case_globals)] -use crate::cruby::{Qfalse, Qnil, Qtrue, VALUE, RUBY_T_ARRAY, RUBY_T_STRING, RUBY_T_HASH, RUBY_T_CLASS}; +use crate::cruby::{Qfalse, Qnil, Qtrue, VALUE, RUBY_T_ARRAY, RUBY_T_STRING, RUBY_T_HASH, RUBY_T_CLASS, RUBY_T_MODULE}; use crate::cruby::{rb_cInteger, rb_cFloat, rb_cArray, rb_cHash, rb_cString, rb_cSymbol, rb_cObject, rb_cTrueClass, rb_cFalseClass, rb_cNilClass, rb_cRange, rb_cSet, rb_cRegexp, rb_cClass, rb_cModule}; use crate::cruby::ClassRelationship; use crate::cruby::get_class_name; use crate::cruby::ruby_sym_to_rust_string; use crate::cruby::rb_mRubyVMFrozenCore; +use crate::cruby::rb_obj_class; use crate::hir::PtrPrintMap; #[derive(Copy, Clone, Debug, PartialEq)] @@ -145,9 +146,16 @@ fn is_range_exact(val: VALUE) -> bool { val.class_of() == unsafe { rb_cRange } } -fn is_class_exact(val: VALUE) -> bool { - // Objects with RUBY_T_CLASS type and not instances of Module - val.builtin_type() == RUBY_T_CLASS && val.class_of() != unsafe { rb_cModule } +fn is_module_exact(val: VALUE) -> bool { + if val.builtin_type() != RUBY_T_MODULE { + return false; + } + + // For Class and Module instances, `class_of` will return the singleton class of the object. + // Using `rb_obj_class` will give us the actual class of the module so we can check if the + // object is an instance of Module, or an instance of Module subclass. + let klass = unsafe { rb_obj_class(val) }; + klass == unsafe { rb_cModule } } impl Type { @@ -202,7 +210,10 @@ impl Type { else if is_string_exact(val) { Type { bits: bits::StringExact, spec: Specialization::Object(val) } } - else if is_class_exact(val) { + else if is_module_exact(val) { + Type { bits: bits::ModuleExact, spec: Specialization::Object(val) } + } + else if val.builtin_type() == RUBY_T_CLASS { Type { bits: bits::ClassExact, spec: Specialization::Object(val) } } else if val.class_of() == unsafe { rb_cRegexp } { @@ -301,6 +312,7 @@ impl Type { if class == unsafe { rb_cFloat } { return true; } if class == unsafe { rb_cHash } { return true; } if class == unsafe { rb_cInteger } { return true; } + if class == unsafe { rb_cModule } { return true; } if class == unsafe { rb_cNilClass } { return true; } if class == unsafe { rb_cObject } { return true; } if class == unsafe { rb_cRange } { return true; } @@ -410,6 +422,7 @@ impl Type { if self.is_subtype(types::FloatExact) { return Some(unsafe { rb_cFloat }); } if self.is_subtype(types::HashExact) { return Some(unsafe { rb_cHash }); } if self.is_subtype(types::IntegerExact) { return Some(unsafe { rb_cInteger }); } + if self.is_subtype(types::ModuleExact) { return Some(unsafe { rb_cModule }); } if self.is_subtype(types::NilClassExact) { return Some(unsafe { rb_cNilClass }); } if self.is_subtype(types::ObjectExact) { return Some(unsafe { rb_cObject }); } if self.is_subtype(types::RangeExact) { return Some(unsafe { rb_cRange }); } |
