Skip to content

Commit

Permalink
Add local dask tests to spot-check support on CI going forward
Browse files Browse the repository at this point in the history
(depends on dask changes to be PR'd)
  • Loading branch information
ihnorton committed Mar 29, 2019
1 parent 1c81fa6 commit 617a6ae
Showing 1 changed file with 73 additions and 0 deletions.
73 changes: 73 additions & 0 deletions tiledb/tests/test_dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import dask, dask.array as da
import tiledb
from tiledb.tests.common import DiskTestCase

import numpy as np
from numpy.testing import assert_array_equal, assert_approx_equal


class DaskSupport(DiskTestCase):
def test_dask_from_numpy_1d(self):
uri = self.path("np_1attr")
A = np.random.randn(500,500)
T = tiledb.from_numpy(uri, A, tile=100)
T.close()

T = tiledb.open(uri)
D = da.from_tiledb(T)
assert_array_equal(D, A)

D2 = da.from_tiledb(uri)
assert_array_equal(D2, A)
self.assertAlmostEqual(np.mean(A), D2.mean().compute(scheduler='single-threaded'))

def _make_multiattr_2d(self, uri, shape=(0,100), tile=10):
dom = tiledb.Domain(
tiledb.Dim("x", (0,100), dtype=np.uint64, tile=tile),
tiledb.Dim("y", (0,500), dtype=np.uint64, tile=tile))
schema = tiledb.ArraySchema(
attrs=(tiledb.Attr("attr1"),
tiledb.Attr("attr2")),
domain=dom)

tiledb.DenseArray.create(uri, schema)

def test_dask_multiattr_2d(self):
uri = self.path("multiattr")

self._make_multiattr_2d(uri)

T = tiledb.DenseArray(uri, 'w')

ar1 = np.random.randn(*T.schema.shape)
ar2 = np.random.randn(*T.schema.shape)

T[:] = {'attr1': ar1,
'attr2': ar2}
T.close()
T = tiledb.DenseArray(uri, 'r')

# basic round-trip from dask.array
D = da.from_tiledb(T, attribute='attr2')
assert_array_equal(ar2, np.array(D))

# smoke-test computation
# note: re-init from_tiledb each time, or else dask just uses the cached materialization
D = da.from_tiledb(uri, attribute='attr2')
self.assertAlmostEqual(np.mean(ar2), D.mean().compute(scheduler='threads', num_workers=4))
D = da.from_tiledb(uri, attribute='attr2')
self.assertAlmostEqual(np.mean(ar2), D.mean().compute(scheduler='single-threaded'))
#self.assertAlmostEqual(np.mean(ar1), D.mean().compute(scheduler='processes', num_workers=1))


# test dask.distributed
from dask.distributed import Client
with Client() as client:
assert_approx_equal(D.mean().compute(), np.mean(ar2))

def test_dask_write(self):
uri = self.path("dask_w")
D = da.random.random(10,10)
D.to_tiledb(uri)
DT = da.from_tiledb(uri)
assert_array_equal(D, DT)

0 comments on commit 617a6ae

Please sign in to comment.