diff options
| author | Daniel Colson <danieljamescolson@gmail.com> | 2025-10-22 23:01:26 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-10-22 20:01:26 -0700 |
| commit | 271be0a2258061be486e24dc61994a3f4155d669 (patch) | |
| tree | dab44b0f29178c35424a318570d0d65d18003212 | |
| parent | da214cf3a9611ca00d3dd204a97b3c22ba90d2d1 (diff) | |
ZJIT: Implement classvar get and set (#14918)
https://github.com/Shopify/ruby/issues/649
Class vars are a bit more involved than ivars, since we need to get the
class from the cref, so this calls out to `rb_vm_getclassvariable` and
`rb_vm_setclassvariable` like YJIT.
| -rw-r--r-- | test/ruby/test_zjit.rb | 22 | ||||
| -rw-r--r-- | zjit/src/codegen.rs | 16 | ||||
| -rw-r--r-- | zjit/src/hir.rs | 88 |
3 files changed, 124 insertions, 2 deletions
diff --git a/test/ruby/test_zjit.rb b/test/ruby/test_zjit.rb index 44f010d056..13c7817017 100644 --- a/test/ruby/test_zjit.rb +++ b/test/ruby/test_zjit.rb @@ -1664,6 +1664,28 @@ class TestZJIT < Test::Unit::TestCase } end + def test_getclassvariable + assert_compiles '42', %q{ + class Foo + def self.test = @@x + end + + Foo.class_variable_set(:@@x, 42) + Foo.test() + } + end + + def test_setclassvariable + assert_compiles '42', %q{ + class Foo + def self.test = @@x = 42 + end + + Foo.test() + Foo.class_variable_get(:@@x) + } + end + def test_attr_reader assert_compiles '[4, 4]', %q{ class C diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs index c00bdb474e..029e144303 100644 --- a/zjit/src/codegen.rs +++ b/zjit/src/codegen.rs @@ -426,6 +426,8 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio &Insn::GetLocal { ep_offset, level, use_sp, .. } => gen_getlocal(asm, ep_offset, level, use_sp), &Insn::SetLocal { val, ep_offset, level } => no_output!(gen_setlocal(asm, opnd!(val), function.type_of(val), ep_offset, level)), Insn::GetConstantPath { ic, state } => gen_get_constant_path(jit, asm, *ic, &function.frame_state(*state)), + Insn::GetClassVar { id, ic, state } => gen_getclassvar(jit, asm, *id, *ic, &function.frame_state(*state)), + Insn::SetClassVar { id, val, ic, state } => no_output!(gen_setclassvar(jit, asm, *id, opnd!(val), *ic, &function.frame_state(*state))), Insn::SetIvar { self_val, id, val, state } => no_output!(gen_setivar(jit, asm, opnd!(self_val), *id, opnd!(val), &function.frame_state(*state))), Insn::SideExit { state, reason } => no_output!(gen_side_exit(jit, asm, reason, &function.frame_state(*state))), Insn::PutSpecialObject { value_type } => gen_putspecialobject(asm, *value_type), @@ -832,6 +834,20 @@ fn gen_setivar(jit: &mut JITState, asm: &mut Assembler, recv: Opnd, id: ID, val: asm_ccall!(asm, rb_ivar_set, recv, id.0.into(), val); } +fn gen_getclassvar(jit: &mut JITState, asm: &mut Assembler, id: ID, ic: *const iseq_inline_cvar_cache_entry, state: &FrameState) -> Opnd { + gen_prepare_non_leaf_call(jit, asm, state); + + let iseq = asm.load(Opnd::mem(64, CFP, RUBY_OFFSET_CFP_ISEQ)); + asm_ccall!(asm, rb_vm_getclassvariable, iseq, CFP, id.0.into(), Opnd::const_ptr(ic)) +} + +fn gen_setclassvar(jit: &mut JITState, asm: &mut Assembler, id: ID, val: Opnd, ic: *const iseq_inline_cvar_cache_entry, state: &FrameState) { + gen_prepare_non_leaf_call(jit, asm, state); + + let iseq = asm.load(Opnd::mem(64, CFP, RUBY_OFFSET_CFP_ISEQ)); + asm_ccall!(asm, rb_vm_setclassvariable, iseq, CFP, id.0.into(), val, Opnd::const_ptr(ic)); +} + /// Look up global variables fn gen_getglobal(jit: &mut JITState, asm: &mut Assembler, id: ID, state: &FrameState) -> Opnd { // `Warning` module's method `warn` can be called when reading certain global variables diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index 834a33d23c..9f422c0146 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -648,6 +648,11 @@ pub enum Insn { GetSpecialSymbol { symbol_type: SpecialBackrefSymbol, state: InsnId }, GetSpecialNumber { nth: u64, state: InsnId }, + /// Get a class variable `id` + GetClassVar { id: ID, ic: *const iseq_inline_cvar_cache_entry, state: InsnId }, + /// Set a class variable `id` to `val` + SetClassVar { id: ID, val: InsnId, ic: *const iseq_inline_cvar_cache_entry, state: InsnId }, + /// Own a FrameState so that instructions can look up their dominating FrameState when /// generating deopt side-exits and frame reconstruction metadata. Does not directly generate /// any code. @@ -811,7 +816,7 @@ impl Insn { match self { Insn::Jump(_) | Insn::IfTrue { .. } | Insn::IfFalse { .. } | Insn::EntryPoint { .. } | Insn::Return { .. } - | Insn::PatchPoint { .. } | Insn::SetIvar { .. } | Insn::ArrayExtend { .. } + | Insn::PatchPoint { .. } | Insn::SetIvar { .. } | Insn::SetClassVar { .. } | Insn::ArrayExtend { .. } | Insn::ArrayPush { .. } | Insn::SideExit { .. } | Insn::SetGlobal { .. } | Insn::SetLocal { .. } | Insn::Throw { .. } | Insn::IncrCounter(_) | Insn::IncrCounterPtr { .. } | Insn::CheckInterrupts { .. } | Insn::GuardBlockParamProxy { .. } => false, @@ -1130,6 +1135,8 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> { Insn::SetLocal { val, level, ep_offset } => write!(f, "SetLocal l{level}, EP@{ep_offset}, {val}"), Insn::GetSpecialSymbol { symbol_type, .. } => write!(f, "GetSpecialSymbol {symbol_type:?}"), Insn::GetSpecialNumber { nth, .. } => write!(f, "GetSpecialNumber {nth}"), + Insn::GetClassVar { id, .. } => write!(f, "GetClassVar :{}", id.contents_lossy()), + Insn::SetClassVar { id, val, .. } => write!(f, "SetClassVar :{}, {val}", id.contents_lossy()), Insn::ToArray { val, .. } => write!(f, "ToArray {val}"), Insn::ToNewArray { val, .. } => write!(f, "ToNewArray {val}"), Insn::ArrayExtend { left, right, .. } => write!(f, "ArrayExtend {left}, {right}"), @@ -1716,6 +1723,8 @@ impl Function { &LoadIvarEmbedded { self_val, id, index } => LoadIvarEmbedded { self_val: find!(self_val), id, index }, &LoadIvarExtended { self_val, id, index } => LoadIvarExtended { self_val: find!(self_val), id, index }, &SetIvar { self_val, id, val, state } => SetIvar { self_val: find!(self_val), id, val: find!(val), state }, + &GetClassVar { id, ic, state } => GetClassVar { id, ic, state }, + &SetClassVar { id, val, ic, state } => SetClassVar { id, val: find!(val), ic, state }, &SetLocal { val, ep_offset, level } => SetLocal { val: find!(val), ep_offset, level }, &GetSpecialSymbol { symbol_type, state } => GetSpecialSymbol { symbol_type, state }, &GetSpecialNumber { nth, state } => GetSpecialNumber { nth, state }, @@ -1765,7 +1774,7 @@ impl Function { Insn::Param { .. } => unimplemented!("params should not be present in block.insns"), Insn::SetGlobal { .. } | Insn::Jump(_) | Insn::EntryPoint { .. } | Insn::IfTrue { .. } | Insn::IfFalse { .. } | Insn::Return { .. } | Insn::Throw { .. } - | Insn::PatchPoint { .. } | Insn::SetIvar { .. } | Insn::ArrayExtend { .. } + | Insn::PatchPoint { .. } | Insn::SetIvar { .. } | Insn::SetClassVar { .. } | Insn::ArrayExtend { .. } | Insn::ArrayPush { .. } | Insn::SideExit { .. } | Insn::SetLocal { .. } | Insn::IncrCounter(_) | Insn::CheckInterrupts { .. } | Insn::GuardBlockParamProxy { .. } | Insn::IncrCounterPtr { .. } => panic!("Cannot infer type of instruction with no output: {}", self.insns[insn.0]), @@ -1848,6 +1857,7 @@ impl Function { Insn::LoadIvarExtended { .. } => types::BasicObject, Insn::GetSpecialSymbol { .. } => types::BasicObject, Insn::GetSpecialNumber { .. } => types::BasicObject, + Insn::GetClassVar { .. } => types::BasicObject, Insn::ToNewArray { .. } => types::ArrayExact, Insn::ToArray { .. } => types::ArrayExact, Insn::ObjToString { .. } => types::BasicObject, @@ -3156,6 +3166,13 @@ impl Function { worklist.push_back(val); worklist.push_back(state); } + &Insn::GetClassVar { state, .. } => { + worklist.push_back(state); + } + &Insn::SetClassVar { val, state, .. } => { + worklist.push_back(val); + worklist.push_back(state); + } &Insn::ArrayPush { array, val, state } => { worklist.push_back(array); worklist.push_back(val); @@ -4639,6 +4656,20 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> { let val = state.stack_pop()?; fun.push_insn(block, Insn::SetIvar { self_val: self_param, id, val, state: exit_id }); } + YARVINSN_getclassvariable => { + let id = ID(get_arg(pc, 0).as_u64()); + let ic = get_arg(pc, 1).as_ptr(); + let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state }); + let result = fun.push_insn(block, Insn::GetClassVar { id, ic, state: exit_id }); + state.stack_push(result); + } + YARVINSN_setclassvariable => { + let id = ID(get_arg(pc, 0).as_u64()); + let ic = get_arg(pc, 1).as_ptr(); + let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state }); + let val = state.stack_pop()?; + fun.push_insn(block, Insn::SetClassVar { id, val, ic, state: exit_id }); + } YARVINSN_opt_reverse => { // Reverse the order of the top N stack items. let n = get_arg(pc, 0).as_usize(); @@ -7430,6 +7461,59 @@ mod tests { } #[test] + fn test_getclassvariable() { + eval(" + class Foo + def self.test = @@foo + end + "); + let iseq = crate::cruby::with_rubyvm(|| get_method_iseq("Foo", "test")); + assert!(iseq_contains_opcode(iseq, YARVINSN_getclassvariable), "iseq Foo.test does not contain getclassvariable"); + let function = iseq_to_hir(iseq).unwrap(); + assert_snapshot!(hir_string_function(&function), @r" + fn test@<compiled>:3: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + Jump bb2(v1) + bb1(v4:BasicObject): + EntryPoint JIT(0) + Jump bb2(v4) + bb2(v6:BasicObject): + v11:BasicObject = GetClassVar :@@foo + CheckInterrupts + Return v11 + "); + } + + #[test] + fn test_setclassvariable() { + eval(" + class Foo + def self.test = @@foo = 42 + end + "); + let iseq = crate::cruby::with_rubyvm(|| get_method_iseq("Foo", "test")); + assert!(iseq_contains_opcode(iseq, YARVINSN_setclassvariable), "iseq Foo.test does not contain setclassvariable"); + let function = iseq_to_hir(iseq).unwrap(); + assert_snapshot!(hir_string_function(&function), @r" + fn test@<compiled>:3: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + Jump bb2(v1) + bb1(v4:BasicObject): + EntryPoint JIT(0) + Jump bb2(v4) + bb2(v6:BasicObject): + v10:Fixnum[42] = Const Value(42) + SetClassVar :@@foo, v10 + CheckInterrupts + Return v10 + "); + } + + #[test] fn test_setglobal() { eval(" def test = $foo = 1 |
