-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from goodwillpunning/initialSourceCode
Add initial source code
- Loading branch information
Showing
6 changed files
with
364 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,5 @@ | ||
# PyCharm IDE | ||
.idea | ||
|
||
# Just Apple stuff. | ||
.DS_Store |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,226 @@ | ||
import time | ||
import re | ||
from functools import reduce | ||
from pyspark.sql import DataFrame | ||
from concurrent.futures import ThreadPoolExecutor | ||
|
||
|
||
class Benchmark: | ||
"""Encapsulates a query benchmark test.""" | ||
def __init__(self, query=None, query_file=None, query_file_dir=None, concurrency=1, db_hostname=None, warehouse_http_path=None, token=None, catalog='hive_metastore', new_warehouse_config=None, results_cache_enabled=False): | ||
self.query=query | ||
self.query_file=query_file | ||
self.query_file_dir=query_file_dir | ||
self.concurrency=concurrency | ||
self.hostname=self.setHostname(db_hostname) | ||
self.http_path=warehouse_http_path | ||
self.token=token | ||
self.catalog=catalog | ||
self.new_warehouse_config=new_warehouse_config | ||
self.results_cache_enabled=results_cache_enabled | ||
# Check if a new SQL warehouse needs to be created | ||
if new_warehouse_config is not None: | ||
self.setWarehouseConfig(new_warehouse_config) | ||
|
||
def _get_user_id(self): | ||
"""Helper method for filtering query history the current User's Id""" | ||
response = requests.get( | ||
f"https://{self.hostname}/api/2.0/preview/scim/v2/Me", | ||
headers={ | ||
"Authorization": f"Bearer {self.token}" | ||
} | ||
) | ||
return response.json()['id'] | ||
|
||
def _validate_warehouse(self, http_path): | ||
"""Validates the SQL warehouse HTTP path.""" | ||
return True | ||
|
||
def _launch_new_warehouse(self): | ||
"""Launches a new SQL Warehouse""" | ||
warehouse_utils = SQLWarehouseUtils() | ||
warehouse_utils.setToken(token=self.token) | ||
warehouse_utils.setHostname(hostname=self.hostname) | ||
return warehouse_utils.launch_warehouse(self.new_warehouse_config) | ||
|
||
def setWarehouseConfig(self, config): | ||
"""Launches a new cluster/warehouse from a JSON config.""" | ||
self.new_warehouse_config = config | ||
print(f"Creating new warehouse with config: {config}") | ||
warehouse_id = self._launch_new_warehouse() | ||
print(f"The warehouse Id is: {warehouse_id}") | ||
self.http_path = f"/sql/1.0/warehouses/{warehouse_id}" | ||
|
||
def setWarehouse(self, http_path): | ||
"""Sets the SQL Warehouse http path to use for the benchmark.""" | ||
assert self._validate_warehouse(id), "Invalid HTTP path for SQL Warehouse." | ||
self.http_path=http_path | ||
|
||
def setConcurrency(self, concurrency): | ||
"""Sets the query execution parallelism.""" | ||
self.concurrency=concurrency | ||
|
||
def setHostname(self, hostname): | ||
"""Sets the Databricks workspace hostname.""" | ||
hostname_clean = hostname.strip().replace("http://", "").replace("https://", "").replace("/", "") if hostname is not None else hostname | ||
self.hostname=hostname_clean | ||
|
||
def setWarehouseToken(self, token): | ||
"""Sets the API token for communicating with the SQL warehouse.""" | ||
self.token = token | ||
|
||
def setCatalog(self, catalog): | ||
"""Set the target Catalog to execute queries.""" | ||
self.catalog = catalog | ||
|
||
def setQuery(self, query): | ||
"""Sets a single query to execute.""" | ||
self.query = query | ||
|
||
def _validateQueryFile(self, query_file): | ||
"""Validates the query file.""" | ||
return True | ||
|
||
def setQueryFile(self, query_file): | ||
"""Sets the query file to use.""" | ||
assert self._validateQueryFile(query_file), "Invalid query file." | ||
self.query_file = query_file | ||
|
||
def _validateQueryFileDir(self, query_file_dir): | ||
"""Validates the query file directory.""" | ||
return True | ||
|
||
def setQueryFileDir(self, query_file_dir): | ||
"""Sets the directory to load query files.""" | ||
assert self._validateQueryFileDir(query_file_dir), "Invalid query file directory." | ||
self.query_file_dir = query_file_dir | ||
|
||
def _execute_single_query(self, query, id=None): | ||
query = query.strip() | ||
print(query) | ||
start_time = time.perf_counter() | ||
sql_warehouse = SQLWarehouseUtils(self.hostname, self.http_path, self.token) | ||
sql_warehouse.execute_query(query) | ||
end_time = time.perf_counter() | ||
elapsed_time = f"{end_time-start_time:0.3f}" | ||
metrics = [(id, self.hostname, self.http_path, self.concurrency, query, elapsed_time)] | ||
metrics_df = spark.createDataFrame(metrics, "id string, hostname string, warehouse string, concurrency int, query_text string, query_duration_secs string") | ||
return metrics_df | ||
|
||
def _set_default_catalog(self): | ||
query = f"USE CATALOG {self.catalog}" | ||
self._execute_single_query(query) | ||
|
||
def _set_results_caching(self): | ||
"""Enables/disables results caching.""" | ||
if not self.results_cache_enabled: | ||
query = "SET use_cached_result=false" | ||
else: | ||
query = "SET use_cached_result=true" | ||
self._execute_single_query(query) | ||
|
||
def _parse_queries(self, raw_queries): | ||
split_raw = re.split(r"(Q\d+\n+)", raw_queries)[1:] | ||
split_clean = list(map(str.strip, split_raw)) | ||
headers = split_clean[::2] | ||
queries = split_clean[1::2] | ||
return headers, queries | ||
|
||
def _get_concurrent_queries(self, headers_list, queries_list, max_concurrency): | ||
"""Slices headers and queries into equal bins""" | ||
for i in range(0, len(queries_list), max_concurrency): | ||
headers = headers_list[i:(i + max_concurrency)] | ||
queries = queries_list[i:(i + max_concurrency)] | ||
yield list(zip(queries, headers)) | ||
|
||
def _execute_queries_from_file(self, query_file): | ||
"""Parses a file containing a list of queries to execute on a SQL warehouse.""" | ||
# Keep a list of unique query Ids/headers and query strings | ||
headers = [] | ||
queries = [] | ||
|
||
# Load queries from SQL file | ||
print(f"Loading queries from file: '{query_file}'") | ||
query_file_cleaned = query_file.replace("dbfs:/", "/dbfs/") # Replace `dbfs:` paths | ||
|
||
# Parse the raw SQL, splitting lines into a query identifier (header) and query string | ||
with open(query_file_cleaned) as f: | ||
raw_queries = f.read() | ||
file_headers, file_queries = self._parse_queries(raw_queries) | ||
headers = headers + file_headers | ||
queries = queries + file_queries | ||
|
||
# Split the list of queries into buckets determined by specified concurrency | ||
bucketed_queries_list = list(self._get_concurrent_queries(headers, queries, self.concurrency)) | ||
print(f"There are {len(queries)} total queries.") | ||
print(f"The concurrency is {self.concurrency}") | ||
print(f"There are {len(bucketed_queries_list)} buckets of queries") | ||
|
||
# Take each bucket of queries and execute concurrently | ||
final_metrics_result = [] | ||
for query_bucket in bucketed_queries_list: | ||
print(f'Executing {len(query_bucket)} queries concurrently.') | ||
# Multi-thread query execution | ||
queries_in_bucket = [query_with_header for query_with_header in query_bucket] | ||
num_threads = len(queries_in_bucket) | ||
with ThreadPoolExecutor(max_workers=num_threads) as executor: | ||
# Maps the method '_execute_single_query' with a list of queries. | ||
metrics_list = list(executor.map(lambda query_with_header: self._execute_single_query(*query_with_header), query_bucket)) | ||
final_metrics_result = final_metrics_result + metrics_list | ||
|
||
# Union together the metrics DFs | ||
if len(final_metrics_result) > 0: | ||
final_metrics_df = reduce(DataFrame.unionAll, final_metrics_result) | ||
else: | ||
final_metrics_df = spark.sparkContext.emptyRDD() | ||
return final_metrics_df | ||
|
||
def execute(self): | ||
"""Executes the benchmark test.""" | ||
print("Executing benchmark test.") | ||
# Set which Catalog to use | ||
self._set_default_catalog() | ||
# Enable/disable results caching on the SQL warehouse | ||
# https://docs.databricks.com/sql/admin/query-caching.html | ||
self._set_results_caching() | ||
# Query format precedence: | ||
# 1. Query File Dir | ||
# 2. Query File | ||
# 3. Single Query | ||
if self.query_file_dir is not None: | ||
print("Loading query files from directory.") | ||
# TODO: Implement query directory parsing | ||
#metrics_df = self._execute_queries_from_dir(self.query_file_dir) | ||
elif self.query_file is not None: | ||
print("Loading query file.") | ||
metrics_df = self._execute_queries_from_file(self.query_file) | ||
elif self.query is not None: | ||
print("Executing single query.") | ||
metrics_df = self._execute_single_query(self.query) | ||
else: | ||
raise ValueError("No query specified.") | ||
return metrics_df | ||
|
||
def preWarmTables(self, tables): | ||
"""Delta caches the table before running a benchmark test.""" | ||
assert self.http_path is not None, "No running warehouse. You can launch a new ware house by calling `.setWarehouseConfig()`." | ||
assert self.catalog is not None, "No catalog provided. You can add a catalog by calling `.setCatalog()`." | ||
self._execute_single_query(f"USE CATALOG {self.catalog}") | ||
for table in tables: | ||
print(f"Pre-warming table: {table}") | ||
query = f"SELECT * FROM {table}" | ||
self._execute_single_query(query) | ||
|
||
def __str__(self): | ||
object_str = f""" | ||
Benchmark Test: | ||
------------------------ | ||
catalog={self.catalog} | ||
query="{self.query}" | ||
query_file={self.query_file} | ||
query_file_dir={self.query_file_dir} | ||
concurrency={self.concurrency} | ||
hostname={self.hostname} | ||
warehouse_http_path={self.http_path} | ||
""" | ||
return object_str |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from pyspark.sql import SparkSession | ||
|
||
|
||
class SparkSingleton: | ||
|
||
@classmethod | ||
def get_instance(cls): | ||
return SparkSession.builder.getOrCreate() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
from databricks import sql | ||
import requests | ||
|
||
|
||
class SQLWarehouseUtils: | ||
|
||
_LATEST_RUNTIME = '11.3.x-photon-scala2.12' | ||
_CLUSTER_SIZES = ["2X-Small", "X-Small", "Small", "Medium", "Large", "X-Large", "2X-Large", "3X-Large", "4X-Large"] | ||
|
||
def __init__(self, hostname=None, warehouse_http_path=None, token=None): | ||
self.hostname=hostname | ||
self.http_path=warehouse_http_path | ||
self.access_token=token | ||
|
||
def _get_connection(self): | ||
connection = sql.connect( | ||
server_hostname=self.hostname, | ||
http_path=self.http_path, | ||
access_token=self.access_token, | ||
session_configuration={"use_cached_result": "false"}) | ||
return connection | ||
|
||
def execute_query(self, query_str): | ||
connection = self._get_connection() | ||
cursor = connection.cursor() | ||
result = cursor.execute(query_str) | ||
cursor.close() | ||
connection.close() | ||
|
||
def get_rows(self, query_str): | ||
connection = self._get_connection() | ||
cursor = connection.cursor() | ||
cursor.execute(query_str) | ||
rows = cursor.fetchall() | ||
cursor.close() | ||
connection.close() | ||
return rows | ||
|
||
def setToken(self, token): | ||
self.access_token = token | ||
|
||
def setHostname(self, hostname): | ||
self.hostname = hostname | ||
|
||
def _get_spark_runtimes(self): | ||
"""Gets a list of the latest Spark runtimes.""" | ||
response = requests.get( | ||
f"https://{self.hostname}/api/2.0/clusters/spark-versions", | ||
headers={ | ||
"Authorization": f"Bearer {self.access_token}" | ||
} | ||
) | ||
result = list(map(lambda v: v['key'], response.json()['versions'])) | ||
return result | ||
|
||
def launch_warehouse(self, config): | ||
"""Creates a new SQL warehouse based upon a config.""" | ||
assert self.access_token is not None, "An API token is needed to launch a compute instance. Use `.setToken(token)` to add an API token." | ||
assert self.hostname is not None, "A Databricks hostname is needed to launch a compute instance. Use `.setHostname(hostname)` to add a Databricks hostname." | ||
# Determine the type of compute to lauch: warehouse or cluster | ||
if 'type' not in config: | ||
type = 'warehouse' # default to a SQL warehouse | ||
else: | ||
type = config['type'].strip().lower() | ||
assert type == "warehouse" or type == "cluster", "Invalid compute 'type' provided. Allowed types include: ['warehouse', 'cluster']." | ||
|
||
# Determine the Spark runtime to install | ||
latest_runtimes = self._get_spark_runtimes() | ||
if 'runtime' not in config: | ||
spark_version = self._LATEST_RUNTIME # default to the latest runtime | ||
elif config['runtime'].strip().lower() == 'latest': | ||
spark_version = self._LATEST_RUNTIME # default to the latest runtime | ||
else: | ||
spark_version = config['runtime'].strip().lower() | ||
assert spark_version in latest_runtimes, f"Invalid Spark 'runtime'. Valid runtimes include: {latest_runtimes}" | ||
|
||
# Determine the size of the compute | ||
if 'size' not in config: | ||
size = 'Small' | ||
else: | ||
size = config['size'].strip() | ||
assert size in self._CLUSTER_SIZES, f"Invalid cluster 'size'. Valid cluster 'sizes' include: {self._CLUSTER_SIZES}" | ||
|
||
# Determine if Photon should be enabled or not | ||
if 'enable_photon' not in config: | ||
enable_photon = 'true' # default | ||
else: | ||
enable_photon = str(config['enable_photon']).lower() | ||
|
||
# Determine auto-scaling | ||
if 'max_num_clusters' in config: | ||
max_num_clusters = config['max_num_clusters'] | ||
min_num_clusters = config['min_num_clusters'] if 'min_num_clusters' in config else 1 | ||
else: | ||
min_num_clusters = 1 | ||
max_num_clusters = 1 | ||
|
||
response = requests.post( | ||
f"https://{self.hostname}/api/2.0/sql/warehouses/", | ||
headers={ | ||
"Authorization": f"Bearer {self.access_token}" | ||
}, | ||
json={ | ||
"name": "🧪 Beaker Benchmark Testing Warehouse", | ||
"cluster_size": size, | ||
"min_num_clusters": min_num_clusters, | ||
"max_num_clusters": max_num_clusters, | ||
"tags": { | ||
"custom_tags": [ | ||
{ | ||
"key": "Description", | ||
"value": "Beaker Benchmark Testing" | ||
} | ||
] | ||
}, | ||
"enable_photon": enable_photon, | ||
"channel": { | ||
"name": "CHANNEL_NAME_CURRENT" | ||
} | ||
} | ||
) | ||
warehouse_id = response.json()['id'] | ||
return warehouse_id |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
databricks-sql-connector | ||
pyspark | ||
requests |