diff options
| -rw-r--r-- | zjit/src/hir.rs | 24 | ||||
| -rw-r--r-- | zjit/src/hir/opt_tests.rs | 120 | ||||
| -rw-r--r-- | zjit/src/invariants.rs | 24 |
3 files changed, 128 insertions, 40 deletions
diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index 7812c6058e..3fd5fcc5cb 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -6087,14 +6087,12 @@ pub fn jit_entry_insns(iseq: IseqPtr) -> Vec<u32> { struct BytecodeInfo { jump_targets: Vec<u32>, - has_blockiseq: bool, } fn compute_bytecode_info(iseq: *const rb_iseq_t, opt_table: &[u32]) -> BytecodeInfo { let iseq_size = unsafe { get_iseq_encoded_size(iseq) }; let mut insn_idx = 0; let mut jump_targets: HashSet<u32> = opt_table.iter().copied().collect(); - let mut has_blockiseq = false; while insn_idx < iseq_size { // Get the current pc and opcode let pc = unsafe { rb_iseq_pc_at_idx(iseq, insn_idx) }; @@ -6118,18 +6116,12 @@ fn compute_bytecode_info(iseq: *const rb_iseq_t, opt_table: &[u32]) -> BytecodeI jump_targets.insert(insn_idx); } } - YARVINSN_send | YARVINSN_invokesuper => { - let blockiseq: IseqPtr = get_arg(pc, 1).as_iseq(); - if !blockiseq.is_null() { - has_blockiseq = true; - } - } _ => {} } } let mut result = jump_targets.into_iter().collect::<Vec<_>>(); result.sort(); - BytecodeInfo { jump_targets: result, has_blockiseq } + BytecodeInfo { jump_targets: result } } #[derive(Debug, PartialEq, Clone, Copy)] @@ -6244,7 +6236,7 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> { // Compute a map of PC->Block by finding jump targets let jit_entry_insns = jit_entry_insns(iseq); - let BytecodeInfo { jump_targets, has_blockiseq } = compute_bytecode_info(iseq, &jit_entry_insns); + let BytecodeInfo { jump_targets } = compute_bytecode_info(iseq, &jit_entry_insns); // Make all empty basic blocks. The ordering of the BBs matters for getting fallthrough jumps // in good places, but it's not necessary for correctness. TODO: Higher quality scheduling during lowering. @@ -6276,7 +6268,11 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> { // Check if the EP is escaped for the ISEQ from the beginning. We give up // optimizing locals in that case because they're shared with other frames. - let ep_escaped = iseq_escapes_ep(iseq); + let ep_starts_escaped = iseq_escapes_ep(iseq); + // Check if the EP has been escaped at some point in the ISEQ. If it has, then we assume that + // its EP is shared with other frames. + let ep_has_been_escaped = crate::invariants::iseq_escapes_ep(iseq); + let ep_escaped = ep_starts_escaped || ep_has_been_escaped; // Iteratively fill out basic blocks using a queue. // TODO(max): Basic block arguments at edges @@ -6620,7 +6616,7 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> { // Use FrameState to get kw_bits when possible, just like getlocal_WC_0. let val = if !local_inval { state.getlocal(ep_offset) - } else if ep_escaped || has_blockiseq { + } else if ep_escaped { fun.push_insn(block, Insn::GetLocal { ep_offset, level: 0, use_sp: false, rest_param: false }) } else { let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state.without_locals() }); @@ -6743,7 +6739,7 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> { // In case of JIT-to-JIT send locals might never end up in EP memory. let val = state.getlocal(ep_offset); state.stack_push(val); - } else if ep_escaped || has_blockiseq { // TODO: figure out how to drop has_blockiseq here + } else if ep_escaped { // Read the local using EP let val = fun.push_insn(block, Insn::GetLocal { ep_offset, level: 0, use_sp: false, rest_param: false }); state.setlocal(ep_offset, val); // remember the result to spill on side-exits @@ -6764,7 +6760,7 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> { YARVINSN_setlocal_WC_0 => { let ep_offset = get_arg(pc, 0).as_u32(); let val = state.stack_pop()?; - if ep_escaped || has_blockiseq { // TODO: figure out how to drop has_blockiseq here + if ep_escaped { // Write the local using EP fun.push_insn(block, Insn::SetLocal { val, ep_offset, level: 0 }); } else if local_inval { diff --git a/zjit/src/hir/opt_tests.rs b/zjit/src/hir/opt_tests.rs index d0d2e19078..afbbc8bedc 100644 --- a/zjit/src/hir/opt_tests.rs +++ b/zjit/src/hir/opt_tests.rs @@ -2926,15 +2926,14 @@ mod hir_opt_tests { Jump bb2(v5, v6) bb2(v8:BasicObject, v9:NilClass): v13:Fixnum[1] = Const Value(1) - SetLocal :a, l0, EP@3, v13 PatchPoint NoSingletonClass(Object@0x1000) PatchPoint MethodRedefined(Object@0x1000, foo@0x1008, cme:0x1010) v31:HeapObject[class_exact*:Object@VALUE(0x1000)] = GuardType v8, HeapObject[class_exact*:Object@VALUE(0x1000)] IncrCounter inline_iseq_optimized_send_count - v20:BasicObject = GetLocal :a, l0, EP@3 - v24:BasicObject = GetLocal :a, l0, EP@3 + v19:BasicObject = GetLocal :a, l0, EP@3 + PatchPoint NoEPEscape(test) CheckInterrupts - Return v24 + Return v19 "); } @@ -3415,15 +3414,14 @@ mod hir_opt_tests { Jump bb2(v6, v7, v8) bb2(v10:BasicObject, v11:BasicObject, v12:NilClass): v16:ArrayExact = NewArray - SetLocal :a, l0, EP@3, v16 - v22:TrueClass = Const Value(true) + v21:TrueClass = Const Value(true) IncrCounter complex_arg_pass_caller_kwarg - v24:BasicObject = Send v11, 0x1000, :each_line, v22 # SendFallbackReason: Complex argument passing - v25:BasicObject = GetLocal :s, l0, EP@4 - v26:BasicObject = GetLocal :a, l0, EP@3 - v30:BasicObject = GetLocal :a, l0, EP@3 + v23:BasicObject = Send v11, 0x1000, :each_line, v21 # SendFallbackReason: Complex argument passing + v24:BasicObject = GetLocal :s, l0, EP@4 + v25:BasicObject = GetLocal :a, l0, EP@3 + PatchPoint NoEPEscape(test) CheckInterrupts - Return v30 + Return v25 "); } @@ -6523,7 +6521,6 @@ mod hir_opt_tests { Jump bb2(v5, v6) bb2(v8:BasicObject, v9:NilClass): v13:ArrayExact = NewArray - SetLocal :result, l0, EP@3, v13 PatchPoint SingleRactorMode PatchPoint StableConstantNames(0x1000, A) v36:ArrayExact[VALUE(0x1008)] = Const Value(VALUE(0x1008)) @@ -6533,10 +6530,10 @@ mod hir_opt_tests { PatchPoint NoSingletonClass(Array@0x1020) PatchPoint MethodRedefined(Array@0x1020, zip@0x1028, cme:0x1030) v43:BasicObject = CCallVariadic v36, :zip@0x1058, v39 - v25:BasicObject = GetLocal :result, l0, EP@3 - v29:BasicObject = GetLocal :result, l0, EP@3 + v24:BasicObject = GetLocal :result, l0, EP@3 + PatchPoint NoEPEscape(test) CheckInterrupts - Return v29 + Return v24 "); } @@ -11887,4 +11884,97 @@ mod hir_opt_tests { Return v31 "); } + + #[test] + fn recompile_after_ep_escape_uses_ep_locals() { + // When a method creates a lambda, EP escapes to the heap. After + // invalidation and recompilation, the compiler must use EP-based + // locals (SetLocal/GetLocal) instead of SSA locals, because the + // spill target (stack) and the read target (heap EP) diverge. + eval(" + CONST = {}.freeze + def test_ep_escape(list, sep=nil, iter_method=:each) + sep ||= lambda { } + kwsplat = CONST + list.__send__(iter_method) {|*v| yield(*v) } + end + + test_ep_escape({a: 1}, nil, :each_pair) { |k, v| + test_ep_escape([1], lambda { }) { |x| } + } + test_ep_escape({a: 1}, nil, :each_pair) { |k, v| + test_ep_escape([1], lambda { }) { |x| } + } + "); + assert_snapshot!(hir_string("test_ep_escape"), @r" + fn test_ep_escape@<compiled>:3: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal :list, l0, SP@7 + v3:BasicObject = GetLocal :sep, l0, SP@6 + v4:BasicObject = GetLocal :iter_method, l0, SP@5 + v5:NilClass = Const Value(nil) + v6:CPtr = LoadPC + v7:CPtr[CPtr(0x1000)] = Const CPtr(0x1008) + v8:CBool = IsBitEqual v6, v7 + IfTrue v8, bb2(v1, v2, v3, v4, v5) + v10:CPtr[CPtr(0x1000)] = Const CPtr(0x1008) + v11:CBool = IsBitEqual v6, v10 + IfTrue v11, bb4(v1, v2, v3, v4, v5) + Jump bb6(v1, v2, v3, v4, v5) + bb1(v15:BasicObject, v16:BasicObject): + EntryPoint JIT(0) + v17:NilClass = Const Value(nil) + v18:NilClass = Const Value(nil) + v19:NilClass = Const Value(nil) + Jump bb2(v15, v16, v17, v18, v19) + bb2(v35:BasicObject, v36:BasicObject, v37:BasicObject, v38:BasicObject, v39:NilClass): + v42:NilClass = Const Value(nil) + SetLocal :sep, l0, EP@5, v42 + Jump bb4(v35, v36, v42, v38, v39) + bb3(v22:BasicObject, v23:BasicObject, v24:BasicObject): + EntryPoint JIT(1) + v25:NilClass = Const Value(nil) + v26:NilClass = Const Value(nil) + Jump bb4(v22, v23, v24, v25, v26) + bb4(v46:BasicObject, v47:BasicObject, v48:BasicObject, v49:BasicObject, v50:NilClass): + v53:StaticSymbol[:each] = Const Value(VALUE(0x1010)) + SetLocal :iter_method, l0, EP@4, v53 + Jump bb6(v46, v47, v48, v53, v50) + bb5(v29:BasicObject, v30:BasicObject, v31:BasicObject, v32:BasicObject): + EntryPoint JIT(2) + v33:NilClass = Const Value(nil) + Jump bb6(v29, v30, v31, v32, v33) + bb6(v57:BasicObject, v58:BasicObject, v59:BasicObject, v60:BasicObject, v61:NilClass): + CheckInterrupts + v67:CBool = Test v59 + v68:Truthy = RefineType v59, Truthy + IfTrue v67, bb7(v57, v58, v68, v60, v61) + v70:Falsy = RefineType v59, Falsy + PatchPoint NoSingletonClass(Object@0x1018) + PatchPoint MethodRedefined(Object@0x1018, lambda@0x1020, cme:0x1028) + v114:HeapObject[class_exact*:Object@VALUE(0x1018)] = GuardType v57, HeapObject[class_exact*:Object@VALUE(0x1018)] + v115:BasicObject = CCallWithFrame v114, :Kernel#lambda@0x1050, block=0x1058 + v74:BasicObject = GetLocal :list, l0, EP@6 + v76:BasicObject = GetLocal :iter_method, l0, EP@4 + v77:BasicObject = GetLocal :kwsplat, l0, EP@3 + SetLocal :sep, l0, EP@5, v115 + Jump bb7(v57, v74, v115, v76, v77) + bb7(v81:BasicObject, v82:BasicObject, v83:BasicObject, v84:BasicObject, v85:BasicObject): + PatchPoint SingleRactorMode + PatchPoint StableConstantNames(0x1060, CONST) + v110:HashExact[VALUE(0x1068)] = Const Value(VALUE(0x1068)) + SetLocal :kwsplat, l0, EP@3, v110 + v95:BasicObject = GetLocal :list, l0, EP@6 + v97:BasicObject = GetLocal :iter_method, l0, EP@4 + v99:BasicObject = Send v95, 0x1070, :__send__, v97 # SendFallbackReason: Send: unsupported method type Optimized + v100:BasicObject = GetLocal :list, l0, EP@6 + v101:BasicObject = GetLocal :sep, l0, EP@5 + v102:BasicObject = GetLocal :iter_method, l0, EP@4 + v103:BasicObject = GetLocal :kwsplat, l0, EP@3 + CheckInterrupts + Return v99 + "); + } } diff --git a/zjit/src/invariants.rs b/zjit/src/invariants.rs index f1180acf2a..7aa13cbfcb 100644 --- a/zjit/src/invariants.rs +++ b/zjit/src/invariants.rs @@ -206,20 +206,22 @@ pub extern "C" fn rb_zjit_invalidate_no_ep_escape(iseq: IseqPtr) { return; } - // Remember that this ISEQ may escape EP - let invariants = ZJITState::get_invariants(); - invariants.ep_escape_iseqs.insert(iseq); + with_vm_lock(src_loc!(), || { + // Remember that this ISEQ may escape EP + let invariants = ZJITState::get_invariants(); + invariants.ep_escape_iseqs.insert(iseq); - // If the ISEQ has been compiled assuming it doesn't escape EP, invalidate the JIT code. - if let Some(patch_points) = invariants.no_ep_escape_iseq_patch_points.get(&iseq) { - debug!("EP is escaped: {}", iseq_name(iseq)); + // If the ISEQ has been compiled assuming it doesn't escape EP, invalidate the JIT code. + if let Some(patch_points) = invariants.no_ep_escape_iseq_patch_points.get(&iseq) { + debug!("EP is escaped: {}", iseq_name(iseq)); - // Invalidate the patch points for this ISEQ - let cb = ZJITState::get_code_block(); - compile_patch_points!(cb, patch_points, "EP is escaped: {}", iseq_name(iseq)); + // Invalidate the patch points for this ISEQ + let cb = ZJITState::get_code_block(); + compile_patch_points!(cb, patch_points, "EP is escaped: {}", iseq_name(iseq)); - cb.mark_all_executable(); - } + cb.mark_all_executable(); + } + }); } /// Track that JIT code for a ISEQ will assume that base pointer is equal to environment pointer. |
