Skip to content

Commit

Permalink
Cache source dataframes in PandasDatasource - see apache#3302
Browse files Browse the repository at this point in the history
  • Loading branch information
rhunwicks committed Oct 5, 2017
1 parent 1df6237 commit 878c7c4
Show file tree
Hide file tree
Showing 4 changed files with 369 additions and 13 deletions.
188 changes: 188 additions & 0 deletions contrib/connectors/pandas/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from io import open
import json
import os
import tempfile
from time import time
try:
import cPickle as pickle
except ImportError: # pragma: no cover
import pickle

import pandas as pd

from werkzeug.contrib.cache import FileSystemCache
from werkzeug.posixemulation import rename


class DataFrameCache(FileSystemCache):
"""
A cache that stores Pandas DataFrames on the file system.
DataFrames are stored in Feather Format - a fast on-disk representation
of the Apache Arrow in-memory format to eliminate serialization
overhead.
This cache depends on being the only user of the `cache_dir`. Make
absolutely sure that nobody but this cache stores files there or
otherwise the cache will randomly delete files therein.
:param cache_dir: the directory where cache files are stored.
:param threshold: the maximum number of items the cache stores before
it starts deleting some.
:param default_timeout: the default timeout that is used if no timeout is
specified on :meth:`~BaseCache.set`. A timeout of
0 indicates that the cache never expires.
:param mode: the file mode wanted for the cache files, default 0600
"""

_fs_cache_suffix = '.cached'
_fs_metadata_suffix = '.metadata'

def _list_dir(self):
"""return a list of (fully qualified) cache filenames
"""
return [os.path.join(self._path, fn) for fn in os.listdir(self._path)
if fn.endswith(self._fs_cache_suffix)]

def _prune(self):
entries = self._list_dir()
if len(entries) > self._threshold:
now = time()
for idx, cname in enumerate(entries):
mname = os.path.splitext(cname)[0] + self._fs_metadata_suffix
try:
with open(mname, 'r', encoding='utf-8') as f:
metadata = json.load(f)
except (IOError, OSError):
metadata = {'expires': -1}
try:
remove = ((metadata['expires'] != 0 and metadata['expires'] <= now)
or idx % 3 == 0)
if remove:
os.remove(cname)
os.remove(mname)
except (IOError, OSError):
pass

def clear(self):
for cname in self._list_dir():
try:
mname = os.path.splitext(cname)[0] + self._fs_metadata_suffix
os.remove(cname)
os.remove(mname)
except (IOError, OSError):
return False
return True

def get(self, key):
filename = self._get_filename(key)
cname = filename + self._fs_cache_suffix
mname = filename + self._fs_metadata_suffix
try:
with open(mname, 'r', encoding='utf-8') as f:
metadata = json.load(f)
except (IOError, OSError):
metadata = {'expires': -1}
try:
with open(cname, 'rb') as f:
if metadata['expires'] == 0 or metadata['expires'] > time():
read_method = getattr(pd, 'read_{}'.format(metadata['format']))
read_args = metadata.get('read_args', {})
if metadata['format'] == 'hdf':
return read_method(f.name, **read_args)
else:
return read_method(f, **read_args)
else:
os.remove(cname)
os.remove(mname)
return None
except (IOError, OSError):
return None

def add(self, key, value, timeout=None):
filename = self._get_filename(key) + self._fs_cache_suffix
if not os.path.exists(filename):
return self.set(key, value, timeout)
return False

def set(self, key, value, timeout=None):
metadata = {'expires': self._normalize_timeout(timeout)}
filename = self._get_filename(key)
cname = filename + self._fs_cache_suffix
mname = filename + self._fs_metadata_suffix
self._prune()
try:
fd, tmp = tempfile.mkstemp(suffix=self._fs_transaction_suffix,
dir=self._path)
with os.fdopen(fd, 'wb') as f:
try:
value.to_feather(f)
metadata['format'] = 'feather'
except ValueError:
try:
value.to_hdf(tmp, 'df')
metadata['format'] = 'hdf'
metadata['read_args'] = {'key': 'df'}
except ImportError:
# PyTables is not installed, so fallback to pickle
pickle.dump(value, f, pickle.HIGHEST_PROTOCOL)
metadata['format'] = 'pickle'
rename(tmp, cname)
os.chmod(cname, self._mode)
with open(mname, 'w', encoding='utf-8') as f:
json.dump(metadata, f)
os.chmod(mname, self._mode)
except (IOError, OSError):
return False
else:
return True

def delete(self, key):
filename = self._get_filename(key)
cname = filename + self._fs_cache_suffix
mname = filename + self._fs_metadata_suffix
try:
os.remove(cname)
os.remove(mname)
except (IOError, OSError):
return False
else:
return True

def has(self, key):
filename = self._get_filename(key)
cname = filename + self._fs_cache_suffix
mname = filename + self._fs_metadata_suffix
try:
with open(mname, 'r', encoding='utf-8') as f:
metadata = json.load(f)
except (IOError, OSError):
metadata = {'expires': -1}
try:
with open(cname, 'rb') as f:
if metadata['expires'] == 0 or metadata['expires'] > time():
return True
else:
os.remove(cname)
os.remove(mname)
return False
except (IOError, OSError):
return False

def inc(self, key, delta=1):
raise NotImplementedError()

def dec(self, key, delta=1):
raise NotImplementedError()


dataframe_cache = DataFrameCache(
cache_dir='/tmp/pandasdatasource_cache',
threshold=200,
default_timeout=24 * 60 * 60,
)
50 changes: 37 additions & 13 deletions contrib/connectors/pandas/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from collections import OrderedDict
from datetime import datetime
import hashlib
import logging
from past.builtins import basestring
try:
Expand Down Expand Up @@ -33,6 +34,7 @@
from superset.models.helpers import QueryResult, set_perm
from superset.utils import QueryStatus

from .cache import dataframe_cache

FORMATS = [
('csv', 'csv'),
Expand Down Expand Up @@ -196,8 +198,11 @@ def full_name(self):

@property
def database(self):
uri = urlparse(self.source_url)
return PandasDatabase(database_name=uri.netloc,
try:
database_name = urlparse(self.source_url).netloc
except AttributeError:
database_name = 'memory'
return PandasDatabase(database_name=database_name,
cache_timeout=None)

@property
Expand Down Expand Up @@ -291,21 +296,38 @@ def get_empty_dataframe(self):
columns.append((col.column_name, type))
return pd.DataFrame({k: pd.Series(dtype=t) for k, t in columns})

@property
def cache_key(self):
source = {'source_url': self.source_url}
source.update(self.pandas_read_parameters)
s = str([(k, source[k]) for k in sorted(source.keys())])
return hashlib.md5(s.encode('utf-8')).hexdigest()

def get_dataframe(self):
"""
Read the source_url and return a Pandas DataFrame.
Use the PandasColumns to coerce columns into the correct dtype,
and add any calculated columns to the DataFrame.
"""
calculated_columns = []
if self.df is None:
self.df = self.pandas_read_method(self.source_url,
**self.pandas_read_parameters)
# read_html returns a list of DataFrames
if (isinstance(self.df, list) and
isinstance(self.df[0], pd.DataFrame)):
self.df = self.df[0]
cache_key = self.cache_key
self.df = dataframe_cache.get(cache_key)
if not isinstance(self.df, pd.DataFrame):
self.df = self.pandas_read_method(self.source_url, **self.pandas_read_parameters)

# read_html returns a list of DataFrames
if (isinstance(self.df, list) and
isinstance(self.df[0], pd.DataFrame)):
self.df = self.df[0]

# Our column names are always strings
self.df.columns = [str(col) for col in self.df.columns]

timeout = self.cache_timeout or self.database.cache_timeout
dataframe_cache.set(cache_key, self.df, timeout)

calculated_columns = []
for col in self.columns:
name = col.column_name
type = col.type
Expand Down Expand Up @@ -379,7 +401,7 @@ def get_agg_function(self, expr):
The function can be defined on the Connector, or on the DataFrame,
in the local scope
"""
if expr in ['sum', 'mean', 'std', 'sem', 'count']:
if expr in ['sum', 'mean', 'std', 'sem', 'count', 'min', 'max']:
return expr
if hasattr(self, expr):
return getattr(self, expr)
Expand Down Expand Up @@ -746,17 +768,19 @@ def get_metadata(self):
dbcols = (
db.session.query(PandasColumn)
.filter(PandasColumn.datasource == self)
.filter(or_(PandasColumn.column_name == col.name
.filter(or_(PandasColumn.column_name == col
for col in df.columns)))
dbcols = {dbcol.column_name: dbcol for dbcol in dbcols}
for col in df.columns:
dbcol = dbcols.get(col.name, None)
dbcol = dbcols.get(col, None)
if not dbcol:
dbcol = PandasColumn(column_name=col, type=df.dtypes[col].name)
dbcol = PandasColumn(column_name=str(col), type=df.dtypes[col].name)
dbcol.groupby = dbcol.is_string
dbcol.filterable = dbcol.is_string
dbcol.sum = dbcol.is_num
dbcol.avg = dbcol.is_num
dbcol.min = dbcol.is_num or dbcol.is_dttm
dbcol.max = dbcol.is_num or dbcol.is_dttm
self.columns.append(dbcol)

if not any_date_col and dbcol.is_time:
Expand Down
Loading

0 comments on commit 878c7c4

Please sign in to comment.