diff options
| author | Max Bernstein <rubybugs@bernsteinbear.com> | 2026-01-21 14:23:29 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-01-21 19:23:29 +0000 |
| commit | 0cc4e212c47f15f9f5384fb9871a2da8a6276ed4 (patch) | |
| tree | d0513556218558ed9789a771427f1fe985d5eba0 | |
| parent | cfa97af7e1426c76b769495ef4b1689be3b0a685 (diff) | |
ZJIT: Get type information from branchif, branchunless, branchnil instructions (#15915)
Do a sort of "partial static single information (SSI)" form that learns
types of operands from branch instructions. A branchif, for example,
tells us that in the truthy path, we know the operand is not nil, and
not false. Similarly, in the falsy path, we know the operand is either
nil or false.
Add a RefineType instruction to attach this information.
This PR does this in SSA construction because it's pretty
straightforward, but we can also do a more aggressive version of this
that can learn information about e.g. int ranges from other checks later
in the optimization pipeline.
| -rw-r--r-- | zjit/src/codegen.rs | 1 | ||||
| -rw-r--r-- | zjit/src/hir.rs | 78 | ||||
| -rw-r--r-- | zjit/src/hir/opt_tests.rs | 132 | ||||
| -rw-r--r-- | zjit/src/hir/tests.rs | 170 | ||||
| -rw-r--r-- | zjit/src/hir_type/gen_hir_type.rb | 5 | ||||
| -rw-r--r-- | zjit/src/hir_type/hir_type.inc.rs | 11 | ||||
| -rw-r--r-- | zjit/src/hir_type/mod.rs | 60 |
7 files changed, 324 insertions, 133 deletions
diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs index d777002e31..0030493ddf 100644 --- a/zjit/src/codegen.rs +++ b/zjit/src/codegen.rs @@ -524,6 +524,7 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio &Insn::BoxFixnum { val, state } => gen_box_fixnum(jit, asm, opnd!(val), &function.frame_state(state)), &Insn::UnboxFixnum { val } => gen_unbox_fixnum(asm, opnd!(val)), Insn::Test { val } => gen_test(asm, opnd!(val)), + Insn::RefineType { val, .. } => opnd!(val), Insn::GuardType { val, guard_type, state } => gen_guard_type(jit, asm, opnd!(val), *guard_type, &function.frame_state(*state)), Insn::GuardTypeNot { val, guard_type, state } => gen_guard_type_not(jit, asm, opnd!(val), *guard_type, &function.frame_state(*state)), &Insn::GuardBitEquals { val, expected, reason, state } => gen_guard_bit_equals(jit, asm, opnd!(val), expected, reason, &function.frame_state(state)), diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index fc071e3d67..4326d37b34 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -994,6 +994,10 @@ pub enum Insn { ObjToString { val: InsnId, cd: *const rb_call_data, state: InsnId }, AnyToString { val: InsnId, str: InsnId, state: InsnId }, + /// Refine the known type information of with additional type information. + /// Computes the intersection of the existing type and the new type. + RefineType { val: InsnId, new_type: Type }, + /// Side-exit if val doesn't have the expected type. GuardType { val: InsnId, guard_type: Type, state: InsnId }, GuardTypeNot { val: InsnId, guard_type: Type, state: InsnId }, @@ -1212,6 +1216,7 @@ impl Insn { Insn::IncrCounterPtr { .. } => effects::Any, Insn::CheckInterrupts { .. } => effects::Any, Insn::InvokeProc { .. } => effects::Any, + Insn::RefineType { .. } => effects::Empty, } } @@ -1507,6 +1512,7 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> { Insn::FixnumLShift { left, right, .. } => { write!(f, "FixnumLShift {left}, {right}") }, Insn::FixnumRShift { left, right, .. } => { write!(f, "FixnumRShift {left}, {right}") }, Insn::GuardType { val, guard_type, .. } => { write!(f, "GuardType {val}, {}", guard_type.print(self.ptr_map)) }, + Insn::RefineType { val, new_type, .. } => { write!(f, "RefineType {val}, {}", new_type.print(self.ptr_map)) }, Insn::GuardTypeNot { val, guard_type, .. } => { write!(f, "GuardTypeNot {val}, {}", guard_type.print(self.ptr_map)) }, Insn::GuardBitEquals { val, expected, .. } => { write!(f, "GuardBitEquals {val}, {}", expected.print(self.ptr_map)) }, &Insn::GuardShape { val, shape, .. } => { write!(f, "GuardShape {val}, {:p}", self.ptr_map.map_shape(shape)) }, @@ -2174,6 +2180,7 @@ impl Function { Jump(target) => Jump(find_branch_edge!(target)), &IfTrue { val, ref target } => IfTrue { val: find!(val), target: find_branch_edge!(target) }, &IfFalse { val, ref target } => IfFalse { val: find!(val), target: find_branch_edge!(target) }, + &RefineType { val, new_type } => RefineType { val: find!(val), new_type }, &GuardType { val, guard_type, state } => GuardType { val: find!(val), guard_type, state }, &GuardTypeNot { val, guard_type, state } => GuardTypeNot { val: find!(val), guard_type, state }, &GuardBitEquals { val, expected, reason, state } => GuardBitEquals { val: find!(val), expected, reason, state }, @@ -2423,6 +2430,7 @@ impl Function { Insn::CCall { return_type, .. } => *return_type, &Insn::CCallVariadic { return_type, .. } => return_type, Insn::GuardType { val, guard_type, .. } => self.type_of(*val).intersection(*guard_type), + Insn::RefineType { val, new_type, .. } => self.type_of(*val).intersection(*new_type), Insn::GuardTypeNot { .. } => types::BasicObject, Insn::GuardBitEquals { val, expected, .. } => self.type_of(*val).intersection(Type::from_const(*expected)), Insn::GuardShape { val, .. } => self.type_of(*val), @@ -2594,6 +2602,7 @@ impl Function { | Insn::GuardTypeNot { val, .. } | Insn::GuardShape { val, .. } | Insn::GuardBitEquals { val, .. } => self.chase_insn(val), + | Insn::RefineType { val, .. } => self.chase_insn(val), _ => id, } } @@ -4445,6 +4454,7 @@ impl Function { worklist.extend(values); worklist.push_back(state); } + | &Insn::RefineType { val, .. } | &Insn::Return { val } | &Insn::Test { val } | &Insn::SetLocal { val, .. } @@ -5370,6 +5380,7 @@ impl Function { self.assert_subtype(insn_id, val, types::BasicObject)?; self.assert_subtype(insn_id, class, types::Class) } + Insn::RefineType { .. } => Ok(()), } } @@ -5562,6 +5573,19 @@ impl FrameState { state.stack.extend_from_slice(new_args); state } + + fn replace(&mut self, old: InsnId, new: InsnId) { + for slot in &mut self.stack { + if *slot == old { + *slot = new; + } + } + for slot in &mut self.locals { + if *slot == old { + *slot = new; + } + } + } } /// Print adaptor for [`FrameState`]. See [`PtrPrintMap`]. @@ -6245,10 +6269,17 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> { let test_id = fun.push_insn(block, Insn::Test { val }); let target_idx = insn_idx_at_offset(insn_idx, offset); let target = insn_idx_to_block[&target_idx]; + let nil_false_type = types::NilClass.union(types::FalseClass); + let nil_false = fun.push_insn(block, Insn::RefineType { val, new_type: nil_false_type }); + let mut iffalse_state = state.clone(); + iffalse_state.replace(val, nil_false); let _branch_id = fun.push_insn(block, Insn::IfFalse { val: test_id, - target: BranchEdge { target, args: state.as_args(self_param) } + target: BranchEdge { target, args: iffalse_state.as_args(self_param) } }); + let not_nil_false_type = types::BasicObject.subtract(types::NilClass).subtract(types::FalseClass); + let not_nil_false = fun.push_insn(block, Insn::RefineType { val, new_type: not_nil_false_type }); + state.replace(val, not_nil_false); queue.push_back((state.clone(), target, target_idx, local_inval)); } YARVINSN_branchif => { @@ -6258,10 +6289,17 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> { let test_id = fun.push_insn(block, Insn::Test { val }); let target_idx = insn_idx_at_offset(insn_idx, offset); let target = insn_idx_to_block[&target_idx]; + let not_nil_false_type = types::BasicObject.subtract(types::NilClass).subtract(types::FalseClass); + let not_nil_false = fun.push_insn(block, Insn::RefineType { val, new_type: not_nil_false_type }); + let mut iftrue_state = state.clone(); + iftrue_state.replace(val, not_nil_false); let _branch_id = fun.push_insn(block, Insn::IfTrue { val: test_id, - target: BranchEdge { target, args: state.as_args(self_param) } + target: BranchEdge { target, args: iftrue_state.as_args(self_param) } }); + let nil_false_type = types::NilClass.union(types::FalseClass); + let nil_false = fun.push_insn(block, Insn::RefineType { val, new_type: nil_false_type }); + state.replace(val, nil_false); queue.push_back((state.clone(), target, target_idx, local_inval)); } YARVINSN_branchnil => { @@ -6271,10 +6309,16 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> { let test_id = fun.push_insn(block, Insn::IsNil { val }); let target_idx = insn_idx_at_offset(insn_idx, offset); let target = insn_idx_to_block[&target_idx]; + let nil = fun.push_insn(block, Insn::Const { val: Const::Value(Qnil) }); + let mut iftrue_state = state.clone(); + iftrue_state.replace(val, nil); let _branch_id = fun.push_insn(block, Insn::IfTrue { val: test_id, - target: BranchEdge { target, args: state.as_args(self_param) } + target: BranchEdge { target, args: iftrue_state.as_args(self_param) } }); + let new_type = types::BasicObject.subtract(types::NilClass); + let not_nil = fun.push_insn(block, Insn::RefineType { val, new_type }); + state.replace(val, not_nil); queue.push_back((state.clone(), target, target_idx, local_inval)); } YARVINSN_opt_case_dispatch => { @@ -7693,21 +7737,23 @@ mod graphviz_tests { <TR><TD ALIGN="left" PORT="v12">PatchPoint NoTracePoint </TD></TR> <TR><TD ALIGN="left" PORT="v14">CheckInterrupts </TD></TR> <TR><TD ALIGN="left" PORT="v15">v15:CBool = Test v9 </TD></TR> - <TR><TD ALIGN="left" PORT="v16">IfFalse v15, bb3(v8, v9) </TD></TR> - <TR><TD ALIGN="left" PORT="v18">PatchPoint NoTracePoint </TD></TR> - <TR><TD ALIGN="left" PORT="v19">v19:Fixnum[3] = Const Value(3) </TD></TR> - <TR><TD ALIGN="left" PORT="v21">PatchPoint NoTracePoint </TD></TR> - <TR><TD ALIGN="left" PORT="v22">CheckInterrupts </TD></TR> - <TR><TD ALIGN="left" PORT="v23">Return v19 </TD></TR> + <TR><TD ALIGN="left" PORT="v16">v16:Falsy = RefineType v9, Falsy </TD></TR> + <TR><TD ALIGN="left" PORT="v17">IfFalse v15, bb3(v8, v16) </TD></TR> + <TR><TD ALIGN="left" PORT="v18">v18:Truthy = RefineType v9, Truthy </TD></TR> + <TR><TD ALIGN="left" PORT="v20">PatchPoint NoTracePoint </TD></TR> + <TR><TD ALIGN="left" PORT="v21">v21:Fixnum[3] = Const Value(3) </TD></TR> + <TR><TD ALIGN="left" PORT="v23">PatchPoint NoTracePoint </TD></TR> + <TR><TD ALIGN="left" PORT="v24">CheckInterrupts </TD></TR> + <TR><TD ALIGN="left" PORT="v25">Return v21 </TD></TR> </TABLE>>]; - bb2:v16 -> bb3:params:n; + bb2:v17 -> bb3:params:n; bb3 [label=<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0"> - <TR><TD ALIGN="LEFT" PORT="params" BGCOLOR="gray">bb3(v24:BasicObject, v25:BasicObject) </TD></TR> - <TR><TD ALIGN="left" PORT="v28">PatchPoint NoTracePoint </TD></TR> - <TR><TD ALIGN="left" PORT="v29">v29:Fixnum[4] = Const Value(4) </TD></TR> - <TR><TD ALIGN="left" PORT="v31">PatchPoint NoTracePoint </TD></TR> - <TR><TD ALIGN="left" PORT="v32">CheckInterrupts </TD></TR> - <TR><TD ALIGN="left" PORT="v33">Return v29 </TD></TR> + <TR><TD ALIGN="LEFT" PORT="params" BGCOLOR="gray">bb3(v26:BasicObject, v27:Falsy) </TD></TR> + <TR><TD ALIGN="left" PORT="v30">PatchPoint NoTracePoint </TD></TR> + <TR><TD ALIGN="left" PORT="v31">v31:Fixnum[4] = Const Value(4) </TD></TR> + <TR><TD ALIGN="left" PORT="v33">PatchPoint NoTracePoint </TD></TR> + <TR><TD ALIGN="left" PORT="v34">CheckInterrupts </TD></TR> + <TR><TD ALIGN="left" PORT="v35">Return v31 </TD></TR> </TABLE>>]; } "#); diff --git a/zjit/src/hir/opt_tests.rs b/zjit/src/hir/opt_tests.rs index b7595f1b27..da24025010 100644 --- a/zjit/src/hir/opt_tests.rs +++ b/zjit/src/hir/opt_tests.rs @@ -52,9 +52,10 @@ mod hir_opt_tests { bb2(v8:BasicObject, v9:NilClass): v13:TrueClass = Const Value(true) CheckInterrupts - v23:Fixnum[3] = Const Value(3) + v22:TrueClass = RefineType v13, Truthy + v25:Fixnum[3] = Const Value(3) CheckInterrupts - Return v23 + Return v25 "); } @@ -84,9 +85,10 @@ mod hir_opt_tests { bb2(v8:BasicObject, v9:NilClass): v13:FalseClass = Const Value(false) CheckInterrupts - v33:Fixnum[4] = Const Value(4) + v20:FalseClass = RefineType v13, Falsy + v35:Fixnum[4] = Const Value(4) CheckInterrupts - Return v33 + Return v35 "); } @@ -267,12 +269,12 @@ mod hir_opt_tests { v10:Fixnum[1] = Const Value(1) v12:Fixnum[2] = Const Value(2) PatchPoint MethodRedefined(Integer@0x1000, <@0x1008, cme:0x1010) - v40:TrueClass = Const Value(true) + v42:TrueClass = Const Value(true) IncrCounter inline_cfunc_optimized_send_count CheckInterrupts - v22:Fixnum[3] = Const Value(3) + v24:Fixnum[3] = Const Value(3) CheckInterrupts - Return v22 + Return v24 "); } @@ -300,18 +302,18 @@ mod hir_opt_tests { v10:Fixnum[1] = Const Value(1) v12:Fixnum[2] = Const Value(2) PatchPoint MethodRedefined(Integer@0x1000, <=@0x1008, cme:0x1010) - v55:TrueClass = Const Value(true) + v59:TrueClass = Const Value(true) IncrCounter inline_cfunc_optimized_send_count CheckInterrupts - v21:Fixnum[2] = Const Value(2) v23:Fixnum[2] = Const Value(2) + v25:Fixnum[2] = Const Value(2) PatchPoint MethodRedefined(Integer@0x1000, <=@0x1008, cme:0x1010) - v57:TrueClass = Const Value(true) + v61:TrueClass = Const Value(true) IncrCounter inline_cfunc_optimized_send_count CheckInterrupts - v33:Fixnum[3] = Const Value(3) + v37:Fixnum[3] = Const Value(3) CheckInterrupts - Return v33 + Return v37 "); } @@ -339,12 +341,12 @@ mod hir_opt_tests { v10:Fixnum[2] = Const Value(2) v12:Fixnum[1] = Const Value(1) PatchPoint MethodRedefined(Integer@0x1000, >@0x1008, cme:0x1010) - v40:TrueClass = Const Value(true) + v42:TrueClass = Const Value(true) IncrCounter inline_cfunc_optimized_send_count CheckInterrupts - v22:Fixnum[3] = Const Value(3) + v24:Fixnum[3] = Const Value(3) CheckInterrupts - Return v22 + Return v24 "); } @@ -372,18 +374,18 @@ mod hir_opt_tests { v10:Fixnum[2] = Const Value(2) v12:Fixnum[1] = Const Value(1) PatchPoint MethodRedefined(Integer@0x1000, >=@0x1008, cme:0x1010) - v55:TrueClass = Const Value(true) + v59:TrueClass = Const Value(true) IncrCounter inline_cfunc_optimized_send_count CheckInterrupts - v21:Fixnum[2] = Const Value(2) v23:Fixnum[2] = Const Value(2) + v25:Fixnum[2] = Const Value(2) PatchPoint MethodRedefined(Integer@0x1000, >=@0x1008, cme:0x1010) - v57:TrueClass = Const Value(true) + v61:TrueClass = Const Value(true) IncrCounter inline_cfunc_optimized_send_count CheckInterrupts - v33:Fixnum[3] = Const Value(3) + v37:Fixnum[3] = Const Value(3) CheckInterrupts - Return v33 + Return v37 "); } @@ -411,12 +413,12 @@ mod hir_opt_tests { v10:Fixnum[1] = Const Value(1) v12:Fixnum[2] = Const Value(2) PatchPoint MethodRedefined(Integer@0x1000, ==@0x1008, cme:0x1010) - v40:FalseClass = Const Value(false) + v42:FalseClass = Const Value(false) IncrCounter inline_cfunc_optimized_send_count CheckInterrupts - v31:Fixnum[4] = Const Value(4) + v33:Fixnum[4] = Const Value(4) CheckInterrupts - Return v31 + Return v33 "); } @@ -444,12 +446,12 @@ mod hir_opt_tests { v10:Fixnum[2] = Const Value(2) v12:Fixnum[2] = Const Value(2) PatchPoint MethodRedefined(Integer@0x1000, ==@0x1008, cme:0x1010) - v40:TrueClass = Const Value(true) + v42:TrueClass = Const Value(true) IncrCounter inline_cfunc_optimized_send_count CheckInterrupts - v22:Fixnum[3] = Const Value(3) + v24:Fixnum[3] = Const Value(3) CheckInterrupts - Return v22 + Return v24 "); } @@ -478,12 +480,12 @@ mod hir_opt_tests { v12:Fixnum[2] = Const Value(2) PatchPoint MethodRedefined(Integer@0x1000, !=@0x1008, cme:0x1010) PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_EQ) - v41:TrueClass = Const Value(true) + v43:TrueClass = Const Value(true) IncrCounter inline_cfunc_optimized_send_count CheckInterrupts - v22:Fixnum[3] = Const Value(3) + v24:Fixnum[3] = Const Value(3) CheckInterrupts - Return v22 + Return v24 "); } @@ -512,12 +514,12 @@ mod hir_opt_tests { v12:Fixnum[2] = Const Value(2) PatchPoint MethodRedefined(Integer@0x1000, !=@0x1008, cme:0x1010) PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, BOP_EQ) - v41:FalseClass = Const Value(false) + v43:FalseClass = Const Value(false) IncrCounter inline_cfunc_optimized_send_count CheckInterrupts - v31:Fixnum[4] = Const Value(4) + v33:Fixnum[4] = Const Value(4) CheckInterrupts - Return v31 + Return v33 "); } @@ -4992,8 +4994,9 @@ mod hir_opt_tests { bb2(v8:BasicObject, v9:NilClass): v13:NilClass = Const Value(nil) CheckInterrupts + v21:NilClass = Const Value(nil) CheckInterrupts - Return v13 + Return v21 "); } @@ -5020,10 +5023,11 @@ mod hir_opt_tests { bb2(v8:BasicObject, v9:NilClass): v13:Fixnum[1] = Const Value(1) CheckInterrupts + v23:Fixnum[1] = RefineType v13, NotNil PatchPoint MethodRedefined(Integer@0x1000, itself@0x1008, cme:0x1010) IncrCounter inline_cfunc_optimized_send_count CheckInterrupts - Return v13 + Return v23 "); } @@ -5840,20 +5844,22 @@ mod hir_opt_tests { bb2(v8:BasicObject, v9:BasicObject): CheckInterrupts v15:CBool = Test v9 - IfFalse v15, bb3(v8, v9) - v18:FalseClass = Const Value(false) + v16:Falsy = RefineType v9, Falsy + IfFalse v15, bb3(v8, v16) + v18:Truthy = RefineType v9, Truthy + v20:FalseClass = Const Value(false) CheckInterrupts - Jump bb4(v8, v9, v18) - bb3(v22:BasicObject, v23:BasicObject): - v26:NilClass = Const Value(nil) - Jump bb4(v22, v23, v26) - bb4(v28:BasicObject, v29:BasicObject, v30:NilClass|FalseClass): + Jump bb4(v8, v18, v20) + bb3(v24:BasicObject, v25:Falsy): + v28:NilClass = Const Value(nil) + Jump bb4(v24, v25, v28) + bb4(v30:BasicObject, v31:BasicObject, v32:Falsy): PatchPoint MethodRedefined(NilClass@0x1000, !@0x1008, cme:0x1010) - v41:NilClass = GuardType v30, NilClass - v42:TrueClass = Const Value(true) + v43:NilClass = GuardType v32, NilClass + v44:TrueClass = Const Value(true) IncrCounter inline_cfunc_optimized_send_count CheckInterrupts - Return v42 + Return v44 "); } @@ -10059,9 +10065,9 @@ mod hir_opt_tests { bb2(v6:BasicObject): PatchPoint NoSingletonClass(C@0x1000) PatchPoint MethodRedefined(C@0x1000, class@0x1008, cme:0x1010) - v40:HeapObject[class_exact:C] = GuardType v6, HeapObject[class_exact:C] + v42:HeapObject[class_exact:C] = GuardType v6, HeapObject[class_exact:C] IncrCounter inline_iseq_optimized_send_count - v44:Class[C@0x1000] = Const Value(VALUE(0x1000)) + v46:Class[C@0x1000] = Const Value(VALUE(0x1000)) IncrCounter inline_cfunc_optimized_send_count v13:StaticSymbol[:_lex_actions] = Const Value(VALUE(0x1038)) v15:TrueClass = Const Value(true) @@ -10069,12 +10075,12 @@ mod hir_opt_tests { PatchPoint MethodRedefined(Class@0x1040, respond_to?@0x1048, cme:0x1050) PatchPoint NoSingletonClass(Class@0x1040) PatchPoint MethodRedefined(Class@0x1040, _lex_actions@0x1078, cme:0x1080) - v52:TrueClass = Const Value(true) + v54:TrueClass = Const Value(true) IncrCounter inline_cfunc_optimized_send_count CheckInterrupts - v24:StaticSymbol[:CORRECT] = Const Value(VALUE(0x10a8)) + v26:StaticSymbol[:CORRECT] = Const Value(VALUE(0x10a8)) CheckInterrupts - Return v24 + Return v26 "); } @@ -10230,23 +10236,23 @@ mod hir_opt_tests { CheckInterrupts SetLocal :formatted, l0, EP@3, v15 PatchPoint SingleRactorMode - v54:HeapBasicObject = GuardType v14, HeapBasicObject - v55:CShape = LoadField v54, :_shape_id@0x1000 - v56:CShape[0x1001] = GuardBitEquals v55, CShape(0x1001) - StoreField v54, :@formatted@0x1002, v15 - WriteBarrier v54, v15 - v59:CShape[0x1003] = Const CShape(0x1003) - StoreField v54, :_shape_id@0x1000, v59 - v43:Class[VMFrozenCore] = Const Value(VALUE(0x1008)) + v56:HeapBasicObject = GuardType v14, HeapBasicObject + v57:CShape = LoadField v56, :_shape_id@0x1000 + v58:CShape[0x1001] = GuardBitEquals v57, CShape(0x1001) + StoreField v56, :@formatted@0x1002, v15 + WriteBarrier v56, v15 + v61:CShape[0x1003] = Const CShape(0x1003) + StoreField v56, :_shape_id@0x1000, v61 + v45:Class[VMFrozenCore] = Const Value(VALUE(0x1008)) PatchPoint NoSingletonClass(Class@0x1010) PatchPoint MethodRedefined(Class@0x1010, lambda@0x1018, cme:0x1020) - v64:BasicObject = CCallWithFrame v43, :RubyVM::FrozenCore.lambda@0x1048, block=0x1050 - v46:BasicObject = GetLocal :a, l0, EP@6 - v47:BasicObject = GetLocal :_b, l0, EP@5 - v48:BasicObject = GetLocal :_c, l0, EP@4 - v49:BasicObject = GetLocal :formatted, l0, EP@3 + v66:BasicObject = CCallWithFrame v45, :RubyVM::FrozenCore.lambda@0x1048, block=0x1050 + v48:BasicObject = GetLocal :a, l0, EP@6 + v49:BasicObject = GetLocal :_b, l0, EP@5 + v50:BasicObject = GetLocal :_c, l0, EP@4 + v51:BasicObject = GetLocal :formatted, l0, EP@3 CheckInterrupts - Return v64 + Return v66 "); } diff --git a/zjit/src/hir/tests.rs b/zjit/src/hir/tests.rs index 3b0f591599..3e28178273 100644 --- a/zjit/src/hir/tests.rs +++ b/zjit/src/hir/tests.rs @@ -1083,14 +1083,16 @@ pub mod hir_build_tests { v10:TrueClass|NilClass = DefinedIvar v6, :@foo CheckInterrupts v13:CBool = Test v10 + v14:NilClass = RefineType v10, Falsy IfFalse v13, bb3(v6) - v17:Fixnum[3] = Const Value(3) + v16:TrueClass = RefineType v10, Truthy + v19:Fixnum[3] = Const Value(3) CheckInterrupts - Return v17 - bb3(v22:BasicObject): - v26:Fixnum[4] = Const Value(4) + Return v19 + bb3(v24:BasicObject): + v28:Fixnum[4] = Const Value(4) CheckInterrupts - Return v26 + Return v28 "); } @@ -1146,14 +1148,16 @@ pub mod hir_build_tests { bb2(v8:BasicObject, v9:BasicObject): CheckInterrupts v15:CBool = Test v9 - IfFalse v15, bb3(v8, v9) - v19:Fixnum[3] = Const Value(3) + v16:Falsy = RefineType v9, Falsy + IfFalse v15, bb3(v8, v16) + v18:Truthy = RefineType v9, Truthy + v21:Fixnum[3] = Const Value(3) CheckInterrupts - Return v19 - bb3(v24:BasicObject, v25:BasicObject): - v29:Fixnum[4] = Const Value(4) + Return v21 + bb3(v26:BasicObject, v27:Falsy): + v31:Fixnum[4] = Const Value(4) CheckInterrupts - Return v29 + Return v31 "); } @@ -1184,16 +1188,18 @@ pub mod hir_build_tests { bb2(v10:BasicObject, v11:BasicObject, v12:NilClass): CheckInterrupts v18:CBool = Test v11 - IfFalse v18, bb3(v10, v11, v12) - v22:Fixnum[3] = Const Value(3) + v19:Falsy = RefineType v11, Falsy + IfFalse v18, bb3(v10, v19, v12) + v21:Truthy = RefineType v11, Truthy + v24:Fixnum[3] = Const Value(3) CheckInterrupts - Jump bb4(v10, v11, v22) - bb3(v27:BasicObject, v28:BasicObject, v29:NilClass): - v33:Fixnum[4] = Const Value(4) - Jump bb4(v27, v28, v33) - bb4(v36:BasicObject, v37:BasicObject, v38:Fixnum): + Jump bb4(v10, v21, v24) + bb3(v29:BasicObject, v30:Falsy, v31:NilClass): + v35:Fixnum[4] = Const Value(4) + Jump bb4(v29, v30, v35) + bb4(v38:BasicObject, v39:BasicObject, v40:Fixnum): CheckInterrupts - Return v38 + Return v40 "); } @@ -1484,16 +1490,18 @@ pub mod hir_build_tests { v35:BasicObject = SendWithoutBlock v28, :>, v32 # SendFallbackReason: Uncategorized(opt_gt) CheckInterrupts v38:CBool = Test v35 + v39:Truthy = RefineType v35, Truthy IfTrue v38, bb3(v26, v27, v28) - v41:NilClass = Const Value(nil) + v41:Falsy = RefineType v35, Falsy + v43:NilClass = Const Value(nil) CheckInterrupts Return v27 - bb3(v49:BasicObject, v50:BasicObject, v51:BasicObject): - v56:Fixnum[1] = Const Value(1) - v59:BasicObject = SendWithoutBlock v50, :+, v56 # SendFallbackReason: Uncategorized(opt_plus) - v64:Fixnum[1] = Const Value(1) - v67:BasicObject = SendWithoutBlock v51, :-, v64 # SendFallbackReason: Uncategorized(opt_minus) - Jump bb4(v49, v59, v67) + bb3(v51:BasicObject, v52:BasicObject, v53:BasicObject): + v58:Fixnum[1] = Const Value(1) + v61:BasicObject = SendWithoutBlock v52, :+, v58 # SendFallbackReason: Uncategorized(opt_plus) + v66:Fixnum[1] = Const Value(1) + v69:BasicObject = SendWithoutBlock v53, :-, v66 # SendFallbackReason: Uncategorized(opt_minus) + Jump bb4(v51, v61, v69) "); } @@ -1549,14 +1557,16 @@ pub mod hir_build_tests { v13:TrueClass = Const Value(true) CheckInterrupts v19:CBool[true] = Test v13 - IfFalse v19, bb3(v8, v13) - v23:Fixnum[3] = Const Value(3) + v20 = RefineType v13, Falsy + IfFalse v19, bb3(v8, v20) + v22:TrueClass = RefineType v13, Truthy + v25:Fixnum[3] = Const Value(3) CheckInterrupts - Return v23 - bb3(v28, v29): - v33 = Const Value(4) + Return v25 + bb3(v30, v31): + v35 = Const Value(4) CheckInterrupts - Return v33 + Return v35 "); } @@ -3090,12 +3100,60 @@ pub mod hir_build_tests { bb2(v8:BasicObject, v9:BasicObject): CheckInterrupts v16:CBool = IsNil v9 - IfTrue v16, bb3(v8, v9, v9) - v19:BasicObject = SendWithoutBlock v9, :itself # SendFallbackReason: Uncategorized(opt_send_without_block) - Jump bb3(v8, v9, v19) - bb3(v21:BasicObject, v22:BasicObject, v23:BasicObject): + v17:NilClass = Const Value(nil) + IfTrue v16, bb3(v8, v17, v17) + v19:NotNil = RefineType v9, NotNil + v21:BasicObject = SendWithoutBlock v19, :itself # SendFallbackReason: Uncategorized(opt_send_without_block) + Jump bb3(v8, v19, v21) + bb3(v23:BasicObject, v24:BasicObject, v25:BasicObject): CheckInterrupts - Return v23 + Return v25 + "); + } + + #[test] + fn test_infer_nilability_from_branchif() { + eval(" + def test(x) + if x + x&.itself + else + 4 + end + end + "); + assert_contains_opcode("test", YARVINSN_branchnil); + // Note that IsNil has as its operand a value that we know statically *cannot* be nil + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:3: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal :x, l0, SP@4 + Jump bb2(v1, v2) + bb1(v5:BasicObject, v6:BasicObject): + EntryPoint JIT(0) + Jump bb2(v5, v6) + bb2(v8:BasicObject, v9:BasicObject): + CheckInterrupts + v15:CBool = Test v9 + v16:Falsy = RefineType v9, Falsy + IfFalse v15, bb3(v8, v16) + v18:Truthy = RefineType v9, Truthy + CheckInterrupts + v24:CBool[false] = IsNil v18 + v25:NilClass = Const Value(nil) + IfTrue v24, bb4(v8, v25, v25) + v27:Truthy = RefineType v18, NotNil + v29:BasicObject = SendWithoutBlock v27, :itself # SendFallbackReason: Uncategorized(opt_send_without_block) + CheckInterrupts + Return v29 + bb3(v34:BasicObject, v35:Falsy): + v39:Fixnum[4] = Const Value(4) + Jump bb4(v34, v35, v39) + bb4(v41:BasicObject, v42:Falsy, v43:Fixnum[4]): + CheckInterrupts + Return v43 "); } @@ -3174,14 +3232,16 @@ pub mod hir_build_tests { v32:HeapObject[BlockParamProxy] = Const Value(VALUE(0x1000)) CheckInterrupts v35:CBool[true] = Test v32 + v36 = RefineType v32, Falsy IfFalse v35, bb3(v16, v17, v18, v19, v20, v25) - v40:BasicObject = InvokeBlock, v25 # SendFallbackReason: Uncategorized(invokeblock) - v43:BasicObject = InvokeBuiltin dir_s_close, v16, v25 + v38:HeapObject[BlockParamProxy] = RefineType v32, Truthy + v42:BasicObject = InvokeBlock, v25 # SendFallbackReason: Uncategorized(invokeblock) + v45:BasicObject = InvokeBuiltin dir_s_close, v16, v25 CheckInterrupts - Return v40 - bb3(v49, v50, v51, v52, v53, v54): + Return v42 + bb3(v51, v52, v53, v54, v55, v56): CheckInterrupts - Return v54 + Return v56 "); } @@ -3302,14 +3362,16 @@ pub mod hir_build_tests { v21:BasicObject = SendWithoutBlock v9, :[], v16, v18 # SendFallbackReason: Uncategorized(opt_send_without_block) CheckInterrupts v25:CBool = Test v21 - IfTrue v25, bb3(v8, v9, v13, v9, v16, v18, v21) - v29:Fixnum[2] = Const Value(2) - v32:BasicObject = SendWithoutBlock v9, :[]=, v16, v18, v29 # SendFallbackReason: Uncategorized(opt_send_without_block) + v26:Truthy = RefineType v21, Truthy + IfTrue v25, bb3(v8, v9, v13, v9, v16, v18, v26) + v28:Falsy = RefineType v21, Falsy + v31:Fixnum[2] = Const Value(2) + v34:BasicObject = SendWithoutBlock v9, :[]=, v16, v18, v31 # SendFallbackReason: Uncategorized(opt_send_without_block) CheckInterrupts - Return v29 - bb3(v38:BasicObject, v39:BasicObject, v40:NilClass, v41:BasicObject, v42:Fixnum[0], v43:Fixnum[1], v44:BasicObject): + Return v31 + bb3(v40:BasicObject, v41:BasicObject, v42:NilClass, v43:BasicObject, v44:Fixnum[0], v45:Fixnum[1], v46:Truthy): CheckInterrupts - Return v44 + Return v46 "); } @@ -3652,14 +3714,16 @@ pub mod hir_build_tests { v15:BoolExact = FixnumBitCheck v12, 0 CheckInterrupts v18:CBool = Test v15 + v19:TrueClass = RefineType v15, Truthy IfTrue v18, bb3(v10, v11, v12) - v21:Fixnum[1] = Const Value(1) + v21:FalseClass = RefineType v15, Falsy v23:Fixnum[1] = Const Value(1) - v26:BasicObject = SendWithoutBlock v21, :+, v23 # SendFallbackReason: Uncategorized(opt_plus) - Jump bb3(v10, v26, v12) - bb3(v29:BasicObject, v30:BasicObject, v31:BasicObject): + v25:Fixnum[1] = Const Value(1) + v28:BasicObject = SendWithoutBlock v23, :+, v25 # SendFallbackReason: Uncategorized(opt_plus) + Jump bb3(v10, v28, v12) + bb3(v31:BasicObject, v32:BasicObject, v33:BasicObject): CheckInterrupts - Return v30 + Return v32 "); } diff --git a/zjit/src/hir_type/gen_hir_type.rb b/zjit/src/hir_type/gen_hir_type.rb index 9576d2b1c0..f952a8b715 100644 --- a/zjit/src/hir_type/gen_hir_type.rb +++ b/zjit/src/hir_type/gen_hir_type.rb @@ -178,10 +178,15 @@ add_union "BuiltinExact", $builtin_exact add_union "Subclass", $subclass add_union "BoolExact", [true_exact.name, false_exact.name] add_union "Immediate", [fixnum.name, flonum.name, static_sym.name, nil_exact.name, true_exact.name, false_exact.name, undef_.name] +add_union "Falsy", [nil_exact.name, false_exact.name] $bits["HeapBasicObject"] = ["BasicObject & !Immediate"] $numeric_bits["HeapBasicObject"] = $numeric_bits["BasicObject"] & ~$numeric_bits["Immediate"] $bits["HeapObject"] = ["Object & !Immediate"] $numeric_bits["HeapObject"] = $numeric_bits["Object"] & ~$numeric_bits["Immediate"] +$bits["Truthy"] = ["BasicObject & !Falsy"] +$numeric_bits["Truthy"] = $numeric_bits["BasicObject"] & ~$numeric_bits["Falsy"] +$bits["NotNil"] = ["BasicObject & !NilClass"] +$numeric_bits["NotNil"] = $numeric_bits["BasicObject"] & ~$numeric_bits["NilClass"] # ===== Finished generating the DAG; write Rust code ===== diff --git a/zjit/src/hir_type/hir_type.inc.rs b/zjit/src/hir_type/hir_type.inc.rs index b388b3a0d1..886b4b54dd 100644 --- a/zjit/src/hir_type/hir_type.inc.rs +++ b/zjit/src/hir_type/hir_type.inc.rs @@ -32,6 +32,7 @@ mod bits { pub const DynamicSymbol: u64 = 1u64 << 20; pub const Empty: u64 = 0u64; pub const FalseClass: u64 = 1u64 << 21; + pub const Falsy: u64 = FalseClass | NilClass; pub const Fixnum: u64 = 1u64 << 22; pub const Float: u64 = Flonum | HeapFloat; pub const Flonum: u64 = 1u64 << 23; @@ -47,6 +48,7 @@ mod bits { pub const ModuleExact: u64 = 1u64 << 27; pub const ModuleSubclass: u64 = 1u64 << 28; pub const NilClass: u64 = 1u64 << 29; + pub const NotNil: u64 = BasicObject & !NilClass; pub const Numeric: u64 = Float | Integer | NumericExact | NumericSubclass; pub const NumericExact: u64 = 1u64 << 30; pub const NumericSubclass: u64 = 1u64 << 31; @@ -70,14 +72,17 @@ mod bits { pub const Subclass: u64 = ArraySubclass | BasicObjectSubclass | HashSubclass | ModuleSubclass | NumericSubclass | ObjectSubclass | RangeSubclass | RegexpSubclass | SetSubclass | StringSubclass; pub const Symbol: u64 = DynamicSymbol | StaticSymbol; pub const TrueClass: u64 = 1u64 << 43; + pub const Truthy: u64 = BasicObject & !Falsy; pub const Undef: u64 = 1u64 << 44; - pub const AllBitPatterns: [(&str, u64); 71] = [ + pub const AllBitPatterns: [(&str, u64); 74] = [ ("Any", Any), ("RubyValue", RubyValue), ("Immediate", Immediate), ("Undef", Undef), ("BasicObject", BasicObject), ("Object", Object), + ("NotNil", NotNil), + ("Truthy", Truthy), ("BuiltinExact", BuiltinExact), ("BoolExact", BoolExact), ("TrueClass", TrueClass), @@ -103,6 +108,7 @@ mod bits { ("Numeric", Numeric), ("NumericSubclass", NumericSubclass), ("NumericExact", NumericExact), + ("Falsy", Falsy), ("NilClass", NilClass), ("Module", Module), ("ModuleSubclass", ModuleSubclass), @@ -180,6 +186,7 @@ pub mod types { pub const DynamicSymbol: Type = Type::from_bits(bits::DynamicSymbol); pub const Empty: Type = Type::from_bits(bits::Empty); pub const FalseClass: Type = Type::from_bits(bits::FalseClass); + pub const Falsy: Type = Type::from_bits(bits::Falsy); pub const Fixnum: Type = Type::from_bits(bits::Fixnum); pub const Float: Type = Type::from_bits(bits::Float); pub const Flonum: Type = Type::from_bits(bits::Flonum); @@ -195,6 +202,7 @@ pub mod types { pub const ModuleExact: Type = Type::from_bits(bits::ModuleExact); pub const ModuleSubclass: Type = Type::from_bits(bits::ModuleSubclass); pub const NilClass: Type = Type::from_bits(bits::NilClass); + pub const NotNil: Type = Type::from_bits(bits::NotNil); pub const Numeric: Type = Type::from_bits(bits::Numeric); pub const NumericExact: Type = Type::from_bits(bits::NumericExact); pub const NumericSubclass: Type = Type::from_bits(bits::NumericSubclass); @@ -218,6 +226,7 @@ pub mod types { pub const Subclass: Type = Type::from_bits(bits::Subclass); pub const Symbol: Type = Type::from_bits(bits::Symbol); pub const TrueClass: Type = Type::from_bits(bits::TrueClass); + pub const Truthy: Type = Type::from_bits(bits::Truthy); pub const Undef: Type = Type::from_bits(bits::Undef); pub const ExactBitsAndClass: [(u64, *const VALUE); 17] = [ (bits::ObjectExact, &raw const crate::cruby::rb_cObject), diff --git a/zjit/src/hir_type/mod.rs b/zjit/src/hir_type/mod.rs index cc6a208bcd..1f7526915c 100644 --- a/zjit/src/hir_type/mod.rs +++ b/zjit/src/hir_type/mod.rs @@ -453,6 +453,25 @@ impl Type { types::Empty } + /// Subtract `other` from `self`, preserving specialization if possible. + pub fn subtract(&self, other: Type) -> Type { + // If self is a subtype of other, the result is empty (no negative types). + if self.is_subtype(other) { return types::Empty; } + // Self is not a subtype of other. That means either: + // * Their type bits do not overlap at all (eg Int vs String) + // * Their type bits overlap but self's specialization is not a subtype of other's (eg + // Fixnum[5] vs Fixnum[4]) + // Check for the latter case, returning self unchanged if so. + if !self.spec_is_subtype_of(other) { + return *self; + } + // Now self is either a supertype of other (eg Object vs String or Fixnum vs Fixnum[5]) or + // their type bits do not overlap at all (eg Int vs String). + // Just subtract the bits and keep self's specialization. + let bits = self.bits & !other.bits; + Type { bits, spec: self.spec } + } + pub fn could_be(&self, other: Type) -> bool { !self.intersection(other).bit_equal(types::Empty) } @@ -1060,4 +1079,45 @@ mod tests { assert!(!types::CBool.has_value(Const::CBool(true))); assert!(!types::CShape.has_value(Const::CShape(crate::cruby::ShapeId(0x1234)))); } + + #[test] + fn test_subtract_with_superset_returns_empty() { + let left = types::NilClass; + let right = types::BasicObject; + let result = left.subtract(right); + assert_bit_equal(result, types::Empty); + } + + #[test] + fn test_subtract_with_subset_removes_bits() { + let left = types::BasicObject; + let right = types::NilClass; + let result = left.subtract(right); + assert_subtype(result, types::BasicObject); + assert_not_subtype(types::NilClass, result); + } + + #[test] + fn test_subtract_with_no_overlap_returns_self() { + let left = types::Fixnum; + let right = types::StringExact; + let result = left.subtract(right); + assert_bit_equal(result, left); + } + + #[test] + fn test_subtract_with_no_specialization_overlap_returns_self() { + let left = Type::fixnum(4); + let right = Type::fixnum(5); + let result = left.subtract(right); + assert_bit_equal(result, left); + } + + #[test] + fn test_subtract_with_specialization_subset_removes_specialization() { + let left = types::Fixnum; + let right = Type::fixnum(42); + let result = left.subtract(right); + assert_bit_equal(result, types::Fixnum); + } } |
