diff --git a/pyroSAR/drivers.py b/pyroSAR/drivers.py index 96d457c0..3c882640 100644 --- a/pyroSAR/drivers.py +++ b/pyroSAR/drivers.py @@ -1891,7 +1891,7 @@ class Archive(object): """ def __init__(self, dbfile, custom_fields=None, postgres=False, user='postgres', - password='1234', host='localhost', port=5432, cleanup=True): + password='1234', host='localhost', port=5432, cleanup=True, add_geometry=False): # check for driver, if postgres then check if server is reachable if not postgres: self.driver = 'sqlite' @@ -1973,7 +1973,14 @@ def __init__(self, dbfile, custom_fields=None, postgres=False, user='postgres', Column('hv', Integer), Column('vh', Integer), Column('bbox', Geometry(geometry_type='POLYGON', management=True, srid=4326))) - + + colnames = self.get_colnames() if sql_inspect(self.engine).has_table('data') else [] + + if (add_geometry and not sql_inspect(self.engine).has_table('data')) or 'geometry' in colnames: + # add geometry to schema if new database is created, or database had it enabled once + self.data_schema.append_column(Column('geometry', + Geometry(geometry_type='POLYGON', management=True, srid=4326))) + # add custom fields if self.custom_fields is not None: for key, val in self.custom_fields.items(): @@ -1993,6 +2000,10 @@ def __init__(self, dbfile, custom_fields=None, postgres=False, user='postgres', if not sql_inspect(self.engine).has_table('data'): log.debug("creating DB table 'data'") self.data_schema.create(self.engine) + elif add_geometry and 'geometry' not in self.get_colnames('data'): + log.info("add_geometry has been enabled after database is already created, " + "if you want to update table 'data' to include this column, " + "run Archive.update_geometry_field() once!") if not sql_inspect(self.engine).has_table('duplicates'): log.debug("creating DB table 'duplicates'") self.duplicates_schema.create(self.engine) @@ -2008,7 +2019,72 @@ def __init__(self, dbfile, custom_fields=None, postgres=False, user='postgres', log.info('checking for missing scenes') self.cleanup() sys.stdout.flush() - + + def update_geometry_field(self): + """ + Add the geometry as column to an existing database, then re-ingests all scenes with added geometry column + """ + if 'geometry' not in self.get_colnames(): + # get all scenes from data table + temp_data = self.Session().query(self.Data.scene) + to_insert = [] + # save in new list + for entry in temp_data: + to_insert.append(entry[0]) + # remove old table + self.drop_table('data') + + # define new table + temp_table_schema = Table('data', self.meta, + Column('sensor', String), + Column('orbit', String), + Column('orbitNumber_abs', Integer), + Column('orbitNumber_rel', Integer), + Column('cycleNumber', Integer), + Column('frameNumber', Integer), + Column('acquisition_mode', String), + Column('start', String), + Column('stop', String), + Column('product', String), + Column('samples', Integer), + Column('lines', Integer), + Column('outname_base', String, primary_key=True), + Column('scene', String), + Column('hh', Integer), + Column('vv', Integer), + Column('hv', Integer), + Column('vh', Integer), + Column('bbox', Geometry(geometry_type='POLYGON', management=True, srid=4326)), + Column('geometry', Geometry(geometry_type='POLYGON', management=True, srid=4326))) + + # create table + log.debug("creating DB table 'data' with geometry field") + temp_table_schema.create(self.engine) + + # update base + self.Base = automap_base(metadata=self.meta) + self.Base.prepare(self.engine, reflect=True) + self.Data = self.Base.classes.data + # insert previous data in new table + self.insert(to_insert) + + def get_class_by_tablename(self, table): + """Return class reference mapped to table. + adapted from OrangeTux's comment on + https://stackoverflow.com/questions/11668355/sqlalchemy-get-model-from-table-name-this-may-imply-appending-some-function-to + + Parameters + ---------- + table: str + String with name of table. + Returns + ------- + Class reference or None. + """ + for c in self.Base.classes: + if hasattr(c, '__table__') and str(c.__table__) == table: + return c + def add_tables(self, tables): """ Add tables to the database per :class:`sqlalchemy.schema.Table` @@ -2069,8 +2145,8 @@ def __load_spatialite(dbapi_conn, connection_record): continue else: dbapi_conn.load_extension('mod_spatialite') - - def __prepare_insertion(self, scene): + + def __prepare_insertion(self, scene, table='data'): """ read scene metadata and parse a string for inserting it into the database @@ -2078,6 +2154,8 @@ def __prepare_insertion(self, scene): ---------- scene: str or ID a SAR scene + table: str + which table to prepare insert obj for? Returns ------- @@ -2086,16 +2164,17 @@ def __prepare_insertion(self, scene): id = scene if isinstance(scene, ID) else identify(scene) pols = [x.lower() for x in id.polarizations] # insertion as an object of Class Data (reflected in the init()) - insertion = self.Data() - colnames = self.get_colnames() + # insertion = self.Data() + insertion = self.get_class_by_tablename(table)() + colnames = self.get_colnames(table) for attribute in colnames: - if attribute == 'bbox': - geom = id.bbox() + if attribute in ['bbox', 'geometry']: + geom = getattr(scene, attribute)() geom.reproject(4326) geom = geom.convert2wkt(set3D=False)[0] geom = 'SRID=4326;' + str(geom) # set attributes of the Data object according to input - setattr(insertion, 'bbox', geom) + setattr(insertion, attribute, geom) elif attribute in ['hh', 'vv', 'hv', 'vh']: setattr(insertion, attribute, int(attribute in pols)) else: @@ -2112,23 +2191,24 @@ def __prepare_insertion(self, scene): def __select_missing(self, table): """ - + Parameters + ------- + table: str + which table to search missing scenes in, must contain scene column Returns ------- list the names of all scenes, which are no longer stored in their registered location """ - if table == 'data': - # using ORM query to get all scenes locations - scenes = self.Session().query(self.Data.scene) - elif table == 'duplicates': - scenes = self.Session().query(self.Duplicates.scene) - else: - raise ValueError("parameter 'table' must either be 'data' or 'duplicates'") + table_obj = self.get_class_by_tablename(table) + if table_obj is None: + log.info(f'Table {table} is not registered in the database') + return [] + scenes = self.Session().query(table_obj.scene) files = [self.encode(x[0]) for x in scenes] return [x for x in files if not os.path.isfile(x)] - - def insert(self, scene_in, pbar=False, test=False): + + def insert(self, scene_in, table='data', pbar=False, test=False): """ Insert one or many scenes into the database @@ -2136,13 +2216,13 @@ def insert(self, scene_in, pbar=False, test=False): ---------- scene_in: str or ID or list a SAR scene or a list of scenes to be inserted + table: str + which table to insert to pbar: bool show a progress bar? test: bool should the insertion only be tested or directly be committed to the database? """ - length = len(scene_in) if isinstance(scene_in, list) else 1 - if isinstance(scene_in, (ID, str)): scene_in = [scene_in] if not isinstance(scene_in, list): @@ -2150,7 +2230,7 @@ def insert(self, scene_in, pbar=False, test=False): 'or a list containing several of either') log.info('filtering scenes by name') - scenes = self.filter_scenelist(scene_in) + scenes = self.filter_scenelist(scene_in, table) if len(scenes) == 0: log.info('...nothing to be done') return @@ -2177,8 +2257,8 @@ def insert(self, scene_in, pbar=False, test=False): session = self.Session() for i, id in enumerate(scenes): basename = id.outname_base() - if not self.is_registered(id) and basename not in basenames: - insertion = self.__prepare_insertion(id) + if not self.is_registered(id, table) and basename not in basenames: + insertion = self.__prepare_insertion(id, table) insertions.append(insertion) counter_regulars += 1 log.debug('regular: {}'.format(id.scene)) @@ -2212,8 +2292,8 @@ def insert(self, scene_in, pbar=False, test=False): log.info(message.format(counter_regulars, '' if counter_regulars == 1 else 's')) message = '{0} duplicate{1} registered' log.info(message.format(counter_duplicates, '' if counter_duplicates == 1 else 's')) - - def is_registered(self, scene): + + def is_registered(self, scene, table='data'): """ Simple check if a scene is already registered in the database. @@ -2221,7 +2301,8 @@ def is_registered(self, scene): ---------- scene: str or ID the SAR scene - + table: + which table to search in (must contain outname_base column with scene id) Returns ------- bool @@ -2229,8 +2310,9 @@ def is_registered(self, scene): """ id = scene if isinstance(scene, ID) else identify(scene) # ORM query, where scene equals id.scene, return first - exists_data = self.Session().query(self.Data.outname_base).filter( - self.Data.outname_base == id.outname_base()).first() + table = self.get_class_by_tablename(table) + exists_data = self.Session().query(table.outname_base).filter( + table.outname_base == id.outname_base()).first() exists_duplicates = self.Session().query(self.Duplicates.outname_base).filter( self.Duplicates.outname_base == id.outname_base()).first() in_data = False @@ -2264,15 +2346,19 @@ def __is_registered_in_duplicates(self, scene): in_dup = len(exists_duplicates) != 0 return in_dup - def cleanup(self): + def cleanup(self, table='data'): """ Remove all scenes from the database, which are no longer stored in their registered location + Parameters + ---------- + table: str + tablename, table must contain scene column Returns ------- """ - missing = self.__select_missing('data') + missing = self.__select_missing(table) for scene in missing: log.info('Removing missing scene from database tables: {}'.format(scene)) self.drop_element(scene, with_duplicates=True) @@ -2329,8 +2415,8 @@ def export2shp(self, path, table='data'): # ogr2ogr(db_connection, path, options={'format': 'ESRI Shapefile'}) subprocess.call(['ogr2ogr', '-f', 'ESRI Shapefile', path, db_connection, table]) - - def filter_scenelist(self, scenelist): + + def filter_scenelist(self, scenelist, table='data'): """ Filter a list of scenes by file names already registered in the database. @@ -2338,6 +2424,8 @@ def filter_scenelist(self, scenelist): ---------- scenelist: :obj:`list` of :obj:`str` or :obj:`pyroSAR.drivers.ID` the scenes to be filtered + table: str + which table to search in Returns ------- @@ -2350,7 +2438,8 @@ def filter_scenelist(self, scenelist): raise TypeError("items in scenelist must be of type 'str' or 'pyroSAR.ID'") # ORM query, get all scenes locations - scenes_data = self.Session().query(self.Data.scene) + table = self.get_class_by_tablename(table) + scenes_data = self.Session().query(table.scene) registered = [os.path.basename(self.encode(x[0])) for x in scenes_data] scenes_duplicates = self.Session().query(self.Duplicates.scene) duplicates = [os.path.basename(self.encode(x[0])) for x in scenes_duplicates] @@ -2362,15 +2451,19 @@ def get_colnames(self, table='data'): """ Return the names of all columns of a table. + Parameters + ---------- + table: str + tablename + Returns ------- list the column names of the chosen table """ - # get all columns of one table, but shows geometry columns not correctly - table_info = Table(table, self.meta, autoload=True, autoload_with=self.engine) - col_names = table_info.c.keys() - + dicts = sql_inspect(self.engine).get_columns(table) + col_names = [i['name'] for i in dicts] + return sorted([self.encode(x) for x in col_names]) def get_tablenames(self, return_all=False): @@ -2398,7 +2491,8 @@ def get_tablenames(self, return_all=False): 'virts_geometry_columns', 'virts_geometry_columns_auth', 'virts_geometry_columns_field_infos', 'virts_geometry_columns_statistics', 'data_licenses', 'KNN'] # get tablenames from metadata - tables = sorted([self.encode(x) for x in self.meta.tables.keys()]) + insp = sql_inspect(self.engine) + tables = sorted([self.encode(x) for x in insp.get_table_names()]) if return_all: return tables else: @@ -2408,17 +2502,22 @@ def get_tablenames(self, return_all=False): ret.append(i) return ret - def get_unique_directories(self): + def get_unique_directories(self, table='data'): """ Get a list of directories containing registered scenes + Parameters + ---------- + table: str + tablename, table must contain scene column Returns ------- list the directory names """ # ORM query, get all directories - scenes = self.Session().query(self.Data.scene) + table = self.get_class_by_tablename(table) + scenes = self.Session().query(table.scene) registered = [os.path.dirname(self.encode(x[0])) for x in scenes] return list(set(registered)) @@ -2509,7 +2608,7 @@ def move(self, scenelist, directory, pbar=False): log.info('The following scenes already exist at the target location:\n{}'.format('\n'.join(double))) def select(self, vectorobject=None, mindate=None, maxdate=None, processdir=None, - recursive=False, polarizations=None, **args): + recursive=False, polarizations=None, use_geometry=False, table='data', **args): """ select scenes from the database @@ -2528,6 +2627,8 @@ def select(self, vectorobject=None, mindate=None, maxdate=None, processdir=None, (only if `processdir` is not None) should also the subdirectories of the `processdir` be scanned? polarizations: list a list of polarization strings, e.g. ['HH', 'VV'] + use_geometry: bool + use the map_overlay as footprint instead of the bounding box **args: any further arguments (columns), which are registered in the database. See :meth:`~Archive.get_colnames()` @@ -2537,8 +2638,8 @@ def select(self, vectorobject=None, mindate=None, maxdate=None, processdir=None, the file names pointing to the selected scenes """ - arg_valid = [x for x in args.keys() if x in self.get_colnames()] - arg_invalid = [x for x in args.keys() if x not in self.get_colnames()] + arg_valid = [x for x in args.keys() if x in self.get_colnames(table)] + arg_invalid = [x for x in args.keys() if x not in self.get_colnames(table)] if len(arg_invalid) > 0: log.info('the following arguments will be ignored as they are not registered in the data base: {}'.format( ', '.join(arg_invalid))) @@ -2578,18 +2679,24 @@ def select(self, vectorobject=None, mindate=None, maxdate=None, processdir=None, if isinstance(vectorobject, Vector): vectorobject.reproject(4326) site_geom = vectorobject.convert2wkt(set3D=False)[0] + + if not use_geometry: + vector_id = 'bbox' + else: + log.info('Using precise footprints for selection.') + vector_id = 'geometry' + # postgres has a different way to store geometries if self.driver == 'postgresql': - arg_format.append("st_intersects(bbox, 'SRID=4326; {}')".format( - site_geom - )) + arg_format.append(F"st_intersects({vector_id}, 'SRID=4326; {site_geom}')") else: - arg_format.append('st_intersects(GeomFromText(?, 4326), bbox) = 1') + arg_format.append(F"st_intersects(GeomFromText(?, 4326), {vector_id}) = 1") vals.append(site_geom) + else: log.info('WARNING: argument vectorobject is ignored, must be of type spatialist.vector.Vector') - query = '''SELECT scene, outname_base FROM data WHERE {}'''.format(' AND '.join(arg_format)) + query = '''SELECT scene, outname_base FROM {} WHERE '''.format(table) + '''{}'''.format(' AND '.join(arg_format)) # the query gets assembled stepwise here for val in vals: query = query.replace('?', ''' '{0}' ''', 1).format(val) @@ -2777,6 +2884,7 @@ def drop_table(self, table): log.info('table {} dropped from database.'.format(table)) else: raise ValueError("table {} is not registered in the database!".format(table)) + self.meta = MetaData(self.engine) self.Base = automap_base(metadata=self.meta) self.Base.prepare(self.engine, reflect=True) diff --git a/tests/test_drivers.py b/tests/test_drivers.py index 5e328997..dd9ea94c 100644 --- a/tests/test_drivers.py +++ b/tests/test_drivers.py @@ -182,11 +182,33 @@ def test_archive(tmpdir, testdata): db.add_tables(mytable) assert 'mytable' in db.get_tablenames() + db.drop_table('mytable') + assert 'mytable' not in db.get_tablenames() with pytest.raises(TypeError): db.filter_scenelist([1]) db.close() +def test_archive_geometry(tmpdir, testdata): + dbfile = os.path.join(str(tmpdir), 'scenes.db') + db = pyroSAR.Archive(dbfile) + db.insert(testdata['s1']) + db.close() + db = pyroSAR.Archive(dbfile, add_geometry=True) + db.update_geometry_field() + db.insert(testdata['s1_2']) + db.close() + os.remove(dbfile) + dbfile = os.path.join(str(tmpdir), 'scenes_geo.db') + db_geo = pyroSAR.Archive(dbfile, add_geometry=True) + db_geo.insert(testdata['s1']) + db_geo.close() + db_geo = pyroSAR.Archive(dbfile) + db_geo.insert(testdata['s1_2']) + db_geo.close() + os.remove(dbfile) + + def test_archive2(tmpdir, testdata): dbfile = os.path.join(str(tmpdir), 'scenes.db') with pyroSAR.Archive(dbfile) as db: