Skip to content

Commit

Permalink
[query] Add query_matrix_table an analogue to query_table
Browse files Browse the repository at this point in the history
CHANGELOG: Add query_matrix_table an analogue to query_table
  • Loading branch information
chrisvittal committed Feb 4, 2025
1 parent 3080087 commit 4e0dbe5
Show file tree
Hide file tree
Showing 7 changed files with 274 additions and 0 deletions.
10 changes: 10 additions & 0 deletions hail/hail/src/is/hail/expr/ir/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1186,6 +1186,7 @@ package defs {
classOf[PartitionNativeReader],
classOf[PartitionNativeReaderIndexed],
classOf[PartitionNativeIntervalReader],
classOf[PartitionZippedNativeIntervalReader],
classOf[PartitionZippedNativeReader],
classOf[PartitionZippedIndexedNativeReader],
classOf[BgenPartitionReader],
Expand Down Expand Up @@ -1216,6 +1217,15 @@ package defs {
spec,
(jv \ "uidFieldName").extract[String],
)
case "PartitionZippedNativeIntervalReader" =>
val path = (jv \ "path").extract[String]
val spec = RelationalSpec.read(ctx.fs, path).asInstanceOf[AbstractMatrixTableSpec]
PartitionZippedNativeIntervalReader(
ctx.stateManager,
path,
spec,
(jv \ "uidFieldName").extract[String],
)
case "GVCFPartitionReader" =>
val header = VCFHeaderInfo.fromJSON((jv \ "header"))
val callFields = (jv \ "callFields").extract[Set[String]]
Expand Down
56 changes: 56 additions & 0 deletions hail/hail/src/is/hail/expr/ir/TableIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,7 @@ case class PartitionNativeIntervalReader(
lazy val partitioner = rowsSpec.partitioner(sm)

lazy val contextType: Type = RVDPartitioner.intervalIRRepresentation(partitioner.kType)
require(partitioner.kType.size > 0)

def toJValue: JValue = Extraction.decompose(this)(PartitionReader.formats)

Expand Down Expand Up @@ -1509,6 +1510,61 @@ case class PartitionZippedNativeReader(left: PartitionReader, right: PartitionRe
}
}

private[this] class PartitionEntriesNativeIntervalReader(
sm: HailStateManager,
entriesPath: String,
entriesSpec: AbstractTableSpec,
uidFieldName: String,
rowsTableSpec: AbstractTableSpec,
) extends PartitionNativeIntervalReader(sm, entriesPath, entriesSpec, uidFieldName) {
override lazy val partitioner = rowsTableSpec.rowsSpec.partitioner(sm)
}

case class PartitionZippedNativeIntervalReader(
sm: HailStateManager,
mtPath: String,
mtSpec: AbstractMatrixTableSpec,
uidFieldName: String,
) extends PartitionReader {
require(mtSpec.indexed)

// XXX: rows and entries paths are hardcoded, see MatrixTableSpec
private lazy val rowsReader =
PartitionNativeIntervalReader(sm, mtPath + "/rows", mtSpec.rowsSpec, "__dummy")

private lazy val entriesReader =
new PartitionEntriesNativeIntervalReader(
sm,
mtPath + "/entries",
mtSpec.entriesSpec,
uidFieldName,
rowsReader.tableSpec,
)

private lazy val zippedReader = PartitionZippedNativeReader(rowsReader, entriesReader)

def contextType = rowsReader.contextType
def fullRowType = zippedReader.fullRowType
def rowRequiredness(requestedType: TStruct): RStruct = zippedReader.rowRequiredness(requestedType)
def toJValue: JValue = Extraction.decompose(this)(PartitionReader.formats)

def emitStream(
ctx: ExecuteContext,
cb: EmitCodeBuilder,
mb: EmitMethodBuilder[_],
codeContext: EmitCode,
requestedType: TStruct,
): IEmitCode = {
val zipContextType: TBaseStruct = tcoerce(zippedReader.contextType)
val valueContext = cb.memoize(codeContext)
val contexts: IndexedSeq[EmitCode] = FastSeq(valueContext, valueContext)
val st = SStackStruct(zipContextType, contexts.map(_.emitType))
val context = EmitCode.present(mb, st.fromEmitCodes(cb, contexts))

zippedReader.emitStream(ctx, cb, mb, context, requestedType)
}
}

case class PartitionZippedIndexedNativeReader(
specLeft: AbstractTypedCodecSpec,
specRight: AbstractTypedCodecSpec,
Expand Down
2 changes: 2 additions & 0 deletions hail/python/hail/expr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@
qchisqtail,
qnorm,
qpois,
query_matrix_table,
query_table,
rand_beta,
rand_bool,
Expand Down Expand Up @@ -554,6 +555,7 @@
'_console_log',
'dnorm',
'dchisq',
'query_matrix_table',
'query_table',
'keyed_union',
'keyed_intersection',
Expand Down
100 changes: 100 additions & 0 deletions hail/python/hail/expr/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7070,6 +7070,106 @@ def coerce_endpoint(point):
)


@typecheck(path=builtins.str, point_or_interval=expr_any, entries_name=builtins.str)
def query_matrix_table(path, point_or_interval, entries_name='entries_array'):
"""Query row records from a matrix table corresponding to a given point or
range of row keys. The entry fields are localized as an array of structs as
in :meth:`.MatrixTable.localize_entries`.
Notes
-----
This function does not dispatch to a distributed runtime; it can be used inside
already-distributed queries such as in :meth:`.Table.annotate`.
Warning
-------
This function contains no safeguards against reading large amounts of data
using a single thread.
Parameters
----------
path : :class:`str`
Table path.
point_or_interval
Point or interval to query.
entries_name : :class:`str`
Identifier to use for the localized entries array. Must not conflict
with any row field identifiers.
Returns
-------
:class:`.ArrayExpression`
"""
matrix_table = hl.read_matrix_table(path)
if entries_name in matrix_table.row:
raise ValueError(
f'query_matrix_table: field "{entries_name}" is present in matrix table row fields, pick a different `entries_name`'
)

entries_table = hl.read_table(path + '/entries')
[entry_id] = list(entries_table.row)

full_row_type = tstruct(**matrix_table.row.dtype, **entries_table.row.dtype)
key_typ = matrix_table.row_key.dtype
key_names = list(key_typ)
len = builtins.len

def coerce_endpoint(point):
if point.dtype == key_typ[0]:
point = hl.struct(**{key_names[0]: point})
ts = point.dtype
if isinstance(ts, tstruct):
i = 0
while i < len(ts):
if i >= len(key_typ):
raise ValueError(
f"query_matrix_table: queried with {len(ts)} row key field(s), but matrix table only has {len(key_typ)} row key field(s)"
)
if key_typ[i] != ts[i]:
raise ValueError(
f"query_matrix_table: mismatch at row key field {i} ({list(ts.keys())[i]!r}): query type is {ts[i]}, matrix table row key type is {key_typ[i]}"
)
i += 1

if i == 0:
raise ValueError("query_matrix_table: cannot query with empty row key")

point_size = builtins.len(point.dtype)
return hl.tuple([
hl.struct(**{
key_names[i]: (point[i] if i < point_size else hl.missing(key_typ[i]))
for i in builtins.range(builtins.len(key_typ))
}),
hl.int32(point_size),
])
else:
raise ValueError(
f"query_matrix_table: row key mismatch: cannot query a matrix table with row key "
f"({', '.join(builtins.str(x) for x in key_typ.values())}) with query point type {point.dtype}"
)

if point_or_interval.dtype != key_typ[0] and isinstance(point_or_interval.dtype, hl.tinterval):
partition_interval = hl.interval(
start=coerce_endpoint(point_or_interval.start),
end=coerce_endpoint(point_or_interval.end),
includes_start=point_or_interval.includes_start,
includes_end=point_or_interval.includes_end,
)
else:
point = coerce_endpoint(point_or_interval)
partition_interval = hl.interval(start=point, end=point, includes_start=True, includes_end=True)
read_part_ir = ir.ReadPartition(
partition_interval._ir, reader=ir.PartitionZippedNativeIntervalReader(path, full_row_type)
)
stream_expr = construct_expr(
read_part_ir,
type=hl.tstream(full_row_type),
indices=partition_interval._indices,
aggregations=partition_interval._aggregations,
)
return stream_expr.map(lambda item: item.rename({entry_id: entries_name})).to_array()


@typecheck(msg=expr_str, result=expr_any)
def _console_log(msg, result):
indices, aggregations = unify_all(msg, result)
Expand Down
2 changes: 2 additions & 0 deletions hail/python/hail/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
NDArraySVD,
NDArrayWrite,
PartitionNativeIntervalReader,
PartitionZippedNativeIntervalReader,
ProjectedTopLevelReference,
ReadPartition,
Recur,
Expand Down Expand Up @@ -527,6 +528,7 @@
'TableNativeFanoutWriter',
'ReadPartition',
'PartitionNativeIntervalReader',
'PartitionZippedNativeIntervalReader',
'GVCFPartitionReader',
'TableGen',
'Partitioner',
Expand Down
31 changes: 31 additions & 0 deletions hail/python/hail/ir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3510,6 +3510,37 @@ def row_type(self):
return tstruct(**self.table_row_type, **{self.uid_field: ttuple(tint64, tint64)})


class PartitionZippedNativeIntervalReader(PartitionReader):
def __init__(self, path, full_row_type, uid_field=None):
self.path = path
self.full_row_type = full_row_type
self.uid_field = uid_field

def with_uid_field(self, uid_field):
return PartitionZippedNativeIntervalReader(path=self.path, uid_field=uid_field)

def render(self):
return escape_str(
json.dumps({
"name": "PartitionZippedNativeIntervalReader",
"path": self.path,
"uidFieldName": self.uid_field if self.uid_field is not None else '__dummy',
})
)

def _eq(self, other):
return (
isinstance(other, PartitionZippedNativeIntervalReader)
and self.path == other.path
and self.uid_field == other.uid_field
)

def row_type(self):
if self.uid_field is None:
return self.full_row_type
return tstruct(**self.full_row_type, **{self.uid_field: ttuple(tint64, tint64)})


class ReadPartition(IR):
@typecheck_method(context=IR, reader=PartitionReader)
def __init__(self, context, reader):
Expand Down
73 changes: 73 additions & 0 deletions hail/python/test/hail/matrixtable/test_matrix_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2383,3 +2383,76 @@ def test_struct_of_arrays_encoding():
etype = md['_codecSpec']['_eType']
assert 'EStructOfArrays' in etype
assert mt._same(std_mt)


def test_query_matrix_table():
n_cols = 100
f = new_temp_file(extension='mt')
mt = hl.utils.range_matrix_table(n_rows=200, n_cols=n_cols, n_partitions=10)
mt = mt.filter_rows(mt.row_idx % 10 == 0)
mt = mt.filter_cols(mt.col_idx % 10 == 0)
mt = mt.annotate_rows(s=hl.str(mt.row_idx))
mt = mt.annotate_entries(n=mt.row_idx * mt.col_idx)
mt.write(f)

queries = [
hl.query_matrix_table(f, 50, 'e'),
hl.query_matrix_table(f, hl.struct(idx=50), 'e'),
hl.query_matrix_table(f, 55, 'e'),
hl.query_matrix_table(f, 5, 'e'),
hl.query_matrix_table(f, -1, 'e'),
hl.query_matrix_table(f, 205, 'e'),
hl.query_matrix_table(f, hl.interval(27, 66), 'e'),
hl.query_matrix_table(f, hl.interval(276, 33333), 'e'),
hl.query_matrix_table(f, hl.interval(-22276, -5), 'e'),
hl.query_matrix_table(f, hl.interval(hl.struct(idx=27), hl.struct(idx=66)), 'e'),
hl.query_matrix_table(f, hl.interval(40, 80, includes_end=True), 'e'),
]

col_idxs = [n for n in range(n_cols) if n % 10 == 0]

def ea_for(n):
return [n * m for m in col_idxs]

expected = [
[hl.Struct(row_idx=50, s='50', e=ea_for(50))],
[hl.Struct(row_idx=50, s='50', e=ea_for(50))],
[],
[],
[],
[],
[
hl.Struct(idx=30, s='30', e=ea_for(30)),
hl.Struct(idx=40, s='40', e=ea_for(40)),
hl.Struct(idx=50, s='50', e=ea_for(50)),
hl.Struct(idx=60, s='60', e=ea_for(60)),
],
[],
[],
[
hl.Struct(idx=30, s='30', e=ea_for(30)),
hl.Struct(idx=40, s='40', e=ea_for(40)),
hl.Struct(idx=50, s='50', e=ea_for(50)),
hl.Struct(idx=60, s='60', e=ea_for(60)),
],
[
hl.Struct(idx=40, s='40', e=ea_for(40)),
hl.Struct(idx=50, s='50', e=ea_for(50)),
hl.Struct(idx=60, s='60', e=ea_for(60)),
hl.Struct(idx=70, s='70', e=ea_for(70)),
hl.Struct(idx=80, s='80', e=ea_for(80)),
],
]

assert hl.eval(queries) == expected

with pytest.raises(ValueError, match='query_matrix_table: field "s" present'):
hl.query_matrix_table(f, 0, 's')
with pytest.raises(ValueError, match='query_matrix_table: mismatch at row key field'):
hl.query_matrix_table(f, hl.interval('1', '2'))
with pytest.raises(ValueError, match='query_matrix_table: row key mismatch: cannot query'):
hl.query_matrix_table(f, '1')
with pytest.raises(ValueError, match='query_matrix_table: cannot query with empty row key'):
hl.query_matrix_table(f, hl.struct())
with pytest.raises(ValueError, match='query_matrix_table: queried with 2 row key field'):
hl.query_matrix_table(f, hl.struct(idx=5, foo='s'))

0 comments on commit 4e0dbe5

Please sign in to comment.