diff options
| author | Max Bernstein <ruby@bernsteinbear.com> | 2025-11-24 22:44:53 -0500 |
|---|---|---|
| committer | Max Bernstein <tekknolagi@gmail.com> | 2025-12-01 15:19:26 -0800 |
| commit | 6db83a00a4272eb1089d67da83e1cd9d4e10227b (patch) | |
| tree | eadaee8a0ae5fb3ce337610c60ba1c097e10fdb3 | |
| parent | a25196395e7502e4d6faad0856c697690d8a202e (diff) | |
ZJIT: Specialize Integer#>>
Same as Integer#>>. Also add more strict type checks for both Integer#>>
and Integer#<<.
| -rw-r--r-- | zjit/src/codegen.rs | 15 | ||||
| -rw-r--r-- | zjit/src/cruby_methods.rs | 11 | ||||
| -rw-r--r-- | zjit/src/hir.rs | 22 | ||||
| -rw-r--r-- | zjit/src/hir/opt_tests.rs | 105 |
4 files changed, 152 insertions, 1 deletions
diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs index 6fc8566469..df9a9299cf 100644 --- a/zjit/src/codegen.rs +++ b/zjit/src/codegen.rs @@ -409,6 +409,12 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio let shift_amount = function.type_of(right).fixnum_value().unwrap() as u64; gen_fixnum_lshift(jit, asm, opnd!(left), shift_amount, &function.frame_state(state)) } + &Insn::FixnumRShift { left, right } => { + // We only create FixnumRShift when we know the shift amount statically and it's in [0, + // 63]. + let shift_amount = function.type_of(right).fixnum_value().unwrap() as u64; + gen_fixnum_rshift(asm, opnd!(left), shift_amount) + } &Insn::FixnumMod { left, right, state } => gen_fixnum_mod(jit, asm, opnd!(left), opnd!(right), &function.frame_state(state)), Insn::IsNil { val } => gen_isnil(asm, opnd!(val)), &Insn::IsMethodCfunc { val, cd, cfunc, state: _ } => gen_is_method_cfunc(jit, asm, opnd!(val), cd, cfunc), @@ -1754,6 +1760,15 @@ fn gen_fixnum_lshift(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, s out_val } +/// Compile Fixnum >> Fixnum +fn gen_fixnum_rshift(asm: &mut Assembler, left: lir::Opnd, shift_amount: u64) -> lir::Opnd { + // Shift amount is known statically to be in the range [0, 63] + assert!(shift_amount < 64); + let result = asm.rshift(left, shift_amount.into()); + // Re-tag the output value + asm.or(result, 1.into()) +} + fn gen_fixnum_mod(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, right: lir::Opnd, state: &FrameState) -> lir::Opnd { // Check for left % 0, which raises ZeroDivisionError asm.cmp(right, Opnd::from(VALUE::fixnum_from_usize(0))); diff --git a/zjit/src/cruby_methods.rs b/zjit/src/cruby_methods.rs index f86d383876..f84882101c 100644 --- a/zjit/src/cruby_methods.rs +++ b/zjit/src/cruby_methods.rs @@ -242,6 +242,7 @@ pub fn init() -> Annotations { annotate!(rb_cInteger, "<", inline_integer_lt); annotate!(rb_cInteger, "<=", inline_integer_le); annotate!(rb_cInteger, "<<", inline_integer_lshift); + annotate!(rb_cInteger, ">>", inline_integer_rshift); annotate!(rb_cInteger, "to_s", types::StringExact); annotate!(rb_cString, "to_s", inline_string_to_s, types::StringExact); let thread_singleton = unsafe { rb_singleton_class(rb_cThread) }; @@ -575,6 +576,16 @@ fn inline_integer_lshift(fun: &mut hir::Function, block: hir::BlockId, recv: hir try_inline_fixnum_op(fun, block, &|left, right| hir::Insn::FixnumLShift { left, right, state }, BOP_LTLT, recv, other, state) } +fn inline_integer_rshift(fun: &mut hir::Function, block: hir::BlockId, recv: hir::InsnId, args: &[hir::InsnId], state: hir::InsnId) -> Option<hir::InsnId> { + let &[other] = args else { return None; }; + // Only convert to FixnumLShift if we know the shift amount is known at compile-time and could + // plausibly create a fixnum. + let Some(other_value) = fun.type_of(other).fixnum_value() else { return None; }; + // TODO(max): If other_value > 63, rewrite to constant zero. + if other_value < 0 || other_value > 63 { return None; } + try_inline_fixnum_op(fun, block, &|left, right| hir::Insn::FixnumRShift { left, right }, BOP_GTGT, recv, other, state) +} + fn inline_basic_object_eq(fun: &mut hir::Function, block: hir::BlockId, recv: hir::InsnId, args: &[hir::InsnId], _state: hir::InsnId) -> Option<hir::InsnId> { let &[other] = args else { return None; }; let c_result = fun.push_insn(block, hir::Insn::IsBitEqual { left: recv, right: other }); diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index 6604c52a82..a69628b869 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -891,6 +891,7 @@ pub enum Insn { FixnumOr { left: InsnId, right: InsnId }, FixnumXor { left: InsnId, right: InsnId }, FixnumLShift { left: InsnId, right: InsnId, state: InsnId }, + FixnumRShift { left: InsnId, right: InsnId }, // Distinct from `SendWithoutBlock` with `mid:to_s` because does not have a patch point for String to_s being redefined ObjToString { val: InsnId, cd: *const rb_call_data, state: InsnId }, @@ -989,6 +990,7 @@ impl Insn { Insn::FixnumOr { .. } => false, Insn::FixnumXor { .. } => false, Insn::FixnumLShift { .. } => false, + Insn::FixnumRShift { .. } => false, Insn::GetLocal { .. } => false, Insn::IsNil { .. } => false, Insn::LoadPC => false, @@ -1233,6 +1235,7 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> { Insn::FixnumOr { left, right, .. } => { write!(f, "FixnumOr {left}, {right}") }, Insn::FixnumXor { left, right, .. } => { write!(f, "FixnumXor {left}, {right}") }, Insn::FixnumLShift { left, right, .. } => { write!(f, "FixnumLShift {left}, {right}") }, + Insn::FixnumRShift { left, right, .. } => { write!(f, "FixnumRShift {left}, {right}") }, Insn::GuardType { val, guard_type, .. } => { write!(f, "GuardType {val}, {}", guard_type.print(self.ptr_map)) }, Insn::GuardTypeNot { val, guard_type, .. } => { write!(f, "GuardTypeNot {val}, {}", guard_type.print(self.ptr_map)) }, Insn::GuardBitEquals { val, expected, .. } => { write!(f, "GuardBitEquals {val}, {}", expected.print(self.ptr_map)) }, @@ -1854,6 +1857,7 @@ impl Function { &FixnumOr { left, right } => FixnumOr { left: find!(left), right: find!(right) }, &FixnumXor { left, right } => FixnumXor { left: find!(left), right: find!(right) }, &FixnumLShift { left, right, state } => FixnumLShift { left: find!(left), right: find!(right), state }, + &FixnumRShift { left, right } => FixnumRShift { left: find!(left), right: find!(right) }, &ObjToString { val, cd, state } => ObjToString { val: find!(val), cd, @@ -2076,6 +2080,7 @@ impl Function { Insn::FixnumOr { .. } => types::Fixnum, Insn::FixnumXor { .. } => types::Fixnum, Insn::FixnumLShift { .. } => types::Fixnum, + Insn::FixnumRShift { .. } => types::Fixnum, Insn::PutSpecialObject { .. } => types::BasicObject, Insn::SendWithoutBlock { .. } => types::BasicObject, Insn::SendWithoutBlockDirect { .. } => types::BasicObject, @@ -3639,6 +3644,7 @@ impl Function { | &Insn::FixnumAnd { left, right } | &Insn::FixnumOr { left, right } | &Insn::FixnumXor { left, right } + | &Insn::FixnumRShift { left, right } | &Insn::IsBitEqual { left, right } | &Insn::IsBitNotEqual { left, right } => { @@ -4403,12 +4409,26 @@ impl Function { | Insn::FixnumAnd { left, right } | Insn::FixnumOr { left, right } | Insn::FixnumXor { left, right } - | Insn::FixnumLShift { left, right, .. } | Insn::NewRangeFixnum { low: left, high: right, .. } => { self.assert_subtype(insn_id, left, types::Fixnum)?; self.assert_subtype(insn_id, right, types::Fixnum) } + Insn::FixnumLShift { left, right, .. } + | Insn::FixnumRShift { left, right, .. } => { + self.assert_subtype(insn_id, left, types::Fixnum)?; + self.assert_subtype(insn_id, right, types::Fixnum)?; + let Some(obj) = self.type_of(right).fixnum_value() else { + return Err(ValidationError::MismatchedOperandType(insn_id, right, "<a compile-time constant>".into(), "<unknown>".into())); + }; + if obj < 0 { + return Err(ValidationError::MismatchedOperandType(insn_id, right, "<positive>".into(), format!("{obj}"))); + } + if obj > 63 { + return Err(ValidationError::MismatchedOperandType(insn_id, right, "<less than 64>".into(), format!("{obj}"))); + } + Ok(()) + } Insn::GuardBitEquals { val, expected, .. } => { match expected { Const::Value(_) => self.assert_subtype(insn_id, val, types::RubyValue), diff --git a/zjit/src/hir/opt_tests.rs b/zjit/src/hir/opt_tests.rs index 60f5814973..610986d62f 100644 --- a/zjit/src/hir/opt_tests.rs +++ b/zjit/src/hir/opt_tests.rs @@ -6826,6 +6826,111 @@ mod hir_opt_tests { } #[test] + fn test_inline_integer_gtgt_with_known_fixnum() { + eval(" + def test(x) = x >> 5 + test(4) + "); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:2: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal l0, SP@4 + Jump bb2(v1, v2) + bb1(v5:BasicObject, v6:BasicObject): + EntryPoint JIT(0) + Jump bb2(v5, v6) + bb2(v8:BasicObject, v9:BasicObject): + v14:Fixnum[5] = Const Value(5) + PatchPoint MethodRedefined(Integer@0x1000, >>@0x1008, cme:0x1010) + v23:Fixnum = GuardType v9, Fixnum + v24:Fixnum = FixnumRShift v23, v14 + IncrCounter inline_cfunc_optimized_send_count + CheckInterrupts + Return v24 + "); + } + + #[test] + fn test_dont_inline_integer_gtgt_with_negative() { + eval(" + def test(x) = x >> -5 + test(4) + "); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:2: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal l0, SP@4 + Jump bb2(v1, v2) + bb1(v5:BasicObject, v6:BasicObject): + EntryPoint JIT(0) + Jump bb2(v5, v6) + bb2(v8:BasicObject, v9:BasicObject): + v14:Fixnum[-5] = Const Value(-5) + PatchPoint MethodRedefined(Integer@0x1000, >>@0x1008, cme:0x1010) + v23:Fixnum = GuardType v9, Fixnum + v24:BasicObject = CCallWithFrame v23, :Integer#>>@0x1038, v14 + CheckInterrupts + Return v24 + "); + } + + #[test] + fn test_dont_inline_integer_gtgt_with_out_of_range() { + eval(" + def test(x) = x >> 64 + test(4) + "); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:2: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal l0, SP@4 + Jump bb2(v1, v2) + bb1(v5:BasicObject, v6:BasicObject): + EntryPoint JIT(0) + Jump bb2(v5, v6) + bb2(v8:BasicObject, v9:BasicObject): + v14:Fixnum[64] = Const Value(64) + PatchPoint MethodRedefined(Integer@0x1000, >>@0x1008, cme:0x1010) + v23:Fixnum = GuardType v9, Fixnum + v24:BasicObject = CCallWithFrame v23, :Integer#>>@0x1038, v14 + CheckInterrupts + Return v24 + "); + } + + #[test] + fn test_dont_inline_integer_gtgt_with_unknown_fixnum() { + eval(" + def test(x, y) = x >> y + test(4, 5) + "); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:2: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal l0, SP@5 + v3:BasicObject = GetLocal l0, SP@4 + Jump bb2(v1, v2, v3) + bb1(v6:BasicObject, v7:BasicObject, v8:BasicObject): + EntryPoint JIT(0) + Jump bb2(v6, v7, v8) + bb2(v10:BasicObject, v11:BasicObject, v12:BasicObject): + PatchPoint MethodRedefined(Integer@0x1000, >>@0x1008, cme:0x1010) + v25:Fixnum = GuardType v11, Fixnum + v26:BasicObject = CCallWithFrame v25, :Integer#>>@0x1038, v12 + CheckInterrupts + Return v26 + "); + } + + #[test] fn test_optimize_string_append() { eval(r#" def test(x, y) = x << y |
