diff options
| author | Jeff Zhang <jeff@39bytes.dev> | 2026-01-20 18:31:26 -0500 |
|---|---|---|
| committer | Alan Wu <XrXr@users.noreply.github.com> | 2026-01-20 19:42:25 -0500 |
| commit | f7e73ba3bf9ced61ac9cca5cf042fd0efe398f69 (patch) | |
| tree | c2aa093c35376c849ac01ad1c4e80f1751633bcb | |
| parent | e24b52885feaa87cdb5796c2a08e5995274e83cb (diff) | |
ZJIT: A64: Avoid gaps in the stack when preserving registers for calls
Previously, we used a `str x, [sp, #-0x10]!` for each value, which
left an 8-byte gap. Use STP to store a pair at a time instead.
| -rw-r--r-- | zjit/src/backend/arm64/mod.rs | 110 | ||||
| -rw-r--r-- | zjit/src/backend/lir.rs | 54 | ||||
| -rw-r--r-- | zjit/src/backend/x86_64/mod.rs | 121 | ||||
| -rw-r--r-- | zjit/src/codegen.rs | 26 |
4 files changed, 301 insertions, 10 deletions
diff --git a/zjit/src/backend/arm64/mod.rs b/zjit/src/backend/arm64/mod.rs index a019e2037d..574249dabd 100644 --- a/zjit/src/backend/arm64/mod.rs +++ b/zjit/src/backend/arm64/mod.rs @@ -1420,12 +1420,20 @@ impl Assembler { Insn::CPush(opnd) => { emit_push(cb, opnd.into()); }, + Insn::CPushPair(opnd0, opnd1) => { + // Second operand ends up at the lower stack address + stp_pre(cb, opnd1.into(), opnd0.into(), A64Opnd::new_mem(64, C_SP_REG, -C_SP_STEP)); + }, Insn::CPop { out } => { emit_pop(cb, out.into()); }, Insn::CPopInto(opnd) => { emit_pop(cb, opnd.into()); }, + Insn::CPopPairInto(opnd0, opnd1) => { + // First operand is popped from the lower stack address + ldp_post(cb, opnd0.into(), opnd1.into(), A64Opnd::new_mem(64, C_SP_REG, C_SP_STEP)); + }, Insn::CCall { fptr, .. } => { match fptr { Opnd::UImm(fptr) => { @@ -2663,6 +2671,76 @@ mod tests { } #[test] + fn test_ccall_register_preservation_even() { + let (mut asm, mut cb) = setup_asm(); + + let v0 = asm.load(1.into()); + let v1 = asm.load(2.into()); + let v2 = asm.load(3.into()); + let v3 = asm.load(4.into()); + asm.ccall(0 as _, vec![]); + _ = asm.add(v0, v1); + _ = asm.add(v2, v3); + + asm.compile_with_num_regs(&mut cb, ALLOC_REGS.len()); + + assert_disasm_snapshot!(cb.disasm(), @" + 0x0: mov x0, #1 + 0x4: mov x1, #2 + 0x8: mov x2, #3 + 0xc: mov x3, #4 + 0x10: mov x4, x0 + 0x14: stp x2, x1, [sp, #-0x10]! + 0x18: stp x4, x3, [sp, #-0x10]! + 0x1c: mov x16, #0 + 0x20: blr x16 + 0x24: ldp x4, x3, [sp], #0x10 + 0x28: ldp x2, x1, [sp], #0x10 + 0x2c: adds x4, x4, x1 + 0x30: adds x2, x2, x3 + "); + assert_snapshot!(cb.hexdump(), @"200080d2410080d2620080d2830080d2e40300aae207bfa9e40fbfa9100080d200023fd6e40fc1a8e207c1a8840001ab420003ab"); + } + + #[test] + fn test_ccall_register_preservation_odd() { + let (mut asm, mut cb) = setup_asm(); + + let v0 = asm.load(1.into()); + let v1 = asm.load(2.into()); + let v2 = asm.load(3.into()); + let v3 = asm.load(4.into()); + let v4 = asm.load(5.into()); + asm.ccall(0 as _, vec![]); + _ = asm.add(v0, v1); + _ = asm.add(v2, v3); + _ = asm.add(v2, v4); + + asm.compile_with_num_regs(&mut cb, ALLOC_REGS.len()); + + assert_disasm_snapshot!(cb.disasm(), @" + 0x0: mov x0, #1 + 0x4: mov x1, #2 + 0x8: mov x2, #3 + 0xc: mov x3, #4 + 0x10: mov x4, #5 + 0x14: mov x5, x0 + 0x18: stp x2, x1, [sp, #-0x10]! + 0x1c: stp x4, x3, [sp, #-0x10]! + 0x20: str x5, [sp, #-0x10]! + 0x24: mov x16, #0 + 0x28: blr x16 + 0x2c: ldr x5, [sp], #0x10 + 0x30: ldp x4, x3, [sp], #0x10 + 0x34: ldp x2, x1, [sp], #0x10 + 0x38: adds x5, x5, x1 + 0x3c: adds x0, x2, x3 + 0x40: adds x2, x2, x4 + "); + assert_snapshot!(cb.hexdump(), @"200080d2410080d2620080d2830080d2a40080d2e50300aae207bfa9e40fbfa9e50f1ff8100080d200023fd6e50741f8e40fc1a8e207c1a8a50001ab400003ab420004ab"); + } + + #[test] fn test_ccall_resolve_parallel_moves_large_cycle() { let (mut asm, mut cb) = setup_asm(); @@ -2686,6 +2764,38 @@ mod tests { } #[test] + fn test_cpush_pair() { + let (mut asm, mut cb) = setup_asm(); + let v0 = asm.load(1.into()); + let v1 = asm.load(2.into()); + asm.cpush_pair(v0, v1); + asm.compile_with_num_regs(&mut cb, ALLOC_REGS.len()); + + assert_disasm_snapshot!(cb.disasm(), @" + 0x0: mov x0, #1 + 0x4: mov x1, #2 + 0x8: stp x1, x0, [sp, #-0x10]! + "); + assert_snapshot!(cb.hexdump(), @"200080d2410080d2e103bfa9"); + } + + #[test] + fn test_cpop_pair_into() { + let (mut asm, mut cb) = setup_asm(); + let v0 = asm.load(1.into()); + let v1 = asm.load(2.into()); + asm.cpop_pair_into(v0, v1); + asm.compile_with_num_regs(&mut cb, ALLOC_REGS.len()); + + assert_disasm_snapshot!(cb.disasm(), @" + 0x0: mov x0, #1 + 0x4: mov x1, #2 + 0x8: ldp x0, x1, [sp], #0x10 + "); + assert_snapshot!(cb.hexdump(), @"200080d2410080d2e007c1a8"); + } + + #[test] fn test_split_spilled_lshift() { let (mut asm, mut cb) = setup_asm(); diff --git a/zjit/src/backend/lir.rs b/zjit/src/backend/lir.rs index 69ddbf2471..bb39d85cc8 100644 --- a/zjit/src/backend/lir.rs +++ b/zjit/src/backend/lir.rs @@ -380,9 +380,17 @@ pub enum Insn { /// Pop a register from the C stack and store it into another register CPopInto(Opnd), + /// Pop a pair of registers from the C stack and store it into a pair of registers. + /// The registers are popped from left to right. + CPopPairInto(Opnd, Opnd), + /// Push a register onto the C stack CPush(Opnd), + /// Push a pair of registers onto the C stack. + /// The registers are pushed from left to right. + CPushPair(Opnd, Opnd), + // C function call with N arguments (variadic) CCall { opnds: Vec<Opnd>, @@ -609,7 +617,9 @@ impl Insn { Insn::Cmp { .. } => "Cmp", Insn::CPop { .. } => "CPop", Insn::CPopInto(_) => "CPopInto", + Insn::CPopPairInto(_, _) => "CPopPairInto", Insn::CPush(_) => "CPush", + Insn::CPushPair(_, _) => "CPushPair", Insn::CCall { .. } => "CCall", Insn::CRet(_) => "CRet", Insn::CSelE { .. } => "CSelE", @@ -865,6 +875,8 @@ impl<'a> Iterator for InsnOpndIterator<'a> { }, Insn::Add { left: opnd0, right: opnd1, .. } | Insn::And { left: opnd0, right: opnd1, .. } | + Insn::CPushPair(opnd0, opnd1) | + Insn::CPopPairInto(opnd0, opnd1) | Insn::Cmp { left: opnd0, right: opnd1 } | Insn::CSelE { truthy: opnd0, falsy: opnd1, .. } | Insn::CSelG { truthy: opnd0, falsy: opnd1, .. } | @@ -1034,6 +1046,8 @@ impl<'a> InsnOpndMutIterator<'a> { }, Insn::Add { left: opnd0, right: opnd1, .. } | Insn::And { left: opnd0, right: opnd1, .. } | + Insn::CPushPair(opnd0, opnd1) | + Insn::CPopPairInto(opnd0, opnd1) | Insn::Cmp { left: opnd0, right: opnd1 } | Insn::CSelE { truthy: opnd0, falsy: opnd1, .. } | Insn::CSelG { truthy: opnd0, falsy: opnd1, .. } | @@ -1592,9 +1606,19 @@ impl Assembler saved_regs = pool.live_regs(); // Save live registers - for &(reg, _) in saved_regs.iter() { - asm.cpush(Opnd::Reg(reg)); - pool.dealloc_opnd(&Opnd::Reg(reg)); + for pair in saved_regs.chunks(2) { + match *pair { + [(reg0, _), (reg1, _)] => { + asm.cpush_pair(Opnd::Reg(reg0), Opnd::Reg(reg1)); + pool.dealloc_opnd(&Opnd::Reg(reg0)); + pool.dealloc_opnd(&Opnd::Reg(reg1)); + } + [(reg, _)] => { + asm.cpush(Opnd::Reg(reg)); + pool.dealloc_opnd(&Opnd::Reg(reg)); + } + _ => unreachable!("chunks(2)") + } } // On x86_64, maintain 16-byte stack alignment if cfg!(target_arch = "x86_64") && saved_regs.len() % 2 == 1 { @@ -1725,9 +1749,19 @@ impl Assembler asm.cpop_into(Opnd::Reg(saved_regs.last().unwrap().0)); } // Restore saved registers - for &(reg, vreg_idx) in saved_regs.iter().rev() { - asm.cpop_into(Opnd::Reg(reg)); - pool.take_reg(®, vreg_idx); + for pair in saved_regs.chunks(2).rev() { + match *pair { + [(reg, vreg_idx)] => { + asm.cpop_into(Opnd::Reg(reg)); + pool.take_reg(®, vreg_idx); + } + [(reg0, vreg_idx0), (reg1, vreg_idx1)] => { + asm.cpop_pair_into(Opnd::Reg(reg1), Opnd::Reg(reg0)); + pool.take_reg(®1, vreg_idx1); + pool.take_reg(®0, vreg_idx0); + } + _ => unreachable!("chunks(2)") + } } saved_regs.clear(); } @@ -2125,10 +2159,18 @@ impl Assembler { self.push_insn(Insn::CPopInto(opnd)); } + pub fn cpop_pair_into(&mut self, opnd0: Opnd, opnd1: Opnd) { + self.push_insn(Insn::CPopPairInto(opnd0, opnd1)); + } + pub fn cpush(&mut self, opnd: Opnd) { self.push_insn(Insn::CPush(opnd)); } + pub fn cpush_pair(&mut self, opnd0: Opnd, opnd1: Opnd) { + self.push_insn(Insn::CPushPair(opnd0, opnd1)); + } + pub fn cret(&mut self, opnd: Opnd) { self.push_insn(Insn::CRet(opnd)); } diff --git a/zjit/src/backend/x86_64/mod.rs b/zjit/src/backend/x86_64/mod.rs index c1b9b2da13..38b9f2791b 100644 --- a/zjit/src/backend/x86_64/mod.rs +++ b/zjit/src/backend/x86_64/mod.rs @@ -851,12 +851,20 @@ impl Assembler { Insn::CPush(opnd) => { push(cb, opnd.into()); }, + Insn::CPushPair(opnd0, opnd1) => { + push(cb, opnd0.into()); + push(cb, opnd1.into()); + }, Insn::CPop { out } => { pop(cb, out.into()); }, Insn::CPopInto(opnd) => { pop(cb, opnd.into()); }, + Insn::CPopPairInto(opnd0, opnd1) => { + pop(cb, opnd0.into()); + pop(cb, opnd1.into()); + }, // C function call Insn::CCall { fptr, .. } => { @@ -1649,6 +1657,119 @@ mod tests { } #[test] + fn test_ccall_register_preservation_even() { + let (mut asm, mut cb) = setup_asm(); + + let v0 = asm.load(1.into()); + let v1 = asm.load(2.into()); + let v2 = asm.load(3.into()); + let v3 = asm.load(4.into()); + asm.ccall(0 as _, vec![]); + _ = asm.add(v0, v1); + _ = asm.add(v2, v3); + + asm.compile_with_num_regs(&mut cb, ALLOC_REGS.len()); + + assert_disasm_snapshot!(cb.disasm(), @" + 0x0: mov edi, 1 + 0x5: mov esi, 2 + 0xa: mov edx, 3 + 0xf: mov ecx, 4 + 0x14: push rdi + 0x15: push rsi + 0x16: push rdx + 0x17: push rcx + 0x18: mov eax, 0 + 0x1d: call rax + 0x1f: pop rcx + 0x20: pop rdx + 0x21: pop rsi + 0x22: pop rdi + 0x23: add rdi, rsi + 0x26: add rdx, rcx + "); + assert_snapshot!(cb.hexdump(), @"bf01000000be02000000ba03000000b90400000057565251b800000000ffd0595a5e5f4801f74801ca"); + } + + #[test] + fn test_ccall_register_preservation_odd() { + let (mut asm, mut cb) = setup_asm(); + + let v0 = asm.load(1.into()); + let v1 = asm.load(2.into()); + let v2 = asm.load(3.into()); + let v3 = asm.load(4.into()); + let v4 = asm.load(5.into()); + asm.ccall(0 as _, vec![]); + _ = asm.add(v0, v1); + _ = asm.add(v2, v3); + _ = asm.add(v2, v4); + + asm.compile_with_num_regs(&mut cb, ALLOC_REGS.len()); + + assert_disasm_snapshot!(cb.disasm(), @" + 0x0: mov edi, 1 + 0x5: mov esi, 2 + 0xa: mov edx, 3 + 0xf: mov ecx, 4 + 0x14: mov r8d, 5 + 0x1a: push rdi + 0x1b: push rsi + 0x1c: push rdx + 0x1d: push rcx + 0x1e: push r8 + 0x20: push r8 + 0x22: mov eax, 0 + 0x27: call rax + 0x29: pop r8 + 0x2b: pop r8 + 0x2d: pop rcx + 0x2e: pop rdx + 0x2f: pop rsi + 0x30: pop rdi + 0x31: add rdi, rsi + 0x34: mov rdi, rdx + 0x37: add rdi, rcx + 0x3a: add rdx, r8 + "); + assert_snapshot!(cb.hexdump(), @"bf01000000be02000000ba03000000b90400000041b8050000005756525141504150b800000000ffd041584158595a5e5f4801f74889d74801cf4c01c2"); + } + + #[test] + fn test_cpush_pair() { + let (mut asm, mut cb) = setup_asm(); + let v0 = asm.load(1.into()); + let v1 = asm.load(2.into()); + asm.cpush_pair(v0, v1); + asm.compile_with_num_regs(&mut cb, ALLOC_REGS.len()); + + assert_disasm_snapshot!(cb.disasm(), @" + 0x0: mov edi, 1 + 0x5: mov esi, 2 + 0xa: push rdi + 0xb: push rsi + "); + assert_snapshot!(cb.hexdump(), @"bf01000000be020000005756"); + } + + #[test] + fn test_cpop_pair_into() { + let (mut asm, mut cb) = setup_asm(); + let v0 = asm.load(1.into()); + let v1 = asm.load(2.into()); + asm.cpop_pair_into(v0, v1); + asm.compile_with_num_regs(&mut cb, ALLOC_REGS.len()); + + assert_disasm_snapshot!(cb.disasm(), @" + 0x0: mov edi, 1 + 0x5: mov esi, 2 + 0xa: pop rdi + 0xb: pop rsi + "); + assert_snapshot!(cb.hexdump(), @"bf01000000be020000005f5e"); + } + + #[test] fn test_cmov_mem() { let (mut asm, mut cb) = setup_asm(); diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs index 0ae85c24a2..4dae41bf02 100644 --- a/zjit/src/codegen.rs +++ b/zjit/src/codegen.rs @@ -2638,8 +2638,17 @@ pub fn gen_function_stub_hit_trampoline(cb: &mut CodeBlock) -> Result<CodePtr, C asm.frame_setup(&[]); asm_comment!(asm, "preserve argument registers"); - for ® in ALLOC_REGS.iter() { - asm.cpush(Opnd::Reg(reg)); + + for pair in ALLOC_REGS.chunks(2) { + match *pair { + [reg0, reg1] => { + asm.cpush_pair(Opnd::Reg(reg0), Opnd::Reg(reg1)); + } + [reg] => { + asm.cpush(Opnd::Reg(reg)); + } + _ => unreachable!("chunks(2)") + } } if cfg!(target_arch = "x86_64") && ALLOC_REGS.len() % 2 == 1 { asm.cpush(Opnd::Reg(ALLOC_REGS[0])); // maintain alignment for x86_64 @@ -2653,8 +2662,17 @@ pub fn gen_function_stub_hit_trampoline(cb: &mut CodeBlock) -> Result<CodePtr, C if cfg!(target_arch = "x86_64") && ALLOC_REGS.len() % 2 == 1 { asm.cpop_into(Opnd::Reg(ALLOC_REGS[0])); } - for ® in ALLOC_REGS.iter().rev() { - asm.cpop_into(Opnd::Reg(reg)); + + for pair in ALLOC_REGS.chunks(2).rev() { + match *pair { + [reg] => { + asm.cpop_into(Opnd::Reg(reg)); + } + [reg0, reg1] => { + asm.cpop_pair_into(Opnd::Reg(reg1), Opnd::Reg(reg0)); + } + _ => unreachable!("chunks(2)") + } } // Discard the current frame since the JIT function will set it up again |
