summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMax Bernstein <ruby@bernsteinbear.com>2025-11-24 22:44:53 -0500
committerMax Bernstein <tekknolagi@gmail.com>2025-12-01 15:19:26 -0800
commit6db83a00a4272eb1089d67da83e1cd9d4e10227b (patch)
treeeadaee8a0ae5fb3ce337610c60ba1c097e10fdb3
parenta25196395e7502e4d6faad0856c697690d8a202e (diff)
ZJIT: Specialize Integer#>>
Same as Integer#>>. Also add more strict type checks for both Integer#>> and Integer#<<.
-rw-r--r--zjit/src/codegen.rs15
-rw-r--r--zjit/src/cruby_methods.rs11
-rw-r--r--zjit/src/hir.rs22
-rw-r--r--zjit/src/hir/opt_tests.rs105
4 files changed, 152 insertions, 1 deletions
diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs
index 6fc8566469..df9a9299cf 100644
--- a/zjit/src/codegen.rs
+++ b/zjit/src/codegen.rs
@@ -409,6 +409,12 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio
let shift_amount = function.type_of(right).fixnum_value().unwrap() as u64;
gen_fixnum_lshift(jit, asm, opnd!(left), shift_amount, &function.frame_state(state))
}
+ &Insn::FixnumRShift { left, right } => {
+ // We only create FixnumRShift when we know the shift amount statically and it's in [0,
+ // 63].
+ let shift_amount = function.type_of(right).fixnum_value().unwrap() as u64;
+ gen_fixnum_rshift(asm, opnd!(left), shift_amount)
+ }
&Insn::FixnumMod { left, right, state } => gen_fixnum_mod(jit, asm, opnd!(left), opnd!(right), &function.frame_state(state)),
Insn::IsNil { val } => gen_isnil(asm, opnd!(val)),
&Insn::IsMethodCfunc { val, cd, cfunc, state: _ } => gen_is_method_cfunc(jit, asm, opnd!(val), cd, cfunc),
@@ -1754,6 +1760,15 @@ fn gen_fixnum_lshift(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, s
out_val
}
+/// Compile Fixnum >> Fixnum
+fn gen_fixnum_rshift(asm: &mut Assembler, left: lir::Opnd, shift_amount: u64) -> lir::Opnd {
+ // Shift amount is known statically to be in the range [0, 63]
+ assert!(shift_amount < 64);
+ let result = asm.rshift(left, shift_amount.into());
+ // Re-tag the output value
+ asm.or(result, 1.into())
+}
+
fn gen_fixnum_mod(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, right: lir::Opnd, state: &FrameState) -> lir::Opnd {
// Check for left % 0, which raises ZeroDivisionError
asm.cmp(right, Opnd::from(VALUE::fixnum_from_usize(0)));
diff --git a/zjit/src/cruby_methods.rs b/zjit/src/cruby_methods.rs
index f86d383876..f84882101c 100644
--- a/zjit/src/cruby_methods.rs
+++ b/zjit/src/cruby_methods.rs
@@ -242,6 +242,7 @@ pub fn init() -> Annotations {
annotate!(rb_cInteger, "<", inline_integer_lt);
annotate!(rb_cInteger, "<=", inline_integer_le);
annotate!(rb_cInteger, "<<", inline_integer_lshift);
+ annotate!(rb_cInteger, ">>", inline_integer_rshift);
annotate!(rb_cInteger, "to_s", types::StringExact);
annotate!(rb_cString, "to_s", inline_string_to_s, types::StringExact);
let thread_singleton = unsafe { rb_singleton_class(rb_cThread) };
@@ -575,6 +576,16 @@ fn inline_integer_lshift(fun: &mut hir::Function, block: hir::BlockId, recv: hir
try_inline_fixnum_op(fun, block, &|left, right| hir::Insn::FixnumLShift { left, right, state }, BOP_LTLT, recv, other, state)
}
+fn inline_integer_rshift(fun: &mut hir::Function, block: hir::BlockId, recv: hir::InsnId, args: &[hir::InsnId], state: hir::InsnId) -> Option<hir::InsnId> {
+ let &[other] = args else { return None; };
+ // Only convert to FixnumLShift if we know the shift amount is known at compile-time and could
+ // plausibly create a fixnum.
+ let Some(other_value) = fun.type_of(other).fixnum_value() else { return None; };
+ // TODO(max): If other_value > 63, rewrite to constant zero.
+ if other_value < 0 || other_value > 63 { return None; }
+ try_inline_fixnum_op(fun, block, &|left, right| hir::Insn::FixnumRShift { left, right }, BOP_GTGT, recv, other, state)
+}
+
fn inline_basic_object_eq(fun: &mut hir::Function, block: hir::BlockId, recv: hir::InsnId, args: &[hir::InsnId], _state: hir::InsnId) -> Option<hir::InsnId> {
let &[other] = args else { return None; };
let c_result = fun.push_insn(block, hir::Insn::IsBitEqual { left: recv, right: other });
diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs
index 6604c52a82..a69628b869 100644
--- a/zjit/src/hir.rs
+++ b/zjit/src/hir.rs
@@ -891,6 +891,7 @@ pub enum Insn {
FixnumOr { left: InsnId, right: InsnId },
FixnumXor { left: InsnId, right: InsnId },
FixnumLShift { left: InsnId, right: InsnId, state: InsnId },
+ FixnumRShift { left: InsnId, right: InsnId },
// Distinct from `SendWithoutBlock` with `mid:to_s` because does not have a patch point for String to_s being redefined
ObjToString { val: InsnId, cd: *const rb_call_data, state: InsnId },
@@ -989,6 +990,7 @@ impl Insn {
Insn::FixnumOr { .. } => false,
Insn::FixnumXor { .. } => false,
Insn::FixnumLShift { .. } => false,
+ Insn::FixnumRShift { .. } => false,
Insn::GetLocal { .. } => false,
Insn::IsNil { .. } => false,
Insn::LoadPC => false,
@@ -1233,6 +1235,7 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> {
Insn::FixnumOr { left, right, .. } => { write!(f, "FixnumOr {left}, {right}") },
Insn::FixnumXor { left, right, .. } => { write!(f, "FixnumXor {left}, {right}") },
Insn::FixnumLShift { left, right, .. } => { write!(f, "FixnumLShift {left}, {right}") },
+ Insn::FixnumRShift { left, right, .. } => { write!(f, "FixnumRShift {left}, {right}") },
Insn::GuardType { val, guard_type, .. } => { write!(f, "GuardType {val}, {}", guard_type.print(self.ptr_map)) },
Insn::GuardTypeNot { val, guard_type, .. } => { write!(f, "GuardTypeNot {val}, {}", guard_type.print(self.ptr_map)) },
Insn::GuardBitEquals { val, expected, .. } => { write!(f, "GuardBitEquals {val}, {}", expected.print(self.ptr_map)) },
@@ -1854,6 +1857,7 @@ impl Function {
&FixnumOr { left, right } => FixnumOr { left: find!(left), right: find!(right) },
&FixnumXor { left, right } => FixnumXor { left: find!(left), right: find!(right) },
&FixnumLShift { left, right, state } => FixnumLShift { left: find!(left), right: find!(right), state },
+ &FixnumRShift { left, right } => FixnumRShift { left: find!(left), right: find!(right) },
&ObjToString { val, cd, state } => ObjToString {
val: find!(val),
cd,
@@ -2076,6 +2080,7 @@ impl Function {
Insn::FixnumOr { .. } => types::Fixnum,
Insn::FixnumXor { .. } => types::Fixnum,
Insn::FixnumLShift { .. } => types::Fixnum,
+ Insn::FixnumRShift { .. } => types::Fixnum,
Insn::PutSpecialObject { .. } => types::BasicObject,
Insn::SendWithoutBlock { .. } => types::BasicObject,
Insn::SendWithoutBlockDirect { .. } => types::BasicObject,
@@ -3639,6 +3644,7 @@ impl Function {
| &Insn::FixnumAnd { left, right }
| &Insn::FixnumOr { left, right }
| &Insn::FixnumXor { left, right }
+ | &Insn::FixnumRShift { left, right }
| &Insn::IsBitEqual { left, right }
| &Insn::IsBitNotEqual { left, right }
=> {
@@ -4403,12 +4409,26 @@ impl Function {
| Insn::FixnumAnd { left, right }
| Insn::FixnumOr { left, right }
| Insn::FixnumXor { left, right }
- | Insn::FixnumLShift { left, right, .. }
| Insn::NewRangeFixnum { low: left, high: right, .. }
=> {
self.assert_subtype(insn_id, left, types::Fixnum)?;
self.assert_subtype(insn_id, right, types::Fixnum)
}
+ Insn::FixnumLShift { left, right, .. }
+ | Insn::FixnumRShift { left, right, .. } => {
+ self.assert_subtype(insn_id, left, types::Fixnum)?;
+ self.assert_subtype(insn_id, right, types::Fixnum)?;
+ let Some(obj) = self.type_of(right).fixnum_value() else {
+ return Err(ValidationError::MismatchedOperandType(insn_id, right, "<a compile-time constant>".into(), "<unknown>".into()));
+ };
+ if obj < 0 {
+ return Err(ValidationError::MismatchedOperandType(insn_id, right, "<positive>".into(), format!("{obj}")));
+ }
+ if obj > 63 {
+ return Err(ValidationError::MismatchedOperandType(insn_id, right, "<less than 64>".into(), format!("{obj}")));
+ }
+ Ok(())
+ }
Insn::GuardBitEquals { val, expected, .. } => {
match expected {
Const::Value(_) => self.assert_subtype(insn_id, val, types::RubyValue),
diff --git a/zjit/src/hir/opt_tests.rs b/zjit/src/hir/opt_tests.rs
index 60f5814973..610986d62f 100644
--- a/zjit/src/hir/opt_tests.rs
+++ b/zjit/src/hir/opt_tests.rs
@@ -6826,6 +6826,111 @@ mod hir_opt_tests {
}
#[test]
+ fn test_inline_integer_gtgt_with_known_fixnum() {
+ eval("
+ def test(x) = x >> 5
+ test(4)
+ ");
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:2:
+ bb0():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ v2:BasicObject = GetLocal l0, SP@4
+ Jump bb2(v1, v2)
+ bb1(v5:BasicObject, v6:BasicObject):
+ EntryPoint JIT(0)
+ Jump bb2(v5, v6)
+ bb2(v8:BasicObject, v9:BasicObject):
+ v14:Fixnum[5] = Const Value(5)
+ PatchPoint MethodRedefined(Integer@0x1000, >>@0x1008, cme:0x1010)
+ v23:Fixnum = GuardType v9, Fixnum
+ v24:Fixnum = FixnumRShift v23, v14
+ IncrCounter inline_cfunc_optimized_send_count
+ CheckInterrupts
+ Return v24
+ ");
+ }
+
+ #[test]
+ fn test_dont_inline_integer_gtgt_with_negative() {
+ eval("
+ def test(x) = x >> -5
+ test(4)
+ ");
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:2:
+ bb0():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ v2:BasicObject = GetLocal l0, SP@4
+ Jump bb2(v1, v2)
+ bb1(v5:BasicObject, v6:BasicObject):
+ EntryPoint JIT(0)
+ Jump bb2(v5, v6)
+ bb2(v8:BasicObject, v9:BasicObject):
+ v14:Fixnum[-5] = Const Value(-5)
+ PatchPoint MethodRedefined(Integer@0x1000, >>@0x1008, cme:0x1010)
+ v23:Fixnum = GuardType v9, Fixnum
+ v24:BasicObject = CCallWithFrame v23, :Integer#>>@0x1038, v14
+ CheckInterrupts
+ Return v24
+ ");
+ }
+
+ #[test]
+ fn test_dont_inline_integer_gtgt_with_out_of_range() {
+ eval("
+ def test(x) = x >> 64
+ test(4)
+ ");
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:2:
+ bb0():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ v2:BasicObject = GetLocal l0, SP@4
+ Jump bb2(v1, v2)
+ bb1(v5:BasicObject, v6:BasicObject):
+ EntryPoint JIT(0)
+ Jump bb2(v5, v6)
+ bb2(v8:BasicObject, v9:BasicObject):
+ v14:Fixnum[64] = Const Value(64)
+ PatchPoint MethodRedefined(Integer@0x1000, >>@0x1008, cme:0x1010)
+ v23:Fixnum = GuardType v9, Fixnum
+ v24:BasicObject = CCallWithFrame v23, :Integer#>>@0x1038, v14
+ CheckInterrupts
+ Return v24
+ ");
+ }
+
+ #[test]
+ fn test_dont_inline_integer_gtgt_with_unknown_fixnum() {
+ eval("
+ def test(x, y) = x >> y
+ test(4, 5)
+ ");
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:2:
+ bb0():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ v2:BasicObject = GetLocal l0, SP@5
+ v3:BasicObject = GetLocal l0, SP@4
+ Jump bb2(v1, v2, v3)
+ bb1(v6:BasicObject, v7:BasicObject, v8:BasicObject):
+ EntryPoint JIT(0)
+ Jump bb2(v6, v7, v8)
+ bb2(v10:BasicObject, v11:BasicObject, v12:BasicObject):
+ PatchPoint MethodRedefined(Integer@0x1000, >>@0x1008, cme:0x1010)
+ v25:Fixnum = GuardType v11, Fixnum
+ v26:BasicObject = CCallWithFrame v25, :Integer#>>@0x1038, v12
+ CheckInterrupts
+ Return v26
+ ");
+ }
+
+ #[test]
fn test_optimize_string_append() {
eval(r#"
def test(x, y) = x << y