Skip to content

Commit

Permalink
Merge pull request #102 from psqlpy-python/feature/support_pg_vector
Browse files Browse the repository at this point in the history
Added PgVector integration
  • Loading branch information
chandr-andr authored Nov 19, 2024
2 parents 335e591 + b50c4fb commit e831c58
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 86 deletions.
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,6 @@ itertools = "0.12.1"
openssl-src = "300.2.2"
openssl-sys = "0.9.102"
pg_interval = { git = "https://github.com/chandr-andr/rust-postgres-interval.git", branch = "psqlpy" }
pgvector = { git = "https://github.com/chandr-andr/pgvector-rust.git", branch = "psqlpy", features = [
"postgres",
] }
13 changes: 13 additions & 0 deletions python/psqlpy/_internal/extra_types.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -774,3 +774,16 @@ class IntervalArray:
### Parameters:
- `inner`: inner value, sequence of timedelta values.
"""

class PgVector:
"""Represent VECTOR in PostgreSQL."""

def __init__(
self: Self,
inner: typing.Sequence[float | int],
) -> None:
"""Create new instance of PgVector.
### Parameters:
- `inner`: inner value, sequence of float or int values.
"""
2 changes: 2 additions & 0 deletions python/psqlpy/extra_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
MoneyArray,
NumericArray,
PathArray,
PgVector,
PointArray,
PyBox,
PyCircle,
Expand Down Expand Up @@ -98,4 +99,5 @@
"LsegArray",
"CircleArray",
"IntervalArray",
"PgVector",
]
2 changes: 1 addition & 1 deletion python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ async def psql_pool_with_cert_file(


@pytest.fixture(autouse=True)
async def create_deafult_data_for_tests(
async def create_default_data_for_tests(
psql_pool: ConnectionPool,
table_name: str,
number_database_records: int,
Expand Down
20 changes: 20 additions & 0 deletions src/extra_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,25 @@ use crate::{
},
};

#[pyclass]
#[derive(Clone)]
pub struct PgVector(Vec<f32>);

#[pymethods]
impl PgVector {
#[new]
fn new(vector: Vec<f32>) -> Self {
Self(vector)
}
}

impl PgVector {
#[must_use]
pub fn inner_value(self) -> Vec<f32> {
self.0
}
}

macro_rules! build_python_type {
($st_name:ident, $rust_type:ty) => {
#[pyclass]
Expand Down Expand Up @@ -412,5 +431,6 @@ pub fn extra_types_module(_py: Python<'_>, pymod: &Bound<'_, PyModule>) -> PyRes
pymod.add_class::<LsegArray>()?;
pymod.add_class::<CircleArray>()?;
pymod.add_class::<IntervalArray>()?;
pymod.add_class::<PgVector>()?;
Ok(())
}
Loading

0 comments on commit e831c58

Please sign in to comment.