diff options
| author | Max Bernstein <rubybugs@bernsteinbear.com> | 2025-10-28 20:11:27 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-10-29 00:11:27 +0000 |
| commit | f8a333ae193017999b38f6a4838582cc2c333063 (patch) | |
| tree | 8bc50d52aa3a37ed7f83d2df58b213ebe39c7f30 | |
| parent | 80be97e4a2c878d7c5a129b245f1e2430b99b19b (diff) | |
ZJIT: Add type checker to HIR (#14978)
Allow instructions to constrain their operands' input types to avoid
accidentally creating invalid HIR.
| -rw-r--r-- | zjit/src/hir.rs | 162 | ||||
| -rw-r--r-- | zjit/src/stats.rs | 4 |
2 files changed, 166 insertions, 0 deletions
diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index 3f68764722..b284ae6c11 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -1356,6 +1356,9 @@ pub enum ValidationError { OperandNotDefined(BlockId, InsnId, InsnId), /// The offending block and instruction DuplicateInstruction(BlockId, InsnId), + /// The offending instruction, its operand, expected type string, actual type string + MismatchedOperandType(InsnId, InsnId, String, String), + MiscValidationError(InsnId, String), } fn can_direct_send(iseq: *const rb_iseq_t) -> bool { @@ -3565,11 +3568,170 @@ impl Function { Ok(()) } + fn assert_subtype(&self, user: InsnId, operand: InsnId, expected: Type) -> Result<(), ValidationError> { + let actual = self.type_of(operand); + if !actual.is_subtype(expected) { + return Err(ValidationError::MismatchedOperandType(user, operand, format!("{}", expected), format!("{}", actual))); + } + Ok(()) + } + + fn validate_insn_type(&self, insn_id: InsnId) -> Result<(), ValidationError> { + let insn_id = self.union_find.borrow().find_const(insn_id); + let insn = self.find(insn_id); + match insn { + Insn::StringCopy { val, .. } => self.assert_subtype(insn_id, val, types::StringExact), + Insn::StringIntern { val, .. } => self.assert_subtype(insn_id, val, types::StringExact), + Insn::ArrayDup { val, .. } => self.assert_subtype(insn_id, val, types::ArrayExact), + Insn::StringAppend { recv, other, .. } => { + self.assert_subtype(insn_id, recv, types::StringExact)?; + self.assert_subtype(insn_id, other, types::String) + } + Insn::NewHash { ref elements, .. } => { + if elements.len() % 2 != 0 { + return Err(ValidationError::MiscValidationError(insn_id, "NewHash elements length is not even".to_string())); + } + Ok(()) + } + Insn::NewRangeFixnum { low, high, .. } => { + self.assert_subtype(insn_id, low, types::Fixnum)?; + self.assert_subtype(insn_id, high, types::Fixnum) + } + Insn::ArrayExtend { left, right, .. } => { + // TODO(max): Do left and right need to be ArrayExact? + self.assert_subtype(insn_id, left, types::Array)?; + self.assert_subtype(insn_id, right, types::Array) + } + Insn::ArrayPush { array, .. } => self.assert_subtype(insn_id, array, types::Array), + Insn::ArrayPop { array, .. } => self.assert_subtype(insn_id, array, types::Array), + Insn::ArrayLength { array, .. } => self.assert_subtype(insn_id, array, types::Array), + Insn::HashAref { hash, .. } => self.assert_subtype(insn_id, hash, types::Hash), + Insn::HashDup { val, .. } => self.assert_subtype(insn_id, val, types::HashExact), + Insn::ObjectAllocClass { class, .. } => { + let has_leaf_allocator = unsafe { rb_zjit_class_has_default_allocator(class) } || class_has_leaf_allocator(class); + if !has_leaf_allocator { + return Err(ValidationError::MiscValidationError(insn_id, "ObjectAllocClass must have leaf allocator".to_string())); + } + Ok(()) + } + Insn::Test { val } => self.assert_subtype(insn_id, val, types::BasicObject), + Insn::IsNil { val } => self.assert_subtype(insn_id, val, types::BasicObject), + Insn::IsMethodCfunc { val, .. } => self.assert_subtype(insn_id, val, types::BasicObject), + Insn::IsBitEqual { left, right } + | Insn::IsBitNotEqual { left, right } => { + if self.is_a(left, types::CInt) && self.is_a(right, types::CInt) { + // TODO(max): Check that int sizes match + Ok(()) + } else if self.is_a(left, types::CPtr) && self.is_a(right, types::CPtr) { + Ok(()) + } else if self.is_a(left, types::RubyValue) && self.is_a(right, types::RubyValue) { + Ok(()) + } else { + return Err(ValidationError::MiscValidationError(insn_id, "IsBitEqual can only compare CInt/CInt or RubyValue/RubyValue".to_string())); + } + } + Insn::BoxBool { val } => self.assert_subtype(insn_id, val, types::CBool), + Insn::SetGlobal { val, .. } => self.assert_subtype(insn_id, val, types::BasicObject), + Insn::GetIvar { self_val, .. } => self.assert_subtype(insn_id, self_val, types::BasicObject), + Insn::SetIvar { self_val, val, .. } => { + self.assert_subtype(insn_id, self_val, types::BasicObject)?; + self.assert_subtype(insn_id, val, types::BasicObject) + } + Insn::DefinedIvar { self_val, .. } => self.assert_subtype(insn_id, self_val, types::BasicObject), + Insn::LoadIvarEmbedded { self_val, .. } => self.assert_subtype(insn_id, self_val, types::HeapBasicObject), + Insn::LoadIvarExtended { self_val, .. } => self.assert_subtype(insn_id, self_val, types::HeapBasicObject), + Insn::SetLocal { val, .. } => self.assert_subtype(insn_id, val, types::BasicObject), + Insn::SetClassVar { val, .. } => self.assert_subtype(insn_id, val, types::BasicObject), + Insn::IfTrue { val, .. } | Insn::IfFalse { val, .. } => self.assert_subtype(insn_id, val, types::CBool), + Insn::SendWithoutBlock { recv, ref args, .. } + | Insn::SendWithoutBlockDirect { recv, ref args, .. } + | Insn::Send { recv, ref args, .. } + | Insn::SendForward { recv, ref args, .. } + | Insn::InvokeSuper { recv, ref args, .. } + | Insn::CCallVariadic { recv, ref args, .. } => { + self.assert_subtype(insn_id, recv, types::BasicObject)?; + for &arg in args { + self.assert_subtype(insn_id, arg, types::BasicObject)?; + } + Ok(()) + } + Insn::CCallWithFrame { ref args, .. } + | Insn::InvokeBuiltin { ref args, .. } + | Insn::InvokeBlock { ref args, .. } => { + for &arg in args { + self.assert_subtype(insn_id, arg, types::BasicObject)?; + } + Ok(()) + } + Insn::Return { val } => self.assert_subtype(insn_id, val, types::BasicObject), + Insn::Throw { val, .. } => self.assert_subtype(insn_id, val, types::BasicObject), + Insn::FixnumAdd { left, right, .. } + | Insn::FixnumSub { left, right, .. } + | Insn::FixnumMult { left, right, .. } + | Insn::FixnumDiv { left, right, .. } + | Insn::FixnumMod { left, right, .. } + | Insn::FixnumEq { left, right } + | Insn::FixnumNeq { left, right } + | Insn::FixnumLt { left, right } + | Insn::FixnumLe { left, right } + | Insn::FixnumGt { left, right } + | Insn::FixnumGe { left, right } + | Insn::FixnumAnd { left, right } + | Insn::FixnumOr { left, right } + | Insn::FixnumXor { left, right } + => { + self.assert_subtype(insn_id, left, types::Fixnum)?; + self.assert_subtype(insn_id, right, types::Fixnum) + } + Insn::ObjToString { val, .. } => self.assert_subtype(insn_id, val, types::BasicObject), + Insn::AnyToString { val, str, .. } => { + self.assert_subtype(insn_id, val, types::BasicObject)?; + self.assert_subtype(insn_id, str, types::BasicObject) + } + Insn::GuardType { val, .. } => self.assert_subtype(insn_id, val, types::BasicObject), + Insn::GuardTypeNot { val, .. } => self.assert_subtype(insn_id, val, types::BasicObject), + Insn::GuardBitEquals { val, expected, .. } => { + match expected { + Const::Value(_) => self.assert_subtype(insn_id, val, types::RubyValue), + Const::CInt8(_) => self.assert_subtype(insn_id, val, types::CInt8), + Const::CInt16(_) => self.assert_subtype(insn_id, val, types::CInt16), + Const::CInt32(_) => self.assert_subtype(insn_id, val, types::CInt32), + Const::CInt64(_) => self.assert_subtype(insn_id, val, types::CInt64), + Const::CUInt8(_) => self.assert_subtype(insn_id, val, types::CUInt8), + Const::CUInt16(_) => self.assert_subtype(insn_id, val, types::CUInt16), + Const::CUInt32(_) => self.assert_subtype(insn_id, val, types::CUInt32), + Const::CUInt64(_) => self.assert_subtype(insn_id, val, types::CUInt64), + Const::CBool(_) => self.assert_subtype(insn_id, val, types::CBool), + Const::CDouble(_) => self.assert_subtype(insn_id, val, types::CDouble), + Const::CPtr(_) => self.assert_subtype(insn_id, val, types::CPtr), + } + } + Insn::GuardShape { val, .. } => self.assert_subtype(insn_id, val, types::BasicObject), + Insn::GuardNotFrozen { val, .. } => self.assert_subtype(insn_id, val, types::BasicObject), + Insn::StringGetbyteFixnum { string, index } => { + self.assert_subtype(insn_id, string, types::String)?; + self.assert_subtype(insn_id, index, types::Fixnum) + } + _ => Ok(()), + } + } + + /// Check that insn types match the expected types for each instruction. + fn validate_types(&self) -> Result<(), ValidationError> { + for block_id in self.rpo() { + for &insn_id in &self.blocks[block_id.0].insns { + self.validate_insn_type(insn_id)?; + } + } + Ok(()) + } + /// Run all validation passes we have. pub fn validate(&self) -> Result<(), ValidationError> { self.validate_block_terminators_and_jumps()?; self.validate_definite_assignment()?; self.validate_insn_uniqueness()?; + self.validate_types()?; Ok(()) } } diff --git a/zjit/src/stats.rs b/zjit/src/stats.rs index 4874d0fe64..9965526b76 100644 --- a/zjit/src/stats.rs +++ b/zjit/src/stats.rs @@ -202,6 +202,8 @@ make_counters! { compile_error_validation_jump_target_not_in_rpo, compile_error_validation_operand_not_defined, compile_error_validation_duplicate_instruction, + compile_error_validation_type_check_failure, + compile_error_validation_misc_validation_error, // The number of times YARV instructions are executed on JIT code zjit_insn_count, @@ -320,6 +322,8 @@ pub fn exit_counter_for_compile_error(compile_error: &CompileError) -> Counter { JumpTargetNotInRPO(_) => compile_error_validation_jump_target_not_in_rpo, OperandNotDefined(_, _, _) => compile_error_validation_operand_not_defined, DuplicateInstruction(_, _) => compile_error_validation_duplicate_instruction, + MismatchedOperandType(..) => compile_error_validation_type_check_failure, + MiscValidationError(..) => compile_error_validation_misc_validation_error, }, } } |
