summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMax Bernstein <rubybugs@bernsteinbear.com>2025-10-28 20:11:27 -0400
committerGitHub <noreply@github.com>2025-10-29 00:11:27 +0000
commitf8a333ae193017999b38f6a4838582cc2c333063 (patch)
tree8bc50d52aa3a37ed7f83d2df58b213ebe39c7f30
parent80be97e4a2c878d7c5a129b245f1e2430b99b19b (diff)
ZJIT: Add type checker to HIR (#14978)
Allow instructions to constrain their operands' input types to avoid accidentally creating invalid HIR.
-rw-r--r--zjit/src/hir.rs162
-rw-r--r--zjit/src/stats.rs4
2 files changed, 166 insertions, 0 deletions
diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs
index 3f68764722..b284ae6c11 100644
--- a/zjit/src/hir.rs
+++ b/zjit/src/hir.rs
@@ -1356,6 +1356,9 @@ pub enum ValidationError {
OperandNotDefined(BlockId, InsnId, InsnId),
/// The offending block and instruction
DuplicateInstruction(BlockId, InsnId),
+ /// The offending instruction, its operand, expected type string, actual type string
+ MismatchedOperandType(InsnId, InsnId, String, String),
+ MiscValidationError(InsnId, String),
}
fn can_direct_send(iseq: *const rb_iseq_t) -> bool {
@@ -3565,11 +3568,170 @@ impl Function {
Ok(())
}
+ fn assert_subtype(&self, user: InsnId, operand: InsnId, expected: Type) -> Result<(), ValidationError> {
+ let actual = self.type_of(operand);
+ if !actual.is_subtype(expected) {
+ return Err(ValidationError::MismatchedOperandType(user, operand, format!("{}", expected), format!("{}", actual)));
+ }
+ Ok(())
+ }
+
+ fn validate_insn_type(&self, insn_id: InsnId) -> Result<(), ValidationError> {
+ let insn_id = self.union_find.borrow().find_const(insn_id);
+ let insn = self.find(insn_id);
+ match insn {
+ Insn::StringCopy { val, .. } => self.assert_subtype(insn_id, val, types::StringExact),
+ Insn::StringIntern { val, .. } => self.assert_subtype(insn_id, val, types::StringExact),
+ Insn::ArrayDup { val, .. } => self.assert_subtype(insn_id, val, types::ArrayExact),
+ Insn::StringAppend { recv, other, .. } => {
+ self.assert_subtype(insn_id, recv, types::StringExact)?;
+ self.assert_subtype(insn_id, other, types::String)
+ }
+ Insn::NewHash { ref elements, .. } => {
+ if elements.len() % 2 != 0 {
+ return Err(ValidationError::MiscValidationError(insn_id, "NewHash elements length is not even".to_string()));
+ }
+ Ok(())
+ }
+ Insn::NewRangeFixnum { low, high, .. } => {
+ self.assert_subtype(insn_id, low, types::Fixnum)?;
+ self.assert_subtype(insn_id, high, types::Fixnum)
+ }
+ Insn::ArrayExtend { left, right, .. } => {
+ // TODO(max): Do left and right need to be ArrayExact?
+ self.assert_subtype(insn_id, left, types::Array)?;
+ self.assert_subtype(insn_id, right, types::Array)
+ }
+ Insn::ArrayPush { array, .. } => self.assert_subtype(insn_id, array, types::Array),
+ Insn::ArrayPop { array, .. } => self.assert_subtype(insn_id, array, types::Array),
+ Insn::ArrayLength { array, .. } => self.assert_subtype(insn_id, array, types::Array),
+ Insn::HashAref { hash, .. } => self.assert_subtype(insn_id, hash, types::Hash),
+ Insn::HashDup { val, .. } => self.assert_subtype(insn_id, val, types::HashExact),
+ Insn::ObjectAllocClass { class, .. } => {
+ let has_leaf_allocator = unsafe { rb_zjit_class_has_default_allocator(class) } || class_has_leaf_allocator(class);
+ if !has_leaf_allocator {
+ return Err(ValidationError::MiscValidationError(insn_id, "ObjectAllocClass must have leaf allocator".to_string()));
+ }
+ Ok(())
+ }
+ Insn::Test { val } => self.assert_subtype(insn_id, val, types::BasicObject),
+ Insn::IsNil { val } => self.assert_subtype(insn_id, val, types::BasicObject),
+ Insn::IsMethodCfunc { val, .. } => self.assert_subtype(insn_id, val, types::BasicObject),
+ Insn::IsBitEqual { left, right }
+ | Insn::IsBitNotEqual { left, right } => {
+ if self.is_a(left, types::CInt) && self.is_a(right, types::CInt) {
+ // TODO(max): Check that int sizes match
+ Ok(())
+ } else if self.is_a(left, types::CPtr) && self.is_a(right, types::CPtr) {
+ Ok(())
+ } else if self.is_a(left, types::RubyValue) && self.is_a(right, types::RubyValue) {
+ Ok(())
+ } else {
+ return Err(ValidationError::MiscValidationError(insn_id, "IsBitEqual can only compare CInt/CInt or RubyValue/RubyValue".to_string()));
+ }
+ }
+ Insn::BoxBool { val } => self.assert_subtype(insn_id, val, types::CBool),
+ Insn::SetGlobal { val, .. } => self.assert_subtype(insn_id, val, types::BasicObject),
+ Insn::GetIvar { self_val, .. } => self.assert_subtype(insn_id, self_val, types::BasicObject),
+ Insn::SetIvar { self_val, val, .. } => {
+ self.assert_subtype(insn_id, self_val, types::BasicObject)?;
+ self.assert_subtype(insn_id, val, types::BasicObject)
+ }
+ Insn::DefinedIvar { self_val, .. } => self.assert_subtype(insn_id, self_val, types::BasicObject),
+ Insn::LoadIvarEmbedded { self_val, .. } => self.assert_subtype(insn_id, self_val, types::HeapBasicObject),
+ Insn::LoadIvarExtended { self_val, .. } => self.assert_subtype(insn_id, self_val, types::HeapBasicObject),
+ Insn::SetLocal { val, .. } => self.assert_subtype(insn_id, val, types::BasicObject),
+ Insn::SetClassVar { val, .. } => self.assert_subtype(insn_id, val, types::BasicObject),
+ Insn::IfTrue { val, .. } | Insn::IfFalse { val, .. } => self.assert_subtype(insn_id, val, types::CBool),
+ Insn::SendWithoutBlock { recv, ref args, .. }
+ | Insn::SendWithoutBlockDirect { recv, ref args, .. }
+ | Insn::Send { recv, ref args, .. }
+ | Insn::SendForward { recv, ref args, .. }
+ | Insn::InvokeSuper { recv, ref args, .. }
+ | Insn::CCallVariadic { recv, ref args, .. } => {
+ self.assert_subtype(insn_id, recv, types::BasicObject)?;
+ for &arg in args {
+ self.assert_subtype(insn_id, arg, types::BasicObject)?;
+ }
+ Ok(())
+ }
+ Insn::CCallWithFrame { ref args, .. }
+ | Insn::InvokeBuiltin { ref args, .. }
+ | Insn::InvokeBlock { ref args, .. } => {
+ for &arg in args {
+ self.assert_subtype(insn_id, arg, types::BasicObject)?;
+ }
+ Ok(())
+ }
+ Insn::Return { val } => self.assert_subtype(insn_id, val, types::BasicObject),
+ Insn::Throw { val, .. } => self.assert_subtype(insn_id, val, types::BasicObject),
+ Insn::FixnumAdd { left, right, .. }
+ | Insn::FixnumSub { left, right, .. }
+ | Insn::FixnumMult { left, right, .. }
+ | Insn::FixnumDiv { left, right, .. }
+ | Insn::FixnumMod { left, right, .. }
+ | Insn::FixnumEq { left, right }
+ | Insn::FixnumNeq { left, right }
+ | Insn::FixnumLt { left, right }
+ | Insn::FixnumLe { left, right }
+ | Insn::FixnumGt { left, right }
+ | Insn::FixnumGe { left, right }
+ | Insn::FixnumAnd { left, right }
+ | Insn::FixnumOr { left, right }
+ | Insn::FixnumXor { left, right }
+ => {
+ self.assert_subtype(insn_id, left, types::Fixnum)?;
+ self.assert_subtype(insn_id, right, types::Fixnum)
+ }
+ Insn::ObjToString { val, .. } => self.assert_subtype(insn_id, val, types::BasicObject),
+ Insn::AnyToString { val, str, .. } => {
+ self.assert_subtype(insn_id, val, types::BasicObject)?;
+ self.assert_subtype(insn_id, str, types::BasicObject)
+ }
+ Insn::GuardType { val, .. } => self.assert_subtype(insn_id, val, types::BasicObject),
+ Insn::GuardTypeNot { val, .. } => self.assert_subtype(insn_id, val, types::BasicObject),
+ Insn::GuardBitEquals { val, expected, .. } => {
+ match expected {
+ Const::Value(_) => self.assert_subtype(insn_id, val, types::RubyValue),
+ Const::CInt8(_) => self.assert_subtype(insn_id, val, types::CInt8),
+ Const::CInt16(_) => self.assert_subtype(insn_id, val, types::CInt16),
+ Const::CInt32(_) => self.assert_subtype(insn_id, val, types::CInt32),
+ Const::CInt64(_) => self.assert_subtype(insn_id, val, types::CInt64),
+ Const::CUInt8(_) => self.assert_subtype(insn_id, val, types::CUInt8),
+ Const::CUInt16(_) => self.assert_subtype(insn_id, val, types::CUInt16),
+ Const::CUInt32(_) => self.assert_subtype(insn_id, val, types::CUInt32),
+ Const::CUInt64(_) => self.assert_subtype(insn_id, val, types::CUInt64),
+ Const::CBool(_) => self.assert_subtype(insn_id, val, types::CBool),
+ Const::CDouble(_) => self.assert_subtype(insn_id, val, types::CDouble),
+ Const::CPtr(_) => self.assert_subtype(insn_id, val, types::CPtr),
+ }
+ }
+ Insn::GuardShape { val, .. } => self.assert_subtype(insn_id, val, types::BasicObject),
+ Insn::GuardNotFrozen { val, .. } => self.assert_subtype(insn_id, val, types::BasicObject),
+ Insn::StringGetbyteFixnum { string, index } => {
+ self.assert_subtype(insn_id, string, types::String)?;
+ self.assert_subtype(insn_id, index, types::Fixnum)
+ }
+ _ => Ok(()),
+ }
+ }
+
+ /// Check that insn types match the expected types for each instruction.
+ fn validate_types(&self) -> Result<(), ValidationError> {
+ for block_id in self.rpo() {
+ for &insn_id in &self.blocks[block_id.0].insns {
+ self.validate_insn_type(insn_id)?;
+ }
+ }
+ Ok(())
+ }
+
/// Run all validation passes we have.
pub fn validate(&self) -> Result<(), ValidationError> {
self.validate_block_terminators_and_jumps()?;
self.validate_definite_assignment()?;
self.validate_insn_uniqueness()?;
+ self.validate_types()?;
Ok(())
}
}
diff --git a/zjit/src/stats.rs b/zjit/src/stats.rs
index 4874d0fe64..9965526b76 100644
--- a/zjit/src/stats.rs
+++ b/zjit/src/stats.rs
@@ -202,6 +202,8 @@ make_counters! {
compile_error_validation_jump_target_not_in_rpo,
compile_error_validation_operand_not_defined,
compile_error_validation_duplicate_instruction,
+ compile_error_validation_type_check_failure,
+ compile_error_validation_misc_validation_error,
// The number of times YARV instructions are executed on JIT code
zjit_insn_count,
@@ -320,6 +322,8 @@ pub fn exit_counter_for_compile_error(compile_error: &CompileError) -> Counter {
JumpTargetNotInRPO(_) => compile_error_validation_jump_target_not_in_rpo,
OperandNotDefined(_, _, _) => compile_error_validation_operand_not_defined,
DuplicateInstruction(_, _) => compile_error_validation_duplicate_instruction,
+ MismatchedOperandType(..) => compile_error_validation_type_check_failure,
+ MiscValidationError(..) => compile_error_validation_misc_validation_error,
},
}
}