diff --git a/collection.go b/collection.go index 4dbd7a4..155f75f 100644 --- a/collection.go +++ b/collection.go @@ -292,6 +292,19 @@ func (c *Collection) AddDocument(ctx context.Context, doc Document) error { return nil } +// ListIDs returns the IDs of all documents in the collection. +func (c *Collection) ListIDs(_ context.Context) []string { + c.documentsLock.RLock() + defer c.documentsLock.RUnlock() + + ids := make([]string, 0, len(c.documents)) + for id := range c.documents { + ids = append(ids, id) + } + + return ids +} + // GetByID returns a document by its ID. // The returned document is a copy of the original document, so it can be safely // modified without affecting the collection. diff --git a/collection_test.go b/collection_test.go index 4e47d4b..e8dc918 100644 --- a/collection_test.go +++ b/collection_test.go @@ -391,6 +391,56 @@ func TestCollection_QueryError(t *testing.T) { } } +func TestCollection_ListIDs(t *testing.T) { + ctx := context.Background() + + // Create collection + db := NewDB() + name := "test" + metadata := map[string]string{"foo": "bar"} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return vectors, nil + } + c, err := db.CreateCollection(name, metadata, embeddingFunc) + if err != nil { + t.Fatal("expected no error, got", err) + } + if c == nil { + t.Fatal("expected collection, got nil") + } + + // Add documents + ids := []string{"1", "2"} + metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}} + contents := []string{"hello world", "hallo welt"} + err = c.Add(context.Background(), ids, nil, metadatas, contents) + if err != nil { + t.Fatal("expected nil, got", err) + } + + // List IDs + foundIds := c.ListIDs(ctx) + + // Ensure IDs match + // (slices are same length and all the items in the first slice exist in the second slice) + if len(foundIds) != len(ids) { + t.Fatal("expected", len(ids), "got", len(foundIds)) + } + for _, id := range ids { + found := false + for _, foundID := range foundIds { + if id == foundID { + found = true + break + } + } + if !found { + t.Fatal("expected", id, "in", foundIds) + } + } +} + func TestCollection_Get(t *testing.T) { ctx := context.Background()