summaryrefslogtreecommitdiff
path: root/lib/ruby_vm
diff options
context:
space:
mode:
authorTakashi Kokubun <takashikkbn@gmail.com>2023-02-24 14:48:02 -0800
committerTakashi Kokubun <takashikkbn@gmail.com>2023-03-05 23:28:59 -0800
commit33213542f241709727475a386a3fa189d426b52d (patch)
tree08aa9e8c2a3c94bd26e87e8aba7a88d2c6406693 /lib/ruby_vm
parent5576da7900162234c8e114b72401a8e0681c7c61 (diff)
Implement invokesuper
Notes
Notes: Merged: https://github.com/ruby/ruby/pull/7448
Diffstat (limited to 'lib/ruby_vm')
-rw-r--r--lib/ruby_vm/mjit/insn_compiler.rb236
1 files changed, 208 insertions, 28 deletions
diff --git a/lib/ruby_vm/mjit/insn_compiler.rb b/lib/ruby_vm/mjit/insn_compiler.rb
index 469b5b6cac..29a66b4913 100644
--- a/lib/ruby_vm/mjit/insn_compiler.rb
+++ b/lib/ruby_vm/mjit/insn_compiler.rb
@@ -22,7 +22,7 @@ module RubyVM::MJIT
asm.incr_counter(:mjit_insns_count)
asm.comment("Insn: #{insn.name}")
- # 58/101
+ # 59/101
case insn.name
when :nop then nop(jit, ctx, asm)
when :getlocal then getlocal(jit, ctx, asm)
@@ -82,7 +82,7 @@ module RubyVM::MJIT
# opt_str_uminus
# opt_newarray_max
# opt_newarray_min
- # invokesuper
+ when :invokesuper then invokesuper(jit, ctx, asm)
# invokeblock
when :leave then leave(jit, ctx, asm)
# throw
@@ -629,10 +629,24 @@ module RubyVM::MJIT
# @param jit [RubyVM::MJIT::JITState]
# @param ctx [RubyVM::MJIT::Context]
# @param asm [RubyVM::MJIT::Assembler]
- # @param cd `RubyVM::MJIT::CPointer::Struct_rb_call_data`
- def opt_send_without_block(jit, ctx, asm)
- cd = C.rb_call_data.new(jit.operand(0))
- jit_call_general(jit, ctx, asm, cd)
+ def opt_send_without_block(jit, ctx, asm, cd: C.rb_call_data.new(jit.operand(0)))
+ # Specialize on a compile-time receiver, and split a block for chain guards
+ unless jit.at_current_insn?
+ defer_compilation(jit, ctx, asm)
+ return EndBlock
+ end
+
+ # calling->ci
+ mid = C.vm_ci_mid(cd.ci)
+ argc = C.vm_ci_argc(cd.ci)
+ flags = C.vm_ci_flag(cd.ci)
+
+ # vm_sendish
+ cme = jit_search_method(jit, ctx, asm, mid, argc, flags)
+ if cme == CantCompile
+ return CantCompile
+ end
+ jit_call_general(jit, ctx, asm, mid, argc, flags, cme)
end
# objtostring
@@ -648,7 +662,35 @@ module RubyVM::MJIT
# opt_str_uminus
# opt_newarray_max
# opt_newarray_min
- # invokesuper
+
+ # @param jit [RubyVM::MJIT::JITState]
+ # @param ctx [RubyVM::MJIT::Context]
+ # @param asm [RubyVM::MJIT::Assembler]
+ def invokesuper(jit, ctx, asm)
+ # Specialize on a compile-time receiver, and split a block for chain guards
+ unless jit.at_current_insn?
+ defer_compilation(jit, ctx, asm)
+ return EndBlock
+ end
+
+ cd = C.rb_call_data.new(jit.operand(0))
+ blockiseq = jit.operand(1)
+
+ jit_caller_setup_arg_block(jit, ctx, asm, cd.ci, blockiseq, true)
+
+ # calling->ci
+ mid = C.vm_ci_mid(cd.ci)
+ argc = C.vm_ci_argc(cd.ci)
+ flags = C.vm_ci_flag(cd.ci)
+
+ # vm_sendish
+ cme = jit_search_super_method(jit, ctx, asm, mid, argc, flags)
+ if cme == CantCompile
+ return CantCompile
+ end
+ jit_call_general(jit, ctx, asm, mid, argc, flags, cme)
+ end
+
# invokeblock
# @param jit [RubyVM::MJIT::JITState]
@@ -977,7 +1019,7 @@ module RubyVM::MJIT
# opt_neq is passed two rb_call_data as arguments:
# first for ==, second for !=
neq_cd = C.rb_call_data.new(jit.operand(1))
- jit_call_general(jit, ctx, asm, neq_cd)
+ opt_send_without_block(jit, ctx, asm, cd: neq_cd)
end
# @param jit [RubyVM::MJIT::JITState]
@@ -1861,9 +1903,25 @@ module RubyVM::MJIT
asm.jnz(side_exit(jit, ctx))
end
- # vm_get_ep
+ # See get_lvar_level in compile.c
+ def get_lvar_level(iseq)
+ level = 0
+ while iseq.to_i != iseq.body.local_iseq.to_i
+ level += 1
+ iseq = iseq.body.parent_iseq
+ end
+ return level
+ end
+
+ # GET_LEP
# @param jit [RubyVM::MJIT::JITState]
- # @param ctx [RubyVM::MJIT::Context]
+ # @param asm [RubyVM::MJIT::Assembler]
+ def jit_get_lep(jit, asm, reg:)
+ level = get_lvar_level(jit.iseq)
+ jit_get_ep(asm, level, reg:)
+ end
+
+ # vm_get_ep
# @param asm [RubyVM::MJIT::Assembler]
def jit_get_ep(asm, level, reg:)
asm.mov(reg, [CFP, C.rb_control_frame_t.offsetof(:ep)])
@@ -1940,29 +1998,37 @@ module RubyVM::MJIT
EndBlock
end
- # vm_call_general (vm_sendish -> vm_call_general)
+ # vm_caller_setup_arg_block
# @param jit [RubyVM::MJIT::JITState]
# @param ctx [RubyVM::MJIT::Context]
# @param asm [RubyVM::MJIT::Assembler]
- def jit_call_general(jit, ctx, asm, cd)
- ci = cd.ci
- mid = C.vm_ci_mid(ci)
- argc = C.vm_ci_argc(ci)
- flags = C.vm_ci_flag(ci)
- jit_call_method(jit, ctx, asm, mid, argc, flags)
+ def jit_caller_setup_arg_block(jit, ctx, asm, ci, blockiseq, is_super)
+ if C.vm_ci_flag(ci) & C.VM_CALL_ARGS_BLOCKARG != 0
+ asm.incr_counter(:send_blockarg)
+ return CantCompile
+ elsif blockiseq != 0
+ asm.incr_counter(:send_blockiseq)
+ return CantCompile
+ else
+ if is_super
+ # GET_BLOCK_HANDLER();
+ # Guard no block passed. Only handle that case for now.
+ asm.comment('guard no block given')
+ jit_get_lep(jit, asm, reg: :rax)
+ asm.cmp([:rax, C.VALUE.size * C.VM_ENV_DATA_INDEX_SPECVAL], C.VM_BLOCK_HANDLER_NONE)
+ asm.jne(counted_exit(side_exit(jit, ctx), :send_block_handler))
+ else
+ raise NotImplementedError
+ end
+ end
end
- # vm_call_method
+ # vm_search_method
# @param jit [RubyVM::MJIT::JITState]
# @param ctx [RubyVM::MJIT::Context]
# @param asm [RubyVM::MJIT::Assembler]
- # @param send_shift [Integer] The number of shifts needed for VM_CALL_OPT_SEND
- def jit_call_method(jit, ctx, asm, mid, argc, flags, send_shift: 0)
- # Specialize on a compile-time receiver, and split a block for chain guards
- unless jit.at_current_insn?
- defer_compilation(jit, ctx, asm)
- return EndBlock
- end
+ def jit_search_method(jit, ctx, asm, mid, argc, flags, send_shift: 0)
+ assert_equal(true, jit.at_current_insn?)
# Generate a side exit
side_exit = side_exit(jit, ctx)
@@ -1993,6 +2059,111 @@ module RubyVM::MJIT
return CantCompile # We don't support vm_call_method_name
end
+ # Invalidate on redefinition (part of vm_search_method_fastpath)
+ Invariants.assume_method_lookup_stable(jit, cme)
+
+ return cme
+ end
+
+ def jit_search_super_method(jit, ctx, asm, mid, argc, flags)
+ assert_equal(true, jit.at_current_insn?)
+
+ me = C.rb_vm_frame_method_entry(jit.cfp)
+ if me.nil?
+ return CantCompile
+ end
+
+ # FIXME: We should track and invalidate this block when this cme is invalidated
+ current_defined_class = me.defined_class
+ mid = me.def.original_id
+
+ if me.to_i != C.rb_callable_method_entry(current_defined_class, me.called_id).to_i
+ # Though we likely could generate this call, as we are only concerned
+ # with the method entry remaining valid, assume_method_lookup_stable
+ # below requires that the method lookup matches as well
+ return CantCompile
+ end
+
+ # vm_search_normal_superclass
+ rbasic_klass = C.to_ruby(C.RBasic.new(C.to_value(current_defined_class)).klass)
+ if C.BUILTIN_TYPE(current_defined_class) == C.RUBY_T_ICLASS && C.BUILTIN_TYPE(rbasic_klass) == C.RUBY_T_MODULE && \
+ C.FL_TEST_RAW(rbasic_klass, C.RMODULE_IS_REFINEMENT) != 0
+ return CantCompile
+ end
+ comptime_superclass = C.rb_class_get_superclass(current_defined_class)
+
+ # Don't JIT calls that aren't simple
+ # Note, not using VM_CALL_ARGS_SIMPLE because sometimes we pass a block.
+
+ if flags & C.VM_CALL_KWARG != 0
+ asm.incr_counter(:send_kwarg)
+ return CantCompile
+ end
+ if flags & C.VM_CALL_KW_SPLAT != 0
+ asm.incr_counter(:send_kw_splat)
+ return CantCompile
+ end
+ if flags & C.VM_CALL_ARGS_BLOCKARG != 0
+ asm.incr_counter(:send_blockarg)
+ return CantCompile
+ end
+
+ # Ensure we haven't rebound this method onto an incompatible class.
+ # In the interpreter we try to avoid making this check by performing some
+ # cheaper calculations first, but since we specialize on the method entry
+ # and so only have to do this once at compile time this is fine to always
+ # check and side exit.
+ comptime_recv = jit.peek_at_stack(argc)
+ unless comptime_recv.kind_of?(current_defined_class)
+ return CantCompile
+ end
+
+ # Do method lookup
+ cme = C.rb_callable_method_entry(comptime_superclass, mid)
+
+ if cme.nil?
+ return CantCompile
+ end
+
+ # Check that we'll be able to write this method dispatch before generating checks
+ cme_def_type = cme.def.type
+ if cme_def_type != C.VM_METHOD_TYPE_ISEQ && cme_def_type != C.VM_METHOD_TYPE_CFUNC
+ # others unimplemented
+ return CantCompile
+ end
+
+ # Guard that the receiver has the same class as the one from compile time
+ side_exit = side_exit(jit, ctx)
+
+ asm.comment('guard known me')
+ jit_get_lep(jit, asm, reg: :rax)
+
+ asm.mov(:rcx, me.to_i)
+ asm.cmp([:rax, C.VALUE.size * C.VM_ENV_DATA_INDEX_ME_CREF], :rcx)
+ asm.jne(counted_exit(side_exit, :invokesuper_me_changed))
+
+ # We need to assume that both our current method entry and the super
+ # method entry we invoke remain stable
+ Invariants.assume_method_lookup_stable(jit, me)
+ Invariants.assume_method_lookup_stable(jit, cme)
+
+ return cme
+ end
+
+ # vm_call_general
+ # @param jit [RubyVM::MJIT::JITState]
+ # @param ctx [RubyVM::MJIT::Context]
+ # @param asm [RubyVM::MJIT::Assembler]
+ def jit_call_general(jit, ctx, asm, mid, argc, flags, cme)
+ jit_call_method(jit, ctx, asm, mid, argc, flags, cme)
+ end
+
+ # vm_call_method
+ # @param jit [RubyVM::MJIT::JITState]
+ # @param ctx [RubyVM::MJIT::Context]
+ # @param asm [RubyVM::MJIT::Assembler]
+ # @param send_shift [Integer] The number of shifts needed for VM_CALL_OPT_SEND
+ def jit_call_method(jit, ctx, asm, mid, argc, flags, cme, send_shift: 0)
# The main check of vm_call_method before vm_call_method_each_type
case C.METHOD_ENTRY_VISI(cme)
when C.METHOD_VISI_PUBLIC
@@ -2011,8 +2182,11 @@ module RubyVM::MJIT
raise 'unreachable'
end
- # Invalidate on redefinition (part of vm_search_method_fastpath)
- Invariants.assume_method_lookup_stable(jit, cme)
+ # Get a compile-time receiver
+ recv_idx = argc + (flags & C.VM_CALL_ARGS_BLOCKARG != 0 ? 1 : 0)
+ recv_idx += send_shift
+ comptime_recv = jit.peek_at_stack(recv_idx)
+ recv_opnd = ctx.stack_opnd(recv_idx)
jit_call_method_each_type(jit, ctx, asm, argc, flags, cme, comptime_recv, recv_opnd, send_shift:)
end
@@ -2343,8 +2517,14 @@ module RubyVM::MJIT
asm.cmp(C_RET, mid)
jit_chain_guard(:jne, jit, ctx, asm, mid_changed_exit)
+ # rb_callable_method_entry_with_refinements
+ cme = jit_search_method(jit, ctx, asm, mid, argc, flags, send_shift:)
+ if cme == CantCompile
+ return CantCompile
+ end
+
if flags & C.VM_CALL_FCALL != 0
- return jit_call_method(jit, ctx, asm, mid, argc, flags, send_shift:)
+ return jit_call_method(jit, ctx, asm, mid, argc, flags, cme, send_shift:)
end
raise NotImplementedError # unreachable for now