summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMaxime Chevalier-Boisvert <maxime.chevalierboisvert@shopify.com>2023-08-18 10:05:32 -0400
committerGitHub <noreply@github.com>2023-08-18 10:05:32 -0400
commit314eed8a5ec9f1b46624b277dde75f8079026b7b (patch)
treefcc70dd0335d52c6f33fba361670ed85222ae646
parent724223b4ca0117306529c9cbcfaedc3a07b840bf (diff)
YJIT: implement fast path for integer multiplication in opt_mult (#8204)
* YJIT: implement fast path for integer multiplication in opt_mult * Update yjit/src/codegen.rs Co-authored-by: Alan Wu <XrXr@users.noreply.github.com> * Implement mul with overflow checking on arm64 * Fix missing semicolon * Add arm splitting for lshift, rshift, urshift --------- Co-authored-by: Alan Wu <XrXr@users.noreply.github.com>
Notes
Notes: Merged-By: maximecb <maximecb@ruby-lang.org>
-rw-r--r--bootstraptest/test_yjit.rb15
-rw-r--r--yjit/src/asm/arm64/inst/mod.rs2
-rw-r--r--yjit/src/asm/arm64/inst/smulh.rs60
-rw-r--r--yjit/src/asm/arm64/mod.rs17
-rw-r--r--yjit/src/backend/arm64/mod.rs43
-rw-r--r--yjit/src/codegen.rs38
-rw-r--r--yjit/src/stats.rs1
7 files changed, 171 insertions, 5 deletions
diff --git a/bootstraptest/test_yjit.rb b/bootstraptest/test_yjit.rb
index 249c4c19c7..5dd244e3be 100644
--- a/bootstraptest/test_yjit.rb
+++ b/bootstraptest/test_yjit.rb
@@ -4101,3 +4101,18 @@ assert_equal '6', %q{
Sub.new.number { 3 }
}
+
+# Integer multiplication and overflow
+assert_equal '[6, -6, 9671406556917033397649408, -9671406556917033397649408, 21267647932558653966460912964485513216]', %q{
+ def foo(a, b)
+ a * b
+ end
+
+ r1 = foo(2, 3)
+ r2 = foo(2, -3)
+ r3 = foo(2 << 40, 2 << 41)
+ r4 = foo(2 << 40, -2 << 41)
+ r5 = foo(1 << 62, 1 << 62)
+
+ [r1, r2, r3, r4, r5]
+}
diff --git a/yjit/src/asm/arm64/inst/mod.rs b/yjit/src/asm/arm64/inst/mod.rs
index 665ebef57c..bfffd914ef 100644
--- a/yjit/src/asm/arm64/inst/mod.rs
+++ b/yjit/src/asm/arm64/inst/mod.rs
@@ -17,6 +17,7 @@ mod load_store_exclusive;
mod logical_imm;
mod logical_reg;
mod madd;
+mod smulh;
mod mov;
mod nop;
mod pc_rel;
@@ -42,6 +43,7 @@ pub use load_store_exclusive::LoadStoreExclusive;
pub use logical_imm::LogicalImm;
pub use logical_reg::LogicalReg;
pub use madd::MAdd;
+pub use smulh::SMulH;
pub use mov::Mov;
pub use nop::Nop;
pub use pc_rel::PCRelative;
diff --git a/yjit/src/asm/arm64/inst/smulh.rs b/yjit/src/asm/arm64/inst/smulh.rs
new file mode 100644
index 0000000000..796a19433f
--- /dev/null
+++ b/yjit/src/asm/arm64/inst/smulh.rs
@@ -0,0 +1,60 @@
+/// The struct that represents an A64 signed multipy high instruction
+///
+/// +-------------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+
+/// | 31 30 29 28 | 27 26 25 24 | 23 22 21 20 | 19 18 17 16 | 15 14 13 12 | 11 10 09 08 | 07 06 05 04 | 03 02 01 00 |
+/// | 1 0 0 1 1 0 1 1 0 1 0 0 |
+/// | rm.............. ra.............. rn.............. rd.............. |
+/// +-------------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+
+///
+pub struct SMulH {
+ /// The number of the general-purpose destination register.
+ rd: u8,
+
+ /// The number of the first general-purpose source register.
+ rn: u8,
+
+ /// The number of the third general-purpose source register.
+ ra: u8,
+
+ /// The number of the second general-purpose source register.
+ rm: u8,
+}
+
+impl SMulH {
+ /// SMULH
+ /// https://developer.arm.com/documentation/ddi0602/2023-06/Base-Instructions/SMULH--Signed-Multiply-High-
+ pub fn smulh(rd: u8, rn: u8, rm: u8) -> Self {
+ Self { rd, rn, ra: 0b11111, rm }
+ }
+}
+
+impl From<SMulH> for u32 {
+ /// Convert an instruction into a 32-bit value.
+ fn from(inst: SMulH) -> Self {
+ 0
+ | (0b10011011010 << 21)
+ | ((inst.rm as u32) << 16)
+ | ((inst.ra as u32) << 10)
+ | ((inst.rn as u32) << 5)
+ | (inst.rd as u32)
+ }
+}
+
+impl From<SMulH> for [u8; 4] {
+ /// Convert an instruction into a 4 byte array.
+ fn from(inst: SMulH) -> [u8; 4] {
+ let result: u32 = inst.into();
+ result.to_le_bytes()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_smulh() {
+ let result: u32 = SMulH::smulh(0, 1, 2).into();
+ assert_eq!(0x9b427c20, result);
+ }
+}
diff --git a/yjit/src/asm/arm64/mod.rs b/yjit/src/asm/arm64/mod.rs
index bcdbda8dc0..eb99c00ba7 100644
--- a/yjit/src/asm/arm64/mod.rs
+++ b/yjit/src/asm/arm64/mod.rs
@@ -186,7 +186,7 @@ pub fn asr(cb: &mut CodeBlock, rd: A64Opnd, rn: A64Opnd, shift: A64Opnd) {
SBFM::asr(rd.reg_no, rn.reg_no, shift.try_into().unwrap(), rd.num_bits).into()
},
- _ => panic!("Invalid operand combination to asr instruction."),
+ _ => panic!("Invalid operand combination to asr instruction: asr {:?}, {:?}, {:?}", rd, rn, shift),
};
cb.write_bytes(&bytes);
@@ -713,6 +713,21 @@ pub fn mul(cb: &mut CodeBlock, rd: A64Opnd, rn: A64Opnd, rm: A64Opnd) {
cb.write_bytes(&bytes);
}
+/// SMULH - multiply two 64-bit registers to produce a 128-bit result, put the high 64-bits of the result into rd
+pub fn smulh(cb: &mut CodeBlock, rd: A64Opnd, rn: A64Opnd, rm: A64Opnd) {
+ let bytes: [u8; 4] = match (rd, rn, rm) {
+ (A64Opnd::Reg(rd), A64Opnd::Reg(rn), A64Opnd::Reg(rm)) => {
+ assert!(rd.num_bits == rn.num_bits && rn.num_bits == rm.num_bits, "Expected registers to be the same size");
+ assert!(rd.num_bits == 64, "smulh only applicable to 64-bit registers");
+
+ SMulH::smulh(rd.reg_no, rn.reg_no, rm.reg_no).into()
+ },
+ _ => panic!("Invalid operand combination to mul instruction")
+ };
+
+ cb.write_bytes(&bytes);
+}
+
/// MVN - move a value in a register to another register, negating it
pub fn mvn(cb: &mut CodeBlock, rd: A64Opnd, rm: A64Opnd) {
let bytes: [u8; 4] = match (rd, rm) {
diff --git a/yjit/src/backend/arm64/mod.rs b/yjit/src/backend/arm64/mod.rs
index a991a4b215..1007df9cf8 100644
--- a/yjit/src/backend/arm64/mod.rs
+++ b/yjit/src/backend/arm64/mod.rs
@@ -612,6 +612,19 @@ impl Assembler
asm.not(opnd0);
},
+ Insn::LShift { opnd, shift, .. } |
+ Insn::RShift { opnd, shift, .. } |
+ Insn::URShift { opnd, shift, .. } => {
+ // The operand must be in a register, so
+ // if we get anything else we need to load it first.
+ let opnd0 = match opnd {
+ Opnd::Mem(_) => split_load_operand(asm, *opnd),
+ _ => *opnd
+ };
+
+ *opnd = opnd0;
+ asm.push_insn(insn);
+ },
Insn::Store { dest, src } => {
// The value being stored must be in a register, so if it's
// not already one we'll load it first.
@@ -811,6 +824,7 @@ impl Assembler
let start_write_pos = cb.get_write_pos();
let mut insn_idx: usize = 0;
while let Some(insn) = self.insns.get(insn_idx) {
+ let mut next_insn_idx = insn_idx + 1;
let src_ptr = cb.get_write_ptr();
let had_dropped_bytes = cb.has_dropped_bytes();
let old_label_state = cb.get_label_state();
@@ -863,7 +877,32 @@ impl Assembler
subs(cb, out.into(), left.into(), right.into());
},
Insn::Mul { left, right, out } => {
- mul(cb, out.into(), left.into(), right.into());
+ // If the next instruction is jo (jump on overflow)
+ match self.insns.get(insn_idx + 1) {
+ Some(Insn::Jo(target)) => {
+ // Compute the high 64 bits
+ smulh(cb, Self::SCRATCH0, left.into(), right.into());
+
+ // Compute the low 64 bits
+ // This may clobber one of the input registers,
+ // so we do it after smulh
+ mul(cb, out.into(), left.into(), right.into());
+
+ // Produce a register that is all zeros or all ones
+ // Based on the sign bit of the 64-bit mul result
+ asr(cb, Self::SCRATCH1, out.into(), A64Opnd::UImm(63));
+
+ // If the high 64-bits are not all zeros or all ones,
+ // matching the sign bit, then we have an overflow
+ cmp(cb, Self::SCRATCH0, Self::SCRATCH1);
+ emit_conditional_jump::<{Condition::NE}>(cb, compile_side_exit(*target, self, ocb));
+
+ next_insn_idx += 1;
+ }
+ _ => {
+ mul(cb, out.into(), left.into(), right.into());
+ }
+ }
},
Insn::And { left, right, out } => {
and(cb, out.into(), left.into(), right.into());
@@ -1158,7 +1197,7 @@ impl Assembler
return Err(());
}
} else {
- insn_idx += 1;
+ insn_idx = next_insn_idx;
gc_offsets.append(&mut insn_gc_offsets);
}
}
diff --git a/yjit/src/codegen.rs b/yjit/src/codegen.rs
index df02ab0cc7..a5c030bc12 100644
--- a/yjit/src/codegen.rs
+++ b/yjit/src/codegen.rs
@@ -3398,8 +3398,42 @@ fn gen_opt_mult(
asm: &mut Assembler,
ocb: &mut OutlinedCb,
) -> Option<CodegenStatus> {
- // Delegate to send, call the method on the recv
- gen_opt_send_without_block(jit, asm, ocb)
+ let two_fixnums = match asm.ctx.two_fixnums_on_stack(jit) {
+ Some(two_fixnums) => two_fixnums,
+ None => {
+ defer_compilation(jit, asm, ocb);
+ return Some(EndBlock);
+ }
+ };
+
+ if two_fixnums {
+ if !assume_bop_not_redefined(jit, asm, ocb, INTEGER_REDEFINED_OP_FLAG, BOP_MULT) {
+ return None;
+ }
+
+ // Check that both operands are fixnums
+ guard_two_fixnums(jit, asm, ocb);
+
+ // Get the operands from the stack
+ let arg1 = asm.stack_pop(1);
+ let arg0 = asm.stack_pop(1);
+
+ // Do some bitwise gymnastics to handle tag bits
+ // x * y is translated to (x >> 1) * (y - 1) + 1
+ let arg0_untag = asm.rshift(arg0, Opnd::UImm(1));
+ let arg1_untag = asm.sub(arg1, Opnd::UImm(1));
+ let out_val = asm.mul(arg0_untag, arg1_untag);
+ asm.jo(Target::side_exit(Counter::opt_mult_overflow));
+ let out_val = asm.add(out_val, Opnd::UImm(1));
+
+ // Push the output on the stack
+ let dst = asm.stack_push(Type::Fixnum);
+ asm.mov(dst, out_val);
+
+ Some(KeepCompiling)
+ } else {
+ gen_opt_send_without_block(jit, asm, ocb)
+ }
}
fn gen_opt_div(
diff --git a/yjit/src/stats.rs b/yjit/src/stats.rs
index c6b0ea4e18..9ef3e8e94c 100644
--- a/yjit/src/stats.rs
+++ b/yjit/src/stats.rs
@@ -343,6 +343,7 @@ make_counters! {
opt_plus_overflow,
opt_minus_overflow,
+ opt_mult_overflow,
opt_mod_zero,
opt_div_zero,