summaryrefslogtreecommitdiff
path: root/zjit/src/state.rs
diff options
context:
space:
mode:
Diffstat (limited to 'zjit/src/state.rs')
-rw-r--r--zjit/src/state.rs541
1 files changed, 541 insertions, 0 deletions
diff --git a/zjit/src/state.rs b/zjit/src/state.rs
new file mode 100644
index 0000000000..da09d09314
--- /dev/null
+++ b/zjit/src/state.rs
@@ -0,0 +1,541 @@
+//! Runtime state of ZJIT.
+
+use crate::codegen::{gen_entry_trampoline, gen_exit_trampoline, gen_function_stub_hit_trampoline, gen_materialize_exit_trampoline, gen_materialize_exit_trampoline_with_counter};
+use crate::cruby::{self, rb_bug_panic_hook, rb_vm_insn_count, src_loc, EcPtr, Qnil, Qtrue, rb_profile_frames, rb_profile_frame_full_label, rb_profile_frame_absolute_path, rb_profile_frame_path, VALUE, VM_INSTRUCTION_SIZE, with_vm_lock, rust_str_to_id, rb_funcallv, rb_const_get, rb_cRubyVM};
+use crate::cruby_methods;
+use cruby::{ID, rb_callable_method_entry, get_def_method_serial, rb_gc_register_mark_object, ruby_str_to_rust_string_result};
+use std::sync::atomic::Ordering;
+use crate::invariants::Invariants;
+use crate::asm::CodeBlock;
+use crate::options::{get_option, rb_zjit_prepare_options};
+use crate::jit_frame::JITFrame;
+use crate::stats::{Counters, InsnCounters, PerfettoTracer};
+use crate::virtualmem::CodePtr;
+use std::sync::atomic::AtomicUsize;
+use std::collections::HashMap;
+use std::ptr::null;
+
+/// Shared trampoline to enter ZJIT. Not null when ZJIT is enabled.
+#[allow(non_upper_case_globals)]
+#[unsafe(no_mangle)]
+pub static mut rb_zjit_entry: *const u8 = null();
+
+/// Like rb_zjit_enabled_p, but for Rust code.
+pub fn zjit_enabled_p() -> bool {
+ unsafe { rb_zjit_entry != null() }
+}
+
+/// Global state needed for code generation
+pub struct ZJITState {
+ /// Inline code block (fast path)
+ code_block: CodeBlock,
+
+ /// ZJIT statistics
+ counters: Counters,
+
+ /// Side-exit counters
+ exit_counters: InsnCounters,
+
+ /// Send fallback counters
+ send_fallback_counters: InsnCounters,
+
+ /// Assumptions that require invalidation
+ invariants: Invariants,
+
+ /// Assert successful compilation if set to true
+ assert_compiles: bool,
+
+ /// Properties of core library methods
+ method_annotations: cruby_methods::Annotations,
+
+ /// Trampoline to side-exit without restoring PC or the stack
+ exit_trampoline: CodePtr,
+
+ /// Trampoline to materialize JIT frames before side-exiting
+ materialize_exit_trampoline: CodePtr,
+
+ /// Trampoline to materialize JIT frames and increment exit_compilation_failure
+ materialize_exit_trampoline_with_counter: CodePtr,
+
+ /// Trampoline to call function_stub_hit
+ function_stub_hit_trampoline: CodePtr,
+
+ /// Counter pointers for full frame C functions
+ full_frame_cfunc_counter_pointers: HashMap<String, Box<u64>>,
+
+ /// Counter pointers for un-annotated C functions
+ not_annotated_frame_cfunc_counter_pointers: HashMap<String, Box<u64>>,
+
+ /// Counter pointers for all calls to any kind of C function from JIT code
+ ccall_counter_pointers: HashMap<String, Box<u64>>,
+
+ /// Counter pointers for access counts of ISEQs accessed by JIT code
+ iseq_calls_count_pointers: HashMap<String, Box<u64>>,
+
+ /// Perfetto tracer for --zjit-trace-exits
+ perfetto_tracer: Option<PerfettoTracer>,
+
+ /// Frame metadata for ISEQ and C calls that are known at compile time
+ jit_frames: Vec<*mut JITFrame>,
+}
+
+/// Tracks the initialization progress
+enum InitializationState {
+ Uninitialized,
+
+ /// At boot time, rb_zjit_init will be called regardless of whether
+ /// ZJIT is enabled, in this phase we initialize any states that must
+ /// be captured at during boot.
+ Initialized(cruby_methods::Annotations),
+
+ /// When ZJIT is enabled, either during boot with `--zjit`, or lazily
+ /// at a later time with `RubyVM::ZJIT.enable`, we perform the rest
+ /// of the initialization steps and produce the `ZJITState` instance.
+ Enabled(ZJITState),
+
+ /// Indicates that ZJITState::init has panicked. Should never be
+ /// encountered in practice since we abort immediately when that
+ /// happens.
+ Panicked,
+}
+
+/// Private singleton instance of the codegen globals
+static mut ZJIT_STATE: InitializationState = InitializationState::Uninitialized;
+
+impl ZJITState {
+ /// Initialize the ZJIT globals. Return the address of the JIT entry trampoline.
+ pub fn init() -> *const u8 {
+ use InitializationState::*;
+
+ let initialization_state = unsafe {
+ std::mem::replace(&mut ZJIT_STATE, Panicked)
+ };
+
+ let Initialized(method_annotations) = initialization_state else {
+ panic!("rb_zjit_init was never called");
+ };
+
+ let mut cb = {
+ use crate::options::*;
+ use crate::virtualmem::*;
+ use std::rc::Rc;
+ use std::cell::RefCell;
+
+ let mem_block = VirtualMem::alloc(get_option!(exec_mem_bytes), Some(get_option!(mem_bytes)));
+ let mem_block = Rc::new(RefCell::new(mem_block));
+
+ CodeBlock::new(mem_block.clone(), get_option_ref!(dump_disasm).is_some())
+ };
+
+ let entry_trampoline = gen_entry_trampoline(&mut cb).unwrap().raw_ptr(&cb);
+ let exit_trampoline = gen_exit_trampoline(&mut cb).unwrap();
+ let materialize_exit_trampoline = gen_materialize_exit_trampoline(&mut cb, exit_trampoline).unwrap();
+ let function_stub_hit_trampoline = gen_function_stub_hit_trampoline(&mut cb).unwrap();
+
+ let perfetto_tracer = if get_option!(trace_side_exits).is_some() || get_option!(trace_compiles) || get_option!(trace_invalidation) {
+ Some(PerfettoTracer::new())
+ } else {
+ None
+ };
+
+ // Initialize the codegen globals instance
+ let zjit_state = ZJITState {
+ code_block: cb,
+ counters: Counters::default(),
+ exit_counters: [0; VM_INSTRUCTION_SIZE as usize],
+ send_fallback_counters: [0; VM_INSTRUCTION_SIZE as usize],
+ invariants: Invariants::default(),
+ assert_compiles: false,
+ method_annotations,
+ exit_trampoline,
+ materialize_exit_trampoline,
+ materialize_exit_trampoline_with_counter: materialize_exit_trampoline,
+ function_stub_hit_trampoline,
+ full_frame_cfunc_counter_pointers: HashMap::new(),
+ not_annotated_frame_cfunc_counter_pointers: HashMap::new(),
+ ccall_counter_pointers: HashMap::new(),
+ iseq_calls_count_pointers: HashMap::new(),
+ perfetto_tracer,
+ jit_frames: vec![],
+ };
+ unsafe { ZJIT_STATE = Enabled(zjit_state); }
+
+ // With --zjit-stats, use a different trampoline on function stub exits
+ // to count exit_compilation_failure. Note that the trampoline code depends
+ // on the counter, so ZJIT_STATE needs to be initialized first.
+ if get_option!(stats) {
+ let cb = ZJITState::get_code_block();
+ let code_ptr = gen_materialize_exit_trampoline_with_counter(cb, materialize_exit_trampoline).unwrap();
+ ZJITState::get_instance().materialize_exit_trampoline_with_counter = code_ptr;
+ }
+
+ entry_trampoline
+ }
+
+ /// Return true if zjit_state has been initialized
+ pub fn has_instance() -> bool {
+ matches!(unsafe { &ZJIT_STATE }, InitializationState::Enabled(_))
+ }
+
+ /// Get a mutable reference to the codegen globals instance
+ fn get_instance() -> &'static mut ZJITState {
+ if let InitializationState::Enabled(instance) = unsafe { &mut ZJIT_STATE } {
+ instance
+ } else {
+ panic!("ZJITState::get_instance called when ZJIT is not enabled")
+ }
+ }
+
+ /// Get a mutable reference to the inline code block
+ pub fn get_code_block() -> &'static mut CodeBlock {
+ &mut ZJITState::get_instance().code_block
+ }
+
+ /// Get a mutable reference to the invariants
+ pub fn get_invariants() -> &'static mut Invariants {
+ &mut ZJITState::get_instance().invariants
+ }
+
+ pub fn get_jit_frames() -> &'static mut Vec<*mut JITFrame> {
+ &mut ZJITState::get_instance().jit_frames
+ }
+
+ pub fn get_method_annotations() -> &'static cruby_methods::Annotations {
+ &ZJITState::get_instance().method_annotations
+ }
+
+ /// Return true if successful compilation should be asserted
+ pub fn assert_compiles_enabled() -> bool {
+ ZJITState::get_instance().assert_compiles
+ }
+
+ /// Start asserting successful compilation
+ pub fn enable_assert_compiles() {
+ let instance = ZJITState::get_instance();
+ instance.assert_compiles = true;
+ }
+
+ /// Stop asserting successful compilation
+ pub fn disable_assert_compiles() {
+ let instance = ZJITState::get_instance();
+ instance.assert_compiles = false;
+ }
+
+ /// Get a mutable reference to counters for ZJIT stats
+ pub fn get_counters() -> &'static mut Counters {
+ &mut ZJITState::get_instance().counters
+ }
+
+ /// Get a mutable reference to side-exit counters
+ pub fn get_exit_counters() -> &'static mut InsnCounters {
+ &mut ZJITState::get_instance().exit_counters
+ }
+
+ /// Get a mutable reference to fallback counters
+ pub fn get_send_fallback_counters() -> &'static mut InsnCounters {
+ &mut ZJITState::get_instance().send_fallback_counters
+ }
+
+ /// Get a mutable reference to full frame cfunc counter pointers
+ pub fn get_not_inlined_cfunc_counter_pointers() -> &'static mut HashMap<String, Box<u64>> {
+ &mut ZJITState::get_instance().full_frame_cfunc_counter_pointers
+ }
+
+ /// Get a mutable reference to non-annotated cfunc counter pointers
+ pub fn get_not_annotated_cfunc_counter_pointers() -> &'static mut HashMap<String, Box<u64>> {
+ &mut ZJITState::get_instance().not_annotated_frame_cfunc_counter_pointers
+ }
+
+ /// Get a mutable reference to ccall counter pointers
+ pub fn get_ccall_counter_pointers() -> &'static mut HashMap<String, Box<u64>> {
+ &mut ZJITState::get_instance().ccall_counter_pointers
+ }
+
+ /// Get a mutable reference to iseq access count pointers
+ pub fn get_iseq_calls_count_pointers() -> &'static mut HashMap<String, Box<u64>> {
+ &mut ZJITState::get_instance().iseq_calls_count_pointers
+ }
+
+ /// Was --zjit-save-compiled-iseqs specified?
+ pub fn should_log_compiled_iseqs() -> bool {
+ get_option!(log_compiled_iseqs).is_some()
+ }
+
+ /// Log the name of a compiled ISEQ to the file specified in options.log_compiled_iseqs
+ pub fn log_compile(iseq_name: String) {
+ assert!(ZJITState::should_log_compiled_iseqs());
+ let filename = get_option!(log_compiled_iseqs).as_ref().unwrap();
+ use std::io::Write;
+ let mut file = match std::fs::OpenOptions::new().create(true).append(true).open(filename) {
+ Ok(f) => f,
+ Err(e) => {
+ eprintln!("ZJIT: Failed to create file '{}': {}", filename.display(), e);
+ return;
+ }
+ };
+ if let Err(e) = writeln!(file, "{iseq_name}") {
+ eprintln!("ZJIT: Failed to write to file '{}': {}", filename.display(), e);
+ }
+ }
+
+ /// Check if we are allowed to compile a given ISEQ based on --zjit-allowed-iseqs
+ pub fn can_compile_iseq(iseq: cruby::IseqPtr) -> bool {
+ if let Some(ref allowed_iseqs) = get_option!(allowed_iseqs) {
+ let name = cruby::iseq_get_location(iseq, 0);
+ allowed_iseqs.contains(&name)
+ } else {
+ true // If no restrictions, allow all ISEQs
+ }
+ }
+
+ /// Return a code pointer to the side-exit trampoline
+ pub fn get_exit_trampoline() -> CodePtr {
+ ZJITState::get_instance().exit_trampoline
+ }
+
+ /// Return a code pointer to the materialize_exit trampoline
+ pub fn get_materialize_exit_trampoline() -> CodePtr {
+ ZJITState::get_instance().materialize_exit_trampoline
+ }
+
+ /// Return a code pointer to the materialize_exit trampoline for function stubs
+ pub fn get_materialize_exit_trampoline_with_counter() -> CodePtr {
+ ZJITState::get_instance().materialize_exit_trampoline_with_counter
+ }
+
+ /// Return a code pointer to the function stub hit trampoline
+ pub fn get_function_stub_hit_trampoline() -> CodePtr {
+ ZJITState::get_instance().function_stub_hit_trampoline
+ }
+
+ /// Get a mutable reference to the Perfetto tracer
+ pub fn get_tracer() -> Option<&'static mut PerfettoTracer> {
+ if !ZJITState::has_instance() { return None; }
+ ZJITState::get_instance().perfetto_tracer.as_mut()
+ }
+}
+
+/// The `::RubyVM::ZJIT` module.
+pub static ZJIT_MODULE: AtomicUsize = AtomicUsize::new(!0);
+/// Serial of the canonical version of `induce_side_exit!` right after VM boot.
+pub static INDUCE_SIDE_EXIT_SERIAL: AtomicUsize = AtomicUsize::new(!0);
+/// Serial of the canonical version of `induce_compile_failure!` right after VM boot.
+pub static INDUCE_COMPILE_FAILURE_SERIAL: AtomicUsize = AtomicUsize::new(!0);
+/// Serial of the canonical version of `induce_breakpoint!` right after VM boot.
+pub static INDUCE_BREAKPOINT_SERIAL: AtomicUsize = AtomicUsize::new(!0);
+
+/// Check if a method, `method_id`, currently exists on `ZJIT.singleton_class` and has the `expected_serial`.
+pub fn zjit_module_method_match_serial(method_id: ID, expected_serial: &AtomicUsize) -> bool {
+ let zjit_module_singleton = VALUE(ZJIT_MODULE.load(Ordering::Relaxed)).class_of();
+ let cme = unsafe { rb_callable_method_entry(zjit_module_singleton, method_id) };
+ if cme.is_null() {
+ false
+ } else {
+ let serial = unsafe { get_def_method_serial((*cme).def) };
+ serial == expected_serial.load(std::sync::atomic::Ordering::Relaxed)
+ }
+}
+
+/// Initialize IDs and annotate builtin C method entries.
+/// Must be called at boot before ruby_init_prelude() since the prelude
+/// could redefine core methods (e.g. Kernel.prepend via bundler).
+#[unsafe(no_mangle)]
+pub extern "C" fn rb_zjit_init_builtin_cmes() {
+ use InitializationState::*;
+
+ debug_assert!(
+ matches!(unsafe { &ZJIT_STATE }, Uninitialized),
+ "rb_zjit_init_builtin_cmes should only be called once during boot",
+ );
+
+ cruby::ids::init();
+
+ let method_annotations = cruby_methods::init();
+
+ unsafe { ZJIT_STATE = Initialized(method_annotations); }
+
+ // Boot time setup for compiler directives
+ unsafe {
+ let zjit_module = rb_const_get(rb_cRubyVM, rust_str_to_id("ZJIT"));
+
+ let cme = rb_callable_method_entry(zjit_module.class_of(), ID!(induce_side_exit_bang));
+ assert!(! cme.is_null(), "RubyVM::ZJIT.induce_side_exit! should exist on boot");
+ let serial = get_def_method_serial((*cme).def) ;
+ INDUCE_SIDE_EXIT_SERIAL.store(serial, Ordering::Relaxed);
+
+ let cme = rb_callable_method_entry(zjit_module.class_of(), ID!(induce_compile_failure_bang));
+ assert!(! cme.is_null(), "RubyVM::ZJIT.induce_compile_failure! should exist on boot");
+ let serial = get_def_method_serial((*cme).def) ;
+ INDUCE_COMPILE_FAILURE_SERIAL.store(serial, Ordering::Relaxed);
+
+ let cme = rb_callable_method_entry(zjit_module.class_of(), ID!(induce_breakpoint_bang));
+ assert!(! cme.is_null(), "RubyVM::ZJIT.induce_breakpoint! should exist on boot");
+ let serial = get_def_method_serial((*cme).def) ;
+ INDUCE_BREAKPOINT_SERIAL.store(serial, Ordering::Relaxed);
+
+ // Root and pin the module since we'll be doing object identity comparisons.
+ ZJIT_MODULE.store(zjit_module.0, Ordering::Relaxed);
+ rb_gc_register_mark_object(zjit_module);
+ }
+}
+
+/// Initialize ZJIT at boot. This is called even if ZJIT is disabled.
+#[unsafe(no_mangle)]
+pub extern "C" fn rb_zjit_init(zjit_enabled: bool) {
+ // If --zjit, enable ZJIT immediately
+ if zjit_enabled {
+ zjit_enable();
+ }
+}
+
+/// Enable ZJIT compilation.
+fn zjit_enable() {
+ // Call ZJIT hooks before enabling ZJIT to avoid compiling the hooks themselves
+ unsafe {
+ let zjit = rb_const_get(rb_cRubyVM, rust_str_to_id("ZJIT"));
+ rb_funcallv(zjit, rust_str_to_id("call_jit_hooks"), 0, std::ptr::null());
+ }
+
+ // Catch panics to avoid UB for unwinding into C frames.
+ // See https://doc.rust-lang.org/nomicon/exception-safety.html
+ let result = std::panic::catch_unwind(|| {
+ // Initialize ZJIT states
+ let zjit_entry = ZJITState::init();
+
+ // Install a panic hook for ZJIT
+ rb_bug_panic_hook();
+
+ // Discard the instruction count for boot which we never compile
+ unsafe { rb_vm_insn_count = 0; }
+
+ // ZJIT enabled and initialized successfully
+ assert!(unsafe{ rb_zjit_entry == null() });
+ unsafe { rb_zjit_entry = zjit_entry; }
+ });
+
+ if result.is_err() {
+ println!("ZJIT: zjit_enable() panicked. Aborting.");
+ std::process::abort();
+ }
+}
+
+/// Enable ZJIT compilation, returning Qtrue if ZJIT was previously disabled
+#[unsafe(no_mangle)]
+pub extern "C" fn rb_zjit_enable(_ec: EcPtr, _self: VALUE) -> VALUE {
+ with_vm_lock(src_loc!(), || {
+ // Options would not have been initialized during boot if no flags were specified
+ rb_zjit_prepare_options();
+
+ // Initialize and enable ZJIT
+ zjit_enable();
+
+ // Add "+ZJIT" to RUBY_DESCRIPTION
+ unsafe {
+ unsafe extern "C" {
+ fn ruby_set_zjit_description();
+ }
+ ruby_set_zjit_description();
+ }
+
+ Qtrue
+ })
+}
+
+/// Assert that any future ZJIT compilation will return a function pointer (not fail to compile)
+#[unsafe(no_mangle)]
+pub extern "C" fn rb_zjit_assert_compiles(_ec: EcPtr, _self: VALUE) -> VALUE {
+ ZJITState::enable_assert_compiles();
+ Qnil
+}
+
+/// Resolve a profile frame VALUE to a human-readable "label (path)" string.
+fn resolve_frame_label(frame: VALUE) -> String {
+ unsafe {
+ let label_str = ruby_str_to_rust_string_result(rb_profile_frame_full_label(frame)).unwrap_or("<unknown>".into());
+
+ let path = rb_profile_frame_absolute_path(frame);
+ let path = if path.nil_p() { rb_profile_frame_path(frame) } else { path };
+ let path_str = ruby_str_to_rust_string_result(path).unwrap_or("<unknown>".into());
+
+ format!("{label_str} ({path_str})")
+ }
+}
+
+/// Record a backtrace with ZJIT side exits as a Perfetto trace event
+#[unsafe(no_mangle)]
+pub extern "C" fn rb_zjit_record_exit_stack(reason: *const std::ffi::c_char) {
+ if !zjit_enabled_p() || get_option!(trace_side_exits).is_none() {
+ return;
+ }
+
+ let tracer = match ZJITState::get_tracer() {
+ Some(t) => t,
+ None => return,
+ };
+
+ // When `trace_side_exits_sample_interval` is non-zero, apply sampling.
+ if get_option!(trace_side_exits_sample_interval) != 0 {
+ if tracer.skipped_samples < get_option!(trace_side_exits_sample_interval) {
+ tracer.skipped_samples += 1;
+ return;
+ } else {
+ tracer.skipped_samples = 0;
+ }
+ }
+
+ // Collect profile frames
+ let frames = capture_ruby_frames();
+
+ // Get the reason string
+ let reason_str = if reason.is_null() {
+ "unknown"
+ } else {
+ unsafe { std::ffi::CStr::from_ptr(reason).to_str().unwrap_or("unknown") }
+ };
+
+ tracer.write_event("side_exit", reason_str, &frames);
+}
+
+/// Wrap a closure in a Perfetto duration event with category "invalidation"
+/// and a Ruby backtrace captured on the begin event.
+pub fn trace_invalidation<F, R>(reason: &str, func: F) -> R where F: FnOnce() -> R {
+ if !get_option!(trace_invalidation) {
+ return func();
+ }
+
+ // Capture backtrace and emit begin event before patching
+ let frames = capture_ruby_frames();
+ if let Some(tracer) = ZJITState::get_tracer() {
+ let ts = tracer.elapsed_ns();
+ tracer.write_duration_begin("invalidation", reason, ts, &frames);
+ }
+
+ let result = func();
+
+ if let Some(tracer) = ZJITState::get_tracer() {
+ let ts = tracer.elapsed_ns();
+ tracer.write_duration_end("invalidation", reason, ts);
+ }
+ result
+}
+
+/// Capture the current Ruby call stack as human-readable frame labels.
+fn capture_ruby_frames() -> Vec<String> {
+ const BUFF_LEN: usize = 2048;
+ let mut frames_buffer = vec![VALUE(0_usize); BUFF_LEN];
+ let mut lines_buffer = vec![0i32; BUFF_LEN];
+
+ let stack_length = unsafe {
+ rb_profile_frames(
+ 0,
+ BUFF_LEN as i32,
+ frames_buffer.as_mut_ptr(),
+ lines_buffer.as_mut_ptr(),
+ )
+ };
+
+ // Resolve each frame to a human-readable string (top frame first)
+ (0..stack_length as usize)
+ .map(|i| resolve_frame_label(frames_buffer[i]))
+ .collect()
+}