Skip to content

Commit

Permalink
[#25807] DocDB: Support various metrics in vector indexes
Browse files Browse the repository at this point in the history
Summary:
PgVector and usearch support different vector distance functions.
PgVector supports: Squared Euclidean/L2, Inner product, Cosine and L1 distance functions.
usearch also supports: Squared Euclidean/L2, Inner product, Cosine, along with 7 other distance functions. It, however, does not support L1.

So both of them supports:
1) L2Squared
2) Cosine
3) InnerProduct

This diff provides support for all those metrics in ybhnsw index access method.
Jira: DB-15107

Test Plan: PgVectorIndexTest.Cosine/*, PgVectorIndexTest.L2/*, PgVectorIndexTest.InnerProduct/*

Reviewers: tnayak, aleksandr.ponomarenko

Reviewed By: tnayak, aleksandr.ponomarenko

Subscribers: yql, ybase

Tags: #jenkins-ready

Differential Revision: https://phorge.dev.yugabyte.com/D41586
  • Loading branch information
spolitov committed Feb 4, 2025
1 parent 1e6f612 commit f8950ff
Show file tree
Hide file tree
Showing 21 changed files with 308 additions and 71 deletions.
3 changes: 2 additions & 1 deletion src/postgres/src/backend/access/yb_access/yb_lsm.c
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,8 @@ static void
ybcinbindschema(YbcPgStatement handle,
struct IndexInfo *indexInfo,
TupleDesc indexTupleDesc,
int16 *coloptions)
int16 *coloptions,
Oid *opclassOids)
{
YBCBindCreateIndexColumns(handle,
indexInfo,
Expand Down
3 changes: 2 additions & 1 deletion src/postgres/src/backend/access/ybgin/ybginutil.c
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ void
ybginbindschema(YbcPgStatement handle,
struct IndexInfo *indexInfo,
TupleDesc indexTupleDesc,
int16 *coloptions)
int16 *coloptions,
Oid *objectClassId)
{
YBCBindCreateIndexColumns(handle,
indexInfo,
Expand Down
3 changes: 2 additions & 1 deletion src/postgres/src/backend/catalog/index.c
Original file line number Diff line number Diff line change
Expand Up @@ -1040,7 +1040,8 @@ index_create(Relation heapRelation,
colocationId,
tableSpaceId,
YbGetRelfileNodeId(indexRelation),
InvalidOid /* oldRelfileNodeId */ );
InvalidOid /* oldRelfileNodeId */ ,
classObjectId);
}

/*
Expand Down
6 changes: 4 additions & 2 deletions src/postgres/src/backend/commands/yb_cmds.c
Original file line number Diff line number Diff line change
Expand Up @@ -1226,7 +1226,8 @@ YBCCreateIndex(const char *indexName,
Oid colocationId,
Oid tablespaceId,
Oid indexRelfileNodeId,
Oid oldRelfileNodeId)
Oid oldRelfileNodeId,
Oid *opclassOids)
{
Oid namespaceId = RelationGetNamespace(rel);
char *db_name = get_database_name(YBCGetDatabaseOid(rel));
Expand Down Expand Up @@ -1266,7 +1267,8 @@ YBCCreateIndex(const char *indexName,
true);

Assert(amroutine != NULL && amroutine->yb_ambindschema != NULL);
amroutine->yb_ambindschema(handle, indexInfo, indexTupleDesc, coloptions);
amroutine->yb_ambindschema(handle, indexInfo, indexTupleDesc, coloptions,
opclassOids);

/* Handle SPLIT statement, if present */
if (split_options)
Expand Down
3 changes: 2 additions & 1 deletion src/postgres/src/backend/utils/misc/pg_yb_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -6129,7 +6129,8 @@ YbIndexSetNewRelfileNode(Relation indexRel, Oid newRelfileNodeId,
InvalidOid /* colocation ID */ ,
indexRel->rd_rel->reltablespace,
newRelfileNodeId,
YbGetRelfileNodeId(indexRel));
YbGetRelfileNodeId(indexRel),
NULL /* opclassOids */ );

table_close(indexedRel, ShareLock);

Expand Down
3 changes: 2 additions & 1 deletion src/postgres/src/include/access/amapi.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ typedef int64 (*yb_amgetbitmap_function) (IndexScanDesc scan,
typedef void (*yb_ambindschema_function) (YbcPgStatement handle,
struct IndexInfo *indexInfo,
TupleDesc indexTupleDesc,
int16 *coloptions);
int16 *coloptions,
Oid *opclassOids);

/* end index scan */
typedef void (*amendscan_function) (IndexScanDesc scan);
Expand Down
3 changes: 2 additions & 1 deletion src/postgres/src/include/access/ybgin.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ extern bool ybginvalidate(Oid opclassoid);
extern void ybginbindschema(YbcPgStatement handle,
struct IndexInfo *indexInfo,
TupleDesc indexTupleDesc,
int16 *coloptions);
int16 *coloptions,
Oid *opclassIds);
extern IndexScanDesc ybginbeginscan(Relation rel, int nkeys, int norderbys);
extern void ybginrescan(IndexScanDesc scan, ScanKey scankey, int nscankeys,
ScanKey orderbys, int norderbys);
Expand Down
3 changes: 2 additions & 1 deletion src/postgres/src/include/commands/yb_cmds.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ extern void YBCCreateIndex(const char *indexName,
Oid colocationId,
Oid tablespaceId,
Oid indexRelfileNodeId,
Oid oldRelfileNodeId);
Oid oldRelfileNodeId,
Oid *opclassOids);

extern void YBCBindCreateIndexColumns(YbcPgStatement handle,
IndexInfo *indexInfo,
Expand Down
11 changes: 11 additions & 0 deletions src/postgres/third-party-extensions/pgvector/sql/vector.sql
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,14 @@ CREATE OPERATOR CLASS vector_l2_ops
DEFAULT FOR TYPE vector USING ybhnsw AS
OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 1 vector_l2_squared_distance(vector, vector);

CREATE OPERATOR CLASS vector_ip_ops
FOR TYPE vector USING ybhnsw AS
OPERATOR 1 <#> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 1 vector_negative_inner_product(vector, vector);

CREATE OPERATOR CLASS vector_cosine_ops
FOR TYPE vector USING ybhnsw AS
OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 1 vector_negative_inner_product(vector, vector),
FUNCTION 2 vector_norm(vector);
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ static void
ybdummyannbindcolumnschema(YbcPgStatement handle,
IndexInfo *indexInfo,
TupleDesc indexTupleDesc,
int16 *coloptions)
int16 *coloptions,
Oid *opclassOids)
{
elog(WARNING,
"ybdummyann is meant for internal-testing only and "
"does not yield ordered results");
bindVectorIndexOptions(handle, indexInfo, indexTupleDesc, YB_VEC_DUMMY);
bindVectorIndexOptions(
handle, indexInfo, indexTupleDesc, YB_VEC_DUMMY, YB_VEC_DIST_L2);
YBCBindCreateIndexColumns(
handle, indexInfo, indexTupleDesc, coloptions, 0);
}
Expand Down
29 changes: 27 additions & 2 deletions src/postgres/third-party-extensions/pgvector/src/ybvector/ybhnsw.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,40 @@
#include "postgres.h"

#include "ybvector.h"
#include "catalog/pg_opclass.h"
#include "commands/yb_cmds.h"
#include "utils/syscache.h"

static void
ybhnswbindcolumnschema(YbcPgStatement handle,
IndexInfo *indexInfo,
TupleDesc indexTupleDesc,
int16 *coloptions)
int16 *coloptions,
Oid *opclassOids)
{
bindVectorIndexOptions(handle, indexInfo, indexTupleDesc, YB_VEC_HNSW);
HeapTuple ht_opc;
Form_pg_opclass opcrec;

Assert(indexInfo->ii_NumIndexKeyAttrs == 1);
ht_opc = SearchSysCache1(CLAOID, ObjectIdGetDatum(opclassOids[0]));
if (!HeapTupleIsValid(ht_opc))
elog(ERROR, "cache lookup failed for opclass %u", opclassOids[0]);
opcrec = (Form_pg_opclass) GETSTRUCT(ht_opc);
YbcPgVectorDistType dist_type;
if (!strcmp(opcrec->opcname.data, "vector_l2_ops")) {
dist_type = YB_VEC_DIST_L2;
} else if (!strcmp(opcrec->opcname.data, "vector_ip_ops")) {
dist_type = YB_VEC_DIST_IP;
} else if (!strcmp(opcrec->opcname.data, "vector_cosine_ops")) {
dist_type = YB_VEC_DIST_COSINE;
} else {
elog(ERROR, "unsupported vector index op class name %s",
opcrec->opcname.data);
}
ReleaseSysCache(ht_opc);

bindVectorIndexOptions(
handle, indexInfo, indexTupleDesc, YB_VEC_HNSW, dist_type);
YBCBindCreateIndexColumns(handle, indexInfo, indexTupleDesc, coloptions, 0);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ makeBaseYbVectorHandler(bool is_copartitioned)
IndexAmRoutine *amroutine = makeNode(IndexAmRoutine);

amroutine->amstrategies = 0;
amroutine->amsupport = 1;
amroutine->amsupport = 2;
amroutine->amcanorder = false;
amroutine->amcanorderbyop = true;
amroutine->amcanbackward = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,5 @@ void
bindVectorIndexOptions(YbcPgStatement handle,
IndexInfo *indexInfo,
TupleDesc indexTupleDesc,
YbcPgVectorIdxType ybpg_idx_type);
YbcPgVectorIdxType ybpg_idx_type,
YbcPgVectorDistType dist_type);
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,12 @@ void
bindVectorIndexOptions(YbcPgStatement handle,
IndexInfo *indexInfo,
TupleDesc indexTupleDesc,
YbcPgVectorIdxType ybpg_idx_type)
YbcPgVectorIdxType ybpg_idx_type,
YbcPgVectorDistType dist_type)
{
YbcPgVectorIdxOptions options;
options.idx_type = ybpg_idx_type;

/*
* Hardcoded for now.
* TODO(tanuj): Pass down distance info from the used distance opclass.
*/
options.dist_type = YB_VEC_DIST_L2;
options.dist_type = dist_type;

/* We only support indexes with one vector attribute for now. */
Assert(indexTupleDesc->natts == 1);
Expand Down
8 changes: 8 additions & 0 deletions src/yb/docdb/pgsql_operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ DEFINE_RUNTIME_AUTO_bool(ysql_skip_row_lock_for_update, kExternal, true, false,
"take finer column-level locks instead of locking the whole row. This may cause issues with "
"data integrity for operations with implicit dependencies between columns.");

DEFINE_RUNTIME_bool(vector_index_skip_filter_check, false,
"Whether to skip filter check during vector index search.");

DECLARE_uint64(rpc_max_message_size);

Expand Down Expand Up @@ -900,6 +902,9 @@ class PgsqlVectorFilter {
}

Status Init(const PgsqlReadOperationData& data) {
if (FLAGS_vector_index_skip_filter_check) {
return Status::OK();
}
std::vector<ColumnId> columns;
ColumnId index_column = data.vector_index->column_id();
for (const auto& col_ref : data.request.col_refs()) {
Expand All @@ -919,6 +924,9 @@ class PgsqlVectorFilter {
}

bool operator()(const vector_index::VectorId& vector_id) {
if (!row_) {
return true;
}
auto key = dockv::VectorIdKey(vector_id);
// TODO(vector_index) handle failure
auto ybctid = CHECK_RESULT(iter_.impl().FetchDirect(key.AsSlice()));
Expand Down
42 changes: 35 additions & 7 deletions src/yb/docdb/vector_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,28 +53,44 @@ const std::string kVectorIndexDirPrefix = "vi-";
namespace {

template <template<class, class> class Factory, class LSM>
auto VectorLSMFactory(size_t dimensions) {
auto VectorLSMFactory(vector_index::DistanceKind distance_kind, size_t dimensions) {
using FactoryImpl = vector_index::MakeVectorIndexFactory<Factory, LSM>;
return [dimensions] {
return [distance_kind, dimensions] {
vector_index::HNSWOptions hnsw_options = {
.dimensions = dimensions,
.num_neighbors_per_vertex = FLAGS_vector_index_num_neighbors_per_vertex,
.num_neighbors_per_vertex_base = FLAGS_vector_index_num_neighbors_per_vertex_base,
.ef_construction = FLAGS_vector_index_ef_construction,
.ef = FLAGS_vector_index_ef,
.distance_kind = distance_kind,
};
return FactoryImpl::Create(hnsw_options);
};
}

vector_index::DistanceKind ConvertDistanceKind(PgVectorDistanceType dist_type) {
switch (dist_type) {
case PgVectorDistanceType::DIST_L2:
return vector_index::DistanceKind::kL2Squared;
case PgVectorDistanceType::DIST_IP:
return vector_index::DistanceKind::kInnerProduct;
case PgVectorDistanceType::DIST_COSINE:
return vector_index::DistanceKind::kCosine;
case PgVectorDistanceType::INVALID_DIST:
break;
}
FATAL_INVALID_ENUM_VALUE(PgVectorDistanceType, dist_type);
}

template<vector_index::IndexableVectorType Vector,
vector_index::ValidDistanceResultType DistanceResult>
Result<typename vector_index::VectorLSMTypes<Vector, DistanceResult>::VectorIndexFactory>
GetVectorLSMFactory(PgVectorIndexType type, size_t dimensions) {
GetVectorLSMFactory(PgVectorIndexType type, vector_index::DistanceKind distance_kind,
size_t dimensions) {
using LSM = vector_index::VectorLSM<Vector, DistanceResult>;
switch (type) {
case PgVectorIndexType::HNSW:
return VectorLSMFactory<vector_index::UsearchIndexFactory, LSM>(dimensions);
return VectorLSMFactory<vector_index::UsearchIndexFactory, LSM>(distance_kind, dimensions);
case PgVectorIndexType::DUMMY: [[fallthrough]];
case PgVectorIndexType::IVFFLAT: [[fallthrough]];
case PgVectorIndexType::UNKNOWN_IDX:
Expand Down Expand Up @@ -115,8 +131,13 @@ Result<vector_index::VectorLSMInsertEntry<Vector>> ConvertEntry(
};
}

size_t EncodeDistance(float distance) {
return bit_cast<uint32_t>(util::CanonicalizeFloat(distance));
EncodedDistance EncodeDistance(float distance) {
uint32_t v = bit_cast<uint32_t>(distance);
if (v >> 31) {
return ~v;
} else {
return v ^ util::kInt32SignBitFlipMask;
}
}

template<vector_index::IndexableVectorType Vector,
Expand Down Expand Up @@ -154,7 +175,8 @@ class VectorIndexImpl : public VectorIndex {
.log_prefix = log_prefix,
.storage_dir = GetStorageDir(data_root_dir, DirName()),
.vector_index_factory = VERIFY_RESULT((GetVectorLSMFactory<Vector, DistanceResult>(
idx_options.idx_type(), idx_options.dimensions()))),
idx_options.idx_type(), ConvertDistanceKind(idx_options.dist_type()),
idx_options.dimensions()))),
.points_per_chunk = FLAGS_vector_index_initial_chunk_size,
.thread_pool = &thread_pool,
.frontiers_factory = [] { return std::make_unique<docdb::ConsensusFrontiers>(); },
Expand Down Expand Up @@ -196,7 +218,13 @@ class VectorIndexImpl : public VectorIndex {
.encoded_distance = EncodeDistance(entry.distance),
.key = KeyBuffer(db_entry.value),
});
#ifndef NDEBUG
if (result.size() > 1) {
CHECK_GE(result.back().encoded_distance, result[result.size() - 2].encoded_distance);
}
#endif
}

return result;
}

Expand Down
2 changes: 1 addition & 1 deletion src/yb/docdb/vector_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

namespace yb::docdb {

using EncodedDistance = size_t;
using EncodedDistance = uint64_t;

struct VectorIndexInsertEntry {
ValueBuffer value;
Expand Down
16 changes: 7 additions & 9 deletions src/yb/tablet/tablet.cc
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ using rocksdb::SequenceNumber;

namespace yb::tablet {

bool TEST_fail_on_seq_scan_with_vector_indexes = false;

using strings::Substitute;

using client::YBSession;
Expand Down Expand Up @@ -1785,15 +1787,6 @@ Result<std::unique_ptr<docdb::YQLRowwiseIteratorIf>> Tablet::NewRowIterator(
return std::move(iter);
}

Result<std::unique_ptr<docdb::YQLRowwiseIteratorIf>> Tablet::NewRowIterator(
const TableId& table_id) const {
const std::shared_ptr<tablet::TableInfo> table_info =
VERIFY_RESULT(metadata_->GetTableInfo(table_id));
CHECK(false);
dockv::ReaderProjection projection(table_info->schema());
return NewRowIterator(projection, {}, table_id);
}

Status Tablet::ApplyRowOperations(
WriteOperation* operation, const docdb::StorageSet& apply_to_storages) {
AtomicFlagSleepMs(&FLAGS_TEST_inject_sleep_before_applying_write_batch_ms);
Expand Down Expand Up @@ -2182,6 +2175,11 @@ Status Tablet::DoHandlePgsqlReadRequest(
if (pgsql_read_request.index_request().has_vector_idx_options()) {
vector_index_table_id = index_table_id;
}
#ifndef NDEBUG
} else if (has_vector_indexes_.load(std::memory_order_relaxed)) {
CHECK(!TEST_fail_on_seq_scan_with_vector_indexes ||
pgsql_read_request.has_ybctid_column_value()) << pgsql_read_request.ShortDebugString();
#endif
}
auto index_doc_read_context = !index_table_id.empty()
? VERIFY_RESULT(GetDocReadContext(index_table_id)) : nullptr;
Expand Down
3 changes: 0 additions & 3 deletions src/yb/tablet/tablet.h
Original file line number Diff line number Diff line change
Expand Up @@ -426,9 +426,6 @@ class Tablet : public AbstractTablet,
CoarseTimePoint deadline = CoarseTimePoint::max(),
docdb::SkipSeek skip_seek = docdb::SkipSeek::kFalse) const;

Result<std::unique_ptr<docdb::YQLRowwiseIteratorIf>> NewRowIterator(
const TableId& table_id) const;

Result<std::unique_ptr<docdb::YQLRowwiseIteratorIf>> CreateCDCSnapshotIterator(
const dockv::ReaderProjection& projection,
const ReadHybridTime& time,
Expand Down
1 change: 1 addition & 0 deletions src/yb/yql/pggate/pg_dml_read.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "yb/gutil/casts.h"
#include "yb/gutil/strings/substitute.h"

#include "yb/util/debug-util.h"
#include "yb/util/logging.h"
#include "yb/util/range.h"
#include "yb/util/slice.h"
Expand Down
Loading

0 comments on commit f8950ff

Please sign in to comment.