diff options
| author | Max Bernstein <max@bernsteinbear.com> | 2025-03-10 15:53:42 -0400 |
|---|---|---|
| committer | Takashi Kokubun <takashikkbn@gmail.com> | 2025-04-18 21:52:59 +0900 |
| commit | 42cef565a405287c2b62501eafb0786d1f72e08f (patch) | |
| tree | 2f4cecbd325dfae3f3c0c95775b18c061f46ebbe /zjit | |
| parent | 43d532e36db81a68164f0a807c34c57b6f4a363f (diff) | |
Add intraprocedural flow typing (https://github.com/Shopify/zjit/pull/23)
* Add RPO
* Add basic flow typing
* Add more tests, check for type bit equality
* Run to fixpoint
* Only use/flow types if insn has an output
* WIP
* WIP 2: merge pred args
* It compiles again
* Infer more Const instructions
* Boot VM
* Test displaying types
* .
* Use type_of more
* Extract Param inference into its own function for readability
* .
* .
* Add notion of unions to generated type bit patterns
* .
* .
* Fix hierarchy for user/exact
* .
* .
* .
* Give ArraySet a receiver
* Use Function::find consistently
* s/fd/find/g
* Comment
* .
* Add TODO about recursion
* FrameStateId
* Use worklist based type inference
This requires computing "uses", or at least which blocks to revisit if
their params change.
* Just use a set
* Revert "Just use a set"
This reverts commit 54d88be00cbf78ce7e928c66d955c968187a5ec9.
* Revert "Use worklist based type inference"
This reverts commit e99b24629723c8848fefd5a75caa23e84c2f552e.
* .
* Store block params separately
* Sparse type inference
* Get tests passing after rebase
* .
* Use assert_method_hir
* .
Notes
Notes:
Merged: https://github.com/ruby/ruby/pull/13131
Diffstat (limited to 'zjit')
| -rw-r--r-- | zjit/src/gen_hir_type.rb | 56 | ||||
| -rw-r--r-- | zjit/src/hir.rs | 604 | ||||
| -rw-r--r-- | zjit/src/hir_type.inc.rs | 31 | ||||
| -rw-r--r-- | zjit/src/hir_type.rs | 24 |
4 files changed, 612 insertions, 103 deletions
diff --git a/zjit/src/gen_hir_type.rb b/zjit/src/gen_hir_type.rb index 42f25f8fc8..7fc29fe8f8 100644 --- a/zjit/src/gen_hir_type.rb +++ b/zjit/src/gen_hir_type.rb @@ -46,19 +46,19 @@ end # Start at Any. All types are subtypes of Any. any = Type.new "Any" # Build the Ruby object universe. -object = any.subtype "Object" -object.subtype "ObjectExact" -$object_user = object.subtype "ObjectUser" -$user = any.subtype "User" -$builtin_exact = object.subtype "BuiltinExact" +$object = any.subtype "Object" +object_exact = $object.subtype "ObjectExact" +object_user = $object.subtype "ObjectUser" +$user = [object_user.name] +$builtin_exact = [object_exact.name] # Define a new type that can be subclassed (most of them). def base_type name - type = $object_user.subtype name + type = $object.subtype name exact = type.subtype(name+"Exact") user = type.subtype(name+"User") - $builtin_exact.subtypes << exact - $user.subtypes << user + $builtin_exact << exact.name + $user << user.name [type, exact] end @@ -82,8 +82,8 @@ symbol_exact.subtype "StaticSymbol" symbol_exact.subtype "DynamicSymbol" base_type "NilClass" -base_type "TrueClass" -base_type "FalseClass" +_, true_exact = base_type "TrueClass" +_, false_exact = base_type "FalseClass" # Build the primitive object universe. primitive = any.subtype "Primitive" @@ -101,40 +101,54 @@ unsigned = primitive_int.subtype "CUnsigned" # Assign individual bits to type leaves and union bit patterns to nodes with subtypes num_bits = 0 -bits = {"Empty" => ["0u64"]} -numeric_bits = {"Empty" => 0} +$bits = {"Empty" => ["0u64"]} +$numeric_bits = {"Empty" => 0} Set[any, *any.all_subtypes].sort_by(&:name).each {|type| subtypes = type.subtypes if subtypes.empty? # Assign bits for leaves - bits[type.name] = ["1u64 << #{num_bits}"] - numeric_bits[type.name] = 1 << num_bits + $bits[type.name] = ["1u64 << #{num_bits}"] + $numeric_bits[type.name] = 1 << num_bits num_bits += 1 else # Assign bits for unions - bits[type.name] = subtypes.map(&:name).sort + $bits[type.name] = subtypes.map(&:name).sort end } [*any.all_subtypes, any].each {|type| subtypes = type.subtypes unless subtypes.empty? - numeric_bits[type.name] = subtypes.map {|ty| numeric_bits[ty.name]}.reduce(&:|) + $numeric_bits[type.name] = subtypes.map {|ty| $numeric_bits[ty.name]}.reduce(&:|) end } +# Unions are for names of groups of type bit patterns that don't fit neatly +# into the Ruby class hierarchy. For example, we might want to refer to a union +# of TrueClassExact|FalseClassExact by the name BoolExact even though a "bool" +# doesn't exist as a class in Ruby. +def add_union name, type_names + type_names = type_names.sort + $bits[name] = type_names + $numeric_bits[name] = type_names.map {|type_name| $numeric_bits[type_name]}.reduce(&:|) +end + +add_union "BuiltinExact", $builtin_exact +add_union "User", $user +add_union "BoolExact", [true_exact.name, false_exact.name] + # ===== Finished generating the DAG; write Rust code ===== puts "// This file is @generated by src/gen_hir_type.rb." puts "mod bits {" -bits.keys.sort.map {|type_name| - subtypes = bits[type_name].join(" | ") +$bits.keys.sort.map {|type_name| + subtypes = $bits[type_name].join(" | ") puts " pub const #{type_name}: u64 = #{subtypes};" } -puts " pub const AllBitPatterns: [(&'static str, u64); #{bits.size}] = [" +puts " pub const AllBitPatterns: [(&'static str, u64); #{$bits.size}] = [" # Sort the bit patterns by decreasing value so that we can print the densest # possible to-string representation of a Type. For example, CSigned instead of # CInt8|CInt16|... -numeric_bits.sort_by {|key, val| -val}.each {|type_name, _| +$numeric_bits.sort_by {|key, val| -val}.each {|type_name, _| puts " (\"#{type_name}\", #{type_name})," } puts " ];" @@ -143,7 +157,7 @@ puts " pub const NumTypeBits: u64 = #{num_bits}; puts "pub mod types { use super::*;" -bits.keys.sort.map {|type_name| +$bits.keys.sort.map {|type_name| puts " pub const #{type_name}: Type = Type::from_bits(bits::#{type_name});" } puts "}" diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index f86aac169d..968fe1907e 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -5,10 +5,9 @@ use crate::{ cruby::*, options::get_option, hir_type::types::Fixnum, options::DumpHIR, profile::get_or_create_iseq_payload }; use std::collections::{HashMap, HashSet}; +use crate::hir_type::{Type, types}; -use crate::hir_type::Type; - -#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] +#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)] pub struct InsnId(pub usize); impl Into<usize> for InsnId { @@ -122,6 +121,7 @@ impl std::fmt::Display for Invariant { #[derive(Debug, Clone, PartialEq)] pub enum Const { Value(VALUE), + CBool(bool), CInt8(i8), CInt16(i16), CInt32(i32), @@ -154,7 +154,7 @@ pub enum Insn { StringIntern { val: InsnId }, NewArray { count: usize }, - ArraySet { idx: usize, val: InsnId }, + ArraySet { array: InsnId, idx: usize, val: InsnId }, ArrayDup { val: InsnId }, // Check if the value is truthy and "return" a C boolean. In reality, we will likely fuse this @@ -213,6 +213,18 @@ pub enum Insn { PatchPoint(Invariant), } +impl Insn { + /// Not every instruction returns a value. Return true if the instruction does and false otherwise. + pub fn has_output(&self) -> bool { + match self { + Insn::ArraySet { .. } | Insn::Snapshot { .. } | Insn::Jump(_) + | Insn::IfTrue { .. } | Insn::IfFalse { .. } | Insn::Return { .. } + | Insn::PatchPoint { .. } => false, + _ => true, + } + } +} + #[derive(Default, Debug)] pub struct Block { params: Vec<InsnId>, @@ -336,6 +348,7 @@ pub struct Function { pub insns: Vec<Insn>, union_find: UnionFind<InsnId>, + insn_types: Vec<Type>, blocks: Vec<Block>, entry_block: BlockId, frame_states: Vec<FrameState>, @@ -346,6 +359,7 @@ impl Function { Function { iseq, insns: vec![], + insn_types: vec![], union_find: UnionFind::new(), blocks: vec![Block::default()], entry_block: BlockId(0), @@ -362,6 +376,7 @@ impl Function { self.blocks[block.0].insns.push(id); } self.insns.push(insn); + self.insn_types.push(types::Empty); id } @@ -384,6 +399,10 @@ impl Function { id } + /// Return a copy of the instruction where the instruction and its operands have been read from + /// the union-find table (to find the current most-optimized version of this instruction). See + /// [`UnionFind`] for more. + /// /// Use for pattern matching over instructions in a union-find-safe way. For example: /// ```rust /// match func.find(insn_id) { @@ -393,15 +412,42 @@ impl Function { /// _ => {} /// } /// ``` - fn find(&mut self, insn_id: InsnId) -> Insn { - let insn_id = self.union_find.find(insn_id); + fn find(&self, insn_id: InsnId) -> Insn { + macro_rules! find { + ( $x:expr ) => { + { + self.union_find.find_const($x) + } + }; + } + let insn_id = self.union_find.find_const(insn_id); use Insn::*; match &self.insns[insn_id.0] { - result@(PutSelf | Const {..} | Param {..} | NewArray {..} | GetConstantPath {..}) => result.clone(), - StringCopy { val } => StringCopy { val: self.union_find.find(*val) }, - StringIntern { val } => StringIntern { val: self.union_find.find(*val) }, - Test { val } => Test { val: self.union_find.find(*val) }, - insn => todo!("find({insn:?})"), + result@(PutSelf | Const {..} | Param {..} | NewArray {..} | GetConstantPath {..} | Snapshot {..} + | Jump(_) | PatchPoint {..}) => result.clone(), + Return { val } => Return { val: find!(*val) }, + StringCopy { val } => StringCopy { val: find!(*val) }, + StringIntern { val } => StringIntern { val: find!(*val) }, + Test { val } => Test { val: find!(*val) }, + IfTrue { val, target } => IfTrue { val: find!(*val), target: target.clone() }, + IfFalse { val, target } => IfFalse { val: find!(*val), target: target.clone() }, + GuardType { val, guard_type, state } => GuardType { val: find!(*val), guard_type: *guard_type, state: *state }, + FixnumAdd { left, right, state } => FixnumAdd { left: find!(*left), right: find!(*right), state: *state }, + FixnumSub { left, right, state } => FixnumSub { left: find!(*left), right: find!(*right), state: *state }, + FixnumMult { left, right, state } => FixnumMult { left: find!(*left), right: find!(*right), state: *state }, + FixnumDiv { left, right, state } => FixnumDiv { left: find!(*left), right: find!(*right), state: *state }, + FixnumMod { left, right, state } => FixnumMod { left: find!(*left), right: find!(*right), state: *state }, + FixnumNeq { left, right, state } => FixnumNeq { left: find!(*left), right: find!(*right), state: *state }, + FixnumEq { left, right, state } => FixnumEq { left: find!(*left), right: find!(*right), state: *state }, + FixnumGt { left, right, state } => FixnumGt { left: find!(*left), right: find!(*right), state: *state }, + FixnumGe { left, right, state } => FixnumGe { left: find!(*left), right: find!(*right), state: *state }, + FixnumLt { left, right, state } => FixnumLt { left: find!(*left), right: find!(*right), state: *state }, + FixnumLe { left, right, state } => FixnumLe { left: find!(*left), right: find!(*right), state: *state }, + Send { self_val, call_info, args } => Send { self_val: find!(*self_val), call_info: call_info.clone(), args: args.iter().map(|arg| find!(*arg)).collect() }, + ArraySet { array, idx, val } => ArraySet { array: find!(*array), idx: *idx, val: find!(*val) }, + ArrayDup { val } => ArrayDup { val: find!(*val) }, + CCall { cfun, args } => CCall { cfun: *cfun, args: args.iter().map(|arg| find!(*arg)).collect() }, + Defined { .. } => todo!("find(Defined)"), } } @@ -410,6 +456,159 @@ impl Function { let new_insn = self.push_insn(block, replacement); self.union_find.make_equal_to(insn, new_insn); } + + fn type_of(&self, insn: InsnId) -> Type { + assert!(self.insns[insn.0].has_output()); + self.insn_types[insn.0] + } + + fn infer_type(&self, insn: InsnId) -> Type { + assert!(self.insns[insn.0].has_output()); + match &self.insns[insn.0] { + Insn::Param { .. } => unimplemented!("use infer_param_type instead"), + Insn::ArraySet { .. } | Insn::Snapshot { .. } | Insn::Jump(_) + | Insn::IfTrue { .. } | Insn::IfFalse { .. } | Insn::Return { .. } + | Insn::PatchPoint { .. } => + panic!("Cannot infer type of instruction with no output"), + Insn::Const { val: Const::Value(val) } => Type::from_value(*val), + Insn::Const { val: Const::CBool(val) } => Type::from_cbool(*val), + Insn::Const { val: Const::CInt8(val) } => Type::from_cint(types::CInt8, *val as i64), + Insn::Const { val: Const::CInt16(val) } => Type::from_cint(types::CInt16, *val as i64), + Insn::Const { val: Const::CInt32(val) } => Type::from_cint(types::CInt32, *val as i64), + Insn::Const { val: Const::CInt64(val) } => Type::from_cint(types::CInt64, *val as i64), + Insn::Const { val: Const::CUInt8(val) } => Type::from_cint(types::CUInt8, *val as i64), + Insn::Const { val: Const::CUInt16(val) } => Type::from_cint(types::CUInt16, *val as i64), + Insn::Const { val: Const::CUInt32(val) } => Type::from_cint(types::CUInt32, *val as i64), + Insn::Const { val: Const::CUInt64(val) } => Type::from_cint(types::CUInt64, *val as i64), + Insn::Const { val: Const::CPtr(val) } => Type::from_cint(types::CPtr, *val as i64), + Insn::Const { val: Const::CDouble(val) } => Type::from_double(*val), + Insn::Test { val } if self.type_of(*val).is_subtype(types::NilClassExact) || self.type_of(*val).is_subtype(types::FalseClassExact) => Type::from_cbool(false), + Insn::Test { val } if !self.type_of(*val).could_be(types::NilClassExact) && !self.type_of(*val).could_be(types::FalseClassExact) => Type::from_cbool(true), + Insn::Test { .. } => types::CBool, + Insn::StringCopy { .. } => types::StringExact, + Insn::StringIntern { .. } => types::StringExact, + Insn::NewArray { .. } => types::ArrayExact, + Insn::ArrayDup { .. } => types::ArrayExact, + Insn::CCall { .. } => types::Any, + Insn::GuardType { val, guard_type, .. } => self.type_of(*val).intersection(*guard_type), + Insn::FixnumAdd { .. } => types::Fixnum, + Insn::FixnumSub { .. } => types::Fixnum, + Insn::FixnumMult { .. } => types::Fixnum, + Insn::FixnumDiv { .. } => types::Fixnum, + Insn::FixnumMod { .. } => types::Fixnum, + Insn::FixnumEq { .. } => types::BoolExact, + Insn::FixnumNeq { .. } => types::BoolExact, + Insn::FixnumLt { .. } => types::BoolExact, + Insn::FixnumLe { .. } => types::BoolExact, + Insn::FixnumGt { .. } => types::BoolExact, + Insn::FixnumGe { .. } => types::BoolExact, + Insn::Send { .. } => types::Object, + Insn::PutSelf => types::Object, + Insn::Defined { .. } => types::Object, + Insn::GetConstantPath { .. } => types::Object, + } + } + + fn infer_types(&mut self) { + // Reset all types + self.insn_types.fill(types::Empty); + // Compute predecessor instructions for each block + let mut preds: Vec<Vec<InsnId>> = vec![Vec::new(); self.blocks.len()]; + let rpo = self.rpo(); + // Walk the graph, computing predecessor blocks + for block in &rpo { + for insn in &self.blocks[block.0].insns { + match self.find(*insn) { + Insn::IfTrue { target, .. } + | Insn::IfFalse { target, .. } + | Insn::Jump(target) => + preds[target.target.0].push(*insn), + _ => {} + } + } + } + for idx in 0..preds.len() { + preds[idx].sort(); + preds[idx].dedup(); + } + // Walk the graph, computing types until fixpoint + let mut reachable = vec![false; self.blocks.len()]; + reachable[self.entry_block.0] = true; + loop { + let mut changed = false; + for block in &rpo { + if !reachable[block.0] { continue; } + for insn_id in &self.blocks[block.0].insns { + let insn = self.find(*insn_id); + let insn_type = match insn { + Insn::IfTrue { val, target: BranchEdge { target, args } } => { + assert!(!self.type_of(val).bit_equal(types::Empty)); + if self.type_of(val).could_be(Type::from_cbool(true)) { + reachable[target.0] = true; + for (idx, arg) in args.iter().enumerate() { + let param = self.blocks[target.0].params[idx]; + self.insn_types[param.0] = self.type_of(param).union(self.type_of(*arg)); + } + } + continue; + } + Insn::IfFalse { val, target: BranchEdge { target, args } } => { + assert!(!self.type_of(val).bit_equal(types::Empty)); + if self.type_of(val).could_be(Type::from_cbool(false)) { + reachable[target.0] = true; + for (idx, arg) in args.iter().enumerate() { + let param = self.blocks[target.0].params[idx]; + self.insn_types[param.0] = self.type_of(param).union(self.type_of(*arg)); + } + } + continue; + } + Insn::Jump(BranchEdge { target, args }) => { + reachable[target.0] = true; + for (idx, arg) in args.iter().enumerate() { + let param = self.blocks[target.0].params[idx]; + self.insn_types[param.0] = self.type_of(param).union(self.type_of(*arg)); + } + continue; + } + _ if insn.has_output() => self.infer_type(*insn_id), + _ => continue, + }; + if !self.type_of(*insn_id).bit_equal(insn_type) { + self.insn_types[insn_id.0] = insn_type; + changed = true; + } + } + } + if !changed { + break; + } + } + } + + /// Return a traversal of the `Function`'s `BlockId`s in reverse post-order. + fn rpo(&self) -> Vec<BlockId> { + let mut result = vec![]; + self.po_from(self.entry_block, &mut result, &mut HashSet::new()); + result.reverse(); + result + } + + fn po_from(&self, block: BlockId, mut result: &mut Vec<BlockId>, mut seen: &mut HashSet<BlockId>) { + // TODO(max): Avoid using recursion for post-order traversal. For graphs, this is slightly + // trickier than it might seem. + if seen.contains(&block) { return; } + seen.insert(block); + for insn in &self.blocks[block.0].insns { + match self.find(*insn) { + Insn::Jump(edge) => self.po_from(edge.target, &mut result, &mut seen), + Insn::IfTrue { target, .. } => self.po_from(target.target, &mut result, &mut seen), + Insn::IfFalse { target, .. } => self.po_from(target.target, &mut result, &mut seen), + _ => {} + } + } + result.push(block); + } } impl<'a> std::fmt::Display for FunctionPrinter<'a> { @@ -422,23 +621,36 @@ impl<'a> std::fmt::Display for FunctionPrinter<'a> { let mut sep = ""; for param in &block.params { write!(f, "{sep}{param}")?; + let insn_type = fun.type_of(*param); + if !insn_type.is_subtype(types::Empty) { + write!(f, ":{insn_type}")?; + } sep = ", "; } } writeln!(f, "):")?; for insn_id in &block.insns { - if !self.display_snapshot && matches!(fun.insns[insn_id.0], Insn::Snapshot {..}) { + let insn = fun.find(*insn_id); + if !self.display_snapshot && matches!(insn, Insn::Snapshot {..}) { continue; } - write!(f, " {insn_id} = ")?; - match &fun.insns[insn_id.0] { + write!(f, " ")?; + if insn.has_output() { + let insn_type = fun.type_of(*insn_id); + if insn_type.is_subtype(types::Empty) { + write!(f, "{insn_id} = ")?; + } else { + write!(f, "{insn_id}:{insn_type} = ")?; + } + } + match insn { Insn::Const { val } => { write!(f, "Const {val}")?; } Insn::Param { idx } => { write!(f, "Param {idx}")?; } Insn::NewArray { count } => { write!(f, "NewArray {count}")?; } - Insn::ArraySet { idx, val } => { write!(f, "ArraySet {idx}, {val}")?; } + Insn::ArraySet { array, idx, val } => { write!(f, "ArraySet {array}, {idx}, {val}")?; } Insn::ArrayDup { val } => { write!(f, "ArrayDup {val}")?; } Insn::Test { val } => { write!(f, "Test {val}")?; } - Insn::Snapshot { state } => { write!(f, "Snapshot {}", fun.frame_state(*state))?; } + Insn::Snapshot { state } => { write!(f, "Snapshot {}", fun.frame_state(state))?; } Insn::Jump(target) => { write!(f, "Jump {target}")?; } Insn::IfTrue { val, target } => { write!(f, "IfTrue {val}, {target}")?; } Insn::IfFalse { val, target } => { write!(f, "IfFalse {val}, {target}")?; } @@ -702,11 +914,11 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> { } YARVINSN_newarray => { let count = get_arg(pc, 0).as_usize(); - let insn_id = fun.push_insn(block, Insn::NewArray { count }); + let array = fun.push_insn(block, Insn::NewArray { count }); for idx in (0..count).rev() { - fun.push_insn(block, Insn::ArraySet { idx, val: state.pop()? }); + fun.push_insn(block, Insn::ArraySet { array, idx, val: state.pop()? }); } - state.push(insn_id); + state.push(array); } YARVINSN_duparray => { let val = fun.push_insn(block, Insn::Const { val: Const::Value(get_arg(pc, 0)) }); @@ -968,6 +1180,8 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> { } } + fun.infer_types(); + match get_option!(dump_hir) { Some(DumpHIR::WithoutSnapshot) => println!("HIR:\n{}", FunctionPrinter::without_snapshot(&fun)), Some(DumpHIR::All) => println!("HIR:\n{}", FunctionPrinter::with_snapshot(&fun)), @@ -1028,6 +1242,188 @@ mod union_find_tests { } #[cfg(test)] +mod rpo_tests { + use super::*; + + #[test] + fn one_block() { + let mut function = Function::new(std::ptr::null()); + let entry = function.entry_block; + let val = function.push_insn(entry, Insn::Const { val: Const::Value(Qnil) }); + function.push_insn(entry, Insn::Return { val }); + assert_eq!(function.rpo(), vec![entry]); + } + + #[test] + fn jump() { + let mut function = Function::new(std::ptr::null()); + let entry = function.entry_block; + let exit = function.new_block(); + function.push_insn(entry, Insn::Jump(BranchEdge { target: exit, args: vec![] })); + let val = function.push_insn(entry, Insn::Const { val: Const::Value(Qnil) }); + function.push_insn(entry, Insn::Return { val }); + assert_eq!(function.rpo(), vec![entry, exit]); + } + + #[test] + fn diamond_iftrue() { + let mut function = Function::new(std::ptr::null()); + let entry = function.entry_block; + let side = function.new_block(); + let exit = function.new_block(); + function.push_insn(side, Insn::Jump(BranchEdge { target: exit, args: vec![] })); + let val = function.push_insn(entry, Insn::Const { val: Const::Value(Qnil) }); + function.push_insn(entry, Insn::IfTrue { val, target: BranchEdge { target: side, args: vec![] } }); + function.push_insn(entry, Insn::Jump(BranchEdge { target: exit, args: vec![] })); + let val = function.push_insn(entry, Insn::Const { val: Const::Value(Qnil) }); + function.push_insn(entry, Insn::Return { val }); + assert_eq!(function.rpo(), vec![entry, side, exit]); + } + + #[test] + fn diamond_iffalse() { + let mut function = Function::new(std::ptr::null()); + let entry = function.entry_block; + let side = function.new_block(); + let exit = function.new_block(); + function.push_insn(side, Insn::Jump(BranchEdge { target: exit, args: vec![] })); + let val = function.push_insn(entry, Insn::Const { val: Const::Value(Qnil) }); + function.push_insn(entry, Insn::IfFalse { val, target: BranchEdge { target: side, args: vec![] } }); + function.push_insn(entry, Insn::Jump(BranchEdge { target: exit, args: vec![] })); + let val = function.push_insn(entry, Insn::Const { val: Const::Value(Qnil) }); + function.push_insn(entry, Insn::Return { val }); + assert_eq!(function.rpo(), vec![entry, side, exit]); + } + + #[test] + fn a_loop() { + let mut function = Function::new(std::ptr::null()); + let entry = function.entry_block; + function.push_insn(entry, Insn::Jump(BranchEdge { target: entry, args: vec![] })); + assert_eq!(function.rpo(), vec![entry]); + } +} + +#[cfg(test)] +mod infer_tests { + use super::*; + + #[track_caller] + fn assert_subtype(left: Type, right: Type) { + assert!(left.is_subtype(right), "{left} is not a subtype of {right}"); + } + + #[track_caller] + fn assert_bit_equal(left: Type, right: Type) { + assert!(left.bit_equal(right), "{left} != {right}"); + } + + #[test] + fn test_const() { + let mut function = Function::new(std::ptr::null()); + let val = function.push_insn(function.entry_block, Insn::Const { val: Const::Value(Qnil) }); + assert_bit_equal(function.infer_type(val), types::NilClassExact); + } + + #[test] + fn test_nil() { + crate::cruby::with_rubyvm(|| { + let mut function = Function::new(std::ptr::null()); + let nil = function.push_insn(function.entry_block, Insn::Const { val: Const::Value(Qnil) }); + let val = function.push_insn(function.entry_block, Insn::Test { val: nil }); + function.infer_types(); + assert_bit_equal(function.type_of(val), Type::from_cbool(false)); + }); + } + + #[test] + fn test_false() { + crate::cruby::with_rubyvm(|| { + let mut function = Function::new(std::ptr::null()); + let false_ = function.push_insn(function.entry_block, Insn::Const { val: Const::Value(Qfalse) }); + let val = function.push_insn(function.entry_block, Insn::Test { val: false_ }); + function.infer_types(); + assert_bit_equal(function.type_of(val), Type::from_cbool(false)); + }); + } + + #[test] + fn test_truthy() { + crate::cruby::with_rubyvm(|| { + let mut function = Function::new(std::ptr::null()); + let true_ = function.push_insn(function.entry_block, Insn::Const { val: Const::Value(Qtrue) }); + let val = function.push_insn(function.entry_block, Insn::Test { val: true_ }); + function.infer_types(); + assert_bit_equal(function.type_of(val), Type::from_cbool(true)); + }); + } + + #[test] + fn test_unknown() { + crate::cruby::with_rubyvm(|| { + let mut function = Function::new(std::ptr::null()); + let param = function.push_insn(function.entry_block, Insn::PutSelf); + let val = function.push_insn(function.entry_block, Insn::Test { val: param }); + function.infer_types(); + assert_bit_equal(function.type_of(val), types::CBool); + }); + } + + #[test] + fn newarray() { + let mut function = Function::new(std::ptr::null()); + let val = function.push_insn(function.entry_block, Insn::NewArray { count: 0 }); + assert_bit_equal(function.infer_type(val), types::ArrayExact); + } + + #[test] + fn arraydup() { + let mut function = Function::new(std::ptr::null()); + let arr = function.push_insn(function.entry_block, Insn::NewArray { count: 0 }); + let val = function.push_insn(function.entry_block, Insn::ArrayDup { val: arr }); + assert_bit_equal(function.infer_type(val), types::ArrayExact); + } + + #[test] + fn diamond_iffalse_merge_fixnum() { + let mut function = Function::new(std::ptr::null()); + let entry = function.entry_block; + let side = function.new_block(); + let exit = function.new_block(); + let v0 = function.push_insn(side, Insn::Const { val: Const::Value(VALUE::fixnum_from_usize(3)) }); + function.push_insn(side, Insn::Jump(BranchEdge { target: exit, args: vec![v0] })); + let val = function.push_insn(entry, Insn::Const { val: Const::CBool(false) }); + function.push_insn(entry, Insn::IfFalse { val, target: BranchEdge { target: side, args: vec![] } }); + let v1 = function.push_insn(entry, Insn::Const { val: Const::Value(VALUE::fixnum_from_usize(4)) }); + function.push_insn(entry, Insn::Jump(BranchEdge { target: exit, args: vec![v1] })); + let param = function.push_insn(exit, Insn::Param { idx: 0 }); + crate::cruby::with_rubyvm(|| { + function.infer_types(); + }); + assert_bit_equal(function.type_of(param), types::Fixnum); + } + + #[test] + fn diamond_iffalse_merge_bool() { + let mut function = Function::new(std::ptr::null()); + let entry = function.entry_block; + let side = function.new_block(); + let exit = function.new_block(); + let v0 = function.push_insn(side, Insn::Const { val: Const::Value(Qtrue) }); + function.push_insn(side, Insn::Jump(BranchEdge { target: exit, args: vec![v0] })); + let val = function.push_insn(entry, Insn::Const { val: Const::CBool(false) }); + function.push_insn(entry, Insn::IfFalse { val, target: BranchEdge { target: side, args: vec![] } }); + let v1 = function.push_insn(entry, Insn::Const { val: Const::Value(Qfalse) }); + function.push_insn(entry, Insn::Jump(BranchEdge { target: exit, args: vec![v1] })); + let param = function.push_insn(exit, Insn::Param { idx: 0 }); + crate::cruby::with_rubyvm(|| { + function.infer_types(); + assert_bit_equal(function.type_of(param), types::TrueClassExact.union(types::FalseClassExact)); + }); + } +} + +#[cfg(test)] mod tests { use super::*; @@ -1093,8 +1489,8 @@ mod tests { let function = iseq_to_hir(iseq).unwrap(); assert_function_hir(function, " bb0(): - v1 = Const Value(123) - v3 = Return v1 + v1:Fixnum[123] = Const Value(123) + Return v1 "); } @@ -1105,10 +1501,10 @@ mod tests { let function = iseq_to_hir(iseq).unwrap(); assert_function_hir(function, " bb0(): - v1 = Const Value(1) - v3 = Const Value(2) - v5 = Send v1, :+, v3 - v7 = Return v5 + v1:Fixnum[1] = Const Value(1) + v3:Fixnum[2] = Const Value(2) + v5:Object = Send v1, :+, v3 + Return v5 "); } @@ -1119,9 +1515,9 @@ mod tests { let function = iseq_to_hir(iseq).unwrap(); assert_function_hir(function, " bb0(): - v0 = Const Value(nil) - v2 = Const Value(1) - v6 = Return v2 + v0:NilClassExact = Const Value(nil) + v2:Fixnum[1] = Const Value(1) + Return v2 "); } @@ -1132,15 +1528,15 @@ mod tests { let function = iseq_to_hir(iseq).unwrap(); assert_function_hir(function, " bb0(): - v0 = Const Value(nil) - v2 = Const Value(true) - v6 = Test v2 - v7 = IfFalse v6, bb1(v2) - v9 = Const Value(3) - v11 = Return v9 + v0:NilClassExact = Const Value(nil) + v2:TrueClassExact = Const Value(true) + v6:CBool[true] = Test v2 + IfFalse v6, bb1(v2) + v9:Fixnum[3] = Const Value(3) + Return v9 bb1(v12): v14 = Const Value(4) - v16 = Return v14 + Return v14 "); } @@ -1152,11 +1548,11 @@ mod tests { "); assert_method_hir("test", " bb0(v0, v1): - v5 = PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_PLUS) + PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_PLUS) v6 = GuardType v0, Fixnum v7 = GuardType v1, Fixnum - v8 = FixnumAdd v6, v7 - v10 = Return v8 + v8:Fixnum = FixnumAdd v6, v7 + Return v8 "); } @@ -1168,11 +1564,11 @@ mod tests { "); assert_method_hir("test", " bb0(v0, v1): - v5 = PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_MINUS) + PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_MINUS) v6 = GuardType v0, Fixnum v7 = GuardType v1, Fixnum - v8 = FixnumSub v6, v7 - v10 = Return v8 + v8:Fixnum = FixnumSub v6, v7 + Return v8 "); } @@ -1184,11 +1580,11 @@ mod tests { "); assert_method_hir("test", " bb0(v0, v1): - v5 = PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_MULT) + PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_MULT) v6 = GuardType v0, Fixnum v7 = GuardType v1, Fixnum - v8 = FixnumMult v6, v7 - v10 = Return v8 + v8:Fixnum = FixnumMult v6, v7 + Return v8 "); } @@ -1200,11 +1596,11 @@ mod tests { "); assert_method_hir("test", " bb0(v0, v1): - v5 = PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_DIV) + PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_DIV) v6 = GuardType v0, Fixnum v7 = GuardType v1, Fixnum - v8 = FixnumDiv v6, v7 - v10 = Return v8 + v8:Fixnum = FixnumDiv v6, v7 + Return v8 "); } @@ -1216,11 +1612,11 @@ mod tests { "); assert_method_hir("test", " bb0(v0, v1): - v5 = PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_MOD) + PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_MOD) v6 = GuardType v0, Fixnum v7 = GuardType v1, Fixnum - v8 = FixnumMod v6, v7 - v10 = Return v8 + v8:Fixnum = FixnumMod v6, v7 + Return v8 "); } @@ -1232,11 +1628,11 @@ mod tests { "); assert_method_hir("test", " bb0(v0, v1): - v5 = PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_EQ) + PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_EQ) v6 = GuardType v0, Fixnum v7 = GuardType v1, Fixnum - v8 = FixnumEq v6, v7 - v10 = Return v8 + v8:BoolExact = FixnumEq v6, v7 + Return v8 "); } @@ -1248,11 +1644,11 @@ mod tests { "); assert_method_hir("test", " bb0(v0, v1): - v5 = PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_NEQ) + PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_NEQ) v6 = GuardType v0, Fixnum v7 = GuardType v1, Fixnum - v8 = FixnumNeq v6, v7 - v10 = Return v8 + v8:BoolExact = FixnumNeq v6, v7 + Return v8 "); } @@ -1264,11 +1660,11 @@ mod tests { "); assert_method_hir("test", " bb0(v0, v1): - v5 = PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_LT) + PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_LT) v6 = GuardType v0, Fixnum v7 = GuardType v1, Fixnum - v8 = FixnumLt v6, v7 - v10 = Return v8 + v8:BoolExact = FixnumLt v6, v7 + Return v8 "); } @@ -1280,11 +1676,11 @@ mod tests { "); assert_method_hir("test", " bb0(v0, v1): - v5 = PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_LE) + PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_LE) v6 = GuardType v0, Fixnum v7 = GuardType v1, Fixnum - v8 = FixnumLe v6, v7 - v10 = Return v8 + v8:BoolExact = FixnumLe v6, v7 + Return v8 "); } @@ -1296,11 +1692,57 @@ mod tests { "); assert_method_hir("test", " bb0(v0, v1): - v5 = PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_GT) + PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_GT) v6 = GuardType v0, Fixnum v7 = GuardType v1, Fixnum - v8 = FixnumGt v6, v7 - v10 = Return v8 + v8:BoolExact = FixnumGt v6, v7 + Return v8 + "); + } + + #[test] + fn test_loop() { + eval(" + def test + result = 0 + times = 10 + while times > 0 + result = result + 1 + times = times - 1 + end + result + end + test + "); + assert_method_hir("test", " + bb0(): + v0:NilClassExact = Const Value(nil) + v1:NilClassExact = Const Value(nil) + v3:Fixnum[0] = Const Value(0) + v6:Fixnum[10] = Const Value(10) + Jump bb2(v3, v6) + bb1(v29:Fixnum, v30:Fixnum): + v33:Fixnum[1] = Const Value(1) + PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_PLUS) + v36:Fixnum = GuardType v29, Fixnum + v37:Fixnum[1] = GuardType v33, Fixnum + v38:Fixnum = FixnumAdd v36, v37 + v42:Fixnum[1] = Const Value(1) + PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_MINUS) + v45:Fixnum = GuardType v30, Fixnum + v46:Fixnum[1] = GuardType v42, Fixnum + v47:Fixnum = FixnumSub v45, v46 + Jump bb2(v38, v47) + bb2(v10:Fixnum, v11:Fixnum): + v14:Fixnum[0] = Const Value(0) + PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_GT) + v17:Fixnum = GuardType v11, Fixnum + v18:Fixnum[0] = GuardType v14, Fixnum + v19:BoolExact = FixnumGt v17, v18 + v21:CBool = Test v19 + IfTrue v21, bb1(v10, v11) + v24:NilClassExact = Const Value(nil) + Return v10 "); } @@ -1312,11 +1754,37 @@ mod tests { "); assert_method_hir("test", " bb0(v0, v1): - v5 = PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_GE) + PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_GE) v6 = GuardType v0, Fixnum v7 = GuardType v1, Fixnum - v8 = FixnumGe v6, v7 - v10 = Return v8 + v8:BoolExact = FixnumGe v6, v7 + Return v8 + "); + } + + #[test] + fn test_display_types() { + eval(" + def test + cond = true + if cond + 3 + else + 4 + end + end "); + assert_method_hir("test", " + bb0(): + v0:NilClassExact = Const Value(nil) + v2:TrueClassExact = Const Value(true) + v6:CBool[true] = Test v2 + IfFalse v6, bb1(v2) + v9:Fixnum[3] = Const Value(3) + Return v9 + bb1(v12): + v14 = Const Value(4) + Return v14 + "); } } diff --git a/zjit/src/hir_type.inc.rs b/zjit/src/hir_type.inc.rs index 3921ec042d..b55c0c7d83 100644 --- a/zjit/src/hir_type.inc.rs +++ b/zjit/src/hir_type.inc.rs @@ -1,11 +1,12 @@ // This file is @generated by src/gen_hir_type.rb. mod bits { - pub const Any: u64 = Object | Primitive | User; + pub const Any: u64 = Object | Primitive; pub const Array: u64 = ArrayExact | ArrayUser; pub const ArrayExact: u64 = 1u64 << 0; pub const ArrayUser: u64 = 1u64 << 1; pub const Bignum: u64 = 1u64 << 2; - pub const BuiltinExact: u64 = ArrayExact | FalseClassExact | FloatExact | HashExact | IntegerExact | NilClassExact | StringExact | SymbolExact | TrueClassExact; + pub const BoolExact: u64 = FalseClassExact | TrueClassExact; + pub const BuiltinExact: u64 = ArrayExact | FalseClassExact | FloatExact | HashExact | IntegerExact | NilClassExact | ObjectExact | StringExact | SymbolExact | TrueClassExact; pub const CBool: u64 = 1u64 << 3; pub const CDouble: u64 = 1u64 << 4; pub const CInt: u64 = CSigned | CUnsigned; @@ -41,29 +42,29 @@ mod bits { pub const NilClass: u64 = NilClassExact | NilClassUser; pub const NilClassExact: u64 = 1u64 << 25; pub const NilClassUser: u64 = 1u64 << 26; - pub const Object: u64 = BuiltinExact | ObjectExact | ObjectUser; + pub const Object: u64 = Array | FalseClass | Float | Hash | Integer | NilClass | ObjectExact | ObjectUser | String | Symbol | TrueClass; pub const ObjectExact: u64 = 1u64 << 27; - pub const ObjectUser: u64 = Array | FalseClass | Float | Hash | Integer | NilClass | String | Symbol | TrueClass; + pub const ObjectUser: u64 = 1u64 << 28; pub const Primitive: u64 = CBool | CDouble | CInt | CNull | CPtr; - pub const StaticSymbol: u64 = 1u64 << 28; + pub const StaticSymbol: u64 = 1u64 << 29; pub const String: u64 = StringExact | StringUser; - pub const StringExact: u64 = 1u64 << 29; - pub const StringUser: u64 = 1u64 << 30; + pub const StringExact: u64 = 1u64 << 30; + pub const StringUser: u64 = 1u64 << 31; pub const Symbol: u64 = SymbolExact | SymbolUser; pub const SymbolExact: u64 = DynamicSymbol | StaticSymbol; - pub const SymbolUser: u64 = 1u64 << 31; + pub const SymbolUser: u64 = 1u64 << 32; pub const TrueClass: u64 = TrueClassExact | TrueClassUser; - pub const TrueClassExact: u64 = 1u64 << 32; - pub const TrueClassUser: u64 = 1u64 << 33; - pub const User: u64 = ArrayUser | FalseClassUser | FloatUser | HashUser | IntegerUser | NilClassUser | StringUser | SymbolUser | TrueClassUser; - pub const AllBitPatterns: [(&'static str, u64); 56] = [ + pub const TrueClassExact: u64 = 1u64 << 33; + pub const TrueClassUser: u64 = 1u64 << 34; + pub const User: u64 = ArrayUser | FalseClassUser | FloatUser | HashUser | IntegerUser | NilClassUser | ObjectUser | StringUser | SymbolUser | TrueClassUser; + pub const AllBitPatterns: [(&'static str, u64); 57] = [ ("Any", Any), ("Object", Object), - ("ObjectUser", ObjectUser), ("TrueClass", TrueClass), ("User", User), ("TrueClassUser", TrueClassUser), ("BuiltinExact", BuiltinExact), + ("BoolExact", BoolExact), ("TrueClassExact", TrueClassExact), ("Symbol", Symbol), ("SymbolUser", SymbolUser), @@ -72,6 +73,7 @@ mod bits { ("StringExact", StringExact), ("SymbolExact", SymbolExact), ("StaticSymbol", StaticSymbol), + ("ObjectUser", ObjectUser), ("ObjectExact", ObjectExact), ("NilClass", NilClass), ("NilClassUser", NilClassUser), @@ -114,7 +116,7 @@ mod bits { ("ArrayExact", ArrayExact), ("Empty", Empty), ]; - pub const NumTypeBits: u64 = 34; + pub const NumTypeBits: u64 = 35; } pub mod types { use super::*; @@ -123,6 +125,7 @@ pub mod types { pub const ArrayExact: Type = Type::from_bits(bits::ArrayExact); pub const ArrayUser: Type = Type::from_bits(bits::ArrayUser); pub const Bignum: Type = Type::from_bits(bits::Bignum); + pub const BoolExact: Type = Type::from_bits(bits::BoolExact); pub const BuiltinExact: Type = Type::from_bits(bits::BuiltinExact); pub const CBool: Type = Type::from_bits(bits::CBool); pub const CDouble: Type = Type::from_bits(bits::CDouble); diff --git a/zjit/src/hir_type.rs b/zjit/src/hir_type.rs index 4b6b2cbb23..de52483fc1 100644 --- a/zjit/src/hir_type.rs +++ b/zjit/src/hir_type.rs @@ -194,6 +194,11 @@ impl Type { Type { bits: ty.bits, spec: Specialization::Int(val as u64) } } + /// Create a `Type` (a `CDouble` with double specialization) from a f64. + pub fn from_double(val: f64) -> Type { + Type { bits: bits::CDouble, spec: Specialization::Double(val) } + } + /// Create a `Type` from a primitive boolean. pub fn from_cbool(val: bool) -> Type { Type { bits: bits::CBool, spec: Specialization::Int(val as u64) } @@ -226,6 +231,7 @@ impl Type { false } + /// Union both types together, preserving specialization if possible. pub fn union(&self, other: Type) -> Type { // Easy cases first if self.is_subtype(other) { return other; } @@ -256,6 +262,19 @@ impl Type { Type { bits, spec: Specialization::Type(super_class) } } + /// Intersect both types, preserving specialization if possible. + pub fn intersection(&self, other: Type) -> Type { + let bits = self.bits & other.bits; + if bits == bits::Empty { return types::Empty; } + if self.spec_is_subtype_of(other) { return Type { bits, spec: self.spec }; } + if other.spec_is_subtype_of(*self) { return Type { bits, spec: other.spec }; } + types::Empty + } + + pub fn could_be(&self, other: Type) -> bool { + !self.intersection(other).bit_equal(types::Empty) + } + /// Check if the type field of `self` is a subtype of the type field of `other` and also check /// if the specialization of `self` is a subtype of the specialization of `other`. pub fn is_subtype(&self, other: Type) -> bool { @@ -298,6 +317,11 @@ impl Type { } } + /// Check bit equality of two `Type`s. Do not use! You are probably looking for [`Type::is_subtype`]. + pub fn bit_equal(&self, other: Type) -> bool { + self.bits == other.bits && self.spec == other.spec + } + /// Check *only* if `self`'s specialization is a subtype of `other`'s specialization. Private. /// You probably want [`Type::is_subtype`] instead. fn spec_is_subtype_of(&self, other: Type) -> bool { |
