diff options
| author | Takashi Kokubun <takashi.kokubun@shopify.com> | 2025-07-02 10:37:30 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-07-02 10:37:30 -0700 |
| commit | 6e28574ed08b076783035dc67ed0067550ff6bbe (patch) | |
| tree | d8c8320bec0a814ab80051cd344598e59eb527de | |
| parent | a0bf36a9f42f8d627dacb2f7f3a697d53f712a5c (diff) | |
ZJIT: Support spilling basic block arguments (#13761)
Co-authored-by: Max Bernstein <max@bernsteinbear.com>
| -rw-r--r-- | test/ruby/test_zjit.rb | 34 | ||||
| -rw-r--r-- | zjit/src/backend/arm64/mod.rs | 1 | ||||
| -rw-r--r-- | zjit/src/backend/lir.rs | 15 | ||||
| -rw-r--r-- | zjit/src/backend/x86_64/mod.rs | 1 | ||||
| -rw-r--r-- | zjit/src/codegen.rs | 87 |
5 files changed, 120 insertions, 18 deletions
diff --git a/test/ruby/test_zjit.rb b/test/ruby/test_zjit.rb index 920ec461a5..0c73e6b456 100644 --- a/test/ruby/test_zjit.rb +++ b/test/ruby/test_zjit.rb @@ -693,6 +693,40 @@ class TestZJIT < Test::Unit::TestCase }, call_threshold: 5, num_profiles: 3 end + def test_spilled_basic_block_args + assert_compiles '55', %q{ + def test(n1, n2) + n3 = 3 + n4 = 4 + n5 = 5 + n6 = 6 + n7 = 7 + n8 = 8 + n9 = 9 + n10 = 10 + if n1 < n2 + n1 + n2 + n3 + n4 + n5 + n6 + n7 + n8 + n9 + n10 + end + end + test(1, 2) + } + end + + def test_spilled_method_args + omit 'CCall with spilled arguments is not implemented yet' + assert_compiles '55', %q{ + def foo(n1, n2, n3, n4, n5, n6, n7, n8, n9, n10) + n1 + n2 + n3 + n4 + n5 + n6 + n7 + n8 + n9 + n10 + end + + def test + foo(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + end + + test + } + end + def test_opt_aref_with assert_compiles ':ok', %q{ def aref_with(hash) = hash["key"] diff --git a/zjit/src/backend/arm64/mod.rs b/zjit/src/backend/arm64/mod.rs index dd1eb52d34..3c18a57dd0 100644 --- a/zjit/src/backend/arm64/mod.rs +++ b/zjit/src/backend/arm64/mod.rs @@ -28,6 +28,7 @@ pub const _C_ARG_OPNDS: [Opnd; 6] = [ // C return value register on this platform pub const C_RET_REG: Reg = X0_REG; pub const _C_RET_OPND: Opnd = Opnd::Reg(X0_REG); +pub const _NATIVE_STACK_PTR: Opnd = Opnd::Reg(XZR_REG); // These constants define the way we work with Arm64's stack pointer. The stack // pointer always needs to be aligned to a 16-byte boundary. diff --git a/zjit/src/backend/lir.rs b/zjit/src/backend/lir.rs index 5fe4b85b62..c168170b2f 100644 --- a/zjit/src/backend/lir.rs +++ b/zjit/src/backend/lir.rs @@ -16,6 +16,7 @@ pub const SP: Opnd = _SP; pub const C_ARG_OPNDS: [Opnd; 6] = _C_ARG_OPNDS; pub const C_RET_OPND: Opnd = _C_RET_OPND; +pub const NATIVE_STACK_PTR: Opnd = _NATIVE_STACK_PTR; pub use crate::backend::current::{Reg, C_RET_REG}; // Memory operand base @@ -277,7 +278,7 @@ pub enum Target /// Pointer to a piece of ZJIT-generated code CodePtr(CodePtr), // Side exit with a counter - SideExit { pc: *const VALUE, stack: Vec<Opnd>, locals: Vec<Opnd> }, + SideExit { pc: *const VALUE, stack: Vec<Opnd>, locals: Vec<Opnd>, c_stack_bytes: usize }, /// A label within the generated code Label(Label), } @@ -1774,7 +1775,7 @@ impl Assembler for (idx, target) in targets { // Compile a side exit. Note that this is past the split pass and alloc_regs(), // so you can't use a VReg or an instruction that needs to be split. - if let Target::SideExit { pc, stack, locals } = target { + if let Target::SideExit { pc, stack, locals, c_stack_bytes } = target { let side_exit_label = self.new_label("side_exit".into()); self.write_label(side_exit_label.clone()); @@ -1810,6 +1811,11 @@ impl Assembler let cfp_sp = Opnd::mem(64, CFP, RUBY_OFFSET_CFP_SP); self.store(cfp_sp, Opnd::Reg(Assembler::SCRATCH_REG)); + if c_stack_bytes > 0 { + asm_comment!(self, "restore C stack pointer"); + self.add_into(NATIVE_STACK_PTR, c_stack_bytes.into()); + } + asm_comment!(self, "exit to the interpreter"); self.frame_teardown(); self.mov(C_RET_OPND, Opnd::UImm(Qundef.as_u64())); @@ -1842,6 +1848,11 @@ impl Assembler { out } + pub fn add_into(&mut self, left: Opnd, right: Opnd) -> Opnd { + self.push_insn(Insn::Add { left, right, out: left }); + left + } + #[must_use] pub fn and(&mut self, left: Opnd, right: Opnd) -> Opnd { let out = self.new_vreg(Opnd::match_num_bits(&[left, right])); diff --git a/zjit/src/backend/x86_64/mod.rs b/zjit/src/backend/x86_64/mod.rs index d83fc184f9..793a096365 100644 --- a/zjit/src/backend/x86_64/mod.rs +++ b/zjit/src/backend/x86_64/mod.rs @@ -28,6 +28,7 @@ pub const _C_ARG_OPNDS: [Opnd; 6] = [ // C return value register on this platform pub const C_RET_REG: Reg = RAX_REG; pub const _C_RET_OPND: Opnd = Opnd::Reg(RAX_REG); +pub const _NATIVE_STACK_PTR: Opnd = Opnd::Reg(RSP_REG); impl CodeBlock { // The number of bytes that are generated by jmp_ptr diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs index 33a8af6868..6d73a3a32d 100644 --- a/zjit/src/codegen.rs +++ b/zjit/src/codegen.rs @@ -6,7 +6,7 @@ use crate::backend::current::{Reg, ALLOC_REGS}; use crate::profile::get_or_create_iseq_payload; use crate::state::ZJITState; use crate::{asm::CodeBlock, cruby::*, options::debug, virtualmem::CodePtr}; -use crate::backend::lir::{self, asm_comment, Assembler, Opnd, Target, CFP, C_ARG_OPNDS, C_RET_OPND, EC, SP}; +use crate::backend::lir::{self, asm_comment, Assembler, Opnd, Target, CFP, C_ARG_OPNDS, C_RET_OPND, EC, NATIVE_STACK_PTR, SP}; use crate::hir::{iseq_to_hir, Block, BlockId, BranchEdge, CallInfo, RangeType, SELF_PARAM_IDX, SpecialObjectType}; use crate::hir::{Const, FrameState, Function, Insn, InsnId}; use crate::hir_type::{types::Fixnum, Type}; @@ -25,16 +25,20 @@ struct JITState { /// Branches to an ISEQ that need to be compiled later branch_iseqs: Vec<(Rc<Branch>, IseqPtr)>, + + /// The number of bytes allocated for basic block arguments spilled onto the C stack + c_stack_bytes: usize, } impl JITState { /// Create a new JITState instance - fn new(iseq: IseqPtr, num_insns: usize, num_blocks: usize) -> Self { + fn new(iseq: IseqPtr, num_insns: usize, num_blocks: usize, c_stack_bytes: usize) -> Self { JITState { iseq, opnds: vec![None; num_insns], labels: vec![None; num_blocks], branch_iseqs: Vec::default(), + c_stack_bytes, } } @@ -179,7 +183,8 @@ fn gen_iseq(cb: &mut CodeBlock, iseq: IseqPtr) -> Option<(CodePtr, Vec<(Rc<Branc /// Compile a function fn gen_function(cb: &mut CodeBlock, iseq: IseqPtr, function: &Function) -> Option<(CodePtr, Vec<(Rc<Branch>, IseqPtr)>)> { - let mut jit = JITState::new(iseq, function.num_insns(), function.num_blocks()); + let c_stack_bytes = aligned_stack_bytes(max_num_params(function).saturating_sub(ALLOC_REGS.len())); + let mut jit = JITState::new(iseq, function.num_insns(), function.num_blocks(), c_stack_bytes); let mut asm = Assembler::new(); // Compile each basic block @@ -195,6 +200,13 @@ fn gen_function(cb: &mut CodeBlock, iseq: IseqPtr, function: &Function) -> Optio // Set up the frame at the first block if block_id == BlockId(0) { asm.frame_setup(); + + // Bump the C stack pointer for basic block arguments + if jit.c_stack_bytes > 0 { + asm_comment!(asm, "bump C stack pointer"); + let new_sp = asm.sub(NATIVE_STACK_PTR, jit.c_stack_bytes.into()); + asm.mov(NATIVE_STACK_PTR, new_sp); + } } // Compile all parameters @@ -252,7 +264,7 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio Insn::SendWithoutBlock { call_info, cd, state, self_val, args, .. } => gen_send_without_block(jit, asm, call_info, *cd, &function.frame_state(*state), self_val, args)?, Insn::SendWithoutBlockDirect { cme, iseq, self_val, args, state, .. } => gen_send_without_block_direct(cb, jit, asm, *cme, *iseq, opnd!(self_val), args, &function.frame_state(*state))?, Insn::InvokeBuiltin { bf, args, state } => gen_invokebuiltin(jit, asm, &function.frame_state(*state), bf, args)?, - Insn::Return { val } => return Some(gen_return(asm, opnd!(val))?), + Insn::Return { val } => return Some(gen_return(jit, asm, opnd!(val))?), Insn::FixnumAdd { left, right, state } => gen_fixnum_add(jit, asm, opnd!(left), opnd!(right), &function.frame_state(*state))?, Insn::FixnumSub { left, right, state } => gen_fixnum_sub(jit, asm, opnd!(left), opnd!(right), &function.frame_state(*state))?, Insn::FixnumMult { left, right, state } => gen_fixnum_mult(jit, asm, opnd!(left), opnd!(right), &function.frame_state(*state))?, @@ -518,7 +530,16 @@ fn gen_branch_params(jit: &mut JITState, asm: &mut Assembler, branch: &BranchEdg asm_comment!(asm, "set branch params: {}", branch.args.len()); let mut moves: Vec<(Reg, Opnd)> = vec![]; for (idx, &arg) in branch.args.iter().enumerate() { - moves.push((param_reg(idx), jit.get_opnd(arg)?)); + match param_opnd(idx) { + Opnd::Reg(reg) => { + // If a parameter is a register, we need to parallel-move it + moves.push((reg, jit.get_opnd(arg)?)); + }, + param => { + // If a parameter is memory, we set it beforehand + asm.mov(param, jit.get_opnd(arg)?); + } + } } asm.parallel_mov(moves); } @@ -555,7 +576,13 @@ fn gen_const(val: VALUE) -> lir::Opnd { /// Compile a basic block argument fn gen_param(asm: &mut Assembler, idx: usize) -> lir::Opnd { - asm.live_reg_opnd(Opnd::Reg(param_reg(idx))) + // Allocate a register or a stack slot + match param_opnd(idx) { + // If it's a register, insert LiveReg instruction to reserve the register + // in the register pool for register allocation. + param @ Opnd::Reg(_) => asm.live_reg_opnd(param), + param => param, + } } /// Compile a jump to a basic block @@ -797,7 +824,7 @@ fn gen_new_range( } /// Compile code that exits from JIT code with a return value -fn gen_return(asm: &mut Assembler, val: lir::Opnd) -> Option<()> { +fn gen_return(jit: &JITState, asm: &mut Assembler, val: lir::Opnd) -> Option<()> { // Pop the current frame (ec->cfp++) // Note: the return PC is already in the previous CFP asm_comment!(asm, "pop stack frame"); @@ -805,6 +832,13 @@ fn gen_return(asm: &mut Assembler, val: lir::Opnd) -> Option<()> { asm.mov(CFP, incr_cfp); asm.mov(Opnd::mem(64, EC, RUBY_OFFSET_EC_CFP), CFP); + // Restore the C stack pointer bumped for basic block arguments + if jit.c_stack_bytes > 0 { + asm_comment!(asm, "restore C stack pointer"); + let new_sp = asm.add(NATIVE_STACK_PTR, jit.c_stack_bytes.into()); + asm.mov(NATIVE_STACK_PTR, new_sp); + } + asm.frame_teardown(); // Return from the function @@ -992,17 +1026,15 @@ fn gen_push_frame(asm: &mut Assembler, argc: usize, state: &FrameState, frame: C asm.mov(cfp_opnd(RUBY_OFFSET_CFP_BLOCK_CODE), 0.into()); } -/// Return a register we use for the basic block argument at a given index -fn param_reg(idx: usize) -> Reg { - // To simplify the implementation, allocate a fixed register for each basic block argument for now. +/// Return an operand we use for the basic block argument at a given index +fn param_opnd(idx: usize) -> Opnd { + // To simplify the implementation, allocate a fixed register or a stack slot for each basic block argument for now. // TODO: Allow allocating arbitrary registers for basic block arguments - if idx >= ALLOC_REGS.len() { - unimplemented!( - "register spilling not yet implemented, too many basic block arguments ({}/{})", - idx + 1, ALLOC_REGS.len() - ); + if idx < ALLOC_REGS.len() { + Opnd::Reg(ALLOC_REGS[idx]) + } else { + Opnd::mem(64, NATIVE_STACK_PTR, -((idx - ALLOC_REGS.len() + 1) as i32) * SIZEOF_VALUE_I32) } - ALLOC_REGS[idx] } /// Inverse of ep_offset_to_local_idx(). See ep_offset_to_local_idx() for details. @@ -1045,6 +1077,7 @@ fn side_exit(jit: &mut JITState, state: &FrameState) -> Option<Target> { pc: state.pc, stack, locals, + c_stack_bytes: jit.c_stack_bytes, }; Some(target) } @@ -1063,6 +1096,28 @@ fn iseq_entry_escapes_ep(iseq: IseqPtr) -> bool { } } +/// Returne the maximum number of arguments for a block in a given function +fn max_num_params(function: &Function) -> usize { + let reverse_post_order = function.rpo(); + reverse_post_order.iter().map(|&block_id| { + let block = function.block(block_id); + block.params().len() + }).max().unwrap_or(0) +} + +/// Given the number of spill slots needed for a function, return the number of bytes +/// the function needs to allocate on the stack for the stack frame. +fn aligned_stack_bytes(num_slots: usize) -> usize { + // Both x86_64 and arm64 require the stack to be aligned to 16 bytes. + // Since SIZEOF_VALUE is 8 bytes, we need to round up the size to the nearest even number. + let num_slots = if num_slots % 2 == 0 { + num_slots + } else { + num_slots + 1 + }; + num_slots * SIZEOF_VALUE +} + impl Assembler { /// Make a C call while marking the start and end positions of it fn ccall_with_branch(&mut self, fptr: *const u8, opnds: Vec<Opnd>, branch: &Rc<Branch>) -> Opnd { |
