Skip to content

Commit

Permalink
Inference - automatically recreate model trained with COCO (#929)
Browse files Browse the repository at this point in the history
* added get)classes method

* automatically create model trained on coco dataset

* dded models trained on coco
  • Loading branch information
ai-fast-track authored Sep 21, 2021
1 parent f0d1fe0 commit 9f23f35
Show file tree
Hide file tree
Showing 3 changed files with 343 additions and 84 deletions.
3 changes: 3 additions & 0 deletions icevision/core/class_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def __init__(
def num_classes(self):
return len(self)

def get_classes(self) -> Sequence[str]:
return self._id2class

def get_by_id(self, id: int) -> str:
return self._id2class[id]

Expand Down
162 changes: 133 additions & 29 deletions icevision/models/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,100 @@
load_state_dict,
)

# COCO Classes: 80 classes
CLASSES = (
"person",
"bicycle",
"car",
"motorcycle",
"airplane",
"bus",
"train",
"truck",
"boat",
"traffic light",
"fire hydrant",
"stop sign",
"parking meter",
"bench",
"bird",
"cat",
"dog",
"horse",
"sheep",
"cow",
"elephant",
"bear",
"zebra",
"giraffe",
"backpack",
"umbrella",
"handbag",
"tie",
"suitcase",
"frisbee",
"skis",
"snowboard",
"sports ball",
"kite",
"baseball bat",
"baseball glove",
"skateboard",
"surfboard",
"tennis racket",
"bottle",
"wine glass",
"cup",
"fork",
"knife",
"spoon",
"bowl",
"banana",
"apple",
"sandwich",
"orange",
"broccoli",
"carrot",
"hot dog",
"pizza",
"donut",
"cake",
"chair",
"couch",
"potted plant",
"bed",
"dining table",
"toilet",
"tv",
"laptop",
"mouse",
"remote",
"keyboard",
"cell phone",
"microwave",
"oven",
"toaster",
"sink",
"refrigerator",
"book",
"clock",
"vase",
"scissors",
"teddy bear",
"hair drier",
"toothbrush",
)


def save_icevision_checkpoint(
model,
model_name,
backbone_name,
class_map,
img_size,
filename,
optimizer=None,
meta=None,
model_name=None,
backbone_name=None,
classes=None,
img_size=None,
):
"""Save checkpoint to file.
Expand Down Expand Up @@ -49,11 +133,6 @@ def save_icevision_checkpoint(
elif not isinstance(meta, dict):
raise TypeError(f"meta must be a dict or None, but got {type(meta)}")

if class_map:
classes = class_map._id2class
else:
classes = None

if classes:
meta.update(classes=classes)

Expand All @@ -71,19 +150,25 @@ def save_icevision_checkpoint(

def model_from_checkpoint(
filename: Union[Path, str],
model_name=None,
backbone_name=None,
classes=None,
is_coco=False,
img_size=None,
map_location=None,
strict=False,
logger=None,
revise_keys=[(r"^module\.", "")],
revise_keys=[
(r"^module\.", ""),
],
eval_mode=True,
):
"""load checkpoint through URL scheme path.
Args:
filename (str): checkpoint file name with given prefix
map_location (str, optional): Same as :func:`torch.load`.
Default: None
logger (:mod:`logging.Logger`, optional): The logger for message.
Default: None
Returns:
dict or OrderedDict: The loaded checkpoint.
Expand All @@ -94,38 +179,54 @@ def model_from_checkpoint(

checkpoint = _load_checkpoint(filename)

class_map = None
num_classes = None
img_size = None
model_name = None
backbone = None
if is_coco and classes:
logger.warning(
"`is_coco` cannot be set to True if `classes` is passed and `not None`. `classes` has priority. `is_coco` will be ignored."
)

classes = checkpoint["meta"].get("classes", None)
if classes is None:
if is_coco:
classes = CLASSES
else:
classes = checkpoint["meta"].get("classes", None)

class_map = None
if classes:
class_map = ClassMap(checkpoint["meta"]["classes"])
class_map = ClassMap(classes)
num_classes = len(class_map)

img_size = checkpoint["meta"].get("img_size", None)
if img_size is None:
img_size = checkpoint["meta"].get("img_size", None)

if model_name is None:
model_name = checkpoint["meta"].get("model_name", None)

model_name = checkpoint["meta"].get("model_name", None)
model_type = None
if model_name:
lib, mod = model_name.split(".")
model_type = getattr(getattr(models, lib), mod)

backbone_name = checkpoint["meta"].get("backbone_name", None)
if backbone_name:
if backbone_name is None:
backbone_name = checkpoint["meta"].get("backbone_name", None)
if model_type and backbone_name:
backbone = getattr(model_type.backbones, backbone_name)

extra_args = {}
if img_size is None:
img_size = checkpoint["meta"].get("img_size", None)

models_with_img_size = ("yolov5", "efficientdet")
# if 'efficientdet' in model_name:
if any(m in model_name for m in models_with_img_size):
if (model_name) and (any(m in model_name for m in models_with_img_size)):
extra_args["img_size"] = img_size

# Instantiate model
model = model_type.model(
backbone=backbone(pretrained=False), num_classes=num_classes, **extra_args
)
if model_type and backbone:
model = model_type.model(
backbone=backbone(pretrained=False), num_classes=num_classes, **extra_args
)
else:
model = None

# OrderedDict is a subclass of dict
if not isinstance(checkpoint, dict):
Expand All @@ -140,7 +241,10 @@ def model_from_checkpoint(
state_dict = {re.sub(p, r, k): v for k, v in state_dict.items()}

# load state_dict
load_state_dict(model, state_dict, strict, logger)
if model:
load_state_dict(model, state_dict, strict, logger)
if eval_mode:
model.eval()

checkpoint_and_model = {
"model": model,
Expand Down
262 changes: 207 additions & 55 deletions notebooks/inference.ipynb

Large diffs are not rendered by default.

0 comments on commit 9f23f35

Please sign in to comment.