From 98e6f5e4bc56b8b611d152a43500531478a6472d Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Fri, 6 Feb 2026 17:45:58 -0500 Subject: Enable ZJIT jit hooks for with_jit builtins --- jit_hook.rb | 6 ++--- zjit/bindgen/src/main.rs | 2 ++ zjit/src/cruby.rs | 6 +++++ zjit/src/cruby_bindings.inc.rs | 2 ++ zjit/src/hir/opt_tests.rs | 50 +++++++++++++++++++++++++++++++++++++----- zjit/src/hir/tests.rs | 48 ++++++++++++++++++++++++++++++++++++++++ zjit/src/state.rs | 8 +++++-- 7 files changed, 112 insertions(+), 10 deletions(-) diff --git a/jit_hook.rb b/jit_hook.rb index 346b716948..c605d6e26d 100644 --- a/jit_hook.rb +++ b/jit_hook.rb @@ -2,10 +2,10 @@ class Module # Internal helper for built-in initializations to define methods only when JIT is enabled. # This method is removed in jit_undef.rb. private def with_jit(&block) # :nodoc: - # ZJIT currently doesn't compile Array#each properly, so it's disabled for now. - if defined?(RubyVM::ZJIT) && false # TODO: remove `&& false` (Shopify/ruby#667) + if defined?(RubyVM::ZJIT) RubyVM::ZJIT.send(:add_jit_hook, block) - elsif defined?(RubyVM::YJIT) + end + if defined?(RubyVM::YJIT) RubyVM::YJIT.send(:add_jit_hook, block) end end diff --git a/zjit/bindgen/src/main.rs b/zjit/bindgen/src/main.rs index d71e75c444..dab1a6d929 100644 --- a/zjit/bindgen/src/main.rs +++ b/zjit/bindgen/src/main.rs @@ -167,6 +167,8 @@ fn main() { .allowlist_var("rb_cClass") .allowlist_var("rb_cRegexp") .allowlist_var("rb_cISeq") + .allowlist_var("rb_cRubyVM") + .allowlist_function("rb_const_get") .allowlist_type("ruby_fl_type") .allowlist_type("ruby_fl_ushift") diff --git a/zjit/src/cruby.rs b/zjit/src/cruby.rs index 8e569793a8..a47d9bf61f 100644 --- a/zjit/src/cruby.rs +++ b/zjit/src/cruby.rs @@ -1133,6 +1133,12 @@ pub mod test_utils { crate::cruby::ids::init(); // for ID! usages in tests } + // Call ZJIT hooks to install Ruby implementations of builtins like Array#each + unsafe { + let zjit = rb_const_get(rb_cRubyVM, rust_str_to_id("ZJIT")); + rb_funcallv(zjit, rust_str_to_id("call_jit_hooks"), 0, std::ptr::null()); + } + // Set up globals for convenience let zjit_entry = ZJITState::init(); diff --git a/zjit/src/cruby_bindings.inc.rs b/zjit/src/cruby_bindings.inc.rs index 77103b9930..f178e76728 100644 --- a/zjit/src/cruby_bindings.inc.rs +++ b/zjit/src/cruby_bindings.inc.rs @@ -1960,6 +1960,7 @@ unsafe extern "C" { pub fn rb_ivar_set(obj: VALUE, name: ID, val: VALUE) -> VALUE; pub fn rb_ivar_defined(obj: VALUE, name: ID) -> VALUE; pub fn rb_attr_get(obj: VALUE, name: ID) -> VALUE; + pub fn rb_const_get(space: VALUE, name: ID) -> VALUE; pub fn rb_class_allocate_instance(klass: VALUE) -> VALUE; pub fn rb_obj_equal(obj1: VALUE, obj2: VALUE) -> VALUE; pub fn rb_reg_new_ary(ary: VALUE, options: ::std::os::raw::c_int) -> VALUE; @@ -1982,6 +1983,7 @@ unsafe extern "C" { id: ID, ) -> *const rb_callable_method_entry_t; pub static mut rb_cISeq: VALUE; + pub static mut rb_cRubyVM: VALUE; pub static mut rb_mRubyVMFrozenCore: VALUE; pub static mut rb_block_param_proxy: VALUE; pub fn rb_vm_ep_local_ep(ep: *const VALUE) -> *const VALUE; diff --git a/zjit/src/hir/opt_tests.rs b/zjit/src/hir/opt_tests.rs index 4ce01c438a..8a172e6b12 100644 --- a/zjit/src/hir/opt_tests.rs +++ b/zjit/src/hir/opt_tests.rs @@ -823,11 +823,11 @@ mod hir_opt_tests { bb2(v8:BasicObject, v9:BasicObject): PatchPoint NoSingletonClass(C@0x1000) PatchPoint MethodRedefined(C@0x1000, fun_new_map@0x1008, cme:0x1010) - v23:ArraySubclass[class_exact:C] = GuardType v9, ArraySubclass[class_exact:C] - v24:BasicObject = CCallWithFrame v23, :C#fun_new_map@0x1038, block=0x1040 + v22:ArraySubclass[class_exact:C] = GuardType v9, ArraySubclass[class_exact:C] + v23:BasicObject = SendDirect v22, 0x1038, :fun_new_map (0x1048) v15:BasicObject = GetLocal :o, l0, EP@3 CheckInterrupts - Return v24 + Return v23 "); } @@ -6515,9 +6515,9 @@ mod hir_opt_tests { v11:ArrayExact = ArrayDup v10 PatchPoint NoSingletonClass(Array@0x1008) PatchPoint MethodRedefined(Array@0x1008, map@0x1010, cme:0x1018) - v21:BasicObject = CCallWithFrame v11, :Array#map@0x1040, block=0x1048 + v20:BasicObject = SendDirect v11, 0x1040, :map (0x1050) CheckInterrupts - Return v21 + Return v20 "); } @@ -12004,4 +12004,44 @@ mod hir_opt_tests { Return v99 "); } + + #[test] + fn test_array_each() { + eval("[1, 2, 3].each { |x| x }"); + assert_snapshot!(hir_string_proc("Array.instance_method(:each)"), @r" + fn each@: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:NilClass = Const Value(nil) + Jump bb2(v1, v2) + bb1(v5:BasicObject): + EntryPoint JIT(0) + v6:NilClass = Const Value(nil) + Jump bb2(v5, v6) + bb2(v8:BasicObject, v9:NilClass): + v13:NilClass = Const Value(nil) + v15:TrueClass|NilClass = Defined yield, v13 + v17:CBool = Test v15 + IfFalse v17, bb3(v8, v9) + v35:Fixnum[0] = Const Value(0) + Jump bb7(v8, v35) + bb3(v23:BasicObject, v24:NilClass): + v28:BasicObject = InvokeBuiltin , v23 + CheckInterrupts + Return v28 + bb7(v48:BasicObject, v49:BasicObject): + v52:BasicObject = InvokeBuiltin rb_jit_ary_at_end, v48, v49 + v54:CBool = Test v52 + IfFalse v54, bb6(v48, v49) + CheckInterrupts + Return v48 + bb6(v67:BasicObject, v68:BasicObject): + v72:BasicObject = InvokeBuiltin rb_jit_ary_at, v67, v68 + v74:BasicObject = InvokeBlock, v72 # SendFallbackReason: Uncategorized(invokeblock) + v78:BasicObject = InvokeBuiltin rb_jit_fixnum_inc, v67, v68 + PatchPoint NoEPEscape(each) + Jump bb7(v67, v78) + "); + } } diff --git a/zjit/src/hir/tests.rs b/zjit/src/hir/tests.rs index f3ab0fa57c..8830b199c4 100644 --- a/zjit/src/hir/tests.rs +++ b/zjit/src/hir/tests.rs @@ -4069,6 +4069,54 @@ pub mod hir_build_tests { SideExit TooManyKeywordParameters "); } + + #[test] + fn test_array_each() { + assert_snapshot!(hir_string_proc("Array.instance_method(:each)"), @r" + fn each@: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:NilClass = Const Value(nil) + Jump bb2(v1, v2) + bb1(v5:BasicObject): + EntryPoint JIT(0) + v6:NilClass = Const Value(nil) + Jump bb2(v5, v6) + bb2(v8:BasicObject, v9:NilClass): + v13:NilClass = Const Value(nil) + v15:TrueClass|NilClass = Defined yield, v13 + v17:CBool = Test v15 + v18:NilClass = RefineType v15, Falsy + IfFalse v17, bb3(v8, v9) + v20:TrueClass = RefineType v15, Truthy + Jump bb5(v8, v9) + bb3(v23:BasicObject, v24:NilClass): + v28:BasicObject = InvokeBuiltin , v23 + Jump bb4(v23, v24, v28) + bb4(v40:BasicObject, v41:NilClass, v42:BasicObject): + CheckInterrupts + Return v42 + bb5(v30:BasicObject, v31:NilClass): + v35:Fixnum[0] = Const Value(0) + Jump bb7(v30, v35) + bb7(v48:BasicObject, v49:BasicObject): + v52:BasicObject = InvokeBuiltin rb_jit_ary_at_end, v48, v49 + v54:CBool = Test v52 + v55:Falsy = RefineType v52, Falsy + IfFalse v54, bb6(v48, v49) + v57:Truthy = RefineType v52, Truthy + v59:NilClass = Const Value(nil) + CheckInterrupts + Return v48 + bb6(v67:BasicObject, v68:BasicObject): + v72:BasicObject = InvokeBuiltin rb_jit_ary_at, v67, v68 + v74:BasicObject = InvokeBlock, v72 # SendFallbackReason: Uncategorized(invokeblock) + v78:BasicObject = InvokeBuiltin rb_jit_fixnum_inc, v67, v68 + PatchPoint NoEPEscape(each) + Jump bb7(v67, v78) + "); + } } /// Test successor and predecessor set computations. diff --git a/zjit/src/state.rs b/zjit/src/state.rs index a807be3f12..04411e7efb 100644 --- a/zjit/src/state.rs +++ b/zjit/src/state.rs @@ -1,7 +1,7 @@ //! Runtime state of ZJIT. use crate::codegen::{gen_entry_trampoline, gen_exit_trampoline, gen_exit_trampoline_with_counter, gen_function_stub_hit_trampoline}; -use crate::cruby::{self, rb_bug_panic_hook, rb_vm_insn_count, src_loc, EcPtr, Qnil, Qtrue, rb_vm_insn_addr2opcode, rb_profile_frames, VALUE, VM_INSTRUCTION_SIZE, size_t, rb_gc_mark, with_vm_lock}; +use crate::cruby::{self, rb_bug_panic_hook, rb_vm_insn_count, src_loc, EcPtr, Qnil, Qtrue, rb_vm_insn_addr2opcode, rb_profile_frames, VALUE, VM_INSTRUCTION_SIZE, size_t, rb_gc_mark, with_vm_lock, rust_str_to_id, rb_funcallv, rb_const_get, rb_cRubyVM}; use crate::cruby_methods; use crate::invariants::Invariants; use crate::asm::CodeBlock; @@ -319,7 +319,11 @@ pub extern "C" fn rb_zjit_init(zjit_enabled: bool) { /// Enable ZJIT compilation. fn zjit_enable() { - // TODO: call RubyVM::ZJIT::call_jit_hooks here + // Call ZJIT hooks before enabling ZJIT to avoid compiling the hooks themselves + unsafe { + let zjit = rb_const_get(rb_cRubyVM, rust_str_to_id("ZJIT")); + rb_funcallv(zjit, rust_str_to_id("call_jit_hooks"), 0, std::ptr::null()); + } // Catch panics to avoid UB for unwinding into C frames. // See https://doc.rust-lang.org/nomicon/exception-safety.html -- cgit v1.2.3