diff options
| author | Daichi Kamiyama <32436625+dak2@users.noreply.github.com> | 2026-05-13 05:39:11 +0900 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-05-12 20:39:11 +0000 |
| commit | ece14b61f505eea1ebefb3b8295df0fcf4d22567 (patch) | |
| tree | 6e794b8d1caf3e16d6c3e5d56e003626493e86c6 | |
| parent | ab849a434bb1fd12c8e39e601be65a2bda240b39 (diff) | |
ZJIT: Drop redundant type guards via block-local HIR canonicalize (#16828)
## Summary
`GuardType` narrows only its result id, not the original input.
A later guard on the same input therefore looks unfoldable to `fold_constants` even though the side-exit semantics already prove the narrower type.
Canonicalize rewrites later uses to the most recent `Guard*` result, making the redundant guard foldable.
For example, in this CFG-join shape:
```ruby
def test(n, cond)
if cond
a = n + 1
else
a = n + 2
end
n + a # `n` gets a redundant Fixnum guard here
end
```
This PR adds a block-local HIR `canonicalize` pass that walks each block in RPO and rewrites every operand through union-find plus a per-block `rewrite_map` keyed on the most recent `Guard*` for that value.
After canonicalization, `infer_types` can narrow merge-block parameter types and `fold_constants` can then drop the redundant guards in both shapes above.
Inspired by Cranelift's canonicalize https://cfallin.org/blog/2026/04/09/aegraph/
Fixes: https://github.com/Shopify/ruby/issues/978
## Benchmarks
Bench (arm64 linux devcontainer, ruby/ruby-bench, warmup=10 bench=20)
```
Throughput master/staged
lobsters 1.021 (+2.1%, within ±3-6% noise)
railsbench 1.012 (+1.2%, within ±3-4% noise)
--zjit-stats lobsters / railsbench
code_region_bytes -1.1% / -1.0% (redundant CFG-join guards still removed)
guard_type_count -29.4% / +30.1% (railsbench likely single-run noise) ⚠
compile_hir_time +14.6% / +13.7% (canonicalize_time: 67ms / 31ms)
invalidation_time ±0% / ±0%
```
| -rw-r--r-- | zjit.rb | 1 | ||||
| -rw-r--r-- | zjit/src/hir.rs | 63 | ||||
| -rw-r--r-- | zjit/src/hir/opt_tests.rs | 87 | ||||
| -rw-r--r-- | zjit/src/stats.rs | 1 |
4 files changed, 140 insertions, 12 deletions
@@ -148,6 +148,7 @@ class << RubyVM::ZJIT :compile_hir_time_ns, :compile_hir_build_time_ns, :compile_hir_strength_reduce_time_ns, + :compile_hir_canonicalize_time_ns, :compile_hir_fold_constants_time_ns, :compile_hir_clean_cfg_time_ns, :compile_hir_eliminate_dead_code_time_ns, diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index 6b2d9ee7e3..31fd50c16f 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -4144,6 +4144,9 @@ impl Function { } } } + crate::stats::trace_compile_phase("canonicalize", || + crate::stats::with_time_stat(Counter::compile_hir_canonicalize_time_ns, || self.canonicalize()) + ); crate::stats::trace_compile_phase("infer_types", || self.infer_types()); } @@ -4198,6 +4201,9 @@ impl Function { } } } + crate::stats::trace_compile_phase("canonicalize", || + crate::stats::with_time_stat(Counter::compile_hir_canonicalize_time_ns, || self.canonicalize()) + ); crate::stats::trace_compile_phase("infer_types", || self.infer_types()); } @@ -4486,6 +4492,9 @@ impl Function { } } } + crate::stats::trace_compile_phase("canonicalize", || + crate::stats::with_time_stat(Counter::compile_hir_canonicalize_time_ns, || self.canonicalize()) + ); crate::stats::trace_compile_phase("infer_types", || self.infer_types()); } @@ -4816,6 +4825,9 @@ impl Function { self.push_insn_id(block, insn_id); } } + crate::stats::trace_compile_phase("canonicalize", || + crate::stats::with_time_stat(Counter::compile_hir_canonicalize_time_ns, || self.canonicalize()) + ); crate::stats::trace_compile_phase("infer_types", || self.infer_types()); } @@ -4927,6 +4939,54 @@ impl Function { .unwrap_or(insn_id) } + /// Block-local canonicalize: rewrite each operand through union-find and a + /// per-block map of the most recent `Guard*` for that value. Forwards + /// guarded values into branch-edge args (so `infer_types` narrows merge-block + /// parameters and `fold_constants` drops redundant CFG-join guards) and + /// ordinary in-block uses. + /// + /// `Guard*` substitutions are unconditional within a block: a guard's + /// side-exit semantics guarantee the substituted value type holds for every + /// downstream use in the same block. + /// + /// `RefineType` is intentionally skipped: its narrowing is only valid on one + /// branch arm, which would require dropping refine-derived rewrites at each + /// `IfTrue`/`IfFalse`. Cross-arm refine forwarding is left for a follow-up + /// dominator-scoped pass. + /// + /// Inspired by Cranelift's aegraph canonicalize step + /// (<https://cfallin.org/blog/2026/04/09/aegraph/>). + fn canonicalize(&mut self) { + let mut rewrite_map: HashMap<InsnId, InsnId> = HashMap::new(); + for block in self.rpo() { + rewrite_map.clear(); + for i in 0..self.blocks[block.0].insns.len() { + let insn_id = self.blocks[block.0].insns[i]; + let canonical_id = self.union_find.borrow().find_const(insn_id); + + let union_find = &self.union_find; + self.insns[canonical_id.0].for_each_operand_mut(|operand| { + let canon = union_find.borrow().find_const(*operand); + *operand = rewrite_map.get(&canon).copied().unwrap_or(canon); + }); + + // For the binary guards only `left` is registered because their infer_type is + // type_of(left). + match &self.insns[canonical_id.0] { + Insn::GuardType { val: src, .. } + | Insn::GuardBitEquals { val: src, .. } + | Insn::GuardAnyBitSet { val: src, .. } + | Insn::GuardNoBitsSet { val: src, .. } + | Insn::GuardGreaterEq { left: src, .. } + | Insn::GuardLess { left: src, .. } => { + rewrite_map.insert(*src, canonical_id); + } + _ => {} + } + } + } + } + /// Use type information left by `infer_types` to fold away operations that can be evaluated at compile-time. /// /// It can fold fixnum math, truthiness tests, and branches with constant conditionals. @@ -5315,6 +5375,9 @@ impl Function { changed = true; } if changed { + crate::stats::trace_compile_phase("canonicalize", || + crate::stats::with_time_stat(Counter::compile_hir_canonicalize_time_ns, || self.canonicalize()) + ); crate::stats::trace_compile_phase("infer_types", || self.infer_types()); } } diff --git a/zjit/src/hir/opt_tests.rs b/zjit/src/hir/opt_tests.rs index 41fcfb1698..61b85e3a76 100644 --- a/zjit/src/hir/opt_tests.rs +++ b/zjit/src/hir/opt_tests.rs @@ -1129,7 +1129,7 @@ mod hir_opt_tests { PatchPoint NoSingletonClass(CustomEq@0x1008) PatchPoint MethodRedefined(CustomEq@0x1008, !=@0x1010, cme:0x1018) v30:ObjectSubclass[class_exact:CustomEq] = GuardType v10, ObjectSubclass[class_exact:CustomEq] - v31:BoolExact = CCallWithFrame v30, :BasicObject#!=@0x1040, v10 + v31:BoolExact = CCallWithFrame v30, :BasicObject#!=@0x1040, v30 v21:NilClass = Const Value(nil) CheckInterrupts Return v21 @@ -1752,7 +1752,7 @@ mod hir_opt_tests { v28:Fixnum[30] = Const Value(30) v30:Fixnum[40] = Const Value(40) v32:Fixnum[50] = Const Value(50) - v34:BasicObject = Send v6, :target, v24, v26, v28, v30, v32 # SendFallbackReason: Argument count does not match parameter count + v34:BasicObject = Send v44, :target, v24, v26, v28, v30, v32 # SendFallbackReason: Argument count does not match parameter count v37:ArrayExact = NewArray v45, v48, v34 CheckInterrupts Return v37 @@ -4046,7 +4046,7 @@ mod hir_opt_tests { v33:Fixnum[40] = Const Value(40) v35:Fixnum[50] = Const Value(50) v37:Fixnum[60] = Const Value(60) - v39:BasicObject = Send v6, :target, v27, v29, v31, v33, v35, v37 # SendFallbackReason: Too many arguments for LIR + v39:BasicObject = Send v48, :target, v27, v29, v31, v33, v35, v37 # SendFallbackReason: Too many arguments for LIR v41:ArrayExact = NewArray v49, v52, v39 CheckInterrupts Return v41 @@ -4651,7 +4651,7 @@ mod hir_opt_tests { v49:SetExact = GuardType v17, SetExact v50:BasicObject = CCallVariadic v49, :Set#initialize@0x1068 CheckInterrupts - Return v17 + Return v49 "); } @@ -5683,7 +5683,7 @@ mod hir_opt_tests { WriteBarrier v28, v10 v33:CShape[0x1003] = Const CShape(0x1003) StoreField v28, :_shape_id@0x1000, v33 - v14:HeapBasicObject = RefineType v6, HeapBasicObject + v14:HeapBasicObject = RefineType v28, HeapBasicObject v17:Fixnum[2] = Const Value(2) PatchPoint SingleRactorMode StoreField v14, :@bar@0x1004, v17 @@ -6359,7 +6359,7 @@ mod hir_opt_tests { PatchPoint NoSingletonClass(Array@0x1010) PatchPoint MethodRedefined(Array@0x1010, to_s@0x1018, cme:0x1020) v33:BasicObject = CCallWithFrame v28, :Array#to_s@0x1048 - v20:String = AnyToString v10, str: v33 + v20:String = AnyToString v28, str: v33 v22:StringExact = StringConcat v14, v20 CheckInterrupts Return v22 @@ -10089,7 +10089,7 @@ mod hir_opt_tests { v33:CInt64 = AdjustBounds v32, v31 v34:CInt64[0] = Const CInt64(0) v35:CInt64 = GuardGreaterEq v33, v34 - v36:Fixnum = StringGetbyte v28, v33 + v36:Fixnum = StringGetbyte v28, v35 CheckInterrupts Return v36 "); @@ -14643,7 +14643,7 @@ mod hir_opt_tests { SetLocal :other_block, l0, EP@3, v40 v27:CPtr = GetEP 0 v28:BasicObject = LoadField v27, :other_block@0x1051 - v30:BasicObject = InvokeSuper v11, 0x1058, v28 # SendFallbackReason: super: complex argument passing to `super` call + v30:BasicObject = InvokeSuper v39, 0x1058, v28 # SendFallbackReason: super: complex argument passing to `super` call CheckInterrupts Return v30 "); @@ -15257,7 +15257,7 @@ mod hir_opt_tests { v78:BasicObject = LoadField v75, :iter_method@0x1058 v79:BasicObject = LoadField v75, :kwsplat@0x1059 SetLocal :sep, l0, EP@5, v119 - Jump bb8(v58, v76, v119, v78, v79) + Jump bb8(v118, v76, v119, v78, v79) bb8(v83:BasicObject, v84:BasicObject, v85:BasicObject, v86:BasicObject, v87:BasicObject): PatchPoint SingleRactorMode PatchPoint StableConstantNames(0x1060, CONST) @@ -15360,7 +15360,7 @@ mod hir_opt_tests { WriteBarrier v35, v13 v40:CShape[0x1003] = Const CShape(0x1003) StoreField v35, :_shape_id@0x1000, v40 - v20:HeapBasicObject = RefineType v8, HeapBasicObject + v20:HeapBasicObject = RefineType v35, HeapBasicObject PatchPoint NoEPEscape(initialize) PatchPoint SingleRactorMode WriteBarrier v20, v13 @@ -15408,7 +15408,7 @@ mod hir_opt_tests { WriteBarrier v49, v16 v54:CShape[0x1003] = Const CShape(0x1003) StoreField v49, :_shape_id@0x1000, v54 - v23:HeapBasicObject = RefineType v10, HeapBasicObject + v23:HeapBasicObject = RefineType v49, HeapBasicObject v26:Fixnum[5] = Const Value(5) PatchPoint NoEPEscape(initialize) PatchPoint MethodRedefined(Integer@0x1008, +@0x1010, cme:0x1018) @@ -15456,7 +15456,7 @@ mod hir_opt_tests { WriteBarrier v43, v13 v48:CShape[0x1003] = Const CShape(0x1003) StoreField v43, :_shape_id@0x1000, v48 - v20:HeapBasicObject = RefineType v8, HeapBasicObject + v20:HeapBasicObject = RefineType v43, HeapBasicObject PatchPoint NoEPEscape(initialize) PatchPoint SingleRactorMode WriteBarrier v20, v13 @@ -15723,6 +15723,69 @@ mod hir_opt_tests { } #[test] + fn test_dedup_guard_type_across_cfg_join() { + eval(" + def test(n, cond) + if cond + a = n + 1 + else + a = n + 2 + end + n + a + end + test(1, true); test(1, false) + "); + let hir = hir_string("test"); + let guard_count = hir.matches("GuardType").count(); + assert_eq!( + guard_count, 2, + "expected 2 GuardType instructions after cross-block dedup, found {guard_count}\n\nHIR:\n{hir}" + ); + } + + #[test] + fn test_forward_guard_through_conditional_branch() { + eval(" + def test(n, a, b) + if a + if b + n + 1 + else + n + 2 + end + else + n + 3 + end + end + test(1, true, true); test(1, true, false); test(1, false, false) + "); + let hir = hir_string("test"); + let guard_count = hir.matches("GuardType").count(); + assert!( + guard_count <= 3, + "expected at most 3 GuardType instructions (one per leaf branch) after forwarding through conditional branches, found {guard_count}\n\nHIR:\n{hir}" + ); + } + + #[test] + fn test_no_forward_when_no_guard_in_branches() { + let src = " + def test(n, cond) + a = if cond then 1 else 2 end + n + a + end + test(1, true); test(1, false) + "; + eval(src); + let hir = hir_string("test"); + let guard_count = hir.matches("GuardType").count(); + assert_eq!( + guard_count, 1, + "expected 1 GuardType (merge block only), found {guard_count}\n\nHIR:\n{hir}" + ); + } + + #[test] fn test_infer_types_across_non_maximal_basic_blocks() { // Previous worklist-based type inference only worked for maximal SSA. This is a regression // test for hanging. diff --git a/zjit/src/stats.rs b/zjit/src/stats.rs index 522b74c48a..587dde0660 100644 --- a/zjit/src/stats.rs +++ b/zjit/src/stats.rs @@ -171,6 +171,7 @@ make_counters! { compile_hir_build_time_ns, compile_hir_strength_reduce_time_ns, compile_hir_optimize_load_store_time_ns, + compile_hir_canonicalize_time_ns, compile_hir_fold_constants_time_ns, compile_hir_clean_cfg_time_ns, compile_hir_remove_redundant_patch_points_time_ns, |
