From a71c3aa2d45375c5cb8498b66a49a88ed08c9c3c Mon Sep 17 00:00:00 2001
From: Dustin Blackman <dev@dustinblackman.com>
Date: Sat, 11 Nov 2023 09:05:47 -0500
Subject: [PATCH] fix: OpenAI authorization

---
 src/infrastructure/backends/openai.rs      | 2 ++
 src/infrastructure/backends/openai_test.rs | 4 +++-
 2 files changed, 5 insertions(+), 1 deletion(-)

diff --git a/src/infrastructure/backends/openai.rs b/src/infrastructure/backends/openai.rs
index 3a0a931..96d03f6 100644
--- a/src/infrastructure/backends/openai.rs
+++ b/src/infrastructure/backends/openai.rs
@@ -113,6 +113,7 @@ impl Backend for OpenAI {
     async fn list_models(&self) -> Result<Vec<String>> {
         let res = reqwest::Client::new()
             .get(format!("{url}/v1/models", url = self.url))
+            .header("Authorization", format!("Bearer {}", self.token))
             .send()
             .await?
             .json::<ModelListResponse>()
@@ -154,6 +155,7 @@ impl Backend for OpenAI {
 
         let res = reqwest::Client::new()
             .post(format!("{url}/v1/chat/completions", url = self.url))
+            .header("Authorization", format!("Bearer {}", self.token))
             .json(&req)
             .send()
             .await?;
diff --git a/src/infrastructure/backends/openai_test.rs b/src/infrastructure/backends/openai_test.rs
index d2089a9..21d97ba 100644
--- a/src/infrastructure/backends/openai_test.rs
+++ b/src/infrastructure/backends/openai_test.rs
@@ -86,15 +86,16 @@ async fn it_lists_models() -> Result<()> {
     let mut server = mockito::Server::new();
     let mock = server
         .mock("GET", "/v1/models")
+        .match_header("Authorization", "Bearer abc")
         .with_status(200)
         .with_body(body)
         .create();
 
     let backend = OpenAI::with_url(server.url());
     let res = backend.list_models().await?;
+    mock.assert();
 
     assert_eq!(res, vec!["first".to_string(), "second".to_string()]);
-    mock.assert();
 
     return Ok(());
 }
@@ -129,6 +130,7 @@ async fn it_gets_completions() -> Result<()> {
     let mut server = mockito::Server::new();
     let mock = server
         .mock("POST", "/v1/chat/completions")
+        .match_header("Authorization", "Bearer abc")
         .with_status(200)
         .with_body(body)
         .create();