#!/usr/bin/env python # Martin Kersner, m.kersner@gmail.com # 2016/01/18 from __future__ import print_function import os import sys from skimage.io import imread import numpy as np from utils import get_id_classes, convert_from_color_segmentation def main(): ## ext = '.png' class_names = ['bird', 'bottle', 'chair'] ## path, txt_file = process_arguments(sys.argv) clear_class_logs(class_names) class_ids = get_id_classes(class_names) with open(txt_file, 'rb') as f: for img_name in f: img_name = img_name.strip() detected_class = contain_class(os.path.join(path, img_name)+ext, class_ids, class_names) if detected_class: log_class(img_name, detected_class) def clear_class_logs(class_names): for c in class_names: file_name = c + '.txt' if os.path.isfile(file_name): os.remove(file_name) def log_class(img_name, detected_class): with open(detected_class + '.txt', 'ab') as f: print(img_name, file=f) def contain_class(img_name, class_ids, class_names): img = imread(img_name) # If label is three-dimensional image we have to convert it to # corresponding labels (0 - 20). Currently anticipated labels are from # VOC pascal datasets. if (len(img.shape) > 2): img = convert_from_color_segmentation(img) for i,j in enumerate(class_ids): if j in np.unique(img): return class_names[i] return False def process_arguments(argv): if len(argv) != 3: help() dataset_segmentation_path = argv[1] list_of_images = argv[2] return dataset_segmentation_path, list_of_images def help(): print('Usage: python filter_images.py PATH LIST_FILE\n' 'PATH points to directory with segmentation image labels.\n' 'LIST_FILE denotes text file containing names of images in PATH.\n' 'Names do not include extension of images.' , file=sys.stderr) exit() if __name__ == '__main__': main()