summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--jit.c6
-rw-r--r--test/ruby/test_zjit.rb30
-rw-r--r--yjit.c6
-rw-r--r--yjit/bindgen/src/main.rs2
-rw-r--r--yjit/src/codegen.rs4
-rw-r--r--yjit/src/cruby_bindings.inc.rs2
-rw-r--r--zjit/bindgen/src/main.rs1
-rw-r--r--zjit/src/codegen.rs14
-rw-r--r--zjit/src/cruby_bindings.inc.rs1
-rw-r--r--zjit/src/hir.rs137
-rw-r--r--zjit/src/hir_type/mod.rs17
-rw-r--r--zjit/src/stats.rs1
12 files changed, 204 insertions, 17 deletions
diff --git a/jit.c b/jit.c
index f233a2f01f..0b491f0481 100644
--- a/jit.c
+++ b/jit.c
@@ -450,6 +450,12 @@ rb_yarv_ary_entry_internal(VALUE ary, long offset)
return rb_ary_entry_internal(ary, offset);
}
+long
+rb_jit_array_len(VALUE a)
+{
+ return rb_array_len(a);
+}
+
void
rb_set_cfp_pc(struct rb_control_frame_struct *cfp, const VALUE *pc)
{
diff --git a/test/ruby/test_zjit.rb b/test/ruby/test_zjit.rb
index faf717096a..b0c717bc24 100644
--- a/test/ruby/test_zjit.rb
+++ b/test/ruby/test_zjit.rb
@@ -1714,6 +1714,36 @@ class TestZJIT < Test::Unit::TestCase
}, call_threshold: 1, insns: [:opt_getconstant_path]
end
+ def test_expandarray_no_splat
+ assert_compiles '[3, 4]', %q{
+ def test(o)
+ a, b = o
+ [a, b]
+ end
+ test [3, 4]
+ }, call_threshold: 1, insns: [:expandarray]
+ end
+
+ def test_expandarray_splat
+ assert_compiles '[3, [4]]', %q{
+ def test(o)
+ a, *b = o
+ [a, b]
+ end
+ test [3, 4]
+ }, call_threshold: 1, insns: [:expandarray]
+ end
+
+ def test_expandarray_splat_post
+ assert_compiles '[3, [4], 5]', %q{
+ def test(o)
+ a, *b, c = o
+ [a, b, c]
+ end
+ test [3, 4, 5]
+ }, call_threshold: 1, insns: [:expandarray]
+ end
+
def test_getconstant_path_autoload
# A constant-referencing expression can run arbitrary code through Kernel#autoload.
Dir.mktmpdir('autoload') do |tmpdir|
diff --git a/yjit.c b/yjit.c
index 57b09d73b0..598fe57167 100644
--- a/yjit.c
+++ b/yjit.c
@@ -69,12 +69,6 @@ STATIC_ASSERT(pointer_tagging_scheme, USE_FLONUM);
// The "_yjit_" part is for trying to be informative. We might want different
// suffixes for symbols meant for Rust and symbols meant for broader CRuby.
-long
-rb_yjit_array_len(VALUE a)
-{
- return rb_array_len(a);
-}
-
# define PTR2NUM(x) (rb_int2inum((intptr_t)(void *)(x)))
// For a given raw_sample (frame), set the hash with the caller's
diff --git a/yjit/bindgen/src/main.rs b/yjit/bindgen/src/main.rs
index 29b17346cd..0d4d57e069 100644
--- a/yjit/bindgen/src/main.rs
+++ b/yjit/bindgen/src/main.rs
@@ -381,7 +381,7 @@ fn main() {
.allowlist_function("rb_METHOD_ENTRY_VISI")
.allowlist_function("rb_RCLASS_ORIGIN")
.allowlist_function("rb_method_basic_definition_p")
- .allowlist_function("rb_yjit_array_len")
+ .allowlist_function("rb_jit_array_len")
.allowlist_function("rb_obj_class")
.allowlist_function("rb_obj_is_proc")
.allowlist_function("rb_vm_base_ptr")
diff --git a/yjit/src/codegen.rs b/yjit/src/codegen.rs
index 25a8545e85..bf758a4f62 100644
--- a/yjit/src/codegen.rs
+++ b/yjit/src/codegen.rs
@@ -2305,7 +2305,7 @@ fn gen_expandarray(
}
// Get the compile-time array length
- let comptime_len = unsafe { rb_yjit_array_len(comptime_recv) as u32 };
+ let comptime_len = unsafe { rb_jit_array_len(comptime_recv) as u32 };
// Move the array from the stack and check that it's an array.
guard_object_is_array(
@@ -7603,7 +7603,7 @@ fn gen_send_iseq(
gen_counter_incr(jit, asm, Counter::send_iseq_splat_not_array);
return None;
} else {
- unsafe { rb_yjit_array_len(array) as u32}
+ unsafe { rb_jit_array_len(array) as u32}
};
// Arity check accounting for size of the splat. When callee has rest parameters, we insert
diff --git a/yjit/src/cruby_bindings.inc.rs b/yjit/src/cruby_bindings.inc.rs
index 1e34440460..0a14a69928 100644
--- a/yjit/src/cruby_bindings.inc.rs
+++ b/yjit/src/cruby_bindings.inc.rs
@@ -1113,7 +1113,6 @@ extern "C" {
lines: *mut ::std::os::raw::c_int,
) -> ::std::os::raw::c_int;
pub fn rb_jit_cont_each_iseq(callback: rb_iseq_callback, data: *mut ::std::os::raw::c_void);
- pub fn rb_yjit_array_len(a: VALUE) -> ::std::os::raw::c_long;
pub fn rb_yjit_exit_locations_dict(
yjit_raw_samples: *mut VALUE,
yjit_line_samples: *mut ::std::os::raw::c_int,
@@ -1250,6 +1249,7 @@ extern "C" {
pub fn rb_IMEMO_TYPE_P(imemo: VALUE, imemo_type: imemo_type) -> ::std::os::raw::c_int;
pub fn rb_assert_cme_handle(handle: VALUE);
pub fn rb_yarv_ary_entry_internal(ary: VALUE, offset: ::std::os::raw::c_long) -> VALUE;
+ pub fn rb_jit_array_len(a: VALUE) -> ::std::os::raw::c_long;
pub fn rb_set_cfp_pc(cfp: *mut rb_control_frame_struct, pc: *const VALUE);
pub fn rb_set_cfp_sp(cfp: *mut rb_control_frame_struct, sp: *mut VALUE);
pub fn rb_jit_shape_too_complex_p(shape_id: shape_id_t) -> bool;
diff --git a/zjit/bindgen/src/main.rs b/zjit/bindgen/src/main.rs
index b54e9404fd..f13b61acf0 100644
--- a/zjit/bindgen/src/main.rs
+++ b/zjit/bindgen/src/main.rs
@@ -272,6 +272,7 @@ fn main() {
.allowlist_function("rb_jit_mark_executable")
.allowlist_function("rb_jit_mark_unused")
.allowlist_function("rb_jit_get_page_size")
+ .allowlist_function("rb_jit_array_len")
.allowlist_function("rb_zjit_iseq_builtin_attrs")
.allowlist_function("rb_zjit_iseq_inspect")
.allowlist_function("rb_zjit_iseq_insn_set")
diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs
index 42501de242..ea4a7ecc7c 100644
--- a/zjit/src/codegen.rs
+++ b/zjit/src/codegen.rs
@@ -356,6 +356,7 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio
Insn::NewRangeFixnum { low, high, flag, state } => gen_new_range_fixnum(asm, opnd!(low), opnd!(high), *flag, &function.frame_state(*state)),
Insn::ArrayDup { val, state } => gen_array_dup(asm, opnd!(val), &function.frame_state(*state)),
Insn::ArrayArefFixnum { array, index, .. } => gen_aref_fixnum(asm, opnd!(array), opnd!(index)),
+ Insn::ArrayLength { array } => gen_array_length(asm, opnd!(array)),
Insn::ObjectAlloc { val, state } => gen_object_alloc(jit, asm, opnd!(val), &function.frame_state(*state)),
&Insn::ObjectAllocClass { class, state } => gen_object_alloc_class(asm, class, &function.frame_state(state)),
Insn::StringCopy { val, chilled, state } => gen_string_copy(asm, opnd!(val), *chilled, &function.frame_state(*state)),
@@ -1258,6 +1259,10 @@ fn gen_aref_fixnum(
asm_ccall!(asm, rb_ary_entry, array, unboxed_idx)
}
+fn gen_array_length(asm: &mut Assembler, array: Opnd) -> lir::Opnd {
+ asm_ccall!(asm, rb_jit_array_len, array)
+}
+
/// Compile a new hash instruction
fn gen_new_hash(
jit: &mut JITState,
@@ -1589,8 +1594,13 @@ fn gen_guard_type_not(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, g
}
/// Compile an identity check with a side exit
-fn gen_guard_bit_equals(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, expected: VALUE, state: &FrameState) -> lir::Opnd {
- asm.cmp(val, Opnd::Value(expected));
+fn gen_guard_bit_equals(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, expected: crate::hir::Const, state: &FrameState) -> lir::Opnd {
+ let expected_opnd: Opnd = match expected {
+ crate::hir::Const::Value(v) => { Opnd::Value(v) }
+ crate::hir::Const::CInt64(v) => { v.into() }
+ _ => panic!("gen_guard_bit_equals: unexpected hir::Const {:?}", expected),
+ };
+ asm.cmp(val, expected_opnd);
asm.jnz(side_exit(jit, state, GuardBitEquals(expected)));
val
}
diff --git a/zjit/src/cruby_bindings.inc.rs b/zjit/src/cruby_bindings.inc.rs
index ffaafed5ff..c67e229a80 100644
--- a/zjit/src/cruby_bindings.inc.rs
+++ b/zjit/src/cruby_bindings.inc.rs
@@ -1045,6 +1045,7 @@ unsafe extern "C" {
pub fn rb_IMEMO_TYPE_P(imemo: VALUE, imemo_type: imemo_type) -> ::std::os::raw::c_int;
pub fn rb_assert_cme_handle(handle: VALUE);
pub fn rb_yarv_ary_entry_internal(ary: VALUE, offset: ::std::os::raw::c_long) -> VALUE;
+ pub fn rb_jit_array_len(a: VALUE) -> ::std::os::raw::c_long;
pub fn rb_set_cfp_pc(cfp: *mut rb_control_frame_struct, pc: *const VALUE);
pub fn rb_set_cfp_sp(cfp: *mut rb_control_frame_struct, sp: *mut VALUE);
pub fn rb_jit_shape_too_complex_p(shape_id: shape_id_t) -> bool;
diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs
index bccd27fc39..73b7cca8d0 100644
--- a/zjit/src/hir.rs
+++ b/zjit/src/hir.rs
@@ -270,7 +270,7 @@ impl<'a> std::fmt::Display for InvariantPrinter<'a> {
}
}
-#[derive(Debug, Clone, PartialEq)]
+#[derive(Debug, Clone, PartialEq, Copy)]
pub enum Const {
Value(VALUE),
CBool(bool),
@@ -460,7 +460,7 @@ pub enum SideExitReason {
GuardType(Type),
GuardTypeNot(Type),
GuardShape(ShapeId),
- GuardBitEquals(VALUE),
+ GuardBitEquals(Const),
PatchPoint(Invariant),
CalleeSideExit,
ObjToStringFallback,
@@ -583,6 +583,8 @@ pub enum Insn {
/// Push `val` onto `array`, where `array` is already `Array`.
ArrayPush { array: InsnId, val: InsnId, state: InsnId },
ArrayArefFixnum { array: InsnId, index: InsnId },
+ /// Return the length of the array as a C `long` ([`types::CInt64`])
+ ArrayLength { array: InsnId },
HashAref { hash: InsnId, key: InsnId, state: InsnId },
HashDup { val: InsnId, state: InsnId },
@@ -768,8 +770,8 @@ pub enum Insn {
/// Side-exit if val doesn't have the expected type.
GuardType { val: InsnId, guard_type: Type, state: InsnId },
GuardTypeNot { val: InsnId, guard_type: Type, state: InsnId },
- /// Side-exit if val is not the expected VALUE.
- GuardBitEquals { val: InsnId, expected: VALUE, state: InsnId },
+ /// Side-exit if val is not the expected Const.
+ GuardBitEquals { val: InsnId, expected: Const, state: InsnId },
/// Side-exit if val doesn't have the expected shape.
GuardShape { val: InsnId, shape: ShapeId, state: InsnId },
/// Side-exit if the block param has been modified or the block handler for the frame
@@ -899,6 +901,9 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> {
Insn::ArrayArefFixnum { array, index, .. } => {
write!(f, "ArrayArefFixnum {array}, {index}")
}
+ Insn::ArrayLength { array } => {
+ write!(f, "ArrayLength {array}")
+ }
Insn::NewHash { elements, .. } => {
write!(f, "NewHash")?;
let mut prefix = " ";
@@ -1604,6 +1609,7 @@ impl Function {
&NewRange { low, high, flag, state } => NewRange { low: find!(low), high: find!(high), flag, state: find!(state) },
&NewRangeFixnum { low, high, flag, state } => NewRangeFixnum { low: find!(low), high: find!(high), flag, state: find!(state) },
&ArrayArefFixnum { array, index } => ArrayArefFixnum { array: find!(array), index: find!(index) },
+ &ArrayLength { array } => ArrayLength { array: find!(array) },
&ArrayMax { ref elements, state } => ArrayMax { elements: find_vec!(elements), state: find!(state) },
&SetGlobal { id, val, state } => SetGlobal { id, val: find!(val), state },
&GetIvar { self_val, id, state } => GetIvar { self_val: find!(self_val), id, state },
@@ -1691,6 +1697,7 @@ impl Function {
Insn::NewArray { .. } => types::ArrayExact,
Insn::ArrayDup { .. } => types::ArrayExact,
Insn::ArrayArefFixnum { .. } => types::BasicObject,
+ Insn::ArrayLength { .. } => types::CInt64,
Insn::HashAref { .. } => types::BasicObject,
Insn::NewHash { .. } => types::HashExact,
Insn::HashDup { .. } => types::HashExact,
@@ -1703,7 +1710,7 @@ impl Function {
&Insn::CCallVariadic { return_type, .. } => return_type,
Insn::GuardType { val, guard_type, .. } => self.type_of(*val).intersection(*guard_type),
Insn::GuardTypeNot { .. } => types::BasicObject,
- Insn::GuardBitEquals { val, expected, .. } => self.type_of(*val).intersection(Type::from_value(*expected)),
+ Insn::GuardBitEquals { val, expected, .. } => self.type_of(*val).intersection(Type::from_const(*expected)),
Insn::GuardShape { val, .. } => self.type_of(*val),
Insn::FixnumAdd { .. } => types::Fixnum,
Insn::FixnumSub { .. } => types::Fixnum,
@@ -2803,6 +2810,9 @@ impl Function {
worklist.push_back(array);
worklist.push_back(index);
}
+ &Insn::ArrayLength { array } => {
+ worklist.push_back(array);
+ }
&Insn::HashAref { hash, key, state } => {
worklist.push_back(hash);
worklist.push_back(key);
@@ -4428,6 +4438,31 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
state.stack_push(result);
}
}
+ YARVINSN_expandarray => {
+ let num = get_arg(pc, 0).as_u64();
+ let flag = get_arg(pc, 1).as_u64();
+ if flag != 0 {
+ // We don't (yet) handle 0x01 (rest args), 0x02 (post args), or 0x04
+ // (reverse?)
+ //
+ // Unhandled opcode; side-exit into the interpreter
+ let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
+ fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::UnhandledYARVInsn(opcode) });
+ break; // End the block
+ }
+ let val = state.stack_pop()?;
+ let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
+ let array = fun.push_insn(block, Insn::GuardType { val, guard_type: types::ArrayExact, state: exit_id, });
+ let length = fun.push_insn(block, Insn::ArrayLength { array });
+ fun.push_insn(block, Insn::GuardBitEquals { val: length, expected: Const::CInt64(num as i64), state: exit_id });
+ for i in (0..num).rev() {
+ // TODO(max): Add a short-cut path for long indices into an array where the
+ // index is known to be in-bounds
+ let index = fun.push_insn(block, Insn::Const { val: Const::Value(VALUE::fixnum_from_usize(i.try_into().unwrap())) });
+ let element = fun.push_insn(block, Insn::ArrayArefFixnum { array, index });
+ state.stack_push(element);
+ }
+ }
_ => {
// Unhandled opcode; side-exit into the interpreter
let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
@@ -7914,6 +7949,98 @@ mod tests {
Return v17
");
}
+
+ #[test]
+ fn test_expandarray_no_splat() {
+ eval(r#"
+ def test(o)
+ a, b = o
+ end
+ "#);
+ assert_contains_opcode("test", YARVINSN_expandarray);
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:3:
+ bb0():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ v2:BasicObject = GetLocal l0, SP@6
+ v3:NilClass = Const Value(nil)
+ v4:NilClass = Const Value(nil)
+ Jump bb2(v1, v2, v3, v4)
+ bb1(v7:BasicObject, v8:BasicObject):
+ EntryPoint JIT(0)
+ v9:NilClass = Const Value(nil)
+ v10:NilClass = Const Value(nil)
+ Jump bb2(v7, v8, v9, v10)
+ bb2(v12:BasicObject, v13:BasicObject, v14:NilClass, v15:NilClass):
+ v20:ArrayExact = GuardType v13, ArrayExact
+ v21:CInt64 = ArrayLength v20
+ v22:CInt64[2] = GuardBitEquals v21, CInt64(2)
+ v23:Fixnum[1] = Const Value(1)
+ v24:BasicObject = ArrayArefFixnum v20, v23
+ v25:Fixnum[0] = Const Value(0)
+ v26:BasicObject = ArrayArefFixnum v20, v25
+ PatchPoint NoEPEscape(test)
+ CheckInterrupts
+ Return v13
+ ");
+ }
+
+ #[test]
+ fn test_expandarray_splat() {
+ eval(r#"
+ def test(o)
+ a, *b = o
+ end
+ "#);
+ assert_contains_opcode("test", YARVINSN_expandarray);
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:3:
+ bb0():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ v2:BasicObject = GetLocal l0, SP@6
+ v3:NilClass = Const Value(nil)
+ v4:NilClass = Const Value(nil)
+ Jump bb2(v1, v2, v3, v4)
+ bb1(v7:BasicObject, v8:BasicObject):
+ EntryPoint JIT(0)
+ v9:NilClass = Const Value(nil)
+ v10:NilClass = Const Value(nil)
+ Jump bb2(v7, v8, v9, v10)
+ bb2(v12:BasicObject, v13:BasicObject, v14:NilClass, v15:NilClass):
+ SideExit UnhandledYARVInsn(expandarray)
+ ");
+ }
+
+ #[test]
+ fn test_expandarray_splat_post() {
+ eval(r#"
+ def test(o)
+ a, *b, c = o
+ end
+ "#);
+ assert_contains_opcode("test", YARVINSN_expandarray);
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:3:
+ bb0():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ v2:BasicObject = GetLocal l0, SP@7
+ v3:NilClass = Const Value(nil)
+ v4:NilClass = Const Value(nil)
+ v5:NilClass = Const Value(nil)
+ Jump bb2(v1, v2, v3, v4, v5)
+ bb1(v8:BasicObject, v9:BasicObject):
+ EntryPoint JIT(0)
+ v10:NilClass = Const Value(nil)
+ v11:NilClass = Const Value(nil)
+ v12:NilClass = Const Value(nil)
+ Jump bb2(v8, v9, v10, v11, v12)
+ bb2(v14:BasicObject, v15:BasicObject, v16:NilClass, v17:NilClass, v18:NilClass):
+ SideExit UnhandledYARVInsn(expandarray)
+ ");
+ }
}
#[cfg(test)]
diff --git a/zjit/src/hir_type/mod.rs b/zjit/src/hir_type/mod.rs
index ffde7e458d..f24161657e 100644
--- a/zjit/src/hir_type/mod.rs
+++ b/zjit/src/hir_type/mod.rs
@@ -238,6 +238,23 @@ impl Type {
}
}
+ pub fn from_const(val: Const) -> Type {
+ match val {
+ Const::Value(v) => Self::from_value(v),
+ Const::CBool(v) => Self::from_cbool(v),
+ Const::CInt8(v) => Self::from_cint(types::CInt8, v as i64),
+ Const::CInt16(v) => Self::from_cint(types::CInt16, v as i64),
+ Const::CInt32(v) => Self::from_cint(types::CInt32, v as i64),
+ Const::CInt64(v) => Self::from_cint(types::CInt64, v as i64),
+ Const::CUInt8(v) => Self::from_cint(types::CUInt8, v as i64),
+ Const::CUInt16(v) => Self::from_cint(types::CUInt16, v as i64),
+ Const::CUInt32(v) => Self::from_cint(types::CUInt32, v as i64),
+ Const::CUInt64(v) => Self::from_cint(types::CUInt64, v as i64),
+ Const::CPtr(v) => Self::from_cptr(v),
+ Const::CDouble(v) => Self::from_double(v),
+ }
+ }
+
pub fn from_profiled_type(val: ProfiledType) -> Type {
if val.is_fixnum() { types::Fixnum }
else if val.is_flonum() { types::Flonum }
diff --git a/zjit/src/stats.rs b/zjit/src/stats.rs
index 843806e5be..50f6e61f5c 100644
--- a/zjit/src/stats.rs
+++ b/zjit/src/stats.rs
@@ -140,6 +140,7 @@ make_counters! {
exit_guard_type_failure,
exit_guard_type_not_failure,
exit_guard_bit_equals_failure,
+ exit_guard_int_equals_failure,
exit_guard_shape_failure,
exit_patchpoint_bop_redefined,
exit_patchpoint_method_redefined,