summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeff Zhang <jeff@39bytes.dev>2026-01-20 18:31:26 -0500
committerAlan Wu <XrXr@users.noreply.github.com>2026-01-20 19:42:25 -0500
commitf7e73ba3bf9ced61ac9cca5cf042fd0efe398f69 (patch)
treec2aa093c35376c849ac01ad1c4e80f1751633bcb
parente24b52885feaa87cdb5796c2a08e5995274e83cb (diff)
ZJIT: A64: Avoid gaps in the stack when preserving registers for calls
Previously, we used a `str x, [sp, #-0x10]!` for each value, which left an 8-byte gap. Use STP to store a pair at a time instead.
-rw-r--r--zjit/src/backend/arm64/mod.rs110
-rw-r--r--zjit/src/backend/lir.rs54
-rw-r--r--zjit/src/backend/x86_64/mod.rs121
-rw-r--r--zjit/src/codegen.rs26
4 files changed, 301 insertions, 10 deletions
diff --git a/zjit/src/backend/arm64/mod.rs b/zjit/src/backend/arm64/mod.rs
index a019e2037d..574249dabd 100644
--- a/zjit/src/backend/arm64/mod.rs
+++ b/zjit/src/backend/arm64/mod.rs
@@ -1420,12 +1420,20 @@ impl Assembler {
Insn::CPush(opnd) => {
emit_push(cb, opnd.into());
},
+ Insn::CPushPair(opnd0, opnd1) => {
+ // Second operand ends up at the lower stack address
+ stp_pre(cb, opnd1.into(), opnd0.into(), A64Opnd::new_mem(64, C_SP_REG, -C_SP_STEP));
+ },
Insn::CPop { out } => {
emit_pop(cb, out.into());
},
Insn::CPopInto(opnd) => {
emit_pop(cb, opnd.into());
},
+ Insn::CPopPairInto(opnd0, opnd1) => {
+ // First operand is popped from the lower stack address
+ ldp_post(cb, opnd0.into(), opnd1.into(), A64Opnd::new_mem(64, C_SP_REG, C_SP_STEP));
+ },
Insn::CCall { fptr, .. } => {
match fptr {
Opnd::UImm(fptr) => {
@@ -2663,6 +2671,76 @@ mod tests {
}
#[test]
+ fn test_ccall_register_preservation_even() {
+ let (mut asm, mut cb) = setup_asm();
+
+ let v0 = asm.load(1.into());
+ let v1 = asm.load(2.into());
+ let v2 = asm.load(3.into());
+ let v3 = asm.load(4.into());
+ asm.ccall(0 as _, vec![]);
+ _ = asm.add(v0, v1);
+ _ = asm.add(v2, v3);
+
+ asm.compile_with_num_regs(&mut cb, ALLOC_REGS.len());
+
+ assert_disasm_snapshot!(cb.disasm(), @"
+ 0x0: mov x0, #1
+ 0x4: mov x1, #2
+ 0x8: mov x2, #3
+ 0xc: mov x3, #4
+ 0x10: mov x4, x0
+ 0x14: stp x2, x1, [sp, #-0x10]!
+ 0x18: stp x4, x3, [sp, #-0x10]!
+ 0x1c: mov x16, #0
+ 0x20: blr x16
+ 0x24: ldp x4, x3, [sp], #0x10
+ 0x28: ldp x2, x1, [sp], #0x10
+ 0x2c: adds x4, x4, x1
+ 0x30: adds x2, x2, x3
+ ");
+ assert_snapshot!(cb.hexdump(), @"200080d2410080d2620080d2830080d2e40300aae207bfa9e40fbfa9100080d200023fd6e40fc1a8e207c1a8840001ab420003ab");
+ }
+
+ #[test]
+ fn test_ccall_register_preservation_odd() {
+ let (mut asm, mut cb) = setup_asm();
+
+ let v0 = asm.load(1.into());
+ let v1 = asm.load(2.into());
+ let v2 = asm.load(3.into());
+ let v3 = asm.load(4.into());
+ let v4 = asm.load(5.into());
+ asm.ccall(0 as _, vec![]);
+ _ = asm.add(v0, v1);
+ _ = asm.add(v2, v3);
+ _ = asm.add(v2, v4);
+
+ asm.compile_with_num_regs(&mut cb, ALLOC_REGS.len());
+
+ assert_disasm_snapshot!(cb.disasm(), @"
+ 0x0: mov x0, #1
+ 0x4: mov x1, #2
+ 0x8: mov x2, #3
+ 0xc: mov x3, #4
+ 0x10: mov x4, #5
+ 0x14: mov x5, x0
+ 0x18: stp x2, x1, [sp, #-0x10]!
+ 0x1c: stp x4, x3, [sp, #-0x10]!
+ 0x20: str x5, [sp, #-0x10]!
+ 0x24: mov x16, #0
+ 0x28: blr x16
+ 0x2c: ldr x5, [sp], #0x10
+ 0x30: ldp x4, x3, [sp], #0x10
+ 0x34: ldp x2, x1, [sp], #0x10
+ 0x38: adds x5, x5, x1
+ 0x3c: adds x0, x2, x3
+ 0x40: adds x2, x2, x4
+ ");
+ assert_snapshot!(cb.hexdump(), @"200080d2410080d2620080d2830080d2a40080d2e50300aae207bfa9e40fbfa9e50f1ff8100080d200023fd6e50741f8e40fc1a8e207c1a8a50001ab400003ab420004ab");
+ }
+
+ #[test]
fn test_ccall_resolve_parallel_moves_large_cycle() {
let (mut asm, mut cb) = setup_asm();
@@ -2686,6 +2764,38 @@ mod tests {
}
#[test]
+ fn test_cpush_pair() {
+ let (mut asm, mut cb) = setup_asm();
+ let v0 = asm.load(1.into());
+ let v1 = asm.load(2.into());
+ asm.cpush_pair(v0, v1);
+ asm.compile_with_num_regs(&mut cb, ALLOC_REGS.len());
+
+ assert_disasm_snapshot!(cb.disasm(), @"
+ 0x0: mov x0, #1
+ 0x4: mov x1, #2
+ 0x8: stp x1, x0, [sp, #-0x10]!
+ ");
+ assert_snapshot!(cb.hexdump(), @"200080d2410080d2e103bfa9");
+ }
+
+ #[test]
+ fn test_cpop_pair_into() {
+ let (mut asm, mut cb) = setup_asm();
+ let v0 = asm.load(1.into());
+ let v1 = asm.load(2.into());
+ asm.cpop_pair_into(v0, v1);
+ asm.compile_with_num_regs(&mut cb, ALLOC_REGS.len());
+
+ assert_disasm_snapshot!(cb.disasm(), @"
+ 0x0: mov x0, #1
+ 0x4: mov x1, #2
+ 0x8: ldp x0, x1, [sp], #0x10
+ ");
+ assert_snapshot!(cb.hexdump(), @"200080d2410080d2e007c1a8");
+ }
+
+ #[test]
fn test_split_spilled_lshift() {
let (mut asm, mut cb) = setup_asm();
diff --git a/zjit/src/backend/lir.rs b/zjit/src/backend/lir.rs
index 69ddbf2471..bb39d85cc8 100644
--- a/zjit/src/backend/lir.rs
+++ b/zjit/src/backend/lir.rs
@@ -380,9 +380,17 @@ pub enum Insn {
/// Pop a register from the C stack and store it into another register
CPopInto(Opnd),
+ /// Pop a pair of registers from the C stack and store it into a pair of registers.
+ /// The registers are popped from left to right.
+ CPopPairInto(Opnd, Opnd),
+
/// Push a register onto the C stack
CPush(Opnd),
+ /// Push a pair of registers onto the C stack.
+ /// The registers are pushed from left to right.
+ CPushPair(Opnd, Opnd),
+
// C function call with N arguments (variadic)
CCall {
opnds: Vec<Opnd>,
@@ -609,7 +617,9 @@ impl Insn {
Insn::Cmp { .. } => "Cmp",
Insn::CPop { .. } => "CPop",
Insn::CPopInto(_) => "CPopInto",
+ Insn::CPopPairInto(_, _) => "CPopPairInto",
Insn::CPush(_) => "CPush",
+ Insn::CPushPair(_, _) => "CPushPair",
Insn::CCall { .. } => "CCall",
Insn::CRet(_) => "CRet",
Insn::CSelE { .. } => "CSelE",
@@ -865,6 +875,8 @@ impl<'a> Iterator for InsnOpndIterator<'a> {
},
Insn::Add { left: opnd0, right: opnd1, .. } |
Insn::And { left: opnd0, right: opnd1, .. } |
+ Insn::CPushPair(opnd0, opnd1) |
+ Insn::CPopPairInto(opnd0, opnd1) |
Insn::Cmp { left: opnd0, right: opnd1 } |
Insn::CSelE { truthy: opnd0, falsy: opnd1, .. } |
Insn::CSelG { truthy: opnd0, falsy: opnd1, .. } |
@@ -1034,6 +1046,8 @@ impl<'a> InsnOpndMutIterator<'a> {
},
Insn::Add { left: opnd0, right: opnd1, .. } |
Insn::And { left: opnd0, right: opnd1, .. } |
+ Insn::CPushPair(opnd0, opnd1) |
+ Insn::CPopPairInto(opnd0, opnd1) |
Insn::Cmp { left: opnd0, right: opnd1 } |
Insn::CSelE { truthy: opnd0, falsy: opnd1, .. } |
Insn::CSelG { truthy: opnd0, falsy: opnd1, .. } |
@@ -1592,9 +1606,19 @@ impl Assembler
saved_regs = pool.live_regs();
// Save live registers
- for &(reg, _) in saved_regs.iter() {
- asm.cpush(Opnd::Reg(reg));
- pool.dealloc_opnd(&Opnd::Reg(reg));
+ for pair in saved_regs.chunks(2) {
+ match *pair {
+ [(reg0, _), (reg1, _)] => {
+ asm.cpush_pair(Opnd::Reg(reg0), Opnd::Reg(reg1));
+ pool.dealloc_opnd(&Opnd::Reg(reg0));
+ pool.dealloc_opnd(&Opnd::Reg(reg1));
+ }
+ [(reg, _)] => {
+ asm.cpush(Opnd::Reg(reg));
+ pool.dealloc_opnd(&Opnd::Reg(reg));
+ }
+ _ => unreachable!("chunks(2)")
+ }
}
// On x86_64, maintain 16-byte stack alignment
if cfg!(target_arch = "x86_64") && saved_regs.len() % 2 == 1 {
@@ -1725,9 +1749,19 @@ impl Assembler
asm.cpop_into(Opnd::Reg(saved_regs.last().unwrap().0));
}
// Restore saved registers
- for &(reg, vreg_idx) in saved_regs.iter().rev() {
- asm.cpop_into(Opnd::Reg(reg));
- pool.take_reg(&reg, vreg_idx);
+ for pair in saved_regs.chunks(2).rev() {
+ match *pair {
+ [(reg, vreg_idx)] => {
+ asm.cpop_into(Opnd::Reg(reg));
+ pool.take_reg(&reg, vreg_idx);
+ }
+ [(reg0, vreg_idx0), (reg1, vreg_idx1)] => {
+ asm.cpop_pair_into(Opnd::Reg(reg1), Opnd::Reg(reg0));
+ pool.take_reg(&reg1, vreg_idx1);
+ pool.take_reg(&reg0, vreg_idx0);
+ }
+ _ => unreachable!("chunks(2)")
+ }
}
saved_regs.clear();
}
@@ -2125,10 +2159,18 @@ impl Assembler {
self.push_insn(Insn::CPopInto(opnd));
}
+ pub fn cpop_pair_into(&mut self, opnd0: Opnd, opnd1: Opnd) {
+ self.push_insn(Insn::CPopPairInto(opnd0, opnd1));
+ }
+
pub fn cpush(&mut self, opnd: Opnd) {
self.push_insn(Insn::CPush(opnd));
}
+ pub fn cpush_pair(&mut self, opnd0: Opnd, opnd1: Opnd) {
+ self.push_insn(Insn::CPushPair(opnd0, opnd1));
+ }
+
pub fn cret(&mut self, opnd: Opnd) {
self.push_insn(Insn::CRet(opnd));
}
diff --git a/zjit/src/backend/x86_64/mod.rs b/zjit/src/backend/x86_64/mod.rs
index c1b9b2da13..38b9f2791b 100644
--- a/zjit/src/backend/x86_64/mod.rs
+++ b/zjit/src/backend/x86_64/mod.rs
@@ -851,12 +851,20 @@ impl Assembler {
Insn::CPush(opnd) => {
push(cb, opnd.into());
},
+ Insn::CPushPair(opnd0, opnd1) => {
+ push(cb, opnd0.into());
+ push(cb, opnd1.into());
+ },
Insn::CPop { out } => {
pop(cb, out.into());
},
Insn::CPopInto(opnd) => {
pop(cb, opnd.into());
},
+ Insn::CPopPairInto(opnd0, opnd1) => {
+ pop(cb, opnd0.into());
+ pop(cb, opnd1.into());
+ },
// C function call
Insn::CCall { fptr, .. } => {
@@ -1649,6 +1657,119 @@ mod tests {
}
#[test]
+ fn test_ccall_register_preservation_even() {
+ let (mut asm, mut cb) = setup_asm();
+
+ let v0 = asm.load(1.into());
+ let v1 = asm.load(2.into());
+ let v2 = asm.load(3.into());
+ let v3 = asm.load(4.into());
+ asm.ccall(0 as _, vec![]);
+ _ = asm.add(v0, v1);
+ _ = asm.add(v2, v3);
+
+ asm.compile_with_num_regs(&mut cb, ALLOC_REGS.len());
+
+ assert_disasm_snapshot!(cb.disasm(), @"
+ 0x0: mov edi, 1
+ 0x5: mov esi, 2
+ 0xa: mov edx, 3
+ 0xf: mov ecx, 4
+ 0x14: push rdi
+ 0x15: push rsi
+ 0x16: push rdx
+ 0x17: push rcx
+ 0x18: mov eax, 0
+ 0x1d: call rax
+ 0x1f: pop rcx
+ 0x20: pop rdx
+ 0x21: pop rsi
+ 0x22: pop rdi
+ 0x23: add rdi, rsi
+ 0x26: add rdx, rcx
+ ");
+ assert_snapshot!(cb.hexdump(), @"bf01000000be02000000ba03000000b90400000057565251b800000000ffd0595a5e5f4801f74801ca");
+ }
+
+ #[test]
+ fn test_ccall_register_preservation_odd() {
+ let (mut asm, mut cb) = setup_asm();
+
+ let v0 = asm.load(1.into());
+ let v1 = asm.load(2.into());
+ let v2 = asm.load(3.into());
+ let v3 = asm.load(4.into());
+ let v4 = asm.load(5.into());
+ asm.ccall(0 as _, vec![]);
+ _ = asm.add(v0, v1);
+ _ = asm.add(v2, v3);
+ _ = asm.add(v2, v4);
+
+ asm.compile_with_num_regs(&mut cb, ALLOC_REGS.len());
+
+ assert_disasm_snapshot!(cb.disasm(), @"
+ 0x0: mov edi, 1
+ 0x5: mov esi, 2
+ 0xa: mov edx, 3
+ 0xf: mov ecx, 4
+ 0x14: mov r8d, 5
+ 0x1a: push rdi
+ 0x1b: push rsi
+ 0x1c: push rdx
+ 0x1d: push rcx
+ 0x1e: push r8
+ 0x20: push r8
+ 0x22: mov eax, 0
+ 0x27: call rax
+ 0x29: pop r8
+ 0x2b: pop r8
+ 0x2d: pop rcx
+ 0x2e: pop rdx
+ 0x2f: pop rsi
+ 0x30: pop rdi
+ 0x31: add rdi, rsi
+ 0x34: mov rdi, rdx
+ 0x37: add rdi, rcx
+ 0x3a: add rdx, r8
+ ");
+ assert_snapshot!(cb.hexdump(), @"bf01000000be02000000ba03000000b90400000041b8050000005756525141504150b800000000ffd041584158595a5e5f4801f74889d74801cf4c01c2");
+ }
+
+ #[test]
+ fn test_cpush_pair() {
+ let (mut asm, mut cb) = setup_asm();
+ let v0 = asm.load(1.into());
+ let v1 = asm.load(2.into());
+ asm.cpush_pair(v0, v1);
+ asm.compile_with_num_regs(&mut cb, ALLOC_REGS.len());
+
+ assert_disasm_snapshot!(cb.disasm(), @"
+ 0x0: mov edi, 1
+ 0x5: mov esi, 2
+ 0xa: push rdi
+ 0xb: push rsi
+ ");
+ assert_snapshot!(cb.hexdump(), @"bf01000000be020000005756");
+ }
+
+ #[test]
+ fn test_cpop_pair_into() {
+ let (mut asm, mut cb) = setup_asm();
+ let v0 = asm.load(1.into());
+ let v1 = asm.load(2.into());
+ asm.cpop_pair_into(v0, v1);
+ asm.compile_with_num_regs(&mut cb, ALLOC_REGS.len());
+
+ assert_disasm_snapshot!(cb.disasm(), @"
+ 0x0: mov edi, 1
+ 0x5: mov esi, 2
+ 0xa: pop rdi
+ 0xb: pop rsi
+ ");
+ assert_snapshot!(cb.hexdump(), @"bf01000000be020000005f5e");
+ }
+
+ #[test]
fn test_cmov_mem() {
let (mut asm, mut cb) = setup_asm();
diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs
index 0ae85c24a2..4dae41bf02 100644
--- a/zjit/src/codegen.rs
+++ b/zjit/src/codegen.rs
@@ -2638,8 +2638,17 @@ pub fn gen_function_stub_hit_trampoline(cb: &mut CodeBlock) -> Result<CodePtr, C
asm.frame_setup(&[]);
asm_comment!(asm, "preserve argument registers");
- for &reg in ALLOC_REGS.iter() {
- asm.cpush(Opnd::Reg(reg));
+
+ for pair in ALLOC_REGS.chunks(2) {
+ match *pair {
+ [reg0, reg1] => {
+ asm.cpush_pair(Opnd::Reg(reg0), Opnd::Reg(reg1));
+ }
+ [reg] => {
+ asm.cpush(Opnd::Reg(reg));
+ }
+ _ => unreachable!("chunks(2)")
+ }
}
if cfg!(target_arch = "x86_64") && ALLOC_REGS.len() % 2 == 1 {
asm.cpush(Opnd::Reg(ALLOC_REGS[0])); // maintain alignment for x86_64
@@ -2653,8 +2662,17 @@ pub fn gen_function_stub_hit_trampoline(cb: &mut CodeBlock) -> Result<CodePtr, C
if cfg!(target_arch = "x86_64") && ALLOC_REGS.len() % 2 == 1 {
asm.cpop_into(Opnd::Reg(ALLOC_REGS[0]));
}
- for &reg in ALLOC_REGS.iter().rev() {
- asm.cpop_into(Opnd::Reg(reg));
+
+ for pair in ALLOC_REGS.chunks(2).rev() {
+ match *pair {
+ [reg] => {
+ asm.cpop_into(Opnd::Reg(reg));
+ }
+ [reg0, reg1] => {
+ asm.cpop_pair_into(Opnd::Reg(reg1), Opnd::Reg(reg0));
+ }
+ _ => unreachable!("chunks(2)")
+ }
}
// Discard the current frame since the JIT function will set it up again