diff options
| author | Max Bernstein <ruby@bernsteinbear.com> | 2025-09-18 13:47:35 -0400 |
|---|---|---|
| committer | Max Bernstein <tekknolagi@gmail.com> | 2025-09-19 22:38:29 -0400 |
| commit | 254b9b4952156ee443b920e89fec7b6d16a0f785 (patch) | |
| tree | 00456298f98cc113cd347a805d3428320dd29264 | |
| parent | 4a04e6f7555c94e0e58ad090681e359c7dcbf22e (diff) | |
ZJIT: Expand the list of safe allocators
It's not just the default allocator; other allocators are also leaf.
| -rw-r--r-- | zjit/bindgen/src/main.rs | 1 | ||||
| -rw-r--r-- | zjit/src/codegen.rs | 13 | ||||
| -rw-r--r-- | zjit/src/cruby.rs | 13 | ||||
| -rw-r--r-- | zjit/src/cruby_bindings.inc.rs | 1 | ||||
| -rw-r--r-- | zjit/src/hir.rs | 163 |
5 files changed, 186 insertions, 5 deletions
diff --git a/zjit/bindgen/src/main.rs b/zjit/bindgen/src/main.rs index c4fd625818..c4233521cc 100644 --- a/zjit/bindgen/src/main.rs +++ b/zjit/bindgen/src/main.rs @@ -336,6 +336,7 @@ fn main() { .allowlist_function("rb_yarv_class_of") .allowlist_function("rb_zjit_class_initialized_p") .allowlist_function("rb_zjit_class_has_default_allocator") + .allowlist_function("rb_zjit_class_get_alloc_func") .allowlist_function("rb_get_ec_cfp") .allowlist_function("rb_get_cfp_iseq") .allowlist_function("rb_get_cfp_pc") diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs index 0b236d0b57..03228db0bb 100644 --- a/zjit/src/codegen.rs +++ b/zjit/src/codegen.rs @@ -1314,9 +1314,16 @@ fn gen_object_alloc_class(asm: &mut Assembler, class: VALUE, state: &FrameState) // Allocating an object for a known class with default allocator is leaf; see doc for // `ObjectAllocClass`. gen_prepare_leaf_call_with_gc(asm, state); - assert!(unsafe { rb_zjit_class_has_default_allocator(class) }, "class must have default allocator"); - // TODO(max): inline code to allocate an instance - asm_ccall!(asm, rb_class_allocate_instance, class.into()) + if unsafe { rb_zjit_class_has_default_allocator(class) } { + // TODO(max): inline code to allocate an instance + asm_ccall!(asm, rb_class_allocate_instance, class.into()) + } else { + assert!(class_has_leaf_allocator(class), "class passed to ObjectAllocClass must have a leaf allocator"); + let alloc_func = unsafe { rb_zjit_class_get_alloc_func(class) }; + assert!(alloc_func.is_some(), "class {} passed to ObjectAllocClass must have an allocator", get_class_name(class)); + asm_comment!(asm, "call allocator for class {}", get_class_name(class)); + asm.ccall(alloc_func.unwrap() as *const u8, vec![class.into()]) + } } /// Compile code that exits from JIT code with a return value diff --git a/zjit/src/cruby.rs b/zjit/src/cruby.rs index 374faa6d97..82d0582da4 100644 --- a/zjit/src/cruby.rs +++ b/zjit/src/cruby.rs @@ -1218,6 +1218,19 @@ pub fn get_class_name(class: VALUE) -> String { }).unwrap_or_else(|| "Unknown".to_string()) } +pub fn class_has_leaf_allocator(class: VALUE) -> bool { + // empty_hash_alloc + if class == unsafe { rb_cHash } { return true; } + // empty_ary_alloc + if class == unsafe { rb_cArray } { return true; } + // empty_str_alloc + if class == unsafe { rb_cString } { return true; } + // rb_reg_s_alloc + if class == unsafe { rb_cRegexp } { return true; } + // rb_class_allocate_instance + unsafe { rb_zjit_class_has_default_allocator(class) } +} + /// Interned ID values for Ruby symbols and method names. /// See [type@crate::cruby::ID] and usages outside of ZJIT. pub(crate) mod ids { diff --git a/zjit/src/cruby_bindings.inc.rs b/zjit/src/cruby_bindings.inc.rs index dfdb6c0f29..5a06d99f9e 100644 --- a/zjit/src/cruby_bindings.inc.rs +++ b/zjit/src/cruby_bindings.inc.rs @@ -948,6 +948,7 @@ unsafe extern "C" { recv: VALUE, ) -> *const rb_callable_method_entry_struct; pub fn rb_zjit_class_initialized_p(klass: VALUE) -> bool; + pub fn rb_zjit_class_get_alloc_func(klass: VALUE) -> rb_alloc_func_t; pub fn rb_zjit_class_has_default_allocator(klass: VALUE) -> bool; pub fn rb_iseq_encoded_size(iseq: *const rb_iseq_t) -> ::std::os::raw::c_uint; pub fn rb_iseq_pc_at_idx(iseq: *const rb_iseq_t, insn_idx: u32) -> *mut VALUE; diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index 449646024a..4750b187e0 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -1978,8 +1978,8 @@ impl Function { if unsafe { rb_zjit_singleton_class_p(class) } { self.push_insn_id(block, insn_id); continue; } - if !unsafe { rb_zjit_class_has_default_allocator(class) } { - // Custom or NULL allocator; could run arbitrary code. + if !class_has_leaf_allocator(class) { + // Custom, known unsafe, or NULL allocator; could run arbitrary code. self.push_insn_id(block, insn_id); continue; } let replacement = self.push_insn(block, Insn::ObjectAllocClass { class, state }); @@ -8297,6 +8297,165 @@ mod opt_tests { } #[test] + fn test_opt_new_object() { + eval(" + def test = Object.new + test + "); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:2: + bb0(v0:BasicObject): + PatchPoint SingleRactorMode + PatchPoint StableConstantNames(0x1000, Object) + v34:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008)) + v6:NilClass = Const Value(nil) + PatchPoint MethodRedefined(Object@0x1008, new@0x1010, cme:0x1018) + v37:HeapObject[class_exact:Object] = ObjectAllocClass VALUE(0x1008) + PatchPoint MethodRedefined(Object@0x1008, initialize@0x1040, cme:0x1048) + v39:NilClass = CCall initialize@0x1070, v37 + CheckInterrupts + CheckInterrupts + Return v37 + "); + } + + #[test] + fn test_opt_new_basic_object() { + eval(" + def test = BasicObject.new + test + "); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:2: + bb0(v0:BasicObject): + PatchPoint SingleRactorMode + PatchPoint StableConstantNames(0x1000, BasicObject) + v34:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008)) + v6:NilClass = Const Value(nil) + PatchPoint MethodRedefined(BasicObject@0x1008, new@0x1010, cme:0x1018) + v37:HeapObject[class_exact:BasicObject] = ObjectAllocClass VALUE(0x1008) + PatchPoint MethodRedefined(BasicObject@0x1008, initialize@0x1040, cme:0x1048) + v39:NilClass = CCall initialize@0x1070, v37 + CheckInterrupts + CheckInterrupts + Return v37 + "); + } + + #[test] + fn test_opt_new_hash() { + eval(" + def test = Hash.new + test + "); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:2: + bb0(v0:BasicObject): + PatchPoint SingleRactorMode + PatchPoint StableConstantNames(0x1000, Hash) + v34:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008)) + v6:NilClass = Const Value(nil) + PatchPoint MethodRedefined(Hash@0x1008, new@0x1010, cme:0x1018) + v37:HashExact = ObjectAllocClass VALUE(0x1008) + v12:BasicObject = SendWithoutBlock v37, :initialize + CheckInterrupts + CheckInterrupts + Return v37 + "); + } + + #[test] + fn test_opt_new_array() { + eval(" + def test = Array.new 1 + test + "); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:2: + bb0(v0:BasicObject): + PatchPoint SingleRactorMode + PatchPoint StableConstantNames(0x1000, Array) + v36:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008)) + v6:NilClass = Const Value(nil) + v7:Fixnum[1] = Const Value(1) + PatchPoint MethodRedefined(Array@0x1008, new@0x1010, cme:0x1018) + PatchPoint MethodRedefined(Class@0x1040, new@0x1010, cme:0x1018) + v45:BasicObject = CCallVariadic new@0x1048, v36, v7 + CheckInterrupts + Return v45 + "); + } + + #[test] + fn test_opt_new_set() { + eval(" + def test = Set.new + test + "); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:2: + bb0(v0:BasicObject): + PatchPoint SingleRactorMode + PatchPoint StableConstantNames(0x1000, Set) + v34:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008)) + v6:NilClass = Const Value(nil) + PatchPoint MethodRedefined(Set@0x1008, new@0x1010, cme:0x1018) + v10:HeapObject = ObjectAlloc v34 + PatchPoint MethodRedefined(Set@0x1008, initialize@0x1040, cme:0x1048) + v39:BasicObject = CCallVariadic initialize@0x1070, v10 + CheckInterrupts + CheckInterrupts + Return v10 + "); + } + + #[test] + fn test_opt_new_string() { + eval(" + def test = String.new + test + "); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:2: + bb0(v0:BasicObject): + PatchPoint SingleRactorMode + PatchPoint StableConstantNames(0x1000, String) + v34:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008)) + v6:NilClass = Const Value(nil) + PatchPoint MethodRedefined(String@0x1008, new@0x1010, cme:0x1018) + PatchPoint MethodRedefined(Class@0x1040, new@0x1010, cme:0x1018) + v43:BasicObject = CCallVariadic new@0x1048, v34 + CheckInterrupts + Return v43 + "); + } + + #[test] + fn test_opt_new_regexp() { + eval(" + def test = Regexp.new '' + test + "); + assert_snapshot!(hir_string("test"), @r" + fn test@<compiled>:2: + bb0(v0:BasicObject): + PatchPoint SingleRactorMode + PatchPoint StableConstantNames(0x1000, Regexp) + v38:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008)) + v6:NilClass = Const Value(nil) + v7:StringExact[VALUE(0x1010)] = Const Value(VALUE(0x1010)) + v9:StringExact = StringCopy v7 + PatchPoint MethodRedefined(Regexp@0x1008, new@0x1018, cme:0x1020) + v41:HeapObject[class_exact:Regexp] = ObjectAllocClass VALUE(0x1008) + PatchPoint MethodRedefined(Regexp@0x1008, initialize@0x1048, cme:0x1050) + v44:BasicObject = CCallVariadic initialize@0x1078, v41, v9 + CheckInterrupts + CheckInterrupts + Return v41 + "); + } + + #[test] fn test_opt_length() { eval(" def test(a,b) = [a,b].length |
