summaryrefslogtreecommitdiff
path: root/ujit_core.c
diff options
context:
space:
mode:
authorMaxime Chevalier-Boisvert <maxime.chevalierboisvert@shopify.com>2020-12-16 21:45:51 -0500
committerAlan Wu <XrXr@users.noreply.github.com>2021-10-20 18:19:26 -0400
commit7d7e58d3521c797d772fdb5f974061c8a3758594 (patch)
treef85258a44efaf7c7394015cb99b6342fecd50491 /ujit_core.c
parent40b70ef7c762701d26539e5a401449d7f3733b5a (diff)
Implement branch stub logic
Diffstat (limited to 'ujit_core.c')
-rw-r--r--ujit_core.c86
1 files changed, 54 insertions, 32 deletions
diff --git a/ujit_core.c b/ujit_core.c
index 916e8bc4d8..4a9ddc334d 100644
--- a/ujit_core.c
+++ b/ujit_core.c
@@ -101,25 +101,6 @@ static const struct st_hash_type hashtype_blockid = {
blockid_hash,
};
-// Called by the generated code when a branch stub is executed
-// Triggers compilation of branches and code patching
-void branch_stub_hit(uint32_t branch_idx, uint32_t target_idx)
-{
-
-
-
-
- // TODO
- //uint8_t* code_ptr = ujit_compile_block(blockid.iseq, blockid.idx, false);
- //st_insert(version_tbl, (st_data_t)&blockid, (st_data_t)code_ptr);
-
-
-
-
-
-
-}
-
// Retrieve a basic block version for an (iseq, idx) tuple
uint8_t* find_block_version(blockid_t block)
{
@@ -132,6 +113,48 @@ uint8_t* find_block_version(blockid_t block)
return NULL;
}
+// Called by the generated code when a branch stub is executed
+// Triggers compilation of branches and code patching
+uint8_t* branch_stub_hit(uint32_t branch_idx, uint32_t target_idx)
+{
+ assert (branch_idx < num_branches);
+ assert (target_idx < 2);
+ branch_t branch = branch_entries[branch_idx];
+ blockid_t target = branch.targets[target_idx];
+
+ // If either of the target blocks will be placed next
+ if (cb->write_pos == branch.end_pos)
+ {
+ branch.shape = (uint8_t)target_idx;
+
+ // Rewrite the branch with the new, potentially more compact shape
+ cb_set_pos(cb, branch.start_pos);
+ branch.gen_fn(cb, branch.dst_addrs[0], branch.dst_addrs[1], branch.shape);
+ assert (cb->write_pos <= branch.end_pos);
+ }
+
+ // Try to find a compiled version of this block
+ uint8_t* code_ptr = find_block_version(target);
+
+ // If this block hasn't yet been compiled
+ if (!code_ptr)
+ {
+ code_ptr = ujit_compile_block(target.iseq, target.idx, false);
+ st_insert(version_tbl, (st_data_t)&target, (st_data_t)code_ptr);
+ branch.dst_addrs[target_idx] = code_ptr;
+ }
+
+ // Rewrite the branch with the new jump target address
+ size_t cur_pos = cb->write_pos;
+ cb_set_pos(cb, branch.start_pos);
+ branch.gen_fn(cb, branch.dst_addrs[0], branch.dst_addrs[1], branch.shape);
+ assert (cb->write_pos <= branch.end_pos);
+ cb_set_pos(cb, cur_pos);
+
+ // Return a pointer to the compiled block version
+ return code_ptr;
+}
+
// Get a version or stub corresponding to a branch target
// TODO: need incoming and target versioning contexts
uint8_t* get_branch_target(codeblock_t* ocb, blockid_t target, uint32_t branch_idx, uint32_t target_idx)
@@ -145,16 +168,13 @@ uint8_t* get_branch_target(codeblock_t* ocb, blockid_t target, uint32_t branch_i
// Generate an outlined stub that will call
// branch_stub_hit(uint32_t branch_idx, uint32_t target_idx)
+ mov(ocb, RDI, imm_opnd(branch_idx));
+ mov(ocb, RSI, imm_opnd(target_idx));
+ call_ptr(ocb, REG0, (void *)&branch_stub_hit);
-
-
-
-
-
-
-
-
-
+ // Jump to the address returned by the
+ // branch_stub_hit call
+ jmp_rm(ocb, RAX);
return stub_addr;
}
@@ -162,13 +182,13 @@ uint8_t* get_branch_target(codeblock_t* ocb, blockid_t target, uint32_t branch_i
void gen_branch(codeblock_t* cb, codeblock_t* ocb, blockid_t target0, blockid_t target1, branchgen_fn gen_fn)
{
// Get branch targets or stubs (code pointers)
- uint8_t* target_code0 = get_branch_target(ocb, target0, num_branches, 0);
- uint8_t* target_code1 = get_branch_target(ocb, target1, num_branches, 1);
+ uint8_t* dst_addr0 = get_branch_target(ocb, target0, num_branches, 0);
+ uint8_t* dst_addr1 = get_branch_target(ocb, target1, num_branches, 1);
uint32_t start_pos = (uint32_t)cb->write_pos;
// Call the branch generation function
- gen_fn(cb, target_code0, target_code1, DEFAULT);
+ gen_fn(cb, dst_addr0, dst_addr1, SHAPE_DEFAULT);
uint32_t end_pos = (uint32_t)cb->write_pos;
@@ -177,7 +197,9 @@ void gen_branch(codeblock_t* cb, codeblock_t* ocb, blockid_t target0, blockid_t
start_pos,
end_pos,
{ target0, target1 },
- gen_fn
+ { dst_addr0, dst_addr1 },
+ gen_fn,
+ SHAPE_DEFAULT
};
assert (num_branches < MAX_BRANCHES);