Skip to content

Commit

Permalink
add python implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
zkeram committed Jan 31, 2025
1 parent 0c47746 commit 0446827
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 13 deletions.
64 changes: 64 additions & 0 deletions polars_bio/interval_op_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import datafusion
import pandas as pd
import polars as pl
from typing_extensions import TYPE_CHECKING, Union
from pathlib import Path

def get_py_ctx() -> datafusion.context.SessionContext:
return datafusion.context.SessionContext()

def read_df_to_datafusion(
py_ctx: datafusion.context.SessionContext,
df: Union[str, pl.DataFrame, pl.LazyFrame, pd.DataFrame]
) -> datafusion.dataframe:
if isinstance(df, pl.DataFrame):
return py_ctx.from_polars(df)
elif isinstance(df, pd.DataFrame):
return py_ctx.from_pandas(df)
elif isinstance(df, pl.LazyFrame):
return py_ctx.from_polars(df.collect())
elif isinstance(df, str):
ext = Path(df).suffix
if ext == '.csv':
return py_ctx.read_csv(df)
elif ext == '.bed':
return py_ctx.read_csv(df, has_header=False, delimited='\t', file_extension='.bed', schema=pa.schema([
(DEFAULT_COLUMNS[0], pa.string()),
(DEFAULT_COLUMNS[1], pa.int64()),
(DEFAULT_COLUMNS[2], pa.int64())]))
else:
return py_ctx.read_parquet(df)
raise ValueError("Invalid `df` argument.")

def df_to_lazyframe(
df: datafusion.DataFrame
) -> pl.LazyFrame:
# TODO: make it actually lazy
'''
def _get_lazy(
with_columns: list[str] | None,
predicate: pl.Expr | None,
n_rows: int | None,
batch_size: int | None
) -> Iterator[pl.DataFrame]:
return register_io_source(_overlap_source, schema=schema)
'''
return df.to_polars().lazy()

def convert_result(
df: datafusion.DataFrame,
output_type: str,
streaming: bool
) -> Union[pl.LazyFrame, pl.DataFrame, pd.DataFrame]:
# TODO: implement streaming
if streaming == True:
#raise NotImplementedError("streaming is not implemented")
return df.to_polars().lazy()
if output_type == "polars.DataFrame":
return df.to_polars()
elif output_type == "pandas.DataFrame":
return df.to_pandas()
elif output_type == "polars.LazyFrame":
return df_to_lazyframe(df)
raise ValueError("Invalid `output_type` argument")
74 changes: 61 additions & 13 deletions polars_bio/range_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from .constants import DEFAULT_INTERVAL_COLUMNS
from .context import ctx
from .range_op_helpers import _validate_overlap_input, range_operation
from .interval_op_helpers import read_df_to_datafusion, convert_result, get_py_ctx

import datafusion
from datafusion import col, literal

__all__ = ["overlap", "nearest", "count_overlaps"]

Expand Down Expand Up @@ -222,21 +226,65 @@ def count_overlaps(
```
Todo:
Support for on_cols.
Support return_input.
"""
_validate_overlap_input(cols1, cols2, on_cols, suffixes, output_type, how="inner")

range_op = RangeOp.CountOverlapsNaive if naive_query else RangeOp.CountOverlaps

my_ctx = get_py_ctx()
on_cols = [] if on_cols is None else on_cols
cols1 = DEFAULT_INTERVAL_COLUMNS if cols1 is None else cols1
cols2 = DEFAULT_INTERVAL_COLUMNS if cols2 is None else cols2
range_options = RangeOptions(
range_op=range_op,
filter_op=overlap_filter,
suffixes=suffixes,
columns_1=cols1,
columns_2=cols2,
streaming=streaming,
)
return range_operation(df1, df2, range_options, output_type, ctx)
if naive_query:
range_options = RangeOptions(
range_op=NaiveRangeQuery,
filter_op=overlap_filter,
suffixes=suffixes,
columns_1=cols1,
columns_2=cols2,
streaming=streaming,
)
return range_operation(df1, df2, range_options, output_type, ctx)
df1 = read_df_to_datafusion(my_ctx, df1)
df2 = read_df_to_datafusion(my_ctx, df2)

# TODO: guarantee no collisions
s1start_s2end = "s1starts2end"
s1end_s2start = "s1ends2start"
contig = "contig"
count = "count"
starts = "starts"
ends = "ends"
is_s1 = "is_s1"
suff, _ = suffixes

df1 = df1.select(*([literal(1).alias(is_s1), col(cols1[1]).alias(s1start_s2end), col(cols1[2]).alias(s1end_s2start), col(cols1[0]).alias(contig)] + on_cols))
df2 = df2.select(*([literal(0).alias(is_s1), col(cols2[2]).alias(s1end_s2start), col(cols2[1]).alias(s1start_s2end), col(cols2[0]).alias(contig)] + on_cols))

df = df1.union(df2)

partitioning = [col(contig)] + [col(c) for c in on_cols]
df.show()
df = df.select(*([s1start_s2end, s1end_s2start, contig, is_s1,
datafusion.functions.sum(col(is_s1)).over(
datafusion.expr.Window(
partition_by=partitioning,
order_by=[col(s1start_s2end).sort(), col(is_s1).sort(ascending=(overlap_filter == FilterOp.Strict))],
)
).alias(starts),
datafusion.functions.sum(col(is_s1)).over(
datafusion.expr.Window(
partition_by=partitioning,
order_by=[col(s1end_s2start).sort(), col(is_s1).sort(ascending=(overlap_filter == FilterOp.Weak))],
)
).alias(ends)] + on_cols))
df.show()
df = df.filter(col(is_s1) == 0)
df = df.select(*([
col(contig).alias(cols1[0] + suff),
col(s1end_s2start).alias(cols1[1] + suff),
col(s1start_s2end).alias(cols1[2] + suff)] +
on_cols +
[(col(starts) - col(ends)).alias(count)]
))

return convert_result(df, output_type, streaming)

0 comments on commit 0446827

Please sign in to comment.