summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTakashi Kokubun <takashi.kokubun@shopify.com>2025-11-04 16:09:13 -0800
committerGitHub <noreply@github.com>2025-11-04 16:09:13 -0800
commitbd3b44cb0a341878abe0edf65d01b1a48c93f088 (patch)
treecdaa1d06a4b5f8a135464f50f92d45f046553c19
parentbe905b2e581540dc2c51a54aed537b19955b7bb0 (diff)
ZJIT: Use a shared trampoline across all ISEQs (#15042)
-rw-r--r--vm.c69
-rw-r--r--vm_core.h1
-rw-r--r--vm_exec.h19
-rw-r--r--zjit.h6
-rw-r--r--zjit/src/backend/arm64/mod.rs30
-rw-r--r--zjit/src/backend/lir.rs15
-rw-r--r--zjit/src/backend/x86_64/mod.rs10
-rw-r--r--zjit/src/codegen.rs27
-rw-r--r--zjit/src/cruby.rs6
-rw-r--r--zjit/src/state.rs21
10 files changed, 133 insertions, 71 deletions
diff --git a/vm.c b/vm.c
index 32785dbcc8..f0aebf08a3 100644
--- a/vm.c
+++ b/vm.c
@@ -503,7 +503,7 @@ rb_yjit_threshold_hit(const rb_iseq_t *iseq, uint64_t entry_calls)
#define rb_yjit_threshold_hit(iseq, entry_calls) false
#endif
-#if USE_YJIT || USE_ZJIT
+#if USE_YJIT
// Generate JIT code that supports the following kinds of ISEQ entries:
// * The first ISEQ on vm_exec (e.g. <main>, or Ruby methods/blocks
// called by a C method). The current frame has VM_FRAME_FLAG_FINISH.
@@ -513,13 +513,32 @@ rb_yjit_threshold_hit(const rb_iseq_t *iseq, uint64_t entry_calls)
// The current frame doesn't have VM_FRAME_FLAG_FINISH. The current
// vm_exec does NOT stop whether JIT code returns Qundef or not.
static inline rb_jit_func_t
-jit_compile(rb_execution_context_t *ec)
+yjit_compile(rb_execution_context_t *ec)
{
const rb_iseq_t *iseq = ec->cfp->iseq;
struct rb_iseq_constant_body *body = ISEQ_BODY(iseq);
+ // Increment the ISEQ's call counter and trigger JIT compilation if not compiled
+ if (body->jit_entry == NULL) {
+ body->jit_entry_calls++;
+ if (rb_yjit_threshold_hit(iseq, body->jit_entry_calls)) {
+ rb_yjit_compile_iseq(iseq, ec, false);
+ }
+ }
+ return body->jit_entry;
+}
+#else
+# define yjit_compile(ec) ((rb_jit_func_t)0)
+#endif
+
#if USE_ZJIT
- if (body->jit_entry == NULL && rb_zjit_enabled_p) {
+static inline rb_jit_func_t
+zjit_compile(rb_execution_context_t *ec)
+{
+ const rb_iseq_t *iseq = ec->cfp->iseq;
+ struct rb_iseq_constant_body *body = ISEQ_BODY(iseq);
+
+ if (body->jit_entry == NULL) {
body->jit_entry_calls++;
// At profile-threshold, rewrite some of the YARV instructions
@@ -533,38 +552,38 @@ jit_compile(rb_execution_context_t *ec)
rb_zjit_compile_iseq(iseq, false);
}
}
-#endif
-
-#if USE_YJIT
- // Increment the ISEQ's call counter and trigger JIT compilation if not compiled
- if (body->jit_entry == NULL && rb_yjit_enabled_p) {
- body->jit_entry_calls++;
- if (rb_yjit_threshold_hit(iseq, body->jit_entry_calls)) {
- rb_yjit_compile_iseq(iseq, ec, false);
- }
- }
-#endif
return body->jit_entry;
}
+#else
+# define zjit_compile(ec) ((rb_jit_func_t)0)
+#endif
-// Execute JIT code compiled by jit_compile()
+// Execute JIT code compiled by yjit_compile() or zjit_compile()
static inline VALUE
jit_exec(rb_execution_context_t *ec)
{
- rb_jit_func_t func = jit_compile(ec);
- if (func) {
- // Call the JIT code
- return func(ec, ec->cfp);
- }
- else {
+#if USE_YJIT
+ if (rb_yjit_enabled_p) {
+ rb_jit_func_t func = yjit_compile(ec);
+ if (func) {
+ return func(ec, ec->cfp);
+ }
return Qundef;
}
-}
-#else
-# define jit_compile(ec) ((rb_jit_func_t)0)
-# define jit_exec(ec) Qundef
#endif
+#if USE_ZJIT
+ void *zjit_entry = rb_zjit_entry;
+ if (zjit_entry) {
+ rb_jit_func_t func = zjit_compile(ec);
+ if (func) {
+ return ((rb_zjit_func_t)zjit_entry)(ec, ec->cfp, func);
+ }
+ }
+#endif
+ return Qundef;
+}
+
#if USE_YJIT
// Generate JIT code that supports the following kind of ISEQ entry:
// * The first ISEQ pushed by vm_exec_handle_exception. The frame would
diff --git a/vm_core.h b/vm_core.h
index e8e6a6a3a6..ded0280387 100644
--- a/vm_core.h
+++ b/vm_core.h
@@ -398,6 +398,7 @@ enum rb_builtin_attr {
};
typedef VALUE (*rb_jit_func_t)(struct rb_execution_context_struct *, struct rb_control_frame_struct *);
+typedef VALUE (*rb_zjit_func_t)(struct rb_execution_context_struct *, struct rb_control_frame_struct *, rb_jit_func_t);
struct rb_iseq_constant_body {
enum rb_iseq_type type;
diff --git a/vm_exec.h b/vm_exec.h
index c3b7d4e488..033a48f1e7 100644
--- a/vm_exec.h
+++ b/vm_exec.h
@@ -175,11 +175,22 @@ default: \
// Run the JIT from the interpreter
#define JIT_EXEC(ec, val) do { \
- rb_jit_func_t func; \
/* don't run tailcalls since that breaks FINISH */ \
- if (UNDEF_P(val) && GET_CFP() != ec->cfp && (func = jit_compile(ec))) { \
- val = func(ec, ec->cfp); \
- if (ec->tag->state) THROW_EXCEPTION(val); \
+ if (UNDEF_P(val) && GET_CFP() != ec->cfp) { \
+ rb_zjit_func_t zjit_entry; \
+ if (rb_yjit_enabled_p) { \
+ rb_jit_func_t func = yjit_compile(ec); \
+ if (func) { \
+ val = func(ec, ec->cfp); \
+ if (ec->tag->state) THROW_EXCEPTION(val); \
+ } \
+ } \
+ else if ((zjit_entry = rb_zjit_entry)) { \
+ rb_jit_func_t func = zjit_compile(ec); \
+ if (func) { \
+ val = zjit_entry(ec, ec->cfp, func); \
+ } \
+ } \
} \
} while (0)
diff --git a/zjit.h b/zjit.h
index 7b3e410c91..47240846ff 100644
--- a/zjit.h
+++ b/zjit.h
@@ -10,7 +10,7 @@
#endif
#if USE_ZJIT
-extern bool rb_zjit_enabled_p;
+extern void *rb_zjit_entry;
extern uint64_t rb_zjit_call_threshold;
extern uint64_t rb_zjit_profile_threshold;
void rb_zjit_compile_iseq(const rb_iseq_t *iseq, bool jit_exception);
@@ -29,7 +29,7 @@ void rb_zjit_before_ractor_spawn(void);
void rb_zjit_tracing_invalidate_all(void);
void rb_zjit_invalidate_no_singleton_class(VALUE klass);
#else
-#define rb_zjit_enabled_p false
+#define rb_zjit_entry 0
static inline void rb_zjit_compile_iseq(const rb_iseq_t *iseq, bool jit_exception) {}
static inline void rb_zjit_profile_insn(uint32_t insn, rb_execution_context_t *ec) {}
static inline void rb_zjit_profile_enable(const rb_iseq_t *iseq) {}
@@ -42,4 +42,6 @@ static inline void rb_zjit_tracing_invalidate_all(void) {}
static inline void rb_zjit_invalidate_no_singleton_class(VALUE klass) {}
#endif // #if USE_ZJIT
+#define rb_zjit_enabled_p (rb_zjit_entry != 0)
+
#endif // #ifndef ZJIT_H
diff --git a/zjit/src/backend/arm64/mod.rs b/zjit/src/backend/arm64/mod.rs
index acf0576f9c..532570d732 100644
--- a/zjit/src/backend/arm64/mod.rs
+++ b/zjit/src/backend/arm64/mod.rs
@@ -1428,17 +1428,25 @@ impl Assembler {
}
},
Insn::CCall { fptr, .. } => {
- // The offset to the call target in bytes
- let src_addr = cb.get_write_ptr().raw_ptr(cb) as i64;
- let dst_addr = *fptr as i64;
-
- // Use BL if the offset is short enough to encode as an immediate.
- // Otherwise, use BLR with a register.
- if b_offset_fits_bits((dst_addr - src_addr) / 4) {
- bl(cb, InstructionOffset::from_bytes((dst_addr - src_addr) as i32));
- } else {
- emit_load_value(cb, Self::EMIT_OPND, dst_addr as u64);
- blr(cb, Self::EMIT_OPND);
+ match fptr {
+ Opnd::UImm(fptr) => {
+ // The offset to the call target in bytes
+ let src_addr = cb.get_write_ptr().raw_ptr(cb) as i64;
+ let dst_addr = *fptr as i64;
+
+ // Use BL if the offset is short enough to encode as an immediate.
+ // Otherwise, use BLR with a register.
+ if b_offset_fits_bits((dst_addr - src_addr) / 4) {
+ bl(cb, InstructionOffset::from_bytes((dst_addr - src_addr) as i32));
+ } else {
+ emit_load_value(cb, Self::EMIT_OPND, dst_addr as u64);
+ blr(cb, Self::EMIT_OPND);
+ }
+ }
+ Opnd::Reg(_) => {
+ blr(cb, fptr.into());
+ }
+ _ => unreachable!("unsupported ccall fptr: {fptr:?}")
}
},
Insn::CRet { .. } => {
diff --git a/zjit/src/backend/lir.rs b/zjit/src/backend/lir.rs
index 66e89a1304..e2f75e01c8 100644
--- a/zjit/src/backend/lir.rs
+++ b/zjit/src/backend/lir.rs
@@ -386,7 +386,9 @@ pub enum Insn {
// C function call with N arguments (variadic)
CCall {
opnds: Vec<Opnd>,
- fptr: *const u8,
+ /// The function pointer to be called. This should be Opnd::const_ptr
+ /// (Opnd::UImm) in most cases. gen_entry_trampoline() uses Opnd::Reg.
+ fptr: Opnd,
/// Optional PosMarker to remember the start address of the C call.
/// It's embedded here to insert the PosMarker after push instructions
/// that are split from this CCall on alloc_regs().
@@ -1989,11 +1991,20 @@ impl Assembler {
pub fn ccall(&mut self, fptr: *const u8, opnds: Vec<Opnd>) -> Opnd {
let canary_opnd = self.set_stack_canary();
let out = self.new_vreg(Opnd::match_num_bits(&opnds));
+ let fptr = Opnd::const_ptr(fptr);
self.push_insn(Insn::CCall { fptr, opnds, start_marker: None, end_marker: None, out });
self.clear_stack_canary(canary_opnd);
out
}
+ /// Call a C function stored in a register
+ pub fn ccall_reg(&mut self, fptr: Opnd, num_bits: u8) -> Opnd {
+ assert!(matches!(fptr, Opnd::Reg(_)), "ccall_reg must be called with Opnd::Reg: {fptr:?}");
+ let out = self.new_vreg(num_bits);
+ self.push_insn(Insn::CCall { fptr, opnds: vec![], start_marker: None, end_marker: None, out });
+ out
+ }
+
/// Call a C function with PosMarkers. This is used for recording the start and end
/// addresses of the C call and rewriting it with a different function address later.
pub fn ccall_with_pos_markers(
@@ -2005,7 +2016,7 @@ impl Assembler {
) -> Opnd {
let out = self.new_vreg(Opnd::match_num_bits(&opnds));
self.push_insn(Insn::CCall {
- fptr,
+ fptr: Opnd::const_ptr(fptr),
opnds,
start_marker: Some(Rc::new(start_marker)),
end_marker: Some(Rc::new(end_marker)),
diff --git a/zjit/src/backend/x86_64/mod.rs b/zjit/src/backend/x86_64/mod.rs
index 1d5d90a856..aea25ca2a4 100644
--- a/zjit/src/backend/x86_64/mod.rs
+++ b/zjit/src/backend/x86_64/mod.rs
@@ -863,7 +863,15 @@ impl Assembler {
// C function call
Insn::CCall { fptr, .. } => {
- call_ptr(cb, RAX, *fptr);
+ match fptr {
+ Opnd::UImm(fptr) => {
+ call_ptr(cb, RAX, *fptr as *const u8);
+ }
+ Opnd::Reg(_) => {
+ call(cb, fptr.into());
+ }
+ _ => unreachable!("unsupported ccall fptr: {fptr:?}")
+ }
},
Insn::CRet(opnd) => {
diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs
index 7cd677bde3..01212ac88c 100644
--- a/zjit/src/codegen.rs
+++ b/zjit/src/codegen.rs
@@ -106,8 +106,7 @@ pub extern "C" fn rb_zjit_iseq_gen_entry_point(iseq: IseqPtr, jit_exception: boo
}
// Always mark the code region executable if asm.compile() has been used.
- // We need to do this even if code_ptr is None because, whether gen_entry()
- // fails or not, gen_iseq() may have already used asm.compile().
+ // We need to do this even if code_ptr is None because gen_iseq() may have already used asm.compile().
cb.mark_all_executable();
code_ptr.map_or(std::ptr::null(), |ptr| ptr.raw_ptr(cb))
@@ -131,10 +130,7 @@ fn gen_iseq_entry_point(cb: &mut CodeBlock, iseq: IseqPtr, jit_exception: bool)
debug!("{err:?}: gen_iseq failed: {}", iseq_get_location(iseq, 0));
})?;
- // Compile an entry point to the JIT code
- gen_entry(cb, iseq, start_ptr).inspect_err(|err| {
- debug!("{err:?}: gen_entry failed: {}", iseq_get_location(iseq, 0));
- })
+ Ok(start_ptr)
}
/// Stub a branch for a JIT-to-JIT call
@@ -170,14 +166,16 @@ fn register_with_perf(iseq_name: String, start_ptr: usize, code_size: usize) {
};
}
-/// Compile a JIT entry
-fn gen_entry(cb: &mut CodeBlock, iseq: IseqPtr, function_ptr: CodePtr) -> Result<CodePtr, CompileError> {
+/// Compile a shared JIT entry trampoline
+pub fn gen_entry_trampoline(cb: &mut CodeBlock) -> Result<CodePtr, CompileError> {
// Set up registers for CFP, EC, SP, and basic block arguments
let mut asm = Assembler::new();
- gen_entry_prologue(&mut asm, iseq);
+ gen_entry_prologue(&mut asm);
- // Jump to the first block using a call instruction
- asm.ccall(function_ptr.raw_ptr(cb), vec![]);
+ // Jump to the first block using a call instruction. This trampoline is used
+ // as rb_zjit_func_t in jit_exec(), which takes (EC, CFP, rb_jit_func_t).
+ // So C_ARG_OPNDS[2] is rb_jit_func_t, which is (EC, CFP) -> VALUE.
+ asm.ccall_reg(C_ARG_OPNDS[2], VALUE_BITS);
// Restore registers for CFP, EC, and SP after use
asm_comment!(asm, "return to the interpreter");
@@ -190,8 +188,7 @@ fn gen_entry(cb: &mut CodeBlock, iseq: IseqPtr, function_ptr: CodePtr) -> Result
let start_ptr = code_ptr.raw_addr(cb);
let end_ptr = cb.get_write_ptr().raw_addr(cb);
let code_size = end_ptr - start_ptr;
- let iseq_name = iseq_get_location(iseq, 0);
- register_with_perf(format!("entry for {iseq_name}"), start_ptr, code_size);
+ register_with_perf("ZJIT entry trampoline".into(), start_ptr, code_size);
}
Ok(code_ptr)
}
@@ -990,8 +987,8 @@ fn gen_load_field(asm: &mut Assembler, recv: Opnd, id: ID, offset: i32) -> Opnd
}
/// Compile an interpreter entry block to be inserted into an ISEQ
-fn gen_entry_prologue(asm: &mut Assembler, iseq: IseqPtr) {
- asm_comment!(asm, "ZJIT entry point: {}", iseq_get_location(iseq, 0));
+fn gen_entry_prologue(asm: &mut Assembler) {
+ asm_comment!(asm, "ZJIT entry trampoline");
// Save the registers we'll use for CFP, EP, SP
asm.frame_setup(lir::JIT_PRESERVED_REGS);
diff --git a/zjit/src/cruby.rs b/zjit/src/cruby.rs
index 631acbd863..db47385bc8 100644
--- a/zjit/src/cruby.rs
+++ b/zjit/src/cruby.rs
@@ -1071,7 +1071,7 @@ pub use manual_defs::*;
pub mod test_utils {
use std::{ptr::null, sync::Once};
- use crate::{options::{rb_zjit_call_threshold, rb_zjit_prepare_options, set_call_threshold, DEFAULT_CALL_THRESHOLD}, state::{rb_zjit_enabled_p, ZJITState}};
+ use crate::{options::{rb_zjit_call_threshold, rb_zjit_prepare_options, set_call_threshold, DEFAULT_CALL_THRESHOLD}, state::{rb_zjit_entry, ZJITState}};
use super::*;
@@ -1114,10 +1114,10 @@ pub mod test_utils {
}
// Set up globals for convenience
- ZJITState::init();
+ let zjit_entry = ZJITState::init();
// Enable zjit_* instructions
- unsafe { rb_zjit_enabled_p = true; }
+ unsafe { rb_zjit_entry = zjit_entry; }
}
/// Make sure the Ruby VM is set up and run a given callback with rb_protect()
diff --git a/zjit/src/state.rs b/zjit/src/state.rs
index c0e9e0b77c..3cb60cffcb 100644
--- a/zjit/src/state.rs
+++ b/zjit/src/state.rs
@@ -1,6 +1,6 @@
//! Runtime state of ZJIT.
-use crate::codegen::{gen_exit_trampoline, gen_exit_trampoline_with_counter, gen_function_stub_hit_trampoline};
+use crate::codegen::{gen_entry_trampoline, gen_exit_trampoline, gen_exit_trampoline_with_counter, gen_function_stub_hit_trampoline};
use crate::cruby::{self, rb_bug_panic_hook, rb_vm_insn_count, EcPtr, Qnil, rb_vm_insn_addr2opcode, rb_profile_frames, VALUE, VM_INSTRUCTION_SIZE, size_t, rb_gc_mark};
use crate::cruby_methods;
use crate::invariants::Invariants;
@@ -9,14 +9,16 @@ use crate::options::get_option;
use crate::stats::{Counters, InsnCounters, SideExitLocations};
use crate::virtualmem::CodePtr;
use std::collections::HashMap;
+use std::ptr::null;
+/// Shared trampoline to enter ZJIT. Not null when ZJIT is enabled.
#[allow(non_upper_case_globals)]
#[unsafe(no_mangle)]
-pub static mut rb_zjit_enabled_p: bool = false;
+pub static mut rb_zjit_entry: *const u8 = null();
/// Like rb_zjit_enabled_p, but for Rust code.
pub fn zjit_enabled_p() -> bool {
- unsafe { rb_zjit_enabled_p }
+ unsafe { rb_zjit_entry != null() }
}
/// Global state needed for code generation
@@ -65,8 +67,8 @@ pub struct ZJITState {
static mut ZJIT_STATE: Option<ZJITState> = None;
impl ZJITState {
- /// Initialize the ZJIT globals
- pub fn init() {
+ /// Initialize the ZJIT globals. Return the address of the JIT entry trampoline.
+ pub fn init() -> *const u8 {
let mut cb = {
use crate::options::*;
use crate::virtualmem::*;
@@ -79,6 +81,7 @@ impl ZJITState {
CodeBlock::new(mem_block.clone(), get_option!(dump_disasm))
};
+ let entry_trampoline = gen_entry_trampoline(&mut cb).unwrap().raw_ptr(&cb);
let exit_trampoline = gen_exit_trampoline(&mut cb).unwrap();
let function_stub_hit_trampoline = gen_function_stub_hit_trampoline(&mut cb).unwrap();
@@ -114,6 +117,8 @@ impl ZJITState {
let code_ptr = gen_exit_trampoline_with_counter(cb, exit_trampoline).unwrap();
ZJITState::get_instance().exit_trampoline_with_counter = code_ptr;
}
+
+ entry_trampoline
}
/// Return true if zjit_state has been initialized
@@ -252,7 +257,7 @@ pub extern "C" fn rb_zjit_init() {
let result = std::panic::catch_unwind(|| {
// Initialize ZJIT states
cruby::ids::init();
- ZJITState::init();
+ let zjit_entry = ZJITState::init();
// Install a panic hook for ZJIT
rb_bug_panic_hook();
@@ -261,8 +266,8 @@ pub extern "C" fn rb_zjit_init() {
unsafe { rb_vm_insn_count = 0; }
// ZJIT enabled and initialized successfully
- assert!(unsafe{ !rb_zjit_enabled_p });
- unsafe { rb_zjit_enabled_p = true; }
+ assert!(unsafe{ rb_zjit_entry == null() });
+ unsafe { rb_zjit_entry = zjit_entry; }
});
if result.is_err() {