Skip to content

Commit

Permalink
[Rust] Improve ergonomics of generate function in ChatModule (#1262)
Browse files Browse the repository at this point in the history
Following PR #1253, I think ergonomics of the `generate` function of `ChatModule` can be improved (given it's an important public-facing API). 

This PR simplifies the function's usage by implementing the `From` trait for the `Prompt` enum. Also updated the example code.

Now the interface changes to:

```rust
/// Single prompt case:
cm.generate("what is the meaning of life?", None)

/// Multiple prompt case:
let messages: Vec<ChatMessage> = vec![message1, message2, message3];
let output = cm.generate(messages, None).unwrap();
```
  • Loading branch information
YuchenJin authored Nov 15, 2023
1 parent 8304d4c commit 64e3410
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 14 deletions.
13 changes: 3 additions & 10 deletions rust/examples/mlc_chat.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
extern crate mlc_llm;

use mlc_llm::chat_module::{ChatMessage, ChatModule, Prompt};
use mlc_llm::chat_module::{ChatMessage, ChatModule};

fn main() {
// Single prompt example
let cm = ChatModule::new("/path/to/Llama2-13B-q8f16_1", "rocm", None).unwrap();
let output = cm
.generate(
&Prompt::String("what is the meaning of life?".to_owned()),
None,
)
.unwrap();
let output = cm.generate("what is the meaning of life?", None).unwrap();
println!("resp: {:?}", output);
println!("stats: {:?}", cm.stats(false));

Expand Down Expand Up @@ -38,9 +33,7 @@ fn main() {

let messages = vec![message1, message2, message3, message4, message5];

let prompt = Prompt::MessageList(messages);

let output = cm.generate(&prompt, None).unwrap();
let output = cm.generate(messages, None).unwrap();
println!("resp: {:?}", output);
println!("stats: {:?}", cm.stats(false));
}
27 changes: 23 additions & 4 deletions rust/src/chat_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,24 @@ pub enum Prompt {
MessageList(Vec<ChatMessage>),
}

impl From<&str> for Prompt {
fn from(s: &str) -> Self {
Prompt::String(s.to_owned())
}
}

impl From<String> for Prompt {
fn from(s: String) -> Self {
Prompt::String(s)
}
}

impl From<Vec<ChatMessage>> for Prompt {
fn from(messages: Vec<ChatMessage>) -> Self {
Prompt::MessageList(messages)
}
}

#[derive(Debug, Copy, Clone)]
pub enum PlaceInPrompt {
All = 0,
Expand Down Expand Up @@ -266,12 +284,12 @@ fn get_lib_module_path(
/// let cm = ChatModule::new("Llama-2-7b-chat-hf-q4f16_1", "cuda", None, None).unwrap();
///
/// // Generate a response for a given prompt
/// let output = cm.generate(&Prompt::String("what is the meaning of life?".to_owned()), None).unwrap();
/// let output = cm.generate("what is the meaning of life?", None).unwrap();
///
/// // Print prefill and decode performance statistics
/// println!("Statistics: {:?}\n", cm.stats(false).unwrap());
///
/// let output = cm.generate(&Prompt::String("what is Rust?".to_owned()), None).unwrap();
/// let output = cm.generate("what is Rust?", None).unwrap();
/// ```
pub struct ChatModule {
chat_module: Module,
Expand Down Expand Up @@ -428,7 +446,7 @@ impl ChatModule {
/// response.
pub fn generate(
&self,
prompt: &Prompt,
prompt: impl Into<Prompt>,
generation_config: Option<&GenerationConfig>,
) -> Result<Vec<String>> {
// TODO: add progress_callback
Expand All @@ -441,9 +459,10 @@ impl ChatModule {
}
}

let prompt = prompt.into();
for _ in 0..num_return_sequences {
self.reset_chat().unwrap();
self.prefill(prompt, true, PlaceInPrompt::All, generation_config)
self.prefill(&prompt, true, PlaceInPrompt::All, generation_config)
.unwrap();

while !self.stopped().unwrap() {
Expand Down

0 comments on commit 64e3410

Please sign in to comment.