From ab849a434bb1fd12c8e39e601be65a2bda240b39 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Mon, 11 May 2026 12:17:09 -0400 Subject: ZJIT: Track param-type changes in infer_types fixpoint Branch arms (IfTrue, IfFalse, Jump) update target block param types but were not flagging the fixpoint loop's `changed` bit. With a pure shuffle block (no non-branch insns to drive `changed` via their own infer_type), the loop could exit while param types were still widening. Now each branch arm sets `changed = true` whenever the union actually grew a param's type. Add an HIR build test: a self-loop with a 4-cycle param rotation must reach the full union of all four input types at every param, which would previously fall short by one type even with parallel phi semantics. --- zjit/src/hir.rs | 72 ++++++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 64 insertions(+), 8 deletions(-) diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index 8b671dd815..6b2d9ee7e3 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -3010,6 +3010,17 @@ impl Function { // Fill entry parameter types self.copy_param_types(); + // Assign `new_type` to `insn` if it differs from the recorded type. + // Returns `true` if a write actually happened, `false` if the type + // was already equal. + let set_type = |this: &mut Function, insn: InsnId, new_type: Type| -> bool { + if this.type_of(insn).bit_equal(new_type) { + return false; + } + this.insn_types[insn.0] = new_type; + true + }; + let mut reachable = BlockSet::with_capacity(self.blocks.len()); reachable.insert(self.entries_block); @@ -3019,7 +3030,8 @@ impl Function { let mut changed = false; for &block in &rpo { if !reachable.get(block) { continue; } - for &insn_id in &self.blocks[block.0].insns { + for i in 0..self.blocks[block.0].insns.len() { + let insn_id = self.blocks[block.0].insns[i]; // Instructions without output, including branch instructions, can't be targets // of make_equal_to, so we don't need find() here. let insn_type = match &self.insns[insn_id.0] { @@ -3033,7 +3045,7 @@ impl Function { let arg_types: Vec = args.iter().map(|a| self.type_of(*a)).collect(); for (idx, arg_type) in arg_types.into_iter().enumerate() { let param = self.blocks[target.0].params[idx]; - self.insn_types[param.0] = self.type_of(param).union(arg_type); + changed |= set_type(self, param, self.type_of(param).union(arg_type)); } } continue; @@ -3045,7 +3057,7 @@ impl Function { let arg_types: Vec = args.iter().map(|a| self.type_of(*a)).collect(); for (idx, arg_type) in arg_types.into_iter().enumerate() { let param = self.blocks[target.0].params[idx]; - self.insn_types[param.0] = self.type_of(param).union(arg_type); + changed |= set_type(self, param, self.type_of(param).union(arg_type)); } } continue; @@ -3055,7 +3067,7 @@ impl Function { let arg_types: Vec = args.iter().map(|a| self.type_of(*a)).collect(); for (idx, arg_type) in arg_types.into_iter().enumerate() { let param = self.blocks[target.0].params[idx]; - self.insn_types[param.0] = self.type_of(param).union(arg_type); + changed |= set_type(self, param, self.type_of(param).union(arg_type)); } continue; } @@ -3068,10 +3080,7 @@ impl Function { insn if insn.has_output() => self.infer_type(insn_id), _ => continue, }; - if !self.type_of(insn_id).bit_equal(insn_type) { - self.insn_types[insn_id.0] = insn_type; - changed = true; - } + changed |= set_type(self, insn_id, insn_type); } } if !changed { @@ -9100,6 +9109,53 @@ mod infer_tests { assert_bit_equal(function.type_of(param), types::Fixnum); } + #[test] + fn self_loop_param_rotation_reaches_full_union() { + // bb_entry: jump bb_loop(c1, c2, c3, c4) // 4 distinct types + // bb_loop(p1, p2, p3, p4): + // jump bb_loop(p2, p3, p4, p1) // 4-cycle rotation + // + // Every param transitively flows into every other across enough trips + // around the loop, so the fixpoint for every param is the full union + // of all four input types. The fixpoint loop must not exit while a + // branch arm is still widening a param's type. + let mut function = Function::new(std::ptr::null()); + let entry = function.entry_block; + let loop_block = function.new_block(0); + + let c1 = function.push_insn(entry, Insn::Const { val: Const::Value(Qtrue) }); + let c2 = function.push_insn(entry, Insn::Const { val: Const::Value(Qfalse) }); + let c3 = function.push_insn(entry, Insn::Const { val: Const::Value(Qnil) }); + let c4 = function.push_insn(entry, Insn::Const { val: Const::Value(VALUE::fixnum_from_usize(7)) }); + function.push_insn(entry, Insn::Jump(BranchEdge { + target: loop_block, + args: vec![c1, c2, c3, c4], + })); + + let p1 = function.push_insn(loop_block, Insn::Param); + let p2 = function.push_insn(loop_block, Insn::Param); + let p3 = function.push_insn(loop_block, Insn::Param); + let p4 = function.push_insn(loop_block, Insn::Param); + function.push_insn(loop_block, Insn::Jump(BranchEdge { + target: loop_block, + args: vec![p2, p3, p4, p1], + })); + + function.seal_entries(); + crate::cruby::with_rubyvm(|| { + function.infer_types(); + }); + + let full = types::TrueClass + .union(types::FalseClass) + .union(types::NilClass) + .union(types::Fixnum); + assert_bit_equal(function.type_of(p1), full); + assert_bit_equal(function.type_of(p2), full); + assert_bit_equal(function.type_of(p3), full); + assert_bit_equal(function.type_of(p4), full); + } + #[test] fn diamond_iffalse_merge_bool() { let mut function = Function::new(std::ptr::null()); -- cgit v1.2.3