From 64e3410ebe5786501cb057aa5a0ae4c65f4a7556 Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Tue, 14 Nov 2023 20:54:54 -0800 Subject: [PATCH] [Rust] Improve ergonomics of `generate` function in `ChatModule` (#1262) Following PR #1253, I think ergonomics of the `generate` function of `ChatModule` can be improved (given it's an important public-facing API). This PR simplifies the function's usage by implementing the `From` trait for the `Prompt` enum. Also updated the example code. Now the interface changes to: ```rust /// Single prompt case: cm.generate("what is the meaning of life?", None) /// Multiple prompt case: let messages: Vec = vec![message1, message2, message3]; let output = cm.generate(messages, None).unwrap(); ``` --- rust/examples/mlc_chat.rs | 13 +++---------- rust/src/chat_module.rs | 27 +++++++++++++++++++++++---- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/rust/examples/mlc_chat.rs b/rust/examples/mlc_chat.rs index 0e4c6a916f..fa7132b052 100644 --- a/rust/examples/mlc_chat.rs +++ b/rust/examples/mlc_chat.rs @@ -1,16 +1,11 @@ extern crate mlc_llm; -use mlc_llm::chat_module::{ChatMessage, ChatModule, Prompt}; +use mlc_llm::chat_module::{ChatMessage, ChatModule}; fn main() { // Single prompt example let cm = ChatModule::new("/path/to/Llama2-13B-q8f16_1", "rocm", None).unwrap(); - let output = cm - .generate( - &Prompt::String("what is the meaning of life?".to_owned()), - None, - ) - .unwrap(); + let output = cm.generate("what is the meaning of life?", None).unwrap(); println!("resp: {:?}", output); println!("stats: {:?}", cm.stats(false)); @@ -38,9 +33,7 @@ fn main() { let messages = vec![message1, message2, message3, message4, message5]; - let prompt = Prompt::MessageList(messages); - - let output = cm.generate(&prompt, None).unwrap(); + let output = cm.generate(messages, None).unwrap(); println!("resp: {:?}", output); println!("stats: {:?}", cm.stats(false)); } diff --git a/rust/src/chat_module.rs b/rust/src/chat_module.rs index c82e4d0bf0..e1882a3fd6 100644 --- a/rust/src/chat_module.rs +++ b/rust/src/chat_module.rs @@ -35,6 +35,24 @@ pub enum Prompt { MessageList(Vec), } +impl From<&str> for Prompt { + fn from(s: &str) -> Self { + Prompt::String(s.to_owned()) + } +} + +impl From for Prompt { + fn from(s: String) -> Self { + Prompt::String(s) + } +} + +impl From> for Prompt { + fn from(messages: Vec) -> Self { + Prompt::MessageList(messages) + } +} + #[derive(Debug, Copy, Clone)] pub enum PlaceInPrompt { All = 0, @@ -266,12 +284,12 @@ fn get_lib_module_path( /// let cm = ChatModule::new("Llama-2-7b-chat-hf-q4f16_1", "cuda", None, None).unwrap(); /// /// // Generate a response for a given prompt -/// let output = cm.generate(&Prompt::String("what is the meaning of life?".to_owned()), None).unwrap(); +/// let output = cm.generate("what is the meaning of life?", None).unwrap(); /// /// // Print prefill and decode performance statistics /// println!("Statistics: {:?}\n", cm.stats(false).unwrap()); /// -/// let output = cm.generate(&Prompt::String("what is Rust?".to_owned()), None).unwrap(); +/// let output = cm.generate("what is Rust?", None).unwrap(); /// ``` pub struct ChatModule { chat_module: Module, @@ -428,7 +446,7 @@ impl ChatModule { /// response. pub fn generate( &self, - prompt: &Prompt, + prompt: impl Into, generation_config: Option<&GenerationConfig>, ) -> Result> { // TODO: add progress_callback @@ -441,9 +459,10 @@ impl ChatModule { } } + let prompt = prompt.into(); for _ in 0..num_return_sequences { self.reset_chat().unwrap(); - self.prefill(prompt, true, PlaceInPrompt::All, generation_config) + self.prefill(&prompt, true, PlaceInPrompt::All, generation_config) .unwrap(); while !self.stopped().unwrap() {