From 3f47e9f2fe5aa13ee77b4138adc0bbd2ff438e17 Mon Sep 17 00:00:00 2001 From: Louis Dupont Date: Mon, 13 Nov 2023 14:46:08 +0100 Subject: [PATCH 1/2] hotfix --- src/data_gradients/dataset_adapters/config/data_config.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/data_gradients/dataset_adapters/config/data_config.py b/src/data_gradients/dataset_adapters/config/data_config.py index 50da8ea9..1b83a22e 100644 --- a/src/data_gradients/dataset_adapters/config/data_config.py +++ b/src/data_gradients/dataset_adapters/config/data_config.py @@ -152,7 +152,10 @@ def _fill_missing_params(self, json_dict: JSONDict): if self.n_classes is None: self.n_classes = json_dict.get("n_classes") if self.class_names is None: - self.class_names = json_dict.get("class_names") + class_names = json_dict.get("class_names") + if isinstance(class_names, dict): + class_names = {int(k): v for k, v in class_names.items()} + self.class_names = class_names if self.class_names_to_use is None: self.class_names_to_use = json_dict.get("class_names_to_use") if self.image_channels is None: @@ -342,7 +345,7 @@ def _represents_int(s: str) -> bool: validation=lambda answer: _represents_int(answer) and int(answer) > 0, ) n_classes = int(question.ask()) - return {f"class_{i}": i for i in range(n_classes)} + return {i: f"class_{i}" for i in range(n_classes)} elif class_names: if isinstance(class_names, list): return dict(zip(range(len(class_names)), class_names)) From f4b905ead8f00274c67efcb141aec89cd5396afe Mon Sep 17 00:00:00 2001 From: Louis Dupont Date: Tue, 14 Nov 2023 10:14:24 +0100 Subject: [PATCH 2/2] dic -> Mapping --- src/data_gradients/dataset_adapters/config/data_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/data_gradients/dataset_adapters/config/data_config.py b/src/data_gradients/dataset_adapters/config/data_config.py index 1b83a22e..e9ac4d3b 100644 --- a/src/data_gradients/dataset_adapters/config/data_config.py +++ b/src/data_gradients/dataset_adapters/config/data_config.py @@ -6,7 +6,7 @@ import torch from abc import ABC from dataclasses import dataclass -from typing import Dict, Optional, Callable, Union, List +from typing import Dict, Optional, Callable, Union, List, Mapping import data_gradients from data_gradients.dataset_adapters.config.questions import FixedOptionsQuestion, OpenEndedQuestion, text_to_yellow @@ -153,7 +153,7 @@ def _fill_missing_params(self, json_dict: JSONDict): self.n_classes = json_dict.get("n_classes") if self.class_names is None: class_names = json_dict.get("class_names") - if isinstance(class_names, dict): + if isinstance(class_names, Mapping): class_names = {int(k): v for k, v in class_names.items()} self.class_names = class_names if self.class_names_to_use is None: