Skip to content

Commit

Permalink
Tests for embedder_utils and small corrections
Browse files Browse the repository at this point in the history
  • Loading branch information
PrimozGodec committed Jan 14, 2020
1 parent d65a623 commit c08c1ec
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 17 deletions.
8 changes: 7 additions & 1 deletion orangecontrib/imageanalytics/tests/test_image_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ def test_on_successful_response(self):
len(self.embedder_server._embedder._cache._cache_dict), 1)

@patch(_HTTPX_POST_METHOD, make_dummy_post(b''))
def test_on_empty_response(self):
self.assertEqual(self.embedder_server(self.single_example), [None])
self.assertEqual(
len(self.embedder_server._embedder._cache._cache_dict), 0)

@patch(_HTTPX_POST_METHOD, make_dummy_post(b'blabla'))
def test_on_non_json_response(self):
self.assertEqual(self.embedder_server(self.single_example), [None])
self.assertEqual(
Expand Down Expand Up @@ -286,4 +292,4 @@ def test_connection_error(self, _):
for num_images in range(1, 20):
with self.assertRaises(ConnectionError):
self.embedder_server(self.single_example * num_images)
self.setUp() # to init new embedder
self.setUp() # to init new embedder
17 changes: 9 additions & 8 deletions orangecontrib/imageanalytics/utils/embedder_utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from io import BytesIO
from os.path import join, isfile
import logging
import hashlib
import pickle
import ftplib

import cachecontrol.caches
import requests
from PIL.Image import open as open_image, LANCZOS
from io import BytesIO
from os.path import join, isfile
from requests.exceptions import RequestException
from PIL import ImageFile
from urllib.parse import urlparse
from urllib.request import urlopen, URLError
from urllib.request import urlopen
from urllib.error import URLError

from PIL.Image import open as open_image, LANCZOS
from PIL import ImageFile
import numpy as np

from Orange.misc.environ import cache_dir
Expand Down Expand Up @@ -44,7 +45,7 @@ def __init__(self):
join(cache_dir(), __name__ + ".ImageEmbedder.httpcache"))
)

def load_image_or_none(self, file_path, target_size):
def load_image_or_none(self, file_path, target_size=None):
image = self._load_image_from_url_or_local_path(file_path)

if image is None:
Expand All @@ -60,7 +61,7 @@ def load_image_or_none(self, file_path, target_size):
image = image.resize(target_size, LANCZOS)
return image

def load_image_bytes(self, file_path, target_size):
def load_image_bytes(self, file_path, target_size=None):
image = self.load_image_or_none(file_path, target_size)
if image is None:
return None
Expand Down
Empty file.
173 changes: 173 additions & 0 deletions orangecontrib/imageanalytics/utils/tests/test_embedder_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import os
import unittest
from unittest.mock import patch
from urllib.error import URLError

from PIL.Image import Image
from requests import RequestException

from orangecontrib.imageanalytics.utils.embedder_utils import ImageLoader, \
EmbedderCache


TEST_IMAGES = [
"example_image_0.jpg",
"example_image_1.tiff",
"example_image_2.png"]


def image_name_to_path(im_name):
"""
Transform image names to absolute paths. All images must be in
orangeceontrib.imageanalytics.tests
"""
path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "..", "..", "tests")
path = os.path.abspath(path)
return os.path.join(path, im_name)


class TestImageLoader(unittest.TestCase):
def setUp(self) -> None:
self.image_loader = ImageLoader()

self.im_paths = [image_name_to_path(f) for f in TEST_IMAGES]
self.im_url = "http://file.biolab.si/images/bone-healing/D14/D14-" \
"0401-11-L1-inj-1-0016-m1.jpg"

def test_load_images(self) -> None:
image = self.image_loader.load_image_or_none(self.im_paths[0])
self.assertTrue(isinstance(image, Image))

image = self.image_loader.load_image_or_none(self.im_paths[0],
target_size=(255, 255))
self.assertTrue(isinstance(image, Image))
self.assertTupleEqual((255, 255), image.size)

def test_load_images_url(self) -> None:
"""
Handle loading images from http, https type urls
"""
image = self.image_loader.load_image_or_none(self.im_url)
self.assertTrue(isinstance(image, Image))

image = self.image_loader.load_image_or_none(self.im_paths[0],
target_size=(255, 255))
self.assertTrue(isinstance(image, Image))
self.assertTupleEqual((255, 255), image.size)

# invalid urls could be handled
image = self.image_loader.load_image_or_none(self.im_url + "a")
self.assertIsNone(image)

@patch("requests.sessions.Session.get", side_effect=RequestException)
def test_load_images_url_request_exception(self, _) -> None:
"""
Handle loading images from http, https type urls
"""
image = self.image_loader.load_image_or_none(self.im_url)
self.assertIsNone(image)

@patch(
"orangecontrib.imageanalytics.utils.embedder_utils.urlopen",
return_value=image_name_to_path(TEST_IMAGES[0]))
def test_load_images_ftp(self, _) -> None:
"""
Handle loading images from ftp, data type urls. Since we do not have
a ftp source we just change path to local path.
"""
image = self.image_loader.load_image_or_none("ftp://abcd")
self.assertTrue(isinstance(image, Image))

image = self.image_loader.load_image_or_none(self.im_paths[0],
target_size=(255, 255))
self.assertTrue(isinstance(image, Image))
self.assertTupleEqual((255, 255), image.size)

@patch(
"orangecontrib.imageanalytics.utils.embedder_utils.urlopen",
side_effect=URLError("wrong url"))
def test_load_images_ftp_error(self, _) -> None:
"""
Handle loading images from ftp, data type urls. Since we do not have
a ftp source we just change path to local path.
"""
image = self.image_loader.load_image_or_none("ftp://abcd")
self.assertIsNone(image)

def test_load_image_bytes(self) -> None:
for image in self.im_paths:
image_bytes = self.image_loader.load_image_bytes(image)
self.assertTrue(isinstance(image_bytes, bytes))

# one with wrong path to get none
image_bytes = self.image_loader.load_image_bytes(
self.im_paths[0] + "a")
self.assertIsNone(image_bytes)

@patch("PIL.Image.Image.convert", side_effect=ValueError())
def test_unsuccessful_convert_to_RGB(self, _) -> None:
image = self.image_loader.load_image_or_none(self.im_paths[2])
self.assertIsNone(image)


class TestEmbedderCache(unittest.TestCase):

def setUp(self) -> None:
self.cache = EmbedderCache("test_model")
self.cache.clear_cache() # make sure cache is empty

def test_save_and_load(self) -> None:
self.cache.add("test", "test")
self.cache.persist_cache()

# when initialing cache again it should load same cache
self.cache = EmbedderCache("test_model")
self.assertEqual("test", self.cache.get_cached_result_or_none("test"))

def test_clear_cache(self) -> None:
"""
Strategy 1: clear before persisting
"""
self.cache.add("test", "test")
self.cache.clear_cache()
self.cache.persist_cache()

self.cache = EmbedderCache("test_model")
self.assertIsNone(self.cache.get_cached_result_or_none("test"))

"""
Strategy 2: clear after persisting
"""
self.cache.add("test", "test")
self.cache.persist_cache()
self.cache.clear_cache()

self.cache = EmbedderCache("test_model")
self.assertIsNone(self.cache.get_cached_result_or_none("test"))

def test_get_cached_result_or_none(self) -> None:
self.assertIsNone(self.cache.get_cached_result_or_none("test"))
self.cache._cache_dict = {"test": "test1"}
self.assertEqual("test1", self.cache.get_cached_result_or_none("test"))

def test_add(self) -> None:
self.assertDictEqual(dict(), self.cache._cache_dict)
self.cache.add("test", "test1")
self.assertDictEqual({"test": "test1"}, self.cache._cache_dict)

@patch(
"orangecontrib.imageanalytics.utils.embedder_utils.EmbedderCache."
"load_pickle", side_effect=EOFError)
def test_unsuccessful_load(self, _) -> None:
self.cache.add("test", "test")
self.cache.persist_cache()

# since load was not succesdful it should be initialized as an empty
# dict
self.cache = EmbedderCache("test_model")
self.assertDictEqual({}, self.cache._cache_dict)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time

import numpy as np
from unittest import mock, skipIf

Expand Down Expand Up @@ -95,7 +96,8 @@ def test_no_connection(self, _):
table = load_images()
self.assertEqual(w.cb_embedder.currentText(), "Inception v3")
self.send_signal(w.Inputs.images, table)
self.wait_until_stop_blocking()
self.wait_until_finished()
self.wait_until_finished()
self.assertEqual(w.cb_embedder.currentText(), "SqueezeNet (local)")

output = self.get_output(self.widget.Outputs.embeddings)
Expand All @@ -116,7 +118,7 @@ def test_embedder_changed(self):
simulate.combobox_activate_index(cbox, 3)

self.assertEqual(w.cb_embedder.currentText(), "VGG-19")
self.wait_until_stop_blocking(wait=20000)
self.wait_until_finished(timeout=20000)

output = self.get_output(self.widget.Outputs.embeddings)
self.assertEqual(type(output), Table)
Expand All @@ -131,7 +133,7 @@ def test_not_image_data_attributes(self):
w = self.widget
table = Table("iris")
self.send_signal(w.Inputs.images, table)
self.wait_until_stop_blocking()
self.wait_until_finished()

# it should jut not chrash
cbox = self.widget.controls.cb_embedder_current_id
Expand All @@ -150,7 +152,7 @@ def test_cancel_embedding(self):
self.send_signal(self.widget.Inputs.images, table)
time.sleep(0.5)
self.widget.cancel_button.click()
self.wait_until_stop_blocking()
self.wait_until_finished()
results = self.get_output(self.widget.Outputs.embeddings)

self.assertIsNone(results)
Expand All @@ -167,11 +169,11 @@ def test_variable_make(self):

data = Table("https://datasets.biolab.si/core/bone-healing.xlsx")[::5]
self.send_signal(w.Inputs.images, data)
self.wait_until_stop_blocking()
self.wait_until_finished()
emb1 = self.get_output(self.widget.Outputs.embeddings)

self.send_signal(w.Inputs.images, data)
self.wait_until_stop_blocking()
self.wait_until_finished()
emb2 = self.get_output(self.widget.Outputs.embeddings)

self.assertTrue(
Expand All @@ -192,9 +194,8 @@ def test_unexpected_error(self, _):
table = load_images()
self.assertEqual(w.cb_embedder.currentText(), "Inception v3")
self.send_signal(w.Inputs.images, table)
self.wait_until_stop_blocking()
self.wait_until_finished()

output = self.get_output(self.widget.Outputs.embeddings)
self.assertIsNone(output)
self.widget.Error.unexpected_error.is_shown()
print(self.widget.Error.unexpected_error)

0 comments on commit c08c1ec

Please sign in to comment.