Skip to content

Commit

Permalink
Backward compatibility for the previous pair format
Browse files Browse the repository at this point in the history
  • Loading branch information
sarlinpe committed Mar 1, 2022
1 parent 1a63832 commit 9c73f56
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 25 deletions.
42 changes: 29 additions & 13 deletions hloc/match_features.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import argparse
from typing import Union, Optional, Dict
from typing import Union, Optional, Dict, List, Tuple
from pathlib import Path
import pprint
import collections.abc as collections
Expand All @@ -9,7 +9,7 @@

from . import matchers, logger
from .utils.base_model import dynamic_load
from .utils.parsers import names_to_pair, parse_retrieval
from .utils.parsers import names_to_pair, names_to_pair_old, parse_retrieval
from .utils.io import list_h5_names


Expand Down Expand Up @@ -95,6 +95,27 @@ def main(conf: Dict,
return matches


def find_pairs_to_match(pairs_all: List[Tuple[str]], match_path: Path = None):
'''Avoid to recompute duplicates to save time.'''
pairs = set()
for i, j in pairs_all:
if (j, i) not in pairs:
pairs.add(i, j)
pairs = list(pairs)
if match_path is not None and match_path.exists():
with h5py.File(str(match_path), 'r') as fd:
pairs_filtered = []
for i, j in pairs:
if (names_to_pair(i, j) in fd or
names_to_pair(j, i) in fd or
names_to_pair_old(i, j) in fd or
names_to_pair_old(j, i) in fd):
continue
pairs_filtered.append((i, j))
return pairs_filtered
return pairs


@torch.no_grad()
def match_from_paths(conf: Dict,
pairs_path: Path,
Expand All @@ -112,25 +133,21 @@ def match_from_paths(conf: Dict,
raise FileNotFoundError(f'Reference feature file {path}.')
name2ref = {n: i for i, p in enumerate(feature_paths_refs)
for n in list_h5_names(p)}
match_path.parent.mkdir(exist_ok=True, parents=True)

assert pairs_path.exists(), pairs_path
pairs = parse_retrieval(pairs_path)
pairs = [(q, r) for q, rs in pairs.items() for r in rs]
pairs = find_pairs_to_match(pairs, exclude=None if overwrite else match_path)
if len(pairs) == 0:
logger.info('Skipping the matching.')
return

device = 'cuda' if torch.cuda.is_available() else 'cpu'
Model = dynamic_load(matchers, conf['model']['name'])
model = Model(conf['model']).eval().to(device)

match_path.parent.mkdir(exist_ok=True, parents=True)
skip_pairs = set(list_h5_names(match_path)
if match_path.exists() and not overwrite else ())

for (name0, name1) in tqdm(pairs, smoothing=.1):
pair = names_to_pair(name0, name1)
# Avoid to recompute duplicates to save time
if pair in skip_pairs or names_to_pair(name1, name0) in skip_pairs:
continue

data = {}
with h5py.File(str(feature_path_q), 'r') as fd:
grp = fd[name0]
Expand All @@ -146,6 +163,7 @@ def match_from_paths(conf: Dict,
data = {k: v[None] for k, v in data.items()}

pred = model(data)
pair = names_to_pair(name0, name1)
with h5py.File(str(match_path), 'a') as fd:
if pair in fd:
del fd[pair]
Expand All @@ -157,8 +175,6 @@ def match_from_paths(conf: Dict,
scores = pred['matching_scores0'][0].cpu().half().numpy()
grp.create_dataset('matching_scores0', data=scores)

skip_pairs.add(pair)

logger.info('Finished exporting matches.')


Expand Down
31 changes: 21 additions & 10 deletions hloc/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import cv2
import h5py

from .parsers import names_to_pair
from .parsers import names_to_pair, names_to_pair_old


def read_image(path, grayscale=False):
Expand Down Expand Up @@ -36,17 +36,28 @@ def get_keypoints(path: Path, name: str) -> np.ndarray:
return p


def find_pair(hfile: h5py.File, name0: str, name1: str):
pair = names_to_pair(name0, name1)
if pair in hfile:
return pair, False
pair = names_to_pair(name1, name0)
if pair in hfile:
return pair, True
# older, less efficient format
pair = names_to_pair_old(name0, name1)
if pair in hfile:
return pair, False
pair = names_to_pair_old(name1, name0)
if pair in hfile:
return pair, True
raise ValueError(
f'Could not find pair {(name0, name1)}... '
'Maybe you matched with a different list of pairs? ')


def get_matches(path: Path, name0: str, name1: str) -> Tuple[np.ndarray]:
with h5py.File(str(path), 'r') as hfile:
reverse = False
pair = names_to_pair(name0, name1)
if pair not in hfile:
pair = names_to_pair(name1, name0)
if pair not in hfile:
raise ValueError(
f'Could not find pair {(name0, name1)}... '
'Maybe you matched with a different list of pairs? ')
reverse = True
reverse, pair = find_pair(hfile, name0, name1)
matches = hfile[pair]['matches0'].__array__()
scores = hfile[pair]['matching_scores0'].__array__()
idx = np.where(matches != -1)[0]
Expand Down
8 changes: 6 additions & 2 deletions hloc/utils/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,9 @@ def parse_retrieval(path):
return dict(retrieval)


def names_to_pair(name0, name1):
return '/'.join((name0.replace('/', '-'), name1.replace('/', '-')))
def names_to_pair(name0, name1, separator='/'):
return separator.join((name0.replace('/', '-'), name1.replace('/', '-')))


def names_to_pair_old(name0, name1):
return names_to_pair(name0, name1, separator='_')

0 comments on commit 9c73f56

Please sign in to comment.