diff options
| author | Ken Jin <kenjin4096@gmail.com> | 2025-07-07 23:45:01 +0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-07-07 11:45:01 -0400 |
| commit | c1937480acc0896531b30464951209250b6a581b (patch) | |
| tree | b15ca889ac409ab7f895ae41a1de6cab912cbaac | |
| parent | 0bb44f291e7fb4ec5802826d40a5a445e51ef959 (diff) | |
ZJIT: Add a simple HIR validator (#13780)
This PR adds a simple validator for ZJIT's HIR.
See issue https://github.com/Shopify/ruby/issues/591
| -rw-r--r-- | zjit/src/codegen.rs | 4 | ||||
| -rw-r--r-- | zjit/src/hir.rs | 127 |
2 files changed, 131 insertions, 0 deletions
diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs index 306ba31aba..5e5ea40443 100644 --- a/zjit/src/codegen.rs +++ b/zjit/src/codegen.rs @@ -1085,6 +1085,10 @@ fn compile_iseq(iseq: IseqPtr) -> Option<Function> { } }; function.optimize(); + if let Err(err) = function.validate() { + debug!("ZJIT: compile_iseq: {err:?}"); + return None; + } Some(function) } diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index 32afebce13..df24b061f8 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -900,6 +900,17 @@ impl<T: Copy + Into<usize> + PartialEq> UnionFind<T> { } } +#[derive(Debug, PartialEq)] +pub enum ValidationError { + // All validation errors come with the function's representation as the first argument. + BlockHasNoTerminator(String, BlockId), + // The terminator and its actual position + TerminatorNotAtEnd(String, BlockId, InsnId, usize), + /// Expected length, actual length + MismatchedBlockArity(String, BlockId, usize, usize), +} + + /// A [`Function`], which is analogous to a Ruby ISeq, is a control-flow graph of [`Block`]s /// containing instructions. #[derive(Debug)] @@ -2005,6 +2016,51 @@ impl Function { None => {}, } } + + + /// Validates the following: + /// 1. Basic block jump args match parameter arity. + /// 2. Every terminator must be in the last position. + /// 3. Every block must have a terminator. + fn validate_block_terminators_and_jumps(&self) -> Result<(), ValidationError> { + for block_id in self.rpo() { + let mut block_has_terminator = false; + let insns = &self.blocks[block_id.0].insns; + for (idx, insn_id) in insns.iter().enumerate() { + let insn = self.find(*insn_id); + match &insn { + Insn::Jump(BranchEdge{target, args}) + | Insn::IfTrue { val: _, target: BranchEdge{target, args} } + | Insn::IfFalse { val: _, target: BranchEdge{target, args}} => { + let target_block = &self.blocks[target.0]; + let target_len = target_block.params.len(); + let args_len = args.len(); + if target_len != args_len { + return Err(ValidationError::MismatchedBlockArity(format!("{:?}", self), block_id, target_len, args_len)) + } + } + _ => {} + } + if !insn.is_terminator() { + continue; + } + block_has_terminator = true; + if idx != insns.len() - 1 { + return Err(ValidationError::TerminatorNotAtEnd(format!("{:?}", self), block_id, *insn_id, idx)); + } + } + if !block_has_terminator { + return Err(ValidationError::BlockHasNoTerminator(format!("{:?}", self), block_id)); + } + } + Ok(()) + } + + /// Run all validation passes we have. + pub fn validate(&self) -> Result<(), ValidationError> { + self.validate_block_terminators_and_jumps()?; + Ok(()) + } } impl<'a> std::fmt::Display for FunctionPrinter<'a> { @@ -2241,6 +2297,7 @@ pub enum ParseError { StackUnderflow(FrameState), UnknownParameterType(ParameterType), MalformedIseq(u32), // insn_idx into iseq_encoded + Validation(ValidationError), } /// Return the number of locals in the current ISEQ (includes parameters) @@ -2966,6 +3023,9 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> { } fun.profiles = Some(profiles); + if let Err(e) = fun.validate() { + return Err(ParseError::Validation(e)); + } Ok(fun) } @@ -3070,6 +3130,72 @@ mod rpo_tests { } #[cfg(test)] +mod validation_tests { + use super::*; + + #[track_caller] + fn assert_matches_err(res: Result<(), ValidationError>, expected: ValidationError) { + match res { + Err(validation_err) => { + assert_eq!(validation_err, expected); + } + Ok(_) => assert!(false, "Expected validation error"), + } + } + + #[test] + fn one_block_no_terminator() { + let mut function = Function::new(std::ptr::null()); + let entry = function.entry_block; + let val = function.push_insn(entry, Insn::Const { val: Const::Value(Qnil) }); + assert_matches_err(function.validate(), ValidationError::BlockHasNoTerminator(format!("{:?}", function), entry)); + } + + #[test] + fn one_block_terminator_not_at_end() { + let mut function = Function::new(std::ptr::null()); + let entry = function.entry_block; + let val = function.push_insn(entry, Insn::Const { val: Const::Value(Qnil) }); + let insn_id = function.push_insn(entry, Insn::Return { val }); + function.push_insn(entry, Insn::Const { val: Const::Value(Qnil) }); + assert_matches_err(function.validate(), ValidationError::TerminatorNotAtEnd(format!("{:?}", function), entry, insn_id, 1)); + } + + #[test] + fn iftrue_mismatch_args() { + let mut function = Function::new(std::ptr::null()); + let entry = function.entry_block; + let side = function.new_block(); + let exit = function.new_block(); + let val = function.push_insn(entry, Insn::Const { val: Const::Value(Qnil) }); + function.push_insn(entry, Insn::IfTrue { val, target: BranchEdge { target: side, args: vec![val, val, val] } }); + assert_matches_err(function.validate(), ValidationError::MismatchedBlockArity(format!("{:?}", function), entry, 0, 3)); + } + + #[test] + fn iffalse_mismatch_args() { + let mut function = Function::new(std::ptr::null()); + let entry = function.entry_block; + let side = function.new_block(); + let exit = function.new_block(); + let val = function.push_insn(entry, Insn::Const { val: Const::Value(Qnil) }); + function.push_insn(entry, Insn::IfFalse { val, target: BranchEdge { target: side, args: vec![val, val, val] } }); + assert_matches_err(function.validate(), ValidationError::MismatchedBlockArity(format!("{:?}", function), entry, 0, 3)); + } + + #[test] + fn jump_mismatch_args() { + let mut function = Function::new(std::ptr::null()); + let entry = function.entry_block; + let side = function.new_block(); + let exit = function.new_block(); + let val = function.push_insn(entry, Insn::Const { val: Const::Value(Qnil) }); + function.push_insn(entry, Insn::Jump ( BranchEdge { target: side, args: vec![val, val, val] } )); + assert_matches_err(function.validate(), ValidationError::MismatchedBlockArity(format!("{:?}", function), entry, 0, 3)); + } +} + +#[cfg(test)] mod infer_tests { use super::*; @@ -4701,6 +4827,7 @@ mod opt_tests { unsafe { crate::cruby::rb_zjit_profile_disable(iseq) }; let mut function = iseq_to_hir(iseq).unwrap(); function.optimize(); + function.validate().unwrap(); assert_function_hir(function, hir); } |
