diff options
author | Maxime Chevalier-Boisvert <maxime.chevalierboisvert@shopify.com> | 2023-08-18 10:05:32 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-18 10:05:32 -0400 |
commit | 314eed8a5ec9f1b46624b277dde75f8079026b7b (patch) | |
tree | fcc70dd0335d52c6f33fba361670ed85222ae646 /yjit | |
parent | 724223b4ca0117306529c9cbcfaedc3a07b840bf (diff) |
YJIT: implement fast path for integer multiplication in opt_mult (#8204)
* YJIT: implement fast path for integer multiplication in opt_mult
* Update yjit/src/codegen.rs
Co-authored-by: Alan Wu <XrXr@users.noreply.github.com>
* Implement mul with overflow checking on arm64
* Fix missing semicolon
* Add arm splitting for lshift, rshift, urshift
---------
Co-authored-by: Alan Wu <XrXr@users.noreply.github.com>
Notes
Notes:
Merged-By: maximecb <maximecb@ruby-lang.org>
Diffstat (limited to 'yjit')
-rw-r--r-- | yjit/src/asm/arm64/inst/mod.rs | 2 | ||||
-rw-r--r-- | yjit/src/asm/arm64/inst/smulh.rs | 60 | ||||
-rw-r--r-- | yjit/src/asm/arm64/mod.rs | 17 | ||||
-rw-r--r-- | yjit/src/backend/arm64/mod.rs | 43 | ||||
-rw-r--r-- | yjit/src/codegen.rs | 38 | ||||
-rw-r--r-- | yjit/src/stats.rs | 1 |
6 files changed, 156 insertions, 5 deletions
diff --git a/yjit/src/asm/arm64/inst/mod.rs b/yjit/src/asm/arm64/inst/mod.rs index 665ebef57c..bfffd914ef 100644 --- a/yjit/src/asm/arm64/inst/mod.rs +++ b/yjit/src/asm/arm64/inst/mod.rs @@ -17,6 +17,7 @@ mod load_store_exclusive; mod logical_imm; mod logical_reg; mod madd; +mod smulh; mod mov; mod nop; mod pc_rel; @@ -42,6 +43,7 @@ pub use load_store_exclusive::LoadStoreExclusive; pub use logical_imm::LogicalImm; pub use logical_reg::LogicalReg; pub use madd::MAdd; +pub use smulh::SMulH; pub use mov::Mov; pub use nop::Nop; pub use pc_rel::PCRelative; diff --git a/yjit/src/asm/arm64/inst/smulh.rs b/yjit/src/asm/arm64/inst/smulh.rs new file mode 100644 index 0000000000..796a19433f --- /dev/null +++ b/yjit/src/asm/arm64/inst/smulh.rs @@ -0,0 +1,60 @@ +/// The struct that represents an A64 signed multipy high instruction +/// +/// +-------------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+ +/// | 31 30 29 28 | 27 26 25 24 | 23 22 21 20 | 19 18 17 16 | 15 14 13 12 | 11 10 09 08 | 07 06 05 04 | 03 02 01 00 | +/// | 1 0 0 1 1 0 1 1 0 1 0 0 | +/// | rm.............. ra.............. rn.............. rd.............. | +/// +-------------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+ +/// +pub struct SMulH { + /// The number of the general-purpose destination register. + rd: u8, + + /// The number of the first general-purpose source register. + rn: u8, + + /// The number of the third general-purpose source register. + ra: u8, + + /// The number of the second general-purpose source register. + rm: u8, +} + +impl SMulH { + /// SMULH + /// https://developer.arm.com/documentation/ddi0602/2023-06/Base-Instructions/SMULH--Signed-Multiply-High- + pub fn smulh(rd: u8, rn: u8, rm: u8) -> Self { + Self { rd, rn, ra: 0b11111, rm } + } +} + +impl From<SMulH> for u32 { + /// Convert an instruction into a 32-bit value. + fn from(inst: SMulH) -> Self { + 0 + | (0b10011011010 << 21) + | ((inst.rm as u32) << 16) + | ((inst.ra as u32) << 10) + | ((inst.rn as u32) << 5) + | (inst.rd as u32) + } +} + +impl From<SMulH> for [u8; 4] { + /// Convert an instruction into a 4 byte array. + fn from(inst: SMulH) -> [u8; 4] { + let result: u32 = inst.into(); + result.to_le_bytes() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_smulh() { + let result: u32 = SMulH::smulh(0, 1, 2).into(); + assert_eq!(0x9b427c20, result); + } +} diff --git a/yjit/src/asm/arm64/mod.rs b/yjit/src/asm/arm64/mod.rs index bcdbda8dc0..eb99c00ba7 100644 --- a/yjit/src/asm/arm64/mod.rs +++ b/yjit/src/asm/arm64/mod.rs @@ -186,7 +186,7 @@ pub fn asr(cb: &mut CodeBlock, rd: A64Opnd, rn: A64Opnd, shift: A64Opnd) { SBFM::asr(rd.reg_no, rn.reg_no, shift.try_into().unwrap(), rd.num_bits).into() }, - _ => panic!("Invalid operand combination to asr instruction."), + _ => panic!("Invalid operand combination to asr instruction: asr {:?}, {:?}, {:?}", rd, rn, shift), }; cb.write_bytes(&bytes); @@ -713,6 +713,21 @@ pub fn mul(cb: &mut CodeBlock, rd: A64Opnd, rn: A64Opnd, rm: A64Opnd) { cb.write_bytes(&bytes); } +/// SMULH - multiply two 64-bit registers to produce a 128-bit result, put the high 64-bits of the result into rd +pub fn smulh(cb: &mut CodeBlock, rd: A64Opnd, rn: A64Opnd, rm: A64Opnd) { + let bytes: [u8; 4] = match (rd, rn, rm) { + (A64Opnd::Reg(rd), A64Opnd::Reg(rn), A64Opnd::Reg(rm)) => { + assert!(rd.num_bits == rn.num_bits && rn.num_bits == rm.num_bits, "Expected registers to be the same size"); + assert!(rd.num_bits == 64, "smulh only applicable to 64-bit registers"); + + SMulH::smulh(rd.reg_no, rn.reg_no, rm.reg_no).into() + }, + _ => panic!("Invalid operand combination to mul instruction") + }; + + cb.write_bytes(&bytes); +} + /// MVN - move a value in a register to another register, negating it pub fn mvn(cb: &mut CodeBlock, rd: A64Opnd, rm: A64Opnd) { let bytes: [u8; 4] = match (rd, rm) { diff --git a/yjit/src/backend/arm64/mod.rs b/yjit/src/backend/arm64/mod.rs index a991a4b215..1007df9cf8 100644 --- a/yjit/src/backend/arm64/mod.rs +++ b/yjit/src/backend/arm64/mod.rs @@ -612,6 +612,19 @@ impl Assembler asm.not(opnd0); }, + Insn::LShift { opnd, shift, .. } | + Insn::RShift { opnd, shift, .. } | + Insn::URShift { opnd, shift, .. } => { + // The operand must be in a register, so + // if we get anything else we need to load it first. + let opnd0 = match opnd { + Opnd::Mem(_) => split_load_operand(asm, *opnd), + _ => *opnd + }; + + *opnd = opnd0; + asm.push_insn(insn); + }, Insn::Store { dest, src } => { // The value being stored must be in a register, so if it's // not already one we'll load it first. @@ -811,6 +824,7 @@ impl Assembler let start_write_pos = cb.get_write_pos(); let mut insn_idx: usize = 0; while let Some(insn) = self.insns.get(insn_idx) { + let mut next_insn_idx = insn_idx + 1; let src_ptr = cb.get_write_ptr(); let had_dropped_bytes = cb.has_dropped_bytes(); let old_label_state = cb.get_label_state(); @@ -863,7 +877,32 @@ impl Assembler subs(cb, out.into(), left.into(), right.into()); }, Insn::Mul { left, right, out } => { - mul(cb, out.into(), left.into(), right.into()); + // If the next instruction is jo (jump on overflow) + match self.insns.get(insn_idx + 1) { + Some(Insn::Jo(target)) => { + // Compute the high 64 bits + smulh(cb, Self::SCRATCH0, left.into(), right.into()); + + // Compute the low 64 bits + // This may clobber one of the input registers, + // so we do it after smulh + mul(cb, out.into(), left.into(), right.into()); + + // Produce a register that is all zeros or all ones + // Based on the sign bit of the 64-bit mul result + asr(cb, Self::SCRATCH1, out.into(), A64Opnd::UImm(63)); + + // If the high 64-bits are not all zeros or all ones, + // matching the sign bit, then we have an overflow + cmp(cb, Self::SCRATCH0, Self::SCRATCH1); + emit_conditional_jump::<{Condition::NE}>(cb, compile_side_exit(*target, self, ocb)); + + next_insn_idx += 1; + } + _ => { + mul(cb, out.into(), left.into(), right.into()); + } + } }, Insn::And { left, right, out } => { and(cb, out.into(), left.into(), right.into()); @@ -1158,7 +1197,7 @@ impl Assembler return Err(()); } } else { - insn_idx += 1; + insn_idx = next_insn_idx; gc_offsets.append(&mut insn_gc_offsets); } } diff --git a/yjit/src/codegen.rs b/yjit/src/codegen.rs index df02ab0cc7..a5c030bc12 100644 --- a/yjit/src/codegen.rs +++ b/yjit/src/codegen.rs @@ -3398,8 +3398,42 @@ fn gen_opt_mult( asm: &mut Assembler, ocb: &mut OutlinedCb, ) -> Option<CodegenStatus> { - // Delegate to send, call the method on the recv - gen_opt_send_without_block(jit, asm, ocb) + let two_fixnums = match asm.ctx.two_fixnums_on_stack(jit) { + Some(two_fixnums) => two_fixnums, + None => { + defer_compilation(jit, asm, ocb); + return Some(EndBlock); + } + }; + + if two_fixnums { + if !assume_bop_not_redefined(jit, asm, ocb, INTEGER_REDEFINED_OP_FLAG, BOP_MULT) { + return None; + } + + // Check that both operands are fixnums + guard_two_fixnums(jit, asm, ocb); + + // Get the operands from the stack + let arg1 = asm.stack_pop(1); + let arg0 = asm.stack_pop(1); + + // Do some bitwise gymnastics to handle tag bits + // x * y is translated to (x >> 1) * (y - 1) + 1 + let arg0_untag = asm.rshift(arg0, Opnd::UImm(1)); + let arg1_untag = asm.sub(arg1, Opnd::UImm(1)); + let out_val = asm.mul(arg0_untag, arg1_untag); + asm.jo(Target::side_exit(Counter::opt_mult_overflow)); + let out_val = asm.add(out_val, Opnd::UImm(1)); + + // Push the output on the stack + let dst = asm.stack_push(Type::Fixnum); + asm.mov(dst, out_val); + + Some(KeepCompiling) + } else { + gen_opt_send_without_block(jit, asm, ocb) + } } fn gen_opt_div( diff --git a/yjit/src/stats.rs b/yjit/src/stats.rs index c6b0ea4e18..9ef3e8e94c 100644 --- a/yjit/src/stats.rs +++ b/yjit/src/stats.rs @@ -343,6 +343,7 @@ make_counters! { opt_plus_overflow, opt_minus_overflow, + opt_mult_overflow, opt_mod_zero, opt_div_zero, |