Skip to content

Commit

Permalink
[data] support for read_sql to handle tasks concurrently. (#49424)
Browse files Browse the repository at this point in the history
## Related issue number

Parallel SQL reads support by using MOD/CAT/Custom hashes.
Closes  #49206

<!-- For example: "Closes #1234" -->

## Checks

- [ ] I've signed off every commit(by using the -s flag, i.e., `git
commit -s`) in this PR.
- [ ] I've run `scripts/format.sh` to lint the changes in this PR.
- [ ] I've included any doc changes needed for
https://docs.ray.io/en/master/.
- [ ] I've added any new APIs to the API Reference. For example, if I
added a
method in Tune, I've added it in `doc/source/tune/api/` under the
           corresponding `.rst` file.
- [ ] I've made sure the tests are passing. Note that there might be a
few flaky tests, see the recent failures at https://flakey-tests.ray.io/
- Testing Strategy
   - [ ] Unit tests
   - [ ] Release tests
   - [ ] This PR is not tested :(

---------

Signed-off-by: jukejian <[email protected]>
Signed-off-by: Richard Liaw <[email protected]>
Co-authored-by: Richard Liaw <[email protected]>
  • Loading branch information
Jay-ju and richardliaw authored Jan 30, 2025
1 parent aea6d46 commit 668e5b4
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 24 deletions.
106 changes: 102 additions & 4 deletions python/ray/data/_internal/datasource/sql_datasource.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging
import math
from contextlib import contextmanager
from typing import Any, Callable, Iterable, Iterator, List, Optional

Expand All @@ -7,6 +9,8 @@
Connection = Any # A Python DB API2-compliant `Connection` object.
Cursor = Any # A Python DB API2-compliant `Cursor` object.

logger = logging.getLogger(__name__)


def _cursor_to_block(cursor) -> Block:
import pyarrow as pa
Expand Down Expand Up @@ -71,19 +75,113 @@ def _connect(connection_factory: Callable[[], Connection]) -> Iterator[Cursor]:


class SQLDatasource(Datasource):
def __init__(self, sql: str, connection_factory: Callable[[], Connection]):
MIN_ROWS_PER_READ_TASK = 50

def __init__(
self,
sql: str,
connection_factory: Callable[[], Connection],
shard_hash_fn: str,
shard_keys: Optional[List[str]] = None,
):
self.sql = sql
if shard_keys and len(shard_keys) > 1:
self.shard_keys = f"CONCAT({','.join(shard_keys)})"
elif shard_keys and len(shard_keys) == 1:
self.shard_keys = f"{shard_keys[0]}"
else:
self.shard_keys = None
self.shard_hash_fn = shard_hash_fn
self.connection_factory = connection_factory

def estimate_inmemory_data_size(self) -> Optional[int]:
return None

def supports_sharding(self, parallelism: int) -> bool:
"""Check if database supports sharding with MOD/ABS/CONCAT operations.
Returns:
bool: True if sharding is supported, False otherwise.
"""
if parallelism <= 1 or self.shard_keys is None:
return False

# Test if database supports required operations (MOD, ABS, MD5, CONCAT)
# by executing a sample query
hash_fn = self.shard_hash_fn
query = (
f"SELECT COUNT(1) FROM ({self.sql}) as T"
f" WHERE MOD(ABS({hash_fn}({self.shard_keys})), {parallelism}) = 0"
)
try:
with _connect(self.connection_factory) as cursor:
cursor.execute(query)
return True
except Exception as e:
logger.info(f"Database does not support sharding: {str(e)}.")
return False

def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
def read_fn() -> Iterable[Block]:
def fallback_read_fn() -> Iterable[Block]:
"""Read all data in a single block when sharding is not supported."""
with _connect(self.connection_factory) as cursor:
cursor.execute(self.sql)
return [_cursor_to_block(cursor)]

num_rows_total = self._get_num_rows()

if num_rows_total == 0:
return []

parallelism = min(
parallelism, math.ceil(num_rows_total / self.MIN_ROWS_PER_READ_TASK)
)
num_rows_per_block = num_rows_total // parallelism
num_blocks_with_extra_row = num_rows_total % parallelism

# Check if sharding is supported by the database
# If not, fall back to reading all data in a single task
if not self.supports_sharding(parallelism):
logger.info(
"Sharding is not supported. "
"Falling back to reading all data in a single task."
)
metadata = BlockMetadata(None, None, None, None, None)
return [ReadTask(fallback_read_fn, metadata)]

tasks = []
for i in range(parallelism):
num_rows = num_rows_per_block
if i < num_blocks_with_extra_row:
num_rows += 1
read_fn = self._create_parallel_read_fn(i, parallelism)
metadata = BlockMetadata(
num_rows=num_rows,
size_bytes=None,
schema=None,
input_files=None,
exec_stats=None,
)
tasks.append(ReadTask(read_fn, metadata))

return tasks

def _get_num_rows(self) -> int:
with _connect(self.connection_factory) as cursor:
cursor.execute(f"SELECT COUNT(*) FROM ({self.sql}) as T")
return cursor.fetchone()[0]

def _create_parallel_read_fn(self, task_id: int, parallelism: int):
hash_fn = self.shard_hash_fn
query = (
f"SELECT * FROM ({self.sql}) as T "
f"WHERE MOD(ABS({hash_fn}({self.shard_keys})), {parallelism}) = {task_id}"
)

def read_fn() -> Iterable[Block]:
with _connect(self.connection_factory) as cursor:
cursor.execute(query)
block = _cursor_to_block(cursor)
return [block]

metadata = BlockMetadata(None, None, None, None, None)
return [ReadTask(read_fn, metadata)]
return read_fn
44 changes: 31 additions & 13 deletions python/ray/data/read_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ def read_datasource(
# TODO(hchen/chengsu): Remove the duplicated get_read_tasks call here after
# removing LazyBlockList code path.
read_tasks = datasource_or_legacy_reader.get_read_tasks(requested_parallelism)

import uuid

stats = DatasetStats(
Expand Down Expand Up @@ -2268,6 +2269,8 @@ def read_sql(
sql: str,
connection_factory: Callable[[], Connection],
*,
shard_keys: Optional[list[str]] = None,
shard_hash_fn: str = "MD5",
parallelism: int = -1,
ray_remote_args: Optional[Dict[str, Any]] = None,
concurrency: Optional[int] = None,
Expand All @@ -2278,14 +2281,16 @@ def read_sql(
.. note::
By default, ``read_sql`` launches multiple read tasks, and each task executes a
``LIMIT`` and ``OFFSET`` to fetch a subset of the rows. However, for many
databases, ``OFFSET`` is slow.
Parallelism is supported by databases that support sharding. This means
that the database needs to support all of the following operations:
``MOD``, ``ABS``, and ``CONCAT``.
You can use ``shard_hash_fn`` to specify the hash function to use for sharding.
The default is ``MD5``, but other common alternatives include ``hash``,
``unicode``, and ``SHA``.
As a workaround, set ``override_num_blocks=1`` to directly fetch all rows in a
single task. Note that this approach requires all result rows to fit in the
memory of single task. If the rows don't fit, your program may raise an out of
memory error.
If the database does not support sharding, the read operation will be
executed in a single task.
Examples:
Expand Down Expand Up @@ -2338,27 +2343,40 @@ def create_connection():
connection_factory: A function that takes no arguments and returns a
Python DB API2
`Connection object <https://peps.python.org/pep-0249/#connection-objects>`_.
shard_keys: The keys to shard the data by.
shard_hash_fn: The hash function string to use for sharding. Defaults to "MD5".
For other databases, common alternatives include "hash" and "SHA".
This is applied to the shard keys.
parallelism: This argument is deprecated. Use ``override_num_blocks`` argument.
ray_remote_args: kwargs passed to :func:`ray.remote` in the read tasks.
concurrency: The maximum number of Ray tasks to run concurrently. Set this
to control number of tasks to run concurrently. This doesn't change the
total number of tasks run or the total number of output blocks. By default,
concurrency is dynamically decided based on the available resources.
override_num_blocks: Override the number of output blocks from all read tasks.
This is used for sharding when shard_keys is provided.
By default, the number of output blocks is dynamically decided based on
input data size and available resources. You shouldn't manually set this
value in most cases.
Returns:
A :class:`Dataset` containing the queried data.
"""
if parallelism != -1 and parallelism != 1:
raise ValueError(
"To ensure correctness, 'read_sql' always launches one task. The "
"'parallelism' argument you specified can't be used."
)
datasource = SQLDatasource(
sql=sql,
shard_keys=shard_keys,
shard_hash_fn=shard_hash_fn,
connection_factory=connection_factory,
)
if override_num_blocks and override_num_blocks > 1:
if shard_keys is None:
raise ValueError("shard_keys must be provided when override_num_blocks > 1")

if not datasource.supports_sharding(override_num_blocks):
raise ValueError(
"Database does not support sharding. Please set override_num_blocks to 1."
)

datasource = SQLDatasource(sql=sql, connection_factory=connection_factory)
return read_datasource(
datasource,
parallelism=parallelism,
Expand Down
74 changes: 67 additions & 7 deletions python/ray/data/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,6 @@ def temp_database_fixture() -> Generator[str, None, None]:
yield file.name


def test_read_sql_with_parallelism_warns(temp_database):
with pytest.raises(ValueError):
ray.data.read_sql(
"SELECT * FROM movie", lambda: sqlite3.connect(temp_database), parallelism=2
)


def test_read_sql(temp_database: str):
connection = sqlite3.connect(temp_database)
connection.execute("CREATE TABLE movie(title, year, score)")
Expand All @@ -49,6 +42,73 @@ def test_read_sql(temp_database: str):
assert sorted(actual_values) == sorted(expected_values)


def test_read_sql_with_parallelism_fallback(temp_database: str):
connection = sqlite3.connect(temp_database)
connection.execute("CREATE TABLE grade(name, id, score)")
base_tuple = ("xiaoming", 1, 8.2)
# Generate 200 elements
expected_values = [
(f"{base_tuple[0]}{i}", i, base_tuple[2] + i + 1) for i in range(500)
]
connection.executemany("INSERT INTO grade VALUES (?, ?, ?)", expected_values)
connection.commit()
connection.close()

num_blocks = 2
dataset = ray.data.read_sql(
"SELECT * FROM grade",
lambda: sqlite3.connect(temp_database),
override_num_blocks=num_blocks,
shard_hash_fn="unicode",
shard_keys=["id"],
)
dataset = dataset.materialize()
assert dataset.num_blocks() == num_blocks

actual_values = [tuple(record.values()) for record in dataset.take_all()]
assert sorted(actual_values) == sorted(expected_values)


# for mysql test
@pytest.mark.skip(reason="skip this test because mysql env is not ready")
def test_read_sql_with_parallelism_mysql(temp_database: str):
# connect mysql
import pymysql

connection = pymysql.connect(
host="10.10.xx.xx", user="root", password="22222", database="test"
)
cursor = connection.cursor()

cursor.execute(
"CREATE TABLE IF NOT EXISTS grade (name VARCHAR(255), id INT, score FLOAT)"
)

base_tuple = ("xiaoming", 1, 8.2)
expected_values = [
(f"{base_tuple[0]}{i}", i, base_tuple[2] + i + 1) for i in range(200)
]

cursor.executemany(
"INSERT INTO grade (name, id, score) VALUES (%s, %s, %s)", expected_values
)
connection.commit()

cursor.close()
connection.close()

dataset = ray.data.read_sql(
"SELECT * FROM grade",
lambda: pymysql.connect(host="xxxxx", user="xx", password="xx", database="xx"),
parallelism=4,
shard_keys=["id"],
)
actual_values = [tuple(record.values()) for record in dataset.take_all()]

assert sorted(actual_values) == sorted(expected_values)
assert dataset.materialize().num_blocks() == 4


def test_write_sql(temp_database: str):
connection = sqlite3.connect(temp_database)
connection.cursor().execute("CREATE TABLE test(string, number)")
Expand Down

0 comments on commit 668e5b4

Please sign in to comment.