summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTakashi Kokubun <takashi.kokubun@shopify.com>2025-07-02 10:37:30 -0700
committerGitHub <noreply@github.com>2025-07-02 10:37:30 -0700
commit6e28574ed08b076783035dc67ed0067550ff6bbe (patch)
treed8c8320bec0a814ab80051cd344598e59eb527de
parenta0bf36a9f42f8d627dacb2f7f3a697d53f712a5c (diff)
ZJIT: Support spilling basic block arguments (#13761)
Co-authored-by: Max Bernstein <max@bernsteinbear.com>
-rw-r--r--test/ruby/test_zjit.rb34
-rw-r--r--zjit/src/backend/arm64/mod.rs1
-rw-r--r--zjit/src/backend/lir.rs15
-rw-r--r--zjit/src/backend/x86_64/mod.rs1
-rw-r--r--zjit/src/codegen.rs87
5 files changed, 120 insertions, 18 deletions
diff --git a/test/ruby/test_zjit.rb b/test/ruby/test_zjit.rb
index 920ec461a5..0c73e6b456 100644
--- a/test/ruby/test_zjit.rb
+++ b/test/ruby/test_zjit.rb
@@ -693,6 +693,40 @@ class TestZJIT < Test::Unit::TestCase
}, call_threshold: 5, num_profiles: 3
end
+ def test_spilled_basic_block_args
+ assert_compiles '55', %q{
+ def test(n1, n2)
+ n3 = 3
+ n4 = 4
+ n5 = 5
+ n6 = 6
+ n7 = 7
+ n8 = 8
+ n9 = 9
+ n10 = 10
+ if n1 < n2
+ n1 + n2 + n3 + n4 + n5 + n6 + n7 + n8 + n9 + n10
+ end
+ end
+ test(1, 2)
+ }
+ end
+
+ def test_spilled_method_args
+ omit 'CCall with spilled arguments is not implemented yet'
+ assert_compiles '55', %q{
+ def foo(n1, n2, n3, n4, n5, n6, n7, n8, n9, n10)
+ n1 + n2 + n3 + n4 + n5 + n6 + n7 + n8 + n9 + n10
+ end
+
+ def test
+ foo(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
+ end
+
+ test
+ }
+ end
+
def test_opt_aref_with
assert_compiles ':ok', %q{
def aref_with(hash) = hash["key"]
diff --git a/zjit/src/backend/arm64/mod.rs b/zjit/src/backend/arm64/mod.rs
index dd1eb52d34..3c18a57dd0 100644
--- a/zjit/src/backend/arm64/mod.rs
+++ b/zjit/src/backend/arm64/mod.rs
@@ -28,6 +28,7 @@ pub const _C_ARG_OPNDS: [Opnd; 6] = [
// C return value register on this platform
pub const C_RET_REG: Reg = X0_REG;
pub const _C_RET_OPND: Opnd = Opnd::Reg(X0_REG);
+pub const _NATIVE_STACK_PTR: Opnd = Opnd::Reg(XZR_REG);
// These constants define the way we work with Arm64's stack pointer. The stack
// pointer always needs to be aligned to a 16-byte boundary.
diff --git a/zjit/src/backend/lir.rs b/zjit/src/backend/lir.rs
index 5fe4b85b62..c168170b2f 100644
--- a/zjit/src/backend/lir.rs
+++ b/zjit/src/backend/lir.rs
@@ -16,6 +16,7 @@ pub const SP: Opnd = _SP;
pub const C_ARG_OPNDS: [Opnd; 6] = _C_ARG_OPNDS;
pub const C_RET_OPND: Opnd = _C_RET_OPND;
+pub const NATIVE_STACK_PTR: Opnd = _NATIVE_STACK_PTR;
pub use crate::backend::current::{Reg, C_RET_REG};
// Memory operand base
@@ -277,7 +278,7 @@ pub enum Target
/// Pointer to a piece of ZJIT-generated code
CodePtr(CodePtr),
// Side exit with a counter
- SideExit { pc: *const VALUE, stack: Vec<Opnd>, locals: Vec<Opnd> },
+ SideExit { pc: *const VALUE, stack: Vec<Opnd>, locals: Vec<Opnd>, c_stack_bytes: usize },
/// A label within the generated code
Label(Label),
}
@@ -1774,7 +1775,7 @@ impl Assembler
for (idx, target) in targets {
// Compile a side exit. Note that this is past the split pass and alloc_regs(),
// so you can't use a VReg or an instruction that needs to be split.
- if let Target::SideExit { pc, stack, locals } = target {
+ if let Target::SideExit { pc, stack, locals, c_stack_bytes } = target {
let side_exit_label = self.new_label("side_exit".into());
self.write_label(side_exit_label.clone());
@@ -1810,6 +1811,11 @@ impl Assembler
let cfp_sp = Opnd::mem(64, CFP, RUBY_OFFSET_CFP_SP);
self.store(cfp_sp, Opnd::Reg(Assembler::SCRATCH_REG));
+ if c_stack_bytes > 0 {
+ asm_comment!(self, "restore C stack pointer");
+ self.add_into(NATIVE_STACK_PTR, c_stack_bytes.into());
+ }
+
asm_comment!(self, "exit to the interpreter");
self.frame_teardown();
self.mov(C_RET_OPND, Opnd::UImm(Qundef.as_u64()));
@@ -1842,6 +1848,11 @@ impl Assembler {
out
}
+ pub fn add_into(&mut self, left: Opnd, right: Opnd) -> Opnd {
+ self.push_insn(Insn::Add { left, right, out: left });
+ left
+ }
+
#[must_use]
pub fn and(&mut self, left: Opnd, right: Opnd) -> Opnd {
let out = self.new_vreg(Opnd::match_num_bits(&[left, right]));
diff --git a/zjit/src/backend/x86_64/mod.rs b/zjit/src/backend/x86_64/mod.rs
index d83fc184f9..793a096365 100644
--- a/zjit/src/backend/x86_64/mod.rs
+++ b/zjit/src/backend/x86_64/mod.rs
@@ -28,6 +28,7 @@ pub const _C_ARG_OPNDS: [Opnd; 6] = [
// C return value register on this platform
pub const C_RET_REG: Reg = RAX_REG;
pub const _C_RET_OPND: Opnd = Opnd::Reg(RAX_REG);
+pub const _NATIVE_STACK_PTR: Opnd = Opnd::Reg(RSP_REG);
impl CodeBlock {
// The number of bytes that are generated by jmp_ptr
diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs
index 33a8af6868..6d73a3a32d 100644
--- a/zjit/src/codegen.rs
+++ b/zjit/src/codegen.rs
@@ -6,7 +6,7 @@ use crate::backend::current::{Reg, ALLOC_REGS};
use crate::profile::get_or_create_iseq_payload;
use crate::state::ZJITState;
use crate::{asm::CodeBlock, cruby::*, options::debug, virtualmem::CodePtr};
-use crate::backend::lir::{self, asm_comment, Assembler, Opnd, Target, CFP, C_ARG_OPNDS, C_RET_OPND, EC, SP};
+use crate::backend::lir::{self, asm_comment, Assembler, Opnd, Target, CFP, C_ARG_OPNDS, C_RET_OPND, EC, NATIVE_STACK_PTR, SP};
use crate::hir::{iseq_to_hir, Block, BlockId, BranchEdge, CallInfo, RangeType, SELF_PARAM_IDX, SpecialObjectType};
use crate::hir::{Const, FrameState, Function, Insn, InsnId};
use crate::hir_type::{types::Fixnum, Type};
@@ -25,16 +25,20 @@ struct JITState {
/// Branches to an ISEQ that need to be compiled later
branch_iseqs: Vec<(Rc<Branch>, IseqPtr)>,
+
+ /// The number of bytes allocated for basic block arguments spilled onto the C stack
+ c_stack_bytes: usize,
}
impl JITState {
/// Create a new JITState instance
- fn new(iseq: IseqPtr, num_insns: usize, num_blocks: usize) -> Self {
+ fn new(iseq: IseqPtr, num_insns: usize, num_blocks: usize, c_stack_bytes: usize) -> Self {
JITState {
iseq,
opnds: vec![None; num_insns],
labels: vec![None; num_blocks],
branch_iseqs: Vec::default(),
+ c_stack_bytes,
}
}
@@ -179,7 +183,8 @@ fn gen_iseq(cb: &mut CodeBlock, iseq: IseqPtr) -> Option<(CodePtr, Vec<(Rc<Branc
/// Compile a function
fn gen_function(cb: &mut CodeBlock, iseq: IseqPtr, function: &Function) -> Option<(CodePtr, Vec<(Rc<Branch>, IseqPtr)>)> {
- let mut jit = JITState::new(iseq, function.num_insns(), function.num_blocks());
+ let c_stack_bytes = aligned_stack_bytes(max_num_params(function).saturating_sub(ALLOC_REGS.len()));
+ let mut jit = JITState::new(iseq, function.num_insns(), function.num_blocks(), c_stack_bytes);
let mut asm = Assembler::new();
// Compile each basic block
@@ -195,6 +200,13 @@ fn gen_function(cb: &mut CodeBlock, iseq: IseqPtr, function: &Function) -> Optio
// Set up the frame at the first block
if block_id == BlockId(0) {
asm.frame_setup();
+
+ // Bump the C stack pointer for basic block arguments
+ if jit.c_stack_bytes > 0 {
+ asm_comment!(asm, "bump C stack pointer");
+ let new_sp = asm.sub(NATIVE_STACK_PTR, jit.c_stack_bytes.into());
+ asm.mov(NATIVE_STACK_PTR, new_sp);
+ }
}
// Compile all parameters
@@ -252,7 +264,7 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio
Insn::SendWithoutBlock { call_info, cd, state, self_val, args, .. } => gen_send_without_block(jit, asm, call_info, *cd, &function.frame_state(*state), self_val, args)?,
Insn::SendWithoutBlockDirect { cme, iseq, self_val, args, state, .. } => gen_send_without_block_direct(cb, jit, asm, *cme, *iseq, opnd!(self_val), args, &function.frame_state(*state))?,
Insn::InvokeBuiltin { bf, args, state } => gen_invokebuiltin(jit, asm, &function.frame_state(*state), bf, args)?,
- Insn::Return { val } => return Some(gen_return(asm, opnd!(val))?),
+ Insn::Return { val } => return Some(gen_return(jit, asm, opnd!(val))?),
Insn::FixnumAdd { left, right, state } => gen_fixnum_add(jit, asm, opnd!(left), opnd!(right), &function.frame_state(*state))?,
Insn::FixnumSub { left, right, state } => gen_fixnum_sub(jit, asm, opnd!(left), opnd!(right), &function.frame_state(*state))?,
Insn::FixnumMult { left, right, state } => gen_fixnum_mult(jit, asm, opnd!(left), opnd!(right), &function.frame_state(*state))?,
@@ -518,7 +530,16 @@ fn gen_branch_params(jit: &mut JITState, asm: &mut Assembler, branch: &BranchEdg
asm_comment!(asm, "set branch params: {}", branch.args.len());
let mut moves: Vec<(Reg, Opnd)> = vec![];
for (idx, &arg) in branch.args.iter().enumerate() {
- moves.push((param_reg(idx), jit.get_opnd(arg)?));
+ match param_opnd(idx) {
+ Opnd::Reg(reg) => {
+ // If a parameter is a register, we need to parallel-move it
+ moves.push((reg, jit.get_opnd(arg)?));
+ },
+ param => {
+ // If a parameter is memory, we set it beforehand
+ asm.mov(param, jit.get_opnd(arg)?);
+ }
+ }
}
asm.parallel_mov(moves);
}
@@ -555,7 +576,13 @@ fn gen_const(val: VALUE) -> lir::Opnd {
/// Compile a basic block argument
fn gen_param(asm: &mut Assembler, idx: usize) -> lir::Opnd {
- asm.live_reg_opnd(Opnd::Reg(param_reg(idx)))
+ // Allocate a register or a stack slot
+ match param_opnd(idx) {
+ // If it's a register, insert LiveReg instruction to reserve the register
+ // in the register pool for register allocation.
+ param @ Opnd::Reg(_) => asm.live_reg_opnd(param),
+ param => param,
+ }
}
/// Compile a jump to a basic block
@@ -797,7 +824,7 @@ fn gen_new_range(
}
/// Compile code that exits from JIT code with a return value
-fn gen_return(asm: &mut Assembler, val: lir::Opnd) -> Option<()> {
+fn gen_return(jit: &JITState, asm: &mut Assembler, val: lir::Opnd) -> Option<()> {
// Pop the current frame (ec->cfp++)
// Note: the return PC is already in the previous CFP
asm_comment!(asm, "pop stack frame");
@@ -805,6 +832,13 @@ fn gen_return(asm: &mut Assembler, val: lir::Opnd) -> Option<()> {
asm.mov(CFP, incr_cfp);
asm.mov(Opnd::mem(64, EC, RUBY_OFFSET_EC_CFP), CFP);
+ // Restore the C stack pointer bumped for basic block arguments
+ if jit.c_stack_bytes > 0 {
+ asm_comment!(asm, "restore C stack pointer");
+ let new_sp = asm.add(NATIVE_STACK_PTR, jit.c_stack_bytes.into());
+ asm.mov(NATIVE_STACK_PTR, new_sp);
+ }
+
asm.frame_teardown();
// Return from the function
@@ -992,17 +1026,15 @@ fn gen_push_frame(asm: &mut Assembler, argc: usize, state: &FrameState, frame: C
asm.mov(cfp_opnd(RUBY_OFFSET_CFP_BLOCK_CODE), 0.into());
}
-/// Return a register we use for the basic block argument at a given index
-fn param_reg(idx: usize) -> Reg {
- // To simplify the implementation, allocate a fixed register for each basic block argument for now.
+/// Return an operand we use for the basic block argument at a given index
+fn param_opnd(idx: usize) -> Opnd {
+ // To simplify the implementation, allocate a fixed register or a stack slot for each basic block argument for now.
// TODO: Allow allocating arbitrary registers for basic block arguments
- if idx >= ALLOC_REGS.len() {
- unimplemented!(
- "register spilling not yet implemented, too many basic block arguments ({}/{})",
- idx + 1, ALLOC_REGS.len()
- );
+ if idx < ALLOC_REGS.len() {
+ Opnd::Reg(ALLOC_REGS[idx])
+ } else {
+ Opnd::mem(64, NATIVE_STACK_PTR, -((idx - ALLOC_REGS.len() + 1) as i32) * SIZEOF_VALUE_I32)
}
- ALLOC_REGS[idx]
}
/// Inverse of ep_offset_to_local_idx(). See ep_offset_to_local_idx() for details.
@@ -1045,6 +1077,7 @@ fn side_exit(jit: &mut JITState, state: &FrameState) -> Option<Target> {
pc: state.pc,
stack,
locals,
+ c_stack_bytes: jit.c_stack_bytes,
};
Some(target)
}
@@ -1063,6 +1096,28 @@ fn iseq_entry_escapes_ep(iseq: IseqPtr) -> bool {
}
}
+/// Returne the maximum number of arguments for a block in a given function
+fn max_num_params(function: &Function) -> usize {
+ let reverse_post_order = function.rpo();
+ reverse_post_order.iter().map(|&block_id| {
+ let block = function.block(block_id);
+ block.params().len()
+ }).max().unwrap_or(0)
+}
+
+/// Given the number of spill slots needed for a function, return the number of bytes
+/// the function needs to allocate on the stack for the stack frame.
+fn aligned_stack_bytes(num_slots: usize) -> usize {
+ // Both x86_64 and arm64 require the stack to be aligned to 16 bytes.
+ // Since SIZEOF_VALUE is 8 bytes, we need to round up the size to the nearest even number.
+ let num_slots = if num_slots % 2 == 0 {
+ num_slots
+ } else {
+ num_slots + 1
+ };
+ num_slots * SIZEOF_VALUE
+}
+
impl Assembler {
/// Make a C call while marking the start and end positions of it
fn ccall_with_branch(&mut self, fptr: *const u8, opnds: Vec<Opnd>, branch: &Rc<Branch>) -> Opnd {