Skip to content

Commit

Permalink
Refactor to better adhere to Single Responsibility Principle
Browse files Browse the repository at this point in the history
  • Loading branch information
craigmayhew committed Jan 30, 2024
1 parent 8b20496 commit 220e1e3
Showing 1 changed file with 40 additions and 43 deletions.
83 changes: 40 additions & 43 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,37 +57,7 @@ const MAX_TOKENS: &i32 = &4096;
/// Defines default temperature of response
const TEMPERATURE: &f32 = &0.6;

/// # Initialise `ChatBody` Struct
// TODO: This can probably be part of a lrger refactor so we aren't passing so many tuples back and forther between functions. i.e. we have ChatBody, just use that
fn initialise_chat_body (max_tokens: i32, temperature: f32, top_p: f32, conversation_messages: Vec<Message>) -> ChatBody {
let chatbody = ChatBody {
model: MODEL.to_owned(),
max_tokens: Some(max_tokens),
temperature: Some(temperature),
top_p: Some(top_p),
n: Some(1),
stream: Some(false), // streaming output is not yet supported by this rust app
stop: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
user: None,
messages: conversation_messages,
};

info!("ChatBody struct generated");
debug!("ChatBody struct: {:?} ", chatbody);
chatbody
}

/// # Send Request To Openai API
///
/// Loads the OPENAI_API_KEY environment variable, connects to OpenAI API, sends chat
async fn send_to_gpt4(input: &str, arguments: (String, i32, f32, f32, bool)) -> Result<String, reqwest::Error> {
// debug log
debug!("entered send_to_gpt4()");

let (prepend, max_tokens, temperature, top_p, _render_markdown) = arguments;
fn create_conversation(prepend: String, input: &str) -> Vec<openai_api_rust::Message> {
let mut conversation_messages = vec![
Message { role: Role::System, content: "You are a helpful assistant.".to_string() },
];
Expand All @@ -99,10 +69,20 @@ async fn send_to_gpt4(input: &str, arguments: (String, i32, f32, f32, bool)) ->
if !atty::is(Stream::Stdin) {
conversation_messages.push(Message { role: Role::User, content: input.to_string() });
}

conversation_messages
}

/// # Send Request To Openai API
///
/// Loads the OPENAI_API_KEY environment variable, connects to OpenAI API, sends chat
async fn send_to_gpt4(body: ChatBody) -> Result<String, reqwest::Error> {
// debug log
debug!("entered send_to_gpt4()");

// Load API key from environment OPENAI_API_KEY
let auth = Auth::from_env().expect("Failed to read auth from environment");
let openai = OpenAI::new(auth, "https://api.openai.com/v1/");
let body = initialise_chat_body(max_tokens, temperature, top_p, conversation_messages);
let chat_completion = openai.chat_completion_create(&body).expect("chat completion failed");
let choice = chat_completion.choices;
let message = &choice[0].message.as_ref().expect("Failed to read message from API");
Expand Down Expand Up @@ -180,18 +160,36 @@ fn args_setup() -> Command {
/// # Parse Command Line Arguments
///
/// Arguments are set to defaults where ommitted
fn args_read (args_setup: Command) -> (std::string::String, i32, f32, f32, bool) {
fn args_read(input: &str, args_setup: Command) -> (ChatBody, bool) {
let matches = args_setup.get_matches();

let empty_string = String::from("");

let prepend = matches.get_one::<String>("prepend").unwrap_or(&empty_string);
let max_tokens = matches.get_one::<i32>("max_tokens").unwrap_or(MAX_TOKENS);
let temperature = matches.get_one::<f32>("temperature").unwrap_or(TEMPERATURE);
let top_p = matches.get_one::<f32>("top_p").unwrap_or(&0.95);
let render_markdown = matches.get_one::<bool>("markdown").unwrap_or(&false);
let prepend = matches.get_one::<String>("prepend").unwrap_or(&empty_string).to_owned();
let max_tokens = *matches.get_one::<i32>("max_tokens").unwrap_or(MAX_TOKENS);
let temperature = *matches.get_one::<f32>("temperature").unwrap_or(TEMPERATURE);
let top_p = *matches.get_one::<f32>("top_p").unwrap_or(&0.95);
let render_markdown = *matches.get_one::<bool>("markdown").unwrap_or(&false);

let chatbody = ChatBody {
model: MODEL.to_owned(),
max_tokens: Some(max_tokens),
temperature: Some(temperature),
top_p: Some(top_p),
n: Some(1),
stream: Some(false), // streaming output is not yet supported by this rust app
stop: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
user: None,
messages: create_conversation(prepend, input),
};

info!("ChatBody struct generated");
debug!("ChatBody struct: {:?} ", chatbody);

(prepend.to_owned(),max_tokens.to_owned(),temperature.to_owned(),top_p.to_owned(),render_markdown.to_owned())
(chatbody,render_markdown)
}

/// # Render in Markdown, Plaintext or Error
Expand Down Expand Up @@ -231,15 +229,14 @@ async fn main() {
// enable logging
env_logger::init();

let parsed_arguments = args_read(args_setup());
let render_markdown = parsed_arguments.4;

let mut input = String::new();
// if data is being piped in
// this check is necessary or we hang the whole program waiting for stdin when none arrives
if !atty::is(Stream::Stdin) {
io::stdin().read_to_string(&mut input).expect("Failed to read from stdin");
}

markdown_plaintext_or_error(send_to_gpt4(&input, parsed_arguments).await, render_markdown);
let (chat_body,render_markdown) = args_read(&input, args_setup());

markdown_plaintext_or_error(send_to_gpt4(chat_body).await, render_markdown);
}

0 comments on commit 220e1e3

Please sign in to comment.