summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMax Bernstein <ruby@bernsteinbear.com>2025-09-18 13:47:35 -0400
committerMax Bernstein <tekknolagi@gmail.com>2025-09-19 22:38:29 -0400
commit254b9b4952156ee443b920e89fec7b6d16a0f785 (patch)
tree00456298f98cc113cd347a805d3428320dd29264
parent4a04e6f7555c94e0e58ad090681e359c7dcbf22e (diff)
ZJIT: Expand the list of safe allocators
It's not just the default allocator; other allocators are also leaf.
-rw-r--r--zjit/bindgen/src/main.rs1
-rw-r--r--zjit/src/codegen.rs13
-rw-r--r--zjit/src/cruby.rs13
-rw-r--r--zjit/src/cruby_bindings.inc.rs1
-rw-r--r--zjit/src/hir.rs163
5 files changed, 186 insertions, 5 deletions
diff --git a/zjit/bindgen/src/main.rs b/zjit/bindgen/src/main.rs
index c4fd625818..c4233521cc 100644
--- a/zjit/bindgen/src/main.rs
+++ b/zjit/bindgen/src/main.rs
@@ -336,6 +336,7 @@ fn main() {
.allowlist_function("rb_yarv_class_of")
.allowlist_function("rb_zjit_class_initialized_p")
.allowlist_function("rb_zjit_class_has_default_allocator")
+ .allowlist_function("rb_zjit_class_get_alloc_func")
.allowlist_function("rb_get_ec_cfp")
.allowlist_function("rb_get_cfp_iseq")
.allowlist_function("rb_get_cfp_pc")
diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs
index 0b236d0b57..03228db0bb 100644
--- a/zjit/src/codegen.rs
+++ b/zjit/src/codegen.rs
@@ -1314,9 +1314,16 @@ fn gen_object_alloc_class(asm: &mut Assembler, class: VALUE, state: &FrameState)
// Allocating an object for a known class with default allocator is leaf; see doc for
// `ObjectAllocClass`.
gen_prepare_leaf_call_with_gc(asm, state);
- assert!(unsafe { rb_zjit_class_has_default_allocator(class) }, "class must have default allocator");
- // TODO(max): inline code to allocate an instance
- asm_ccall!(asm, rb_class_allocate_instance, class.into())
+ if unsafe { rb_zjit_class_has_default_allocator(class) } {
+ // TODO(max): inline code to allocate an instance
+ asm_ccall!(asm, rb_class_allocate_instance, class.into())
+ } else {
+ assert!(class_has_leaf_allocator(class), "class passed to ObjectAllocClass must have a leaf allocator");
+ let alloc_func = unsafe { rb_zjit_class_get_alloc_func(class) };
+ assert!(alloc_func.is_some(), "class {} passed to ObjectAllocClass must have an allocator", get_class_name(class));
+ asm_comment!(asm, "call allocator for class {}", get_class_name(class));
+ asm.ccall(alloc_func.unwrap() as *const u8, vec![class.into()])
+ }
}
/// Compile code that exits from JIT code with a return value
diff --git a/zjit/src/cruby.rs b/zjit/src/cruby.rs
index 374faa6d97..82d0582da4 100644
--- a/zjit/src/cruby.rs
+++ b/zjit/src/cruby.rs
@@ -1218,6 +1218,19 @@ pub fn get_class_name(class: VALUE) -> String {
}).unwrap_or_else(|| "Unknown".to_string())
}
+pub fn class_has_leaf_allocator(class: VALUE) -> bool {
+ // empty_hash_alloc
+ if class == unsafe { rb_cHash } { return true; }
+ // empty_ary_alloc
+ if class == unsafe { rb_cArray } { return true; }
+ // empty_str_alloc
+ if class == unsafe { rb_cString } { return true; }
+ // rb_reg_s_alloc
+ if class == unsafe { rb_cRegexp } { return true; }
+ // rb_class_allocate_instance
+ unsafe { rb_zjit_class_has_default_allocator(class) }
+}
+
/// Interned ID values for Ruby symbols and method names.
/// See [type@crate::cruby::ID] and usages outside of ZJIT.
pub(crate) mod ids {
diff --git a/zjit/src/cruby_bindings.inc.rs b/zjit/src/cruby_bindings.inc.rs
index dfdb6c0f29..5a06d99f9e 100644
--- a/zjit/src/cruby_bindings.inc.rs
+++ b/zjit/src/cruby_bindings.inc.rs
@@ -948,6 +948,7 @@ unsafe extern "C" {
recv: VALUE,
) -> *const rb_callable_method_entry_struct;
pub fn rb_zjit_class_initialized_p(klass: VALUE) -> bool;
+ pub fn rb_zjit_class_get_alloc_func(klass: VALUE) -> rb_alloc_func_t;
pub fn rb_zjit_class_has_default_allocator(klass: VALUE) -> bool;
pub fn rb_iseq_encoded_size(iseq: *const rb_iseq_t) -> ::std::os::raw::c_uint;
pub fn rb_iseq_pc_at_idx(iseq: *const rb_iseq_t, insn_idx: u32) -> *mut VALUE;
diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs
index 449646024a..4750b187e0 100644
--- a/zjit/src/hir.rs
+++ b/zjit/src/hir.rs
@@ -1978,8 +1978,8 @@ impl Function {
if unsafe { rb_zjit_singleton_class_p(class) } {
self.push_insn_id(block, insn_id); continue;
}
- if !unsafe { rb_zjit_class_has_default_allocator(class) } {
- // Custom or NULL allocator; could run arbitrary code.
+ if !class_has_leaf_allocator(class) {
+ // Custom, known unsafe, or NULL allocator; could run arbitrary code.
self.push_insn_id(block, insn_id); continue;
}
let replacement = self.push_insn(block, Insn::ObjectAllocClass { class, state });
@@ -8297,6 +8297,165 @@ mod opt_tests {
}
#[test]
+ fn test_opt_new_object() {
+ eval("
+ def test = Object.new
+ test
+ ");
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:2:
+ bb0(v0:BasicObject):
+ PatchPoint SingleRactorMode
+ PatchPoint StableConstantNames(0x1000, Object)
+ v34:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008))
+ v6:NilClass = Const Value(nil)
+ PatchPoint MethodRedefined(Object@0x1008, new@0x1010, cme:0x1018)
+ v37:HeapObject[class_exact:Object] = ObjectAllocClass VALUE(0x1008)
+ PatchPoint MethodRedefined(Object@0x1008, initialize@0x1040, cme:0x1048)
+ v39:NilClass = CCall initialize@0x1070, v37
+ CheckInterrupts
+ CheckInterrupts
+ Return v37
+ ");
+ }
+
+ #[test]
+ fn test_opt_new_basic_object() {
+ eval("
+ def test = BasicObject.new
+ test
+ ");
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:2:
+ bb0(v0:BasicObject):
+ PatchPoint SingleRactorMode
+ PatchPoint StableConstantNames(0x1000, BasicObject)
+ v34:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008))
+ v6:NilClass = Const Value(nil)
+ PatchPoint MethodRedefined(BasicObject@0x1008, new@0x1010, cme:0x1018)
+ v37:HeapObject[class_exact:BasicObject] = ObjectAllocClass VALUE(0x1008)
+ PatchPoint MethodRedefined(BasicObject@0x1008, initialize@0x1040, cme:0x1048)
+ v39:NilClass = CCall initialize@0x1070, v37
+ CheckInterrupts
+ CheckInterrupts
+ Return v37
+ ");
+ }
+
+ #[test]
+ fn test_opt_new_hash() {
+ eval("
+ def test = Hash.new
+ test
+ ");
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:2:
+ bb0(v0:BasicObject):
+ PatchPoint SingleRactorMode
+ PatchPoint StableConstantNames(0x1000, Hash)
+ v34:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008))
+ v6:NilClass = Const Value(nil)
+ PatchPoint MethodRedefined(Hash@0x1008, new@0x1010, cme:0x1018)
+ v37:HashExact = ObjectAllocClass VALUE(0x1008)
+ v12:BasicObject = SendWithoutBlock v37, :initialize
+ CheckInterrupts
+ CheckInterrupts
+ Return v37
+ ");
+ }
+
+ #[test]
+ fn test_opt_new_array() {
+ eval("
+ def test = Array.new 1
+ test
+ ");
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:2:
+ bb0(v0:BasicObject):
+ PatchPoint SingleRactorMode
+ PatchPoint StableConstantNames(0x1000, Array)
+ v36:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008))
+ v6:NilClass = Const Value(nil)
+ v7:Fixnum[1] = Const Value(1)
+ PatchPoint MethodRedefined(Array@0x1008, new@0x1010, cme:0x1018)
+ PatchPoint MethodRedefined(Class@0x1040, new@0x1010, cme:0x1018)
+ v45:BasicObject = CCallVariadic new@0x1048, v36, v7
+ CheckInterrupts
+ Return v45
+ ");
+ }
+
+ #[test]
+ fn test_opt_new_set() {
+ eval("
+ def test = Set.new
+ test
+ ");
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:2:
+ bb0(v0:BasicObject):
+ PatchPoint SingleRactorMode
+ PatchPoint StableConstantNames(0x1000, Set)
+ v34:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008))
+ v6:NilClass = Const Value(nil)
+ PatchPoint MethodRedefined(Set@0x1008, new@0x1010, cme:0x1018)
+ v10:HeapObject = ObjectAlloc v34
+ PatchPoint MethodRedefined(Set@0x1008, initialize@0x1040, cme:0x1048)
+ v39:BasicObject = CCallVariadic initialize@0x1070, v10
+ CheckInterrupts
+ CheckInterrupts
+ Return v10
+ ");
+ }
+
+ #[test]
+ fn test_opt_new_string() {
+ eval("
+ def test = String.new
+ test
+ ");
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:2:
+ bb0(v0:BasicObject):
+ PatchPoint SingleRactorMode
+ PatchPoint StableConstantNames(0x1000, String)
+ v34:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008))
+ v6:NilClass = Const Value(nil)
+ PatchPoint MethodRedefined(String@0x1008, new@0x1010, cme:0x1018)
+ PatchPoint MethodRedefined(Class@0x1040, new@0x1010, cme:0x1018)
+ v43:BasicObject = CCallVariadic new@0x1048, v34
+ CheckInterrupts
+ Return v43
+ ");
+ }
+
+ #[test]
+ fn test_opt_new_regexp() {
+ eval("
+ def test = Regexp.new ''
+ test
+ ");
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:2:
+ bb0(v0:BasicObject):
+ PatchPoint SingleRactorMode
+ PatchPoint StableConstantNames(0x1000, Regexp)
+ v38:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008))
+ v6:NilClass = Const Value(nil)
+ v7:StringExact[VALUE(0x1010)] = Const Value(VALUE(0x1010))
+ v9:StringExact = StringCopy v7
+ PatchPoint MethodRedefined(Regexp@0x1008, new@0x1018, cme:0x1020)
+ v41:HeapObject[class_exact:Regexp] = ObjectAllocClass VALUE(0x1008)
+ PatchPoint MethodRedefined(Regexp@0x1008, initialize@0x1048, cme:0x1050)
+ v44:BasicObject = CCallVariadic initialize@0x1078, v41, v9
+ CheckInterrupts
+ CheckInterrupts
+ Return v41
+ ");
+ }
+
+ #[test]
fn test_opt_length() {
eval("
def test(a,b) = [a,b].length