diff options
| author | Takashi Kokubun <takashi.kokubun@shopify.com> | 2025-11-04 16:09:13 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-11-04 16:09:13 -0800 |
| commit | bd3b44cb0a341878abe0edf65d01b1a48c93f088 (patch) | |
| tree | cdaa1d06a4b5f8a135464f50f92d45f046553c19 | |
| parent | be905b2e581540dc2c51a54aed537b19955b7bb0 (diff) | |
ZJIT: Use a shared trampoline across all ISEQs (#15042)
| -rw-r--r-- | vm.c | 69 | ||||
| -rw-r--r-- | vm_core.h | 1 | ||||
| -rw-r--r-- | vm_exec.h | 19 | ||||
| -rw-r--r-- | zjit.h | 6 | ||||
| -rw-r--r-- | zjit/src/backend/arm64/mod.rs | 30 | ||||
| -rw-r--r-- | zjit/src/backend/lir.rs | 15 | ||||
| -rw-r--r-- | zjit/src/backend/x86_64/mod.rs | 10 | ||||
| -rw-r--r-- | zjit/src/codegen.rs | 27 | ||||
| -rw-r--r-- | zjit/src/cruby.rs | 6 | ||||
| -rw-r--r-- | zjit/src/state.rs | 21 |
10 files changed, 133 insertions, 71 deletions
@@ -503,7 +503,7 @@ rb_yjit_threshold_hit(const rb_iseq_t *iseq, uint64_t entry_calls) #define rb_yjit_threshold_hit(iseq, entry_calls) false #endif -#if USE_YJIT || USE_ZJIT +#if USE_YJIT // Generate JIT code that supports the following kinds of ISEQ entries: // * The first ISEQ on vm_exec (e.g. <main>, or Ruby methods/blocks // called by a C method). The current frame has VM_FRAME_FLAG_FINISH. @@ -513,13 +513,32 @@ rb_yjit_threshold_hit(const rb_iseq_t *iseq, uint64_t entry_calls) // The current frame doesn't have VM_FRAME_FLAG_FINISH. The current // vm_exec does NOT stop whether JIT code returns Qundef or not. static inline rb_jit_func_t -jit_compile(rb_execution_context_t *ec) +yjit_compile(rb_execution_context_t *ec) { const rb_iseq_t *iseq = ec->cfp->iseq; struct rb_iseq_constant_body *body = ISEQ_BODY(iseq); + // Increment the ISEQ's call counter and trigger JIT compilation if not compiled + if (body->jit_entry == NULL) { + body->jit_entry_calls++; + if (rb_yjit_threshold_hit(iseq, body->jit_entry_calls)) { + rb_yjit_compile_iseq(iseq, ec, false); + } + } + return body->jit_entry; +} +#else +# define yjit_compile(ec) ((rb_jit_func_t)0) +#endif + #if USE_ZJIT - if (body->jit_entry == NULL && rb_zjit_enabled_p) { +static inline rb_jit_func_t +zjit_compile(rb_execution_context_t *ec) +{ + const rb_iseq_t *iseq = ec->cfp->iseq; + struct rb_iseq_constant_body *body = ISEQ_BODY(iseq); + + if (body->jit_entry == NULL) { body->jit_entry_calls++; // At profile-threshold, rewrite some of the YARV instructions @@ -533,38 +552,38 @@ jit_compile(rb_execution_context_t *ec) rb_zjit_compile_iseq(iseq, false); } } -#endif - -#if USE_YJIT - // Increment the ISEQ's call counter and trigger JIT compilation if not compiled - if (body->jit_entry == NULL && rb_yjit_enabled_p) { - body->jit_entry_calls++; - if (rb_yjit_threshold_hit(iseq, body->jit_entry_calls)) { - rb_yjit_compile_iseq(iseq, ec, false); - } - } -#endif return body->jit_entry; } +#else +# define zjit_compile(ec) ((rb_jit_func_t)0) +#endif -// Execute JIT code compiled by jit_compile() +// Execute JIT code compiled by yjit_compile() or zjit_compile() static inline VALUE jit_exec(rb_execution_context_t *ec) { - rb_jit_func_t func = jit_compile(ec); - if (func) { - // Call the JIT code - return func(ec, ec->cfp); - } - else { +#if USE_YJIT + if (rb_yjit_enabled_p) { + rb_jit_func_t func = yjit_compile(ec); + if (func) { + return func(ec, ec->cfp); + } return Qundef; } -} -#else -# define jit_compile(ec) ((rb_jit_func_t)0) -# define jit_exec(ec) Qundef #endif +#if USE_ZJIT + void *zjit_entry = rb_zjit_entry; + if (zjit_entry) { + rb_jit_func_t func = zjit_compile(ec); + if (func) { + return ((rb_zjit_func_t)zjit_entry)(ec, ec->cfp, func); + } + } +#endif + return Qundef; +} + #if USE_YJIT // Generate JIT code that supports the following kind of ISEQ entry: // * The first ISEQ pushed by vm_exec_handle_exception. The frame would @@ -398,6 +398,7 @@ enum rb_builtin_attr { }; typedef VALUE (*rb_jit_func_t)(struct rb_execution_context_struct *, struct rb_control_frame_struct *); +typedef VALUE (*rb_zjit_func_t)(struct rb_execution_context_struct *, struct rb_control_frame_struct *, rb_jit_func_t); struct rb_iseq_constant_body { enum rb_iseq_type type; @@ -175,11 +175,22 @@ default: \ // Run the JIT from the interpreter #define JIT_EXEC(ec, val) do { \ - rb_jit_func_t func; \ /* don't run tailcalls since that breaks FINISH */ \ - if (UNDEF_P(val) && GET_CFP() != ec->cfp && (func = jit_compile(ec))) { \ - val = func(ec, ec->cfp); \ - if (ec->tag->state) THROW_EXCEPTION(val); \ + if (UNDEF_P(val) && GET_CFP() != ec->cfp) { \ + rb_zjit_func_t zjit_entry; \ + if (rb_yjit_enabled_p) { \ + rb_jit_func_t func = yjit_compile(ec); \ + if (func) { \ + val = func(ec, ec->cfp); \ + if (ec->tag->state) THROW_EXCEPTION(val); \ + } \ + } \ + else if ((zjit_entry = rb_zjit_entry)) { \ + rb_jit_func_t func = zjit_compile(ec); \ + if (func) { \ + val = zjit_entry(ec, ec->cfp, func); \ + } \ + } \ } \ } while (0) @@ -10,7 +10,7 @@ #endif #if USE_ZJIT -extern bool rb_zjit_enabled_p; +extern void *rb_zjit_entry; extern uint64_t rb_zjit_call_threshold; extern uint64_t rb_zjit_profile_threshold; void rb_zjit_compile_iseq(const rb_iseq_t *iseq, bool jit_exception); @@ -29,7 +29,7 @@ void rb_zjit_before_ractor_spawn(void); void rb_zjit_tracing_invalidate_all(void); void rb_zjit_invalidate_no_singleton_class(VALUE klass); #else -#define rb_zjit_enabled_p false +#define rb_zjit_entry 0 static inline void rb_zjit_compile_iseq(const rb_iseq_t *iseq, bool jit_exception) {} static inline void rb_zjit_profile_insn(uint32_t insn, rb_execution_context_t *ec) {} static inline void rb_zjit_profile_enable(const rb_iseq_t *iseq) {} @@ -42,4 +42,6 @@ static inline void rb_zjit_tracing_invalidate_all(void) {} static inline void rb_zjit_invalidate_no_singleton_class(VALUE klass) {} #endif // #if USE_ZJIT +#define rb_zjit_enabled_p (rb_zjit_entry != 0) + #endif // #ifndef ZJIT_H diff --git a/zjit/src/backend/arm64/mod.rs b/zjit/src/backend/arm64/mod.rs index acf0576f9c..532570d732 100644 --- a/zjit/src/backend/arm64/mod.rs +++ b/zjit/src/backend/arm64/mod.rs @@ -1428,17 +1428,25 @@ impl Assembler { } }, Insn::CCall { fptr, .. } => { - // The offset to the call target in bytes - let src_addr = cb.get_write_ptr().raw_ptr(cb) as i64; - let dst_addr = *fptr as i64; - - // Use BL if the offset is short enough to encode as an immediate. - // Otherwise, use BLR with a register. - if b_offset_fits_bits((dst_addr - src_addr) / 4) { - bl(cb, InstructionOffset::from_bytes((dst_addr - src_addr) as i32)); - } else { - emit_load_value(cb, Self::EMIT_OPND, dst_addr as u64); - blr(cb, Self::EMIT_OPND); + match fptr { + Opnd::UImm(fptr) => { + // The offset to the call target in bytes + let src_addr = cb.get_write_ptr().raw_ptr(cb) as i64; + let dst_addr = *fptr as i64; + + // Use BL if the offset is short enough to encode as an immediate. + // Otherwise, use BLR with a register. + if b_offset_fits_bits((dst_addr - src_addr) / 4) { + bl(cb, InstructionOffset::from_bytes((dst_addr - src_addr) as i32)); + } else { + emit_load_value(cb, Self::EMIT_OPND, dst_addr as u64); + blr(cb, Self::EMIT_OPND); + } + } + Opnd::Reg(_) => { + blr(cb, fptr.into()); + } + _ => unreachable!("unsupported ccall fptr: {fptr:?}") } }, Insn::CRet { .. } => { diff --git a/zjit/src/backend/lir.rs b/zjit/src/backend/lir.rs index 66e89a1304..e2f75e01c8 100644 --- a/zjit/src/backend/lir.rs +++ b/zjit/src/backend/lir.rs @@ -386,7 +386,9 @@ pub enum Insn { // C function call with N arguments (variadic) CCall { opnds: Vec<Opnd>, - fptr: *const u8, + /// The function pointer to be called. This should be Opnd::const_ptr + /// (Opnd::UImm) in most cases. gen_entry_trampoline() uses Opnd::Reg. + fptr: Opnd, /// Optional PosMarker to remember the start address of the C call. /// It's embedded here to insert the PosMarker after push instructions /// that are split from this CCall on alloc_regs(). @@ -1989,11 +1991,20 @@ impl Assembler { pub fn ccall(&mut self, fptr: *const u8, opnds: Vec<Opnd>) -> Opnd { let canary_opnd = self.set_stack_canary(); let out = self.new_vreg(Opnd::match_num_bits(&opnds)); + let fptr = Opnd::const_ptr(fptr); self.push_insn(Insn::CCall { fptr, opnds, start_marker: None, end_marker: None, out }); self.clear_stack_canary(canary_opnd); out } + /// Call a C function stored in a register + pub fn ccall_reg(&mut self, fptr: Opnd, num_bits: u8) -> Opnd { + assert!(matches!(fptr, Opnd::Reg(_)), "ccall_reg must be called with Opnd::Reg: {fptr:?}"); + let out = self.new_vreg(num_bits); + self.push_insn(Insn::CCall { fptr, opnds: vec![], start_marker: None, end_marker: None, out }); + out + } + /// Call a C function with PosMarkers. This is used for recording the start and end /// addresses of the C call and rewriting it with a different function address later. pub fn ccall_with_pos_markers( @@ -2005,7 +2016,7 @@ impl Assembler { ) -> Opnd { let out = self.new_vreg(Opnd::match_num_bits(&opnds)); self.push_insn(Insn::CCall { - fptr, + fptr: Opnd::const_ptr(fptr), opnds, start_marker: Some(Rc::new(start_marker)), end_marker: Some(Rc::new(end_marker)), diff --git a/zjit/src/backend/x86_64/mod.rs b/zjit/src/backend/x86_64/mod.rs index 1d5d90a856..aea25ca2a4 100644 --- a/zjit/src/backend/x86_64/mod.rs +++ b/zjit/src/backend/x86_64/mod.rs @@ -863,7 +863,15 @@ impl Assembler { // C function call Insn::CCall { fptr, .. } => { - call_ptr(cb, RAX, *fptr); + match fptr { + Opnd::UImm(fptr) => { + call_ptr(cb, RAX, *fptr as *const u8); + } + Opnd::Reg(_) => { + call(cb, fptr.into()); + } + _ => unreachable!("unsupported ccall fptr: {fptr:?}") + } }, Insn::CRet(opnd) => { diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs index 7cd677bde3..01212ac88c 100644 --- a/zjit/src/codegen.rs +++ b/zjit/src/codegen.rs @@ -106,8 +106,7 @@ pub extern "C" fn rb_zjit_iseq_gen_entry_point(iseq: IseqPtr, jit_exception: boo } // Always mark the code region executable if asm.compile() has been used. - // We need to do this even if code_ptr is None because, whether gen_entry() - // fails or not, gen_iseq() may have already used asm.compile(). + // We need to do this even if code_ptr is None because gen_iseq() may have already used asm.compile(). cb.mark_all_executable(); code_ptr.map_or(std::ptr::null(), |ptr| ptr.raw_ptr(cb)) @@ -131,10 +130,7 @@ fn gen_iseq_entry_point(cb: &mut CodeBlock, iseq: IseqPtr, jit_exception: bool) debug!("{err:?}: gen_iseq failed: {}", iseq_get_location(iseq, 0)); })?; - // Compile an entry point to the JIT code - gen_entry(cb, iseq, start_ptr).inspect_err(|err| { - debug!("{err:?}: gen_entry failed: {}", iseq_get_location(iseq, 0)); - }) + Ok(start_ptr) } /// Stub a branch for a JIT-to-JIT call @@ -170,14 +166,16 @@ fn register_with_perf(iseq_name: String, start_ptr: usize, code_size: usize) { }; } -/// Compile a JIT entry -fn gen_entry(cb: &mut CodeBlock, iseq: IseqPtr, function_ptr: CodePtr) -> Result<CodePtr, CompileError> { +/// Compile a shared JIT entry trampoline +pub fn gen_entry_trampoline(cb: &mut CodeBlock) -> Result<CodePtr, CompileError> { // Set up registers for CFP, EC, SP, and basic block arguments let mut asm = Assembler::new(); - gen_entry_prologue(&mut asm, iseq); + gen_entry_prologue(&mut asm); - // Jump to the first block using a call instruction - asm.ccall(function_ptr.raw_ptr(cb), vec![]); + // Jump to the first block using a call instruction. This trampoline is used + // as rb_zjit_func_t in jit_exec(), which takes (EC, CFP, rb_jit_func_t). + // So C_ARG_OPNDS[2] is rb_jit_func_t, which is (EC, CFP) -> VALUE. + asm.ccall_reg(C_ARG_OPNDS[2], VALUE_BITS); // Restore registers for CFP, EC, and SP after use asm_comment!(asm, "return to the interpreter"); @@ -190,8 +188,7 @@ fn gen_entry(cb: &mut CodeBlock, iseq: IseqPtr, function_ptr: CodePtr) -> Result let start_ptr = code_ptr.raw_addr(cb); let end_ptr = cb.get_write_ptr().raw_addr(cb); let code_size = end_ptr - start_ptr; - let iseq_name = iseq_get_location(iseq, 0); - register_with_perf(format!("entry for {iseq_name}"), start_ptr, code_size); + register_with_perf("ZJIT entry trampoline".into(), start_ptr, code_size); } Ok(code_ptr) } @@ -990,8 +987,8 @@ fn gen_load_field(asm: &mut Assembler, recv: Opnd, id: ID, offset: i32) -> Opnd } /// Compile an interpreter entry block to be inserted into an ISEQ -fn gen_entry_prologue(asm: &mut Assembler, iseq: IseqPtr) { - asm_comment!(asm, "ZJIT entry point: {}", iseq_get_location(iseq, 0)); +fn gen_entry_prologue(asm: &mut Assembler) { + asm_comment!(asm, "ZJIT entry trampoline"); // Save the registers we'll use for CFP, EP, SP asm.frame_setup(lir::JIT_PRESERVED_REGS); diff --git a/zjit/src/cruby.rs b/zjit/src/cruby.rs index 631acbd863..db47385bc8 100644 --- a/zjit/src/cruby.rs +++ b/zjit/src/cruby.rs @@ -1071,7 +1071,7 @@ pub use manual_defs::*; pub mod test_utils { use std::{ptr::null, sync::Once}; - use crate::{options::{rb_zjit_call_threshold, rb_zjit_prepare_options, set_call_threshold, DEFAULT_CALL_THRESHOLD}, state::{rb_zjit_enabled_p, ZJITState}}; + use crate::{options::{rb_zjit_call_threshold, rb_zjit_prepare_options, set_call_threshold, DEFAULT_CALL_THRESHOLD}, state::{rb_zjit_entry, ZJITState}}; use super::*; @@ -1114,10 +1114,10 @@ pub mod test_utils { } // Set up globals for convenience - ZJITState::init(); + let zjit_entry = ZJITState::init(); // Enable zjit_* instructions - unsafe { rb_zjit_enabled_p = true; } + unsafe { rb_zjit_entry = zjit_entry; } } /// Make sure the Ruby VM is set up and run a given callback with rb_protect() diff --git a/zjit/src/state.rs b/zjit/src/state.rs index c0e9e0b77c..3cb60cffcb 100644 --- a/zjit/src/state.rs +++ b/zjit/src/state.rs @@ -1,6 +1,6 @@ //! Runtime state of ZJIT. -use crate::codegen::{gen_exit_trampoline, gen_exit_trampoline_with_counter, gen_function_stub_hit_trampoline}; +use crate::codegen::{gen_entry_trampoline, gen_exit_trampoline, gen_exit_trampoline_with_counter, gen_function_stub_hit_trampoline}; use crate::cruby::{self, rb_bug_panic_hook, rb_vm_insn_count, EcPtr, Qnil, rb_vm_insn_addr2opcode, rb_profile_frames, VALUE, VM_INSTRUCTION_SIZE, size_t, rb_gc_mark}; use crate::cruby_methods; use crate::invariants::Invariants; @@ -9,14 +9,16 @@ use crate::options::get_option; use crate::stats::{Counters, InsnCounters, SideExitLocations}; use crate::virtualmem::CodePtr; use std::collections::HashMap; +use std::ptr::null; +/// Shared trampoline to enter ZJIT. Not null when ZJIT is enabled. #[allow(non_upper_case_globals)] #[unsafe(no_mangle)] -pub static mut rb_zjit_enabled_p: bool = false; +pub static mut rb_zjit_entry: *const u8 = null(); /// Like rb_zjit_enabled_p, but for Rust code. pub fn zjit_enabled_p() -> bool { - unsafe { rb_zjit_enabled_p } + unsafe { rb_zjit_entry != null() } } /// Global state needed for code generation @@ -65,8 +67,8 @@ pub struct ZJITState { static mut ZJIT_STATE: Option<ZJITState> = None; impl ZJITState { - /// Initialize the ZJIT globals - pub fn init() { + /// Initialize the ZJIT globals. Return the address of the JIT entry trampoline. + pub fn init() -> *const u8 { let mut cb = { use crate::options::*; use crate::virtualmem::*; @@ -79,6 +81,7 @@ impl ZJITState { CodeBlock::new(mem_block.clone(), get_option!(dump_disasm)) }; + let entry_trampoline = gen_entry_trampoline(&mut cb).unwrap().raw_ptr(&cb); let exit_trampoline = gen_exit_trampoline(&mut cb).unwrap(); let function_stub_hit_trampoline = gen_function_stub_hit_trampoline(&mut cb).unwrap(); @@ -114,6 +117,8 @@ impl ZJITState { let code_ptr = gen_exit_trampoline_with_counter(cb, exit_trampoline).unwrap(); ZJITState::get_instance().exit_trampoline_with_counter = code_ptr; } + + entry_trampoline } /// Return true if zjit_state has been initialized @@ -252,7 +257,7 @@ pub extern "C" fn rb_zjit_init() { let result = std::panic::catch_unwind(|| { // Initialize ZJIT states cruby::ids::init(); - ZJITState::init(); + let zjit_entry = ZJITState::init(); // Install a panic hook for ZJIT rb_bug_panic_hook(); @@ -261,8 +266,8 @@ pub extern "C" fn rb_zjit_init() { unsafe { rb_vm_insn_count = 0; } // ZJIT enabled and initialized successfully - assert!(unsafe{ !rb_zjit_enabled_p }); - unsafe { rb_zjit_enabled_p = true; } + assert!(unsafe{ rb_zjit_entry == null() }); + unsafe { rb_zjit_entry = zjit_entry; } }); if result.is_err() { |
