diff --git a/datasette/app.py b/datasette/app.py index e4ce6622d1..f6abe3583f 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -15,27 +15,24 @@ from sanic import Sanic, response from sanic.exceptions import InvalidUsage, NotFound -from datasette.views.base import ( - HASH_BLOCK_SIZE, +from .views.base import ( DatasetteError, RenderMixin, ureg ) -from datasette.views.database import DatabaseDownload, DatabaseView -from datasette.views.index import IndexView -from datasette.views.table import RowView, TableView +from .views.database import DatabaseDownload, DatabaseView +from .views.index import IndexView +from .views.table import RowView, TableView from . import hookspecs from .utils import ( - detect_fts, - detect_spatialite, escape_css_string, escape_sqlite, - get_all_foreign_keys, get_plugins, module_from_path, to_css_class ) +from .inspect import inspect_hash, inspect_views, inspect_tables from .version import __version__ app_root = Path(__file__).parent.parent @@ -209,144 +206,26 @@ def prepare_connection(self, conn): pm.hook.prepare_connection(conn=conn) def inspect(self): - if not self._inspect: - self._inspect = {} - for filename in self.files: - path = Path(filename) - name = path.stem - if name in self._inspect: - raise Exception("Multiple files with same stem %s" % name) - - # Calculate hash, efficiently - m = hashlib.sha256() - with path.open("rb") as fp: - while True: - data = fp.read(HASH_BLOCK_SIZE) - if not data: - break - - m.update(data) - # List tables and their row counts - database_metadata = self.metadata.get("databases", {}).get(name, {}) - tables = {} - views = [] - with sqlite3.connect( - "file:{}?immutable=1".format(path), uri=True - ) as conn: - self.prepare_connection(conn) - table_names = [ - r["name"] - for r in conn.execute( - 'select * from sqlite_master where type="table"' - ) - ] - views = [ - v[0] - for v in conn.execute( - 'select name from sqlite_master where type = "view"' - ) - ] - for table in table_names: - try: - count = conn.execute( - "select count(*) from {}".format(escape_sqlite(table)) - ).fetchone()[ - 0 - ] - except sqlite3.OperationalError: - # This can happen when running against a FTS virtual tables - # e.g. "select count(*) from some_fts;" - count = 0 - # Does this table have a FTS table? - fts_table = detect_fts(conn, table) - - # Figure out primary keys - table_info_rows = [ - row - for row in conn.execute( - 'PRAGMA table_info("{}")'.format(table) - ).fetchall() - if row[-1] - ] - table_info_rows.sort(key=lambda row: row[-1]) - primary_keys = [str(r[1]) for r in table_info_rows] - label_column = None - # If table has two columns, one of which is ID, then label_column is the other one - column_names = [ - r[1] - for r in conn.execute( - "PRAGMA table_info({});".format(escape_sqlite(table)) - ).fetchall() - ] - if ( - column_names - and len(column_names) == 2 - and "id" in column_names - ): - label_column = [c for c in column_names if c != "id"][0] - table_metadata = database_metadata.get("tables", {}).get( - table, {} - ) - tables[table] = { - "name": table, - "columns": column_names, - "primary_keys": primary_keys, - "count": count, - "label_column": label_column, - "hidden": table_metadata.get("hidden") or False, - "fts_table": fts_table, - } - - foreign_keys = get_all_foreign_keys(conn) - for table, info in foreign_keys.items(): - tables[table]["foreign_keys"] = info - - # Mark tables 'hidden' if they relate to FTS virtual tables - hidden_tables = [ - r["name"] - for r in conn.execute( - """ - select name from sqlite_master - where rootpage = 0 - and sql like '%VIRTUAL TABLE%USING FTS%' - """ - ) - ] - - if detect_spatialite(conn): - # Also hide Spatialite internal tables - hidden_tables += [ - "ElementaryGeometries", - "SpatialIndex", - "geometry_columns", - "spatial_ref_sys", - "spatialite_history", - "sql_statements_log", - "sqlite_sequence", - "views_geometry_columns", - "virts_geometry_columns", - ] + [ - r["name"] - for r in conn.execute( - """ - select name from sqlite_master - where name like "idx_%" - and type = "table" - """ - ) - ] - - for t in tables.keys(): - for hidden_table in hidden_tables: - if t == hidden_table or t.startswith(hidden_table): - tables[t]["hidden"] = True - continue - + " Inspect the database and return a dictionary of table metadata " + if self._inspect: + return self._inspect + + self._inspect = {} + for filename in self.files: + path = Path(filename) + name = path.stem + if name in self._inspect: + raise Exception("Multiple files with same stem %s" % name) + + with sqlite3.connect( + "file:{}?immutable=1".format(path), uri=True + ) as conn: + self.prepare_connection(conn) self._inspect[name] = { - "hash": m.hexdigest(), + "hash": inspect_hash(path), "file": str(path), - "tables": tables, - "views": views, + "views": inspect_views(conn), + "tables": inspect_tables(conn, self.metadata.get("databases", {}).get(name, {})) } return self._inspect diff --git a/datasette/inspect.py b/datasette/inspect.py new file mode 100644 index 0000000000..1f35fa66aa --- /dev/null +++ b/datasette/inspect.py @@ -0,0 +1,138 @@ +import hashlib +import sqlite3 + +from .utils import detect_spatialite, detect_fts, escape_sqlite, get_all_foreign_keys + + +HASH_BLOCK_SIZE = 1024 * 1024 + + +def inspect_hash(path): + " Calculate the hash of a database, efficiently. " + m = hashlib.sha256() + with path.open("rb") as fp: + while True: + data = fp.read(HASH_BLOCK_SIZE) + if not data: + break + m.update(data) + + return m.hexdigest() + + +def inspect_views(conn): + " List views in a database. " + return [v[0] for v in conn.execute('select name from sqlite_master where type = "view"')] + + +def detect_label_column(column_names): + """ Detect the label column - which we display as the label for a joined column. + + If a table has two columns, one of which is ID, then label_column is the other one. + """ + if (column_names and len(column_names) == 2 and "id" in column_names): + return [c for c in column_names if c != "id"][0] + + return None + + +def detect_primary_keys(conn, table): + " Figure out primary keys for a table. " + table_info_rows = [ + row + for row in conn.execute( + 'PRAGMA table_info("{}")'.format(table) + ).fetchall() + if row[-1] + ] + table_info_rows.sort(key=lambda row: row[-1]) + return [str(r[1]) for r in table_info_rows] + + +def inspect_tables(conn, database_metadata): + " List tables and their row counts, excluding uninteresting tables. " + tables = {} + table_names = [ + r["name"] + for r in conn.execute( + 'select * from sqlite_master where type="table"' + ) + ] + + for table in table_names: + table_metadata = database_metadata.get("tables", {}).get( + table, {} + ) + + try: + count = conn.execute( + "select count(*) from {}".format(escape_sqlite(table)) + ).fetchone()[0] + except sqlite3.OperationalError: + # This can happen when running against a FTS virtual table + # e.g. "select count(*) from some_fts;" + count = 0 + + column_names = [ + r[1] + for r in conn.execute( + "PRAGMA table_info({});".format(escape_sqlite(table)) + ).fetchall() + ] + + tables[table] = { + "name": table, + "columns": column_names, + "primary_keys": detect_primary_keys(conn, table), + "count": count, + "label_column": detect_label_column(column_names), + "hidden": table_metadata.get("hidden") or False, + "fts_table": detect_fts(conn, table), + } + + foreign_keys = get_all_foreign_keys(conn) + for table, info in foreign_keys.items(): + tables[table]["foreign_keys"] = info + + # Mark tables 'hidden' if they relate to FTS virtual tables + hidden_tables = [ + r["name"] + for r in conn.execute( + """ + select name from sqlite_master + where rootpage = 0 + and sql like '%VIRTUAL TABLE%USING FTS%' + """ + ) + ] + + if detect_spatialite(conn): + # Also hide Spatialite internal tables + hidden_tables += [ + "ElementaryGeometries", + "SpatialIndex", + "geometry_columns", + "spatial_ref_sys", + "spatialite_history", + "sql_statements_log", + "sqlite_sequence", + "views_geometry_columns", + "virts_geometry_columns", + ] + [ + r["name"] + for r in conn.execute( + """ + select name from sqlite_master + where name like "idx_%" + and type = "table" + """ + ) + ] + + for t in tables.keys(): + for hidden_table in hidden_tables: + if t == hidden_table or t.startswith(hidden_table): + tables[t]["hidden"] = True + continue + + return tables diff --git a/datasette/views/base.py b/datasette/views/base.py index d950fa73cb..997350ddbc 100644 --- a/datasette/views/base.py +++ b/datasette/views/base.py @@ -25,7 +25,6 @@ connections = threading.local() ureg = pint.UnitRegistry() -HASH_BLOCK_SIZE = 1024 * 1024 HASH_LENGTH = 7