diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py index aeb9e3cd68..4b50188632 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py @@ -99,6 +99,8 @@ def pull_latest_from_table_or_query( fields_as_string = ", ".join(fields_with_aliases) aliases_as_string = ", ".join(aliases) + date_partition_column = data_source.date_partition_column + start_date_str = _format_datetime(start_date) end_date_str = _format_datetime(end_date) query = f""" @@ -109,7 +111,7 @@ def pull_latest_from_table_or_query( SELECT {fields_as_string}, ROW_NUMBER() OVER({partition_by_join_key_string} ORDER BY {timestamp_desc_string}) AS feast_row_ FROM {from_expression} t1 - WHERE {timestamp_field} BETWEEN TIMESTAMP('{start_date_str}') AND TIMESTAMP('{end_date_str}') + WHERE {timestamp_field} BETWEEN TIMESTAMP('{start_date_str}') AND TIMESTAMP('{end_date_str}'){" AND "+date_partition_column+" >= '"+start_date.strftime('%Y-%m-%d')+"' AND "+date_partition_column+" <= '"+end_date.strftime('%Y-%m-%d')+"' " if date_partition_column != "" and date_partition_column is not None else ''} ) t2 WHERE feast_row_ = 1 """ @@ -641,8 +643,15 @@ def _cast_data_frame( {% endfor %} FROM {{ featureview.table_subquery }} WHERE {{ featureview.timestamp_field }} <= '{{ featureview.max_event_timestamp }}' + {% if featureview.date_partition_column != "" and featureview.date_partition_column is not none %} + AND {{ featureview.date_partition_column }} <= '{{ featureview.max_event_timestamp[:10] }}' + {% endif %} + {% if featureview.ttl == 0 %}{% else %} AND {{ featureview.timestamp_field }} >= '{{ featureview.min_event_timestamp }}' + {% if featureview.date_partition_column != "" and featureview.date_partition_column is not none %} + AND {{ featureview.date_partition_column }} >= '{{ featureview.min_event_timestamp[:10] }}' + {% endif %} {% endif %} ), diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py index 209e3b87e8..7ad331239f 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py @@ -45,6 +45,7 @@ def __init__( tags: Optional[Dict[str, str]] = None, owner: Optional[str] = "", timestamp_field: Optional[str] = None, + date_partition_column: Optional[str] = None, ): """Creates a SparkSource object. @@ -64,6 +65,8 @@ def __init__( maintainer. timestamp_field: Event timestamp field used for point-in-time joins of feature values. + date_partition_column: The column to partition the data on for faster + retrieval. This is useful for large tables and will limit the number ofi """ # If no name, use the table as the default name. if name is None and table is None: @@ -77,6 +80,7 @@ def __init__( created_timestamp_column=created_timestamp_column, field_mapping=field_mapping, description=description, + date_partition_column=date_partition_column, tags=tags, owner=owner, ) @@ -135,6 +139,7 @@ def from_proto(data_source: DataSourceProto) -> Any: query=spark_options.query, path=spark_options.path, file_format=spark_options.file_format, + date_partition_column=data_source.date_partition_column, timestamp_field=data_source.timestamp_field, created_timestamp_column=data_source.created_timestamp_column, description=data_source.description, @@ -148,6 +153,7 @@ def to_proto(self) -> DataSourceProto: type=DataSourceProto.BATCH_SPARK, data_source_class_type="feast.infra.offline_stores.contrib.spark_offline_store.spark_source.SparkSource", field_mapping=self.field_mapping, + date_partition_column=self.date_partition_column, spark_options=self.spark_options.to_proto(), description=self.description, tags=self.tags, diff --git a/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark.py b/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark.py index b8f8cc4247..307ba4058c 100644 --- a/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark.py +++ b/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark.py @@ -71,6 +71,68 @@ def test_pull_latest_from_table_with_nested_timestamp_or_query(mock_get_spark_se assert retrieval_job.query.strip() == expected_query.strip() +@patch( + "feast.infra.offline_stores.contrib.spark_offline_store.spark.get_spark_session_or_start_new_with_repoconfig" +) +def test_pull_latest_from_table_with_nested_timestamp_or_query_and_date_partition_column_set( + mock_get_spark_session, +): + mock_spark_session = MagicMock() + mock_get_spark_session.return_value = mock_spark_session + + test_repo_config = RepoConfig( + project="test_project", + registry="test_registry", + provider="local", + offline_store=SparkOfflineStoreConfig(type="spark"), + ) + + test_data_source = SparkSource( + name="test_nested_batch_source", + description="test_nested_batch_source", + table="offline_store_database_name.offline_store_table_name", + timestamp_field="nested_timestamp", + field_mapping={ + "event_header.event_published_datetime_utc": "nested_timestamp", + }, + date_partition_column="effective_date", + ) + + # Define the parameters for the method + join_key_columns = ["key1", "key2"] + feature_name_columns = ["feature1", "feature2"] + timestamp_field = "event_header.event_published_datetime_utc" + created_timestamp_column = "created_timestamp" + start_date = datetime(2021, 1, 1) + end_date = datetime(2021, 1, 2) + + # Call the method + retrieval_job = SparkOfflineStore.pull_latest_from_table_or_query( + config=test_repo_config, + data_source=test_data_source, + join_key_columns=join_key_columns, + feature_name_columns=feature_name_columns, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + start_date=start_date, + end_date=end_date, + ) + + expected_query = """SELECT + key1, key2, feature1, feature2, nested_timestamp, created_timestamp + + FROM ( + SELECT key1, key2, feature1, feature2, event_header.event_published_datetime_utc AS nested_timestamp, created_timestamp, + ROW_NUMBER() OVER(PARTITION BY key1, key2 ORDER BY event_header.event_published_datetime_utc DESC, created_timestamp DESC) AS feast_row_ + FROM `offline_store_database_name`.`offline_store_table_name` t1 + WHERE event_header.event_published_datetime_utc BETWEEN TIMESTAMP('2021-01-01 00:00:00.000000') AND TIMESTAMP('2021-01-02 00:00:00.000000') AND effective_date >= '2021-01-01' AND effective_date <= '2021-01-02' + ) t2 + WHERE feast_row_ = 1""" # noqa: W293, W291 + + assert isinstance(retrieval_job, RetrievalJob) + assert retrieval_job.query.strip() == expected_query.strip() + + @patch( "feast.infra.offline_stores.contrib.spark_offline_store.spark.get_spark_session_or_start_new_with_repoconfig" ) @@ -127,3 +189,62 @@ def test_pull_latest_from_table_without_nested_timestamp_or_query( assert isinstance(retrieval_job, RetrievalJob) assert retrieval_job.query.strip() == expected_query.strip() + + +@patch( + "feast.infra.offline_stores.contrib.spark_offline_store.spark.get_spark_session_or_start_new_with_repoconfig" +) +def test_pull_latest_from_table_without_nested_timestamp_or_query_and_date_partition_column_set( + mock_get_spark_session, +): + mock_spark_session = MagicMock() + mock_get_spark_session.return_value = mock_spark_session + + test_repo_config = RepoConfig( + project="test_project", + registry="test_registry", + provider="local", + offline_store=SparkOfflineStoreConfig(type="spark"), + ) + + test_data_source = SparkSource( + name="test_batch_source", + description="test_nested_batch_source", + table="offline_store_database_name.offline_store_table_name", + timestamp_field="event_published_datetime_utc", + date_partition_column="effective_date", + ) + + # Define the parameters for the method + join_key_columns = ["key1", "key2"] + feature_name_columns = ["feature1", "feature2"] + timestamp_field = "event_published_datetime_utc" + created_timestamp_column = "created_timestamp" + start_date = datetime(2021, 1, 1) + end_date = datetime(2021, 1, 2) + + # Call the method + retrieval_job = SparkOfflineStore.pull_latest_from_table_or_query( + config=test_repo_config, + data_source=test_data_source, + join_key_columns=join_key_columns, + feature_name_columns=feature_name_columns, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + start_date=start_date, + end_date=end_date, + ) + + expected_query = """SELECT + key1, key2, feature1, feature2, event_published_datetime_utc, created_timestamp + + FROM ( + SELECT key1, key2, feature1, feature2, event_published_datetime_utc, created_timestamp, + ROW_NUMBER() OVER(PARTITION BY key1, key2 ORDER BY event_published_datetime_utc DESC, created_timestamp DESC) AS feast_row_ + FROM `offline_store_database_name`.`offline_store_table_name` t1 + WHERE event_published_datetime_utc BETWEEN TIMESTAMP('2021-01-01 00:00:00.000000') AND TIMESTAMP('2021-01-02 00:00:00.000000') AND effective_date >= '2021-01-01' AND effective_date <= '2021-01-02' + ) t2 + WHERE feast_row_ = 1""" # noqa: W293, W291 + + assert isinstance(retrieval_job, RetrievalJob) + assert retrieval_job.query.strip() == expected_query.strip()