Skip to content

Commit

Permalink
Fuzzy dedup modifications (#687)
Browse files Browse the repository at this point in the history
* Driver broadcast of large spark config variables

Signed-off-by: Constantin M Adam <[email protected]>

* Launch k8s spark cluster using spark native k8s API

Signed-off-by: Constantin M Adam <[email protected]>

* Add PyYAML to the base image dependencies

Signed-off-by: Constantin M Adam <[email protected]>

* Updated documentation for get_bcast_params() method

Signed-off-by: Constantin M Adam <[email protected]>

* Updated documentation

Signed-off-by: Constantin M Adam <[email protected]>

---------

Signed-off-by: Constantin M Adam <[email protected]>
  • Loading branch information
cmadam authored Oct 11, 2024
1 parent fa76928 commit 2869eea
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 16 deletions.
8 changes: 5 additions & 3 deletions data-processing-lib/doc/spark-runtime.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ of this parameter:

## Transforms

* [SparkTransformRuntimeConfiguration](../spark/src/data_processing_spark/transform/runtime_configuration.py) allows
to configure transform to use PySpark

* [SparkTransformRuntimeConfiguration](../spark/src/data_processing_spark/runtime/spark/runtime_configuration.py)
allows to configure transform to use PySpark. In addition to its base class
[TransformRuntimeConfiguration](../python//src/data_processing/runtime/runtime_configuration.py) features,
this class includes `get_bcast_params()` method to get very large configuration settings. Before starting the
transform execution, the Spark runtime will broadcast these settings to all the workers.

## Runtime

Expand Down
3 changes: 2 additions & 1 deletion data-processing-lib/spark/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ authors = [
dependencies = [
"data-prep-toolkit==0.2.2.dev0",
"pyspark>=3.5.2",
"psutil>=6.0.0"
"psutil>=6.0.0",
"PyYAML>=6.0.2"
]

[project_urls]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
# limitations under the License.
################################################################################

from typing import Any

from data_processing.data_access import DataAccessFactoryBase
from data_processing.runtime import TransformRuntimeConfiguration
from data_processing.transform import TransformConfiguration
from data_processing_spark.runtime.spark import DefaultSparkTransformRuntime
Expand All @@ -29,6 +32,16 @@ def __init__(
super().__init__(transform_config=transform_config)
self.runtime_class = runtime_class

def get_bcast_params(self, data_access_factory: DataAccessFactoryBase) -> dict[str, Any]:
"""Allows retrieving and broadcasting to all the workers very large
configuration parameters, like the list of document IDs to remove for
fuzzy dedup, or the list of blocked web domains for block listing. This
function is called by the spark runtime after spark initialization, and
before spark_context.parallelize()
:param data_access_factory - creates data_access object to download the large config parameter
"""
return {}

def create_transform_runtime(self) -> DefaultSparkTransformRuntime:
"""
Create transform runtime with the parameters captured during apply_input_params()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,69 @@
# limitations under the License.
################################################################################

import os
import socket
import time
import traceback
from datetime import datetime

import yaml
from data_processing.data_access import DataAccessFactoryBase
from data_processing.transform import TransformStatistics
from data_processing.utils import GB, get_logger
from data_processing_spark.runtime.spark import (
SparkTransformExecutionConfiguration,
SparkTransformFileProcessor,
SparkTransformRuntimeConfiguration,
SparkTransformExecutionConfiguration,
)
from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession


logger = get_logger(__name__)


def _init_spark(runtime_config: SparkTransformRuntimeConfiguration) -> SparkSession:
server_port_https = int(os.getenv("KUBERNETES_SERVICE_PORT_HTTPS", "-1"))
if server_port_https == -1:
# running locally
spark_config = {"spark.driver.host": "127.0.0.1"}
return SparkSession.builder.appName(runtime_config.get_name()).config(map=spark_config).getOrCreate()
else:
# running in Kubernetes, use spark_profile.yml and
# environment variables for configuration
server_port = os.environ["KUBERNETES_SERVICE_PORT"]
master_url = f"k8s://https://kubernetes.default:{server_port}"

# Read Spark configuration profile
config_filepath = os.path.abspath(
os.path.join(os.getenv("SPARK_HOME"), "work-dir", "config", "spark_profile.yml")
)
with open(config_filepath, "r") as config_fp:
spark_config = yaml.safe_load(os.path.expandvars(config_fp.read()))
spark_config["spark.submit.deployMode"] = "client"

# configure the executor pods from template
executor_pod_template_file = os.path.join(
os.getenv("SPARK_HOME"),
"work-dir",
"src",
"templates",
"spark-executor-pod-template.yml",
)
spark_config["spark.kubernetes.executor.podTemplateFile"] = executor_pod_template_file
spark_config["spark.kubernetes.container.image.pullPolicy"] = "Always"

# Pass the driver IP address to the workers for callback
myservice_url = socket.gethostbyname(socket.gethostname())
spark_config["spark.driver.host"] = myservice_url
spark_config["spark.driver.bindAddress"] = "0.0.0.0"
spark_config["spark.decommission.enabled"] = True
logger.info(f"Launching Spark Session with configuration\n" f"{yaml.dump(spark_config, indent=2)}")
app_name = spark_config.get("spark.app.name", "my-spark-app")
return SparkSession.builder.master(master_url).appName(app_name).config(map=spark_config).getOrCreate()


def orchestrate(
runtime_config: SparkTransformRuntimeConfiguration,
execution_configuration: SparkTransformExecutionConfiguration,
Expand All @@ -45,14 +90,17 @@ def orchestrate(
logger.info(f"orchestrator started at {start_ts}")
# create data access
data_access = data_access_factory.create_data_access()
bcast_params = runtime_config.get_bcast_params(data_access_factory)
if data_access is None:
logger.error("No DataAccess instance provided - exiting")
return 1
# initialize Spark
conf = SparkConf().setAppName(runtime_config.get_name()).set("spark.driver.host", "127.0.0.1")
sc = SparkContext(conf=conf)
spark_session = _init_spark(runtime_config)
sc = spark_session.sparkContext
# broadcast
spark_runtime_config = sc.broadcast(runtime_config)
daf = sc.broadcast(data_access_factory)
spark_bcast_params = sc.broadcast(bcast_params)

def process_partition(iterator):
"""
Expand All @@ -63,6 +111,7 @@ def process_partition(iterator):
# local statistics dictionary
statistics = TransformStatistics()
# create transformer runtime
bcast_params = spark_bcast_params.value
d_access_factory = daf.value
runtime_conf = spark_runtime_config.value
runtime = runtime_conf.create_transform_runtime()
Expand All @@ -77,8 +126,11 @@ def process_partition(iterator):
logger.debug(f"partition {f}")
# add additional parameters
transform_params = (
runtime.get_transform_config(partition=int(f[1]), data_access_factory=d_access_factory,
statistics=statistics))
runtime.get_transform_config(
partition=int(f[1]), data_access_factory=d_access_factory, statistics=statistics
)
| bcast_params
)
# create transform with partition number
file_processor.create_transform(transform_params)
first = False
Expand Down Expand Up @@ -128,7 +180,7 @@ def process_partition(iterator):
memory = 0.0
for i in range(executors.size()):
memory += executors.toList().apply(i)._2()._1()
resources = {"cpus": cpus, "gpus": 0, "memory": round(memory/GB, 2), "object_store": 0}
resources = {"cpus": cpus, "gpus": 0, "memory": round(memory / GB, 2), "object_store": 0}
input_params = runtime_config.get_transform_metadata() | execution_configuration.get_input_params()
metadata = {
"pipeline": execution_configuration.pipeline_id,
Expand All @@ -143,7 +195,8 @@ def process_partition(iterator):
"execution_stats": {
"num partitions": num_partitions,
"execution time, min": round((time.time() - start_time) / 60, 3),
} | resources,
}
| resources,
"job_output_stats": stats,
}
logger.debug(f"Saving job metadata: {metadata}.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,29 @@ def get_transform_config(
"""
Get the dictionary of configuration that will be provided to the transform's initializer.
This is the opportunity for this runtime to create a new set of configuration based on the
config/params provided to this instance's initializer. This may include the addition
of new configuration data such as ray shared memory, new actors, etc, that might be needed and
expected by the transform in its initializer and/or transform() methods.
config/params provided to this instance's initializer.
:param partition - the partition assigned to this worker, needed by transforms like doc_id
:param data_access_factory - data access factory class being used by the RayOrchestrator.
:param statistics - reference to statistics actor
:return: dictionary of transform init params
"""
return self.params

def get_bcast_params(self, data_access_factory: DataAccessFactoryBase) -> dict[str, Any]:
"""Allows retrieving and broadcasting to all the workers very large
configuration parameters, like the list of document IDs to remove for
fuzzy dedup, or the list of blocked web domains for block listing. This
function is called by the spark runtime after spark initialization, and
before spark_context.parallelize()
:param data_access_factory - creates data_access object to download the large config parameter
"""
return {}

def compute_execution_stats(self, stats: TransformStatistics) -> None:
"""
Update/augment the given statistics object with runtime-specific additions/modifications.
This method does not return a value; the job execution statistics are generally reported
as metadata by the Spark Orchestrator.
:param stats: output of statistics as aggregated across all calls to all transforms.
:return: job execution statistics. These are generally reported as metadata by the Ray Orchestrator.
"""
pass
pass

0 comments on commit 2869eea

Please sign in to comment.