summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMax Bernstein <rubybugs@bernsteinbear.com>2025-11-21 08:57:26 -0800
committerGitHub <noreply@github.com>2025-11-21 11:57:26 -0500
commite0bb3fb1cda2238d0c98afcdec2fe282c29994aa (patch)
tree577333a9b74bd100f6ee3e0e4de5cb0a335a269c
parent8728406c418f1a200cda02a259ba164d185a8ebd (diff)
ZJIT: Inline Integer#<< for constant rhs (#15258)
This is good for protoboeuf and other binary parsing
-rw-r--r--zjit/src/asm/arm64/mod.rs2
-rw-r--r--zjit/src/codegen.rs20
-rw-r--r--zjit/src/cruby_methods.rs10
-rw-r--r--zjit/src/hir.rs10
-rw-r--r--zjit/src/hir/opt_tests.rs109
-rw-r--r--zjit/src/stats.rs2
6 files changed, 151 insertions, 2 deletions
diff --git a/zjit/src/asm/arm64/mod.rs b/zjit/src/asm/arm64/mod.rs
index a445911731..4094d101fb 100644
--- a/zjit/src/asm/arm64/mod.rs
+++ b/zjit/src/asm/arm64/mod.rs
@@ -649,7 +649,7 @@ pub fn lsl(cb: &mut CodeBlock, rd: A64Opnd, rn: A64Opnd, shift: A64Opnd) {
ShiftImm::lsl(rd.reg_no, rn.reg_no, uimm as u8, rd.num_bits).into()
},
- _ => panic!("Invalid operands combination to lsl instruction")
+ _ => panic!("Invalid operands combination {rd:?} {rn:?} {shift:?} to lsl instruction")
};
cb.write_bytes(&bytes);
diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs
index b95d137222..9f74838c11 100644
--- a/zjit/src/codegen.rs
+++ b/zjit/src/codegen.rs
@@ -402,6 +402,12 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio
Insn::FixnumAnd { left, right } => gen_fixnum_and(asm, opnd!(left), opnd!(right)),
Insn::FixnumOr { left, right } => gen_fixnum_or(asm, opnd!(left), opnd!(right)),
Insn::FixnumXor { left, right } => gen_fixnum_xor(asm, opnd!(left), opnd!(right)),
+ &Insn::FixnumLShift { left, right, state } => {
+ // We only create FixnumLShift when we know the shift amount statically and it's in [0,
+ // 63].
+ let shift_amount = function.type_of(right).fixnum_value().unwrap() as u64;
+ gen_fixnum_lshift(jit, asm, opnd!(left), shift_amount, &function.frame_state(state))
+ }
&Insn::FixnumMod { left, right, state } => gen_fixnum_mod(jit, asm, opnd!(left), opnd!(right), &function.frame_state(state)),
Insn::IsNil { val } => gen_isnil(asm, opnd!(val)),
&Insn::IsMethodCfunc { val, cd, cfunc, state: _ } => gen_is_method_cfunc(jit, asm, opnd!(val), cd, cfunc),
@@ -1700,6 +1706,20 @@ fn gen_fixnum_xor(asm: &mut Assembler, left: lir::Opnd, right: lir::Opnd) -> lir
asm.add(out_val, Opnd::UImm(1))
}
+/// Compile Fixnum << Fixnum
+fn gen_fixnum_lshift(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, shift_amount: u64, state: &FrameState) -> lir::Opnd {
+ // Shift amount is known statically to be in the range [0, 63]
+ assert!(shift_amount < 64);
+ let in_val = asm.sub(left, Opnd::UImm(1)); // Drop tag bit
+ let out_val = asm.lshift(in_val, shift_amount.into());
+ let unshifted = asm.rshift(out_val, shift_amount.into());
+ asm.cmp(in_val, unshifted);
+ asm.jne(side_exit(jit, state, FixnumLShiftOverflow));
+ // Re-tag the output value
+ let out_val = asm.add(out_val, 1.into());
+ out_val
+}
+
fn gen_fixnum_mod(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, right: lir::Opnd, state: &FrameState) -> lir::Opnd {
// Check for left % 0, which raises ZeroDivisionError
asm.cmp(right, Opnd::from(VALUE::fixnum_from_usize(0)));
diff --git a/zjit/src/cruby_methods.rs b/zjit/src/cruby_methods.rs
index d1f76e8da0..3999ef0a10 100644
--- a/zjit/src/cruby_methods.rs
+++ b/zjit/src/cruby_methods.rs
@@ -240,6 +240,7 @@ pub fn init() -> Annotations {
annotate!(rb_cInteger, ">=", inline_integer_ge);
annotate!(rb_cInteger, "<", inline_integer_lt);
annotate!(rb_cInteger, "<=", inline_integer_le);
+ annotate!(rb_cInteger, "<<", inline_integer_lshift);
annotate!(rb_cString, "to_s", inline_string_to_s, types::StringExact);
let thread_singleton = unsafe { rb_singleton_class(rb_cThread) };
annotate!(thread_singleton, "current", inline_thread_current, types::BasicObject, no_gc, leaf);
@@ -546,6 +547,15 @@ fn inline_integer_le(fun: &mut hir::Function, block: hir::BlockId, recv: hir::In
try_inline_fixnum_op(fun, block, &|left, right| hir::Insn::FixnumLe { left, right }, BOP_LE, recv, other, state)
}
+fn inline_integer_lshift(fun: &mut hir::Function, block: hir::BlockId, recv: hir::InsnId, args: &[hir::InsnId], state: hir::InsnId) -> Option<hir::InsnId> {
+ let &[other] = args else { return None; };
+ // Only convert to FixnumLShift if we know the shift amount is known at compile-time and could
+ // plausibly create a fixnum.
+ let Some(other_value) = fun.type_of(other).fixnum_value() else { return None; };
+ if other_value < 0 || other_value > 63 { return None; }
+ try_inline_fixnum_op(fun, block, &|left, right| hir::Insn::FixnumLShift { left, right, state }, BOP_LTLT, recv, other, state)
+}
+
fn inline_basic_object_eq(fun: &mut hir::Function, block: hir::BlockId, recv: hir::InsnId, args: &[hir::InsnId], _state: hir::InsnId) -> Option<hir::InsnId> {
let &[other] = args else { return None; };
let c_result = fun.push_insn(block, hir::Insn::IsBitEqual { left: recv, right: other });
diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs
index 333d5e5bff..00014c5758 100644
--- a/zjit/src/hir.rs
+++ b/zjit/src/hir.rs
@@ -485,6 +485,7 @@ pub enum SideExitReason {
FixnumAddOverflow,
FixnumSubOverflow,
FixnumMultOverflow,
+ FixnumLShiftOverflow,
GuardType(Type),
GuardTypeNot(Type),
GuardShape(ShapeId),
@@ -868,7 +869,7 @@ pub enum Insn {
/// Non-local control flow. See the throw YARV instruction
Throw { throw_state: u32, val: InsnId, state: InsnId },
- /// Fixnum +, -, *, /, %, ==, !=, <, <=, >, >=, &, |, ^
+ /// Fixnum +, -, *, /, %, ==, !=, <, <=, >, >=, &, |, ^, <<
FixnumAdd { left: InsnId, right: InsnId, state: InsnId },
FixnumSub { left: InsnId, right: InsnId, state: InsnId },
FixnumMult { left: InsnId, right: InsnId, state: InsnId },
@@ -883,6 +884,7 @@ pub enum Insn {
FixnumAnd { left: InsnId, right: InsnId },
FixnumOr { left: InsnId, right: InsnId },
FixnumXor { left: InsnId, right: InsnId },
+ FixnumLShift { left: InsnId, right: InsnId, state: InsnId },
// Distinct from `SendWithoutBlock` with `mid:to_s` because does not have a patch point for String to_s being redefined
ObjToString { val: InsnId, cd: *const rb_call_data, state: InsnId },
@@ -979,6 +981,7 @@ impl Insn {
Insn::FixnumAnd { .. } => false,
Insn::FixnumOr { .. } => false,
Insn::FixnumXor { .. } => false,
+ Insn::FixnumLShift { .. } => false,
Insn::GetLocal { .. } => false,
Insn::IsNil { .. } => false,
Insn::LoadPC => false,
@@ -1218,6 +1221,7 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> {
Insn::FixnumAnd { left, right, .. } => { write!(f, "FixnumAnd {left}, {right}") },
Insn::FixnumOr { left, right, .. } => { write!(f, "FixnumOr {left}, {right}") },
Insn::FixnumXor { left, right, .. } => { write!(f, "FixnumXor {left}, {right}") },
+ Insn::FixnumLShift { left, right, .. } => { write!(f, "FixnumLShift {left}, {right}") },
Insn::GuardType { val, guard_type, .. } => { write!(f, "GuardType {val}, {}", guard_type.print(self.ptr_map)) },
Insn::GuardTypeNot { val, guard_type, .. } => { write!(f, "GuardTypeNot {val}, {}", guard_type.print(self.ptr_map)) },
Insn::GuardBitEquals { val, expected, .. } => { write!(f, "GuardBitEquals {val}, {}", expected.print(self.ptr_map)) },
@@ -1836,6 +1840,7 @@ impl Function {
&FixnumAnd { left, right } => FixnumAnd { left: find!(left), right: find!(right) },
&FixnumOr { left, right } => FixnumOr { left: find!(left), right: find!(right) },
&FixnumXor { left, right } => FixnumXor { left: find!(left), right: find!(right) },
+ &FixnumLShift { left, right, state } => FixnumLShift { left: find!(left), right: find!(right), state },
&ObjToString { val, cd, state } => ObjToString {
val: find!(val),
cd,
@@ -2054,6 +2059,7 @@ impl Function {
Insn::FixnumAnd { .. } => types::Fixnum,
Insn::FixnumOr { .. } => types::Fixnum,
Insn::FixnumXor { .. } => types::Fixnum,
+ Insn::FixnumLShift { .. } => types::Fixnum,
Insn::PutSpecialObject { .. } => types::BasicObject,
Insn::SendWithoutBlock { .. } => types::BasicObject,
Insn::SendWithoutBlockDirect { .. } => types::BasicObject,
@@ -3506,6 +3512,7 @@ impl Function {
| &Insn::FixnumDiv { left, right, state }
| &Insn::FixnumMod { left, right, state }
| &Insn::ArrayExtend { left, right, state }
+ | &Insn::FixnumLShift { left, right, state }
=> {
worklist.push_back(left);
worklist.push_back(right);
@@ -4271,6 +4278,7 @@ impl Function {
| Insn::FixnumAnd { left, right }
| Insn::FixnumOr { left, right }
| Insn::FixnumXor { left, right }
+ | Insn::FixnumLShift { left, right, .. }
| Insn::NewRangeFixnum { low: left, high: right, .. }
=> {
self.assert_subtype(insn_id, left, types::Fixnum)?;
diff --git a/zjit/src/hir/opt_tests.rs b/zjit/src/hir/opt_tests.rs
index 82f54f611a..19c0ce66e3 100644
--- a/zjit/src/hir/opt_tests.rs
+++ b/zjit/src/hir/opt_tests.rs
@@ -6467,6 +6467,115 @@ mod hir_opt_tests {
}
#[test]
+ fn test_inline_integer_ltlt_with_known_fixnum() {
+ eval("
+ def test(x) = x << 5
+ test(4)
+ ");
+ assert_contains_opcode("test", YARVINSN_opt_ltlt);
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:2:
+ bb0():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ v2:BasicObject = GetLocal l0, SP@4
+ Jump bb2(v1, v2)
+ bb1(v5:BasicObject, v6:BasicObject):
+ EntryPoint JIT(0)
+ Jump bb2(v5, v6)
+ bb2(v8:BasicObject, v9:BasicObject):
+ v14:Fixnum[5] = Const Value(5)
+ PatchPoint MethodRedefined(Integer@0x1000, <<@0x1008, cme:0x1010)
+ v24:Fixnum = GuardType v9, Fixnum
+ v25:Fixnum = FixnumLShift v24, v14
+ IncrCounter inline_cfunc_optimized_send_count
+ CheckInterrupts
+ Return v25
+ ");
+ }
+
+ #[test]
+ fn test_dont_inline_integer_ltlt_with_negative() {
+ eval("
+ def test(x) = x << -5
+ test(4)
+ ");
+ assert_contains_opcode("test", YARVINSN_opt_ltlt);
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:2:
+ bb0():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ v2:BasicObject = GetLocal l0, SP@4
+ Jump bb2(v1, v2)
+ bb1(v5:BasicObject, v6:BasicObject):
+ EntryPoint JIT(0)
+ Jump bb2(v5, v6)
+ bb2(v8:BasicObject, v9:BasicObject):
+ v14:Fixnum[-5] = Const Value(-5)
+ PatchPoint MethodRedefined(Integer@0x1000, <<@0x1008, cme:0x1010)
+ v24:Fixnum = GuardType v9, Fixnum
+ v25:BasicObject = CCallWithFrame Integer#<<@0x1038, v24, v14
+ CheckInterrupts
+ Return v25
+ ");
+ }
+
+ #[test]
+ fn test_dont_inline_integer_ltlt_with_out_of_range() {
+ eval("
+ def test(x) = x << 64
+ test(4)
+ ");
+ assert_contains_opcode("test", YARVINSN_opt_ltlt);
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:2:
+ bb0():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ v2:BasicObject = GetLocal l0, SP@4
+ Jump bb2(v1, v2)
+ bb1(v5:BasicObject, v6:BasicObject):
+ EntryPoint JIT(0)
+ Jump bb2(v5, v6)
+ bb2(v8:BasicObject, v9:BasicObject):
+ v14:Fixnum[64] = Const Value(64)
+ PatchPoint MethodRedefined(Integer@0x1000, <<@0x1008, cme:0x1010)
+ v24:Fixnum = GuardType v9, Fixnum
+ v25:BasicObject = CCallWithFrame Integer#<<@0x1038, v24, v14
+ CheckInterrupts
+ Return v25
+ ");
+ }
+
+ #[test]
+ fn test_dont_inline_integer_ltlt_with_unknown_fixnum() {
+ eval("
+ def test(x, y) = x << y
+ test(4, 5)
+ ");
+ assert_contains_opcode("test", YARVINSN_opt_ltlt);
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:2:
+ bb0():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ v2:BasicObject = GetLocal l0, SP@5
+ v3:BasicObject = GetLocal l0, SP@4
+ Jump bb2(v1, v2, v3)
+ bb1(v6:BasicObject, v7:BasicObject, v8:BasicObject):
+ EntryPoint JIT(0)
+ Jump bb2(v6, v7, v8)
+ bb2(v10:BasicObject, v11:BasicObject, v12:BasicObject):
+ PatchPoint MethodRedefined(Integer@0x1000, <<@0x1008, cme:0x1010)
+ v26:Fixnum = GuardType v11, Fixnum
+ v27:BasicObject = CCallWithFrame Integer#<<@0x1038, v26, v12
+ CheckInterrupts
+ Return v27
+ ");
+ }
+
+ #[test]
fn test_optimize_string_append() {
eval(r#"
def test(x, y) = x << y
diff --git a/zjit/src/stats.rs b/zjit/src/stats.rs
index df172997ce..1277db5b7e 100644
--- a/zjit/src/stats.rs
+++ b/zjit/src/stats.rs
@@ -142,6 +142,7 @@ make_counters! {
exit_fixnum_add_overflow,
exit_fixnum_sub_overflow,
exit_fixnum_mult_overflow,
+ exit_fixnum_lshift_overflow,
exit_fixnum_mod_by_zero,
exit_box_fixnum_overflow,
exit_guard_type_failure,
@@ -423,6 +424,7 @@ pub fn side_exit_counter(reason: crate::hir::SideExitReason) -> Counter {
FixnumAddOverflow => exit_fixnum_add_overflow,
FixnumSubOverflow => exit_fixnum_sub_overflow,
FixnumMultOverflow => exit_fixnum_mult_overflow,
+ FixnumLShiftOverflow => exit_fixnum_lshift_overflow,
FixnumModByZero => exit_fixnum_mod_by_zero,
BoxFixnumOverflow => exit_box_fixnum_overflow,
GuardType(_) => exit_guard_type_failure,