summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMax Bernstein <rubybugs@bernsteinbear.com>2025-10-22 16:26:37 -0700
committerGitHub <noreply@github.com>2025-10-22 23:26:37 +0000
commitae767b6ca859bfc9b18e964494052c7e2e5a41df (patch)
tree366a92b24d70679f37829aabcb5fca3987af6ba9
parenta763e6dd484951759b1b6cb7022b99bdf192895d (diff)
ZJIT: Inline Kernel#block_given? (#14914)
Fix https://github.com/Shopify/ruby/issues/832
-rw-r--r--test/ruby/test_zjit.rb21
-rw-r--r--zjit/src/codegen.rs16
-rw-r--r--zjit/src/cruby_methods.rs8
-rw-r--r--zjit/src/hir.rs106
4 files changed, 148 insertions, 3 deletions
diff --git a/test/ruby/test_zjit.rb b/test/ruby/test_zjit.rb
index e151a022d1..44f010d056 100644
--- a/test/ruby/test_zjit.rb
+++ b/test/ruby/test_zjit.rb
@@ -1962,6 +1962,27 @@ class TestZJIT < Test::Unit::TestCase
}, call_threshold: 2
end
+ def test_block_given_p
+ assert_compiles "false", "block_given?"
+ assert_compiles '[false, false, true]', %q{
+ def test = block_given?
+ [test, test, test{}]
+ }, call_threshold: 2, insns: [:opt_send_without_block]
+ end
+
+ def test_block_given_p_from_block
+ # This will do some EP hopping to find the local EP,
+ # so it's slightly different than doing it outside of a block.
+
+ assert_compiles '[false, false, true]', %q{
+ def test
+ yield_self { yield_self { block_given? } }
+ end
+
+ [test, test, test{}]
+ }, call_threshold: 2
+ end
+
def test_invokeblock_without_block_after_jit_call
assert_compiles '"no block given (yield)"', %q{
def test(*arr, &b)
diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs
index f0ac9b5c7b..848249d774 100644
--- a/zjit/src/codegen.rs
+++ b/zjit/src/codegen.rs
@@ -449,6 +449,7 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio
Insn::LoadSelf => gen_load_self(),
&Insn::LoadIvarEmbedded { self_val, id, index } => gen_load_ivar_embedded(asm, opnd!(self_val), id, index),
&Insn::LoadIvarExtended { self_val, id, index } => gen_load_ivar_extended(asm, opnd!(self_val), id, index),
+ &Insn::IsBlockGiven => gen_is_block_given(jit, asm),
&Insn::ArrayMax { state, .. }
| &Insn::FixnumDiv { state, .. }
| &Insn::Throw { state, .. }
@@ -525,6 +526,8 @@ fn gen_defined(jit: &JITState, asm: &mut Assembler, op_type: usize, obj: VALUE,
// `yield` goes to the block handler stowed in the "local" iseq which is
// the current iseq or a parent. Only the "method" iseq type can be passed a
// block handler. (e.g. `yield` in the top level script is a syntax error.)
+ //
+ // Similar to gen_is_block_given
let local_iseq = unsafe { rb_get_iseq_body_local_iseq(jit.iseq) };
if unsafe { rb_get_iseq_body_type(local_iseq) } == ISEQ_TYPE_METHOD {
let lep = gen_get_lep(jit, asm);
@@ -550,6 +553,19 @@ fn gen_defined(jit: &JITState, asm: &mut Assembler, op_type: usize, obj: VALUE,
}
}
+/// Similar to gen_defined for DEFINED_YIELD
+fn gen_is_block_given(jit: &JITState, asm: &mut Assembler) -> Opnd {
+ let local_iseq = unsafe { rb_get_iseq_body_local_iseq(jit.iseq) };
+ if unsafe { rb_get_iseq_body_type(local_iseq) } == ISEQ_TYPE_METHOD {
+ let lep = gen_get_lep(jit, asm);
+ let block_handler = asm.load(Opnd::mem(64, lep, SIZEOF_VALUE_I32 * VM_ENV_DATA_INDEX_SPECVAL));
+ asm.cmp(block_handler, VM_BLOCK_HANDLER_NONE.into());
+ asm.csel_e(Qfalse.into(), Qtrue.into())
+ } else {
+ Qfalse.into()
+ }
+}
+
/// Get a local variable from a higher scope or the heap. `local_ep_offset` is in number of VALUEs.
/// We generate this instruction with level=0 only when the local variable is on the heap, so we
/// can't optimize the level=0 case using the SP register.
diff --git a/zjit/src/cruby_methods.rs b/zjit/src/cruby_methods.rs
index eabddce739..ee10eaa681 100644
--- a/zjit/src/cruby_methods.rs
+++ b/zjit/src/cruby_methods.rs
@@ -188,6 +188,7 @@ pub fn init() -> Annotations {
}
annotate!(rb_mKernel, "itself", inline_kernel_itself);
+ annotate!(rb_mKernel, "block_given?", inline_kernel_block_given_p);
annotate!(rb_cString, "bytesize", types::Fixnum, no_gc, leaf);
annotate!(rb_cString, "to_s", types::StringExact);
annotate!(rb_cString, "getbyte", inline_string_getbyte);
@@ -247,6 +248,13 @@ fn inline_kernel_itself(_fun: &mut hir::Function, _block: hir::BlockId, recv: hi
None
}
+fn inline_kernel_block_given_p(fun: &mut hir::Function, block: hir::BlockId, _recv: hir::InsnId, args: &[hir::InsnId], _state: hir::InsnId) -> Option<hir::InsnId> {
+ let &[] = args else { return None; };
+ // TODO(max): In local iseq types that are not ISEQ_TYPE_METHOD, rewrite to Constant false.
+ let result = fun.push_insn(block, hir::Insn::IsBlockGiven);
+ return Some(result);
+}
+
fn inline_array_aref(fun: &mut hir::Function, block: hir::BlockId, recv: hir::InsnId, args: &[hir::InsnId], state: hir::InsnId) -> Option<hir::InsnId> {
if let &[index] = args {
if fun.likely_a(index, types::Fixnum, state) {
diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs
index 91484ca970..489ea83a44 100644
--- a/zjit/src/hir.rs
+++ b/zjit/src/hir.rs
@@ -610,8 +610,12 @@ pub enum Insn {
IsMethodCfunc { val: InsnId, cd: *const rb_call_data, cfunc: *const u8, state: InsnId },
/// Return C `true` if left == right
IsBitEqual { left: InsnId, right: InsnId },
+ // TODO(max): In iseq body types that are not ISEQ_TYPE_METHOD, rewrite to Constant false.
Defined { op_type: usize, obj: VALUE, pushval: VALUE, v: InsnId, state: InsnId },
GetConstantPath { ic: *const iseq_inline_constant_cache, state: InsnId },
+ /// Kernel#block_given? but without pushing a frame. Similar to [`Insn::Defined`] with
+ /// `DEFINED_YIELD`
+ IsBlockGiven,
/// Get a global variable named `id`
GetGlobal { id: ID, state: InsnId },
@@ -870,6 +874,7 @@ impl Insn {
Insn::NewRange { .. } => true,
Insn::NewRangeFixnum { .. } => false,
Insn::StringGetbyteFixnum { .. } => false,
+ Insn::IsBlockGiven => false,
_ => true,
}
}
@@ -1065,6 +1070,7 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> {
Insn::GuardBlockParamProxy { level, .. } => write!(f, "GuardBlockParamProxy l{level}"),
Insn::PatchPoint { invariant, .. } => { write!(f, "PatchPoint {}", invariant.print(self.ptr_map)) },
Insn::GetConstantPath { ic, .. } => { write!(f, "GetConstantPath {:p}", self.ptr_map.map_ptr(ic)) },
+ Insn::IsBlockGiven => { write!(f, "IsBlockGiven") },
Insn::CCall { cfunc, args, name, return_type: _, elidable: _ } => {
write!(f, "CCall {}@{:p}", name.contents_lossy(), self.ptr_map.map_ptr(cfunc))?;
for arg in args {
@@ -1562,6 +1568,7 @@ impl Function {
result@(Const {..}
| Param {..}
| GetConstantPath {..}
+ | IsBlockGiven
| PatchPoint {..}
| PutSpecialObject {..}
| GetGlobal {..}
@@ -1828,6 +1835,7 @@ impl Function {
Insn::Defined { pushval, .. } => Type::from_value(*pushval).union(types::NilClass),
Insn::DefinedIvar { pushval, .. } => Type::from_value(*pushval).union(types::NilClass),
Insn::GetConstantPath { .. } => types::BasicObject,
+ Insn::IsBlockGiven { .. } => types::BoolExact,
Insn::ArrayMax { .. } => types::BasicObject,
Insn::GetGlobal { .. } => types::BasicObject,
Insn::GetIvar { .. } => types::BasicObject,
@@ -3009,6 +3017,7 @@ impl Function {
| &Insn::LoadSelf
| &Insn::GetLocal { .. }
| &Insn::PutSpecialObject { .. }
+ | &Insn::IsBlockGiven
| &Insn::IncrCounter(_)
| &Insn::IncrCounterPtr { .. } =>
{}
@@ -8492,13 +8501,23 @@ mod opt_tests {
use super::tests::assert_contains_opcode;
#[track_caller]
- fn hir_string(method: &str) -> String {
- let iseq = crate::cruby::with_rubyvm(|| get_method_iseq("self", method));
+ fn hir_string_function(function: &Function) -> String {
+ format!("{}", FunctionPrinter::without_snapshot(function))
+ }
+
+ #[track_caller]
+ fn hir_string_proc(proc: &str) -> String {
+ let iseq = crate::cruby::with_rubyvm(|| get_proc_iseq(proc));
unsafe { crate::cruby::rb_zjit_profile_disable(iseq) };
let mut function = iseq_to_hir(iseq).unwrap();
function.optimize();
function.validate().unwrap();
- format!("{}", FunctionPrinter::without_snapshot(&function))
+ hir_string_function(&function)
+ }
+
+ #[track_caller]
+ fn hir_string(method: &str) -> String {
+ hir_string_proc(&format!("{}.method(:{})", "self", method))
}
#[test]
@@ -10672,6 +10691,87 @@ mod opt_tests {
}
#[test]
+ fn test_inline_kernel_block_given_p() {
+ eval("
+ def test = block_given?
+ test
+ ");
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:2:
+ bb0():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ Jump bb2(v1)
+ bb1(v4:BasicObject):
+ EntryPoint JIT(0)
+ Jump bb2(v4)
+ bb2(v6:BasicObject):
+ PatchPoint MethodRedefined(Object@0x1000, block_given?@0x1008, cme:0x1010)
+ PatchPoint NoSingletonClass(Object@0x1000)
+ v20:HeapObject[class_exact*:Object@VALUE(0x1000)] = GuardType v6, HeapObject[class_exact*:Object@VALUE(0x1000)]
+ v21:BoolExact = IsBlockGiven
+ IncrCounter inline_cfunc_optimized_send_count
+ CheckInterrupts
+ Return v21
+ ");
+ }
+
+ #[test]
+ fn test_inline_kernel_block_given_p_in_block() {
+ eval("
+ TEST = proc { block_given? }
+ TEST.call
+ ");
+ assert_snapshot!(hir_string_proc("TEST"), @r"
+ fn block in <compiled>@<compiled>:2:
+ bb0():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ Jump bb2(v1)
+ bb1(v4:BasicObject):
+ EntryPoint JIT(0)
+ Jump bb2(v4)
+ bb2(v6:BasicObject):
+ PatchPoint MethodRedefined(Object@0x1000, block_given?@0x1008, cme:0x1010)
+ PatchPoint NoSingletonClass(Object@0x1000)
+ v20:HeapObject[class_exact*:Object@VALUE(0x1000)] = GuardType v6, HeapObject[class_exact*:Object@VALUE(0x1000)]
+ v21:BoolExact = IsBlockGiven
+ IncrCounter inline_cfunc_optimized_send_count
+ CheckInterrupts
+ Return v21
+ ");
+ }
+
+ #[test]
+ fn test_elide_kernel_block_given_p() {
+ eval("
+ def test
+ block_given?
+ 5
+ end
+ test
+ ");
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:3:
+ bb0():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ Jump bb2(v1)
+ bb1(v4:BasicObject):
+ EntryPoint JIT(0)
+ Jump bb2(v4)
+ bb2(v6:BasicObject):
+ PatchPoint MethodRedefined(Object@0x1000, block_given?@0x1008, cme:0x1010)
+ PatchPoint NoSingletonClass(Object@0x1000)
+ v23:HeapObject[class_exact*:Object@VALUE(0x1000)] = GuardType v6, HeapObject[class_exact*:Object@VALUE(0x1000)]
+ IncrCounter inline_cfunc_optimized_send_count
+ v14:Fixnum[5] = Const Value(5)
+ CheckInterrupts
+ Return v14
+ ");
+ }
+
+ #[test]
fn const_send_direct_integer() {
eval("
def test(x) = 1.zero?