Skip to content

Commit

Permalink
[ENH] HNSW should fork (#2124)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
	 - HNSW segments should fork following reader/writer
- Main logic change is making HNSW fork() write to a new file while
loading from an old file - and make hnsw orchestrator respect it
	 - Error handling improvements
 - New functionality
	 - None

## Test plan
*How are these changes tested?*
- [x] Tests pass locally with `pytest` for python, `yarn test` for js,
`cargo test` for rust

## Documentation Changes
None
  • Loading branch information
HammadB authored May 6, 2024
1 parent 970ac74 commit f5c7651
Show file tree
Hide file tree
Showing 10 changed files with 534 additions and 104 deletions.
2 changes: 2 additions & 0 deletions rust/worker/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,5 @@ pub(crate) enum ErrorCodes {
pub(crate) trait ChromaError: Error + Send {
fn code(&self) -> ErrorCodes;
}

impl Error for Box<dyn ChromaError> {}
8 changes: 4 additions & 4 deletions rust/worker/src/execution/operators/flush_s3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use crate::types::SegmentFlushInfo;
use crate::{
execution::operator::Operator,
segment::{
distributed_hnsw_segment::DistributedHNSWSegment, record_segment::RecordSegmentWriter,
SegmentWriter,
distributed_hnsw_segment::DistributedHNSWSegmentWriter,
record_segment::RecordSegmentWriter, SegmentWriter,
},
};
use async_trait::async_trait;
Expand All @@ -24,13 +24,13 @@ impl FlushS3Operator {
#[derive(Debug)]
pub struct FlushS3Input {
record_segment_writer: RecordSegmentWriter,
hnsw_segment_writer: Box<DistributedHNSWSegment>,
hnsw_segment_writer: Box<DistributedHNSWSegmentWriter>,
}

impl FlushS3Input {
pub fn new(
record_segment_writer: RecordSegmentWriter,
hnsw_segment_writer: Box<DistributedHNSWSegment>,
hnsw_segment_writer: Box<DistributedHNSWSegmentWriter>,
) -> Self {
Self {
record_segment_writer,
Expand Down
4 changes: 2 additions & 2 deletions rust/worker/src/execution/operators/hnsw_knn.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
errors::ChromaError, execution::operator::Operator,
segment::distributed_hnsw_segment::DistributedHNSWSegment,
segment::distributed_hnsw_segment::DistributedHNSWSegmentReader,
};
use async_trait::async_trait;

Expand All @@ -9,7 +9,7 @@ pub struct HnswKnnOperator {}

#[derive(Debug)]
pub struct HnswKnnOperatorInput {
pub segment: Box<DistributedHNSWSegment>,
pub segment: Box<DistributedHNSWSegmentReader>,
pub query: Vec<f32>,
pub k: usize,
}
Expand Down
8 changes: 4 additions & 4 deletions rust/worker/src/execution/operators/write_segments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::segment::SegmentWriter;
use crate::{
execution::{data::data_chunk::Chunk, operator::Operator},
segment::{
distributed_hnsw_segment::DistributedHNSWSegment, record_segment::RecordSegmentWriter,
distributed_hnsw_segment::DistributedHNSWSegmentWriter, record_segment::RecordSegmentWriter,
},
types::LogRecord,
};
Expand All @@ -21,14 +21,14 @@ impl WriteSegmentsOperator {
#[derive(Debug)]
pub struct WriteSegmentsInput {
record_segment_writer: RecordSegmentWriter,
hnsw_segment_writer: Box<DistributedHNSWSegment>,
hnsw_segment_writer: Box<DistributedHNSWSegmentWriter>,
chunk: Chunk<LogRecord>,
}

impl<'me> WriteSegmentsInput {
pub fn new(
record_segment_writer: RecordSegmentWriter,
hnsw_segment_writer: Box<DistributedHNSWSegment>,
hnsw_segment_writer: Box<DistributedHNSWSegmentWriter>,
chunk: Chunk<LogRecord>,
) -> Self {
WriteSegmentsInput {
Expand All @@ -42,7 +42,7 @@ impl<'me> WriteSegmentsInput {
#[derive(Debug)]
pub struct WriteSegmentsOutput {
pub(crate) record_segment_writer: RecordSegmentWriter,
pub(crate) hnsw_segment_writer: Box<DistributedHNSWSegment>,
pub(crate) hnsw_segment_writer: Box<DistributedHNSWSegmentWriter>,
}

pub type WriteSegmentsResult = Result<WriteSegmentsOutput, ()>;
Expand Down
9 changes: 5 additions & 4 deletions rust/worker/src/execution/orchestration/compact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::execution::operators::write_segments::WriteSegmentsOperator;
use crate::execution::operators::write_segments::WriteSegmentsResult;
use crate::index::hnsw_provider::HnswIndexProvider;
use crate::log::log::Log;
use crate::segment::distributed_hnsw_segment::DistributedHNSWSegment;
use crate::segment::distributed_hnsw_segment::DistributedHNSWSegmentWriter;
use crate::segment::record_segment::RecordSegmentWriter;
use crate::segment::LogMaterializer;
use crate::segment::SegmentFlusher;
Expand Down Expand Up @@ -250,7 +250,7 @@ impl CompactOrchestrator {
async fn flush_s3(
&mut self,
record_segment_writer: RecordSegmentWriter,
hnsw_segment_writer: Box<DistributedHNSWSegment>,
hnsw_segment_writer: Box<DistributedHNSWSegmentWriter>,
self_address: Box<dyn Receiver<FlushS3Result>>,
) {
self.state = ExecutionState::Flush;
Expand Down Expand Up @@ -296,7 +296,8 @@ impl CompactOrchestrator {

async fn get_segment_writers(
&mut self,
) -> Result<(RecordSegmentWriter, Box<DistributedHNSWSegment>), Box<dyn ChromaError>> {
) -> Result<(RecordSegmentWriter, Box<DistributedHNSWSegmentWriter>), Box<dyn ChromaError>>
{
// Care should be taken to use the same writers across the compaction process
// Since the segment writers are stateful, we should not create new writers for each partition
// Nor should we create new writers across different tasks
Expand Down Expand Up @@ -376,7 +377,7 @@ impl CompactOrchestrator {
.dimension
.expect("Dimension is required in the compactor");

let hnsw_segment_writer = match DistributedHNSWSegment::from_segment(
let hnsw_segment_writer = match DistributedHNSWSegmentWriter::from_segment(
hnsw_segment,
dimension as usize,
self.hnsw_index_provider.clone(),
Expand Down
24 changes: 18 additions & 6 deletions rust/worker/src/execution/orchestration/hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ use crate::execution::operators::merge_knn_results::{
};
use crate::execution::operators::pull_log::PullLogsResult;
use crate::index::hnsw_provider::HnswIndexProvider;
use crate::segment::distributed_hnsw_segment::DistributedHNSWSegment;
use crate::segment::distributed_hnsw_segment::{
DistributedHNSWSegmentFromSegmentError, DistributedHNSWSegmentReader,
DistributedHNSWSegmentWriter,
};
use crate::sysdb::sysdb::{GetCollectionsError, GetSegmentsError, SysDb};
use crate::system::{ComponentContext, System};
use crate::types::{Collection, LogRecord, Segment, SegmentType, VectorQueryResult};
Expand Down Expand Up @@ -233,7 +236,7 @@ impl HnswQueryOrchestrator {
.expect("Invariant violation. Collection dimension is not set");

// Fetch the data needed for the duration of the query - The HNSW Segment, The record Segment and the Collection
let hnsw_segment_reader = match DistributedHNSWSegment::from_segment(
let hnsw_segment_reader = match DistributedHNSWSegmentReader::from_segment(
// These unwraps are safe because we have already checked that the segments are set in the orchestrator on_start
hnsw_segment,
dimensionality as usize,
Expand All @@ -242,10 +245,19 @@ impl HnswQueryOrchestrator {
.await
{
Ok(reader) => reader,
Err(e) => {
self.terminate_with_error(e, ctx);
return;
}
Err(e) => match *e {
DistributedHNSWSegmentFromSegmentError::Uninitialized => {
// no task, decrement the merge dependency count and return
self.hnsw_result_distances = Some(Vec::new());
self.hnsw_result_offset_ids = Some(Vec::new());
self.merge_dependency_count -= 1;
return;
}
_ => {
self.terminate_with_error(e, ctx);
return;
}
},
};

println!("Created HNSW Segment Reader: {:?}", hnsw_segment_reader);
Expand Down
8 changes: 5 additions & 3 deletions rust/worker/src/index/hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ pub(crate) struct HnswIndexConfig {
pub(crate) enum HnswIndexFromSegmentError {
#[error("Missing config `{0}`")]
MissingConfig(String),
#[error("Invalid metadata value")]
MetadataValueError(#[from] MetadataValueConversionError),
}

impl ChromaError for HnswIndexFromSegmentError {
Expand All @@ -47,7 +49,7 @@ impl HnswIndexConfig {
pub(crate) fn from_segment(
segment: &Segment,
persist_path: &std::path::Path,
) -> Result<HnswIndexConfig, Box<dyn ChromaError>> {
) -> Result<HnswIndexConfig, Box<HnswIndexFromSegmentError>> {
let persist_path = match persist_path.to_str() {
Some(persist_path) => persist_path,
None => {
Expand Down Expand Up @@ -78,7 +80,7 @@ impl HnswIndexConfig {
fn get_metadata_value_as<'a, T>(
metadata: &'a Metadata,
key: &str,
) -> Result<T, Box<dyn ChromaError>>
) -> Result<T, Box<HnswIndexFromSegmentError>>
where
T: TryFrom<&'a MetadataValue, Error = MetadataValueConversionError>,
{
Expand All @@ -92,7 +94,7 @@ impl HnswIndexConfig {
};
match res {
Ok(value) => Ok(value),
Err(e) => Err(Box::new(e)),
Err(e) => Err(Box::new(HnswIndexFromSegmentError::MetadataValueError(e))),
}
}

Expand Down
Loading

0 comments on commit f5c7651

Please sign in to comment.