From e0bb3fb1cda2238d0c98afcdec2fe282c29994aa Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Fri, 21 Nov 2025 08:57:26 -0800 Subject: ZJIT: Inline Integer#<< for constant rhs (#15258) This is good for protoboeuf and other binary parsing --- zjit/src/asm/arm64/mod.rs | 2 +- zjit/src/codegen.rs | 20 +++++++++ zjit/src/cruby_methods.rs | 10 +++++ zjit/src/hir.rs | 10 ++++- zjit/src/hir/opt_tests.rs | 109 ++++++++++++++++++++++++++++++++++++++++++++++ zjit/src/stats.rs | 2 + 6 files changed, 151 insertions(+), 2 deletions(-) diff --git a/zjit/src/asm/arm64/mod.rs b/zjit/src/asm/arm64/mod.rs index a445911731..4094d101fb 100644 --- a/zjit/src/asm/arm64/mod.rs +++ b/zjit/src/asm/arm64/mod.rs @@ -649,7 +649,7 @@ pub fn lsl(cb: &mut CodeBlock, rd: A64Opnd, rn: A64Opnd, shift: A64Opnd) { ShiftImm::lsl(rd.reg_no, rn.reg_no, uimm as u8, rd.num_bits).into() }, - _ => panic!("Invalid operands combination to lsl instruction") + _ => panic!("Invalid operands combination {rd:?} {rn:?} {shift:?} to lsl instruction") }; cb.write_bytes(&bytes); diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs index b95d137222..9f74838c11 100644 --- a/zjit/src/codegen.rs +++ b/zjit/src/codegen.rs @@ -402,6 +402,12 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio Insn::FixnumAnd { left, right } => gen_fixnum_and(asm, opnd!(left), opnd!(right)), Insn::FixnumOr { left, right } => gen_fixnum_or(asm, opnd!(left), opnd!(right)), Insn::FixnumXor { left, right } => gen_fixnum_xor(asm, opnd!(left), opnd!(right)), + &Insn::FixnumLShift { left, right, state } => { + // We only create FixnumLShift when we know the shift amount statically and it's in [0, + // 63]. + let shift_amount = function.type_of(right).fixnum_value().unwrap() as u64; + gen_fixnum_lshift(jit, asm, opnd!(left), shift_amount, &function.frame_state(state)) + } &Insn::FixnumMod { left, right, state } => gen_fixnum_mod(jit, asm, opnd!(left), opnd!(right), &function.frame_state(state)), Insn::IsNil { val } => gen_isnil(asm, opnd!(val)), &Insn::IsMethodCfunc { val, cd, cfunc, state: _ } => gen_is_method_cfunc(jit, asm, opnd!(val), cd, cfunc), @@ -1700,6 +1706,20 @@ fn gen_fixnum_xor(asm: &mut Assembler, left: lir::Opnd, right: lir::Opnd) -> lir asm.add(out_val, Opnd::UImm(1)) } +/// Compile Fixnum << Fixnum +fn gen_fixnum_lshift(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, shift_amount: u64, state: &FrameState) -> lir::Opnd { + // Shift amount is known statically to be in the range [0, 63] + assert!(shift_amount < 64); + let in_val = asm.sub(left, Opnd::UImm(1)); // Drop tag bit + let out_val = asm.lshift(in_val, shift_amount.into()); + let unshifted = asm.rshift(out_val, shift_amount.into()); + asm.cmp(in_val, unshifted); + asm.jne(side_exit(jit, state, FixnumLShiftOverflow)); + // Re-tag the output value + let out_val = asm.add(out_val, 1.into()); + out_val +} + fn gen_fixnum_mod(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, right: lir::Opnd, state: &FrameState) -> lir::Opnd { // Check for left % 0, which raises ZeroDivisionError asm.cmp(right, Opnd::from(VALUE::fixnum_from_usize(0))); diff --git a/zjit/src/cruby_methods.rs b/zjit/src/cruby_methods.rs index d1f76e8da0..3999ef0a10 100644 --- a/zjit/src/cruby_methods.rs +++ b/zjit/src/cruby_methods.rs @@ -240,6 +240,7 @@ pub fn init() -> Annotations { annotate!(rb_cInteger, ">=", inline_integer_ge); annotate!(rb_cInteger, "<", inline_integer_lt); annotate!(rb_cInteger, "<=", inline_integer_le); + annotate!(rb_cInteger, "<<", inline_integer_lshift); annotate!(rb_cString, "to_s", inline_string_to_s, types::StringExact); let thread_singleton = unsafe { rb_singleton_class(rb_cThread) }; annotate!(thread_singleton, "current", inline_thread_current, types::BasicObject, no_gc, leaf); @@ -546,6 +547,15 @@ fn inline_integer_le(fun: &mut hir::Function, block: hir::BlockId, recv: hir::In try_inline_fixnum_op(fun, block, &|left, right| hir::Insn::FixnumLe { left, right }, BOP_LE, recv, other, state) } +fn inline_integer_lshift(fun: &mut hir::Function, block: hir::BlockId, recv: hir::InsnId, args: &[hir::InsnId], state: hir::InsnId) -> Option { + let &[other] = args else { return None; }; + // Only convert to FixnumLShift if we know the shift amount is known at compile-time and could + // plausibly create a fixnum. + let Some(other_value) = fun.type_of(other).fixnum_value() else { return None; }; + if other_value < 0 || other_value > 63 { return None; } + try_inline_fixnum_op(fun, block, &|left, right| hir::Insn::FixnumLShift { left, right, state }, BOP_LTLT, recv, other, state) +} + fn inline_basic_object_eq(fun: &mut hir::Function, block: hir::BlockId, recv: hir::InsnId, args: &[hir::InsnId], _state: hir::InsnId) -> Option { let &[other] = args else { return None; }; let c_result = fun.push_insn(block, hir::Insn::IsBitEqual { left: recv, right: other }); diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index 333d5e5bff..00014c5758 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -485,6 +485,7 @@ pub enum SideExitReason { FixnumAddOverflow, FixnumSubOverflow, FixnumMultOverflow, + FixnumLShiftOverflow, GuardType(Type), GuardTypeNot(Type), GuardShape(ShapeId), @@ -868,7 +869,7 @@ pub enum Insn { /// Non-local control flow. See the throw YARV instruction Throw { throw_state: u32, val: InsnId, state: InsnId }, - /// Fixnum +, -, *, /, %, ==, !=, <, <=, >, >=, &, |, ^ + /// Fixnum +, -, *, /, %, ==, !=, <, <=, >, >=, &, |, ^, << FixnumAdd { left: InsnId, right: InsnId, state: InsnId }, FixnumSub { left: InsnId, right: InsnId, state: InsnId }, FixnumMult { left: InsnId, right: InsnId, state: InsnId }, @@ -883,6 +884,7 @@ pub enum Insn { FixnumAnd { left: InsnId, right: InsnId }, FixnumOr { left: InsnId, right: InsnId }, FixnumXor { left: InsnId, right: InsnId }, + FixnumLShift { left: InsnId, right: InsnId, state: InsnId }, // Distinct from `SendWithoutBlock` with `mid:to_s` because does not have a patch point for String to_s being redefined ObjToString { val: InsnId, cd: *const rb_call_data, state: InsnId }, @@ -979,6 +981,7 @@ impl Insn { Insn::FixnumAnd { .. } => false, Insn::FixnumOr { .. } => false, Insn::FixnumXor { .. } => false, + Insn::FixnumLShift { .. } => false, Insn::GetLocal { .. } => false, Insn::IsNil { .. } => false, Insn::LoadPC => false, @@ -1218,6 +1221,7 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> { Insn::FixnumAnd { left, right, .. } => { write!(f, "FixnumAnd {left}, {right}") }, Insn::FixnumOr { left, right, .. } => { write!(f, "FixnumOr {left}, {right}") }, Insn::FixnumXor { left, right, .. } => { write!(f, "FixnumXor {left}, {right}") }, + Insn::FixnumLShift { left, right, .. } => { write!(f, "FixnumLShift {left}, {right}") }, Insn::GuardType { val, guard_type, .. } => { write!(f, "GuardType {val}, {}", guard_type.print(self.ptr_map)) }, Insn::GuardTypeNot { val, guard_type, .. } => { write!(f, "GuardTypeNot {val}, {}", guard_type.print(self.ptr_map)) }, Insn::GuardBitEquals { val, expected, .. } => { write!(f, "GuardBitEquals {val}, {}", expected.print(self.ptr_map)) }, @@ -1836,6 +1840,7 @@ impl Function { &FixnumAnd { left, right } => FixnumAnd { left: find!(left), right: find!(right) }, &FixnumOr { left, right } => FixnumOr { left: find!(left), right: find!(right) }, &FixnumXor { left, right } => FixnumXor { left: find!(left), right: find!(right) }, + &FixnumLShift { left, right, state } => FixnumLShift { left: find!(left), right: find!(right), state }, &ObjToString { val, cd, state } => ObjToString { val: find!(val), cd, @@ -2054,6 +2059,7 @@ impl Function { Insn::FixnumAnd { .. } => types::Fixnum, Insn::FixnumOr { .. } => types::Fixnum, Insn::FixnumXor { .. } => types::Fixnum, + Insn::FixnumLShift { .. } => types::Fixnum, Insn::PutSpecialObject { .. } => types::BasicObject, Insn::SendWithoutBlock { .. } => types::BasicObject, Insn::SendWithoutBlockDirect { .. } => types::BasicObject, @@ -3506,6 +3512,7 @@ impl Function { | &Insn::FixnumDiv { left, right, state } | &Insn::FixnumMod { left, right, state } | &Insn::ArrayExtend { left, right, state } + | &Insn::FixnumLShift { left, right, state } => { worklist.push_back(left); worklist.push_back(right); @@ -4271,6 +4278,7 @@ impl Function { | Insn::FixnumAnd { left, right } | Insn::FixnumOr { left, right } | Insn::FixnumXor { left, right } + | Insn::FixnumLShift { left, right, .. } | Insn::NewRangeFixnum { low: left, high: right, .. } => { self.assert_subtype(insn_id, left, types::Fixnum)?; diff --git a/zjit/src/hir/opt_tests.rs b/zjit/src/hir/opt_tests.rs index 82f54f611a..19c0ce66e3 100644 --- a/zjit/src/hir/opt_tests.rs +++ b/zjit/src/hir/opt_tests.rs @@ -6466,6 +6466,115 @@ mod hir_opt_tests { "); } + #[test] + fn test_inline_integer_ltlt_with_known_fixnum() { + eval(" + def test(x) = x << 5 + test(4) + "); + assert_contains_opcode("test", YARVINSN_opt_ltlt); + assert_snapshot!(hir_string("test"), @r" + fn test@:2: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal l0, SP@4 + Jump bb2(v1, v2) + bb1(v5:BasicObject, v6:BasicObject): + EntryPoint JIT(0) + Jump bb2(v5, v6) + bb2(v8:BasicObject, v9:BasicObject): + v14:Fixnum[5] = Const Value(5) + PatchPoint MethodRedefined(Integer@0x1000, <<@0x1008, cme:0x1010) + v24:Fixnum = GuardType v9, Fixnum + v25:Fixnum = FixnumLShift v24, v14 + IncrCounter inline_cfunc_optimized_send_count + CheckInterrupts + Return v25 + "); + } + + #[test] + fn test_dont_inline_integer_ltlt_with_negative() { + eval(" + def test(x) = x << -5 + test(4) + "); + assert_contains_opcode("test", YARVINSN_opt_ltlt); + assert_snapshot!(hir_string("test"), @r" + fn test@:2: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal l0, SP@4 + Jump bb2(v1, v2) + bb1(v5:BasicObject, v6:BasicObject): + EntryPoint JIT(0) + Jump bb2(v5, v6) + bb2(v8:BasicObject, v9:BasicObject): + v14:Fixnum[-5] = Const Value(-5) + PatchPoint MethodRedefined(Integer@0x1000, <<@0x1008, cme:0x1010) + v24:Fixnum = GuardType v9, Fixnum + v25:BasicObject = CCallWithFrame Integer#<<@0x1038, v24, v14 + CheckInterrupts + Return v25 + "); + } + + #[test] + fn test_dont_inline_integer_ltlt_with_out_of_range() { + eval(" + def test(x) = x << 64 + test(4) + "); + assert_contains_opcode("test", YARVINSN_opt_ltlt); + assert_snapshot!(hir_string("test"), @r" + fn test@:2: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal l0, SP@4 + Jump bb2(v1, v2) + bb1(v5:BasicObject, v6:BasicObject): + EntryPoint JIT(0) + Jump bb2(v5, v6) + bb2(v8:BasicObject, v9:BasicObject): + v14:Fixnum[64] = Const Value(64) + PatchPoint MethodRedefined(Integer@0x1000, <<@0x1008, cme:0x1010) + v24:Fixnum = GuardType v9, Fixnum + v25:BasicObject = CCallWithFrame Integer#<<@0x1038, v24, v14 + CheckInterrupts + Return v25 + "); + } + + #[test] + fn test_dont_inline_integer_ltlt_with_unknown_fixnum() { + eval(" + def test(x, y) = x << y + test(4, 5) + "); + assert_contains_opcode("test", YARVINSN_opt_ltlt); + assert_snapshot!(hir_string("test"), @r" + fn test@:2: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal l0, SP@5 + v3:BasicObject = GetLocal l0, SP@4 + Jump bb2(v1, v2, v3) + bb1(v6:BasicObject, v7:BasicObject, v8:BasicObject): + EntryPoint JIT(0) + Jump bb2(v6, v7, v8) + bb2(v10:BasicObject, v11:BasicObject, v12:BasicObject): + PatchPoint MethodRedefined(Integer@0x1000, <<@0x1008, cme:0x1010) + v26:Fixnum = GuardType v11, Fixnum + v27:BasicObject = CCallWithFrame Integer#<<@0x1038, v26, v12 + CheckInterrupts + Return v27 + "); + } + #[test] fn test_optimize_string_append() { eval(r#" diff --git a/zjit/src/stats.rs b/zjit/src/stats.rs index df172997ce..1277db5b7e 100644 --- a/zjit/src/stats.rs +++ b/zjit/src/stats.rs @@ -142,6 +142,7 @@ make_counters! { exit_fixnum_add_overflow, exit_fixnum_sub_overflow, exit_fixnum_mult_overflow, + exit_fixnum_lshift_overflow, exit_fixnum_mod_by_zero, exit_box_fixnum_overflow, exit_guard_type_failure, @@ -423,6 +424,7 @@ pub fn side_exit_counter(reason: crate::hir::SideExitReason) -> Counter { FixnumAddOverflow => exit_fixnum_add_overflow, FixnumSubOverflow => exit_fixnum_sub_overflow, FixnumMultOverflow => exit_fixnum_mult_overflow, + FixnumLShiftOverflow => exit_fixnum_lshift_overflow, FixnumModByZero => exit_fixnum_mod_by_zero, BoxFixnumOverflow => exit_box_fixnum_overflow, GuardType(_) => exit_guard_type_failure, -- cgit v1.2.3