diff --git a/sdk/python/feast/infra/contrib/spark_kafka_processor.py b/sdk/python/feast/infra/contrib/spark_kafka_processor.py index ea55d89988a..775352cbae5 100644 --- a/sdk/python/feast/infra/contrib/spark_kafka_processor.py +++ b/sdk/python/feast/infra/contrib/spark_kafka_processor.py @@ -5,6 +5,7 @@ from pyspark.sql import DataFrame, SparkSession from pyspark.sql.avro.functions import from_avro from pyspark.sql.functions import col, from_json +from pyspark.sql.streaming import StreamingQuery from feast.data_format import AvroFormat, JsonFormat from feast.data_source import KafkaSource, PushMode @@ -63,7 +64,11 @@ def __init__( self.join_keys = [fs.get_entity(entity).join_key for entity in sfv.entities] super().__init__(fs=fs, sfv=sfv, data_source=sfv.stream_source) - def ingest_stream_feature_view(self, to: PushMode = PushMode.ONLINE) -> None: + # Type hinting for data_source type. + # data_source type has been checked to be an instance of KafkaSource. + self.data_source: KafkaSource = self.data_source # type: ignore + + def ingest_stream_feature_view(self, to: PushMode = PushMode.ONLINE) -> StreamingQuery: ingested_stream_df = self._ingest_stream_data() transformed_df = self._construct_transformation_plan(ingested_stream_df) online_store_query = self._write_stream_data(transformed_df, to) @@ -122,7 +127,7 @@ def _ingest_stream_data(self) -> StreamTable: def _construct_transformation_plan(self, df: StreamTable) -> StreamTable: return self.sfv.udf.__call__(df) if self.sfv.udf else df - def _write_stream_data(self, df: StreamTable, to: PushMode): + def _write_stream_data(self, df: StreamTable, to: PushMode) -> StreamingQuery: # Validation occurs at the fs.write_to_online_store() phase against the stream feature view schema. def batch_write(row: DataFrame, batch_id: int): rows: pd.DataFrame = row.toPandas() diff --git a/sdk/python/feast/infra/offline_stores/redshift.py b/sdk/python/feast/infra/offline_stores/redshift.py index 837cf49655d..6034bf5ac7b 100644 --- a/sdk/python/feast/infra/offline_stores/redshift.py +++ b/sdk/python/feast/infra/offline_stores/redshift.py @@ -51,13 +51,13 @@ class RedshiftOfflineStoreConfig(FeastConfigBaseModel): type: Literal["redshift"] = "redshift" """ Offline store type selector""" - cluster_id: Optional[StrictStr] + cluster_id: Optional[StrictStr] = None """ Redshift cluster identifier, for provisioned clusters """ - user: Optional[StrictStr] + user: Optional[StrictStr] = None """ Redshift user name, only required for provisioned clusters """ - workgroup: Optional[StrictStr] + workgroup: Optional[StrictStr] = None """ Redshift workgroup identifier, for serverless """ region: StrictStr diff --git a/sdk/python/feast/infra/offline_stores/snowflake_source.py b/sdk/python/feast/infra/offline_stores/snowflake_source.py index 0cbf82dd1c1..b4500b9807e 100644 --- a/sdk/python/feast/infra/offline_stores/snowflake_source.py +++ b/sdk/python/feast/infra/offline_stores/snowflake_source.py @@ -1,5 +1,5 @@ import warnings -from typing import Callable, Dict, Iterable, Optional, Tuple +from typing import Callable, Dict, Iterable, Optional, Tuple, Any from typeguard import typechecked @@ -223,7 +223,7 @@ def get_table_column_names_and_types( query = f"SELECT * FROM {self.get_table_query_string()} LIMIT 5" cursor = execute_snowflake_statement(conn, query) - metadata = [ + metadata: list[dict[str, Any]] = [ { "column_name": column.name, "type_code": column.type_code, diff --git a/sdk/python/feast/infra/registry/snowflake.py b/sdk/python/feast/infra/registry/snowflake.py index 56c7bc1f659..c1ebf13d6b8 100644 --- a/sdk/python/feast/infra/registry/snowflake.py +++ b/sdk/python/feast/infra/registry/snowflake.py @@ -418,7 +418,7 @@ def _delete_object( """ cursor = execute_snowflake_statement(conn, query) - if cursor.rowcount < 1 and not_found_exception: + if cursor.rowcount < 1 and not_found_exception: # type: ignore raise not_found_exception(name, project) self._set_last_updated_metadata(datetime.utcnow(), project) diff --git a/sdk/python/feast/infra/utils/snowflake/snowflake_utils.py b/sdk/python/feast/infra/utils/snowflake/snowflake_utils.py index a4cda89a6f6..36588efd0fa 100644 --- a/sdk/python/feast/infra/utils/snowflake/snowflake_utils.py +++ b/sdk/python/feast/infra/utils/snowflake/snowflake_utils.py @@ -22,6 +22,8 @@ import feast from feast.errors import SnowflakeIncompleteConfig, SnowflakeQueryUnknownError from feast.feature_view import FeatureView +from feast.infra.offline_stores.snowflake import SnowflakeOfflineStoreConfig +from feast.infra.online_stores.snowflake import SnowflakeOnlineStoreConfig from feast.repo_config import RepoConfig try: @@ -43,7 +45,7 @@ class GetSnowflakeConnection: - def __init__(self, config: str, autocommit=True): + def __init__(self, config: SnowflakeOfflineStoreConfig | SnowflakeOnlineStoreConfig, autocommit=True): self.config = config self.autocommit = autocommit