diff options
| author | Takashi Kokubun <takashi.kokubun@shopify.com> | 2025-11-10 07:30:17 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-11-10 07:30:17 -0800 |
| commit | b539cd2a33dae904e55fd8f054349b9ff076793a (patch) | |
| tree | 6ece03af412224d4f78126d964645b4fada8218e | |
| parent | 557eec792e520a25c3ca59aa0a824af5d009c7d8 (diff) | |
ZJIT: Deduplicate side exits (#15105)
| -rw-r--r-- | zjit/src/asm/mod.rs | 2 | ||||
| -rw-r--r-- | zjit/src/backend/arm64/mod.rs | 16 | ||||
| -rw-r--r-- | zjit/src/backend/lir.rs | 188 | ||||
| -rw-r--r-- | zjit/src/backend/x86_64/mod.rs | 8 | ||||
| -rw-r--r-- | zjit/src/codegen.rs | 37 |
5 files changed, 154 insertions, 97 deletions
diff --git a/zjit/src/asm/mod.rs b/zjit/src/asm/mod.rs index 86176c0ec9..9b792f5f37 100644 --- a/zjit/src/asm/mod.rs +++ b/zjit/src/asm/mod.rs @@ -16,7 +16,7 @@ pub mod x86_64; pub mod arm64; /// Index to a label created by cb.new_label() -#[derive(Copy, Clone, Debug, Eq, PartialEq)] +#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)] pub struct Label(pub usize); /// The object that knows how to encode the branch instruction. diff --git a/zjit/src/backend/arm64/mod.rs b/zjit/src/backend/arm64/mod.rs index 428d4bff77..79c07d42ad 100644 --- a/zjit/src/backend/arm64/mod.rs +++ b/zjit/src/backend/arm64/mod.rs @@ -2,6 +2,7 @@ use std::mem::take; use crate::asm::{CodeBlock, Label}; use crate::asm::arm64::*; +use crate::codegen::split_patch_point; use crate::cruby::*; use crate::backend::lir::*; use crate::options::asm_dump; @@ -826,6 +827,14 @@ impl Assembler { *opnds = vec![]; asm.push_insn(insn); } + // For compile_exits, support splitting simple return values here + Insn::CRet(opnd) => { + match opnd { + Opnd::Reg(C_RET_REG) => {}, + _ => asm.load_into(C_RET_OPND, *opnd), + } + asm.cret(C_RET_OPND); + } Insn::Lea { opnd, out } => { *opnd = split_only_stack_membase(asm, *opnd, SCRATCH0_OPND, &stack_state); let mem_out = split_memory_write(out, SCRATCH0_OPND); @@ -894,6 +903,9 @@ impl Assembler { } } } + &mut Insn::PatchPoint { ref target, invariant, payload } => { + split_patch_point(asm, target, invariant, payload); + } _ => { asm.push_insn(insn); } @@ -1514,7 +1526,7 @@ impl Assembler { Insn::Jonz(opnd, target) => { emit_cmp_zero_jump(cb, opnd.into(), false, target.clone()); }, - Insn::PatchPoint(_) | + Insn::PatchPoint { .. } => unreachable!("PatchPoint should have been lowered to PadPatchPoint in arm64_scratch_split"), Insn::PadPatchPoint => { // If patch points are too close to each other or the end of the block, fill nop instructions if let Some(last_patch_pos) = last_patch_pos { @@ -1694,7 +1706,7 @@ mod tests { let val64 = asm.add(CFP, Opnd::UImm(64)); asm.store(Opnd::mem(64, SP, 0x10), val64); - let side_exit = Target::SideExit { reason: SideExitReason::Interrupt, pc: 0 as _, stack: vec![], locals: vec![], label: None }; + let side_exit = Target::SideExit { reason: SideExitReason::Interrupt, exit: SideExit { pc: Opnd::const_ptr(0 as *const u8), stack: vec![], locals: vec![] } }; asm.push_insn(Insn::Joz(val64, side_exit)); asm.parallel_mov(vec![(C_ARG_OPNDS[0], C_RET_OPND.with_num_bits(32)), (C_ARG_OPNDS[1], Opnd::mem(64, SP, -8))]); diff --git a/zjit/src/backend/lir.rs b/zjit/src/backend/lir.rs index 69b030608b..cb8382a43c 100644 --- a/zjit/src/backend/lir.rs +++ b/zjit/src/backend/lir.rs @@ -6,9 +6,10 @@ use std::rc::Rc; use std::sync::{Arc, Mutex}; use crate::codegen::local_size_and_idx_to_ep_offset; use crate::cruby::{Qundef, RUBY_OFFSET_CFP_PC, RUBY_OFFSET_CFP_SP, SIZEOF_VALUE_I32, vm_stack_canary}; -use crate::hir::SideExitReason; +use crate::hir::{Invariant, SideExitReason}; use crate::options::{TraceExits, debug, get_option}; use crate::cruby::VALUE; +use crate::payload::IseqPayload; use crate::stats::{exit_counter_ptr, exit_counter_ptr_for_opcode, side_exit_counter, CompileError}; use crate::virtualmem::CodePtr; use crate::asm::{CodeBlock, Label}; @@ -25,7 +26,7 @@ pub use crate::backend::current::{ pub static JIT_PRESERVED_REGS: &[Opnd] = &[CFP, SP, EC]; // Memory operand base -#[derive(Clone, Copy, PartialEq, Eq, Debug)] +#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)] pub enum MemBase { /// Register: Every Opnd::Mem should have MemBase::Reg as of emit. @@ -37,7 +38,7 @@ pub enum MemBase } // Memory location -#[derive(Copy, Clone, PartialEq, Eq)] +#[derive(Copy, Clone, PartialEq, Eq, Hash)] pub struct Mem { // Base register number or instruction index @@ -87,7 +88,7 @@ impl fmt::Debug for Mem { } /// Operand to an IR instruction -#[derive(Clone, Copy, PartialEq, Eq)] +#[derive(Clone, Copy, PartialEq, Eq, Hash)] pub enum Opnd { None, // For insns with no output @@ -298,6 +299,14 @@ impl From<VALUE> for Opnd { } } +/// Context for a side exit. If `SideExit` matches, it reuses the same code. +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +pub struct SideExit { + pub pc: Opnd, + pub stack: Vec<Opnd>, + pub locals: Vec<Opnd>, +} + /// Branch target (something that we can jump to) /// for branch instructions #[derive(Clone, Debug)] @@ -309,13 +318,10 @@ pub enum Target Label(Label), /// Side exit to the interpreter SideExit { - pc: *const VALUE, - stack: Vec<Opnd>, - locals: Vec<Opnd>, - /// We use this to enrich asm comments. + /// Context used for compiling the side exit + exit: SideExit, + /// We use this to increment exit counters reason: SideExitReason, - /// Some if the side exit should write this label. We use it for patch points. - label: Option<Label>, }, } @@ -525,7 +531,7 @@ pub enum Insn { Or { left: Opnd, right: Opnd, out: Opnd }, /// Patch point that will be rewritten to a jump to a side exit on invalidation. - PatchPoint(Target), + PatchPoint { target: Target, invariant: Invariant, payload: *mut IseqPayload }, /// Make sure the last PatchPoint has enough space to insert a jump. /// We insert this instruction at the end of each block so that the jump @@ -590,7 +596,7 @@ impl Insn { Insn::Jonz(_, target) | Insn::Label(target) | Insn::LeaJumpTarget { target, .. } | - Insn::PatchPoint(target) => { + Insn::PatchPoint { target, .. } => { Some(target) } _ => None, @@ -652,7 +658,7 @@ impl Insn { Insn::Mov { .. } => "Mov", Insn::Not { .. } => "Not", Insn::Or { .. } => "Or", - Insn::PatchPoint(_) => "PatchPoint", + Insn::PatchPoint { .. } => "PatchPoint", Insn::PadPatchPoint => "PadPatchPoint", Insn::PosMarker(_) => "PosMarker", Insn::RShift { .. } => "RShift", @@ -750,7 +756,7 @@ impl Insn { Insn::Jonz(_, target) | Insn::Label(target) | Insn::LeaJumpTarget { target, .. } | - Insn::PatchPoint(target) => Some(target), + Insn::PatchPoint { target, .. } => Some(target), _ => None } } @@ -797,8 +803,8 @@ impl<'a> Iterator for InsnOpndIterator<'a> { Insn::Jz(target) | Insn::Label(target) | Insn::LeaJumpTarget { target, .. } | - Insn::PatchPoint(target) => { - if let Target::SideExit { stack, locals, .. } = target { + Insn::PatchPoint { target, .. } => { + if let Target::SideExit { exit: SideExit { stack, locals, .. }, .. } = target { let stack_idx = self.idx; if stack_idx < stack.len() { let opnd = &stack[stack_idx]; @@ -823,7 +829,7 @@ impl<'a> Iterator for InsnOpndIterator<'a> { return Some(opnd); } - if let Target::SideExit { stack, locals, .. } = target { + if let Target::SideExit { exit: SideExit { stack, locals, .. }, .. } = target { let stack_idx = self.idx - 1; if stack_idx < stack.len() { let opnd = &stack[stack_idx]; @@ -966,8 +972,8 @@ impl<'a> InsnOpndMutIterator<'a> { Insn::Jz(target) | Insn::Label(target) | Insn::LeaJumpTarget { target, .. } | - Insn::PatchPoint(target) => { - if let Target::SideExit { stack, locals, .. } = target { + Insn::PatchPoint { target, .. } => { + if let Target::SideExit { exit: SideExit { stack, locals, .. }, .. } = target { let stack_idx = self.idx; if stack_idx < stack.len() { let opnd = &mut stack[stack_idx]; @@ -992,7 +998,7 @@ impl<'a> InsnOpndMutIterator<'a> { return Some(opnd); } - if let Target::SideExit { stack, locals, .. } = target { + if let Target::SideExit { exit: SideExit { stack, locals, .. }, .. } = target { let stack_idx = self.idx - 1; if stack_idx < stack.len() { let opnd = &mut stack[stack_idx]; @@ -1779,10 +1785,41 @@ impl Assembler /// Compile Target::SideExit and convert it into Target::CodePtr for all instructions pub fn compile_exits(&mut self) { + /// Compile the main side-exit code. This function takes only SideExit so + /// that it can be safely deduplicated by using SideExit as a dedup key. + fn compile_exit(asm: &mut Assembler, exit: &SideExit) { + let SideExit { pc, stack, locals } = exit; + + asm_comment!(asm, "save cfp->pc"); + asm.store(Opnd::mem(64, CFP, RUBY_OFFSET_CFP_PC), *pc); + + asm_comment!(asm, "save cfp->sp"); + asm.lea_into(Opnd::mem(64, CFP, RUBY_OFFSET_CFP_SP), Opnd::mem(64, SP, stack.len() as i32 * SIZEOF_VALUE_I32)); + + if !stack.is_empty() { + asm_comment!(asm, "write stack slots: {}", join_opnds(&stack, ", ")); + for (idx, &opnd) in stack.iter().enumerate() { + asm.store(Opnd::mem(64, SP, idx as i32 * SIZEOF_VALUE_I32), opnd); + } + } + + if !locals.is_empty() { + asm_comment!(asm, "write locals: {}", join_opnds(&locals, ", ")); + for (idx, &opnd) in locals.iter().enumerate() { + asm.store(Opnd::mem(64, SP, (-local_size_and_idx_to_ep_offset(locals.len(), idx) - 1) * SIZEOF_VALUE_I32), opnd); + } + } + + asm_comment!(asm, "exit to the interpreter"); + asm.frame_teardown(&[]); // matching the setup in gen_entry_point() + asm.cret(Opnd::UImm(Qundef.as_u64())); + } + fn join_opnds(opnds: &Vec<Opnd>, delimiter: &str) -> String { opnds.iter().map(|opnd| format!("{opnd}")).collect::<Vec<_>>().join(delimiter) } + // Extract targets first so that we can update instructions while referencing part of them. let mut targets = HashMap::new(); for (idx, insn) in self.insns.iter().enumerate() { if let Some(target @ Target::SideExit { .. }) = insn.target() { @@ -1790,71 +1827,66 @@ impl Assembler } } + // Map from SideExit to compiled Label. This table is used to deduplicate side exit code. + let mut compiled_exits: HashMap<SideExit, Label> = HashMap::new(); + for (idx, target) in targets { // Compile a side exit. Note that this is past the split pass and alloc_regs(), // so you can't use an instruction that returns a VReg. - if let Target::SideExit { pc, stack, locals, reason, label } = target { - asm_comment!(self, "Exit: {reason}"); - let side_exit_label = if let Some(label) = label { - Target::Label(label) - } else { - self.new_label("side_exit") - }; - self.write_label(side_exit_label.clone()); - - // Restore the PC and the stack for regular side exits. We don't do this for - // side exits right after JIT-to-JIT calls, which restore them before the call. - asm_comment!(self, "write stack slots: {}", join_opnds(&stack, ", ")); - for (idx, &opnd) in stack.iter().enumerate() { - self.store(Opnd::mem(64, SP, idx as i32 * SIZEOF_VALUE_I32), opnd); - } - - asm_comment!(self, "write locals: {}", join_opnds(&locals, ", ")); - for (idx, &opnd) in locals.iter().enumerate() { - self.store(Opnd::mem(64, SP, (-local_size_and_idx_to_ep_offset(locals.len(), idx) - 1) * SIZEOF_VALUE_I32), opnd); - } - - asm_comment!(self, "save cfp->pc"); - self.store(Opnd::mem(64, CFP, RUBY_OFFSET_CFP_PC), Opnd::const_ptr(pc)); - - asm_comment!(self, "save cfp->sp"); - self.lea_into(Opnd::mem(64, CFP, RUBY_OFFSET_CFP_SP), Opnd::mem(64, SP, stack.len() as i32 * SIZEOF_VALUE_I32)); - - // Using C_RET_OPND as an additional scratch register, which is no longer used - if get_option!(stats) { - asm_comment!(self, "increment a side exit counter"); - self.incr_counter(Opnd::const_ptr(exit_counter_ptr(reason)), 1.into()); - - if let SideExitReason::UnhandledYARVInsn(opcode) = reason { - asm_comment!(self, "increment an unhandled YARV insn counter"); - self.incr_counter(Opnd::const_ptr(exit_counter_ptr_for_opcode(opcode)), 1.into()); + if let Target::SideExit { exit: exit @ SideExit { pc, .. }, reason } = target { + // Only record the exit if `trace_side_exits` is defined and the counter is either the one specified + let should_record_exit = get_option!(trace_side_exits).map(|trace| match trace { + TraceExits::All => true, + TraceExits::Counter(counter) if counter == side_exit_counter(reason) => true, + _ => false, + }).unwrap_or(false); + + // If enabled, instrument exits first, and then jump to a shared exit. + let counted_exit = if get_option!(stats) || should_record_exit { + let counted_exit = self.new_label("counted_exit"); + self.write_label(counted_exit.clone()); + asm_comment!(self, "Counted Exit: {reason}"); + + if get_option!(stats) { + asm_comment!(self, "increment a side exit counter"); + self.incr_counter(Opnd::const_ptr(exit_counter_ptr(reason)), 1.into()); + + if let SideExitReason::UnhandledYARVInsn(opcode) = reason { + asm_comment!(self, "increment an unhandled YARV insn counter"); + self.incr_counter(Opnd::const_ptr(exit_counter_ptr_for_opcode(opcode)), 1.into()); + } } - } - - if get_option!(trace_side_exits).is_some() { - // Get the corresponding `Counter` for the current `SideExitReason`. - let side_exit_counter = side_exit_counter(reason); - - // Only record the exit if `trace_side_exits` is defined and the counter is either the one specified - let should_record_exit = get_option!(trace_side_exits) - .map(|trace| match trace { - TraceExits::All => true, - TraceExits::Counter(counter) if counter == side_exit_counter => true, - _ => false, - }) - .unwrap_or(false); if should_record_exit { - asm_ccall!(self, rb_zjit_record_exit_stack, Opnd::const_ptr(pc as *const u8)); + // Preserve caller-saved registers that may be used in the shared exit. + self.cpush_all(); + asm_ccall!(self, rb_zjit_record_exit_stack, pc); + self.cpop_all(); } - } - asm_comment!(self, "exit to the interpreter"); - self.frame_teardown(&[]); // matching the setup in :bb0-prologue: - self.mov(C_RET_OPND, Opnd::UImm(Qundef.as_u64())); - self.cret(C_RET_OPND); + // If the side exit has already been compiled, jump to it. + // Otherwise, let it fall through and compile the exit next. + if let Some(&exit_label) = compiled_exits.get(&exit) { + self.jmp(Target::Label(exit_label)); + } + Some(counted_exit) + } else { + None + }; + + // Compile the shared side exit if not compiled yet + let compiled_exit = if let Some(&compiled_exit) = compiled_exits.get(&exit) { + Target::Label(compiled_exit) + } else { + let new_exit = self.new_label("side_exit"); + self.write_label(new_exit.clone()); + asm_comment!(self, "Exit: {pc}"); + compile_exit(self, &exit); + compiled_exits.insert(exit, new_exit.unwrap_label()); + new_exit + }; - *self.insns[idx].target_mut().unwrap() = side_exit_label; + *self.insns[idx].target_mut().unwrap() = counted_exit.unwrap_or(compiled_exit); } } } @@ -2268,8 +2300,8 @@ impl Assembler { out } - pub fn patch_point(&mut self, target: Target) { - self.push_insn(Insn::PatchPoint(target)); + pub fn patch_point(&mut self, target: Target, invariant: Invariant, payload: *mut IseqPayload) { + self.push_insn(Insn::PatchPoint { target, invariant, payload }); } pub fn pad_patch_point(&mut self) { diff --git a/zjit/src/backend/x86_64/mod.rs b/zjit/src/backend/x86_64/mod.rs index aea25ca2a4..11876eb894 100644 --- a/zjit/src/backend/x86_64/mod.rs +++ b/zjit/src/backend/x86_64/mod.rs @@ -2,6 +2,7 @@ use std::mem::{self, take}; use crate::asm::*; use crate::asm::x86_64::*; +use crate::codegen::split_patch_point; use crate::stats::CompileError; use crate::virtualmem::CodePtr; use crate::cruby::*; @@ -628,6 +629,9 @@ impl Assembler { }; asm.store(dest, src); } + &mut Insn::PatchPoint { ref target, invariant, payload } => { + split_patch_point(asm, target, invariant, payload); + } _ => { asm.push_insn(insn); } @@ -989,7 +993,7 @@ impl Assembler { Insn::Joz(..) | Insn::Jonz(..) => unreachable!("Joz/Jonz should be unused for now"), - Insn::PatchPoint(_) | + Insn::PatchPoint { .. } => unreachable!("PatchPoint should have been lowered to PadPatchPoint in x86_scratch_split"), Insn::PadPatchPoint => { // If patch points are too close to each other or the end of the block, fill nop instructions if let Some(last_patch_pos) = last_patch_pos { @@ -1127,7 +1131,7 @@ mod tests { let val64 = asm.add(CFP, Opnd::UImm(64)); asm.store(Opnd::mem(64, SP, 0x10), val64); - let side_exit = Target::SideExit { reason: SideExitReason::Interrupt, pc: 0 as _, stack: vec![], locals: vec![], label: None }; + let side_exit = Target::SideExit { reason: SideExitReason::Interrupt, exit: SideExit { pc: Opnd::const_ptr(0 as *const u8), stack: vec![], locals: vec![] } }; asm.push_insn(Insn::Joz(val64, side_exit)); asm.parallel_mov(vec![(C_ARG_OPNDS[0], C_RET_OPND.with_num_bits(32)), (C_ARG_OPNDS[1], Opnd::mem(64, SP, -8))]); diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs index 68e8ad8966..58c396ed9b 100644 --- a/zjit/src/codegen.rs +++ b/zjit/src/codegen.rs @@ -7,7 +7,6 @@ use std::rc::Rc; use std::ffi::{c_int, c_long, c_void}; use std::slice; -use crate::asm::Label; use crate::backend::current::ALLOC_REGS; use crate::invariants::{ track_bop_assumption, track_cme_assumption, track_no_ep_escape_assumption, track_no_trace_point_assumption, @@ -19,7 +18,7 @@ use crate::state::ZJITState; use crate::stats::{send_fallback_counter, exit_counter_for_compile_error, incr_counter, incr_counter_by, send_fallback_counter_for_method_type, send_without_block_fallback_counter_for_method_type, send_without_block_fallback_counter_for_optimized_method_type, send_fallback_counter_ptr_for_opcode, CompileError}; use crate::stats::{counter_ptr, with_time_stat, Counter, Counter::{compile_time_ns, exit_compile_error}}; use crate::{asm::CodeBlock, cruby::*, options::debug, virtualmem::CodePtr}; -use crate::backend::lir::{self, asm_comment, asm_ccall, Assembler, Opnd, Target, CFP, C_ARG_OPNDS, C_RET_OPND, EC, NATIVE_STACK_PTR, NATIVE_BASE_PTR, SP}; +use crate::backend::lir::{self, Assembler, C_ARG_OPNDS, C_RET_OPND, CFP, EC, NATIVE_BASE_PTR, NATIVE_STACK_PTR, Opnd, SP, SideExit, Target, asm_ccall, asm_comment}; use crate::hir::{iseq_to_hir, BlockId, BranchEdge, Invariant, RangeType, SideExitReason::{self, *}, SpecialBackrefSymbol, SpecialObjectType}; use crate::hir::{Const, FrameState, Function, Insn, InsnId, SendFallbackReason}; use crate::hir_type::{types, Type}; @@ -701,15 +700,26 @@ fn gen_invokebuiltin(jit: &JITState, asm: &mut Assembler, state: &FrameState, bf /// Record a patch point that should be invalidated on a given invariant fn gen_patch_point(jit: &mut JITState, asm: &mut Assembler, invariant: &Invariant, state: &FrameState) { let payload_ptr = get_or_create_iseq_payload_ptr(jit.iseq); - let label = asm.new_label("patch_point").unwrap_label(); let invariant = *invariant; + let exit = build_side_exit(jit, state); - // Compile a side exit. Fill nop instructions if the last patch point is too close. - asm.patch_point(build_side_exit(jit, state, PatchPoint(invariant), Some(label))); + // Let compile_exits compile a side exit. Let scratch_split lower it with split_patch_point. + asm.patch_point(Target::SideExit { exit, reason: PatchPoint(invariant) }, invariant, payload_ptr); +} + +/// This is used by scratch_split to lower PatchPoint into PadPatchPoint and PosMarker. +/// It's called at scratch_split so that we can use the Label after side-exit deduplication in compile_exits. +pub fn split_patch_point(asm: &mut Assembler, target: &Target, invariant: Invariant, payload_ptr: *mut IseqPayload) { + let Target::Label(exit_label) = *target else { + unreachable!("PatchPoint's target should have been lowered to Target::Label by compile_exits: {target:?}"); + }; + + // Fill nop instructions if the last patch point is too close. + asm.pad_patch_point(); // Remember the current address as a patch point asm.pos_marker(move |code_ptr, cb| { - let side_exit_ptr = cb.resolve_label(label); + let side_exit_ptr = cb.resolve_label(exit_label); match invariant { Invariant::BOPRedefined { klass, bop } => { track_bop_assumption(klass, bop, code_ptr, side_exit_ptr, payload_ptr); @@ -2038,13 +2048,14 @@ fn compile_iseq(iseq: IseqPtr) -> Result<Function, CompileError> { Ok(function) } -/// Build a Target::SideExit for non-PatchPoint instructions +/// Build a Target::SideExit fn side_exit(jit: &JITState, state: &FrameState, reason: SideExitReason) -> Target { - build_side_exit(jit, state, reason, None) + let exit = build_side_exit(jit, state); + Target::SideExit { exit, reason } } -/// Build a Target::SideExit out of a FrameState -fn build_side_exit(jit: &JITState, state: &FrameState, reason: SideExitReason, label: Option<Label>) -> Target { +/// Build a side-exit context +fn build_side_exit(jit: &JITState, state: &FrameState) -> SideExit { let mut stack = Vec::new(); for &insn_id in state.stack() { stack.push(jit.get_opnd(insn_id)); @@ -2055,12 +2066,10 @@ fn build_side_exit(jit: &JITState, state: &FrameState, reason: SideExitReason, l locals.push(jit.get_opnd(insn_id)); } - Target::SideExit { - pc: state.pc, + SideExit{ + pc: Opnd::const_ptr(state.pc), stack, locals, - reason, - label, } } |
