summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--zjit/src/codegen.rs2
-rw-r--r--zjit/src/hir.rs32
-rw-r--r--zjit/src/hir/opt_tests.rs51
-rw-r--r--zjit/src/stats.rs6
4 files changed, 81 insertions, 10 deletions
diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs
index 0a19035dc1..72cfa478f5 100644
--- a/zjit/src/codegen.rs
+++ b/zjit/src/codegen.rs
@@ -380,7 +380,7 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio
&Insn::SendWithoutBlock { cd, state, reason, .. } => gen_send_without_block(jit, asm, cd, &function.frame_state(state), reason),
// Give up SendWithoutBlockDirect for 6+ args since asm.ccall() doesn't support it.
Insn::SendWithoutBlockDirect { cd, state, args, .. } if args.len() + 1 > C_ARG_OPNDS.len() => // +1 for self
- gen_send_without_block(jit, asm, *cd, &function.frame_state(*state), SendFallbackReason::SendWithoutBlockDirectTooManyArgs),
+ gen_send_without_block(jit, asm, *cd, &function.frame_state(*state), SendFallbackReason::TooManyArgsForLir),
Insn::SendWithoutBlockDirect { cme, iseq, recv, args, state, .. } => gen_send_without_block_direct(cb, jit, asm, *cme, *iseq, opnd!(recv), opnds!(args), &function.frame_state(*state)),
&Insn::InvokeSuper { cd, blockiseq, state, reason, .. } => gen_invokesuper(jit, asm, cd, blockiseq, &function.frame_state(state), reason),
&Insn::InvokeBlock { cd, state, reason, .. } => gen_invokeblock(jit, asm, cd, &function.frame_state(state), reason),
diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs
index dcf378a282..a3b4834894 100644
--- a/zjit/src/hir.rs
+++ b/zjit/src/hir.rs
@@ -9,7 +9,7 @@ use crate::{
cast::IntoUsize, codegen::local_idx_to_ep_offset, cruby::*, payload::{get_or_create_iseq_payload, IseqPayload}, options::{debug, get_option, DumpHIR}, state::ZJITState
};
use std::{
- cell::RefCell, collections::{HashMap, HashSet, VecDeque}, ffi::{c_void, c_uint, CStr}, fmt::Display, mem::{align_of, size_of}, ptr, slice::Iter
+ cell::RefCell, collections::{HashMap, HashSet, VecDeque}, ffi::{c_void, c_uint, c_int, CStr}, fmt::Display, mem::{align_of, size_of}, ptr, slice::Iter
};
use crate::hir_type::{Type, types};
use crate::bitset::BitSet;
@@ -593,7 +593,6 @@ pub enum SendFallbackReason {
SendWithoutBlockCfuncArrayVariadic,
SendWithoutBlockNotOptimizedMethodType(MethodType),
SendWithoutBlockNotOptimizedMethodTypeOptimized(OptimizedMethodType),
- SendWithoutBlockDirectTooManyArgs,
SendWithoutBlockBopRedefined,
SendWithoutBlockOperandsNotFixnum,
SendPolymorphic,
@@ -604,8 +603,11 @@ pub enum SendFallbackReason {
SendNotOptimizedMethodType(MethodType),
CCallWithFrameTooManyArgs,
ObjToStringNotString,
+ TooManyArgsForLir,
/// The Proc object for a BMETHOD is not defined by an ISEQ. (See `enum rb_block_type`.)
BmethodNonIseqProc,
+ /// Caller supplies too few or too many arguments than what the callee's parameters expects.
+ ArgcParamMismatch,
/// The call has at least one feature on the caller or callee side that the optimizer does not
/// support.
ComplexArgPass,
@@ -1457,7 +1459,7 @@ pub enum ValidationError {
MiscValidationError(InsnId, String),
}
-fn can_direct_send(function: &mut Function, block: BlockId, iseq: *const rb_iseq_t) -> bool {
+fn can_direct_send(function: &mut Function, block: BlockId, iseq: *const rb_iseq_t, send_insn: InsnId, args: &[InsnId]) -> bool {
let mut can_send = true;
let mut count_failure = |counter| {
can_send = false;
@@ -1472,6 +1474,24 @@ fn can_direct_send(function: &mut Function, block: BlockId, iseq: *const rb_iseq
if unsafe { rb_get_iseq_flags_has_block(iseq) } { count_failure(complex_arg_pass_param_block) }
if unsafe { rb_get_iseq_flags_forwardable(iseq) } { count_failure(complex_arg_pass_param_forwardable) }
+ if !can_send {
+ function.set_dynamic_send_reason(send_insn, ComplexArgPass);
+ return false;
+ }
+
+ // Check argument count against callee's parameters. Note that correctness for this calculation
+ // relies on rejecting features above.
+ let lead_num = unsafe { get_iseq_body_param_lead_num(iseq) };
+ let opt_num = unsafe { get_iseq_body_param_opt_num(iseq) };
+ can_send = c_int::try_from(args.len())
+ .as_ref()
+ .map(|argc| (lead_num..=lead_num + opt_num).contains(argc))
+ .unwrap_or(false);
+ if !can_send {
+ function.set_dynamic_send_reason(send_insn, ArgcParamMismatch);
+ return false
+ }
+
can_send
}
@@ -2358,8 +2378,7 @@ impl Function {
// Only specialize positional-positional calls
// TODO(max): Handle other kinds of parameter passing
let iseq = unsafe { get_def_iseq_ptr((*cme).def) };
- if !can_direct_send(self, block, iseq) {
- self.set_dynamic_send_reason(insn_id, ComplexArgPass);
+ if !can_direct_send(self, block, iseq, insn_id, args.as_slice()) {
self.push_insn_id(block, insn_id); continue;
}
self.push_insn(block, Insn::PatchPoint { invariant: Invariant::MethodRedefined { klass, method: mid, cme }, state });
@@ -2384,8 +2403,7 @@ impl Function {
let capture = unsafe { proc_block.as_.captured.as_ref() };
let iseq = unsafe { *capture.code.iseq.as_ref() };
- if !can_direct_send(self, block, iseq) {
- self.set_dynamic_send_reason(insn_id, ComplexArgPass);
+ if !can_direct_send(self, block, iseq, insn_id, args.as_slice()) {
self.push_insn_id(block, insn_id); continue;
}
// Can't pass a block to a block for now
diff --git a/zjit/src/hir/opt_tests.rs b/zjit/src/hir/opt_tests.rs
index 58b21cdb84..d770b67094 100644
--- a/zjit/src/hir/opt_tests.rs
+++ b/zjit/src/hir/opt_tests.rs
@@ -6184,6 +6184,57 @@ mod hir_opt_tests {
}
#[test]
+ fn test_dont_optimize_when_passing_too_many_args() {
+ eval(r#"
+ public def foo(lead, opt=raise) = opt
+ def test = 0.foo(3, 3, 3)
+ "#);
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:3:
+ bb0():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ Jump bb2(v1)
+ bb1(v4:BasicObject):
+ EntryPoint JIT(0)
+ Jump bb2(v4)
+ bb2(v6:BasicObject):
+ v10:Fixnum[0] = Const Value(0)
+ v12:Fixnum[3] = Const Value(3)
+ v14:Fixnum[3] = Const Value(3)
+ v16:Fixnum[3] = Const Value(3)
+ IncrCounter complex_arg_pass_param_opt
+ v18:BasicObject = SendWithoutBlock v10, :foo, v12, v14, v16
+ CheckInterrupts
+ Return v18
+ ");
+ }
+
+ #[test]
+ fn test_dont_optimize_when_passing_too_few_args() {
+ eval(r#"
+ public def foo(lead, opt=raise) = opt
+ def test = 0.foo
+ "#);
+ assert_snapshot!(hir_string("test"), @r"
+ fn test@<compiled>:3:
+ bb0():
+ EntryPoint interpreter
+ v1:BasicObject = LoadSelf
+ Jump bb2(v1)
+ bb1(v4:BasicObject):
+ EntryPoint JIT(0)
+ Jump bb2(v4)
+ bb2(v6:BasicObject):
+ v10:Fixnum[0] = Const Value(0)
+ IncrCounter complex_arg_pass_param_opt
+ v12:BasicObject = SendWithoutBlock v10, :foo
+ CheckInterrupts
+ Return v12
+ ");
+ }
+
+ #[test]
fn test_dont_inline_integer_succ_with_args() {
eval("
def test = 4.succ 1
diff --git a/zjit/src/stats.rs b/zjit/src/stats.rs
index d91707b587..3874234f72 100644
--- a/zjit/src/stats.rs
+++ b/zjit/src/stats.rs
@@ -178,7 +178,7 @@ make_counters! {
send_fallback_send_without_block_cfunc_array_variadic,
send_fallback_send_without_block_not_optimized_method_type,
send_fallback_send_without_block_not_optimized_method_type_optimized,
- send_fallback_send_without_block_direct_too_many_args,
+ send_fallback_too_many_args_for_lir,
send_fallback_send_without_block_bop_redefined,
send_fallback_send_without_block_operands_not_fixnum,
send_fallback_send_polymorphic,
@@ -186,6 +186,7 @@ make_counters! {
send_fallback_send_no_profiles,
send_fallback_send_not_optimized_method_type,
send_fallback_ccall_with_frame_too_many_args,
+ send_fallback_argc_param_mismatch,
// The call has at least one feature on the caller or callee side
// that the optimizer does not support.
send_fallback_one_or_more_complex_arg_pass,
@@ -476,7 +477,7 @@ pub fn send_fallback_counter(reason: crate::hir::SendFallbackReason) -> Counter
SendWithoutBlockNotOptimizedMethodType(_) => send_fallback_send_without_block_not_optimized_method_type,
SendWithoutBlockNotOptimizedMethodTypeOptimized(_)
=> send_fallback_send_without_block_not_optimized_method_type_optimized,
- SendWithoutBlockDirectTooManyArgs => send_fallback_send_without_block_direct_too_many_args,
+ TooManyArgsForLir => send_fallback_too_many_args_for_lir,
SendWithoutBlockBopRedefined => send_fallback_send_without_block_bop_redefined,
SendWithoutBlockOperandsNotFixnum => send_fallback_send_without_block_operands_not_fixnum,
SendPolymorphic => send_fallback_send_polymorphic,
@@ -485,6 +486,7 @@ pub fn send_fallback_counter(reason: crate::hir::SendFallbackReason) -> Counter
SendCfuncVariadic => send_fallback_send_cfunc_variadic,
SendCfuncArrayVariadic => send_fallback_send_cfunc_array_variadic,
ComplexArgPass => send_fallback_one_or_more_complex_arg_pass,
+ ArgcParamMismatch => send_fallback_argc_param_mismatch,
BmethodNonIseqProc => send_fallback_bmethod_non_iseq_proc,
SendNotOptimizedMethodType(_) => send_fallback_send_not_optimized_method_type,
CCallWithFrameTooManyArgs => send_fallback_ccall_with_frame_too_many_args,