Skip to content

Commit

Permalink
Refactor inspect logic
Browse files Browse the repository at this point in the history
  • Loading branch information
russss authored and Simon Willison committed May 22, 2018
1 parent d59366d commit 58b5a37
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 145 deletions.
167 changes: 23 additions & 144 deletions datasette/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,24 @@
from sanic import Sanic, response
from sanic.exceptions import InvalidUsage, NotFound

from datasette.views.base import (
HASH_BLOCK_SIZE,
from .views.base import (
DatasetteError,
RenderMixin,
ureg
)
from datasette.views.database import DatabaseDownload, DatabaseView
from datasette.views.index import IndexView
from datasette.views.table import RowView, TableView
from .views.database import DatabaseDownload, DatabaseView
from .views.index import IndexView
from .views.table import RowView, TableView

from . import hookspecs
from .utils import (
detect_fts,
detect_spatialite,
escape_css_string,
escape_sqlite,
get_all_foreign_keys,
get_plugins,
module_from_path,
to_css_class
)
from .inspect import inspect_hash, inspect_views, inspect_tables
from .version import __version__

app_root = Path(__file__).parent.parent
Expand Down Expand Up @@ -209,144 +206,26 @@ def prepare_connection(self, conn):
pm.hook.prepare_connection(conn=conn)

def inspect(self):
if not self._inspect:
self._inspect = {}
for filename in self.files:
path = Path(filename)
name = path.stem
if name in self._inspect:
raise Exception("Multiple files with same stem %s" % name)

# Calculate hash, efficiently
m = hashlib.sha256()
with path.open("rb") as fp:
while True:
data = fp.read(HASH_BLOCK_SIZE)
if not data:
break

m.update(data)
# List tables and their row counts
database_metadata = self.metadata.get("databases", {}).get(name, {})
tables = {}
views = []
with sqlite3.connect(
"file:{}?immutable=1".format(path), uri=True
) as conn:
self.prepare_connection(conn)
table_names = [
r["name"]
for r in conn.execute(
'select * from sqlite_master where type="table"'
)
]
views = [
v[0]
for v in conn.execute(
'select name from sqlite_master where type = "view"'
)
]
for table in table_names:
try:
count = conn.execute(
"select count(*) from {}".format(escape_sqlite(table))
).fetchone()[
0
]
except sqlite3.OperationalError:
# This can happen when running against a FTS virtual tables
# e.g. "select count(*) from some_fts;"
count = 0
# Does this table have a FTS table?
fts_table = detect_fts(conn, table)

# Figure out primary keys
table_info_rows = [
row
for row in conn.execute(
'PRAGMA table_info("{}")'.format(table)
).fetchall()
if row[-1]
]
table_info_rows.sort(key=lambda row: row[-1])
primary_keys = [str(r[1]) for r in table_info_rows]
label_column = None
# If table has two columns, one of which is ID, then label_column is the other one
column_names = [
r[1]
for r in conn.execute(
"PRAGMA table_info({});".format(escape_sqlite(table))
).fetchall()
]
if (
column_names
and len(column_names) == 2
and "id" in column_names
):
label_column = [c for c in column_names if c != "id"][0]
table_metadata = database_metadata.get("tables", {}).get(
table, {}
)
tables[table] = {
"name": table,
"columns": column_names,
"primary_keys": primary_keys,
"count": count,
"label_column": label_column,
"hidden": table_metadata.get("hidden") or False,
"fts_table": fts_table,
}

foreign_keys = get_all_foreign_keys(conn)
for table, info in foreign_keys.items():
tables[table]["foreign_keys"] = info

# Mark tables 'hidden' if they relate to FTS virtual tables
hidden_tables = [
r["name"]
for r in conn.execute(
"""
select name from sqlite_master
where rootpage = 0
and sql like '%VIRTUAL TABLE%USING FTS%'
"""
)
]

if detect_spatialite(conn):
# Also hide Spatialite internal tables
hidden_tables += [
"ElementaryGeometries",
"SpatialIndex",
"geometry_columns",
"spatial_ref_sys",
"spatialite_history",
"sql_statements_log",
"sqlite_sequence",
"views_geometry_columns",
"virts_geometry_columns",
] + [
r["name"]
for r in conn.execute(
"""
select name from sqlite_master
where name like "idx_%"
and type = "table"
"""
)
]

for t in tables.keys():
for hidden_table in hidden_tables:
if t == hidden_table or t.startswith(hidden_table):
tables[t]["hidden"] = True
continue

" Inspect the database and return a dictionary of table metadata "
if self._inspect:
return self._inspect

self._inspect = {}
for filename in self.files:
path = Path(filename)
name = path.stem
if name in self._inspect:
raise Exception("Multiple files with same stem %s" % name)

with sqlite3.connect(
"file:{}?immutable=1".format(path), uri=True
) as conn:
self.prepare_connection(conn)
self._inspect[name] = {
"hash": m.hexdigest(),
"hash": inspect_hash(path),
"file": str(path),
"tables": tables,
"views": views,
"views": inspect_views(conn),
"tables": inspect_tables(conn, self.metadata.get("databases", {}).get(name, {}))
}
return self._inspect

Expand Down
138 changes: 138 additions & 0 deletions datasette/inspect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import hashlib
import sqlite3

from .utils import detect_spatialite, detect_fts, escape_sqlite, get_all_foreign_keys


HASH_BLOCK_SIZE = 1024 * 1024


def inspect_hash(path):
" Calculate the hash of a database, efficiently. "
m = hashlib.sha256()
with path.open("rb") as fp:
while True:
data = fp.read(HASH_BLOCK_SIZE)
if not data:
break
m.update(data)

return m.hexdigest()


def inspect_views(conn):
" List views in a database. "
return [v[0] for v in conn.execute('select name from sqlite_master where type = "view"')]


def detect_label_column(column_names):
""" Detect the label column - which we display as the label for a joined column.
If a table has two columns, one of which is ID, then label_column is the other one.
"""
if (column_names and len(column_names) == 2 and "id" in column_names):
return [c for c in column_names if c != "id"][0]

return None


def detect_primary_keys(conn, table):
" Figure out primary keys for a table. "
table_info_rows = [
row
for row in conn.execute(
'PRAGMA table_info("{}")'.format(table)
).fetchall()
if row[-1]
]
table_info_rows.sort(key=lambda row: row[-1])
return [str(r[1]) for r in table_info_rows]


def inspect_tables(conn, database_metadata):
" List tables and their row counts, excluding uninteresting tables. "
tables = {}
table_names = [
r["name"]
for r in conn.execute(
'select * from sqlite_master where type="table"'
)
]

for table in table_names:
table_metadata = database_metadata.get("tables", {}).get(
table, {}
)

try:
count = conn.execute(
"select count(*) from {}".format(escape_sqlite(table))
).fetchone()[0]
except sqlite3.OperationalError:
# This can happen when running against a FTS virtual table
# e.g. "select count(*) from some_fts;"
count = 0

column_names = [
r[1]
for r in conn.execute(
"PRAGMA table_info({});".format(escape_sqlite(table))
).fetchall()
]

tables[table] = {
"name": table,
"columns": column_names,
"primary_keys": detect_primary_keys(conn, table),
"count": count,
"label_column": detect_label_column(column_names),
"hidden": table_metadata.get("hidden") or False,
"fts_table": detect_fts(conn, table),
}

foreign_keys = get_all_foreign_keys(conn)
for table, info in foreign_keys.items():
tables[table]["foreign_keys"] = info

# Mark tables 'hidden' if they relate to FTS virtual tables
hidden_tables = [
r["name"]
for r in conn.execute(
"""
select name from sqlite_master
where rootpage = 0
and sql like '%VIRTUAL TABLE%USING FTS%'
"""
)
]

if detect_spatialite(conn):
# Also hide Spatialite internal tables
hidden_tables += [
"ElementaryGeometries",
"SpatialIndex",
"geometry_columns",
"spatial_ref_sys",
"spatialite_history",
"sql_statements_log",
"sqlite_sequence",
"views_geometry_columns",
"virts_geometry_columns",
] + [
r["name"]
for r in conn.execute(
"""
select name from sqlite_master
where name like "idx_%"
and type = "table"
"""
)
]

for t in tables.keys():
for hidden_table in hidden_tables:
if t == hidden_table or t.startswith(hidden_table):
tables[t]["hidden"] = True
continue

return tables
1 change: 0 additions & 1 deletion datasette/views/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
connections = threading.local()
ureg = pint.UnitRegistry()

HASH_BLOCK_SIZE = 1024 * 1024
HASH_LENGTH = 7


Expand Down

0 comments on commit 58b5a37

Please sign in to comment.