summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTakashi Kokubun <takashikkbn@gmail.com>2023-02-13 23:57:40 -0800
committerTakashi Kokubun <takashikkbn@gmail.com>2023-03-05 23:28:59 -0800
commit3774fe4e912afc97649207fdfd17aa8b44903152 (patch)
treeec2137945cda0d943fb776076bb969ff59115c2d
parentb5c5052839b126d479add428c0c227157a96a90c (diff)
Implement opt_eq and opt_neq
Notes
Notes: Merged: https://github.com/ruby/ruby/pull/7448
-rw-r--r--lib/ruby_vm/mjit/assembler.rb24
-rw-r--r--lib/ruby_vm/mjit/insn_compiler.rb117
-rw-r--r--mjit_c.h1
-rw-r--r--mjit_c.rb12
-rwxr-xr-xtool/mjit/bindgen.rb2
5 files changed, 151 insertions, 5 deletions
diff --git a/lib/ruby_vm/mjit/assembler.rb b/lib/ruby_vm/mjit/assembler.rb
index bbc9691b56..3ffb3e45ad 100644
--- a/lib/ruby_vm/mjit/assembler.rb
+++ b/lib/ruby_vm/mjit/assembler.rb
@@ -156,6 +156,22 @@ module RubyVM::MJIT
end
end
+ def cmove(dst, src)
+ case [dst, src]
+ # CMOVE r64, r/m64 (Mod 11: reg)
+ in [Symbol => dst_reg, Symbol => src_reg]
+ # REX.W + 0F 44 /r
+ # RM: Operand 1: ModRM:reg (r, w), Operand 2: ModRM:r/m (r)
+ insn(
+ prefix: REX_W,
+ opcode: [0x0f, 0x44],
+ mod_rm: ModRM[mod: Mod11, reg: dst_reg, rm: src_reg],
+ )
+ else
+ raise NotImplementedError, "cmove: not-implemented operands: #{dst.inspect}, #{src.inspect}"
+ end
+ end
+
def cmovg(dst, src)
case [dst, src]
# CMOVG r64, r/m64 (Mod 11: reg)
@@ -290,6 +306,10 @@ module RubyVM::MJIT
def je(dst)
case dst
+ # JE rel8
+ in Label => dst_label
+ # 74 cb
+ insn(opcode: 0x74, imm: dst_label)
# JE rel32
in Integer => dst_addr
# 0F 84 cd
@@ -301,6 +321,10 @@ module RubyVM::MJIT
def jmp(dst)
case dst
+ # JZ rel8
+ in Label => dst_label
+ # EB cb
+ insn(opcode: 0xeb, imm: dst_label)
# JMP rel32
in Integer => dst_addr
# E9 cd
diff --git a/lib/ruby_vm/mjit/insn_compiler.rb b/lib/ruby_vm/mjit/insn_compiler.rb
index 95d837b073..a3ebe4a9e5 100644
--- a/lib/ruby_vm/mjit/insn_compiler.rb
+++ b/lib/ruby_vm/mjit/insn_compiler.rb
@@ -23,7 +23,7 @@ module RubyVM::MJIT
asm.incr_counter(:mjit_insns_count)
asm.comment("Insn: #{insn.name}")
- # 38/101
+ # 40/101
case insn.name
when :nop then nop(jit, ctx, asm)
# getlocal
@@ -98,8 +98,8 @@ module RubyVM::MJIT
when :opt_mult then opt_mult(jit, ctx, asm)
when :opt_div then opt_div(jit, ctx, asm)
when :opt_mod then opt_mod(jit, ctx, asm)
- # opt_eq
- # opt_neq
+ when :opt_eq then opt_eq(jit, ctx, asm)
+ when :opt_neq then opt_neq(jit, ctx, asm)
when :opt_lt then opt_lt(jit, ctx, asm)
when :opt_le then opt_le(jit, ctx, asm)
when :opt_gt then opt_gt(jit, ctx, asm)
@@ -610,8 +610,32 @@ module RubyVM::MJIT
end
end
- # opt_eq
- # opt_neq
+ # @param jit [RubyVM::MJIT::JITState]
+ # @param ctx [RubyVM::MJIT::Context]
+ # @param asm [RubyVM::MJIT::Assembler]
+ def opt_eq(jit, ctx, asm)
+ unless jit.at_current_insn?
+ defer_compilation(jit, ctx, asm)
+ return EndBlock
+ end
+
+ if jit_equality_specialized(jit, ctx, asm)
+ jump_to_next_insn(jit, ctx, asm)
+ EndBlock
+ else
+ opt_send_without_block(jit, ctx, asm)
+ end
+ end
+
+ # @param jit [RubyVM::MJIT::JITState]
+ # @param ctx [RubyVM::MJIT::Context]
+ # @param asm [RubyVM::MJIT::Assembler]
+ def opt_neq(jit, ctx, asm)
+ # opt_neq is passed two rb_call_data as arguments:
+ # first for ==, second for !=
+ neq_cd = C.rb_call_data.new(jit.operand(1))
+ jit_call_method(jit, ctx, asm, neq_cd)
+ end
# @param jit [RubyVM::MJIT::JITState]
# @param ctx [RubyVM::MJIT::Context]
@@ -1168,6 +1192,89 @@ module RubyVM::MJIT
# @param jit [RubyVM::MJIT::JITState]
# @param ctx [RubyVM::MJIT::Context]
# @param asm [RubyVM::MJIT::Assembler]
+ def jit_equality_specialized(jit, ctx, asm)
+ # Create a side-exit to fall back to the interpreter
+ side_exit = side_exit(jit, ctx)
+
+ a_opnd = ctx.stack_opnd(1)
+ b_opnd = ctx.stack_opnd(0)
+
+ comptime_a = jit.peek_at_stack(1)
+ comptime_b = jit.peek_at_stack(0)
+
+ if two_fixnums_on_stack?(jit)
+ unless Invariants.assume_bop_not_redefined(jit, C.INTEGER_REDEFINED_OP_FLAG, C.BOP_EQ)
+ return false
+ end
+
+ guard_two_fixnums(jit, ctx, asm, side_exit)
+
+ asm.comment('check fixnum equality')
+ asm.mov(:rax, a_opnd)
+ asm.mov(:rcx, b_opnd)
+ asm.cmp(:rax, :rcx)
+ asm.mov(:rax, Qfalse)
+ asm.mov(:rcx, Qtrue)
+ asm.cmove(:rax, :rcx)
+
+ # Push the output on the stack
+ ctx.stack_pop(2)
+ dst = ctx.stack_push
+ asm.mov(dst, :rax)
+
+ true
+ elsif comptime_a.class == String && comptime_b.class == String
+ unless Invariants.assume_bop_not_redefined(jit, C.STRING_REDEFINED_OP_FLAG, C.BOP_EQ)
+ # if overridden, emit the generic version
+ return false
+ end
+
+ # Guard that a is a String
+ if jit_guard_known_class(jit, ctx, asm, comptime_a.class, a_opnd, comptime_a, side_exit) == CantCompile
+ return false
+ end
+
+ equal_label = asm.new_label(:equal)
+ ret_label = asm.new_label(:ret)
+
+ # If they are equal by identity, return true
+ asm.mov(:rax, a_opnd)
+ asm.mov(:rcx, b_opnd)
+ asm.cmp(:rax, :rcx)
+ asm.je(equal_label)
+
+ # Otherwise guard that b is a T_STRING (from type info) or String (from runtime guard)
+ # Note: any T_STRING is valid here, but we check for a ::String for simplicity
+ # To pass a mutable static variable (rb_cString) requires an unsafe block
+ if jit_guard_known_class(jit, ctx, asm, comptime_b.class, b_opnd, comptime_b, side_exit) == CantCompile
+ return false
+ end
+
+ asm.comment('call rb_str_eql_internal')
+ asm.mov(C_ARGS[0], a_opnd)
+ asm.mov(C_ARGS[1], b_opnd)
+ asm.call(C.rb_str_eql_internal)
+
+ # Push the output on the stack
+ ctx.stack_pop(2)
+ dst = ctx.stack_push
+ asm.mov(dst, C_RET)
+ asm.jmp(ret_label)
+
+ asm.write_label(equal_label)
+ asm.mov(dst, Qtrue)
+
+ asm.write_label(ret_label)
+
+ true
+ else
+ false
+ end
+ end
+
+ # @param jit [RubyVM::MJIT::JITState]
+ # @param ctx [RubyVM::MJIT::Context]
+ # @param asm [RubyVM::MJIT::Assembler]
def jit_prepare_routine_call(jit, ctx, asm)
jit_save_pc(jit, asm)
jit_save_sp(jit, ctx, asm)
diff --git a/mjit_c.h b/mjit_c.h
index 12b75e1082..668920e3fb 100644
--- a/mjit_c.h
+++ b/mjit_c.h
@@ -3,6 +3,7 @@
#define MJIT_C_H
#include "ruby/internal/config.h"
+#include "internal/string.h"
#include "vm_core.h"
#include "vm_callinfo.h"
#include "builtin.h"
diff --git a/mjit_c.rb b/mjit_c.rb
index 5979c5ab51..1de0887aa3 100644
--- a/mjit_c.rb
+++ b/mjit_c.rb
@@ -164,6 +164,10 @@ module RubyVM::MJIT # :nodoc: all
}
end
+ def rb_str_eql_internal
+ Primitive.cexpr! 'SIZET2NUM((size_t)rb_str_eql_internal)'
+ end
+
#========================================================================================
#
# Old stuff
@@ -337,6 +341,10 @@ module RubyVM::MJIT # :nodoc: all
Primitive.cexpr! %q{ UINT2NUM(BOP_AREF) }
end
+ def C.BOP_EQ
+ Primitive.cexpr! %q{ UINT2NUM(BOP_EQ) }
+ end
+
def C.BOP_GE
Primitive.cexpr! %q{ UINT2NUM(BOP_GE) }
end
@@ -433,6 +441,10 @@ module RubyVM::MJIT # :nodoc: all
Primitive.cexpr! %q{ UINT2NUM(SHAPE_ROOT) }
end
+ def C.STRING_REDEFINED_OP_FLAG
+ Primitive.cexpr! %q{ UINT2NUM(STRING_REDEFINED_OP_FLAG) }
+ end
+
def C.T_OBJECT
Primitive.cexpr! %q{ UINT2NUM(T_OBJECT) }
end
diff --git a/tool/mjit/bindgen.rb b/tool/mjit/bindgen.rb
index 74072e07c8..281cc903e9 100755
--- a/tool/mjit/bindgen.rb
+++ b/tool/mjit/bindgen.rb
@@ -350,6 +350,7 @@ generator = BindingGenerator.new(
UINT: %w[
BOP_AND
BOP_AREF
+ BOP_EQ
BOP_GE
BOP_GT
BOP_LE
@@ -361,6 +362,7 @@ generator = BindingGenerator.new(
ARRAY_REDEFINED_OP_FLAG
HASH_REDEFINED_OP_FLAG
INTEGER_REDEFINED_OP_FLAG
+ STRING_REDEFINED_OP_FLAG
METHOD_VISI_PRIVATE
METHOD_VISI_PROTECTED
METHOD_VISI_PUBLIC