Skip to content

Commit

Permalink
Merge pull request #9 from goodwillpunning/addProjectPackaging
Browse files Browse the repository at this point in the history
Update project packaging
  • Loading branch information
goodwillpunning authored Mar 4, 2023
2 parents ec07128 + 69b2509 commit fdc4700
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 29 deletions.
Binary file modified dist/beaker-0.0.1-py3-none-any.whl
Binary file not shown.
Binary file modified dist/beaker-0.0.1.tar.gz
Binary file not shown.
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ classifiers = [
"License :: Other/Proprietary License",
"Operating System :: OS Independent",
]
dependencies = [
"requests",
"databricks-sql-connector",
"pyspark"
]

[project.urls]
"Homepage" = "https://github.com/goodwillpunning/beaker"
Expand Down
1 change: 1 addition & 0 deletions src/beaker/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .benchmark import Benchmark
54 changes: 33 additions & 21 deletions src/beaker/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,33 @@
import time
import re
import requests
from functools import reduce
from pyspark.sql import DataFrame
from concurrent.futures import ThreadPoolExecutor
from beaker.sqlwarehouseutils import SQLWarehouseUtils
from beaker.spark_fixture import get_spark_session


class Benchmark:
"""Encapsulates a query benchmark test."""
def __init__(self, query=None, query_file=None, query_file_dir=None, concurrency=1, db_hostname=None, warehouse_http_path=None, token=None, catalog='hive_metastore', new_warehouse_config=None, results_cache_enabled=False):
self.query=query
self.query_file=query_file
self.query_file_dir=query_file_dir
self.concurrency=concurrency
self.hostname=self.setHostname(db_hostname)
self.http_path=warehouse_http_path
self.token=token
self.catalog=catalog
self.new_warehouse_config=new_warehouse_config
self.results_cache_enabled=results_cache_enabled

def __init__(self, query=None, query_file=None, query_file_dir=None, concurrency=1, db_hostname=None,
warehouse_http_path=None, token=None, catalog='hive_metastore', new_warehouse_config=None,
results_cache_enabled=False):
self.query = query
self.query_file = query_file
self.query_file_dir = query_file_dir
self.concurrency = concurrency
self.hostname = self.setHostname(db_hostname)
self.http_path = warehouse_http_path
self.token = token
self.catalog = catalog
self.new_warehouse_config = new_warehouse_config
self.results_cache_enabled = results_cache_enabled
# Check if a new SQL warehouse needs to be created
if new_warehouse_config is not None:
self.setWarehouseConfig(new_warehouse_config)
self.spark = get_spark_session()

def _get_user_id(self):
"""Helper method for filtering query history the current User's Id"""
Expand Down Expand Up @@ -54,16 +61,17 @@ def setWarehouseConfig(self, config):
def setWarehouse(self, http_path):
"""Sets the SQL Warehouse http path to use for the benchmark."""
assert self._validate_warehouse(id), "Invalid HTTP path for SQL Warehouse."
self.http_path=http_path
self.http_path = http_path

def setConcurrency(self, concurrency):
"""Sets the query execution parallelism."""
self.concurrency=concurrency
self.concurrency = concurrency

def setHostname(self, hostname):
"""Sets the Databricks workspace hostname."""
hostname_clean = hostname.strip().replace("http://", "").replace("https://", "").replace("/", "") if hostname is not None else hostname
self.hostname=hostname_clean
hostname_clean = hostname.strip().replace("http://", "").replace("https://", "")\
.replace("/", "") if hostname is not None else hostname
self.hostname = hostname_clean

def setWarehouseToken(self, token):
"""Sets the API token for communicating with the SQL warehouse."""
Expand Down Expand Up @@ -102,9 +110,10 @@ def _execute_single_query(self, query, id=None):
sql_warehouse = SQLWarehouseUtils(self.hostname, self.http_path, self.token)
sql_warehouse.execute_query(query)
end_time = time.perf_counter()
elapsed_time = f"{end_time-start_time:0.3f}"
elapsed_time = f"{end_time - start_time:0.3f}"
metrics = [(id, self.hostname, self.http_path, self.concurrency, query, elapsed_time)]
metrics_df = spark.createDataFrame(metrics, "id string, hostname string, warehouse string, concurrency int, query_text string, query_duration_secs string")
metrics_df = self.spark.createDataFrame(metrics,
"id string, hostname string, warehouse string, concurrency int, query_text string, query_duration_secs string")
return metrics_df

def _set_default_catalog(self):
Expand Down Expand Up @@ -141,7 +150,7 @@ def _execute_queries_from_file(self, query_file):

# Load queries from SQL file
print(f"Loading queries from file: '{query_file}'")
query_file_cleaned = query_file.replace("dbfs:/", "/dbfs/") # Replace `dbfs:` paths
query_file_cleaned = query_file.replace("dbfs:/", "/dbfs/") # Replace `dbfs:` paths

# Parse the raw SQL, splitting lines into a query identifier (header) and query string
with open(query_file_cleaned) as f:
Expand All @@ -165,14 +174,16 @@ def _execute_queries_from_file(self, query_file):
num_threads = len(queries_in_bucket)
with ThreadPoolExecutor(max_workers=num_threads) as executor:
# Maps the method '_execute_single_query' with a list of queries.
metrics_list = list(executor.map(lambda query_with_header: self._execute_single_query(*query_with_header), query_bucket))
metrics_list = list(
executor.map(lambda query_with_header: self._execute_single_query(*query_with_header),
query_bucket))
final_metrics_result = final_metrics_result + metrics_list

# Union together the metrics DFs
if len(final_metrics_result) > 0:
final_metrics_df = reduce(DataFrame.unionAll, final_metrics_result)
else:
final_metrics_df = spark.sparkContext.emptyRDD()
final_metrics_df = self.spark.sparkContext.emptyRDD()
return final_metrics_df

def execute(self):
Expand All @@ -190,7 +201,8 @@ def execute(self):
if self.query_file_dir is not None:
print("Loading query files from directory.")
# TODO: Implement query directory parsing
#metrics_df = self._execute_queries_from_dir(self.query_file_dir)
# metrics_df = self._execute_queries_from_dir(self.query_file_dir)
metrics_df = self.spark.sparkContext.emptyRDD
elif self.query_file is not None:
print("Loading query file.")
metrics_df = self._execute_queries_from_file(self.query_file)
Expand Down
8 changes: 0 additions & 8 deletions src/beaker/local_spark_singleton.py

This file was deleted.

18 changes: 18 additions & 0 deletions src/beaker/spark_fixture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from pyspark.sql import SparkSession
from functools import lru_cache
import os


@lru_cache(maxsize=None)
def get_spark_session():
if os.getenv("ENV") == "LOCAL":
return SparkSession.builder \
.master("local") \
.appName("beaker") \
.config("spark.sql.shuffle.partitions", "1") \
.config("spark.driver.host", "localhost") \
.getOrCreate()
else:
return SparkSession.builder \
.appName("beaker") \
.getOrCreate()

0 comments on commit fdc4700

Please sign in to comment.