diff options
| author | Takashi Kokubun <takashikkbn@gmail.com> | 2023-02-24 14:48:02 -0800 |
|---|---|---|
| committer | Takashi Kokubun <takashikkbn@gmail.com> | 2023-03-05 23:28:59 -0800 |
| commit | 33213542f241709727475a386a3fa189d426b52d (patch) | |
| tree | 08aa9e8c2a3c94bd26e87e8aba7a88d2c6406693 /lib/ruby_vm | |
| parent | 5576da7900162234c8e114b72401a8e0681c7c61 (diff) | |
Implement invokesuper
Notes
Notes:
Merged: https://github.com/ruby/ruby/pull/7448
Diffstat (limited to 'lib/ruby_vm')
| -rw-r--r-- | lib/ruby_vm/mjit/insn_compiler.rb | 236 |
1 files changed, 208 insertions, 28 deletions
diff --git a/lib/ruby_vm/mjit/insn_compiler.rb b/lib/ruby_vm/mjit/insn_compiler.rb index 469b5b6cac..29a66b4913 100644 --- a/lib/ruby_vm/mjit/insn_compiler.rb +++ b/lib/ruby_vm/mjit/insn_compiler.rb @@ -22,7 +22,7 @@ module RubyVM::MJIT asm.incr_counter(:mjit_insns_count) asm.comment("Insn: #{insn.name}") - # 58/101 + # 59/101 case insn.name when :nop then nop(jit, ctx, asm) when :getlocal then getlocal(jit, ctx, asm) @@ -82,7 +82,7 @@ module RubyVM::MJIT # opt_str_uminus # opt_newarray_max # opt_newarray_min - # invokesuper + when :invokesuper then invokesuper(jit, ctx, asm) # invokeblock when :leave then leave(jit, ctx, asm) # throw @@ -629,10 +629,24 @@ module RubyVM::MJIT # @param jit [RubyVM::MJIT::JITState] # @param ctx [RubyVM::MJIT::Context] # @param asm [RubyVM::MJIT::Assembler] - # @param cd `RubyVM::MJIT::CPointer::Struct_rb_call_data` - def opt_send_without_block(jit, ctx, asm) - cd = C.rb_call_data.new(jit.operand(0)) - jit_call_general(jit, ctx, asm, cd) + def opt_send_without_block(jit, ctx, asm, cd: C.rb_call_data.new(jit.operand(0))) + # Specialize on a compile-time receiver, and split a block for chain guards + unless jit.at_current_insn? + defer_compilation(jit, ctx, asm) + return EndBlock + end + + # calling->ci + mid = C.vm_ci_mid(cd.ci) + argc = C.vm_ci_argc(cd.ci) + flags = C.vm_ci_flag(cd.ci) + + # vm_sendish + cme = jit_search_method(jit, ctx, asm, mid, argc, flags) + if cme == CantCompile + return CantCompile + end + jit_call_general(jit, ctx, asm, mid, argc, flags, cme) end # objtostring @@ -648,7 +662,35 @@ module RubyVM::MJIT # opt_str_uminus # opt_newarray_max # opt_newarray_min - # invokesuper + + # @param jit [RubyVM::MJIT::JITState] + # @param ctx [RubyVM::MJIT::Context] + # @param asm [RubyVM::MJIT::Assembler] + def invokesuper(jit, ctx, asm) + # Specialize on a compile-time receiver, and split a block for chain guards + unless jit.at_current_insn? + defer_compilation(jit, ctx, asm) + return EndBlock + end + + cd = C.rb_call_data.new(jit.operand(0)) + blockiseq = jit.operand(1) + + jit_caller_setup_arg_block(jit, ctx, asm, cd.ci, blockiseq, true) + + # calling->ci + mid = C.vm_ci_mid(cd.ci) + argc = C.vm_ci_argc(cd.ci) + flags = C.vm_ci_flag(cd.ci) + + # vm_sendish + cme = jit_search_super_method(jit, ctx, asm, mid, argc, flags) + if cme == CantCompile + return CantCompile + end + jit_call_general(jit, ctx, asm, mid, argc, flags, cme) + end + # invokeblock # @param jit [RubyVM::MJIT::JITState] @@ -977,7 +1019,7 @@ module RubyVM::MJIT # opt_neq is passed two rb_call_data as arguments: # first for ==, second for != neq_cd = C.rb_call_data.new(jit.operand(1)) - jit_call_general(jit, ctx, asm, neq_cd) + opt_send_without_block(jit, ctx, asm, cd: neq_cd) end # @param jit [RubyVM::MJIT::JITState] @@ -1861,9 +1903,25 @@ module RubyVM::MJIT asm.jnz(side_exit(jit, ctx)) end - # vm_get_ep + # See get_lvar_level in compile.c + def get_lvar_level(iseq) + level = 0 + while iseq.to_i != iseq.body.local_iseq.to_i + level += 1 + iseq = iseq.body.parent_iseq + end + return level + end + + # GET_LEP # @param jit [RubyVM::MJIT::JITState] - # @param ctx [RubyVM::MJIT::Context] + # @param asm [RubyVM::MJIT::Assembler] + def jit_get_lep(jit, asm, reg:) + level = get_lvar_level(jit.iseq) + jit_get_ep(asm, level, reg:) + end + + # vm_get_ep # @param asm [RubyVM::MJIT::Assembler] def jit_get_ep(asm, level, reg:) asm.mov(reg, [CFP, C.rb_control_frame_t.offsetof(:ep)]) @@ -1940,29 +1998,37 @@ module RubyVM::MJIT EndBlock end - # vm_call_general (vm_sendish -> vm_call_general) + # vm_caller_setup_arg_block # @param jit [RubyVM::MJIT::JITState] # @param ctx [RubyVM::MJIT::Context] # @param asm [RubyVM::MJIT::Assembler] - def jit_call_general(jit, ctx, asm, cd) - ci = cd.ci - mid = C.vm_ci_mid(ci) - argc = C.vm_ci_argc(ci) - flags = C.vm_ci_flag(ci) - jit_call_method(jit, ctx, asm, mid, argc, flags) + def jit_caller_setup_arg_block(jit, ctx, asm, ci, blockiseq, is_super) + if C.vm_ci_flag(ci) & C.VM_CALL_ARGS_BLOCKARG != 0 + asm.incr_counter(:send_blockarg) + return CantCompile + elsif blockiseq != 0 + asm.incr_counter(:send_blockiseq) + return CantCompile + else + if is_super + # GET_BLOCK_HANDLER(); + # Guard no block passed. Only handle that case for now. + asm.comment('guard no block given') + jit_get_lep(jit, asm, reg: :rax) + asm.cmp([:rax, C.VALUE.size * C.VM_ENV_DATA_INDEX_SPECVAL], C.VM_BLOCK_HANDLER_NONE) + asm.jne(counted_exit(side_exit(jit, ctx), :send_block_handler)) + else + raise NotImplementedError + end + end end - # vm_call_method + # vm_search_method # @param jit [RubyVM::MJIT::JITState] # @param ctx [RubyVM::MJIT::Context] # @param asm [RubyVM::MJIT::Assembler] - # @param send_shift [Integer] The number of shifts needed for VM_CALL_OPT_SEND - def jit_call_method(jit, ctx, asm, mid, argc, flags, send_shift: 0) - # Specialize on a compile-time receiver, and split a block for chain guards - unless jit.at_current_insn? - defer_compilation(jit, ctx, asm) - return EndBlock - end + def jit_search_method(jit, ctx, asm, mid, argc, flags, send_shift: 0) + assert_equal(true, jit.at_current_insn?) # Generate a side exit side_exit = side_exit(jit, ctx) @@ -1993,6 +2059,111 @@ module RubyVM::MJIT return CantCompile # We don't support vm_call_method_name end + # Invalidate on redefinition (part of vm_search_method_fastpath) + Invariants.assume_method_lookup_stable(jit, cme) + + return cme + end + + def jit_search_super_method(jit, ctx, asm, mid, argc, flags) + assert_equal(true, jit.at_current_insn?) + + me = C.rb_vm_frame_method_entry(jit.cfp) + if me.nil? + return CantCompile + end + + # FIXME: We should track and invalidate this block when this cme is invalidated + current_defined_class = me.defined_class + mid = me.def.original_id + + if me.to_i != C.rb_callable_method_entry(current_defined_class, me.called_id).to_i + # Though we likely could generate this call, as we are only concerned + # with the method entry remaining valid, assume_method_lookup_stable + # below requires that the method lookup matches as well + return CantCompile + end + + # vm_search_normal_superclass + rbasic_klass = C.to_ruby(C.RBasic.new(C.to_value(current_defined_class)).klass) + if C.BUILTIN_TYPE(current_defined_class) == C.RUBY_T_ICLASS && C.BUILTIN_TYPE(rbasic_klass) == C.RUBY_T_MODULE && \ + C.FL_TEST_RAW(rbasic_klass, C.RMODULE_IS_REFINEMENT) != 0 + return CantCompile + end + comptime_superclass = C.rb_class_get_superclass(current_defined_class) + + # Don't JIT calls that aren't simple + # Note, not using VM_CALL_ARGS_SIMPLE because sometimes we pass a block. + + if flags & C.VM_CALL_KWARG != 0 + asm.incr_counter(:send_kwarg) + return CantCompile + end + if flags & C.VM_CALL_KW_SPLAT != 0 + asm.incr_counter(:send_kw_splat) + return CantCompile + end + if flags & C.VM_CALL_ARGS_BLOCKARG != 0 + asm.incr_counter(:send_blockarg) + return CantCompile + end + + # Ensure we haven't rebound this method onto an incompatible class. + # In the interpreter we try to avoid making this check by performing some + # cheaper calculations first, but since we specialize on the method entry + # and so only have to do this once at compile time this is fine to always + # check and side exit. + comptime_recv = jit.peek_at_stack(argc) + unless comptime_recv.kind_of?(current_defined_class) + return CantCompile + end + + # Do method lookup + cme = C.rb_callable_method_entry(comptime_superclass, mid) + + if cme.nil? + return CantCompile + end + + # Check that we'll be able to write this method dispatch before generating checks + cme_def_type = cme.def.type + if cme_def_type != C.VM_METHOD_TYPE_ISEQ && cme_def_type != C.VM_METHOD_TYPE_CFUNC + # others unimplemented + return CantCompile + end + + # Guard that the receiver has the same class as the one from compile time + side_exit = side_exit(jit, ctx) + + asm.comment('guard known me') + jit_get_lep(jit, asm, reg: :rax) + + asm.mov(:rcx, me.to_i) + asm.cmp([:rax, C.VALUE.size * C.VM_ENV_DATA_INDEX_ME_CREF], :rcx) + asm.jne(counted_exit(side_exit, :invokesuper_me_changed)) + + # We need to assume that both our current method entry and the super + # method entry we invoke remain stable + Invariants.assume_method_lookup_stable(jit, me) + Invariants.assume_method_lookup_stable(jit, cme) + + return cme + end + + # vm_call_general + # @param jit [RubyVM::MJIT::JITState] + # @param ctx [RubyVM::MJIT::Context] + # @param asm [RubyVM::MJIT::Assembler] + def jit_call_general(jit, ctx, asm, mid, argc, flags, cme) + jit_call_method(jit, ctx, asm, mid, argc, flags, cme) + end + + # vm_call_method + # @param jit [RubyVM::MJIT::JITState] + # @param ctx [RubyVM::MJIT::Context] + # @param asm [RubyVM::MJIT::Assembler] + # @param send_shift [Integer] The number of shifts needed for VM_CALL_OPT_SEND + def jit_call_method(jit, ctx, asm, mid, argc, flags, cme, send_shift: 0) # The main check of vm_call_method before vm_call_method_each_type case C.METHOD_ENTRY_VISI(cme) when C.METHOD_VISI_PUBLIC @@ -2011,8 +2182,11 @@ module RubyVM::MJIT raise 'unreachable' end - # Invalidate on redefinition (part of vm_search_method_fastpath) - Invariants.assume_method_lookup_stable(jit, cme) + # Get a compile-time receiver + recv_idx = argc + (flags & C.VM_CALL_ARGS_BLOCKARG != 0 ? 1 : 0) + recv_idx += send_shift + comptime_recv = jit.peek_at_stack(recv_idx) + recv_opnd = ctx.stack_opnd(recv_idx) jit_call_method_each_type(jit, ctx, asm, argc, flags, cme, comptime_recv, recv_opnd, send_shift:) end @@ -2343,8 +2517,14 @@ module RubyVM::MJIT asm.cmp(C_RET, mid) jit_chain_guard(:jne, jit, ctx, asm, mid_changed_exit) + # rb_callable_method_entry_with_refinements + cme = jit_search_method(jit, ctx, asm, mid, argc, flags, send_shift:) + if cme == CantCompile + return CantCompile + end + if flags & C.VM_CALL_FCALL != 0 - return jit_call_method(jit, ctx, asm, mid, argc, flags, send_shift:) + return jit_call_method(jit, ctx, asm, mid, argc, flags, cme, send_shift:) end raise NotImplementedError # unreachable for now |
