From c6c04dd53e3e70d0dd47c3d8b702414412d93e90 Mon Sep 17 00:00:00 2001
From: Alessandro Decina <alessandro.d@gmail.com>
Date: Fri, 15 Sep 2023 16:37:04 +0000
Subject: [PATCH] jit: make sure RSP is 16 byte aligned when we call into rust
 code

The System V ABI requires $rsp to be 16 bytes aligned. Internally we
don't emit any instructions that require alignment but when we call out
to rustc generated code we must align.
---
 src/jit.rs | 21 +++++++++++++++++----
 1 file changed, 17 insertions(+), 4 deletions(-)

diff --git a/src/jit.rs b/src/jit.rs
index 7c5038f2..d4e36f7d 100644
--- a/src/jit.rs
+++ b/src/jit.rs
@@ -923,13 +923,23 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> {
         for reg in saved_registers.iter() {
             self.emit_ins(X86Instruction::push(*reg, None));
         }
-    
+
+        // Align RSP to 16 bytes
+        self.emit_ins(X86Instruction::push(RSP, None));
+        self.emit_ins(X86Instruction::push(RSP, Some(X86IndirectAccess::OffsetIndexShift(0, RSP, 0))));
+        self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x81, 4, RSP, -16, None));
+
+        let stack_arguments = arguments.len().saturating_sub(ARGUMENT_REGISTERS.len()) as i64;
+        if stack_arguments % 2 != 0 {
+            // If we're going to pass an odd number of stack args we need to pad
+            // to preserve alignment
+            self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x81, 5, RSP, 8, None));
+        }
+
         // Pass arguments
-        let mut stack_arguments = 0;
         for argument in arguments {
             let is_stack_argument = argument.index >= ARGUMENT_REGISTERS.len();
             let dst = if is_stack_argument {
-                stack_arguments += 1;
                 R11
             } else {
                 ARGUMENT_REGISTERS[argument.index]
@@ -997,7 +1007,10 @@ impl<'a, C: ContextObject> JitCompiler<'a, C> {
         }
     
         // Restore registers from stack
-        self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x81, 0, RSP, stack_arguments * 8, None));
+        self.emit_ins(X86Instruction::alu(OperandSize::S64, 0x81, 0, RSP,
+            if stack_arguments % 2 != 0 { stack_arguments + 1 } else { stack_arguments } * 8, None));
+        self.emit_ins(X86Instruction::load(OperandSize::S64, RSP, RSP, X86IndirectAccess::OffsetIndexShift(8, RSP, 0)));
+
         for reg in saved_registers.iter().rev() {
             self.emit_ins(X86Instruction::pop(*reg));
         }