diff options
| author | Max Bernstein <rubybugs@bernsteinbear.com> | 2025-10-22 11:19:08 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-10-22 11:19:08 -0700 |
| commit | ceed406958349ccd3d29d86ab5b4af9aaf4616e0 (patch) | |
| tree | 739cbf1892b4cb636efb99cae2a14f1c5e269b8d | |
| parent | f09e74ce2b2794571531c708ed684b47a74a2ce9 (diff) | |
ZJIT: Inline simple SendWithoutBlockDirect (#14888)
Copy the YJIT simple inliner except for the kwargs bit. It works great!
| -rw-r--r-- | zjit.rb | 1 | ||||
| -rw-r--r-- | zjit/src/hir.rs | 441 | ||||
| -rw-r--r-- | zjit/src/stats.rs | 1 |
3 files changed, 421 insertions, 22 deletions
@@ -176,6 +176,7 @@ class << RubyVM::ZJIT :optimized_send_count, :iseq_optimized_send_count, :inline_cfunc_optimized_send_count, + :inline_iseq_optimized_send_count, :non_variadic_cfunc_optimized_send_count, :variadic_cfunc_optimized_send_count, ], buf:, stats:, right_align: true, base: :send_count) diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index fbe99d40d3..7def0b090e 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -1358,6 +1358,72 @@ pub struct Function { profiles: Option<ProfileOracle>, } +/// The kind of a value an ISEQ returns +enum IseqReturn { + Value(VALUE), + LocalVariable(u32), + Receiver, +} + +unsafe extern "C" { + fn rb_simple_iseq_p(iseq: IseqPtr) -> bool; +} + +/// Return the ISEQ's return value if it consists of one simple instruction and leave. +fn iseq_get_return_value(iseq: IseqPtr, captured_opnd: Option<InsnId>, ci_flags: u32) -> Option<IseqReturn> { + // Expect only two instructions and one possible operand + // NOTE: If an ISEQ has an optional keyword parameter with a default value that requires + // computation, the ISEQ will always have more than two instructions and won't be inlined. + let iseq_size = unsafe { get_iseq_encoded_size(iseq) }; + if !(2..=3).contains(&iseq_size) { + return None; + } + + // Get the first two instructions + let first_insn = iseq_opcode_at_idx(iseq, 0); + let second_insn = iseq_opcode_at_idx(iseq, insn_len(first_insn as usize)); + + // Extract the return value if known + if second_insn != YARVINSN_leave { + return None; + } + match first_insn { + YARVINSN_getlocal_WC_0 => { + // Accept only cases where only positional arguments are used by both the callee and the caller. + // Keyword arguments may be specified by the callee or the caller but not used. + if captured_opnd.is_some() + // Equivalent to `VM_CALL_ARGS_SIMPLE - VM_CALL_KWARG - has_block_iseq` + || ci_flags & ( + VM_CALL_ARGS_SPLAT + | VM_CALL_KW_SPLAT + | VM_CALL_ARGS_BLOCKARG + | VM_CALL_FORWARDING + ) != 0 + { + return None; + } + + let ep_offset = unsafe { *rb_iseq_pc_at_idx(iseq, 1) }.as_u32(); + let local_idx = ep_offset_to_local_idx(iseq, ep_offset); + + if unsafe { rb_simple_iseq_p(iseq) } { + return Some(IseqReturn::LocalVariable(local_idx.try_into().unwrap())); + } + + // TODO(max): Support only_kwparam case where the local_idx is a positional parameter + + return None; + } + YARVINSN_putnil => Some(IseqReturn::Value(Qnil)), + YARVINSN_putobject => Some(IseqReturn::Value(unsafe { *rb_iseq_pc_at_idx(iseq, 1) })), + YARVINSN_putobject_INT2FIX_0_ => Some(IseqReturn::Value(VALUE::fixnum_from_usize(0))), + YARVINSN_putobject_INT2FIX_1_ => Some(IseqReturn::Value(VALUE::fixnum_from_usize(1))), + // We don't support invokeblock for now. Such ISEQs are likely not used by blocks anyway. + YARVINSN_putself if captured_opnd.is_none() => Some(IseqReturn::Receiver), + _ => None, + } +} + impl Function { fn new(iseq: *const rb_iseq_t) -> Function { Function { @@ -2343,6 +2409,46 @@ impl Function { self.infer_types(); } + fn inline(&mut self) { + for block in self.rpo() { + let old_insns = std::mem::take(&mut self.blocks[block.0].insns); + assert!(self.blocks[block.0].insns.is_empty()); + for insn_id in old_insns { + match self.find(insn_id) { + // Reject block ISEQs to avoid autosplat and other block parameter complications. + Insn::SendWithoutBlockDirect { recv, iseq, cd, args, .. } => { + let call_info = unsafe { (*cd).ci }; + let ci_flags = unsafe { vm_ci_flag(call_info) }; + // .send call is not currently supported for builtins + if ci_flags & VM_CALL_OPT_SEND != 0 { + self.push_insn_id(block, insn_id); continue; + } + let Some(value) = iseq_get_return_value(iseq, None, ci_flags) else { + self.push_insn_id(block, insn_id); continue; + }; + match value { + IseqReturn::LocalVariable(idx) => { + self.push_insn(block, Insn::IncrCounter(Counter::inline_iseq_optimized_send_count)); + self.make_equal_to(insn_id, args[idx as usize]); + } + IseqReturn::Value(value) => { + self.push_insn(block, Insn::IncrCounter(Counter::inline_iseq_optimized_send_count)); + let replacement = self.push_insn(block, Insn::Const { val: Const::Value(value) }); + self.make_equal_to(insn_id, replacement); + } + IseqReturn::Receiver => { + self.push_insn(block, Insn::IncrCounter(Counter::inline_iseq_optimized_send_count)); + self.make_equal_to(insn_id, recv); + } + } + } + _ => { self.push_insn_id(block, insn_id); } + } + } + } + self.infer_types(); + } + fn optimize_getivar(&mut self) { for block in self.rpo() { let old_insns = std::mem::take(&mut self.blocks[block.0].insns); @@ -3208,6 +3314,8 @@ impl Function { // Function is assumed to have types inferred already self.type_specialize(); #[cfg(debug_assertions)] self.assert_validates(); + self.inline(); + #[cfg(debug_assertions)] self.assert_validates(); self.optimize_getivar(); #[cfg(debug_assertions)] self.assert_validates(); self.optimize_c_calls(); @@ -8980,15 +9088,14 @@ mod opt_tests { #[test] fn test_optimize_top_level_call_into_send_direct() { eval(" - def foo - end + def foo = [] def test foo end test; test "); assert_snapshot!(hir_string("test"), @r" - fn test@<compiled>:5: + fn test@<compiled>:4: bb0(): EntryPoint interpreter v1:BasicObject = LoadSelf @@ -9036,8 +9143,7 @@ mod opt_tests { #[test] fn test_optimize_private_top_level_call() { eval(" - def foo - end + def foo = [] private :foo def test foo @@ -9045,7 +9151,7 @@ mod opt_tests { test; test "); assert_snapshot!(hir_string("test"), @r" - fn test@<compiled>:6: + fn test@<compiled>:5: bb0(): EntryPoint interpreter v1:BasicObject = LoadSelf @@ -9094,15 +9200,14 @@ mod opt_tests { #[test] fn test_optimize_top_level_call_with_args_into_send_direct() { eval(" - def foo a, b - end + def foo(a, b) = [] def test foo 1, 2 end test; test "); assert_snapshot!(hir_string("test"), @r" - fn test@<compiled>:5: + fn test@<compiled>:4: bb0(): EntryPoint interpreter v1:BasicObject = LoadSelf @@ -9125,10 +9230,8 @@ mod opt_tests { #[test] fn test_optimize_top_level_sends_into_send_direct() { eval(" - def foo - end - def bar - end + def foo = [] + def bar = [] def test foo bar @@ -9136,7 +9239,7 @@ mod opt_tests { test; test "); assert_snapshot!(hir_string("test"), @r" - fn test@<compiled>:7: + fn test@<compiled>:5: bb0(): EntryPoint interpreter v1:BasicObject = LoadSelf @@ -10656,9 +10759,7 @@ mod opt_tests { fn test_send_direct_to_instance_method() { eval(" class C - def foo - 3 - end + def foo = [] end def test(c) = c.foo @@ -10668,7 +10769,7 @@ mod opt_tests { "); assert_snapshot!(hir_string("test"), @r" - fn test@<compiled>:8: + fn test@<compiled>:6: bb0(): EntryPoint interpreter v1:BasicObject = LoadSelf @@ -12097,7 +12198,7 @@ mod opt_tests { fn test_dont_optimize_array_aref_if_redefined() { eval(r##" class Array - def [](index); end + def [](index) = [] end def test = [4,5,6].freeze[10] "##); @@ -12126,7 +12227,7 @@ mod opt_tests { fn test_dont_optimize_array_max_if_redefined() { eval(r##" class Array - def max = 10 + def max = [] end def test = [4,5,6].max "##); @@ -12797,9 +12898,10 @@ mod opt_tests { PatchPoint MethodRedefined(Object@0x1000, foo@0x1008, cme:0x1010) PatchPoint NoSingletonClass(Object@0x1000) v19:HeapObject[class_exact*:Object@VALUE(0x1000)] = GuardType v6, HeapObject[class_exact*:Object@VALUE(0x1000)] - v20:BasicObject = SendWithoutBlockDirect v19, :foo (0x1038) + IncrCounter inline_iseq_optimized_send_count + v22:NilClass = Const Value(nil) CheckInterrupts - Return v20 + Return v22 "); } @@ -14617,4 +14719,299 @@ mod opt_tests { Return v25 "); } + + #[test] + fn test_inline_send_without_block_direct_putself() { + eval(r#" + def callee = self + def test = callee + test + "#); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:3: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + Jump bb2(v1) + bb1(v4:BasicObject): + EntryPoint JIT(0) + Jump bb2(v4) + bb2(v6:BasicObject): + PatchPoint MethodRedefined(Object@0x1000, callee@0x1008, cme:0x1010) + PatchPoint NoSingletonClass(Object@0x1000) + v19:HeapObject[class_exact*:Object@VALUE(0x1000)] = GuardType v6, HeapObject[class_exact*:Object@VALUE(0x1000)] + IncrCounter inline_iseq_optimized_send_count + CheckInterrupts + Return v19 + "); + } + + #[test] + fn test_inline_send_without_block_direct_putobject_string() { + eval(r#" + # frozen_string_literal: true + def callee = "abc" + def test = callee + test + "#); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:4: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + Jump bb2(v1) + bb1(v4:BasicObject): + EntryPoint JIT(0) + Jump bb2(v4) + bb2(v6:BasicObject): + PatchPoint MethodRedefined(Object@0x1000, callee@0x1008, cme:0x1010) + PatchPoint NoSingletonClass(Object@0x1000) + v19:HeapObject[class_exact*:Object@VALUE(0x1000)] = GuardType v6, HeapObject[class_exact*:Object@VALUE(0x1000)] + IncrCounter inline_iseq_optimized_send_count + v22:StringExact[VALUE(0x1038)] = Const Value(VALUE(0x1038)) + CheckInterrupts + Return v22 + "); + } + + #[test] + fn test_inline_send_without_block_direct_putnil() { + eval(r#" + def callee = nil + def test = callee + test + "#); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:3: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + Jump bb2(v1) + bb1(v4:BasicObject): + EntryPoint JIT(0) + Jump bb2(v4) + bb2(v6:BasicObject): + PatchPoint MethodRedefined(Object@0x1000, callee@0x1008, cme:0x1010) + PatchPoint NoSingletonClass(Object@0x1000) + v19:HeapObject[class_exact*:Object@VALUE(0x1000)] = GuardType v6, HeapObject[class_exact*:Object@VALUE(0x1000)] + IncrCounter inline_iseq_optimized_send_count + v22:NilClass = Const Value(nil) + CheckInterrupts + Return v22 + "); + } + + #[test] + fn test_inline_send_without_block_direct_putobject_true() { + eval(r#" + def callee = true + def test = callee + test + "#); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:3: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + Jump bb2(v1) + bb1(v4:BasicObject): + EntryPoint JIT(0) + Jump bb2(v4) + bb2(v6:BasicObject): + PatchPoint MethodRedefined(Object@0x1000, callee@0x1008, cme:0x1010) + PatchPoint NoSingletonClass(Object@0x1000) + v19:HeapObject[class_exact*:Object@VALUE(0x1000)] = GuardType v6, HeapObject[class_exact*:Object@VALUE(0x1000)] + IncrCounter inline_iseq_optimized_send_count + v22:TrueClass = Const Value(true) + CheckInterrupts + Return v22 + "); + } + + #[test] + fn test_inline_send_without_block_direct_putobject_false() { + eval(r#" + def callee = false + def test = callee + test + "#); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:3: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + Jump bb2(v1) + bb1(v4:BasicObject): + EntryPoint JIT(0) + Jump bb2(v4) + bb2(v6:BasicObject): + PatchPoint MethodRedefined(Object@0x1000, callee@0x1008, cme:0x1010) + PatchPoint NoSingletonClass(Object@0x1000) + v19:HeapObject[class_exact*:Object@VALUE(0x1000)] = GuardType v6, HeapObject[class_exact*:Object@VALUE(0x1000)] + IncrCounter inline_iseq_optimized_send_count + v22:FalseClass = Const Value(false) + CheckInterrupts + Return v22 + "); + } + + #[test] + fn test_inline_send_without_block_direct_putobject_zero() { + eval(r#" + def callee = 0 + def test = callee + test + "#); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:3: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + Jump bb2(v1) + bb1(v4:BasicObject): + EntryPoint JIT(0) + Jump bb2(v4) + bb2(v6:BasicObject): + PatchPoint MethodRedefined(Object@0x1000, callee@0x1008, cme:0x1010) + PatchPoint NoSingletonClass(Object@0x1000) + v19:HeapObject[class_exact*:Object@VALUE(0x1000)] = GuardType v6, HeapObject[class_exact*:Object@VALUE(0x1000)] + IncrCounter inline_iseq_optimized_send_count + v22:Fixnum[0] = Const Value(0) + CheckInterrupts + Return v22 + "); + } + + #[test] + fn test_inline_send_without_block_direct_putobject_one() { + eval(r#" + def callee = 1 + def test = callee + test + "#); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:3: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + Jump bb2(v1) + bb1(v4:BasicObject): + EntryPoint JIT(0) + Jump bb2(v4) + bb2(v6:BasicObject): + PatchPoint MethodRedefined(Object@0x1000, callee@0x1008, cme:0x1010) + PatchPoint NoSingletonClass(Object@0x1000) + v19:HeapObject[class_exact*:Object@VALUE(0x1000)] = GuardType v6, HeapObject[class_exact*:Object@VALUE(0x1000)] + IncrCounter inline_iseq_optimized_send_count + v22:Fixnum[1] = Const Value(1) + CheckInterrupts + Return v22 + "); + } + + #[test] + fn test_inline_send_without_block_direct_parameter() { + eval(r#" + def callee(x) = x + def test = callee 3 + test + "#); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:3: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + Jump bb2(v1) + bb1(v4:BasicObject): + EntryPoint JIT(0) + Jump bb2(v4) + bb2(v6:BasicObject): + v10:Fixnum[3] = Const Value(3) + PatchPoint MethodRedefined(Object@0x1000, callee@0x1008, cme:0x1010) + PatchPoint NoSingletonClass(Object@0x1000) + v20:HeapObject[class_exact*:Object@VALUE(0x1000)] = GuardType v6, HeapObject[class_exact*:Object@VALUE(0x1000)] + IncrCounter inline_iseq_optimized_send_count + CheckInterrupts + Return v10 + "); + } + + #[test] + fn test_inline_send_without_block_direct_last_parameter() { + eval(r#" + def callee(x, y, z) = z + def test = callee 1, 2, 3 + test + "#); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:3: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + Jump bb2(v1) + bb1(v4:BasicObject): + EntryPoint JIT(0) + Jump bb2(v4) + bb2(v6:BasicObject): + v10:Fixnum[1] = Const Value(1) + v11:Fixnum[2] = Const Value(2) + v12:Fixnum[3] = Const Value(3) + PatchPoint MethodRedefined(Object@0x1000, callee@0x1008, cme:0x1010) + PatchPoint NoSingletonClass(Object@0x1000) + v22:HeapObject[class_exact*:Object@VALUE(0x1000)] = GuardType v6, HeapObject[class_exact*:Object@VALUE(0x1000)] + IncrCounter inline_iseq_optimized_send_count + CheckInterrupts + Return v12 + "); + } + + #[test] + fn test_inline_symbol_to_sym() { + eval(r#" + def test(o) = o.to_sym + test :foo + "#); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:2: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal l0, SP@4 + Jump bb2(v1, v2) + bb1(v5:BasicObject, v6:BasicObject): + EntryPoint JIT(0) + Jump bb2(v5, v6) + bb2(v8:BasicObject, v9:BasicObject): + PatchPoint MethodRedefined(Symbol@0x1000, to_sym@0x1008, cme:0x1010) + v21:StaticSymbol = GuardType v9, StaticSymbol + IncrCounter inline_iseq_optimized_send_count + CheckInterrupts + Return v21 + "); + } + + #[test] + fn test_inline_integer_to_i() { + eval(r#" + def test(o) = o.to_i + test 5 + "#); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:2: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal l0, SP@4 + Jump bb2(v1, v2) + bb1(v5:BasicObject, v6:BasicObject): + EntryPoint JIT(0) + Jump bb2(v5, v6) + bb2(v8:BasicObject, v9:BasicObject): + PatchPoint MethodRedefined(Integer@0x1000, to_i@0x1008, cme:0x1010) + v21:Fixnum = GuardType v9, Fixnum + IncrCounter inline_iseq_optimized_send_count + CheckInterrupts + Return v21 + "); + } } diff --git a/zjit/src/stats.rs b/zjit/src/stats.rs index 913a72fa56..4dd87d269a 100644 --- a/zjit/src/stats.rs +++ b/zjit/src/stats.rs @@ -179,6 +179,7 @@ make_counters! { optimized_send { iseq_optimized_send_count, inline_cfunc_optimized_send_count, + inline_iseq_optimized_send_count, non_variadic_cfunc_optimized_send_count, variadic_cfunc_optimized_send_count, } |
