summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--zjit/src/gen_hir_type.rb56
-rw-r--r--zjit/src/hir.rs604
-rw-r--r--zjit/src/hir_type.inc.rs31
-rw-r--r--zjit/src/hir_type.rs24
4 files changed, 612 insertions, 103 deletions
diff --git a/zjit/src/gen_hir_type.rb b/zjit/src/gen_hir_type.rb
index 42f25f8fc8..7fc29fe8f8 100644
--- a/zjit/src/gen_hir_type.rb
+++ b/zjit/src/gen_hir_type.rb
@@ -46,19 +46,19 @@ end
# Start at Any. All types are subtypes of Any.
any = Type.new "Any"
# Build the Ruby object universe.
-object = any.subtype "Object"
-object.subtype "ObjectExact"
-$object_user = object.subtype "ObjectUser"
-$user = any.subtype "User"
-$builtin_exact = object.subtype "BuiltinExact"
+$object = any.subtype "Object"
+object_exact = $object.subtype "ObjectExact"
+object_user = $object.subtype "ObjectUser"
+$user = [object_user.name]
+$builtin_exact = [object_exact.name]
# Define a new type that can be subclassed (most of them).
def base_type name
- type = $object_user.subtype name
+ type = $object.subtype name
exact = type.subtype(name+"Exact")
user = type.subtype(name+"User")
- $builtin_exact.subtypes << exact
- $user.subtypes << user
+ $builtin_exact << exact.name
+ $user << user.name
[type, exact]
end
@@ -82,8 +82,8 @@ symbol_exact.subtype "StaticSymbol"
symbol_exact.subtype "DynamicSymbol"
base_type "NilClass"
-base_type "TrueClass"
-base_type "FalseClass"
+_, true_exact = base_type "TrueClass"
+_, false_exact = base_type "FalseClass"
# Build the primitive object universe.
primitive = any.subtype "Primitive"
@@ -101,40 +101,54 @@ unsigned = primitive_int.subtype "CUnsigned"
# Assign individual bits to type leaves and union bit patterns to nodes with subtypes
num_bits = 0
-bits = {"Empty" => ["0u64"]}
-numeric_bits = {"Empty" => 0}
+$bits = {"Empty" => ["0u64"]}
+$numeric_bits = {"Empty" => 0}
Set[any, *any.all_subtypes].sort_by(&:name).each {|type|
subtypes = type.subtypes
if subtypes.empty?
# Assign bits for leaves
- bits[type.name] = ["1u64 << #{num_bits}"]
- numeric_bits[type.name] = 1 << num_bits
+ $bits[type.name] = ["1u64 << #{num_bits}"]
+ $numeric_bits[type.name] = 1 << num_bits
num_bits += 1
else
# Assign bits for unions
- bits[type.name] = subtypes.map(&:name).sort
+ $bits[type.name] = subtypes.map(&:name).sort
end
}
[*any.all_subtypes, any].each {|type|
subtypes = type.subtypes
unless subtypes.empty?
- numeric_bits[type.name] = subtypes.map {|ty| numeric_bits[ty.name]}.reduce(&:|)
+ $numeric_bits[type.name] = subtypes.map {|ty| $numeric_bits[ty.name]}.reduce(&:|)
end
}
+# Unions are for names of groups of type bit patterns that don't fit neatly
+# into the Ruby class hierarchy. For example, we might want to refer to a union
+# of TrueClassExact|FalseClassExact by the name BoolExact even though a "bool"
+# doesn't exist as a class in Ruby.
+def add_union name, type_names
+ type_names = type_names.sort
+ $bits[name] = type_names
+ $numeric_bits[name] = type_names.map {|type_name| $numeric_bits[type_name]}.reduce(&:|)
+end
+
+add_union "BuiltinExact", $builtin_exact
+add_union "User", $user
+add_union "BoolExact", [true_exact.name, false_exact.name]
+
# ===== Finished generating the DAG; write Rust code =====
puts "// This file is @generated by src/gen_hir_type.rb."
puts "mod bits {"
-bits.keys.sort.map {|type_name|
- subtypes = bits[type_name].join(" | ")
+$bits.keys.sort.map {|type_name|
+ subtypes = $bits[type_name].join(" | ")
puts " pub const #{type_name}: u64 = #{subtypes};"
}
-puts " pub const AllBitPatterns: [(&'static str, u64); #{bits.size}] = ["
+puts " pub const AllBitPatterns: [(&'static str, u64); #{$bits.size}] = ["
# Sort the bit patterns by decreasing value so that we can print the densest
# possible to-string representation of a Type. For example, CSigned instead of
# CInt8|CInt16|...
-numeric_bits.sort_by {|key, val| -val}.each {|type_name, _|
+$numeric_bits.sort_by {|key, val| -val}.each {|type_name, _|
puts " (\"#{type_name}\", #{type_name}),"
}
puts " ];"
@@ -143,7 +157,7 @@ puts " pub const NumTypeBits: u64 = #{num_bits};
puts "pub mod types {
use super::*;"
-bits.keys.sort.map {|type_name|
+$bits.keys.sort.map {|type_name|
puts " pub const #{type_name}: Type = Type::from_bits(bits::#{type_name});"
}
puts "}"
diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs
index f86aac169d..968fe1907e 100644
--- a/zjit/src/hir.rs
+++ b/zjit/src/hir.rs
@@ -5,10 +5,9 @@ use crate::{
cruby::*, options::get_option, hir_type::types::Fixnum, options::DumpHIR, profile::get_or_create_iseq_payload
};
use std::collections::{HashMap, HashSet};
+use crate::hir_type::{Type, types};
-use crate::hir_type::Type;
-
-#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
+#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)]
pub struct InsnId(pub usize);
impl Into<usize> for InsnId {
@@ -122,6 +121,7 @@ impl std::fmt::Display for Invariant {
#[derive(Debug, Clone, PartialEq)]
pub enum Const {
Value(VALUE),
+ CBool(bool),
CInt8(i8),
CInt16(i16),
CInt32(i32),
@@ -154,7 +154,7 @@ pub enum Insn {
StringIntern { val: InsnId },
NewArray { count: usize },
- ArraySet { idx: usize, val: InsnId },
+ ArraySet { array: InsnId, idx: usize, val: InsnId },
ArrayDup { val: InsnId },
// Check if the value is truthy and "return" a C boolean. In reality, we will likely fuse this
@@ -213,6 +213,18 @@ pub enum Insn {
PatchPoint(Invariant),
}
+impl Insn {
+ /// Not every instruction returns a value. Return true if the instruction does and false otherwise.
+ pub fn has_output(&self) -> bool {
+ match self {
+ Insn::ArraySet { .. } | Insn::Snapshot { .. } | Insn::Jump(_)
+ | Insn::IfTrue { .. } | Insn::IfFalse { .. } | Insn::Return { .. }
+ | Insn::PatchPoint { .. } => false,
+ _ => true,
+ }
+ }
+}
+
#[derive(Default, Debug)]
pub struct Block {
params: Vec<InsnId>,
@@ -336,6 +348,7 @@ pub struct Function {
pub insns: Vec<Insn>,
union_find: UnionFind<InsnId>,
+ insn_types: Vec<Type>,
blocks: Vec<Block>,
entry_block: BlockId,
frame_states: Vec<FrameState>,
@@ -346,6 +359,7 @@ impl Function {
Function {
iseq,
insns: vec![],
+ insn_types: vec![],
union_find: UnionFind::new(),
blocks: vec![Block::default()],
entry_block: BlockId(0),
@@ -362,6 +376,7 @@ impl Function {
self.blocks[block.0].insns.push(id);
}
self.insns.push(insn);
+ self.insn_types.push(types::Empty);
id
}
@@ -384,6 +399,10 @@ impl Function {
id
}
+ /// Return a copy of the instruction where the instruction and its operands have been read from
+ /// the union-find table (to find the current most-optimized version of this instruction). See
+ /// [`UnionFind`] for more.
+ ///
/// Use for pattern matching over instructions in a union-find-safe way. For example:
/// ```rust
/// match func.find(insn_id) {
@@ -393,15 +412,42 @@ impl Function {
/// _ => {}
/// }
/// ```
- fn find(&mut self, insn_id: InsnId) -> Insn {
- let insn_id = self.union_find.find(insn_id);
+ fn find(&self, insn_id: InsnId) -> Insn {
+ macro_rules! find {
+ ( $x:expr ) => {
+ {
+ self.union_find.find_const($x)
+ }
+ };
+ }
+ let insn_id = self.union_find.find_const(insn_id);
use Insn::*;
match &self.insns[insn_id.0] {
- result@(PutSelf | Const {..} | Param {..} | NewArray {..} | GetConstantPath {..}) => result.clone(),
- StringCopy { val } => StringCopy { val: self.union_find.find(*val) },
- StringIntern { val } => StringIntern { val: self.union_find.find(*val) },
- Test { val } => Test { val: self.union_find.find(*val) },
- insn => todo!("find({insn:?})"),
+ result@(PutSelf | Const {..} | Param {..} | NewArray {..} | GetConstantPath {..} | Snapshot {..}
+ | Jump(_) | PatchPoint {..}) => result.clone(),
+ Return { val } => Return { val: find!(*val) },
+ StringCopy { val } => StringCopy { val: find!(*val) },
+ StringIntern { val } => StringIntern { val: find!(*val) },
+ Test { val } => Test { val: find!(*val) },
+ IfTrue { val, target } => IfTrue { val: find!(*val), target: target.clone() },
+ IfFalse { val, target } => IfFalse { val: find!(*val), target: target.clone() },
+ GuardType { val, guard_type, state } => GuardType { val: find!(*val), guard_type: *guard_type, state: *state },
+ FixnumAdd { left, right, state } => FixnumAdd { left: find!(*left), right: find!(*right), state: *state },
+ FixnumSub { left, right, state } => FixnumSub { left: find!(*left), right: find!(*right), state: *state },
+ FixnumMult { left, right, state } => FixnumMult { left: find!(*left), right: find!(*right), state: *state },
+ FixnumDiv { left, right, state } => FixnumDiv { left: find!(*left), right: find!(*right), state: *state },
+ FixnumMod { left, right, state } => FixnumMod { left: find!(*left), right: find!(*right), state: *state },
+ FixnumNeq { left, right, state } => FixnumNeq { left: find!(*left), right: find!(*right), state: *state },
+ FixnumEq { left, right, state } => FixnumEq { left: find!(*left), right: find!(*right), state: *state },
+ FixnumGt { left, right, state } => FixnumGt { left: find!(*left), right: find!(*right), state: *state },
+ FixnumGe { left, right, state } => FixnumGe { left: find!(*left), right: find!(*right), state: *state },
+ FixnumLt { left, right, state } => FixnumLt { left: find!(*left), right: find!(*right), state: *state },
+ FixnumLe { left, right, state } => FixnumLe { left: find!(*left), right: find!(*right), state: *state },
+ Send { self_val, call_info, args } => Send { self_val: find!(*self_val), call_info: call_info.clone(), args: args.iter().map(|arg| find!(*arg)).collect() },
+ ArraySet { array, idx, val } => ArraySet { array: find!(*array), idx: *idx, val: find!(*val) },
+ ArrayDup { val } => ArrayDup { val: find!(*val) },
+ CCall { cfun, args } => CCall { cfun: *cfun, args: args.iter().map(|arg| find!(*arg)).collect() },
+ Defined { .. } => todo!("find(Defined)"),
}
}
@@ -410,6 +456,159 @@ impl Function {
let new_insn = self.push_insn(block, replacement);
self.union_find.make_equal_to(insn, new_insn);
}
+
+ fn type_of(&self, insn: InsnId) -> Type {
+ assert!(self.insns[insn.0].has_output());
+ self.insn_types[insn.0]
+ }
+
+ fn infer_type(&self, insn: InsnId) -> Type {
+ assert!(self.insns[insn.0].has_output());
+ match &self.insns[insn.0] {
+ Insn::Param { .. } => unimplemented!("use infer_param_type instead"),
+ Insn::ArraySet { .. } | Insn::Snapshot { .. } | Insn::Jump(_)
+ | Insn::IfTrue { .. } | Insn::IfFalse { .. } | Insn::Return { .. }
+ | Insn::PatchPoint { .. } =>
+ panic!("Cannot infer type of instruction with no output"),
+ Insn::Const { val: Const::Value(val) } => Type::from_value(*val),
+ Insn::Const { val: Const::CBool(val) } => Type::from_cbool(*val),
+ Insn::Const { val: Const::CInt8(val) } => Type::from_cint(types::CInt8, *val as i64),
+ Insn::Const { val: Const::CInt16(val) } => Type::from_cint(types::CInt16, *val as i64),
+ Insn::Const { val: Const::CInt32(val) } => Type::from_cint(types::CInt32, *val as i64),
+ Insn::Const { val: Const::CInt64(val) } => Type::from_cint(types::CInt64, *val as i64),
+ Insn::Const { val: Const::CUInt8(val) } => Type::from_cint(types::CUInt8, *val as i64),
+ Insn::Const { val: Const::CUInt16(val) } => Type::from_cint(types::CUInt16, *val as i64),
+ Insn::Const { val: Const::CUInt32(val) } => Type::from_cint(types::CUInt32, *val as i64),
+ Insn::Const { val: Const::CUInt64(val) } => Type::from_cint(types::CUInt64, *val as i64),
+ Insn::Const { val: Const::CPtr(val) } => Type::from_cint(types::CPtr, *val as i64),
+ Insn::Const { val: Const::CDouble(val) } => Type::from_double(*val),
+ Insn::Test { val } if self.type_of(*val).is_subtype(types::NilClassExact) || self.type_of(*val).is_subtype(types::FalseClassExact) => Type::from_cbool(false),
+ Insn::Test { val } if !self.type_of(*val).could_be(types::NilClassExact) && !self.type_of(*val).could_be(types::FalseClassExact) => Type::from_cbool(true),
+ Insn::Test { .. } => types::CBool,
+ Insn::StringCopy { .. } => types::StringExact,
+ Insn::StringIntern { .. } => types::StringExact,
+ Insn::NewArray { .. } => types::ArrayExact,
+ Insn::ArrayDup { .. } => types::ArrayExact,
+ Insn::CCall { .. } => types::Any,
+ Insn::GuardType { val, guard_type, .. } => self.type_of(*val).intersection(*guard_type),
+ Insn::FixnumAdd { .. } => types::Fixnum,
+ Insn::FixnumSub { .. } => types::Fixnum,
+ Insn::FixnumMult { .. } => types::Fixnum,
+ Insn::FixnumDiv { .. } => types::Fixnum,
+ Insn::FixnumMod { .. } => types::Fixnum,
+ Insn::FixnumEq { .. } => types::BoolExact,
+ Insn::FixnumNeq { .. } => types::BoolExact,
+ Insn::FixnumLt { .. } => types::BoolExact,
+ Insn::FixnumLe { .. } => types::BoolExact,
+ Insn::FixnumGt { .. } => types::BoolExact,
+ Insn::FixnumGe { .. } => types::BoolExact,
+ Insn::Send { .. } => types::Object,
+ Insn::PutSelf => types::Object,
+ Insn::Defined { .. } => types::Object,
+ Insn::GetConstantPath { .. } => types::Object,
+ }
+ }
+
+ fn infer_types(&mut self) {
+ // Reset all types
+ self.insn_types.fill(types::Empty);
+ // Compute predecessor instructions for each block
+ let mut preds: Vec<Vec<InsnId>> = vec![Vec::new(); self.blocks.len()];
+ let rpo = self.rpo();
+ // Walk the graph, computing predecessor blocks
+ for block in &rpo {
+ for insn in &self.blocks[block.0].insns {
+ match self.find(*insn) {
+ Insn::IfTrue { target, .. }
+ | Insn::IfFalse { target, .. }
+ | Insn::Jump(target) =>
+ preds[target.target.0].push(*insn),
+ _ => {}
+ }
+ }
+ }
+ for idx in 0..preds.len() {
+ preds[idx].sort();
+ preds[idx].dedup();
+ }
+ // Walk the graph, computing types until fixpoint
+ let mut reachable = vec![false; self.blocks.len()];
+ reachable[self.entry_block.0] = true;
+ loop {
+ let mut changed = false;
+ for block in &rpo {
+ if !reachable[block.0] { continue; }
+ for insn_id in &self.blocks[block.0].insns {
+ let insn = self.find(*insn_id);
+ let insn_type = match insn {
+ Insn::IfTrue { val, target: BranchEdge { target, args } } => {
+ assert!(!self.type_of(val).bit_equal(types::Empty));
+ if self.type_of(val).could_be(Type::from_cbool(true)) {
+ reachable[target.0] = true;
+ for (idx, arg) in args.iter().enumerate() {
+ let param = self.blocks[target.0].params[idx];
+ self.insn_types[param.0] = self.type_of(param).union(self.type_of(*arg));
+ }
+ }
+ continue;
+ }
+ Insn::IfFalse { val, target: BranchEdge { target, args } } => {
+ assert!(!self.type_of(val).bit_equal(types::Empty));
+ if self.type_of(val).could_be(Type::from_cbool(false)) {
+ reachable[target.0] = true;
+ for (idx, arg) in args.iter().enumerate() {
+ let param = self.blocks[target.0].params[idx];
+ self.insn_types[param.0] = self.type_of(param).union(self.type_of(*arg));
+ }
+ }
+ continue;
+ }
+ Insn::Jump(BranchEdge { target, args }) => {
+ reachable[target.0] = true;
+ for (idx, arg) in args.iter().enumerate() {
+ let param = self.blocks[target.0].params[idx];
+ self.insn_types[param.0] = self.type_of(param).union(self.type_of(*arg));
+ }
+ continue;
+ }
+ _ if insn.has_output() => self.infer_type(*insn_id),
+ _ => continue,
+ };
+ if !self.type_of(*insn_id).bit_equal(insn_type) {
+ self.insn_types[insn_id.0] = insn_type;
+ changed = true;
+ }
+ }
+ }
+ if !changed {
+ break;
+ }
+ }
+ }
+
+ /// Return a traversal of the `Function`'s `BlockId`s in reverse post-order.
+ fn rpo(&self) -> Vec<BlockId> {
+ let mut result = vec![];
+ self.po_from(self.entry_block, &mut result, &mut HashSet::new());
+ result.reverse();
+ result
+ }
+
+ fn po_from(&self, block: BlockId, mut result: &mut Vec<BlockId>, mut seen: &mut HashSet<BlockId>) {
+ // TODO(max): Avoid using recursion for post-order traversal. For graphs, this is slightly
+ // trickier than it might seem.
+ if seen.contains(&block) { return; }
+ seen.insert(block);
+ for insn in &self.blocks[block.0].insns {
+ match self.find(*insn) {
+ Insn::Jump(edge) => self.po_from(edge.target, &mut result, &mut seen),
+ Insn::IfTrue { target, .. } => self.po_from(target.target, &mut result, &mut seen),
+ Insn::IfFalse { target, .. } => self.po_from(target.target, &mut result, &mut seen),
+ _ => {}
+ }
+ }
+ result.push(block);
+ }
}
impl<'a> std::fmt::Display for FunctionPrinter<'a> {
@@ -422,23 +621,36 @@ impl<'a> std::fmt::Display for FunctionPrinter<'a> {
let mut sep = "";
for param in &block.params {
write!(f, "{sep}{param}")?;
+ let insn_type = fun.type_of(*param);
+ if !insn_type.is_subtype(types::Empty) {
+ write!(f, ":{insn_type}")?;
+ }
sep = ", ";
}
}
writeln!(f, "):")?;
for insn_id in &block.insns {
- if !self.display_snapshot && matches!(fun.insns[insn_id.0], Insn::Snapshot {..}) {
+ let insn = fun.find(*insn_id);
+ if !self.display_snapshot && matches!(insn, Insn::Snapshot {..}) {
continue;
}
- write!(f, " {insn_id} = ")?;
- match &fun.insns[insn_id.0] {
+ write!(f, " ")?;
+ if insn.has_output() {
+ let insn_type = fun.type_of(*insn_id);
+ if insn_type.is_subtype(types::Empty) {
+ write!(f, "{insn_id} = ")?;
+ } else {
+ write!(f, "{insn_id}:{insn_type} = ")?;
+ }
+ }
+ match insn {
Insn::Const { val } => { write!(f, "Const {val}")?; }
Insn::Param { idx } => { write!(f, "Param {idx}")?; }
Insn::NewArray { count } => { write!(f, "NewArray {count}")?; }
- Insn::ArraySet { idx, val } => { write!(f, "ArraySet {idx}, {val}")?; }
+ Insn::ArraySet { array, idx, val } => { write!(f, "ArraySet {array}, {idx}, {val}")?; }
Insn::ArrayDup { val } => { write!(f, "ArrayDup {val}")?; }
Insn::Test { val } => { write!(f, "Test {val}")?; }
- Insn::Snapshot { state } => { write!(f, "Snapshot {}", fun.frame_state(*state))?; }
+ Insn::Snapshot { state } => { write!(f, "Snapshot {}", fun.frame_state(state))?; }
Insn::Jump(target) => { write!(f, "Jump {target}")?; }
Insn::IfTrue { val, target } => { write!(f, "IfTrue {val}, {target}")?; }
Insn::IfFalse { val, target } => { write!(f, "IfFalse {val}, {target}")?; }
@@ -702,11 +914,11 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
}
YARVINSN_newarray => {
let count = get_arg(pc, 0).as_usize();
- let insn_id = fun.push_insn(block, Insn::NewArray { count });
+ let array = fun.push_insn(block, Insn::NewArray { count });
for idx in (0..count).rev() {
- fun.push_insn(block, Insn::ArraySet { idx, val: state.pop()? });
+ fun.push_insn(block, Insn::ArraySet { array, idx, val: state.pop()? });
}
- state.push(insn_id);
+ state.push(array);
}
YARVINSN_duparray => {
let val = fun.push_insn(block, Insn::Const { val: Const::Value(get_arg(pc, 0)) });
@@ -968,6 +1180,8 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
}
}
+ fun.infer_types();
+
match get_option!(dump_hir) {
Some(DumpHIR::WithoutSnapshot) => println!("HIR:\n{}", FunctionPrinter::without_snapshot(&fun)),
Some(DumpHIR::All) => println!("HIR:\n{}", FunctionPrinter::with_snapshot(&fun)),
@@ -1028,6 +1242,188 @@ mod union_find_tests {
}
#[cfg(test)]
+mod rpo_tests {
+ use super::*;
+
+ #[test]
+ fn one_block() {
+ let mut function = Function::new(std::ptr::null());
+ let entry = function.entry_block;
+ let val = function.push_insn(entry, Insn::Const { val: Const::Value(Qnil) });
+ function.push_insn(entry, Insn::Return { val });
+ assert_eq!(function.rpo(), vec![entry]);
+ }
+
+ #[test]
+ fn jump() {
+ let mut function = Function::new(std::ptr::null());
+ let entry = function.entry_block;
+ let exit = function.new_block();
+ function.push_insn(entry, Insn::Jump(BranchEdge { target: exit, args: vec![] }));
+ let val = function.push_insn(entry, Insn::Const { val: Const::Value(Qnil) });
+ function.push_insn(entry, Insn::Return { val });
+ assert_eq!(function.rpo(), vec![entry, exit]);
+ }
+
+ #[test]
+ fn diamond_iftrue() {
+ let mut function = Function::new(std::ptr::null());
+ let entry = function.entry_block;
+ let side = function.new_block();
+ let exit = function.new_block();
+ function.push_insn(side, Insn::Jump(BranchEdge { target: exit, args: vec![] }));
+ let val = function.push_insn(entry, Insn::Const { val: Const::Value(Qnil) });
+ function.push_insn(entry, Insn::IfTrue { val, target: BranchEdge { target: side, args: vec![] } });
+ function.push_insn(entry, Insn::Jump(BranchEdge { target: exit, args: vec![] }));
+ let val = function.push_insn(entry, Insn::Const { val: Const::Value(Qnil) });
+ function.push_insn(entry, Insn::Return { val });
+ assert_eq!(function.rpo(), vec![entry, side, exit]);
+ }
+
+ #[test]
+ fn diamond_iffalse() {
+ let mut function = Function::new(std::ptr::null());
+ let entry = function.entry_block;
+ let side = function.new_block();
+ let exit = function.new_block();
+ function.push_insn(side, Insn::Jump(BranchEdge { target: exit, args: vec![] }));
+ let val = function.push_insn(entry, Insn::Const { val: Const::Value(Qnil) });
+ function.push_insn(entry, Insn::IfFalse { val, target: BranchEdge { target: side, args: vec![] } });
+ function.push_insn(entry, Insn::Jump(BranchEdge { target: exit, args: vec![] }));
+ let val = function.push_insn(entry, Insn::Const { val: Const::Value(Qnil) });
+ function.push_insn(entry, Insn::Return { val });
+ assert_eq!(function.rpo(), vec![entry, side, exit]);
+ }
+
+ #[test]
+ fn a_loop() {
+ let mut function = Function::new(std::ptr::null());
+ let entry = function.entry_block;
+ function.push_insn(entry, Insn::Jump(BranchEdge { target: entry, args: vec![] }));
+ assert_eq!(function.rpo(), vec![entry]);
+ }
+}
+
+#[cfg(test)]
+mod infer_tests {
+ use super::*;
+
+ #[track_caller]
+ fn assert_subtype(left: Type, right: Type) {
+ assert!(left.is_subtype(right), "{left} is not a subtype of {right}");
+ }
+
+ #[track_caller]
+ fn assert_bit_equal(left: Type, right: Type) {
+ assert!(left.bit_equal(right), "{left} != {right}");
+ }
+
+ #[test]
+ fn test_const() {
+ let mut function = Function::new(std::ptr::null());
+ let val = function.push_insn(function.entry_block, Insn::Const { val: Const::Value(Qnil) });
+ assert_bit_equal(function.infer_type(val), types::NilClassExact);
+ }
+
+ #[test]
+ fn test_nil() {
+ crate::cruby::with_rubyvm(|| {
+ let mut function = Function::new(std::ptr::null());
+ let nil = function.push_insn(function.entry_block, Insn::Const { val: Const::Value(Qnil) });
+ let val = function.push_insn(function.entry_block, Insn::Test { val: nil });
+ function.infer_types();
+ assert_bit_equal(function.type_of(val), Type::from_cbool(false));
+ });
+ }
+
+ #[test]
+ fn test_false() {
+ crate::cruby::with_rubyvm(|| {
+ let mut function = Function::new(std::ptr::null());
+ let false_ = function.push_insn(function.entry_block, Insn::Const { val: Const::Value(Qfalse) });
+ let val = function.push_insn(function.entry_block, Insn::Test { val: false_ });
+ function.infer_types();
+ assert_bit_equal(function.type_of(val), Type::from_cbool(false));
+ });
+ }
+
+ #[test]
+ fn test_truthy() {
+ crate::cruby::with_rubyvm(|| {
+ let mut function = Function::new(std::ptr::null());
+ let true_ = function.push_insn(function.entry_block, Insn::Const { val: Const::Value(Qtrue) });
+ let val = function.push_insn(function.entry_block, Insn::Test { val: true_ });
+ function.infer_types();
+ assert_bit_equal(function.type_of(val), Type::from_cbool(true));
+ });
+ }
+
+ #[test]
+ fn test_unknown() {
+ crate::cruby::with_rubyvm(|| {
+ let mut function = Function::new(std::ptr::null());
+ let param = function.push_insn(function.entry_block, Insn::PutSelf);
+ let val = function.push_insn(function.entry_block, Insn::Test { val: param });
+ function.infer_types();
+ assert_bit_equal(function.type_of(val), types::CBool);
+ });
+ }
+
+ #[test]
+ fn newarray() {
+ let mut function = Function::new(std::ptr::null());
+ let val = function.push_insn(function.entry_block, Insn::NewArray { count: 0 });
+ assert_bit_equal(function.infer_type(val), types::ArrayExact);
+ }
+
+ #[test]
+ fn arraydup() {
+ let mut function = Function::new(std::ptr::null());
+ let arr = function.push_insn(function.entry_block, Insn::NewArray { count: 0 });
+ let val = function.push_insn(function.entry_block, Insn::ArrayDup { val: arr });
+ assert_bit_equal(function.infer_type(val), types::ArrayExact);
+ }
+
+ #[test]
+ fn diamond_iffalse_merge_fixnum() {
+ let mut function = Function::new(std::ptr::null());
+ let entry = function.entry_block;
+ let side = function.new_block();
+ let exit = function.new_block();
+ let v0 = function.push_insn(side, Insn::Const { val: Const::Value(VALUE::fixnum_from_usize(3)) });
+ function.push_insn(side, Insn::Jump(BranchEdge { target: exit, args: vec![v0] }));
+ let val = function.push_insn(entry, Insn::Const { val: Const::CBool(false) });
+ function.push_insn(entry, Insn::IfFalse { val, target: BranchEdge { target: side, args: vec![] } });
+ let v1 = function.push_insn(entry, Insn::Const { val: Const::Value(VALUE::fixnum_from_usize(4)) });
+ function.push_insn(entry, Insn::Jump(BranchEdge { target: exit, args: vec![v1] }));
+ let param = function.push_insn(exit, Insn::Param { idx: 0 });
+ crate::cruby::with_rubyvm(|| {
+ function.infer_types();
+ });
+ assert_bit_equal(function.type_of(param), types::Fixnum);
+ }
+
+ #[test]
+ fn diamond_iffalse_merge_bool() {
+ let mut function = Function::new(std::ptr::null());
+ let entry = function.entry_block;
+ let side = function.new_block();
+ let exit = function.new_block();
+ let v0 = function.push_insn(side, Insn::Const { val: Const::Value(Qtrue) });
+ function.push_insn(side, Insn::Jump(BranchEdge { target: exit, args: vec![v0] }));
+ let val = function.push_insn(entry, Insn::Const { val: Const::CBool(false) });
+ function.push_insn(entry, Insn::IfFalse { val, target: BranchEdge { target: side, args: vec![] } });
+ let v1 = function.push_insn(entry, Insn::Const { val: Const::Value(Qfalse) });
+ function.push_insn(entry, Insn::Jump(BranchEdge { target: exit, args: vec![v1] }));
+ let param = function.push_insn(exit, Insn::Param { idx: 0 });
+ crate::cruby::with_rubyvm(|| {
+ function.infer_types();
+ assert_bit_equal(function.type_of(param), types::TrueClassExact.union(types::FalseClassExact));
+ });
+ }
+}
+
+#[cfg(test)]
mod tests {
use super::*;
@@ -1093,8 +1489,8 @@ mod tests {
let function = iseq_to_hir(iseq).unwrap();
assert_function_hir(function, "
bb0():
- v1 = Const Value(123)
- v3 = Return v1
+ v1:Fixnum[123] = Const Value(123)
+ Return v1
");
}
@@ -1105,10 +1501,10 @@ mod tests {
let function = iseq_to_hir(iseq).unwrap();
assert_function_hir(function, "
bb0():
- v1 = Const Value(1)
- v3 = Const Value(2)
- v5 = Send v1, :+, v3
- v7 = Return v5
+ v1:Fixnum[1] = Const Value(1)
+ v3:Fixnum[2] = Const Value(2)
+ v5:Object = Send v1, :+, v3
+ Return v5
");
}
@@ -1119,9 +1515,9 @@ mod tests {
let function = iseq_to_hir(iseq).unwrap();
assert_function_hir(function, "
bb0():
- v0 = Const Value(nil)
- v2 = Const Value(1)
- v6 = Return v2
+ v0:NilClassExact = Const Value(nil)
+ v2:Fixnum[1] = Const Value(1)
+ Return v2
");
}
@@ -1132,15 +1528,15 @@ mod tests {
let function = iseq_to_hir(iseq).unwrap();
assert_function_hir(function, "
bb0():
- v0 = Const Value(nil)
- v2 = Const Value(true)
- v6 = Test v2
- v7 = IfFalse v6, bb1(v2)
- v9 = Const Value(3)
- v11 = Return v9
+ v0:NilClassExact = Const Value(nil)
+ v2:TrueClassExact = Const Value(true)
+ v6:CBool[true] = Test v2
+ IfFalse v6, bb1(v2)
+ v9:Fixnum[3] = Const Value(3)
+ Return v9
bb1(v12):
v14 = Const Value(4)
- v16 = Return v14
+ Return v14
");
}
@@ -1152,11 +1548,11 @@ mod tests {
");
assert_method_hir("test", "
bb0(v0, v1):
- v5 = PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_PLUS)
+ PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_PLUS)
v6 = GuardType v0, Fixnum
v7 = GuardType v1, Fixnum
- v8 = FixnumAdd v6, v7
- v10 = Return v8
+ v8:Fixnum = FixnumAdd v6, v7
+ Return v8
");
}
@@ -1168,11 +1564,11 @@ mod tests {
");
assert_method_hir("test", "
bb0(v0, v1):
- v5 = PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_MINUS)
+ PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_MINUS)
v6 = GuardType v0, Fixnum
v7 = GuardType v1, Fixnum
- v8 = FixnumSub v6, v7
- v10 = Return v8
+ v8:Fixnum = FixnumSub v6, v7
+ Return v8
");
}
@@ -1184,11 +1580,11 @@ mod tests {
");
assert_method_hir("test", "
bb0(v0, v1):
- v5 = PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_MULT)
+ PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_MULT)
v6 = GuardType v0, Fixnum
v7 = GuardType v1, Fixnum
- v8 = FixnumMult v6, v7
- v10 = Return v8
+ v8:Fixnum = FixnumMult v6, v7
+ Return v8
");
}
@@ -1200,11 +1596,11 @@ mod tests {
");
assert_method_hir("test", "
bb0(v0, v1):
- v5 = PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_DIV)
+ PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_DIV)
v6 = GuardType v0, Fixnum
v7 = GuardType v1, Fixnum
- v8 = FixnumDiv v6, v7
- v10 = Return v8
+ v8:Fixnum = FixnumDiv v6, v7
+ Return v8
");
}
@@ -1216,11 +1612,11 @@ mod tests {
");
assert_method_hir("test", "
bb0(v0, v1):
- v5 = PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_MOD)
+ PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_MOD)
v6 = GuardType v0, Fixnum
v7 = GuardType v1, Fixnum
- v8 = FixnumMod v6, v7
- v10 = Return v8
+ v8:Fixnum = FixnumMod v6, v7
+ Return v8
");
}
@@ -1232,11 +1628,11 @@ mod tests {
");
assert_method_hir("test", "
bb0(v0, v1):
- v5 = PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_EQ)
+ PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_EQ)
v6 = GuardType v0, Fixnum
v7 = GuardType v1, Fixnum
- v8 = FixnumEq v6, v7
- v10 = Return v8
+ v8:BoolExact = FixnumEq v6, v7
+ Return v8
");
}
@@ -1248,11 +1644,11 @@ mod tests {
");
assert_method_hir("test", "
bb0(v0, v1):
- v5 = PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_NEQ)
+ PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_NEQ)
v6 = GuardType v0, Fixnum
v7 = GuardType v1, Fixnum
- v8 = FixnumNeq v6, v7
- v10 = Return v8
+ v8:BoolExact = FixnumNeq v6, v7
+ Return v8
");
}
@@ -1264,11 +1660,11 @@ mod tests {
");
assert_method_hir("test", "
bb0(v0, v1):
- v5 = PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_LT)
+ PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_LT)
v6 = GuardType v0, Fixnum
v7 = GuardType v1, Fixnum
- v8 = FixnumLt v6, v7
- v10 = Return v8
+ v8:BoolExact = FixnumLt v6, v7
+ Return v8
");
}
@@ -1280,11 +1676,11 @@ mod tests {
");
assert_method_hir("test", "
bb0(v0, v1):
- v5 = PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_LE)
+ PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_LE)
v6 = GuardType v0, Fixnum
v7 = GuardType v1, Fixnum
- v8 = FixnumLe v6, v7
- v10 = Return v8
+ v8:BoolExact = FixnumLe v6, v7
+ Return v8
");
}
@@ -1296,11 +1692,57 @@ mod tests {
");
assert_method_hir("test", "
bb0(v0, v1):
- v5 = PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_GT)
+ PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_GT)
v6 = GuardType v0, Fixnum
v7 = GuardType v1, Fixnum
- v8 = FixnumGt v6, v7
- v10 = Return v8
+ v8:BoolExact = FixnumGt v6, v7
+ Return v8
+ ");
+ }
+
+ #[test]
+ fn test_loop() {
+ eval("
+ def test
+ result = 0
+ times = 10
+ while times > 0
+ result = result + 1
+ times = times - 1
+ end
+ result
+ end
+ test
+ ");
+ assert_method_hir("test", "
+ bb0():
+ v0:NilClassExact = Const Value(nil)
+ v1:NilClassExact = Const Value(nil)
+ v3:Fixnum[0] = Const Value(0)
+ v6:Fixnum[10] = Const Value(10)
+ Jump bb2(v3, v6)
+ bb1(v29:Fixnum, v30:Fixnum):
+ v33:Fixnum[1] = Const Value(1)
+ PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_PLUS)
+ v36:Fixnum = GuardType v29, Fixnum
+ v37:Fixnum[1] = GuardType v33, Fixnum
+ v38:Fixnum = FixnumAdd v36, v37
+ v42:Fixnum[1] = Const Value(1)
+ PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_MINUS)
+ v45:Fixnum = GuardType v30, Fixnum
+ v46:Fixnum[1] = GuardType v42, Fixnum
+ v47:Fixnum = FixnumSub v45, v46
+ Jump bb2(v38, v47)
+ bb2(v10:Fixnum, v11:Fixnum):
+ v14:Fixnum[0] = Const Value(0)
+ PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_GT)
+ v17:Fixnum = GuardType v11, Fixnum
+ v18:Fixnum[0] = GuardType v14, Fixnum
+ v19:BoolExact = FixnumGt v17, v18
+ v21:CBool = Test v19
+ IfTrue v21, bb1(v10, v11)
+ v24:NilClassExact = Const Value(nil)
+ Return v10
");
}
@@ -1312,11 +1754,37 @@ mod tests {
");
assert_method_hir("test", "
bb0(v0, v1):
- v5 = PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_GE)
+ PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_GE)
v6 = GuardType v0, Fixnum
v7 = GuardType v1, Fixnum
- v8 = FixnumGe v6, v7
- v10 = Return v8
+ v8:BoolExact = FixnumGe v6, v7
+ Return v8
+ ");
+ }
+
+ #[test]
+ fn test_display_types() {
+ eval("
+ def test
+ cond = true
+ if cond
+ 3
+ else
+ 4
+ end
+ end
");
+ assert_method_hir("test", "
+ bb0():
+ v0:NilClassExact = Const Value(nil)
+ v2:TrueClassExact = Const Value(true)
+ v6:CBool[true] = Test v2
+ IfFalse v6, bb1(v2)
+ v9:Fixnum[3] = Const Value(3)
+ Return v9
+ bb1(v12):
+ v14 = Const Value(4)
+ Return v14
+ ");
}
}
diff --git a/zjit/src/hir_type.inc.rs b/zjit/src/hir_type.inc.rs
index 3921ec042d..b55c0c7d83 100644
--- a/zjit/src/hir_type.inc.rs
+++ b/zjit/src/hir_type.inc.rs
@@ -1,11 +1,12 @@
// This file is @generated by src/gen_hir_type.rb.
mod bits {
- pub const Any: u64 = Object | Primitive | User;
+ pub const Any: u64 = Object | Primitive;
pub const Array: u64 = ArrayExact | ArrayUser;
pub const ArrayExact: u64 = 1u64 << 0;
pub const ArrayUser: u64 = 1u64 << 1;
pub const Bignum: u64 = 1u64 << 2;
- pub const BuiltinExact: u64 = ArrayExact | FalseClassExact | FloatExact | HashExact | IntegerExact | NilClassExact | StringExact | SymbolExact | TrueClassExact;
+ pub const BoolExact: u64 = FalseClassExact | TrueClassExact;
+ pub const BuiltinExact: u64 = ArrayExact | FalseClassExact | FloatExact | HashExact | IntegerExact | NilClassExact | ObjectExact | StringExact | SymbolExact | TrueClassExact;
pub const CBool: u64 = 1u64 << 3;
pub const CDouble: u64 = 1u64 << 4;
pub const CInt: u64 = CSigned | CUnsigned;
@@ -41,29 +42,29 @@ mod bits {
pub const NilClass: u64 = NilClassExact | NilClassUser;
pub const NilClassExact: u64 = 1u64 << 25;
pub const NilClassUser: u64 = 1u64 << 26;
- pub const Object: u64 = BuiltinExact | ObjectExact | ObjectUser;
+ pub const Object: u64 = Array | FalseClass | Float | Hash | Integer | NilClass | ObjectExact | ObjectUser | String | Symbol | TrueClass;
pub const ObjectExact: u64 = 1u64 << 27;
- pub const ObjectUser: u64 = Array | FalseClass | Float | Hash | Integer | NilClass | String | Symbol | TrueClass;
+ pub const ObjectUser: u64 = 1u64 << 28;
pub const Primitive: u64 = CBool | CDouble | CInt | CNull | CPtr;
- pub const StaticSymbol: u64 = 1u64 << 28;
+ pub const StaticSymbol: u64 = 1u64 << 29;
pub const String: u64 = StringExact | StringUser;
- pub const StringExact: u64 = 1u64 << 29;
- pub const StringUser: u64 = 1u64 << 30;
+ pub const StringExact: u64 = 1u64 << 30;
+ pub const StringUser: u64 = 1u64 << 31;
pub const Symbol: u64 = SymbolExact | SymbolUser;
pub const SymbolExact: u64 = DynamicSymbol | StaticSymbol;
- pub const SymbolUser: u64 = 1u64 << 31;
+ pub const SymbolUser: u64 = 1u64 << 32;
pub const TrueClass: u64 = TrueClassExact | TrueClassUser;
- pub const TrueClassExact: u64 = 1u64 << 32;
- pub const TrueClassUser: u64 = 1u64 << 33;
- pub const User: u64 = ArrayUser | FalseClassUser | FloatUser | HashUser | IntegerUser | NilClassUser | StringUser | SymbolUser | TrueClassUser;
- pub const AllBitPatterns: [(&'static str, u64); 56] = [
+ pub const TrueClassExact: u64 = 1u64 << 33;
+ pub const TrueClassUser: u64 = 1u64 << 34;
+ pub const User: u64 = ArrayUser | FalseClassUser | FloatUser | HashUser | IntegerUser | NilClassUser | ObjectUser | StringUser | SymbolUser | TrueClassUser;
+ pub const AllBitPatterns: [(&'static str, u64); 57] = [
("Any", Any),
("Object", Object),
- ("ObjectUser", ObjectUser),
("TrueClass", TrueClass),
("User", User),
("TrueClassUser", TrueClassUser),
("BuiltinExact", BuiltinExact),
+ ("BoolExact", BoolExact),
("TrueClassExact", TrueClassExact),
("Symbol", Symbol),
("SymbolUser", SymbolUser),
@@ -72,6 +73,7 @@ mod bits {
("StringExact", StringExact),
("SymbolExact", SymbolExact),
("StaticSymbol", StaticSymbol),
+ ("ObjectUser", ObjectUser),
("ObjectExact", ObjectExact),
("NilClass", NilClass),
("NilClassUser", NilClassUser),
@@ -114,7 +116,7 @@ mod bits {
("ArrayExact", ArrayExact),
("Empty", Empty),
];
- pub const NumTypeBits: u64 = 34;
+ pub const NumTypeBits: u64 = 35;
}
pub mod types {
use super::*;
@@ -123,6 +125,7 @@ pub mod types {
pub const ArrayExact: Type = Type::from_bits(bits::ArrayExact);
pub const ArrayUser: Type = Type::from_bits(bits::ArrayUser);
pub const Bignum: Type = Type::from_bits(bits::Bignum);
+ pub const BoolExact: Type = Type::from_bits(bits::BoolExact);
pub const BuiltinExact: Type = Type::from_bits(bits::BuiltinExact);
pub const CBool: Type = Type::from_bits(bits::CBool);
pub const CDouble: Type = Type::from_bits(bits::CDouble);
diff --git a/zjit/src/hir_type.rs b/zjit/src/hir_type.rs
index 4b6b2cbb23..de52483fc1 100644
--- a/zjit/src/hir_type.rs
+++ b/zjit/src/hir_type.rs
@@ -194,6 +194,11 @@ impl Type {
Type { bits: ty.bits, spec: Specialization::Int(val as u64) }
}
+ /// Create a `Type` (a `CDouble` with double specialization) from a f64.
+ pub fn from_double(val: f64) -> Type {
+ Type { bits: bits::CDouble, spec: Specialization::Double(val) }
+ }
+
/// Create a `Type` from a primitive boolean.
pub fn from_cbool(val: bool) -> Type {
Type { bits: bits::CBool, spec: Specialization::Int(val as u64) }
@@ -226,6 +231,7 @@ impl Type {
false
}
+ /// Union both types together, preserving specialization if possible.
pub fn union(&self, other: Type) -> Type {
// Easy cases first
if self.is_subtype(other) { return other; }
@@ -256,6 +262,19 @@ impl Type {
Type { bits, spec: Specialization::Type(super_class) }
}
+ /// Intersect both types, preserving specialization if possible.
+ pub fn intersection(&self, other: Type) -> Type {
+ let bits = self.bits & other.bits;
+ if bits == bits::Empty { return types::Empty; }
+ if self.spec_is_subtype_of(other) { return Type { bits, spec: self.spec }; }
+ if other.spec_is_subtype_of(*self) { return Type { bits, spec: other.spec }; }
+ types::Empty
+ }
+
+ pub fn could_be(&self, other: Type) -> bool {
+ !self.intersection(other).bit_equal(types::Empty)
+ }
+
/// Check if the type field of `self` is a subtype of the type field of `other` and also check
/// if the specialization of `self` is a subtype of the specialization of `other`.
pub fn is_subtype(&self, other: Type) -> bool {
@@ -298,6 +317,11 @@ impl Type {
}
}
+ /// Check bit equality of two `Type`s. Do not use! You are probably looking for [`Type::is_subtype`].
+ pub fn bit_equal(&self, other: Type) -> bool {
+ self.bits == other.bits && self.spec == other.spec
+ }
+
/// Check *only* if `self`'s specialization is a subtype of `other`'s specialization. Private.
/// You probably want [`Type::is_subtype`] instead.
fn spec_is_subtype_of(&self, other: Type) -> bool {