summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStan Lo <stan.lo@shopify.com>2025-07-09 21:20:39 +0100
committerMax Bernstein <tekknolagi@gmail.com>2025-07-09 17:50:41 -0400
commite2a81c738c453d072bdeae1e604a5a95c3376a9f (patch)
treeb90242ca79cf0e562ffc52a4768875b31eb78c8f
parent10b582dab64509ed8de949b02b1c766f88f04621 (diff)
ZJIT: Optimize `opt_and` and `opt_or` instructions for Fixnum
-rw-r--r--test/ruby/test_zjit.rb36
-rw-r--r--zjit/src/codegen.rs12
-rw-r--r--zjit/src/hir.rs54
3 files changed, 101 insertions, 1 deletions
diff --git a/test/ruby/test_zjit.rb b/test/ruby/test_zjit.rb
index 38baf10adb..26f183a2f0 100644
--- a/test/ruby/test_zjit.rb
+++ b/test/ruby/test_zjit.rb
@@ -330,6 +330,42 @@ class TestZJIT < Test::Unit::TestCase
RUBY
end
+ def test_fixnum_and
+ assert_compiles '1', %q{
+ def test(a, b) = a & b
+ test(2, 2)
+ test(2, 2)
+ test(5, 3)
+ }, call_threshold: 2, insns: [:opt_and]
+ end
+
+ def test_fixnum_and_fallthrough
+ assert_compiles 'false', %q{
+ def test(a, b) = a & b
+ test(2, 2)
+ test(2, 2)
+ test(true, false)
+ }, call_threshold: 2, insns: [:opt_and]
+ end
+
+ def test_fixnum_or
+ assert_compiles '3', %q{
+ def test(a, b) = a | b
+ test(5, 3)
+ test(5, 3)
+ test(1, 2)
+ }, call_threshold: 2, insns: [:opt_or]
+ end
+
+ def test_fixnum_or_fallthrough
+ assert_compiles 'true', %q{
+ def test(a, b) = a | b
+ test(2, 2)
+ test(2, 2)
+ test(true, false)
+ }, call_threshold: 2, insns: [:opt_or]
+ end
+
def test_opt_not
assert_compiles('[true, true, false]', <<~RUBY, insns: [:opt_not])
def test(obj) = !obj
diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs
index 73ca1de74a..3432374ccb 100644
--- a/zjit/src/codegen.rs
+++ b/zjit/src/codegen.rs
@@ -291,6 +291,8 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio
Insn::FixnumLe { left, right } => gen_fixnum_le(asm, opnd!(left), opnd!(right))?,
Insn::FixnumGt { left, right } => gen_fixnum_gt(asm, opnd!(left), opnd!(right))?,
Insn::FixnumGe { left, right } => gen_fixnum_ge(asm, opnd!(left), opnd!(right))?,
+ Insn::FixnumAnd { left, right } => gen_fixnum_and(asm, opnd!(left), opnd!(right))?,
+ Insn::FixnumOr { left, right } => gen_fixnum_or(asm, opnd!(left), opnd!(right))?,
Insn::IsNil { val } => gen_isnil(asm, opnd!(val))?,
Insn::Test { val } => gen_test(asm, opnd!(val))?,
Insn::GuardType { val, guard_type, state } => gen_guard_type(jit, asm, opnd!(val), *guard_type, &function.frame_state(*state))?,
@@ -939,6 +941,16 @@ fn gen_fixnum_ge(asm: &mut Assembler, left: lir::Opnd, right: lir::Opnd) -> Opti
Some(asm.csel_ge(Qtrue.into(), Qfalse.into()))
}
+/// Compile Fixnum & Fixnum
+fn gen_fixnum_and(asm: &mut Assembler, left: lir::Opnd, right: lir::Opnd) -> Option<lir::Opnd> {
+ Some(asm.and(left, right))
+}
+
+/// Compile Fixnum | Fixnum
+fn gen_fixnum_or(asm: &mut Assembler, left: lir::Opnd, right: lir::Opnd) -> Option<lir::Opnd> {
+ Some(asm.or(left, right))
+}
+
// Compile val == nil
fn gen_isnil(asm: &mut Assembler, val: lir::Opnd) -> Option<lir::Opnd> {
asm.cmp(val, Qnil.into());
diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs
index dab3b6698d..c12ddfda57 100644
--- a/zjit/src/hir.rs
+++ b/zjit/src/hir.rs
@@ -523,7 +523,7 @@ pub enum Insn {
/// Non-local control flow. See the throw YARV instruction
Throw { throw_state: u32, val: InsnId },
- /// Fixnum +, -, *, /, %, ==, !=, <, <=, >, >=
+ /// Fixnum +, -, *, /, %, ==, !=, <, <=, >, >=, &, |
FixnumAdd { left: InsnId, right: InsnId, state: InsnId },
FixnumSub { left: InsnId, right: InsnId, state: InsnId },
FixnumMult { left: InsnId, right: InsnId, state: InsnId },
@@ -535,6 +535,8 @@ pub enum Insn {
FixnumLe { left: InsnId, right: InsnId },
FixnumGt { left: InsnId, right: InsnId },
FixnumGe { left: InsnId, right: InsnId },
+ FixnumAnd { left: InsnId, right: InsnId },
+ FixnumOr { left: InsnId, right: InsnId },
// Distinct from `SendWithoutBlock` with `mid:to_s` because does not have a patch point for String to_s being redefined
ObjToString { val: InsnId, call_info: CallInfo, cd: *const rb_call_data, state: InsnId },
@@ -604,6 +606,8 @@ impl Insn {
Insn::FixnumLe { .. } => false,
Insn::FixnumGt { .. } => false,
Insn::FixnumGe { .. } => false,
+ Insn::FixnumAnd { .. } => false,
+ Insn::FixnumOr { .. } => false,
Insn::GetLocal { .. } => false,
Insn::IsNil { .. } => false,
Insn::CCall { elidable, .. } => !elidable,
@@ -705,6 +709,8 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> {
Insn::FixnumLe { left, right, .. } => { write!(f, "FixnumLe {left}, {right}") },
Insn::FixnumGt { left, right, .. } => { write!(f, "FixnumGt {left}, {right}") },
Insn::FixnumGe { left, right, .. } => { write!(f, "FixnumGe {left}, {right}") },
+ Insn::FixnumAnd { left, right, .. } => { write!(f, "FixnumAnd {left}, {right}") },
+ Insn::FixnumOr { left, right, .. } => { write!(f, "FixnumOr {left}, {right}") },
Insn::GuardType { val, guard_type, .. } => { write!(f, "GuardType {val}, {}", guard_type.print(self.ptr_map)) },
Insn::GuardBitEquals { val, expected, .. } => { write!(f, "GuardBitEquals {val}, {}", expected.print(self.ptr_map)) },
Insn::PatchPoint(invariant) => { write!(f, "PatchPoint {}", invariant.print(self.ptr_map)) },
@@ -1098,6 +1104,8 @@ impl Function {
FixnumGe { left, right } => FixnumGe { left: find!(*left), right: find!(*right) },
FixnumLt { left, right } => FixnumLt { left: find!(*left), right: find!(*right) },
FixnumLe { left, right } => FixnumLe { left: find!(*left), right: find!(*right) },
+ FixnumAnd { left, right } => FixnumAnd { left: find!(*left), right: find!(*right) },
+ FixnumOr { left, right } => FixnumOr { left: find!(*left), right: find!(*right) },
ObjToString { val, call_info, cd, state } => ObjToString {
val: find!(*val),
call_info: call_info.clone(),
@@ -1225,6 +1233,8 @@ impl Function {
Insn::FixnumLe { .. } => types::BoolExact,
Insn::FixnumGt { .. } => types::BoolExact,
Insn::FixnumGe { .. } => types::BoolExact,
+ Insn::FixnumAnd { .. } => types::Fixnum,
+ Insn::FixnumOr { .. } => types::Fixnum,
Insn::PutSpecialObject { .. } => types::BasicObject,
Insn::SendWithoutBlock { .. } => types::BasicObject,
Insn::SendWithoutBlockDirect { .. } => types::BasicObject,
@@ -1444,6 +1454,10 @@ impl Function {
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumGt { left, right }, BOP_GT, self_val, args[0], state),
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == ">=" && args.len() == 1 =>
self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumGe { left, right }, BOP_GE, self_val, args[0], state),
+ Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == "&" && args.len() == 1 =>
+ self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumAnd { left, right }, BOP_AND, self_val, args[0], state),
+ Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, state, .. } if method_name == "|" && args.len() == 1 =>
+ self.try_rewrite_fixnum_op(block, insn_id, &|left, right| Insn::FixnumOr { left, right }, BOP_OR, self_val, args[0], state),
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, .. } if method_name == "freeze" && args.len() == 0 =>
self.try_rewrite_freeze(block, insn_id, self_val),
Insn::SendWithoutBlock { self_val, call_info: CallInfo { method_name }, args, .. } if method_name == "-@" && args.len() == 0 =>
@@ -1856,6 +1870,8 @@ impl Function {
| Insn::FixnumGe { left, right }
| Insn::FixnumEq { left, right }
| Insn::FixnumNeq { left, right }
+ | Insn::FixnumAnd { left, right }
+ | Insn::FixnumOr { left, right }
=> {
worklist.push_back(left);
worklist.push_back(right);
@@ -6787,4 +6803,40 @@ mod opt_tests {
Return v8
"#]]);
}
+
+ #[test]
+ fn test_guard_fixnum_and_fixnum() {
+ eval("
+ def test(x, y) = x & y
+
+ test(1, 2)
+ ");
+ assert_optimized_method_hir("test", expect![[r#"
+ fn test:
+ bb0(v0:BasicObject, v1:BasicObject, v2:BasicObject):
+ PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, 28)
+ v8:Fixnum = GuardType v1, Fixnum
+ v9:Fixnum = GuardType v2, Fixnum
+ v10:Fixnum = FixnumAnd v8, v9
+ Return v10
+ "#]]);
+ }
+
+ #[test]
+ fn test_guard_fixnum_or_fixnum() {
+ eval("
+ def test(x, y) = x | y
+
+ test(1, 2)
+ ");
+ assert_optimized_method_hir("test", expect![[r#"
+ fn test:
+ bb0(v0:BasicObject, v1:BasicObject, v2:BasicObject):
+ PatchPoint BOPRedefined(INTEGER_REDEFINED_OP_FLAG, 29)
+ v8:Fixnum = GuardType v1, Fixnum
+ v9:Fixnum = GuardType v2, Fixnum
+ v10:Fixnum = FixnumOr v8, v9
+ Return v10
+ "#]]);
+ }
}