summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKen Jin <kenjin4096@gmail.com>2025-07-07 23:45:01 +0800
committerGitHub <noreply@github.com>2025-07-07 11:45:01 -0400
commitc1937480acc0896531b30464951209250b6a581b (patch)
treeb15ca889ac409ab7f895ae41a1de6cab912cbaac
parent0bb44f291e7fb4ec5802826d40a5a445e51ef959 (diff)
ZJIT: Add a simple HIR validator (#13780)
This PR adds a simple validator for ZJIT's HIR. See issue https://github.com/Shopify/ruby/issues/591
-rw-r--r--zjit/src/codegen.rs4
-rw-r--r--zjit/src/hir.rs127
2 files changed, 131 insertions, 0 deletions
diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs
index 306ba31aba..5e5ea40443 100644
--- a/zjit/src/codegen.rs
+++ b/zjit/src/codegen.rs
@@ -1085,6 +1085,10 @@ fn compile_iseq(iseq: IseqPtr) -> Option<Function> {
}
};
function.optimize();
+ if let Err(err) = function.validate() {
+ debug!("ZJIT: compile_iseq: {err:?}");
+ return None;
+ }
Some(function)
}
diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs
index 32afebce13..df24b061f8 100644
--- a/zjit/src/hir.rs
+++ b/zjit/src/hir.rs
@@ -900,6 +900,17 @@ impl<T: Copy + Into<usize> + PartialEq> UnionFind<T> {
}
}
+#[derive(Debug, PartialEq)]
+pub enum ValidationError {
+ // All validation errors come with the function's representation as the first argument.
+ BlockHasNoTerminator(String, BlockId),
+ // The terminator and its actual position
+ TerminatorNotAtEnd(String, BlockId, InsnId, usize),
+ /// Expected length, actual length
+ MismatchedBlockArity(String, BlockId, usize, usize),
+}
+
+
/// A [`Function`], which is analogous to a Ruby ISeq, is a control-flow graph of [`Block`]s
/// containing instructions.
#[derive(Debug)]
@@ -2005,6 +2016,51 @@ impl Function {
None => {},
}
}
+
+
+ /// Validates the following:
+ /// 1. Basic block jump args match parameter arity.
+ /// 2. Every terminator must be in the last position.
+ /// 3. Every block must have a terminator.
+ fn validate_block_terminators_and_jumps(&self) -> Result<(), ValidationError> {
+ for block_id in self.rpo() {
+ let mut block_has_terminator = false;
+ let insns = &self.blocks[block_id.0].insns;
+ for (idx, insn_id) in insns.iter().enumerate() {
+ let insn = self.find(*insn_id);
+ match &insn {
+ Insn::Jump(BranchEdge{target, args})
+ | Insn::IfTrue { val: _, target: BranchEdge{target, args} }
+ | Insn::IfFalse { val: _, target: BranchEdge{target, args}} => {
+ let target_block = &self.blocks[target.0];
+ let target_len = target_block.params.len();
+ let args_len = args.len();
+ if target_len != args_len {
+ return Err(ValidationError::MismatchedBlockArity(format!("{:?}", self), block_id, target_len, args_len))
+ }
+ }
+ _ => {}
+ }
+ if !insn.is_terminator() {
+ continue;
+ }
+ block_has_terminator = true;
+ if idx != insns.len() - 1 {
+ return Err(ValidationError::TerminatorNotAtEnd(format!("{:?}", self), block_id, *insn_id, idx));
+ }
+ }
+ if !block_has_terminator {
+ return Err(ValidationError::BlockHasNoTerminator(format!("{:?}", self), block_id));
+ }
+ }
+ Ok(())
+ }
+
+ /// Run all validation passes we have.
+ pub fn validate(&self) -> Result<(), ValidationError> {
+ self.validate_block_terminators_and_jumps()?;
+ Ok(())
+ }
}
impl<'a> std::fmt::Display for FunctionPrinter<'a> {
@@ -2241,6 +2297,7 @@ pub enum ParseError {
StackUnderflow(FrameState),
UnknownParameterType(ParameterType),
MalformedIseq(u32), // insn_idx into iseq_encoded
+ Validation(ValidationError),
}
/// Return the number of locals in the current ISEQ (includes parameters)
@@ -2966,6 +3023,9 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
}
fun.profiles = Some(profiles);
+ if let Err(e) = fun.validate() {
+ return Err(ParseError::Validation(e));
+ }
Ok(fun)
}
@@ -3070,6 +3130,72 @@ mod rpo_tests {
}
#[cfg(test)]
+mod validation_tests {
+ use super::*;
+
+ #[track_caller]
+ fn assert_matches_err(res: Result<(), ValidationError>, expected: ValidationError) {
+ match res {
+ Err(validation_err) => {
+ assert_eq!(validation_err, expected);
+ }
+ Ok(_) => assert!(false, "Expected validation error"),
+ }
+ }
+
+ #[test]
+ fn one_block_no_terminator() {
+ let mut function = Function::new(std::ptr::null());
+ let entry = function.entry_block;
+ let val = function.push_insn(entry, Insn::Const { val: Const::Value(Qnil) });
+ assert_matches_err(function.validate(), ValidationError::BlockHasNoTerminator(format!("{:?}", function), entry));
+ }
+
+ #[test]
+ fn one_block_terminator_not_at_end() {
+ let mut function = Function::new(std::ptr::null());
+ let entry = function.entry_block;
+ let val = function.push_insn(entry, Insn::Const { val: Const::Value(Qnil) });
+ let insn_id = function.push_insn(entry, Insn::Return { val });
+ function.push_insn(entry, Insn::Const { val: Const::Value(Qnil) });
+ assert_matches_err(function.validate(), ValidationError::TerminatorNotAtEnd(format!("{:?}", function), entry, insn_id, 1));
+ }
+
+ #[test]
+ fn iftrue_mismatch_args() {
+ let mut function = Function::new(std::ptr::null());
+ let entry = function.entry_block;
+ let side = function.new_block();
+ let exit = function.new_block();
+ let val = function.push_insn(entry, Insn::Const { val: Const::Value(Qnil) });
+ function.push_insn(entry, Insn::IfTrue { val, target: BranchEdge { target: side, args: vec![val, val, val] } });
+ assert_matches_err(function.validate(), ValidationError::MismatchedBlockArity(format!("{:?}", function), entry, 0, 3));
+ }
+
+ #[test]
+ fn iffalse_mismatch_args() {
+ let mut function = Function::new(std::ptr::null());
+ let entry = function.entry_block;
+ let side = function.new_block();
+ let exit = function.new_block();
+ let val = function.push_insn(entry, Insn::Const { val: Const::Value(Qnil) });
+ function.push_insn(entry, Insn::IfFalse { val, target: BranchEdge { target: side, args: vec![val, val, val] } });
+ assert_matches_err(function.validate(), ValidationError::MismatchedBlockArity(format!("{:?}", function), entry, 0, 3));
+ }
+
+ #[test]
+ fn jump_mismatch_args() {
+ let mut function = Function::new(std::ptr::null());
+ let entry = function.entry_block;
+ let side = function.new_block();
+ let exit = function.new_block();
+ let val = function.push_insn(entry, Insn::Const { val: Const::Value(Qnil) });
+ function.push_insn(entry, Insn::Jump ( BranchEdge { target: side, args: vec![val, val, val] } ));
+ assert_matches_err(function.validate(), ValidationError::MismatchedBlockArity(format!("{:?}", function), entry, 0, 3));
+ }
+}
+
+#[cfg(test)]
mod infer_tests {
use super::*;
@@ -4701,6 +4827,7 @@ mod opt_tests {
unsafe { crate::cruby::rb_zjit_profile_disable(iseq) };
let mut function = iseq_to_hir(iseq).unwrap();
function.optimize();
+ function.validate().unwrap();
assert_function_hir(function, hir);
}