diff options
| -rw-r--r-- | zjit/src/hir.rs | 41 | ||||
| -rw-r--r-- | zjit/src/hir_type/gen_hir_type.rb | 54 | ||||
| -rw-r--r-- | zjit/src/hir_type/hir_type.inc.rs | 64 | ||||
| -rw-r--r-- | zjit/src/hir_type/mod.rs | 154 |
4 files changed, 171 insertions, 142 deletions
diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index 6e3fd78e0a..e5286cd5e0 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -829,7 +829,10 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> { Insn::ArrayDup { val, .. } => { write!(f, "ArrayDup {val}") } Insn::HashDup { val, .. } => { write!(f, "HashDup {val}") } Insn::ObjectAlloc { val, .. } => { write!(f, "ObjectAlloc {val}") } - Insn::ObjectAllocClass { class, .. } => { write!(f, "ObjectAllocClass {}", class.print(self.ptr_map)) } + &Insn::ObjectAllocClass { class, .. } => { + let class_name = get_class_name(class); + write!(f, "ObjectAllocClass {class_name}:{}", class.print(self.ptr_map)) + } Insn::StringCopy { val, .. } => { write!(f, "StringCopy {val}") } Insn::StringConcat { strings, .. } => { write!(f, "StringConcat")?; @@ -5637,7 +5640,7 @@ mod tests { v10:ArrayExact = ToArray v2 PatchPoint NoEPEscape(test) GuardBlockParamProxy l0 - v15:BasicObject[BlockParamProxy] = Const Value(VALUE(0x1000)) + v15:HeapObject[BlockParamProxy] = Const Value(VALUE(0x1000)) SideExit UnhandledYARVInsn(splatkw) "); } @@ -6221,16 +6224,16 @@ mod tests { v10:BasicObject = InvokeBuiltin dir_s_open, v0, v1, v2 PatchPoint NoEPEscape(open) GuardBlockParamProxy l0 - v17:BasicObject[BlockParamProxy] = Const Value(VALUE(0x1000)) + v17:HeapObject[BlockParamProxy] = Const Value(VALUE(0x1000)) CheckInterrupts - v20:CBool = Test v17 + v20:CBool[true] = Test v17 IfFalse v20, bb1(v0, v1, v2, v3, v4, v10) PatchPoint NoEPEscape(open) v27:BasicObject = InvokeBlock, v10 v31:BasicObject = InvokeBuiltin dir_s_close, v0, v10 CheckInterrupts Return v27 - bb1(v37:BasicObject, v38:BasicObject, v39:BasicObject, v40:BasicObject, v41:BasicObject, v42:BasicObject): + bb1(v37, v38, v39, v40, v41, v42): PatchPoint NoEPEscape(open) CheckInterrupts Return v42 @@ -8033,7 +8036,7 @@ mod opt_tests { bb0(v0:BasicObject): PatchPoint SingleRactorMode PatchPoint StableConstantNames(0x1000, MY_MODULE) - v13:BasicObject[VALUE(0x1008)] = Const Value(VALUE(0x1008)) + v13:HeapObject[VALUE(0x1008)] = Const Value(VALUE(0x1008)) CheckInterrupts Return v13 "); @@ -8378,7 +8381,7 @@ mod opt_tests { v34:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008)) v6:NilClass = Const Value(nil) PatchPoint MethodRedefined(C@0x1008, new@0x1010, cme:0x1018) - v37:HeapObject[class_exact:C] = ObjectAllocClass VALUE(0x1008) + v37:HeapObject[class_exact:C] = ObjectAllocClass C:VALUE(0x1008) PatchPoint MethodRedefined(C@0x1008, initialize@0x1040, cme:0x1048) v39:NilClass = CCall initialize@0x1070, v37 CheckInterrupts @@ -8407,7 +8410,7 @@ mod opt_tests { v6:NilClass = Const Value(nil) v7:Fixnum[1] = Const Value(1) PatchPoint MethodRedefined(C@0x1008, new@0x1010, cme:0x1018) - v39:HeapObject[class_exact:C] = ObjectAllocClass VALUE(0x1008) + v39:HeapObject[class_exact:C] = ObjectAllocClass C:VALUE(0x1008) PatchPoint MethodRedefined(C@0x1008, initialize@0x1040, cme:0x1048) v41:BasicObject = SendWithoutBlockDirect v39, :initialize (0x1070), v7 CheckInterrupts @@ -8430,7 +8433,7 @@ mod opt_tests { v34:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008)) v6:NilClass = Const Value(nil) PatchPoint MethodRedefined(Object@0x1008, new@0x1010, cme:0x1018) - v37:HeapObject[class_exact:Object] = ObjectAllocClass VALUE(0x1008) + v37:ObjectExact = ObjectAllocClass Object:VALUE(0x1008) PatchPoint MethodRedefined(Object@0x1008, initialize@0x1040, cme:0x1048) v39:NilClass = CCall initialize@0x1070, v37 CheckInterrupts @@ -8453,7 +8456,7 @@ mod opt_tests { v34:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008)) v6:NilClass = Const Value(nil) PatchPoint MethodRedefined(BasicObject@0x1008, new@0x1010, cme:0x1018) - v37:HeapObject[class_exact:BasicObject] = ObjectAllocClass VALUE(0x1008) + v37:BasicObjectExact = ObjectAllocClass BasicObject:VALUE(0x1008) PatchPoint MethodRedefined(BasicObject@0x1008, initialize@0x1040, cme:0x1048) v39:NilClass = CCall initialize@0x1070, v37 CheckInterrupts @@ -8476,7 +8479,7 @@ mod opt_tests { v34:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008)) v6:NilClass = Const Value(nil) PatchPoint MethodRedefined(Hash@0x1008, new@0x1010, cme:0x1018) - v37:HashExact = ObjectAllocClass VALUE(0x1008) + v37:HashExact = ObjectAllocClass Hash:VALUE(0x1008) v12:BasicObject = SendWithoutBlock v37, :initialize CheckInterrupts CheckInterrupts @@ -8523,7 +8526,7 @@ mod opt_tests { PatchPoint MethodRedefined(Set@0x1008, new@0x1010, cme:0x1018) v10:HeapObject = ObjectAlloc v34 PatchPoint MethodRedefined(Set@0x1008, initialize@0x1040, cme:0x1048) - v39:HeapObject[class_exact:Set] = GuardType v10, HeapObject[class_exact:Set] + v39:SetExact = GuardType v10, SetExact v40:BasicObject = CCallVariadic initialize@0x1070, v39 CheckInterrupts CheckInterrupts @@ -8568,7 +8571,7 @@ mod opt_tests { v7:StringExact[VALUE(0x1010)] = Const Value(VALUE(0x1010)) v9:StringExact = StringCopy v7 PatchPoint MethodRedefined(Regexp@0x1008, new@0x1018, cme:0x1020) - v41:HeapObject[class_exact:Regexp] = ObjectAllocClass VALUE(0x1008) + v41:RegexpExact = ObjectAllocClass Regexp:VALUE(0x1008) PatchPoint MethodRedefined(Regexp@0x1008, initialize@0x1048, cme:0x1050) v44:BasicObject = CCallVariadic initialize@0x1078, v41, v9 CheckInterrupts @@ -8618,7 +8621,7 @@ mod opt_tests { fn test@<compiled>:2: bb0(v0:BasicObject, v1:BasicObject): GuardBlockParamProxy l0 - v7:BasicObject[BlockParamProxy] = Const Value(VALUE(0x1000)) + v7:HeapObject[BlockParamProxy] = Const Value(VALUE(0x1000)) v9:BasicObject = Send v0, 0x1008, :tap, v7 CheckInterrupts Return v9 @@ -9632,10 +9635,9 @@ mod opt_tests { bb0(v0:BasicObject): PatchPoint SingleRactorMode PatchPoint StableConstantNames(0x1000, O) - v15:BasicObject[VALUE(0x1008)] = Const Value(VALUE(0x1008)) + v15:HeapObject[VALUE(0x1008)] = Const Value(VALUE(0x1008)) PatchPoint MethodRedefined(C@0x1010, foo@0x1018, cme:0x1020) - v18:HeapObject[VALUE(0x1008)] = GuardType v15, HeapObject - v19:HeapObject[VALUE(0x1008)] = GuardShape v18, 0x1048 + v19:HeapObject[VALUE(0x1008)] = GuardShape v15, 0x1048 v20:NilClass = Const Value(nil) CheckInterrupts Return v20 @@ -9659,10 +9661,9 @@ mod opt_tests { bb0(v0:BasicObject): PatchPoint SingleRactorMode PatchPoint StableConstantNames(0x1000, O) - v15:BasicObject[VALUE(0x1008)] = Const Value(VALUE(0x1008)) + v15:HeapObject[VALUE(0x1008)] = Const Value(VALUE(0x1008)) PatchPoint MethodRedefined(C@0x1010, foo@0x1018, cme:0x1020) - v18:HeapObject[VALUE(0x1008)] = GuardType v15, HeapObject - v19:HeapObject[VALUE(0x1008)] = GuardShape v18, 0x1048 + v19:HeapObject[VALUE(0x1008)] = GuardShape v15, 0x1048 v20:NilClass = Const Value(nil) CheckInterrupts Return v20 diff --git a/zjit/src/hir_type/gen_hir_type.rb b/zjit/src/hir_type/gen_hir_type.rb index 15aa68a600..1ab6adf2eb 100644 --- a/zjit/src/hir_type/gen_hir_type.rb +++ b/zjit/src/hir_type/gen_hir_type.rb @@ -58,10 +58,19 @@ object_subclass = $object.subtype "ObjectSubclass" $subclass = [basic_object_subclass.name, object_subclass.name] $builtin_exact = [basic_object_exact.name, object_exact.name] +$c_names = { + "ObjectExact" => "rb_cObject", + "BasicObjectExact" => "rb_cBasicObject", +} + # Define a new type that can be subclassed (most of them). -def base_type name +# If c_name is given, mark the rb_cXYZ object as equivalent to this exact type. +def base_type name, c_name: nil type = $object.subtype name exact = type.subtype(name+"Exact") + if c_name + $c_names[exact.name] = c_name + end subclass = type.subtype(name+"Subclass") $builtin_exact << exact.name $subclass << subclass.name @@ -69,39 +78,45 @@ def base_type name end # Define a new type that cannot be subclassed. -def final_type name - type = $object.subtype name +# If c_name is given, mark the rb_cXYZ object as equivalent to this type. +def final_type name, base: $object, c_name: nil + if c_name + $c_names[name] = c_name + end + type = base.subtype name $builtin_exact << type.name type end -base_type "String" -base_type "Array" -base_type "Hash" -base_type "Range" -base_type "Set" -base_type "Regexp" -module_class, _ = base_type "Module" -module_class.subtype "Class" +base_type "String", c_name: "rb_cString" +base_type "Array", c_name: "rb_cArray" +base_type "Hash", c_name: "rb_cHash" +base_type "Range", c_name: "rb_cRange" +base_type "Set", c_name: "rb_cSet" +base_type "Regexp", c_name: "rb_cRegexp" +module_class, _ = base_type "Module", c_name: "rb_cModule" +class_ = final_type "Class", base: module_class, c_name: "rb_cClass" + +numeric, _ = base_type "Numeric", c_name: "rb_cNumeric" -integer_exact = final_type "Integer" +integer_exact = final_type "Integer", base: numeric, c_name: "rb_cInteger" # CRuby partitions Integer into immediate and non-immediate variants. fixnum = integer_exact.subtype "Fixnum" integer_exact.subtype "Bignum" -float_exact = final_type "Float" +float_exact = final_type "Float", base: numeric, c_name: "rb_cFloat" # CRuby partitions Float into immediate and non-immediate variants. flonum = float_exact.subtype "Flonum" float_exact.subtype "HeapFloat" -symbol_exact = final_type "Symbol" +symbol_exact = final_type "Symbol", c_name: "rb_cSymbol" # CRuby partitions Symbol into immediate and non-immediate variants. static_sym = symbol_exact.subtype "StaticSymbol" symbol_exact.subtype "DynamicSymbol" -nil_exact = final_type "NilClass" -true_exact = final_type "TrueClass" -false_exact = final_type "FalseClass" +nil_exact = final_type "NilClass", c_name: "rb_cNilClass" +true_exact = final_type "TrueClass", c_name: "rb_cTrueClass" +false_exact = final_type "FalseClass", c_name: "rb_cFalseClass" # Build the cvalue object universe. This is for C-level types that may be # passed around when calling into the Ruby VM or after some strength reduction @@ -183,4 +198,9 @@ puts "pub mod types { $bits.keys.sort.map {|type_name| puts " pub const #{type_name}: Type = Type::from_bits(bits::#{type_name});" } +puts " pub const ExactBitsAndClass: [(u64, *const VALUE); #{$c_names.size}] = [" +$c_names.each {|type_name, c_name| + puts " (bits::#{type_name}, &raw const crate::cruby::#{c_name})," +} +puts " ];" puts "}" diff --git a/zjit/src/hir_type/hir_type.inc.rs b/zjit/src/hir_type/hir_type.inc.rs index 2e03fdac96..c392735742 100644 --- a/zjit/src/hir_type/hir_type.inc.rs +++ b/zjit/src/hir_type/hir_type.inc.rs @@ -9,7 +9,7 @@ mod bits { pub const BasicObjectSubclass: u64 = 1u64 << 3; pub const Bignum: u64 = 1u64 << 4; pub const BoolExact: u64 = FalseClass | TrueClass; - pub const BuiltinExact: u64 = ArrayExact | BasicObjectExact | FalseClass | Float | HashExact | Integer | ModuleExact | NilClass | ObjectExact | RangeExact | RegexpExact | SetExact | StringExact | Symbol | TrueClass; + pub const BuiltinExact: u64 = ArrayExact | BasicObjectExact | Class | FalseClass | Float | HashExact | Integer | ModuleExact | NilClass | NumericExact | ObjectExact | RangeExact | RegexpExact | SetExact | StringExact | Symbol | TrueClass; pub const CBool: u64 = 1u64 << 5; pub const CDouble: u64 = 1u64 << 6; pub const CInt: u64 = CSigned | CUnsigned; @@ -45,28 +45,31 @@ mod bits { pub const ModuleExact: u64 = 1u64 << 26; pub const ModuleSubclass: u64 = 1u64 << 27; pub const NilClass: u64 = 1u64 << 28; - pub const Object: u64 = Array | FalseClass | Float | Hash | Integer | Module | NilClass | ObjectExact | ObjectSubclass | Range | Regexp | Set | String | Symbol | TrueClass; - pub const ObjectExact: u64 = 1u64 << 29; - pub const ObjectSubclass: u64 = 1u64 << 30; + pub const Numeric: u64 = Float | Integer | NumericExact | NumericSubclass; + pub const NumericExact: u64 = 1u64 << 29; + pub const NumericSubclass: u64 = 1u64 << 30; + pub const Object: u64 = Array | FalseClass | Hash | Module | NilClass | Numeric | ObjectExact | ObjectSubclass | Range | Regexp | Set | String | Symbol | TrueClass; + pub const ObjectExact: u64 = 1u64 << 31; + pub const ObjectSubclass: u64 = 1u64 << 32; pub const Range: u64 = RangeExact | RangeSubclass; - pub const RangeExact: u64 = 1u64 << 31; - pub const RangeSubclass: u64 = 1u64 << 32; + pub const RangeExact: u64 = 1u64 << 33; + pub const RangeSubclass: u64 = 1u64 << 34; pub const Regexp: u64 = RegexpExact | RegexpSubclass; - pub const RegexpExact: u64 = 1u64 << 33; - pub const RegexpSubclass: u64 = 1u64 << 34; + pub const RegexpExact: u64 = 1u64 << 35; + pub const RegexpSubclass: u64 = 1u64 << 36; pub const RubyValue: u64 = BasicObject | CallableMethodEntry | Undef; pub const Set: u64 = SetExact | SetSubclass; - pub const SetExact: u64 = 1u64 << 35; - pub const SetSubclass: u64 = 1u64 << 36; - pub const StaticSymbol: u64 = 1u64 << 37; + pub const SetExact: u64 = 1u64 << 37; + pub const SetSubclass: u64 = 1u64 << 38; + pub const StaticSymbol: u64 = 1u64 << 39; pub const String: u64 = StringExact | StringSubclass; - pub const StringExact: u64 = 1u64 << 38; - pub const StringSubclass: u64 = 1u64 << 39; - pub const Subclass: u64 = ArraySubclass | BasicObjectSubclass | HashSubclass | ModuleSubclass | ObjectSubclass | RangeSubclass | RegexpSubclass | SetSubclass | StringSubclass; + pub const StringExact: u64 = 1u64 << 40; + pub const StringSubclass: u64 = 1u64 << 41; + pub const Subclass: u64 = ArraySubclass | BasicObjectSubclass | HashSubclass | ModuleSubclass | NumericSubclass | ObjectSubclass | RangeSubclass | RegexpSubclass | SetSubclass | StringSubclass; pub const Symbol: u64 = DynamicSymbol | StaticSymbol; - pub const TrueClass: u64 = 1u64 << 40; - pub const Undef: u64 = 1u64 << 41; - pub const AllBitPatterns: [(&str, u64); 66] = [ + pub const TrueClass: u64 = 1u64 << 42; + pub const Undef: u64 = 1u64 << 43; + pub const AllBitPatterns: [(&'static str, u64); 69] = [ ("Any", Any), ("RubyValue", RubyValue), ("Immediate", Immediate), @@ -94,6 +97,9 @@ mod bits { ("RangeExact", RangeExact), ("ObjectSubclass", ObjectSubclass), ("ObjectExact", ObjectExact), + ("Numeric", Numeric), + ("NumericSubclass", NumericSubclass), + ("NumericExact", NumericExact), ("NilClass", NilClass), ("Module", Module), ("ModuleSubclass", ModuleSubclass), @@ -134,7 +140,7 @@ mod bits { ("ArrayExact", ArrayExact), ("Empty", Empty), ]; - pub const NumTypeBits: u64 = 42; + pub const NumTypeBits: u64 = 44; } pub mod types { use super::*; @@ -183,6 +189,9 @@ pub mod types { pub const ModuleExact: Type = Type::from_bits(bits::ModuleExact); pub const ModuleSubclass: Type = Type::from_bits(bits::ModuleSubclass); pub const NilClass: Type = Type::from_bits(bits::NilClass); + pub const Numeric: Type = Type::from_bits(bits::Numeric); + pub const NumericExact: Type = Type::from_bits(bits::NumericExact); + pub const NumericSubclass: Type = Type::from_bits(bits::NumericSubclass); pub const Object: Type = Type::from_bits(bits::Object); pub const ObjectExact: Type = Type::from_bits(bits::ObjectExact); pub const ObjectSubclass: Type = Type::from_bits(bits::ObjectSubclass); @@ -204,4 +213,23 @@ pub mod types { pub const Symbol: Type = Type::from_bits(bits::Symbol); pub const TrueClass: Type = Type::from_bits(bits::TrueClass); pub const Undef: Type = Type::from_bits(bits::Undef); + pub const ExactBitsAndClass: [(u64, *const VALUE); 17] = [ + (bits::ObjectExact, &raw const crate::cruby::rb_cObject), + (bits::BasicObjectExact, &raw const crate::cruby::rb_cBasicObject), + (bits::StringExact, &raw const crate::cruby::rb_cString), + (bits::ArrayExact, &raw const crate::cruby::rb_cArray), + (bits::HashExact, &raw const crate::cruby::rb_cHash), + (bits::RangeExact, &raw const crate::cruby::rb_cRange), + (bits::SetExact, &raw const crate::cruby::rb_cSet), + (bits::RegexpExact, &raw const crate::cruby::rb_cRegexp), + (bits::ModuleExact, &raw const crate::cruby::rb_cModule), + (bits::Class, &raw const crate::cruby::rb_cClass), + (bits::NumericExact, &raw const crate::cruby::rb_cNumeric), + (bits::Integer, &raw const crate::cruby::rb_cInteger), + (bits::Float, &raw const crate::cruby::rb_cFloat), + (bits::Symbol, &raw const crate::cruby::rb_cSymbol), + (bits::NilClass, &raw const crate::cruby::rb_cNilClass), + (bits::TrueClass, &raw const crate::cruby::rb_cTrueClass), + (bits::FalseClass, &raw const crate::cruby::rb_cFalseClass), + ]; } diff --git a/zjit/src/hir_type/mod.rs b/zjit/src/hir_type/mod.rs index f2fb870257..5478653a23 100644 --- a/zjit/src/hir_type/mod.rs +++ b/zjit/src/hir_type/mod.rs @@ -2,7 +2,7 @@ #![allow(non_upper_case_globals)] use crate::cruby::{rb_block_param_proxy, Qfalse, Qnil, Qtrue, RUBY_T_ARRAY, RUBY_T_CLASS, RUBY_T_HASH, RUBY_T_MODULE, RUBY_T_STRING, VALUE}; -use crate::cruby::{rb_cInteger, rb_cFloat, rb_cArray, rb_cHash, rb_cString, rb_cSymbol, rb_cObject, rb_cTrueClass, rb_cFalseClass, rb_cNilClass, rb_cRange, rb_cSet, rb_cRegexp, rb_cClass, rb_cModule, rb_zjit_singleton_class_p}; +use crate::cruby::{rb_cInteger, rb_cFloat, rb_cArray, rb_cHash, rb_cString, rb_cSymbol, rb_cRange, rb_cModule, rb_zjit_singleton_class_p}; use crate::cruby::ClassRelationship; use crate::cruby::get_class_name; use crate::cruby::ruby_sym_to_rust_string; @@ -177,10 +177,43 @@ impl Type { } } + fn bits_from_exact_class(class: VALUE) -> Option<u64> { + types::ExactBitsAndClass + .iter() + .find(|&(_, class_object)| unsafe { **class_object } == class) + .map(|&(bits, _)| bits) + } + + fn from_heap_object(val: VALUE) -> Type { + assert!(!val.special_const_p(), "val should be a heap object"); + let bits = + // GC-hidden types + if is_array_exact(val) { bits::ArrayExact } + else if is_hash_exact(val) { bits::HashExact } + else if is_string_exact(val) { bits::StringExact } + // Singleton classes + else if is_module_exact(val) { bits::ModuleExact } + else if val.builtin_type() == RUBY_T_CLASS { bits::Class } + // Classes that have an immediate/heap split + else if val.class_of() == unsafe { rb_cInteger } { bits::Bignum } + else if val.class_of() == unsafe { rb_cFloat } { bits::HeapFloat } + else if val.class_of() == unsafe { rb_cSymbol } { bits::DynamicSymbol } + else { + Self::bits_from_exact_class(val.class_of()).unwrap_or({ + // We don't have a specific built-in bit pattern for this class, so generalize + // as HeapObject with object specialization. + bits::HeapObject + }) + }; + let spec = Specialization::Object(val); + Type { bits, spec } + } + /// Create a `Type` from a Ruby `VALUE`. The type is not guaranteed to have object /// specialization in its `specialization` field (for example, `Qnil` will just be /// `types::NilClass`), but will be available via `ruby_object()`. pub fn from_value(val: VALUE) -> Type { + // Immediates if val.fixnum_p() { Type { bits: bits::Fixnum, spec: Specialization::Object(val) } } @@ -199,45 +232,8 @@ impl Type { // valid on imemo. Type { bits: bits::CallableMethodEntry, spec: Specialization::Object(val) } } - else if val.class_of() == unsafe { rb_cInteger } { - Type { bits: bits::Bignum, spec: Specialization::Object(val) } - } - else if val.class_of() == unsafe { rb_cFloat } { - Type { bits: bits::HeapFloat, spec: Specialization::Object(val) } - } - else if val.class_of() == unsafe { rb_cSymbol } { - Type { bits: bits::DynamicSymbol, spec: Specialization::Object(val) } - } - else if is_array_exact(val) { - Type { bits: bits::ArrayExact, spec: Specialization::Object(val) } - } - else if is_hash_exact(val) { - Type { bits: bits::HashExact, spec: Specialization::Object(val) } - } - else if is_range_exact(val) { - Type { bits: bits::RangeExact, spec: Specialization::Object(val) } - } - else if is_string_exact(val) { - Type { bits: bits::StringExact, spec: Specialization::Object(val) } - } - else if is_module_exact(val) { - Type { bits: bits::ModuleExact, spec: Specialization::Object(val) } - } - else if val.builtin_type() == RUBY_T_CLASS { - Type { bits: bits::Class, spec: Specialization::Object(val) } - } - else if val.class_of() == unsafe { rb_cRegexp } { - Type { bits: bits::RegexpExact, spec: Specialization::Object(val) } - } - else if val.class_of() == unsafe { rb_cSet } { - Type { bits: bits::SetExact, spec: Specialization::Object(val) } - } - else if val.class_of() == unsafe { rb_cObject } { - Type { bits: bits::ObjectExact, spec: Specialization::Object(val) } - } else { - // TODO(max): Add more cases for inferring type bits from built-in types - Type { bits: bits::BasicObject, spec: Specialization::Object(val) } + Self::from_heap_object(val) } } @@ -248,26 +244,17 @@ impl Type { else if val.is_nil() { types::NilClass } else if val.is_true() { types::TrueClass } else if val.is_false() { types::FalseClass } - else if val.class() == unsafe { rb_cString } { types::StringExact } - else if val.class() == unsafe { rb_cArray } { types::ArrayExact } - else if val.class() == unsafe { rb_cHash } { types::HashExact } - else { - // TODO(max): Add more cases for inferring type bits from built-in types - Type { bits: bits::HeapObject, spec: Specialization::TypeExact(val.class()) } - } + else { Self::from_class(val.class()) } } pub fn from_class(class: VALUE) -> Type { - if class == unsafe { rb_cArray } { types::ArrayExact } - else if class == unsafe { rb_cFalseClass } { types::FalseClass } - else if class == unsafe { rb_cHash } { types::HashExact } - else if class == unsafe { rb_cInteger } { types::Integer} - else if class == unsafe { rb_cNilClass } { types::NilClass } - else if class == unsafe { rb_cString } { types::StringExact } - else if class == unsafe { rb_cTrueClass } { types::TrueClass } - else { - // TODO(max): Add more cases for inferring type bits from built-in types - Type { bits: bits::HeapObject, spec: Specialization::TypeExact(class) } + match Self::bits_from_exact_class(class) { + Some(bits) => Type::from_bits(bits), + None => { + // We don't have a specific built-in bit pattern for this class, so generalize + // as HeapObject with object specialization. + Type { bits: bits::HeapObject, spec: Specialization::TypeExact(class) } + } } } @@ -361,21 +348,10 @@ impl Type { } fn is_builtin(class: VALUE) -> bool { - if class == unsafe { rb_cArray } { return true; } - if class == unsafe { rb_cClass } { return true; } - if class == unsafe { rb_cFalseClass } { return true; } - if class == unsafe { rb_cFloat } { return true; } - if class == unsafe { rb_cHash } { return true; } - if class == unsafe { rb_cInteger } { return true; } - if class == unsafe { rb_cModule } { return true; } - if class == unsafe { rb_cNilClass } { return true; } - if class == unsafe { rb_cObject } { return true; } - if class == unsafe { rb_cRange } { return true; } - if class == unsafe { rb_cRegexp } { return true; } - if class == unsafe { rb_cString } { return true; } - if class == unsafe { rb_cSymbol } { return true; } - if class == unsafe { rb_cTrueClass } { return true; } - false + types::ExactBitsAndClass + .iter() + .find(|&(_, class_object)| unsafe { **class_object } == class) + .is_some() } /// Union both types together, preserving specialization if possible. @@ -471,22 +447,10 @@ impl Type { if let Some(val) = self.exact_ruby_class() { return Some(val); } - if self.is_subtype(types::ArrayExact) { return Some(unsafe { rb_cArray }); } - if self.is_subtype(types::Class) { return Some(unsafe { rb_cClass }); } - if self.is_subtype(types::FalseClass) { return Some(unsafe { rb_cFalseClass }); } - if self.is_subtype(types::Float) { return Some(unsafe { rb_cFloat }); } - if self.is_subtype(types::HashExact) { return Some(unsafe { rb_cHash }); } - if self.is_subtype(types::Integer) { return Some(unsafe { rb_cInteger }); } - if self.is_subtype(types::ModuleExact) { return Some(unsafe { rb_cModule }); } - if self.is_subtype(types::NilClass) { return Some(unsafe { rb_cNilClass }); } - if self.is_subtype(types::ObjectExact) { return Some(unsafe { rb_cObject }); } - if self.is_subtype(types::RangeExact) { return Some(unsafe { rb_cRange }); } - if self.is_subtype(types::RegexpExact) { return Some(unsafe { rb_cRegexp }); } - if self.is_subtype(types::SetExact) { return Some(unsafe { rb_cSet }); } - if self.is_subtype(types::StringExact) { return Some(unsafe { rb_cString }); } - if self.is_subtype(types::Symbol) { return Some(unsafe { rb_cSymbol }); } - if self.is_subtype(types::TrueClass) { return Some(unsafe { rb_cTrueClass }); } - None + types::ExactBitsAndClass + .iter() + .find(|&(bits, _)| self.is_subtype(Type::from_bits(*bits))) + .map(|&(_, class_object)| unsafe { *class_object }) } /// Check bit equality of two `Type`s. Do not use! You are probably looking for [`Type::is_subtype`]. @@ -534,6 +498,11 @@ mod tests { use crate::cruby::rb_hash_new; use crate::cruby::rb_float_new; use crate::cruby::define_class; + use crate::cruby::rb_cObject; + use crate::cruby::rb_cSet; + use crate::cruby::rb_cTrueClass; + use crate::cruby::rb_cFalseClass; + use crate::cruby::rb_cNilClass; #[track_caller] fn assert_bit_equal(left: Type, right: Type) { @@ -592,6 +561,17 @@ mod tests { } #[test] + fn numeric() { + assert_subtype(types::Integer, types::Numeric); + assert_subtype(types::Float, types::Numeric); + assert_subtype(types::Float.union(types::Integer), types::Numeric); + assert_bit_equal(types::Float + .union(types::Integer) + .union(types::NumericExact) + .union(types::NumericSubclass), types::Numeric); + } + + #[test] fn symbol() { assert_subtype(types::StaticSymbol, types::Symbol); assert_subtype(types::DynamicSymbol, types::Symbol); |
