diff --git a/rust/examples/mlc_chat.rs b/rust/examples/mlc_chat.rs index 0e4c6a916f..fa7132b052 100644 --- a/rust/examples/mlc_chat.rs +++ b/rust/examples/mlc_chat.rs @@ -1,16 +1,11 @@ extern crate mlc_llm; -use mlc_llm::chat_module::{ChatMessage, ChatModule, Prompt}; +use mlc_llm::chat_module::{ChatMessage, ChatModule}; fn main() { // Single prompt example let cm = ChatModule::new("/path/to/Llama2-13B-q8f16_1", "rocm", None).unwrap(); - let output = cm - .generate( - &Prompt::String("what is the meaning of life?".to_owned()), - None, - ) - .unwrap(); + let output = cm.generate("what is the meaning of life?", None).unwrap(); println!("resp: {:?}", output); println!("stats: {:?}", cm.stats(false)); @@ -38,9 +33,7 @@ fn main() { let messages = vec![message1, message2, message3, message4, message5]; - let prompt = Prompt::MessageList(messages); - - let output = cm.generate(&prompt, None).unwrap(); + let output = cm.generate(messages, None).unwrap(); println!("resp: {:?}", output); println!("stats: {:?}", cm.stats(false)); } diff --git a/rust/src/chat_module.rs b/rust/src/chat_module.rs index c82e4d0bf0..e1882a3fd6 100644 --- a/rust/src/chat_module.rs +++ b/rust/src/chat_module.rs @@ -35,6 +35,24 @@ pub enum Prompt { MessageList(Vec), } +impl From<&str> for Prompt { + fn from(s: &str) -> Self { + Prompt::String(s.to_owned()) + } +} + +impl From for Prompt { + fn from(s: String) -> Self { + Prompt::String(s) + } +} + +impl From> for Prompt { + fn from(messages: Vec) -> Self { + Prompt::MessageList(messages) + } +} + #[derive(Debug, Copy, Clone)] pub enum PlaceInPrompt { All = 0, @@ -266,12 +284,12 @@ fn get_lib_module_path( /// let cm = ChatModule::new("Llama-2-7b-chat-hf-q4f16_1", "cuda", None, None).unwrap(); /// /// // Generate a response for a given prompt -/// let output = cm.generate(&Prompt::String("what is the meaning of life?".to_owned()), None).unwrap(); +/// let output = cm.generate("what is the meaning of life?", None).unwrap(); /// /// // Print prefill and decode performance statistics /// println!("Statistics: {:?}\n", cm.stats(false).unwrap()); /// -/// let output = cm.generate(&Prompt::String("what is Rust?".to_owned()), None).unwrap(); +/// let output = cm.generate("what is Rust?", None).unwrap(); /// ``` pub struct ChatModule { chat_module: Module, @@ -428,7 +446,7 @@ impl ChatModule { /// response. pub fn generate( &self, - prompt: &Prompt, + prompt: impl Into, generation_config: Option<&GenerationConfig>, ) -> Result> { // TODO: add progress_callback @@ -441,9 +459,10 @@ impl ChatModule { } } + let prompt = prompt.into(); for _ in 0..num_return_sequences { self.reset_chat().unwrap(); - self.prefill(prompt, true, PlaceInPrompt::All, generation_config) + self.prefill(&prompt, true, PlaceInPrompt::All, generation_config) .unwrap(); while !self.stopped().unwrap() {