diff --git a/.gitignore b/.gitignore index 0c09d88..b580ba1 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,9 @@ # Build helpers *beaker.egg-info* build/* + +myvenv/ + +*/.env +.env + diff --git a/README.md b/README.md index 70db902..9d88483 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,7 @@ metrics = benchmark.execute() print(metrics) ``` -`metrics` is a list of dict. Each dict is the result of a single query execution. +`metrics` is a pandas dataframe of the result of a single query execution. If you want to examine the results as a spark DataFrame and your environment has the capability of creating a spark session, you can use spark_fixture. diff --git a/dist/beaker-0.0.5-py3-none-any.whl b/dist/beaker-0.0.5-py3-none-any.whl index 46f0669..d257cc1 100644 Binary files a/dist/beaker-0.0.5-py3-none-any.whl and b/dist/beaker-0.0.5-py3-none-any.whl differ diff --git a/dist/beaker-0.0.5.tar.gz b/dist/beaker-0.0.5.tar.gz index f1dcc86..4866de0 100644 Binary files a/dist/beaker-0.0.5.tar.gz and b/dist/beaker-0.0.5.tar.gz differ diff --git a/dist/beaker-0.0.6-py3-none-any.whl b/dist/beaker-0.0.6-py3-none-any.whl new file mode 100644 index 0000000..8d47ac9 Binary files /dev/null and b/dist/beaker-0.0.6-py3-none-any.whl differ diff --git a/dist/beaker-0.0.6.tar.gz b/dist/beaker-0.0.6.tar.gz new file mode 100644 index 0000000..8178a10 Binary files /dev/null and b/dist/beaker-0.0.6.tar.gz differ diff --git a/examples/beaker_getting_started.dbc b/examples/beaker_getting_started.dbc deleted file mode 100644 index 7b1f71b..0000000 Binary files a/examples/beaker_getting_started.dbc and /dev/null differ diff --git a/examples/beaker_getting_started.py b/examples/beaker_getting_started.py index fcb1065..30ebd2c 100644 --- a/examples/beaker_getting_started.py +++ b/examples/beaker_getting_started.py @@ -1,5 +1,7 @@ # Databricks notebook source -# MAGIC %pip install databricks-sql-connector +# MAGIC %pip install databricks-sql-connector -q +# MAGIC %pip install databricks-sdk -q +# MAGIC dbutils.library.restartPython() # COMMAND ---------- @@ -31,6 +33,10 @@ # COMMAND ---------- +importlib.reload(spark_fixture) + +# COMMAND ---------- + # MAGIC %md # MAGIC ## Create a new Benchmark Test @@ -47,11 +53,14 @@ # COMMAND ---------- # Change hostname and http_path to your dbsql warehouse -hostname = "your-dbsql-hostname" -http_path = "your-dbsql-http-path" +hostname = spark.conf.get('spark.databricks.workspaceUrl') +# Extract token from dbutils +pat = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get() +# OR Add the appropriate scope and key for your token if configured in databricks secrets +# pat = dbutils.secrets.get(scope="your-scope", key="your-token") -# Add the appropriate scope and key for your token -pat = dbutils.secrets.get(scope="your-scope", key="your-token") +# warehouse http path example, replace with your own +http_path = "/sql/1.0/warehouses/475b94ddc7cd5211" # COMMAND ---------- @@ -59,9 +68,9 @@ # Use the builder pattern to add parameters for connecting to the warehouse bm.setName(name="simple_test") bm.setHostname(hostname=hostname) +bm.setWarehouseToken(token=pat) bm.setWarehouse(http_path=http_path) bm.setConcurrency(concurrency=1) -bm.setWarehouseToken(token=pat) # Define the query to execute and target Catalog query_str = """ @@ -75,11 +84,7 @@ # COMMAND ---------- # Run the benchmark! -metrics = bm.execute() - -# COMMAND ---------- - -metrics +metrics_pdf = bm.execute() # COMMAND ---------- @@ -88,7 +93,7 @@ # COMMAND ---------- -df_simple_test = spark_fixture.metrics_to_df_view(metrics, "simple_test_vw") +df_simple_test = spark_fixture.metrics_to_df_view(metrics_pdf, "simple_test_vw") df_simple_test.display() # COMMAND ---------- @@ -109,10 +114,19 @@ # COMMAND ---------- +hostname = spark.conf.get('spark.databricks.workspaceUrl') +pat = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get() +query_str = """ +SELECT count(*) + FROM delta.`/databricks-datasets/nyctaxi/tables/nyctaxi_yellow` + WHERE passenger_count > 2 +""" + new_warehouse_config = { "type": "warehouse", "runtime": "latest", - "size": "Large", + "size": "2X-Small", + "warehouse": "serverless", "min_num_clusters": 1, "max_num_clusters": 3, "enable_photon": True, @@ -132,8 +146,8 @@ # benchmark.preWarmTables(tables=["table_a", "table_b", "table_c"]) # Run the benchmark! -metrics = bm.execute() -print(metrics) +metrics_pdf = bm.execute() +display(metrics_pdf) # COMMAND ---------- @@ -189,7 +203,14 @@ # COMMAND ---------- -metrics = bm.execute() -print(metrics) +metrics_pdf = bm.execute() +# Create a spark dataframe of the returned metrics pandas dataframe +metrics_df = spark_fixture.metrics_to_df_view(metrics_pdf, view_name="metrics_view") # COMMAND ---------- + +# MAGIC %sql select * from metrics_view + +# COMMAND ---------- + + diff --git a/examples/beaker_standalone.py b/examples/beaker_standalone.py old mode 100755 new mode 100644 index 75cf15f..29f44c0 --- a/examples/beaker_standalone.py +++ b/examples/beaker_standalone.py @@ -4,44 +4,63 @@ logger = logging.getLogger() logger.setLevel(logging.INFO) -sys.path.append("../src") +from dotenv import load_dotenv +load_dotenv() -from beaker import benchmark +sys.path.append("../src") -bm = benchmark.Benchmark() +from beaker import benchmark, sqlwarehouseutils hostname = os.getenv("DATABRICKS_HOST") http_path = os.getenv("DATABRICKS_HTTP_PATH") # Don't put tokens in plaintext in code access_token = os.getenv("DATABRICKS_ACCESS_TOKEN") +catalog_name = os.getenv("CATALOG") +schema_name = os.getenv("SCHEMA") + +bm = benchmark.Benchmark() bm.setName(name="simple_test") bm.setHostname(hostname=hostname) -bm.setWarehouse(http_path=http_path) -bm.setConcurrency(concurrency=2) bm.setWarehouseToken(token=access_token) +bm.setWarehouse(http_path=http_path) +bm.setConcurrency(concurrency=1) print("---- Specify query in code ------") query_str = """ SELECT count(*) FROM delta.`/databricks-datasets/nyctaxi/tables/nyctaxi_yellow` - WHERE passenger_count > 2 + WHERE passenger_count > 2; """ bm.setQuery(query=query_str) bm.setCatalog(catalog="hive_metastore") - -metrics = bm.execute() -print(metrics) +bm.setSchema(schema="default") +metrics_pdf = bm.execute() +print(metrics_pdf) print("---- Specify a single query file ------") bm.query_file_format = "semicolon-delimited" bm.setQueryFile("queries/q1.sql") -metrics = bm.execute() -print(metrics) +metrics_pdf = bm.execute() +print(metrics_pdf) -print("---- Specify a query directory ------") +print("---- Specify a query directory semicolon format------") +bm.query_file_format = "semicolon-delimited" bm.setQueryFileDir("queries") -metrics = bm.execute() -print(metrics) +metrics_pdf = bm.execute() +print(metrics_pdf) + + +print("---- Specify a query directory original format------") +bm.query_file_format = "original" +bm.setQueryFileDir("queries_orig") +metrics_pdf = bm.execute() +print(metrics_pdf) + + +print("---- Close connection ------") +bm.sql_warehouse.close_connection() +# res = bm.stop_warehouse("c0688d9c9c6d1091") +# print(res) \ No newline at end of file diff --git a/examples/beaker_standalone_tpch.py b/examples/beaker_standalone_tpch.py new file mode 100644 index 0000000..142a728 --- /dev/null +++ b/examples/beaker_standalone_tpch.py @@ -0,0 +1,40 @@ +import os, sys +import logging + +logger = logging.getLogger() +logger.setLevel(logging.INFO) + +from dotenv import load_dotenv +load_dotenv() + +sys.path.append("../src") + +from beaker import benchmark, sqlwarehouseutils + +hostname = os.getenv("DATABRICKS_HOST") +http_path = os.getenv("DATABRICKS_HTTP_PATH") +# Don't put tokens in plaintext in code +access_token = os.getenv("DATABRICKS_ACCESS_TOKEN") +catalog_name = "samples" +schema_name = "tpch" + +bm = benchmark.Benchmark() +bm.setName(name="simple_test") +bm.setHostname(hostname=hostname) +bm.setWarehouseToken(token=access_token) +bm.setWarehouse(http_path=http_path) +bm.setConcurrency(concurrency=1) + +print("---- Test prewarm table ------") +bm.setCatalog(catalog_name) +bm.setSchema(schema_name) +tables = ["customer", "lineitem", "nation", "orders", "part", "partsupp", "region", "supplier"] +# bm.preWarmTables(tables=tables) + +bm.query_file_format = "original" +bm.setQueryFileDir("tpch") +metrics_pdf = bm.execute() +print(metrics_pdf) + +print("---- Close connection ------") +bm.sql_warehouse.close_connection() \ No newline at end of file diff --git a/examples/getting_started.ipynb b/examples/getting_started.ipynb index 7f9973f..7aea0ac 100644 --- a/examples/getting_started.ipynb +++ b/examples/getting_started.ipynb @@ -3,8 +3,58 @@ { "cell_type": "code", "execution_count": null, - "id": "c1f35f99-4af0-4dca-923f-1d38963ba4d6", - "metadata": {}, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "419dc4ed-3e7b-4ee0-825a-186e9bce875d", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "%pip install -r requirements.txt -q" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "b0b8e357-baf1-4bde-afa8-e83a9569e70f", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "dbutils.library.restartPython()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "f61e57fe-b420-4019-b5f2-30e675245201", + "showTitle": false, + "title": "" + } + }, "outputs": [], "source": [ "import os, sys" @@ -13,8 +63,18 @@ { "cell_type": "code", "execution_count": null, - "id": "8985fd77-f702-495c-940f-92400ba472a9", - "metadata": {}, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "a2b5b664-16e8-4475-a139-f8362b598acc", + "showTitle": false, + "title": "" + } + }, "outputs": [], "source": [ "import logging\n", @@ -25,8 +85,18 @@ { "cell_type": "code", "execution_count": null, - "id": "58fb1eb8-aac6-4a36-8103-c2a15c34c6db", - "metadata": {}, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "0036c4db-840d-43aa-a8f7-689af7c42a12", + "showTitle": false, + "title": "" + } + }, "outputs": [], "source": [ "sys.path.append('../src/')" @@ -35,8 +105,18 @@ { "cell_type": "code", "execution_count": null, - "id": "7e5fef46-98c2-44f3-a5ae-7b44b1b08366", - "metadata": {}, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "3582a127-4232-4dec-aa3e-1338e6a701cd", + "showTitle": false, + "title": "" + } + }, "outputs": [], "source": [ "from beaker import benchmark" @@ -45,8 +125,18 @@ { "cell_type": "code", "execution_count": null, - "id": "2e08370d-d302-4f72-a3d0-f5ebb18cf0f6", - "metadata": {}, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "0e5351e2-cc32-494f-a4a0-0e9e1c85e4ff", + "showTitle": false, + "title": "" + } + }, "outputs": [], "source": [ "bm = benchmark.Benchmark()" @@ -55,8 +145,18 @@ { "cell_type": "code", "execution_count": null, - "id": "4135ff45-95be-47e9-8b3a-24fae504e613", - "metadata": {}, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "0cf52295-d4c6-477b-98e6-d3937ea2d43c", + "showTitle": false, + "title": "" + } + }, "outputs": [], "source": [ "import importlib\n", @@ -66,8 +166,18 @@ { "cell_type": "code", "execution_count": null, - "id": "58f65a8d-e3db-490a-879c-0714d27aa37d", - "metadata": {}, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "ae05a0a6-bb77-4f4b-b162-70d4383a7db1", + "showTitle": false, + "title": "" + } + }, "outputs": [], "source": [ "hostname = os.getenv(\"DATABRICKS_HOST\")\n", @@ -80,8 +190,18 @@ { "cell_type": "code", "execution_count": null, - "id": "bd04cc75-1331-45f9-be0b-389f55bb6e35", - "metadata": {}, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "e5b84533-ed24-4ea3-88d6-01bf4277cca1", + "showTitle": false, + "title": "" + } + }, "outputs": [], "source": [ "bm.setName(name=\"simple_test\")\n", @@ -94,8 +214,18 @@ { "cell_type": "code", "execution_count": null, - "id": "881feb1d-f3e7-4506-9df9-eafdeef9115a", - "metadata": {}, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "81365f27-ef0b-432d-9440-d72916670b0e", + "showTitle": false, + "title": "" + } + }, "outputs": [], "source": [ "query_str=\"\"\"\n", @@ -110,36 +240,44 @@ { "cell_type": "code", "execution_count": null, - "id": "fd31721a-0701-4bea-a880-0caed688afb4", - "metadata": {}, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "b5e93556-19ff-4add-bc6b-78fc6b21ca6c", + "showTitle": false, + "title": "" + } + }, "outputs": [], "source": [ - "metrics = bm.execute()" + "metrics_pdf = bm.execute()" ] }, { "cell_type": "code", "execution_count": null, - "id": "2a8b5eb7-359f-4dab-a1ff-0b4e2e8bd595", "metadata": {}, "outputs": [], "source": [ - "metrics" + "display(metrics_pdf)" ] }, { "cell_type": "code", "execution_count": null, - "id": "8df2bd2c-03b5-4032-bfcb-ca880442fe95", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "48280d72-1497-4c9a-8471-08c53eb10db7", - "metadata": {}, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "c2f71cd4-5b87-4e98-a6b7-1d079fbfd05c", + "showTitle": false, + "title": "" + } + }, "outputs": [], "source": [ "bm.query_file_format = \"semicolon-delimited\"" @@ -148,8 +286,15 @@ { "cell_type": "code", "execution_count": null, - "id": "e4bf5ddb-15fa-4c0f-805d-6411955a3232", - "metadata": {}, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "59f008bd-460d-454b-80e4-0bbec5f0c155", + "showTitle": false, + "title": "" + } + }, "outputs": [], "source": [ "bm.setQueryFile('queries/q1.sql')" @@ -158,8 +303,15 @@ { "cell_type": "code", "execution_count": null, - "id": "361b06ee-d388-48de-93b0-bd1b03ce2527", - "metadata": {}, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "b023c128-b6db-4664-bb18-bea1c4994ee1", + "showTitle": false, + "title": "" + } + }, "outputs": [], "source": [ "bm.execute()" @@ -168,8 +320,15 @@ { "cell_type": "code", "execution_count": null, - "id": "d4753bb4-a833-4f06-9408-6ee3a3d157e1", - "metadata": {}, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "150794cd-7d2e-4d38-8bf9-0ef4de40ea58", + "showTitle": false, + "title": "" + } + }, "outputs": [], "source": [ "bm.setQueryFileDir('queries')" @@ -178,36 +337,41 @@ { "cell_type": "code", "execution_count": null, - "id": "232368e3-85de-454e-ba52-3903d451c2c1", - "metadata": {}, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "ae66fd20-ee78-4df1-9bc7-b873cd3f08f0", + "showTitle": false, + "title": "" + } + }, "outputs": [], "source": [ - "metrics = bm.execute()" + "metrics_pdf = bm.execute()" ] }, { "cell_type": "code", "execution_count": null, - "id": "b0b203eb-992c-41e0-8038-40ce733d076c", "metadata": {}, "outputs": [], "source": [ - "metrics" + "metrics_pdf" ] }, { "cell_type": "code", "execution_count": null, - "id": "1b4de6a3-1ac9-4ad7-89cd-b8cfaad5e3ba", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cabfde2b", - "metadata": {}, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "6bf0f336-7732-4e24-bdf5-00ca81870bba", + "showTitle": false, + "title": "" + } + }, "outputs": [], "source": [ "from beaker import sqlwarehouseutils" @@ -216,8 +380,15 @@ { "cell_type": "code", "execution_count": null, - "id": "d80b78ab", - "metadata": {}, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "3e796340-dbd4-40c0-a140-56648522eb50", + "showTitle": false, + "title": "" + } + }, "outputs": [], "source": [ "sql_warehouse = sqlwarehouseutils.SQLWarehouseUtils(\n", @@ -233,8 +404,15 @@ { "cell_type": "code", "execution_count": null, - "id": "c0004348", - "metadata": {}, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "7744be3c-7f67-4345-8930-c51499fc8e75", + "showTitle": false, + "title": "" + } + }, "outputs": [], "source": [ "x = sql_warehouse.get_rows(\"\"\"\n", @@ -249,8 +427,15 @@ { "cell_type": "code", "execution_count": null, - "id": "20df6e8f", - "metadata": {}, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "4116033a-c709-4ced-b2b3-c45ba72445c9", + "showTitle": false, + "title": "" + } + }, "outputs": [], "source": [ "x" @@ -258,6 +443,15 @@ } ], "metadata": { + "application/vnd.databricks.v1+notebook": { + "dashboards": [], + "language": "python", + "notebookMetadata": { + "pythonIndentUnit": 2 + }, + "notebookName": "getting_started", + "widgets": {} + }, "kernelspec": { "display_name": "Python [conda env:mr-delta]", "language": "python", @@ -277,5 +471,5 @@ } }, "nbformat": 4, - "nbformat_minor": 5 + "nbformat_minor": 0 } diff --git a/examples/queries/q1.sql b/examples/queries/q1.sql index e756f3f..3584537 100644 --- a/examples/queries/q1.sql +++ b/examples/queries/q1.sql @@ -1,2 +1,3 @@ -- {"query_id":"q1"} select 'q1', now(); + diff --git a/examples/queries/q10.sql b/examples/queries/q10.sql new file mode 100644 index 0000000..a257670 --- /dev/null +++ b/examples/queries/q10.sql @@ -0,0 +1,2 @@ +-- {"query_id":"q10"} +select 'q10', now(); diff --git a/examples/queries_orig/q1.sql b/examples/queries_orig/q1.sql new file mode 100644 index 0000000..5d501f4 --- /dev/null +++ b/examples/queries_orig/q1.sql @@ -0,0 +1,2 @@ +Q1 +select 'q1', now(); diff --git a/examples/queries_orig/q2.sql b/examples/queries_orig/q2.sql new file mode 100644 index 0000000..44e8a1e --- /dev/null +++ b/examples/queries_orig/q2.sql @@ -0,0 +1,2 @@ +Q2 +select 'q2', now(); diff --git a/examples/requirements.txt b/examples/requirements.txt new file mode 100644 index 0000000..5a3d0e5 --- /dev/null +++ b/examples/requirements.txt @@ -0,0 +1,5 @@ +databricks-sql-connector +databricks-sdk +pyspark +python-dotenv +pandas \ No newline at end of file diff --git a/examples/standalone_dist_test.py b/examples/standalone_dist_test.py index 1b7a0ba..5e6cdcd 100644 --- a/examples/standalone_dist_test.py +++ b/examples/standalone_dist_test.py @@ -14,21 +14,23 @@ logger = logging.getLogger() logger.setLevel(logging.INFO) -bm = benchmark.Benchmark() hostname = os.getenv("DATABRICKS_HOST") http_path = os.getenv("DATABRICKS_HTTP_PATH") # Don't put tokens in plaintext in code access_token = os.getenv("DATABRICKS_ACCESS_TOKEN") +bm = benchmark.Benchmark() +bm.setName(name="standalone_dist_test") bm.setHostname(hostname=hostname) +bm.setWarehouseToken(token=access_token) bm.setWarehouse(http_path=http_path) bm.setConcurrency(concurrency=5) -bm.setWarehouseToken(token=access_token) bm.setQuery("select now(), 'foo';") -bm.setQueryRepeatCount(100) +bm.setCatalog(catalog="hive_metastore") +bm.setQueryRepeatCount(10) -metrics = bm.execute() -#print(metrics) +metrics_pdf = bm.execute() +print(metrics_pdf) diff --git a/examples/tpch/1.sql b/examples/tpch/1.sql new file mode 100644 index 0000000..503b9a8 --- /dev/null +++ b/examples/tpch/1.sql @@ -0,0 +1,409 @@ +Q01 +SELECT + l_returnflag, + l_linestatus, + sum(l_quantity) as sum_qty, + sum(l_extendedprice) as sum_base_price, + sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, + sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, + avg(l_quantity) as avg_qty, + avg(l_extendedprice) as avg_price, + avg(l_discount) as avg_disc, + count(*) as count_order +FROM + lineitem +WHERE + l_shipdate <= date '1998-12-01' - interval '90' day +GROUP BY + l_returnflag, + l_linestatus +ORDER BY + l_returnflag, + l_linestatus; + +Q02 +select + s_acctbal, + s_name, + n_name, + p_partkey, + p_mfgr, + s_address, + s_phone, + s_comment +from + part, + supplier, + partsupp, + nation, + region +where + p_partkey = ps_partkey + and s_suppkey = ps_suppkey + and p_size = 15 + and p_type like '%BRASS' + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'EUROPE' + and ps_supplycost = ( + select + min(ps_supplycost) + from + partsupp, + supplier, + nation, + region + where + p_partkey = ps_partkey + and s_suppkey = ps_suppkey + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'EUROPE' + ) +order by + s_acctbal desc, + n_name, + s_name, + p_partkey; + +Q03 +select + l_orderkey, + sum(l_extendedprice * (1 - l_discount)) as revenue, + o_orderdate, + o_shippriority +from + customer, + orders, + lineitem +where + c_mktsegment = 'BUILDING' + and c_custkey = o_custkey + and l_orderkey = o_orderkey + and o_orderdate < date('1995-03-15') + and l_shipdate > date('1995-03-15') +group by + l_orderkey, + o_orderdate, + o_shippriority +order by + revenue desc, + o_orderdate; + +Q04 +select + o_orderpriority, + count(*) as order_count +from + orders +where + o_orderdate >= date('1993-07-01') + and o_orderdate < add_months('1993-07-01', 3) + and exists ( + select + * + from + lineitem + where + l_orderkey = o_orderkey + and l_commitdate < l_receiptdate + ) +group by + o_orderpriority +order by + o_orderpriority; + +Q05 +select + n_name, + sum(l_extendedprice * (1 - l_discount)) as revenue +from + customer, + orders, + lineitem, + supplier, + nation, + region +where + c_custkey = o_custkey + and l_orderkey = o_orderkey + and l_suppkey = s_suppkey + and c_nationkey = s_nationkey + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'ASIA' + and o_orderdate >= date('1993-01-01') + and o_orderdate < date('1994-01-01') +group by + n_name +order by + revenue desc + +Q06 +select + sum(l_extendedprice * l_discount) as revenue +from + lineitem +where + l_shipdate >= date('1993-01-01') + and l_shipdate < date('1994-01-01') + and l_discount between.06 - 0.01 + and.06 + 0.01 + and l_quantity < 24; + +Q07 +select + supp_nation, + cust_nation, + l_year, + sum(volume) as revenue +from + ( + select + n1.n_name as supp_nation, + n2.n_name as cust_nation, + year(l_shipdate) as l_year, + l_extendedprice * (1 - l_discount) as volume + from + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2 + where + s_suppkey = l_suppkey + and o_orderkey = l_orderkey + and c_custkey = o_custkey + and s_nationkey = n1.n_nationkey + and c_nationkey = n2.n_nationkey + and ( + ( + n1.n_name = 'FRANCE' + and n2.n_name = 'GERMANY' + ) + or ( + n1.n_name = 'GERMANY' + and n2.n_name = 'FRANCE' + ) + ) + and l_shipdate between date('1995-01-01') + and date('1996-12-31') + ) as shipping +group by + supp_nation, + cust_nation, + l_year +order by + supp_nation, + cust_nation, + l_year; + +Q08 +select + o_year, + sum( + case + when nation = 'BRAZIL' then volume + else 0 + end + ) / sum(volume) as mkt_share +from + ( + select + year(o_orderdate) as o_year, + l_extendedprice * (1 - l_discount) as volume, + n2.n_name as nation + from + part, + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2, + region + where + p_partkey = l_partkey + and s_suppkey = l_suppkey + and l_orderkey = o_orderkey + and o_custkey = c_custkey + and c_nationkey = n1.n_nationkey + and n1.n_regionkey = r_regionkey + and r_name = 'AMERICA' + and s_nationkey = n2.n_nationkey + and o_orderdate between date('1995-01-01') + and date('1996-12-31') + and p_type = 'ECONOMY ANODIZED STEEL' + ) as all_nations +group by + o_year +order by + o_year; + +Q09 +select + nation, + o_year, + sum(amount) as sum_profit +from + ( + select + n_name as nation, + year(o_orderdate) as o_year, + l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity as amount + from + part, + supplier, + lineitem, + partsupp, + orders, + nation + where + s_suppkey = l_suppkey + and ps_suppkey = l_suppkey + and ps_partkey = l_partkey + and p_partkey = l_partkey + and o_orderkey = l_orderkey + and s_nationkey = n_nationkey + and p_name like '%green%' + ) as profit +group by + nation, + o_year +order by + nation, + o_year desc; + +Q10 +select + c_custkey, + c_name, + sum(l_extendedprice * (1 - l_discount)) as revenue, + c_acctbal, + n_name, + c_address, + c_phone, + c_comment +from + customer, + orders, + lineitem, + nation +where + c_custkey = o_custkey + and l_orderkey = o_orderkey + and o_orderdate >= date('1993-10-01') + and o_orderdate < ADD_MONTHS('1993-10-01', 3) + and l_returnflag = 'R' + and c_nationkey = n_nationkey +group by + c_custkey, + c_name, + c_acctbal, + c_phone, + n_name, + c_address, + c_comment +order by + revenue desc; + +Q11 +select + ps_partkey, + sum(ps_supplycost * ps_availqty) as value +from + partsupp, + supplier, + nation +where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' +group by + ps_partkey +having + sum(ps_supplycost * ps_availqty) > ( + select + sum(ps_supplycost * ps_availqty) * 0.0000010000 + from + partsupp, + supplier, + nation + where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' + ) +order by + value desc; + +Q12 +select + l_shipmode, + sum( + case + when o_orderpriority = '1-URGENT' + or o_orderpriority = '2-HIGH' then 1 + else 0 + end + ) as high_line_count, + sum( + case + when o_orderpriority <> '1-URGENT' + and o_orderpriority <> '2-HIGH' then 1 + else 0 + end + ) as low_line_count +from + orders, + lineitem +where + o_orderkey = l_orderkey + and l_shipmode in ('MAIL', 'SHIP') + and l_commitdate < l_receiptdate + and l_shipdate < l_commitdate + and l_receiptdate >= date('1995-01-01') + and l_receiptdate < date('1996-01-01') +group by + l_shipmode +order by + l_shipmode; + +Q13 +select + c_count, + count(*) as custdist +from + ( + select + c_custkey, + count(o_orderkey) + from + customer + left outer join orders on c_custkey = o_custkey + and o_comment not like '%special%requests%' + group by + c_custkey + ) as orders (c_custkey, c_count) +group by + c_count +order by + custdist desc, + c_count desc; + +Q14 +select + 100.00 * sum( + case + when p_type like 'PROMO%' then l_extendedprice * (1 - l_discount) + else 0 + end + ) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue +from + lineitem, + part +where + l_partkey = p_partkey + and l_shipdate >= date('1995-09-01') + and l_shipdate < add_months('1995-09-01', 1); \ No newline at end of file diff --git a/examples/tpch/2.sql b/examples/tpch/2.sql new file mode 100644 index 0000000..fb4faa7 --- /dev/null +++ b/examples/tpch/2.sql @@ -0,0 +1,276 @@ +Q15 +with revenue0 as ( + select + l_suppkey as supplier_no, + sum(l_extendedprice * (1 - l_discount)) as total_revenue + from + lineitem + where + l_shipdate >= date '1996-01-01' + and l_shipdate < date '1996-01-01' + interval '3' month + group by + l_suppkey + ) +select + s_suppkey, + s_name, + s_address, + s_phone, + total_revenue +from + supplier, + revenue0 +where + s_suppkey = supplier_no + and total_revenue = ( + select + max(total_revenue) + from + revenue0 + ) +order by + s_suppkey; + +Q16 +select + p_brand, + p_type, + p_size, + count(distinct ps_suppkey) as supplier_cnt +from + partsupp, + part +where + p_partkey = ps_partkey + and p_brand <> 'Brand#45' + and p_type not like 'MEDIUM POLISHED%' + and p_size in (49, 14, 23, 45, 19, 3, 36, 9) + and ps_suppkey not in ( + select + s_suppkey + from + supplier + where + s_comment like '%Customer%Complaints%' + ) +group by + p_brand, + p_type, + p_size +order by + supplier_cnt desc, + p_brand, + p_type, + p_size; + +Q17 +select + sum(l_extendedprice) / 7.0 as avg_yearly +from + lineitem, + part +where + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container = 'MED BOX' + and l_quantity < ( + select + 0.2 * avg(l_quantity) + from + lineitem + where + l_partkey = p_partkey + ); + +Q18 +select + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice, + sum(l_quantity) +from + customer, + orders, + lineitem +where + o_orderkey in ( + select + l_orderkey + from + lineitem + group by + l_orderkey + having + sum(l_quantity) > 300 + ) + and c_custkey = o_custkey + and o_orderkey = l_orderkey +group by + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice +order by + o_totalprice desc, + o_orderdate; + +Q19 +select + sum(l_extendedprice * (1 - l_discount)) as revenue +from + lineitem, + part +where + ( + p_partkey = l_partkey + and p_brand = 'Brand#12' + and p_container in ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') + and l_quantity >= 1 + and l_quantity <= 1 + 10 + and p_size between 1 + and 5 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or ( + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container in ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') + and l_quantity >= 10 + and l_quantity <= 10 + 10 + and p_size between 1 + and 10 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or ( + p_partkey = l_partkey + and p_brand = 'Brand#34' + and p_container in ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') + and l_quantity >= 20 + and l_quantity <= 20 + 10 + and p_size between 1 + and 15 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ); + +Q20 +select + s_name, + s_address +from + supplier, + nation +where + s_suppkey in ( + select + ps_suppkey + from + partsupp + where + ps_partkey in ( + select + p_partkey + from + part + where + p_name like 'forest%' + ) + and ps_availqty > ( + select + 0.5 * sum(l_quantity) + from + lineitem + where + l_partkey = ps_partkey + and l_suppkey = ps_suppkey + and l_shipdate >= date('1993-01-01') + and l_shipdate < date('1994-01-01') + ) + ) + and s_nationkey = n_nationkey + and n_name = 'CANADA' +order by + s_name; + +Q21 +select + s_name, + count(*) as numwait +from + supplier, + lineitem l1, + orders, + nation +where + s_suppkey = l1.l_suppkey + and o_orderkey = l1.l_orderkey + and o_orderstatus = 'F' + and l1.l_receiptdate > l1.l_commitdate + and exists ( + select + * + from + lineitem l2 + where + l2.l_orderkey = l1.l_orderkey + and l2.l_suppkey <> l1.l_suppkey + ) + and not exists ( + select + * + from + lineitem l3 + where + l3.l_orderkey = l1.l_orderkey + and l3.l_suppkey <> l1.l_suppkey + and l3.l_receiptdate > l3.l_commitdate + ) + and s_nationkey = n_nationkey + and n_name = 'SAUDI ARABIA' +group by + s_name +order by + numwait desc, + s_name; + +Q22 +select + cntrycode, + count(*) as numcust, + sum(c_acctbal) as totacctbal +from + ( + select + SUBSTR(c_phone, 1, 2) as cntrycode, + c_acctbal + from + customer + where + SUBSTR(c_phone, 1, 2) in ('13', '31', '23', '29', '30', '18', '17') + and c_acctbal > ( + select + avg(c_acctbal) + from + customer + where + c_acctbal > 0.00 + and SUBSTR(c_phone, 1, 2) in ('13', '31', '23', '29', '30', '18', '17') + ) + and not exists ( + select + * + from + orders + where + o_custkey = c_custkey + ) + ) as custsale +group by + cntrycode +order by + cntrycode; \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py index 6e83b38..0cdff65 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1 +1 @@ -VERSION = '0.0.5' +VERSION = '0.0.6' diff --git a/src/beaker/benchmark.py b/src/beaker/benchmark.py index 3da256c..cb956f6 100644 --- a/src/beaker/benchmark.py +++ b/src/beaker/benchmark.py @@ -5,6 +5,10 @@ import logging from concurrent.futures import ThreadPoolExecutor import threading +import datetime +import json +import pandas as pd +from pandas import json_normalize from beaker.sqlwarehouseutils import SQLWarehouseUtils from beaker.spark_fixture import get_spark_session, metrics_to_df_view @@ -52,9 +56,10 @@ def __init__( if new_warehouse_config is not None: self.setWarehouseConfig(new_warehouse_config) self.query_file_format = query_file_format + self.sql_warehouse = None def _create_dbc(self): - sql_warehouse = SQLWarehouseUtils( + self.sql_warehouse = SQLWarehouseUtils( self.hostname, self.http_path, self.token, @@ -63,9 +68,9 @@ def _create_dbc(self): self.results_cache_enabled, ) # establish connection on the existing warehouse - sql_warehouse.setConnection() - logging.info(f"Returning new sqlwarehouseutils: {sql_warehouse}") - return sql_warehouse + self.sql_warehouse.setConnection() + + return self.sql_warehouse def _get_thread_local_connection(self): if not hasattr(thread_local, "connection"): @@ -80,9 +85,14 @@ def _get_user_id(self): ) return response.json()["id"] + def _validate_warehouse(self, http_path): """Validates the SQL warehouse HTTP path.""" - return True + pattern = r'^/sql/1\.0/warehouses/[a-f0-9]+$' + if re.match(pattern, http_path): + return True + else: + return False def _launch_new_warehouse(self): """Launches a new SQL Warehouse""" @@ -108,14 +118,26 @@ def setWarehouseConfig(self, config): """Launches a new cluster/warehouse from a JSON config.""" self.new_warehouse_config = config logging.info(f"Creating new warehouse with config: {config}") - warehouse_id = self._launch_new_warehouse() - logging.info(f"The warehouse Id is: {warehouse_id}") - self.http_path = f"/sql/1.0/warehouses/{warehouse_id}" + self.warehouse_id = self._launch_new_warehouse() + self.warehouse_name = self._get_warehouse_info() + self.http_path = f"/sql/1.0/warehouses/{self.warehouse_id}" def setWarehouse(self, http_path): """Sets the SQL Warehouse http path to use for the benchmark.""" - assert self._validate_warehouse(id), "Invalid HTTP path for SQL Warehouse." + assert self._validate_warehouse(http_path), "Invalid HTTP path for SQL Warehouse." self.http_path = http_path + if self.http_path: + self.warehouse_id = self.http_path.split("/")[-1] + self.warehouse_name = self._get_warehouse_info() + + def stop_warehouse(self, warehouse_id): + """Stops a SQL warehouse.""" + logging.info(f"Stopping warehouse {warehouse_id}") + response = requests.post( + f"https://{self.hostname}/api/2.0/sql/warehouses/{warehouse_id}/stop", + headers={"Authorization": f"Bearer {self.token}"}, + ) + return response.status_code def setConcurrency(self, concurrency): """Sets the query execution parallelism.""" @@ -172,25 +194,25 @@ def setQueryFileDir(self, query_file_dir): def _execute_single_query(self, query, id=None): query = query.strip() - logging.info(query) - - sql_warehouse = self._get_thread_local_connection() - # TODO: instead of using perf counter, we might want to get the query duration from /api/2.0/sql/history/queries API + ## Instead of using perf counter, we want to get the query duration from /api/2.0/sql/history/queries API start_time = time.perf_counter() - sql_warehouse.execute_query(query) + result = self.sql_warehouse.execute_query(query) end_time = time.perf_counter() elapsed_time = f"{end_time - start_time:0.3f}" + metrics = { "id": id, "hostname": self.hostname, "http_path": self.http_path, + "warehouse_name": self.warehouse_name, "concurrency": self.concurrency, "query": query, "elapsed_time": elapsed_time, } return metrics + def _set_default_catalog(self): if self.catalog: query = f"USE CATALOG {self.catalog}" @@ -214,34 +236,29 @@ def _get_queries_from_file_format_orig(self, f): file_headers, file_queries = self._parse_queries(raw_queries) queries = [e for e in zip(file_queries, file_headers)] return queries + + def _get_queries_from_file_format_semi(self, file_path): + """ + Parses a SQL file and returns a list of tuples with the query_id and the query text. - def _get_queries_from_file_format_semi(self, f, filter_comment_lines=False): - fc = None - queries = [] - with open(f, "r") as of: - fc = of.read() - for idx, q in enumerate(fc.split(";")): - q = q.strip() - if not q: - continue - # Keep non-empty lines. - # Also keep or remove comments depending on the flag. - rq = [ - l - for l in q.split("\n") - if l.strip() and not (filter_comment_lines and l.startswith("--")) - ] - if rq: - queries.append( - ( - "\n".join(rq), - f"query{idx}", - ) - ) + Parameters: + file_path (str): The path to the SQL file. + + Returns: + list: A list of tuples, where each tuple contains a query_id and a query text. + """ + with open(file_path, 'r') as file: + content = file.read() + + pattern = r'-- {"query_id":"(.*?)"}\n(.*?);' + matches = re.findall(pattern, content, re.DOTALL) + + # Swap the order of elements in each tuple + queries = [(query, query_id) for query_id, query in matches] return queries + def _get_queries_from_file(self, query_file): - print("Get queries from file:", query_file) if self.query_file_format == self.QUERY_FILE_FORMAT_SEMICOLON_DELIM: return self._get_queries_from_file_format_semi(query_file) elif self.query_file_format == self.QUERY_FILE_FORMAT_ORIGINAL: @@ -258,7 +275,6 @@ def _get_query_filenames_from_dir(self, query_file_dir): return [os.path.join(query_file_dir, f) for f in os.listdir(query_file_dir)] def _get_queries_from_dir(self, query_dir): - print("Get queries from dir:", query_dir) query_files = self._get_query_filenames_from_dir(query_dir) queries = [] for qf in query_files: @@ -285,14 +301,95 @@ def _execute_queries(self, queries, num_threads): executor.map(lambda x: self._execute_single_query(*x), queries) ) return metrics_list + + def get_query_history(self, warehouse_id, start_ts_ms, end_ts_ms): + """ + Retrieves the Query History for a given workspace and Data Warehouse. + + Parameters: + ----------- + warehouse_id (str): The ID of the Data Warehouse for which to retrieve the Query History. + start_ts_ms (int): The Unix timestamp (milliseconds) value representing the start of the query history. + end_ts_ms (int): The Unix timestamp (milliseconds) value representing the end of the query history. + + Returns: + -------- + end_res : query history json + """ + logging.info(f"Extracting query history {self.warehouse_name}") + user_id = self._get_user_id() + ## Put together request + request_string = { + "filter_by": { + "query_start_time_range": { + "end_time_ms": end_ts_ms, + "start_time_ms": start_ts_ms + }, + "warehouse_ids": warehouse_id, + "user_ids": [user_id], + }, + "include_metrics": "true", + "max_results": "1000" + } + + # ## Convert dict to json + v = json.dumps(request_string) + + uri = f"https://{self.hostname}/api/2.0/sql/history/queries" + headers_auth = {"Authorization":f"Bearer {self.token}"} + + #### Get Query History Results from API + response = requests.get(uri, data=v, headers=headers_auth) + while True: + results = response.json()['res'] + if all([item['is_final'] for item in results]): + break + time.sleep(10) + response = requests.get(uri, data=v, headers=headers_auth) + + if (response.status_code == 200) and ("res" in response.json()): + end_res = response.json()['res'] + return end_res + else: + raise Exception("Failed to retrieve successful query history") + + def clean_query_metrics(self, raw_metrics_pdf): + logging.info(f"Clean Query Metrics {self.warehouse_name}") + metrics_pdf = json_normalize(raw_metrics_pdf['metrics'].apply(str).apply(eval)) + metrics_pdf["query_id"] = raw_metrics_pdf["query_id"] + metrics_pdf["query"] = raw_metrics_pdf["query"] + metrics_pdf["status"] = raw_metrics_pdf["status"] + metrics_pdf["warehouse_name"] = self.warehouse_name + metrics_pdf["id"] = raw_metrics_pdf["id"] + # Reorder the columns + metrics_pdf = metrics_pdf.reindex(columns=['id', 'warehouse_name', 'query', 'query_id', 'status'] + [c for c in metrics_pdf.columns if c not in ['id', 'warehouse_name', 'query', 'query_id', 'status']]) + return metrics_pdf + + def _get_warehouse_info(self): + """Gets the warehouse info.""" + response = requests.get( + f"https://{self.hostname}/api/2.0/sql/warehouses/{self.warehouse_id}", + headers={"Authorization": f"Bearer {self.token}"}, + ) + warehouse_name = response.json()["name"] + return warehouse_name + def execute(self): """Executes the benchmark test.""" - logging.info("Executing benchmark test") - logging.info("Set default catalog and schema") + logging.info("Executing benchmark") + self.sql_warehouse = self._get_thread_local_connection() + + print(self.sql_warehouse) + self._set_default_catalog() self._set_default_schema() - metrics = None + + print(f"Monitor warehouse `{self.warehouse_name}` at: ", f"https://{self.hostname}/sql/warehouses/{self.warehouse_id}/monitoring") + + start_ts_ms = int(time.time() * 1000) + start_dt = datetime.datetime.fromtimestamp(start_ts_ms/1000).strftime('%Y-%m-%d %H:%M:%S') + if self.query_file_dir is not None: logging.info("Loading query files from directory.") metrics = self._execute_queries_from_dir(self.query_file_dir) @@ -304,9 +401,16 @@ def execute(self): metrics = self._execute_queries_from_query(self.query) else: raise ValueError("No query specified.") - metrics_df = metrics_to_df_view(metrics, f"{self.name}_vw") - print(f"Query the metrics view at: ", f"{self.name}_vw") - return metrics + + end_ts_ms = int(time.time() * 1000) + + history_metrics = self.get_query_history(self.warehouse_id, start_ts_ms, end_ts_ms) + history_pdf = pd.DataFrame(history_metrics) + beaker_pdf = pd.DataFrame(metrics) + raw_metrics_pdf = history_pdf.merge(beaker_pdf[['query', 'id']].drop_duplicates(), left_on='query_text', right_on='query', how='inner') + metrics_pdf = self.clean_query_metrics(raw_metrics_pdf) + + return metrics_pdf def preWarmTables(self, tables): """Delta caches the table before running a benchmark test.""" @@ -317,13 +421,17 @@ def preWarmTables(self, tables): assert ( self.catalog is not None ), "No catalog provided. You can add a catalog by calling `.setCatalog()`." + + self.sql_warehouse = self._get_thread_local_connection() + self._set_default_catalog() self._set_default_schema() + logging.info(f"Pre-warming tables on {self.catalog}.{self.schema} in {self.warehouse_name}") for table in tables: - logging.info(f"Pre-warming table: {table}") - query = f"SELECT * FROM {table}" + query = f"CACHE SELECT * FROM {table}" self._execute_single_query(query) + def __str__(self): object_str = f""" Benchmark Test: @@ -337,5 +445,6 @@ def __str__(self): query_repeat_count={self.query_repeat_count} hostname={self.hostname} warehouse_http_path={self.http_path} + sql_warehouse={self.sql_warehouse} """ return object_str diff --git a/src/beaker/spark_fixture.py b/src/beaker/spark_fixture.py index ae0f6ce..3a3c370 100644 --- a/src/beaker/spark_fixture.py +++ b/src/beaker/spark_fixture.py @@ -1,7 +1,7 @@ import os from pyspark.sql import SparkSession from functools import lru_cache - +from pyspark.sql.functions import col @lru_cache(maxsize=None) def get_spark_session(): @@ -16,12 +16,13 @@ def get_spark_session(): else: return SparkSession.builder.appName("beaker").getOrCreate() - -def metrics_to_df_view(metrics, view_name): - """Convert a list of dicts to a results dataframe. - Create a view and return the dataframe. + +def metrics_to_df_view(metrics_pdf, view_name): + """Convert a pandas dataframe to a spark dataframe. + Create a view and return the spark dataframe. """ spark = get_spark_session() - df = spark.createDataFrame(metrics) - df.createOrReplaceTempView(view_name) - return df + metrics_df = spark.createDataFrame(metrics_pdf) + metrics_df.createOrReplaceTempView(view_name) + print(f"Query metrics at: {view_name}") + return metrics_df diff --git a/src/beaker/sqlwarehouseutils.py b/src/beaker/sqlwarehouseutils.py index 4683539..c9b02eb 100644 --- a/src/beaker/sqlwarehouseutils.py +++ b/src/beaker/sqlwarehouseutils.py @@ -62,7 +62,6 @@ def setConnection(self): schema=self.schema, session_configuration={"use_cached_result": results_caching}, ) - logging.info(f"Created new connection: {connection}") self.connection = connection def close_connection(self): @@ -78,7 +77,9 @@ def close_connection(self): def execute_query(self, query_str): with self.connection.cursor() as cursor: - result = cursor.execute(query_str) + cursor.execute(query_str) + result = cursor.fetchall() + return result def get_rows(self, query_str): with self.connection.cursor() as cursor: @@ -109,6 +110,7 @@ def _get_spark_runtimes(self): ) result = list(map(lambda v: v["key"], response.json()["versions"])) return result + def launch_warehouse(self, config): """Creates a new SQL warehouse based upon a config.""" @@ -222,10 +224,25 @@ def launch_warehouse(self, config): warehouse_id = response.json().get("id") warehouse_start_time = time.time() - WorkspaceClient().warehouses.start_and_wait(warehouse_id) - print(f"{int(time.time() - warehouse_start_time)}s Warehouse Startup Time") + + WorkspaceClient(host=f"https://{self.hostname}", token=self.access_token).warehouses.start_and_wait(warehouse_id) + + print(f"{int(time.time() - warehouse_start_time)}s Warehouse {warehouse_id} Startup Time") if not warehouse_id: raise Exception(f"did not get back warehouse_id ({response.json()})") return warehouse_id + + + def __str__(self): + object_str = f""" + SQL Warehouse Utils + ------------------------ + hostname={self.hostname} + catalog={self.catalog} + schema={self.schema} + http_path={self.http_path} + enable_results_caching={self.enable_results_caching} + """ + return object_str \ No newline at end of file diff --git a/src/unittest/test_Benchmark.py b/src/unittest/test_Benchmark.py new file mode 100644 index 0000000..9bd5b59 --- /dev/null +++ b/src/unittest/test_Benchmark.py @@ -0,0 +1,100 @@ +import unittest +import sys +from dotenv import load_dotenv +import os + +sys.path.append("../") +from beaker import benchmark + +load_dotenv("../examples/.env") + +hostname = os.getenv("DATABRICKS_HOST") +http_path = os.getenv("DATABRICKS_HTTP_PATH") +# Don't put tokens in plaintext in code +access_token = os.getenv("DATABRICKS_ACCESS_TOKEN") +catalog_name = os.getenv("CATALOG") +schema_name = os.getenv("SCHEMA") + +class TestBenchmark(unittest.TestCase): + def setUp(self): + self.bm = benchmark.Benchmark() + self.bm.setName(name="unittest") + self.bm.setHostname(hostname=hostname) + self.bm.setWarehouseToken(token=access_token) + self.bm.setWarehouse(http_path=http_path) + + def test_get_queries_from_file_format_semi(self): + # Define a test case + test_file_path = '../../examples/queries/q10.sql' + # replace with the expected output + expected_output = [("select 'q10', now()", 'q10')] + + # Call the function with the test case + actual_output = self.bm._get_queries_from_file_format_semi(test_file_path) + + # Assert that the actual output matches the expected output + self.assertEqual(actual_output, expected_output) + + def test_get_queries_from_file_format_orig(self): + # Define a test case + test_file_path = '../../examples/queries_orig/q1.sql' + # replace with the expected output + expected_output = [("select 'q1', now();", 'Q1')] + + # Call the function with the test case + actual_output = self.bm._get_queries_from_file_format_orig(test_file_path) + + # Assert that the actual output matches the expected output + self.assertEqual(actual_output, expected_output) + + def test_get_queries_from_dir_orig(self): + # Define a test case + test_dir_path = '../../examples/queries_orig/' + # replace with the expected output + expected_output = [("select 'q1', now();", 'Q1'), ("select 'q2', now();", 'Q2')] + + # Call the function with the test case + self.bm.query_file_format = "original" + actual_output = self.bm._get_queries_from_dir(test_dir_path) + + # Assert that the actual output matches the expected output + self.assertEqual(actual_output, expected_output) + + def test_get_queries_from_dir_semi(self): + # Define a test case + test_dir_path = '../../examples/queries/' + # replace with the expected output + expected_output = [("select 'q1', now()", 'q1'), ("select 'q2', now()", 'q2'), ("select 'q10', now()", 'q10')] + + # Call the function with the test case + self.bm.query_file_format = "semicolon-delimited" + actual_output = self.bm._get_queries_from_dir(test_dir_path) + # Assert that the actual output matches the expected output + self.assertEqual(actual_output, expected_output) + + def test_validate_warehouse(self): + # Define a test case + test_http_path = "/sql/1.0/warehouses/632c5da7a7fd6a78" + # replace with the expected output + expected_output = True + + # Call the function with the test case + actual_output = self.bm._validate_warehouse(test_http_path) + + # Assert that the actual output matches the expected output + self.assertEqual(actual_output, expected_output) + + # Define a test case + test_http_path2 = "/sql/1.0/warehouses632c5da7a7fd" + # replace with the expected output + expected_output2 = False + + # Call the function with the test case + actual_output2 = self.bm._validate_warehouse(test_http_path2) + + # Assert that the actual output matches the expected output + self.assertEqual(actual_output2, expected_output2) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file