summaryrefslogtreecommitdiff
path: root/zjit
diff options
context:
space:
mode:
authorNozomi Hijikata <121233810+nozomemein@users.noreply.github.com>2026-03-24 09:56:28 +0900
committerGitHub <noreply@github.com>2026-03-23 17:56:28 -0700
commita897e06db526996c55451c5d7d5049013b76e1a7 (patch)
treee11da9e73cdb3c1d83074b9545f9539ca5285844 /zjit
parent216c5eb335c040e8f363ee689a07642d82b91418 (diff)
ZJIT: Compile checkmatch insn (#16496)
Diffstat (limited to 'zjit')
-rw-r--r--zjit/src/codegen.rs15
-rw-r--r--zjit/src/codegen_tests.rs78
-rw-r--r--zjit/src/hir.rs36
-rw-r--r--zjit/src/hir/tests.rs136
4 files changed, 265 insertions, 0 deletions
diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs
index a1f7d3f65c..30c99152e2 100644
--- a/zjit/src/codegen.rs
+++ b/zjit/src/codegen.rs
@@ -662,6 +662,7 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio
Insn::PutSpecialObject { value_type } => gen_putspecialobject(asm, *value_type),
Insn::AnyToString { val, str, state } => gen_anytostring(asm, opnd!(val), opnd!(str), &function.frame_state(*state)),
Insn::Defined { op_type, obj, pushval, v, state } => gen_defined(jit, asm, *op_type, *obj, *pushval, opnd!(v), &function.frame_state(*state)),
+ Insn::CheckMatch { target, pattern, flag, state } => gen_checkmatch(jit, asm, opnd!(target), opnd!(pattern), *flag, &function.frame_state(*state)),
Insn::GetSpecialSymbol { symbol_type, state: _ } => gen_getspecial_symbol(asm, *symbol_type),
Insn::GetSpecialNumber { nth, state } => gen_getspecial_number(asm, *nth, &function.frame_state(*state)),
&Insn::IncrCounter(counter) => no_output!(gen_incr_counter(asm, counter)),
@@ -1266,6 +1267,20 @@ fn gen_defined_ivar(asm: &mut Assembler, self_val: Opnd, id: ID, pushval: VALUE)
asm_ccall!(asm, rb_zjit_defined_ivar, self_val, id.0.into(), Opnd::Value(pushval))
}
+fn gen_checkmatch(jit: &JITState, asm: &mut Assembler, target: Opnd, pattern: Opnd, flag: u32, state: &FrameState) -> lir::Opnd {
+ // rb_vm_check_match is not leaf unless flag is VM_CHECKMATCH_TYPE_WHEN.
+ // See also: leafness_of_checkmatch() and check_match()
+ if flag != VM_CHECKMATCH_TYPE_WHEN {
+ gen_prepare_non_leaf_call(jit, asm, state);
+ }
+
+ unsafe extern "C" {
+ fn rb_vm_check_match(ec: EcPtr, target: VALUE, pattern: VALUE, flag: u32) -> VALUE;
+ }
+
+ asm_ccall!(asm, rb_vm_check_match, EC, target, pattern, flag.into())
+}
+
fn gen_array_extend(jit: &mut JITState, asm: &mut Assembler, left: Opnd, right: Opnd, state: &FrameState) {
gen_prepare_non_leaf_call(jit, asm, state);
asm_ccall!(asm, rb_ary_concat, left, right);
diff --git a/zjit/src/codegen_tests.rs b/zjit/src/codegen_tests.rs
index 143200c2bc..4f479aa072 100644
--- a/zjit/src/codegen_tests.rs
+++ b/zjit/src/codegen_tests.rs
@@ -4508,6 +4508,84 @@ fn test_opt_case_dispatch() {
}
#[test]
+fn test_checkmatch_case() {
+ eval(r#"
+ def test(o)
+ case o
+ in Integer
+ 1
+ else
+ 2
+ end
+ end
+ "#);
+ assert_contains_opcode("test", YARVINSN_checkmatch);
+ assert_snapshot!(inspect(r#"[test(1), test(2), test("3")]"#), @"[1, 1, 2]");
+}
+
+#[test]
+fn test_checkmatch_case_splat_array() {
+ eval(r#"
+ def test(o)
+ case o
+ when *[1, 2]
+ 1
+ else
+ 2
+ end
+ end
+ "#);
+ assert_contains_opcode("test", YARVINSN_checkmatch);
+ assert_snapshot!(inspect("[test(1), test(2), test(3)]"), @"[1, 1, 2]");
+}
+
+#[test]
+fn test_checkmatch_when_splat_array() {
+ eval(r#"
+ def test
+ case
+ when *[1, 2]
+ 1
+ else
+ 2
+ end
+ end
+ "#);
+ assert_contains_opcode("test", YARVINSN_checkmatch);
+ assert_snapshot!(inspect("[test, test]"), @"[1, 1]");
+}
+
+#[test]
+fn test_checkmatch_rescue() {
+ // Rescue behavior is tested functionally here. It still side-exits because
+ // JIT exception handling is not supported yet.
+ eval(r#"
+ def test
+ begin
+ raise TypeError
+ rescue TypeError
+ 1
+ end
+ end
+ "#);
+ assert_snapshot!(inspect("[test, test]"), @"[1, 1]");
+}
+
+#[test]
+fn test_checkmatch_rescue_splat_array() {
+ eval(r#"
+ def test
+ begin
+ raise TypeError
+ rescue *[TypeError, ArgumentError]
+ 1
+ end
+ end
+ "#);
+ assert_snapshot!(inspect("[test, test]"), @"[1, 1]");
+}
+
+#[test]
fn test_stack_overflow() {
assert_snapshot!(inspect("
def recurse(n)
diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs
index d4ac692934..db7a328771 100644
--- a/zjit/src/hir.rs
+++ b/zjit/src/hir.rs
@@ -847,6 +847,8 @@ pub enum Insn {
/// Return Qtrue if `val` is an instance of `class`, else Qfalse.
/// Equivalent to `class_search_ancestor(CLASS_OF(val), class)`.
IsA { val: InsnId, class: InsnId },
+ /// `case`/`when`/`rescue` match check for `pattern` against `target`.
+ CheckMatch { target: InsnId, pattern: InsnId, flag: u32, state: InsnId },
/// Get a global variable named `id`
GetGlobal { id: ID, state: InsnId },
@@ -1103,6 +1105,11 @@ macro_rules! for_each_operand_impl {
Insn::IsBlockParamModified { ep } => {
$visit_one!(ep);
}
+ Insn::CheckMatch { target, pattern, state, .. } => {
+ $visit_one!(target);
+ $visit_one!(pattern);
+ $visit_one!(state);
+ }
Insn::PatchPoint { state, .. }
| Insn::CheckInterrupts { state }
| Insn::GetBlockParam { state, .. }
@@ -1495,6 +1502,8 @@ impl Insn {
Insn::LoadSelf { .. } => Effect::read_write(abstract_heaps::Frame, abstract_heaps::Empty),
Insn::LoadField { .. } => Effect::read_write(abstract_heaps::Memory, abstract_heaps::Empty),
Insn::StoreField { .. } => effects::Any,
+ // TODO: Refine CheckMatch effects by flag.
+ Insn::CheckMatch { .. } => effects::Any,
// WriteBarrier can write to object flags and mark bits in Allocator memory.
// This is why WriteBarrier writes to the "Memory" effect. We do not yet have a more granular specialization for flags
Insn::WriteBarrier { .. } => Effect::read_write(abstract_heaps::Allocator, abstract_heaps::Allocator.union(abstract_heaps::Memory)),
@@ -1984,6 +1993,23 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> {
}
Insn::DefinedIvar { self_val, id, .. } => write!(f, "DefinedIvar {self_val}, :{}", id.contents_lossy()),
Insn::GetIvar { self_val, id, .. } => write!(f, "GetIvar {self_val}, :{}", id.contents_lossy()),
+ Insn::CheckMatch { target, pattern, flag, .. } => {
+ const TYPE_MASK: u32 = 0x03;
+ const ARRAY_FLAG: u32 = 0x04;
+
+ let match_type = match *flag & TYPE_MASK {
+ VM_CHECKMATCH_TYPE_WHEN => "WHEN",
+ VM_CHECKMATCH_TYPE_CASE => "CASE",
+ VM_CHECKMATCH_TYPE_RESCUE => "RESCUE",
+ _ => return write!(f, "CheckMatch {target}, {pattern}, {flag}"),
+ };
+ let flag = if *flag & ARRAY_FLAG != 0 {
+ format!("{match_type}|ARRAY")
+ } else {
+ match_type.to_string()
+ };
+ write!(f, "CheckMatch {target}, {pattern}, {flag}")
+ }
Insn::LoadPC => write!(f, "LoadPC"),
Insn::LoadEC => write!(f, "LoadEC"),
Insn::LoadSP => write!(f, "LoadSP"),
@@ -2760,6 +2786,7 @@ impl Function {
&CCallVariadic { cfunc, recv, ref args, cme, name, state, return_type, elidable, blockiseq } => CCallVariadic {
cfunc, recv: find!(recv), args: find_vec!(args), cme, name, state, return_type, elidable, blockiseq
},
+ &CheckMatch { target, pattern, flag, state } => CheckMatch { target: find!(target), pattern: find!(pattern), flag, state: find!(state) },
&Defined { op_type, obj, pushval, v, state } => Defined { op_type, obj, pushval, v: find!(v), state: find!(state) },
&DefinedIvar { self_val, pushval, id, state } => DefinedIvar { self_val: find!(self_val), pushval, id, state },
&GetConstant { klass, id, allow_nil, state } => GetConstant { klass: find!(klass), id, allow_nil: find!(allow_nil), state },
@@ -2898,6 +2925,7 @@ impl Function {
&Insn::CCallWithFrame { return_type, .. } => return_type,
Insn::CCall { return_type, .. } => *return_type,
&Insn::CCallVariadic { return_type, .. } => return_type,
+ Insn::CheckMatch { .. } => types::BasicObject,
Insn::GuardType { val, guard_type, .. } => self.type_of(*val).intersection(*guard_type),
Insn::RefineType { val, new_type, .. } => self.type_of(*val).intersection(*new_type),
Insn::HasType { .. } => types::CBool,
@@ -5929,6 +5957,7 @@ impl Function {
Insn::SetIvar { self_val: left, val: right, .. }
| Insn::NewRange { low: left, high: right, .. }
| Insn::AnyToString { val: left, str: right, .. }
+ | Insn::CheckMatch { target: left, pattern: right, .. }
| Insn::WriteBarrier { recv: left, val: right } => {
self.assert_subtype(insn_id, left, types::BasicObject)?;
self.assert_subtype(insn_id, right, types::BasicObject)
@@ -7097,6 +7126,13 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
};
state.stack_push(fun.push_insn(block, Insn::FixnumBitCheck { val, index }));
}
+ YARVINSN_checkmatch => {
+ let flag = get_arg(pc, 0).as_u32();
+ let pattern = state.stack_pop()?;
+ let target = state.stack_pop()?;
+ let result = fun.push_insn(block, Insn::CheckMatch { target, pattern, flag, state: exit_id });
+ state.stack_push(result);
+ }
YARVINSN_getconstant => {
let id = ID(get_arg(pc, 0).as_u64());
let allow_nil = state.stack_pop()?;
diff --git a/zjit/src/hir/tests.rs b/zjit/src/hir/tests.rs
index 5ff827d9b4..b56320254f 100644
--- a/zjit/src/hir/tests.rs
+++ b/zjit/src/hir/tests.rs
@@ -340,6 +340,142 @@ pub mod hir_build_tests {
}
#[test]
+ fn test_checkmatch_case() {
+ eval(r#"
+ def test(o)
+ case o
+ in Integer
+ 1
+ else
+ 2
+ end
+ end
+ test(1)
+ "#);
+ assert_contains_opcode("test", YARVINSN_checkmatch);
+ assert_snapshot!(hir_string("test"), @"
+ fn test@<compiled>:3:
+ bb1():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ v2:CPtr = LoadSP
+ v3:BasicObject = LoadField v2, :o@0x1000
+ Jump bb3(v1, v3)
+ bb2():
+ EntryPoint JIT(0)
+ v6:BasicObject = LoadArg :self@0
+ v7:BasicObject = LoadArg :o@1
+ Jump bb3(v6, v7)
+ bb3(v9:BasicObject, v10:BasicObject):
+ v14:NilClass = Const Value(nil)
+ v18:BasicObject = GetConstantPath 0x1008
+ v20:BasicObject = CheckMatch v10, v18, CASE
+ CheckInterrupts
+ v23:CBool = Test v20
+ v24:Truthy = RefineType v20, Truthy
+ IfTrue v23, bb4(v9, v10, v14, v10)
+ v26:Falsy = RefineType v20, Falsy
+ v31:Fixnum[2] = Const Value(2)
+ CheckInterrupts
+ Return v31
+ bb4(v36:BasicObject, v37:BasicObject, v38:NilClass, v39:BasicObject):
+ v44:Fixnum[1] = Const Value(1)
+ CheckInterrupts
+ Return v44
+ ");
+ }
+
+ #[test]
+ fn test_checkmatch_case_splat_array() {
+ eval(r#"
+ def test(o)
+ case o
+ when *[1, 2]
+ 1
+ else
+ 2
+ end
+ end
+ test(1)
+ "#);
+ assert_contains_opcode("test", YARVINSN_checkmatch);
+ assert_snapshot!(hir_string("test"), @"
+ fn test@<compiled>:3:
+ bb1():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ v2:CPtr = LoadSP
+ v3:BasicObject = LoadField v2, :o@0x1000
+ Jump bb3(v1, v3)
+ bb2():
+ EntryPoint JIT(0)
+ v6:BasicObject = LoadArg :self@0
+ v7:BasicObject = LoadArg :o@1
+ Jump bb3(v6, v7)
+ bb3(v9:BasicObject, v10:BasicObject):
+ v16:ArrayExact[VALUE(0x1008)] = Const Value(VALUE(0x1008))
+ v17:ArrayExact = ArrayDup v16
+ v19:BasicObject = CheckMatch v10, v17, CASE|ARRAY
+ CheckInterrupts
+ v22:CBool = Test v19
+ v23:Truthy = RefineType v19, Truthy
+ IfTrue v22, bb4(v9, v10, v10)
+ v25:Falsy = RefineType v19, Falsy
+ v29:Fixnum[2] = Const Value(2)
+ CheckInterrupts
+ Return v29
+ bb4(v34:BasicObject, v35:BasicObject, v36:BasicObject):
+ v41:Fixnum[1] = Const Value(1)
+ CheckInterrupts
+ Return v41
+ ");
+ }
+
+ #[test]
+ fn test_checkmatch_when_splat_array() {
+ eval(r#"
+ def test
+ case
+ when *[1, 2]
+ 1
+ else
+ 2
+ end
+ end
+ test
+ "#);
+ assert_contains_opcode("test", YARVINSN_checkmatch);
+ assert_snapshot!(hir_string("test"), @"
+ fn test@<compiled>:4:
+ bb1():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ Jump bb3(v1)
+ bb2():
+ EntryPoint JIT(0)
+ v4:BasicObject = LoadArg :self@0
+ Jump bb3(v4)
+ bb3(v6:BasicObject):
+ v10:NilClass = Const Value(nil)
+ v12:ArrayExact[VALUE(0x1000)] = Const Value(VALUE(0x1000))
+ v13:ArrayExact = ArrayDup v12
+ v15:BasicObject = CheckMatch v10, v13, WHEN|ARRAY
+ CheckInterrupts
+ v18:CBool = Test v15
+ v19:Truthy = RefineType v15, Truthy
+ IfTrue v18, bb4(v6)
+ v21:Falsy = RefineType v15, Falsy
+ v24:Fixnum[2] = Const Value(2)
+ CheckInterrupts
+ Return v24
+ bb4(v29:BasicObject):
+ v33:Fixnum[1] = Const Value(1)
+ CheckInterrupts
+ Return v33
+ ");
+ }
+
+ #[test]
fn test_new_array() {
eval("def test = []");
assert_contains_opcode("test", YARVINSN_newarray);