Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Made Sampler and InferenceParameters threadsafe #292

Merged
merged 1 commit into from
Jun 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions crates/llm-base/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ pub struct InferenceParameters {
/// sampler that offers a [Default](samplers::TopPTopK::default) implementation.
pub sampler: Arc<dyn Sampler>,
}

//Since Sampler implements Send and Sync, InferenceParameters should too.
unsafe impl Send for InferenceParameters {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hol up is this necessary? won't these be inherited?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(it's fine but I don't think it's needed?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyO3 was complaining, so i had to wrap it in an Arc<>. With this i can just use it in a python objekt.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weird, I think it should automatically infer Send + Sync for inference parameters as all the fields should be, too 🤔

It's totally fine as is, but can you try playing around with it and see if it still compiles without the unsafe impls?

unsafe impl Sync for InferenceParameters {}

impl Default for InferenceParameters {
fn default() -> Self {
Self {
Expand Down
2 changes: 1 addition & 1 deletion crates/llm-base/src/samplers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use rand::{distributions::WeightedIndex, prelude::Distribution};
use crate::{TokenBias, TokenId};

/// A sampler for generation.
pub trait Sampler: Debug {
pub trait Sampler: Debug + Send + Sync {
/// Given the previous tokens, the logits from the most recent evaluation, and a source of randomness,
/// sample from the logits and return the token ID.
fn sample(
Expand Down