diff options
| -rw-r--r-- | zjit/src/backend/lir.rs | 68 | ||||
| -rw-r--r-- | zjit/src/codegen.rs | 15 | ||||
| -rw-r--r-- | zjit/src/hir.rs | 1 |
3 files changed, 51 insertions, 33 deletions
diff --git a/zjit/src/backend/lir.rs b/zjit/src/backend/lir.rs index e5453e4d55..94c53569b4 100644 --- a/zjit/src/backend/lir.rs +++ b/zjit/src/backend/lir.rs @@ -270,6 +270,14 @@ impl From<VALUE> for Opnd { } } +/// Set of things we need to restore for side exits. +#[derive(Clone, Debug)] +pub struct SideExitContext { + pub pc: *const VALUE, + pub stack: Vec<Opnd>, + pub locals: Vec<Opnd>, +} + /// Branch target (something that we can jump to) /// for branch instructions #[derive(Clone, Debug)] @@ -281,12 +289,14 @@ pub enum Target Label(Label), /// Side exit to the interpreter SideExit { - pc: *const VALUE, - stack: Vec<Opnd>, - locals: Vec<Opnd>, - c_stack_bytes: usize, + /// Context to restore on regular side exits. None for side exits right + /// after JIT-to-JIT calls because we restore them before the JIT call. + context: Option<SideExitContext>, + /// We use this to enrich asm comments. reason: SideExitReason, - // Some if the side exit should write this label. We use it for patch points. + /// The number of bytes we need to adjust the C stack pointer by. + c_stack_bytes: usize, + /// Some if the side exit should write this label. We use it for patch points. label: Option<Label>, }, } @@ -767,7 +777,7 @@ impl<'a> Iterator for InsnOpndIterator<'a> { Insn::Label(target) | Insn::LeaJumpTarget { target, .. } | Insn::PatchPoint(target) => { - if let Target::SideExit { stack, locals, .. } = target { + if let Target::SideExit { context: Some(SideExitContext { stack, locals, .. }), .. } = target { let stack_idx = self.idx; if stack_idx < stack.len() { let opnd = &stack[stack_idx]; @@ -792,7 +802,7 @@ impl<'a> Iterator for InsnOpndIterator<'a> { return Some(opnd); } - if let Target::SideExit { stack, locals, .. } = target { + if let Target::SideExit { context: Some(SideExitContext { stack, locals, .. }), .. } = target { let stack_idx = self.idx - 1; if stack_idx < stack.len() { let opnd = &stack[stack_idx]; @@ -923,7 +933,7 @@ impl<'a> InsnOpndMutIterator<'a> { Insn::Label(target) | Insn::LeaJumpTarget { target, .. } | Insn::PatchPoint(target) => { - if let Target::SideExit { stack, locals, .. } = target { + if let Target::SideExit { context: Some(SideExitContext { stack, locals, .. }), .. } = target { let stack_idx = self.idx; if stack_idx < stack.len() { let opnd = &mut stack[stack_idx]; @@ -948,7 +958,7 @@ impl<'a> InsnOpndMutIterator<'a> { return Some(opnd); } - if let Target::SideExit { stack, locals, .. } = target { + if let Target::SideExit { context: Some(SideExitContext { stack, locals, .. }), .. } = target { let stack_idx = self.idx - 1; if stack_idx < stack.len() { let opnd = &mut stack[stack_idx]; @@ -1803,7 +1813,7 @@ impl Assembler for (idx, target) in targets { // Compile a side exit. Note that this is past the split pass and alloc_regs(), // so you can't use a VReg or an instruction that needs to be split. - if let Target::SideExit { pc, stack, locals, c_stack_bytes, reason, label } = target { + if let Target::SideExit { context, reason, c_stack_bytes, label } = target { asm_comment!(self, "Exit: {reason}"); let side_exit_label = if let Some(label) = label { Target::Label(label) @@ -1823,26 +1833,30 @@ impl Assembler } } - asm_comment!(self, "write stack slots: {stack:?}"); - for (idx, &opnd) in stack.iter().enumerate() { - let opnd = split_store_source(self, opnd); - self.store(Opnd::mem(64, SP, idx as i32 * SIZEOF_VALUE_I32), opnd); - } + // Restore the PC and the stack for regular side exits. We don't do this for + // side exits right after JIT-to-JIT calls, which restore them before the call. + if let Some(SideExitContext { pc, stack, locals }) = context { + asm_comment!(self, "write stack slots: {stack:?}"); + for (idx, &opnd) in stack.iter().enumerate() { + let opnd = split_store_source(self, opnd); + self.store(Opnd::mem(64, SP, idx as i32 * SIZEOF_VALUE_I32), opnd); + } - asm_comment!(self, "write locals: {locals:?}"); - for (idx, &opnd) in locals.iter().enumerate() { - let opnd = split_store_source(self, opnd); - self.store(Opnd::mem(64, SP, (-local_size_and_idx_to_ep_offset(locals.len(), idx) - 1) * SIZEOF_VALUE_I32), opnd); - } + asm_comment!(self, "write locals: {locals:?}"); + for (idx, &opnd) in locals.iter().enumerate() { + let opnd = split_store_source(self, opnd); + self.store(Opnd::mem(64, SP, (-local_size_and_idx_to_ep_offset(locals.len(), idx) - 1) * SIZEOF_VALUE_I32), opnd); + } - asm_comment!(self, "save cfp->pc"); - self.load_into(Opnd::Reg(Assembler::SCRATCH_REG), Opnd::const_ptr(pc)); - self.store(Opnd::mem(64, CFP, RUBY_OFFSET_CFP_PC), Opnd::Reg(Assembler::SCRATCH_REG)); + asm_comment!(self, "save cfp->pc"); + self.load_into(Opnd::Reg(Assembler::SCRATCH_REG), Opnd::const_ptr(pc)); + self.store(Opnd::mem(64, CFP, RUBY_OFFSET_CFP_PC), Opnd::Reg(Assembler::SCRATCH_REG)); - asm_comment!(self, "save cfp->sp"); - self.lea_into(Opnd::Reg(Assembler::SCRATCH_REG), Opnd::mem(64, SP, stack.len() as i32 * SIZEOF_VALUE_I32)); - let cfp_sp = Opnd::mem(64, CFP, RUBY_OFFSET_CFP_SP); - self.store(cfp_sp, Opnd::Reg(Assembler::SCRATCH_REG)); + asm_comment!(self, "save cfp->sp"); + self.lea_into(Opnd::Reg(Assembler::SCRATCH_REG), Opnd::mem(64, SP, stack.len() as i32 * SIZEOF_VALUE_I32)); + let cfp_sp = Opnd::mem(64, CFP, RUBY_OFFSET_CFP_SP); + self.store(cfp_sp, Opnd::Reg(Assembler::SCRATCH_REG)); + } if c_stack_bytes > 0 { asm_comment!(self, "restore C stack pointer"); diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs index 917528ca0a..72652363b1 100644 --- a/zjit/src/codegen.rs +++ b/zjit/src/codegen.rs @@ -7,7 +7,7 @@ use crate::invariants::track_bop_assumption; use crate::gc::get_or_create_iseq_payload; use crate::state::ZJITState; use crate::{asm::CodeBlock, cruby::*, options::debug, virtualmem::CodePtr}; -use crate::backend::lir::{self, asm_comment, asm_ccall, Assembler, Opnd, Target, CFP, C_ARG_OPNDS, C_RET_OPND, EC, NATIVE_STACK_PTR, SP}; +use crate::backend::lir::{self, asm_comment, asm_ccall, Assembler, Opnd, SideExitContext, Target, CFP, C_ARG_OPNDS, C_RET_OPND, EC, NATIVE_STACK_PTR, SP}; use crate::hir::{iseq_to_hir, Block, BlockId, BranchEdge, CallInfo, Invariant, RangeType, SideExitReason, SideExitReason::*, SpecialObjectType, SELF_PARAM_IDX}; use crate::hir::{Const, FrameState, Function, Insn, InsnId}; use crate::hir_type::{types::Fixnum, Type}; @@ -774,7 +774,8 @@ fn gen_send_without_block_direct( // TODO: Let side exit code pop all JIT frames to optimize away this cmp + je. asm_comment!(asm, "side-exit if callee side-exits"); asm.cmp(ret, Qundef.into()); - asm.je(ZJITState::get_exit_trampoline().into()); + // Restore the C stack pointer on exit + asm.je(Target::SideExit { context: None, reason: CalleeSideExit, c_stack_bytes: jit.c_stack_bytes, label: None }); asm_comment!(asm, "restore SP register for the caller"); let new_sp = asm.sub(SP, sp_offset.into()); @@ -1112,11 +1113,13 @@ fn build_side_exit(jit: &mut JITState, state: &FrameState, reason: SideExitReaso } let target = Target::SideExit { - pc: state.pc, - stack, - locals, - c_stack_bytes: jit.c_stack_bytes, + context: Some(SideExitContext { + pc: state.pc, + stack, + locals, + }), reason, + c_stack_bytes: jit.c_stack_bytes, label, }; Some(target) diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index 598774de8c..1c6b55a9ff 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -414,6 +414,7 @@ pub enum SideExitReason { GuardType(Type), GuardBitEquals(VALUE), PatchPoint(Invariant), + CalleeSideExit, } impl std::fmt::Display for SideExitReason { |
