summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaichi Kamiyama <32436625+dak2@users.noreply.github.com>2026-05-13 05:39:11 +0900
committerGitHub <noreply@github.com>2026-05-12 20:39:11 +0000
commitece14b61f505eea1ebefb3b8295df0fcf4d22567 (patch)
tree6e794b8d1caf3e16d6c3e5d56e003626493e86c6
parentab849a434bb1fd12c8e39e601be65a2bda240b39 (diff)
ZJIT: Drop redundant type guards via block-local HIR canonicalize (#16828)
## Summary `GuardType` narrows only its result id, not the original input. A later guard on the same input therefore looks unfoldable to `fold_constants` even though the side-exit semantics already prove the narrower type. Canonicalize rewrites later uses to the most recent `Guard*` result, making the redundant guard foldable. For example, in this CFG-join shape: ```ruby def test(n, cond) if cond a = n + 1 else a = n + 2 end n + a # `n` gets a redundant Fixnum guard here end ``` This PR adds a block-local HIR `canonicalize` pass that walks each block in RPO and rewrites every operand through union-find plus a per-block `rewrite_map` keyed on the most recent `Guard*` for that value. After canonicalization, `infer_types` can narrow merge-block parameter types and `fold_constants` can then drop the redundant guards in both shapes above. Inspired by Cranelift's canonicalize https://cfallin.org/blog/2026/04/09/aegraph/ Fixes: https://github.com/Shopify/ruby/issues/978 ## Benchmarks Bench (arm64 linux devcontainer, ruby/ruby-bench, warmup=10 bench=20) ``` Throughput master/staged lobsters 1.021 (+2.1%, within ±3-6% noise) railsbench 1.012 (+1.2%, within ±3-4% noise) --zjit-stats lobsters / railsbench code_region_bytes -1.1% / -1.0% (redundant CFG-join guards still removed) guard_type_count -29.4% / +30.1% (railsbench likely single-run noise) ⚠ compile_hir_time +14.6% / +13.7% (canonicalize_time: 67ms / 31ms) invalidation_time ±0% / ±0% ```
-rw-r--r--zjit.rb1
-rw-r--r--zjit/src/hir.rs63
-rw-r--r--zjit/src/hir/opt_tests.rs87
-rw-r--r--zjit/src/stats.rs1
4 files changed, 140 insertions, 12 deletions
diff --git a/zjit.rb b/zjit.rb
index 89a4a15cfd..480ffa1544 100644
--- a/zjit.rb
+++ b/zjit.rb
@@ -148,6 +148,7 @@ class << RubyVM::ZJIT
:compile_hir_time_ns,
:compile_hir_build_time_ns,
:compile_hir_strength_reduce_time_ns,
+ :compile_hir_canonicalize_time_ns,
:compile_hir_fold_constants_time_ns,
:compile_hir_clean_cfg_time_ns,
:compile_hir_eliminate_dead_code_time_ns,
diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs
index 6b2d9ee7e3..31fd50c16f 100644
--- a/zjit/src/hir.rs
+++ b/zjit/src/hir.rs
@@ -4144,6 +4144,9 @@ impl Function {
}
}
}
+ crate::stats::trace_compile_phase("canonicalize", ||
+ crate::stats::with_time_stat(Counter::compile_hir_canonicalize_time_ns, || self.canonicalize())
+ );
crate::stats::trace_compile_phase("infer_types", || self.infer_types());
}
@@ -4198,6 +4201,9 @@ impl Function {
}
}
}
+ crate::stats::trace_compile_phase("canonicalize", ||
+ crate::stats::with_time_stat(Counter::compile_hir_canonicalize_time_ns, || self.canonicalize())
+ );
crate::stats::trace_compile_phase("infer_types", || self.infer_types());
}
@@ -4486,6 +4492,9 @@ impl Function {
}
}
}
+ crate::stats::trace_compile_phase("canonicalize", ||
+ crate::stats::with_time_stat(Counter::compile_hir_canonicalize_time_ns, || self.canonicalize())
+ );
crate::stats::trace_compile_phase("infer_types", || self.infer_types());
}
@@ -4816,6 +4825,9 @@ impl Function {
self.push_insn_id(block, insn_id);
}
}
+ crate::stats::trace_compile_phase("canonicalize", ||
+ crate::stats::with_time_stat(Counter::compile_hir_canonicalize_time_ns, || self.canonicalize())
+ );
crate::stats::trace_compile_phase("infer_types", || self.infer_types());
}
@@ -4927,6 +4939,54 @@ impl Function {
.unwrap_or(insn_id)
}
+ /// Block-local canonicalize: rewrite each operand through union-find and a
+ /// per-block map of the most recent `Guard*` for that value. Forwards
+ /// guarded values into branch-edge args (so `infer_types` narrows merge-block
+ /// parameters and `fold_constants` drops redundant CFG-join guards) and
+ /// ordinary in-block uses.
+ ///
+ /// `Guard*` substitutions are unconditional within a block: a guard's
+ /// side-exit semantics guarantee the substituted value type holds for every
+ /// downstream use in the same block.
+ ///
+ /// `RefineType` is intentionally skipped: its narrowing is only valid on one
+ /// branch arm, which would require dropping refine-derived rewrites at each
+ /// `IfTrue`/`IfFalse`. Cross-arm refine forwarding is left for a follow-up
+ /// dominator-scoped pass.
+ ///
+ /// Inspired by Cranelift's aegraph canonicalize step
+ /// (<https://cfallin.org/blog/2026/04/09/aegraph/>).
+ fn canonicalize(&mut self) {
+ let mut rewrite_map: HashMap<InsnId, InsnId> = HashMap::new();
+ for block in self.rpo() {
+ rewrite_map.clear();
+ for i in 0..self.blocks[block.0].insns.len() {
+ let insn_id = self.blocks[block.0].insns[i];
+ let canonical_id = self.union_find.borrow().find_const(insn_id);
+
+ let union_find = &self.union_find;
+ self.insns[canonical_id.0].for_each_operand_mut(|operand| {
+ let canon = union_find.borrow().find_const(*operand);
+ *operand = rewrite_map.get(&canon).copied().unwrap_or(canon);
+ });
+
+ // For the binary guards only `left` is registered because their infer_type is
+ // type_of(left).
+ match &self.insns[canonical_id.0] {
+ Insn::GuardType { val: src, .. }
+ | Insn::GuardBitEquals { val: src, .. }
+ | Insn::GuardAnyBitSet { val: src, .. }
+ | Insn::GuardNoBitsSet { val: src, .. }
+ | Insn::GuardGreaterEq { left: src, .. }
+ | Insn::GuardLess { left: src, .. } => {
+ rewrite_map.insert(*src, canonical_id);
+ }
+ _ => {}
+ }
+ }
+ }
+ }
+
/// Use type information left by `infer_types` to fold away operations that can be evaluated at compile-time.
///
/// It can fold fixnum math, truthiness tests, and branches with constant conditionals.
@@ -5315,6 +5375,9 @@ impl Function {
changed = true;
}
if changed {
+ crate::stats::trace_compile_phase("canonicalize", ||
+ crate::stats::with_time_stat(Counter::compile_hir_canonicalize_time_ns, || self.canonicalize())
+ );
crate::stats::trace_compile_phase("infer_types", || self.infer_types());
}
}
diff --git a/zjit/src/hir/opt_tests.rs b/zjit/src/hir/opt_tests.rs
index 41fcfb1698..61b85e3a76 100644
--- a/zjit/src/hir/opt_tests.rs
+++ b/zjit/src/hir/opt_tests.rs
@@ -1129,7 +1129,7 @@ mod hir_opt_tests {
PatchPoint NoSingletonClass(CustomEq@0x1008)
PatchPoint MethodRedefined(CustomEq@0x1008, !=@0x1010, cme:0x1018)
v30:ObjectSubclass[class_exact:CustomEq] = GuardType v10, ObjectSubclass[class_exact:CustomEq]
- v31:BoolExact = CCallWithFrame v30, :BasicObject#!=@0x1040, v10
+ v31:BoolExact = CCallWithFrame v30, :BasicObject#!=@0x1040, v30
v21:NilClass = Const Value(nil)
CheckInterrupts
Return v21
@@ -1752,7 +1752,7 @@ mod hir_opt_tests {
v28:Fixnum[30] = Const Value(30)
v30:Fixnum[40] = Const Value(40)
v32:Fixnum[50] = Const Value(50)
- v34:BasicObject = Send v6, :target, v24, v26, v28, v30, v32 # SendFallbackReason: Argument count does not match parameter count
+ v34:BasicObject = Send v44, :target, v24, v26, v28, v30, v32 # SendFallbackReason: Argument count does not match parameter count
v37:ArrayExact = NewArray v45, v48, v34
CheckInterrupts
Return v37
@@ -4046,7 +4046,7 @@ mod hir_opt_tests {
v33:Fixnum[40] = Const Value(40)
v35:Fixnum[50] = Const Value(50)
v37:Fixnum[60] = Const Value(60)
- v39:BasicObject = Send v6, :target, v27, v29, v31, v33, v35, v37 # SendFallbackReason: Too many arguments for LIR
+ v39:BasicObject = Send v48, :target, v27, v29, v31, v33, v35, v37 # SendFallbackReason: Too many arguments for LIR
v41:ArrayExact = NewArray v49, v52, v39
CheckInterrupts
Return v41
@@ -4651,7 +4651,7 @@ mod hir_opt_tests {
v49:SetExact = GuardType v17, SetExact
v50:BasicObject = CCallVariadic v49, :Set#initialize@0x1068
CheckInterrupts
- Return v17
+ Return v49
");
}
@@ -5683,7 +5683,7 @@ mod hir_opt_tests {
WriteBarrier v28, v10
v33:CShape[0x1003] = Const CShape(0x1003)
StoreField v28, :_shape_id@0x1000, v33
- v14:HeapBasicObject = RefineType v6, HeapBasicObject
+ v14:HeapBasicObject = RefineType v28, HeapBasicObject
v17:Fixnum[2] = Const Value(2)
PatchPoint SingleRactorMode
StoreField v14, :@bar@0x1004, v17
@@ -6359,7 +6359,7 @@ mod hir_opt_tests {
PatchPoint NoSingletonClass(Array@0x1010)
PatchPoint MethodRedefined(Array@0x1010, to_s@0x1018, cme:0x1020)
v33:BasicObject = CCallWithFrame v28, :Array#to_s@0x1048
- v20:String = AnyToString v10, str: v33
+ v20:String = AnyToString v28, str: v33
v22:StringExact = StringConcat v14, v20
CheckInterrupts
Return v22
@@ -10089,7 +10089,7 @@ mod hir_opt_tests {
v33:CInt64 = AdjustBounds v32, v31
v34:CInt64[0] = Const CInt64(0)
v35:CInt64 = GuardGreaterEq v33, v34
- v36:Fixnum = StringGetbyte v28, v33
+ v36:Fixnum = StringGetbyte v28, v35
CheckInterrupts
Return v36
");
@@ -14643,7 +14643,7 @@ mod hir_opt_tests {
SetLocal :other_block, l0, EP@3, v40
v27:CPtr = GetEP 0
v28:BasicObject = LoadField v27, :other_block@0x1051
- v30:BasicObject = InvokeSuper v11, 0x1058, v28 # SendFallbackReason: super: complex argument passing to `super` call
+ v30:BasicObject = InvokeSuper v39, 0x1058, v28 # SendFallbackReason: super: complex argument passing to `super` call
CheckInterrupts
Return v30
");
@@ -15257,7 +15257,7 @@ mod hir_opt_tests {
v78:BasicObject = LoadField v75, :iter_method@0x1058
v79:BasicObject = LoadField v75, :kwsplat@0x1059
SetLocal :sep, l0, EP@5, v119
- Jump bb8(v58, v76, v119, v78, v79)
+ Jump bb8(v118, v76, v119, v78, v79)
bb8(v83:BasicObject, v84:BasicObject, v85:BasicObject, v86:BasicObject, v87:BasicObject):
PatchPoint SingleRactorMode
PatchPoint StableConstantNames(0x1060, CONST)
@@ -15360,7 +15360,7 @@ mod hir_opt_tests {
WriteBarrier v35, v13
v40:CShape[0x1003] = Const CShape(0x1003)
StoreField v35, :_shape_id@0x1000, v40
- v20:HeapBasicObject = RefineType v8, HeapBasicObject
+ v20:HeapBasicObject = RefineType v35, HeapBasicObject
PatchPoint NoEPEscape(initialize)
PatchPoint SingleRactorMode
WriteBarrier v20, v13
@@ -15408,7 +15408,7 @@ mod hir_opt_tests {
WriteBarrier v49, v16
v54:CShape[0x1003] = Const CShape(0x1003)
StoreField v49, :_shape_id@0x1000, v54
- v23:HeapBasicObject = RefineType v10, HeapBasicObject
+ v23:HeapBasicObject = RefineType v49, HeapBasicObject
v26:Fixnum[5] = Const Value(5)
PatchPoint NoEPEscape(initialize)
PatchPoint MethodRedefined(Integer@0x1008, +@0x1010, cme:0x1018)
@@ -15456,7 +15456,7 @@ mod hir_opt_tests {
WriteBarrier v43, v13
v48:CShape[0x1003] = Const CShape(0x1003)
StoreField v43, :_shape_id@0x1000, v48
- v20:HeapBasicObject = RefineType v8, HeapBasicObject
+ v20:HeapBasicObject = RefineType v43, HeapBasicObject
PatchPoint NoEPEscape(initialize)
PatchPoint SingleRactorMode
WriteBarrier v20, v13
@@ -15723,6 +15723,69 @@ mod hir_opt_tests {
}
#[test]
+ fn test_dedup_guard_type_across_cfg_join() {
+ eval("
+ def test(n, cond)
+ if cond
+ a = n + 1
+ else
+ a = n + 2
+ end
+ n + a
+ end
+ test(1, true); test(1, false)
+ ");
+ let hir = hir_string("test");
+ let guard_count = hir.matches("GuardType").count();
+ assert_eq!(
+ guard_count, 2,
+ "expected 2 GuardType instructions after cross-block dedup, found {guard_count}\n\nHIR:\n{hir}"
+ );
+ }
+
+ #[test]
+ fn test_forward_guard_through_conditional_branch() {
+ eval("
+ def test(n, a, b)
+ if a
+ if b
+ n + 1
+ else
+ n + 2
+ end
+ else
+ n + 3
+ end
+ end
+ test(1, true, true); test(1, true, false); test(1, false, false)
+ ");
+ let hir = hir_string("test");
+ let guard_count = hir.matches("GuardType").count();
+ assert!(
+ guard_count <= 3,
+ "expected at most 3 GuardType instructions (one per leaf branch) after forwarding through conditional branches, found {guard_count}\n\nHIR:\n{hir}"
+ );
+ }
+
+ #[test]
+ fn test_no_forward_when_no_guard_in_branches() {
+ let src = "
+ def test(n, cond)
+ a = if cond then 1 else 2 end
+ n + a
+ end
+ test(1, true); test(1, false)
+ ";
+ eval(src);
+ let hir = hir_string("test");
+ let guard_count = hir.matches("GuardType").count();
+ assert_eq!(
+ guard_count, 1,
+ "expected 1 GuardType (merge block only), found {guard_count}\n\nHIR:\n{hir}"
+ );
+ }
+
+ #[test]
fn test_infer_types_across_non_maximal_basic_blocks() {
// Previous worklist-based type inference only worked for maximal SSA. This is a regression
// test for hanging.
diff --git a/zjit/src/stats.rs b/zjit/src/stats.rs
index 522b74c48a..587dde0660 100644
--- a/zjit/src/stats.rs
+++ b/zjit/src/stats.rs
@@ -171,6 +171,7 @@ make_counters! {
compile_hir_build_time_ns,
compile_hir_strength_reduce_time_ns,
compile_hir_optimize_load_store_time_ns,
+ compile_hir_canonicalize_time_ns,
compile_hir_fold_constants_time_ns,
compile_hir_clean_cfg_time_ns,
compile_hir_remove_redundant_patch_points_time_ns,