Skip to content

Commit

Permalink
added more type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
bushwhackr committed Feb 8, 2024
1 parent efab860 commit 0ef5334
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 9 deletions.
9 changes: 7 additions & 2 deletions sdk/python/feast/infra/contrib/spark_kafka_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.avro.functions import from_avro
from pyspark.sql.functions import col, from_json
from pyspark.sql.streaming import StreamingQuery

from feast.data_format import AvroFormat, JsonFormat
from feast.data_source import KafkaSource, PushMode
Expand Down Expand Up @@ -63,7 +64,11 @@ def __init__(
self.join_keys = [fs.get_entity(entity).join_key for entity in sfv.entities]
super().__init__(fs=fs, sfv=sfv, data_source=sfv.stream_source)

def ingest_stream_feature_view(self, to: PushMode = PushMode.ONLINE) -> None:
# Type hinting for data_source type.
# data_source type has been checked to be an instance of KafkaSource.
self.data_source: KafkaSource = self.data_source # type: ignore

def ingest_stream_feature_view(self, to: PushMode = PushMode.ONLINE) -> StreamingQuery:
ingested_stream_df = self._ingest_stream_data()
transformed_df = self._construct_transformation_plan(ingested_stream_df)
online_store_query = self._write_stream_data(transformed_df, to)
Expand Down Expand Up @@ -122,7 +127,7 @@ def _ingest_stream_data(self) -> StreamTable:
def _construct_transformation_plan(self, df: StreamTable) -> StreamTable:
return self.sfv.udf.__call__(df) if self.sfv.udf else df

def _write_stream_data(self, df: StreamTable, to: PushMode):
def _write_stream_data(self, df: StreamTable, to: PushMode) -> StreamingQuery:
# Validation occurs at the fs.write_to_online_store() phase against the stream feature view schema.
def batch_write(row: DataFrame, batch_id: int):
rows: pd.DataFrame = row.toPandas()
Expand Down
6 changes: 3 additions & 3 deletions sdk/python/feast/infra/offline_stores/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ class RedshiftOfflineStoreConfig(FeastConfigBaseModel):
type: Literal["redshift"] = "redshift"
""" Offline store type selector"""

cluster_id: Optional[StrictStr]
cluster_id: Optional[StrictStr] = None
""" Redshift cluster identifier, for provisioned clusters """

user: Optional[StrictStr]
user: Optional[StrictStr] = None
""" Redshift user name, only required for provisioned clusters """

workgroup: Optional[StrictStr]
workgroup: Optional[StrictStr] = None
""" Redshift workgroup identifier, for serverless """

region: StrictStr
Expand Down
4 changes: 2 additions & 2 deletions sdk/python/feast/infra/offline_stores/snowflake_source.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Callable, Dict, Iterable, Optional, Tuple
from typing import Callable, Dict, Iterable, Optional, Tuple, Any

from typeguard import typechecked

Expand Down Expand Up @@ -223,7 +223,7 @@ def get_table_column_names_and_types(
query = f"SELECT * FROM {self.get_table_query_string()} LIMIT 5"
cursor = execute_snowflake_statement(conn, query)

metadata = [
metadata: list[dict[str, Any]] = [
{
"column_name": column.name,
"type_code": column.type_code,
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/feast/infra/registry/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def _delete_object(
"""
cursor = execute_snowflake_statement(conn, query)

if cursor.rowcount < 1 and not_found_exception:
if cursor.rowcount < 1 and not_found_exception: # type: ignore
raise not_found_exception(name, project)
self._set_last_updated_metadata(datetime.utcnow(), project)

Expand Down
4 changes: 3 additions & 1 deletion sdk/python/feast/infra/utils/snowflake/snowflake_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import feast
from feast.errors import SnowflakeIncompleteConfig, SnowflakeQueryUnknownError
from feast.feature_view import FeatureView
from feast.infra.offline_stores.snowflake import SnowflakeOfflineStoreConfig
from feast.infra.online_stores.snowflake import SnowflakeOnlineStoreConfig
from feast.repo_config import RepoConfig

try:
Expand All @@ -43,7 +45,7 @@


class GetSnowflakeConnection:
def __init__(self, config: str, autocommit=True):
def __init__(self, config: SnowflakeOfflineStoreConfig | SnowflakeOnlineStoreConfig, autocommit=True):
self.config = config
self.autocommit = autocommit

Expand Down

0 comments on commit 0ef5334

Please sign in to comment.