diff options
| author | Nozomi Hijikata <121233810+nozomemein@users.noreply.github.com> | 2026-03-24 09:56:28 +0900 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-03-23 17:56:28 -0700 |
| commit | a897e06db526996c55451c5d7d5049013b76e1a7 (patch) | |
| tree | e11da9e73cdb3c1d83074b9545f9539ca5285844 /zjit | |
| parent | 216c5eb335c040e8f363ee689a07642d82b91418 (diff) | |
ZJIT: Compile checkmatch insn (#16496)
Diffstat (limited to 'zjit')
| -rw-r--r-- | zjit/src/codegen.rs | 15 | ||||
| -rw-r--r-- | zjit/src/codegen_tests.rs | 78 | ||||
| -rw-r--r-- | zjit/src/hir.rs | 36 | ||||
| -rw-r--r-- | zjit/src/hir/tests.rs | 136 |
4 files changed, 265 insertions, 0 deletions
diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs index a1f7d3f65c..30c99152e2 100644 --- a/zjit/src/codegen.rs +++ b/zjit/src/codegen.rs @@ -662,6 +662,7 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio Insn::PutSpecialObject { value_type } => gen_putspecialobject(asm, *value_type), Insn::AnyToString { val, str, state } => gen_anytostring(asm, opnd!(val), opnd!(str), &function.frame_state(*state)), Insn::Defined { op_type, obj, pushval, v, state } => gen_defined(jit, asm, *op_type, *obj, *pushval, opnd!(v), &function.frame_state(*state)), + Insn::CheckMatch { target, pattern, flag, state } => gen_checkmatch(jit, asm, opnd!(target), opnd!(pattern), *flag, &function.frame_state(*state)), Insn::GetSpecialSymbol { symbol_type, state: _ } => gen_getspecial_symbol(asm, *symbol_type), Insn::GetSpecialNumber { nth, state } => gen_getspecial_number(asm, *nth, &function.frame_state(*state)), &Insn::IncrCounter(counter) => no_output!(gen_incr_counter(asm, counter)), @@ -1266,6 +1267,20 @@ fn gen_defined_ivar(asm: &mut Assembler, self_val: Opnd, id: ID, pushval: VALUE) asm_ccall!(asm, rb_zjit_defined_ivar, self_val, id.0.into(), Opnd::Value(pushval)) } +fn gen_checkmatch(jit: &JITState, asm: &mut Assembler, target: Opnd, pattern: Opnd, flag: u32, state: &FrameState) -> lir::Opnd { + // rb_vm_check_match is not leaf unless flag is VM_CHECKMATCH_TYPE_WHEN. + // See also: leafness_of_checkmatch() and check_match() + if flag != VM_CHECKMATCH_TYPE_WHEN { + gen_prepare_non_leaf_call(jit, asm, state); + } + + unsafe extern "C" { + fn rb_vm_check_match(ec: EcPtr, target: VALUE, pattern: VALUE, flag: u32) -> VALUE; + } + + asm_ccall!(asm, rb_vm_check_match, EC, target, pattern, flag.into()) +} + fn gen_array_extend(jit: &mut JITState, asm: &mut Assembler, left: Opnd, right: Opnd, state: &FrameState) { gen_prepare_non_leaf_call(jit, asm, state); asm_ccall!(asm, rb_ary_concat, left, right); diff --git a/zjit/src/codegen_tests.rs b/zjit/src/codegen_tests.rs index 143200c2bc..4f479aa072 100644 --- a/zjit/src/codegen_tests.rs +++ b/zjit/src/codegen_tests.rs @@ -4508,6 +4508,84 @@ fn test_opt_case_dispatch() { } #[test] +fn test_checkmatch_case() { + eval(r#" + def test(o) + case o + in Integer + 1 + else + 2 + end + end + "#); + assert_contains_opcode("test", YARVINSN_checkmatch); + assert_snapshot!(inspect(r#"[test(1), test(2), test("3")]"#), @"[1, 1, 2]"); +} + +#[test] +fn test_checkmatch_case_splat_array() { + eval(r#" + def test(o) + case o + when *[1, 2] + 1 + else + 2 + end + end + "#); + assert_contains_opcode("test", YARVINSN_checkmatch); + assert_snapshot!(inspect("[test(1), test(2), test(3)]"), @"[1, 1, 2]"); +} + +#[test] +fn test_checkmatch_when_splat_array() { + eval(r#" + def test + case + when *[1, 2] + 1 + else + 2 + end + end + "#); + assert_contains_opcode("test", YARVINSN_checkmatch); + assert_snapshot!(inspect("[test, test]"), @"[1, 1]"); +} + +#[test] +fn test_checkmatch_rescue() { + // Rescue behavior is tested functionally here. It still side-exits because + // JIT exception handling is not supported yet. + eval(r#" + def test + begin + raise TypeError + rescue TypeError + 1 + end + end + "#); + assert_snapshot!(inspect("[test, test]"), @"[1, 1]"); +} + +#[test] +fn test_checkmatch_rescue_splat_array() { + eval(r#" + def test + begin + raise TypeError + rescue *[TypeError, ArgumentError] + 1 + end + end + "#); + assert_snapshot!(inspect("[test, test]"), @"[1, 1]"); +} + +#[test] fn test_stack_overflow() { assert_snapshot!(inspect(" def recurse(n) diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index d4ac692934..db7a328771 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -847,6 +847,8 @@ pub enum Insn { /// Return Qtrue if `val` is an instance of `class`, else Qfalse. /// Equivalent to `class_search_ancestor(CLASS_OF(val), class)`. IsA { val: InsnId, class: InsnId }, + /// `case`/`when`/`rescue` match check for `pattern` against `target`. + CheckMatch { target: InsnId, pattern: InsnId, flag: u32, state: InsnId }, /// Get a global variable named `id` GetGlobal { id: ID, state: InsnId }, @@ -1103,6 +1105,11 @@ macro_rules! for_each_operand_impl { Insn::IsBlockParamModified { ep } => { $visit_one!(ep); } + Insn::CheckMatch { target, pattern, state, .. } => { + $visit_one!(target); + $visit_one!(pattern); + $visit_one!(state); + } Insn::PatchPoint { state, .. } | Insn::CheckInterrupts { state } | Insn::GetBlockParam { state, .. } @@ -1495,6 +1502,8 @@ impl Insn { Insn::LoadSelf { .. } => Effect::read_write(abstract_heaps::Frame, abstract_heaps::Empty), Insn::LoadField { .. } => Effect::read_write(abstract_heaps::Memory, abstract_heaps::Empty), Insn::StoreField { .. } => effects::Any, + // TODO: Refine CheckMatch effects by flag. + Insn::CheckMatch { .. } => effects::Any, // WriteBarrier can write to object flags and mark bits in Allocator memory. // This is why WriteBarrier writes to the "Memory" effect. We do not yet have a more granular specialization for flags Insn::WriteBarrier { .. } => Effect::read_write(abstract_heaps::Allocator, abstract_heaps::Allocator.union(abstract_heaps::Memory)), @@ -1984,6 +1993,23 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> { } Insn::DefinedIvar { self_val, id, .. } => write!(f, "DefinedIvar {self_val}, :{}", id.contents_lossy()), Insn::GetIvar { self_val, id, .. } => write!(f, "GetIvar {self_val}, :{}", id.contents_lossy()), + Insn::CheckMatch { target, pattern, flag, .. } => { + const TYPE_MASK: u32 = 0x03; + const ARRAY_FLAG: u32 = 0x04; + + let match_type = match *flag & TYPE_MASK { + VM_CHECKMATCH_TYPE_WHEN => "WHEN", + VM_CHECKMATCH_TYPE_CASE => "CASE", + VM_CHECKMATCH_TYPE_RESCUE => "RESCUE", + _ => return write!(f, "CheckMatch {target}, {pattern}, {flag}"), + }; + let flag = if *flag & ARRAY_FLAG != 0 { + format!("{match_type}|ARRAY") + } else { + match_type.to_string() + }; + write!(f, "CheckMatch {target}, {pattern}, {flag}") + } Insn::LoadPC => write!(f, "LoadPC"), Insn::LoadEC => write!(f, "LoadEC"), Insn::LoadSP => write!(f, "LoadSP"), @@ -2760,6 +2786,7 @@ impl Function { &CCallVariadic { cfunc, recv, ref args, cme, name, state, return_type, elidable, blockiseq } => CCallVariadic { cfunc, recv: find!(recv), args: find_vec!(args), cme, name, state, return_type, elidable, blockiseq }, + &CheckMatch { target, pattern, flag, state } => CheckMatch { target: find!(target), pattern: find!(pattern), flag, state: find!(state) }, &Defined { op_type, obj, pushval, v, state } => Defined { op_type, obj, pushval, v: find!(v), state: find!(state) }, &DefinedIvar { self_val, pushval, id, state } => DefinedIvar { self_val: find!(self_val), pushval, id, state }, &GetConstant { klass, id, allow_nil, state } => GetConstant { klass: find!(klass), id, allow_nil: find!(allow_nil), state }, @@ -2898,6 +2925,7 @@ impl Function { &Insn::CCallWithFrame { return_type, .. } => return_type, Insn::CCall { return_type, .. } => *return_type, &Insn::CCallVariadic { return_type, .. } => return_type, + Insn::CheckMatch { .. } => types::BasicObject, Insn::GuardType { val, guard_type, .. } => self.type_of(*val).intersection(*guard_type), Insn::RefineType { val, new_type, .. } => self.type_of(*val).intersection(*new_type), Insn::HasType { .. } => types::CBool, @@ -5929,6 +5957,7 @@ impl Function { Insn::SetIvar { self_val: left, val: right, .. } | Insn::NewRange { low: left, high: right, .. } | Insn::AnyToString { val: left, str: right, .. } + | Insn::CheckMatch { target: left, pattern: right, .. } | Insn::WriteBarrier { recv: left, val: right } => { self.assert_subtype(insn_id, left, types::BasicObject)?; self.assert_subtype(insn_id, right, types::BasicObject) @@ -7097,6 +7126,13 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> { }; state.stack_push(fun.push_insn(block, Insn::FixnumBitCheck { val, index })); } + YARVINSN_checkmatch => { + let flag = get_arg(pc, 0).as_u32(); + let pattern = state.stack_pop()?; + let target = state.stack_pop()?; + let result = fun.push_insn(block, Insn::CheckMatch { target, pattern, flag, state: exit_id }); + state.stack_push(result); + } YARVINSN_getconstant => { let id = ID(get_arg(pc, 0).as_u64()); let allow_nil = state.stack_pop()?; diff --git a/zjit/src/hir/tests.rs b/zjit/src/hir/tests.rs index 5ff827d9b4..b56320254f 100644 --- a/zjit/src/hir/tests.rs +++ b/zjit/src/hir/tests.rs @@ -340,6 +340,142 @@ pub mod hir_build_tests { } #[test] + fn test_checkmatch_case() { + eval(r#" + def test(o) + case o + in Integer + 1 + else + 2 + end + end + test(1) + "#); + assert_contains_opcode("test", YARVINSN_checkmatch); + assert_snapshot!(hir_string("test"), @" + fn test@<compiled>:3: + bb1(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:CPtr = LoadSP + v3:BasicObject = LoadField v2, :o@0x1000 + Jump bb3(v1, v3) + bb2(): + EntryPoint JIT(0) + v6:BasicObject = LoadArg :self@0 + v7:BasicObject = LoadArg :o@1 + Jump bb3(v6, v7) + bb3(v9:BasicObject, v10:BasicObject): + v14:NilClass = Const Value(nil) + v18:BasicObject = GetConstantPath 0x1008 + v20:BasicObject = CheckMatch v10, v18, CASE + CheckInterrupts + v23:CBool = Test v20 + v24:Truthy = RefineType v20, Truthy + IfTrue v23, bb4(v9, v10, v14, v10) + v26:Falsy = RefineType v20, Falsy + v31:Fixnum[2] = Const Value(2) + CheckInterrupts + Return v31 + bb4(v36:BasicObject, v37:BasicObject, v38:NilClass, v39:BasicObject): + v44:Fixnum[1] = Const Value(1) + CheckInterrupts + Return v44 + "); + } + + #[test] + fn test_checkmatch_case_splat_array() { + eval(r#" + def test(o) + case o + when *[1, 2] + 1 + else + 2 + end + end + test(1) + "#); + assert_contains_opcode("test", YARVINSN_checkmatch); + assert_snapshot!(hir_string("test"), @" + fn test@<compiled>:3: + bb1(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:CPtr = LoadSP + v3:BasicObject = LoadField v2, :o@0x1000 + Jump bb3(v1, v3) + bb2(): + EntryPoint JIT(0) + v6:BasicObject = LoadArg :self@0 + v7:BasicObject = LoadArg :o@1 + Jump bb3(v6, v7) + bb3(v9:BasicObject, v10:BasicObject): + v16:ArrayExact[VALUE(0x1008)] = Const Value(VALUE(0x1008)) + v17:ArrayExact = ArrayDup v16 + v19:BasicObject = CheckMatch v10, v17, CASE|ARRAY + CheckInterrupts + v22:CBool = Test v19 + v23:Truthy = RefineType v19, Truthy + IfTrue v22, bb4(v9, v10, v10) + v25:Falsy = RefineType v19, Falsy + v29:Fixnum[2] = Const Value(2) + CheckInterrupts + Return v29 + bb4(v34:BasicObject, v35:BasicObject, v36:BasicObject): + v41:Fixnum[1] = Const Value(1) + CheckInterrupts + Return v41 + "); + } + + #[test] + fn test_checkmatch_when_splat_array() { + eval(r#" + def test + case + when *[1, 2] + 1 + else + 2 + end + end + test + "#); + assert_contains_opcode("test", YARVINSN_checkmatch); + assert_snapshot!(hir_string("test"), @" + fn test@<compiled>:4: + bb1(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + Jump bb3(v1) + bb2(): + EntryPoint JIT(0) + v4:BasicObject = LoadArg :self@0 + Jump bb3(v4) + bb3(v6:BasicObject): + v10:NilClass = Const Value(nil) + v12:ArrayExact[VALUE(0x1000)] = Const Value(VALUE(0x1000)) + v13:ArrayExact = ArrayDup v12 + v15:BasicObject = CheckMatch v10, v13, WHEN|ARRAY + CheckInterrupts + v18:CBool = Test v15 + v19:Truthy = RefineType v15, Truthy + IfTrue v18, bb4(v6) + v21:Falsy = RefineType v15, Falsy + v24:Fixnum[2] = Const Value(2) + CheckInterrupts + Return v24 + bb4(v29:BasicObject): + v33:Fixnum[1] = Const Value(1) + CheckInterrupts + Return v33 + "); + } + + #[test] fn test_new_array() { eval("def test = []"); assert_contains_opcode("test", YARVINSN_newarray); |
