summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEileen <eileencodes@users.noreply.github.com>2025-08-15 14:31:25 -0400
committerGitHub <noreply@github.com>2025-08-15 18:31:25 +0000
commit2a1210f7cce284e07217e620313085e12e4d575d (patch)
tree3a8720f092ad967e25d2507fbd349f17f9b8ff64
parent1d7ed95604d7f9b9847c0054d1c48704a0d1bded (diff)
ZJIT: Implement getspecial (#13642)
ZJIT: Implement getspecial in ZJIT Adds support for the getspecial instruction in zjit. We split getspecial into two instructions, one for special symbols (`$&`, $'`, etc) and one for special backrefs (`$1`, `$2`, etc). Co-authored-by: Aaron Patterson <tenderlove@ruby-lang.org>
-rw-r--r--test/ruby/test_zjit.rb100
-rw-r--r--zjit/src/codegen.rs35
-rw-r--r--zjit/src/hir.rs57
3 files changed, 191 insertions, 1 deletions
diff --git a/test/ruby/test_zjit.rb b/test/ruby/test_zjit.rb
index e3880a2828..d30af737c3 100644
--- a/test/ruby/test_zjit.rb
+++ b/test/ruby/test_zjit.rb
@@ -1221,6 +1221,106 @@ class TestZJIT < Test::Unit::TestCase
}, insns: [:opt_nil_p]
end
+ def test_getspecial_last_match
+ assert_compiles '"hello"', %q{
+ def test(str)
+ str =~ /hello/
+ $&
+ end
+ test("hello world")
+ }, insns: [:getspecial]
+ end
+
+ def test_getspecial_match_pre
+ assert_compiles '"hello "', %q{
+ def test(str)
+ str =~ /world/
+ $`
+ end
+ test("hello world")
+ }, insns: [:getspecial]
+ end
+
+ def test_getspecial_match_post
+ assert_compiles '" world"', %q{
+ def test(str)
+ str =~ /hello/
+ $'
+ end
+ test("hello world")
+ }, insns: [:getspecial]
+ end
+
+ def test_getspecial_match_last_group
+ assert_compiles '"world"', %q{
+ def test(str)
+ str =~ /(hello) (world)/
+ $+
+ end
+ test("hello world")
+ }, insns: [:getspecial]
+ end
+
+ def test_getspecial_numbered_match_1
+ assert_compiles '"hello"', %q{
+ def test(str)
+ str =~ /(hello) (world)/
+ $1
+ end
+ test("hello world")
+ }, insns: [:getspecial]
+ end
+
+ def test_getspecial_numbered_match_2
+ assert_compiles '"world"', %q{
+ def test(str)
+ str =~ /(hello) (world)/
+ $2
+ end
+ test("hello world")
+ }, insns: [:getspecial]
+ end
+
+ def test_getspecial_numbered_match_nonexistent
+ assert_compiles 'nil', %q{
+ def test(str)
+ str =~ /(hello)/
+ $2
+ end
+ test("hello world")
+ }, insns: [:getspecial]
+ end
+
+ def test_getspecial_no_match
+ assert_compiles 'nil', %q{
+ def test(str)
+ str =~ /xyz/
+ $&
+ end
+ test("hello world")
+ }, insns: [:getspecial]
+ end
+
+ def test_getspecial_complex_pattern
+ assert_compiles '"123"', %q{
+ def test(str)
+ str =~ /(\d+)/
+ $1
+ end
+ test("abc123def")
+ }, insns: [:getspecial]
+ end
+
+ def test_getspecial_multiple_groups
+ assert_compiles '"456"', %q{
+ def test(str)
+ str =~ /(\d+)-(\d+)/
+ $2
+ end
+ test("123-456")
+ }, insns: [:getspecial]
+ end
+
# tool/ruby_vm/views/*.erb relies on the zjit instructions a) being contiguous and
# b) being reliably ordered after all the other instructions.
def test_instruction_order
diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs
index ba7555485a..a096f3fad6 100644
--- a/zjit/src/codegen.rs
+++ b/zjit/src/codegen.rs
@@ -10,7 +10,7 @@ use crate::state::ZJITState;
use crate::stats::{counter_ptr, with_time_stat, Counter, Counter::compile_time_ns};
use crate::{asm::CodeBlock, cruby::*, options::debug, virtualmem::CodePtr};
use crate::backend::lir::{self, asm_comment, asm_ccall, Assembler, Opnd, Target, CFP, C_ARG_OPNDS, C_RET_OPND, EC, NATIVE_STACK_PTR, NATIVE_BASE_PTR, SP};
-use crate::hir::{iseq_to_hir, Block, BlockId, BranchEdge, Invariant, RangeType, SideExitReason, SideExitReason::*, SpecialObjectType, SELF_PARAM_IDX};
+use crate::hir::{iseq_to_hir, Block, BlockId, BranchEdge, Invariant, RangeType, SideExitReason, SideExitReason::*, SpecialObjectType, SpecialBackrefSymbol, SELF_PARAM_IDX};
use crate::hir::{Const, FrameState, Function, Insn, InsnId};
use crate::hir_type::{types, Type};
use crate::options::get_option;
@@ -378,6 +378,8 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio
Insn::PutSpecialObject { value_type } => gen_putspecialobject(asm, *value_type),
Insn::AnyToString { val, str, state } => gen_anytostring(asm, opnd!(val), opnd!(str), &function.frame_state(*state))?,
Insn::Defined { op_type, obj, pushval, v, state } => gen_defined(jit, asm, *op_type, *obj, *pushval, opnd!(v), &function.frame_state(*state))?,
+ Insn::GetSpecialSymbol { symbol_type, state: _ } => gen_getspecial_symbol(asm, *symbol_type),
+ Insn::GetSpecialNumber { nth, state } => gen_getspecial_number(asm, *nth, &function.frame_state(*state)),
&Insn::IncrCounter(counter) => return Some(gen_incr_counter(asm, counter)),
Insn::ObjToString { val, cd, state, .. } => gen_objtostring(jit, asm, opnd!(val), *cd, &function.frame_state(*state))?,
Insn::ArrayExtend { .. }
@@ -640,6 +642,37 @@ fn gen_putspecialobject(asm: &mut Assembler, value_type: SpecialObjectType) -> O
asm_ccall!(asm, rb_vm_get_special_object, ep_reg, Opnd::UImm(u64::from(value_type)))
}
+fn gen_getspecial_symbol(asm: &mut Assembler, symbol_type: SpecialBackrefSymbol) -> Opnd {
+ // Fetch a "special" backref based on the symbol type
+
+ let backref = asm_ccall!(asm, rb_backref_get,);
+
+ match symbol_type {
+ SpecialBackrefSymbol::LastMatch => {
+ asm_ccall!(asm, rb_reg_last_match, backref)
+ }
+ SpecialBackrefSymbol::PreMatch => {
+ asm_ccall!(asm, rb_reg_match_pre, backref)
+ }
+ SpecialBackrefSymbol::PostMatch => {
+ asm_ccall!(asm, rb_reg_match_post, backref)
+ }
+ SpecialBackrefSymbol::LastGroup => {
+ asm_ccall!(asm, rb_reg_match_last, backref)
+ }
+ }
+}
+
+fn gen_getspecial_number(asm: &mut Assembler, nth: u64, state: &FrameState) -> Opnd {
+ // Fetch the N-th match from the last backref based on type shifted by 1
+
+ let backref = asm_ccall!(asm, rb_backref_get,);
+
+ gen_prepare_call_with_gc(asm, state);
+
+ asm_ccall!(asm, rb_reg_nth_match, Opnd::Imm((nth >> 1).try_into().unwrap()), backref)
+}
+
/// Compile an interpreter entry block to be inserted into an ISEQ
fn gen_entry_prologue(asm: &mut Assembler, iseq: IseqPtr) {
asm_comment!(asm, "ZJIT entry point: {}", iseq_get_location(iseq, 0));
diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs
index b6e18e7356..c88965f891 100644
--- a/zjit/src/hir.rs
+++ b/zjit/src/hir.rs
@@ -321,6 +321,29 @@ impl From<RangeType> for u32 {
}
}
+/// Special regex backref symbol types
+#[derive(Debug, Clone, Copy, PartialEq)]
+pub enum SpecialBackrefSymbol {
+ LastMatch, // $&
+ PreMatch, // $`
+ PostMatch, // $'
+ LastGroup, // $+
+}
+
+impl TryFrom<u8> for SpecialBackrefSymbol {
+ type Error = String;
+
+ fn try_from(value: u8) -> Result<Self, Self::Error> {
+ match value as char {
+ '&' => Ok(SpecialBackrefSymbol::LastMatch),
+ '`' => Ok(SpecialBackrefSymbol::PreMatch),
+ '\'' => Ok(SpecialBackrefSymbol::PostMatch),
+ '+' => Ok(SpecialBackrefSymbol::LastGroup),
+ c => Err(format!("invalid backref symbol: '{}'", c)),
+ }
+ }
+}
+
/// Print adaptor for [`Const`]. See [`PtrPrintMap`].
struct ConstPrinter<'a> {
inner: &'a Const,
@@ -415,6 +438,7 @@ pub enum SideExitReason {
PatchPoint(Invariant),
CalleeSideExit,
ObjToStringFallback,
+ UnknownSpecialVariable(u64),
}
impl std::fmt::Display for SideExitReason {
@@ -494,6 +518,8 @@ pub enum Insn {
GetLocal { level: u32, ep_offset: u32 },
/// Set a local variable in a higher scope or the heap
SetLocal { level: u32, ep_offset: u32, val: InsnId },
+ GetSpecialSymbol { symbol_type: SpecialBackrefSymbol, state: InsnId },
+ GetSpecialNumber { nth: u64, state: InsnId },
/// Own a FrameState so that instructions can look up their dominating FrameState when
/// generating deopt side-exits and frame reconstruction metadata. Does not directly generate
@@ -774,6 +800,8 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> {
Insn::SetGlobal { id, val, .. } => write!(f, "SetGlobal :{}, {val}", id.contents_lossy()),
Insn::GetLocal { level, ep_offset } => write!(f, "GetLocal l{level}, EP@{ep_offset}"),
Insn::SetLocal { val, level, ep_offset } => write!(f, "SetLocal l{level}, EP@{ep_offset}, {val}"),
+ Insn::GetSpecialSymbol { symbol_type, .. } => write!(f, "GetSpecialSymbol {symbol_type:?}"),
+ Insn::GetSpecialNumber { nth, .. } => write!(f, "GetSpecialNumber {nth}"),
Insn::ToArray { val, .. } => write!(f, "ToArray {val}"),
Insn::ToNewArray { val, .. } => write!(f, "ToNewArray {val}"),
Insn::ArrayExtend { left, right, .. } => write!(f, "ArrayExtend {left}, {right}"),
@@ -1221,6 +1249,8 @@ impl Function {
&GetIvar { self_val, id, state } => GetIvar { self_val: find!(self_val), id, state },
&SetIvar { self_val, id, val, state } => SetIvar { self_val: find!(self_val), id, val: find!(val), state },
&SetLocal { val, ep_offset, level } => SetLocal { val: find!(val), ep_offset, level },
+ &GetSpecialSymbol { symbol_type, state } => GetSpecialSymbol { symbol_type, state },
+ &GetSpecialNumber { nth, state } => GetSpecialNumber { nth, state },
&ToArray { val, state } => ToArray { val: find!(val), state },
&ToNewArray { val, state } => ToNewArray { val: find!(val), state },
&ArrayExtend { left, right, state } => ArrayExtend { left: find!(left), right: find!(right), state },
@@ -1306,6 +1336,8 @@ impl Function {
Insn::ArrayMax { .. } => types::BasicObject,
Insn::GetGlobal { .. } => types::BasicObject,
Insn::GetIvar { .. } => types::BasicObject,
+ Insn::GetSpecialSymbol { .. } => types::BasicObject,
+ Insn::GetSpecialNumber { .. } => types::BasicObject,
Insn::ToNewArray { .. } => types::ArrayExact,
Insn::ToArray { .. } => types::ArrayExact,
Insn::ObjToString { .. } => types::BasicObject,
@@ -1995,6 +2027,8 @@ impl Function {
worklist.push_back(state);
}
&Insn::GetGlobal { state, .. } |
+ &Insn::GetSpecialSymbol { state, .. } |
+ &Insn::GetSpecialNumber { state, .. } |
&Insn::SideExit { state, .. } => worklist.push_back(state),
}
}
@@ -3325,6 +3359,29 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
let anytostring = fun.push_insn(block, Insn::AnyToString { val, str, state: exit_id });
state.stack_push(anytostring);
}
+ YARVINSN_getspecial => {
+ let key = get_arg(pc, 0).as_u64();
+ let svar = get_arg(pc, 1).as_u64();
+
+ let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
+
+ if svar == 0 {
+ // TODO: Handle non-backref
+ fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::UnknownSpecialVariable(key) });
+ // End the block
+ break;
+ } else if svar & 0x01 != 0 {
+ // Handle symbol backrefs like $&, $`, $', $+
+ let shifted_svar: u8 = (svar >> 1).try_into().unwrap();
+ let symbol_type = SpecialBackrefSymbol::try_from(shifted_svar).expect("invalid backref symbol");
+ let result = fun.push_insn(block, Insn::GetSpecialSymbol { symbol_type, state: exit_id });
+ state.stack_push(result);
+ } else {
+ // Handle number backrefs like $1, $2, $3
+ let result = fun.push_insn(block, Insn::GetSpecialNumber { nth: svar, state: exit_id });
+ state.stack_push(result);
+ }
+ }
_ => {
// Unknown opcode; side-exit into the interpreter
let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });