Skip to content

Commit

Permalink
Fix --list-algorithms using path names instead of algorithm names (fi… (
Browse files Browse the repository at this point in the history
#569)

* fix --list-algorithms using path names instead of algorithm names (fixes #555)
  • Loading branch information
maumueller authored Jan 21, 2025
1 parent e38c914 commit 2331417
Showing 1 changed file with 21 additions and 11 deletions.
32 changes: 21 additions & 11 deletions ann_benchmarks/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,16 +145,15 @@ def load_configs(point_type: str, base_dir: str = "ann_benchmarks/algorithms") -
print(f"Error loading YAML from {config_file}: {e}")
return configs

def _get_definitions(base_dir: str = "ann_benchmarks/algorithms") -> Dict[str, Dict[str, Any]]:
"""Load algorithm configurations for a given point_type."""
def _get_definitions(base_dir: str = "ann_benchmarks/algorithms") -> List[Dict[str, Any]]:
"""Load algorithm configurations."""
config_files = get_config_files(base_dir=base_dir)
configs = {}
configs = []
for config_file in config_files:
with open(config_file, 'r') as stream:
try:
config_data = yaml.safe_load(stream)
algorithm_name = os.path.basename(os.path.dirname(config_file))
configs[algorithm_name] = config_data
configs.append(config_data)
except yaml.YAMLError as e:
print(f"Error loading YAML from {config_file}: {e}")
return configs
Expand Down Expand Up @@ -211,16 +210,27 @@ def list_algorithms(base_dir: str = "ann_benchmarks/algorithms") -> None:
base_dir (str, optional): The base directory where the algorithms are stored.
Defaults to "ann_benchmarks/algorithms".
"""
definitions = _get_definitions(base_dir)

print("The following algorithms are supported...", definitions)
for algorithm in definitions:
all_configs = _get_definitions(base_dir)
data = {}
for algo_configs in all_configs:
for point_type, config_for_point_type in algo_configs.items():
for metric, ccc in config_for_point_type.items():
algo_name = ccc[0]["name"]
if algo_name not in data:
data[algo_name] = {}
if point_type not in data[algo_name]:
data[algo_name][point_type] = []
data[algo_name][point_type].append(metric)

print("The following algorithms are supported:", ", ".join(data))
print("Details of supported metrics and data types: ")
for algorithm in data:
print('\t... for the algorithm "%s"...' % algorithm)

for point_type in definitions[algorithm]:
for point_type in data[algorithm]:
print('\t\t... and the point type "%s", metrics: ' % point_type)

for metric in definitions[algorithm][point_type]:
for metric in data[algorithm][point_type]:
print("\t\t\t%s" % metric)


Expand Down

0 comments on commit 2331417

Please sign in to comment.