diff --git a/contrib/connectors/pandas/cache.py b/contrib/connectors/pandas/cache.py new file mode 100644 index 0000000000000..15d4713f31bd3 --- /dev/null +++ b/contrib/connectors/pandas/cache.py @@ -0,0 +1,188 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from io import open +import json +import os +import tempfile +from time import time +try: + import cPickle as pickle +except ImportError: # pragma: no cover + import pickle + +import pandas as pd + +from werkzeug.contrib.cache import FileSystemCache +from werkzeug.posixemulation import rename + + +class DataFrameCache(FileSystemCache): + """ + A cache that stores Pandas DataFrames on the file system. + + DataFrames are stored in Feather Format - a fast on-disk representation + of the Apache Arrow in-memory format to eliminate serialization + overhead. + + This cache depends on being the only user of the `cache_dir`. Make + absolutely sure that nobody but this cache stores files there or + otherwise the cache will randomly delete files therein. + + :param cache_dir: the directory where cache files are stored. + :param threshold: the maximum number of items the cache stores before + it starts deleting some. + :param default_timeout: the default timeout that is used if no timeout is + specified on :meth:`~BaseCache.set`. A timeout of + 0 indicates that the cache never expires. + :param mode: the file mode wanted for the cache files, default 0600 + """ + + _fs_cache_suffix = '.cached' + _fs_metadata_suffix = '.metadata' + + def _list_dir(self): + """return a list of (fully qualified) cache filenames + """ + return [os.path.join(self._path, fn) for fn in os.listdir(self._path) + if fn.endswith(self._fs_cache_suffix)] + + def _prune(self): + entries = self._list_dir() + if len(entries) > self._threshold: + now = time() + for idx, cname in enumerate(entries): + mname = os.path.splitext(cname)[0] + self._fs_metadata_suffix + try: + with open(mname, 'r', encoding='utf-8') as f: + metadata = json.load(f) + except (IOError, OSError): + metadata = {'expires': -1} + try: + remove = ((metadata['expires'] != 0 and metadata['expires'] <= now) + or idx % 3 == 0) + if remove: + os.remove(cname) + os.remove(mname) + except (IOError, OSError): + pass + + def clear(self): + for cname in self._list_dir(): + try: + mname = os.path.splitext(cname)[0] + self._fs_metadata_suffix + os.remove(cname) + os.remove(mname) + except (IOError, OSError): + return False + return True + + def get(self, key): + filename = self._get_filename(key) + cname = filename + self._fs_cache_suffix + mname = filename + self._fs_metadata_suffix + try: + with open(mname, 'r', encoding='utf-8') as f: + metadata = json.load(f) + except (IOError, OSError): + metadata = {'expires': -1} + try: + with open(cname, 'rb') as f: + if metadata['expires'] == 0 or metadata['expires'] > time(): + read_method = getattr(pd, 'read_{}'.format(metadata['format'])) + read_args = metadata.get('read_args', {}) + if metadata['format'] == 'hdf': + return read_method(f.name, **read_args) + else: + return read_method(f, **read_args) + else: + os.remove(cname) + os.remove(mname) + return None + except (IOError, OSError): + return None + + def add(self, key, value, timeout=None): + filename = self._get_filename(key) + self._fs_cache_suffix + if not os.path.exists(filename): + return self.set(key, value, timeout) + return False + + def set(self, key, value, timeout=None): + metadata = {'expires': self._normalize_timeout(timeout)} + filename = self._get_filename(key) + cname = filename + self._fs_cache_suffix + mname = filename + self._fs_metadata_suffix + self._prune() + try: + fd, tmp = tempfile.mkstemp(suffix=self._fs_transaction_suffix, + dir=self._path) + with os.fdopen(fd, 'wb') as f: + try: + value.to_feather(f) + metadata['format'] = 'feather' + except ValueError: + try: + value.to_hdf(tmp, 'df') + metadata['format'] = 'hdf' + metadata['read_args'] = {'key': 'df'} + except ImportError: + # PyTables is not installed, so fallback to pickle + pickle.dump(value, f, pickle.HIGHEST_PROTOCOL) + metadata['format'] = 'pickle' + rename(tmp, cname) + os.chmod(cname, self._mode) + with open(mname, 'w', encoding='utf-8') as f: + json.dump(metadata, f) + os.chmod(mname, self._mode) + except (IOError, OSError): + return False + else: + return True + + def delete(self, key): + filename = self._get_filename(key) + cname = filename + self._fs_cache_suffix + mname = filename + self._fs_metadata_suffix + try: + os.remove(cname) + os.remove(mname) + except (IOError, OSError): + return False + else: + return True + + def has(self, key): + filename = self._get_filename(key) + cname = filename + self._fs_cache_suffix + mname = filename + self._fs_metadata_suffix + try: + with open(mname, 'r', encoding='utf-8') as f: + metadata = json.load(f) + except (IOError, OSError): + metadata = {'expires': -1} + try: + with open(cname, 'rb') as f: + if metadata['expires'] == 0 or metadata['expires'] > time(): + return True + else: + os.remove(cname) + os.remove(mname) + return False + except (IOError, OSError): + return False + + def inc(self, key, delta=1): + raise NotImplementedError() + + def dec(self, key, delta=1): + raise NotImplementedError() + + +dataframe_cache = DataFrameCache( + cache_dir='/tmp/pandasdatasource_cache', + threshold=200, + default_timeout=24 * 60 * 60, +) diff --git a/contrib/connectors/pandas/models.py b/contrib/connectors/pandas/models.py index fc0509b7dbbb4..a9fdd0be9de2d 100644 --- a/contrib/connectors/pandas/models.py +++ b/contrib/connectors/pandas/models.py @@ -5,6 +5,7 @@ from collections import OrderedDict from datetime import datetime +import hashlib import logging from past.builtins import basestring try: @@ -33,6 +34,7 @@ from superset.models.helpers import QueryResult, set_perm from superset.utils import QueryStatus +from .cache import dataframe_cache FORMATS = [ ('csv', 'csv'), @@ -196,8 +198,11 @@ def full_name(self): @property def database(self): - uri = urlparse(self.source_url) - return PandasDatabase(database_name=uri.netloc, + try: + database_name = urlparse(self.source_url).netloc + except AttributeError: + database_name = 'memory' + return PandasDatabase(database_name=database_name, cache_timeout=None) @property @@ -291,6 +296,13 @@ def get_empty_dataframe(self): columns.append((col.column_name, type)) return pd.DataFrame({k: pd.Series(dtype=t) for k, t in columns}) + @property + def cache_key(self): + source = {'source_url': self.source_url} + source.update(self.pandas_read_parameters) + s = str([(k, source[k]) for k in sorted(source.keys())]) + return hashlib.md5(s.encode('utf-8')).hexdigest() + def get_dataframe(self): """ Read the source_url and return a Pandas DataFrame. @@ -298,14 +310,24 @@ def get_dataframe(self): Use the PandasColumns to coerce columns into the correct dtype, and add any calculated columns to the DataFrame. """ - calculated_columns = [] if self.df is None: - self.df = self.pandas_read_method(self.source_url, - **self.pandas_read_parameters) - # read_html returns a list of DataFrames - if (isinstance(self.df, list) and - isinstance(self.df[0], pd.DataFrame)): - self.df = self.df[0] + cache_key = self.cache_key + self.df = dataframe_cache.get(cache_key) + if not isinstance(self.df, pd.DataFrame): + self.df = self.pandas_read_method(self.source_url, **self.pandas_read_parameters) + + # read_html returns a list of DataFrames + if (isinstance(self.df, list) and + isinstance(self.df[0], pd.DataFrame)): + self.df = self.df[0] + + # Our column names are always strings + self.df.columns = [str(col) for col in self.df.columns] + + timeout = self.cache_timeout or self.database.cache_timeout + dataframe_cache.set(cache_key, self.df, timeout) + + calculated_columns = [] for col in self.columns: name = col.column_name type = col.type @@ -379,7 +401,7 @@ def get_agg_function(self, expr): The function can be defined on the Connector, or on the DataFrame, in the local scope """ - if expr in ['sum', 'mean', 'std', 'sem', 'count']: + if expr in ['sum', 'mean', 'std', 'sem', 'count', 'min', 'max']: return expr if hasattr(self, expr): return getattr(self, expr) @@ -746,17 +768,19 @@ def get_metadata(self): dbcols = ( db.session.query(PandasColumn) .filter(PandasColumn.datasource == self) - .filter(or_(PandasColumn.column_name == col.name + .filter(or_(PandasColumn.column_name == col for col in df.columns))) dbcols = {dbcol.column_name: dbcol for dbcol in dbcols} for col in df.columns: - dbcol = dbcols.get(col.name, None) + dbcol = dbcols.get(col, None) if not dbcol: - dbcol = PandasColumn(column_name=col, type=df.dtypes[col].name) + dbcol = PandasColumn(column_name=str(col), type=df.dtypes[col].name) dbcol.groupby = dbcol.is_string dbcol.filterable = dbcol.is_string dbcol.sum = dbcol.is_num dbcol.avg = dbcol.is_num + dbcol.min = dbcol.is_num or dbcol.is_dttm + dbcol.max = dbcol.is_num or dbcol.is_dttm self.columns.append(dbcol) if not any_date_col and dbcol.is_time: diff --git a/contrib/tests/cache_tests.py b/contrib/tests/cache_tests.py new file mode 100644 index 0000000000000..825971aea0dcc --- /dev/null +++ b/contrib/tests/cache_tests.py @@ -0,0 +1,143 @@ +"""Unit tests for DataFrameCache""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import datetime +import os +import shutil +import tempfile +import time + +import pandas as pd +from pandas.testing import assert_frame_equal + +from tests.base_tests import SupersetTestCase + +from contrib.connectors.pandas.cache import DataFrameCache + + +class DataFrameCacheTestCase(SupersetTestCase): + + def setUp(self): + self.cache = DataFrameCache(cache_dir=tempfile.mkdtemp()) + + def tearDown(self): + shutil.rmtree(self.cache._path) + + def get_df(self, key): + return pd.DataFrame({'one': pd.Series([1, 2, 3]), + key: pd.Series([1, 2, 3, 4])}) + + def test_get_dict(self): + a = self.get_df('a') + b = self.get_df('b') + assert self.cache.set('a', a) + assert self.cache.set('b', b) + d = self.cache.get_dict('a', 'b') + assert 'a' in d + assert_frame_equal(a, d['a']) + assert_frame_equal(b, d['b']) + + def test_set_get(self): + for i in range(3): + assert self.cache.set(str(i), self.get_df(str(i))) + + for i in range(3): + assert_frame_equal(self.cache.get(str(i)), self.get_df(str(i))) + + def test_get_set(self): + assert self.cache.set('foo', self.get_df('bar')) + assert_frame_equal(self.cache.get('foo'), self.get_df('bar')) + + def test_get_many(self): + assert self.cache.set('foo', self.get_df('bar')) + assert self.cache.set('spam', self.get_df('eggs')) + result = list(self.cache.get_many('foo', 'spam')) + assert_frame_equal(result[0], self.get_df('bar')) + assert_frame_equal(result[1], self.get_df('eggs')) + + def test_set_many(self): + assert self.cache.set_many({'foo': self.get_df('bar'), + 'spam': self.get_df('eggs')}) + assert_frame_equal(self.cache.get('foo'), self.get_df('bar')) + assert_frame_equal(self.cache.get('spam'), self.get_df('eggs')) + + def test_add(self): + # sanity check that add() works like set() + assert self.cache.add('foo', self.get_df('bar')) + assert_frame_equal(self.cache.get('foo'), self.get_df('bar')) + assert not self.cache.add('foo', self.get_df('qux')) + assert_frame_equal(self.cache.get('foo'), self.get_df('bar')) + + def test_delete(self): + assert self.cache.add('foo', self.get_df('bar')) + assert_frame_equal(self.cache.get('foo'), self.get_df('bar')) + assert self.cache.delete('foo') + assert self.cache.get('foo') is None + + def test_delete_many(self): + assert self.cache.add('foo', self.get_df('bar')) + assert self.cache.add('spam', self.get_df('eggs')) + assert self.cache.delete_many('foo', 'spam') + assert self.cache.get('foo') is None + assert self.cache.get('spam') is None + + def test_timeout(self): + self.cache.set('foo', self.get_df('bar'), 0) + assert_frame_equal(self.cache.get('foo'), self.get_df('bar')) + self.cache.set('baz', self.get_df('qux'), 1) + assert_frame_equal(self.cache.get('baz'), self.get_df('qux')) + time.sleep(3) + # timeout of zero means no timeout + assert_frame_equal(self.cache.get('foo'), self.get_df('bar')) + assert self.cache.get('baz') is None + + def test_has(self): + assert self.cache.has('foo') in (False, 0) + assert self.cache.has('spam') in (False, 0) + assert self.cache.set('foo', self.get_df('bar')) + assert self.cache.has('foo') in (True, 1) + assert self.cache.has('spam') in (False, 0) + self.cache.delete('foo') + assert self.cache.has('foo') in (False, 0) + assert self.cache.has('spam') in (False, 0) + + def test_prune(self): + THRESHOLD = 13 + c = DataFrameCache(cache_dir=tempfile.mkdtemp(), + threshold=THRESHOLD) + + for i in range(2 * THRESHOLD): + assert c.set(str(i), self.get_df(str(i))) + + cache_files = os.listdir(c._path) + shutil.rmtree(c._path) + + # There will be a small .expires file for every cached file + assert len(cache_files) <= THRESHOLD * 2 + + def test_clear(self): + cache_files = os.listdir(self.cache._path) + assert self.cache.set('foo', self.get_df('bar')) + cache_files = os.listdir(self.cache._path) + # There will be a small .expires file for every cached file + assert len(cache_files) == 2 + assert self.cache.clear() + cache_files = os.listdir(self.cache._path) + assert len(cache_files) == 0 + + def test_non_feather_format(self): + # The Feather on-disk format isn't indexed and doesn't handle + # Object-type columns with non-homogeneous data + # See: + # - https://github.com/wesm/feather/tree/master/python#limitations + # - https://github.com/wesm/feather/issues/200 + now = datetime.datetime.now + df = pd.DataFrame({'one': pd.Series([1, 2, 3], index=['a', 'b', 'c']), + 'two': pd.Series([1, 'string', now(), 4], + index=['a', 'b', 'c', 'd'])}) + + assert self.cache.set('foo', df) + assert_frame_equal(self.cache.get('foo'), df) diff --git a/setup.py b/setup.py index df81a223e1302..c208291472774 100644 --- a/setup.py +++ b/setup.py @@ -48,6 +48,7 @@ def get_git_sha(): 'celery==3.1.25', 'colorama==0.3.9', 'cryptography==1.9', + 'feather-format==0.4.0', 'flask==0.12.2', 'flask-appbuilder==1.9.4', 'flask-cache==0.13.1',