Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature : Support .tflite Predictions #2

Merged
merged 2 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@ Run your fAIr Model Predictions anywhere !

## Prerequisites

- Install ```tensorflow-cpu```
fAIr Predictor has support for GPU , CPU and tflite based devices

- Install ```tensorflow-cpu``` or ```tflite-runtime``` according to your requirements

```tflite-runtime``` support is for having very light deployment in order to run inference &
```tensorflow-cpu``` might require installation of ```efficientnet```

## Example on Collab
```python
Expand Down
115 changes: 93 additions & 22 deletions predictor/prediction.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Standard library imports
import concurrent.futures
import os
import time
import uuid
Expand All @@ -7,10 +8,20 @@

# Third party imports
import numpy as np
from tensorflow import keras

try:
import tensorflow as tf
from tensorflow import keras
except ImportError:
try:
import tflite_runtime.interpreter as tflite
except ImportError:
raise ImportError(
"Neither TensorFlow nor TFLite is installed. Please install either TensorFlow or TFLite."
)

from .georeferencer import georeference
from .utils import open_images, remove_files, save_mask
from .utils import open_images_keras, open_images_pillow, remove_files, save_mask

BATCH_SIZE = 8
IMAGE_SIZE = 256
Expand Down Expand Up @@ -53,35 +64,95 @@ def run_prediction(
prediction_path = temp_dir
start = time.time()
print(f"Using : {checkpoint_path}")
model = keras.models.load_model(checkpoint_path)
if checkpoint_path.endswith(".tflite"):
interpreter = tflite.Interpreter(model_path=checkpoint_path)
interpreter.resize_tensor_input(
interpreter.get_input_details()[0]["index"], (BATCH_SIZE, 256, 256, 3)
)
interpreter.allocate_tensors()
input_tensor_index = interpreter.get_input_details()[0]["index"]
output = interpreter.tensor(interpreter.get_output_details()[0]["index"])
else:
model = keras.models.load_model(checkpoint_path)
print(f"It took {round(time.time()-start)} sec to load model")
start = time.time()

os.makedirs(prediction_path, exist_ok=True)
image_paths = glob(f"{input_path}/*.png")
if checkpoint_path.endswith(".tflite"):
for i in range((len(image_paths) + BATCH_SIZE - 1) // BATCH_SIZE):
image_batch = image_paths[BATCH_SIZE * i : BATCH_SIZE * (i + 1)]
if len(image_batch) < BATCH_SIZE:
interpreter.resize_tensor_input(
interpreter.get_input_details()[0]["index"], (1, 256, 256, 3)
)
interpreter.allocate_tensors()
input_tensor_index = interpreter.get_input_details()[0]["index"]
output = interpreter.tensor(
interpreter.get_output_details()[0]["index"]
)
for path in image_batch:
images = open_images_pillow([path])
images = images.reshape(-1, IMAGE_SIZE, IMAGE_SIZE, 3).astype(
np.float32
)
interpreter.set_tensor(input_tensor_index, images)
interpreter.invoke()
preds = output()
preds = np.argmax(preds, axis=-1)
preds = np.expand_dims(preds, axis=-1)
preds = np.where(
preds > confidence, 1, 0
) # Filter out low confidence predictions

save_mask(
preds[0],
str(f"{prediction_path}/{Path(path).stem}.png"),
)
else:
images = open_images_pillow(image_batch)
images = images.reshape(-1, IMAGE_SIZE, IMAGE_SIZE, 3).astype(
np.float32
)
interpreter.set_tensor(input_tensor_index, images)
interpreter.invoke()
preds = output()
preds = np.argmax(preds, axis=-1)
preds = np.expand_dims(preds, axis=-1)
preds = np.where(
preds > confidence, 1, 0
) # Filter out low confidence predictions

for idx, path in enumerate(image_batch):
save_mask(
preds[idx],
str(f"{prediction_path}/{Path(path).stem}.png"),
)

else:
for i in range((len(image_paths) + BATCH_SIZE - 1) // BATCH_SIZE):
image_batch = image_paths[BATCH_SIZE * i : BATCH_SIZE * (i + 1)]
images = open_images_keras(image_batch)
images = images.reshape(-1, IMAGE_SIZE, IMAGE_SIZE, 3)
preds = model.predict(images)
preds = np.argmax(preds, axis=-1)
preds = np.expand_dims(preds, axis=-1)
preds = np.where(
preds > confidence, 1, 0
) # Filter out low confidence predictions

for idx, path in enumerate(image_batch):
save_mask(
preds[idx],
str(f"{prediction_path}/{Path(path).stem}.png"),
)

for i in range((len(image_paths) + BATCH_SIZE - 1) // BATCH_SIZE):
image_batch = image_paths[BATCH_SIZE * i : BATCH_SIZE * (i + 1)]
images = open_images(image_batch)
images = images.reshape(-1, IMAGE_SIZE, IMAGE_SIZE, 3)

preds = model.predict(images)
preds = np.argmax(preds, axis=-1)
preds = np.expand_dims(preds, axis=-1)
preds = np.where(
preds > confidence, 1, 0
) # Filter out low confidence predictions

for idx, path in enumerate(image_batch):
save_mask(
preds[idx],
str(f"{prediction_path}/{Path(path).stem}.png"),
)
print(
f"It took {round(time.time()-start)} sec to predict with {confidence} Confidence Threshold"
)
keras.backend.clear_session()
del model
if not checkpoint_path.endswith(".tflite"):
keras.backend.clear_session()
del model
start = time.time()
georeference_path = os.path.join(prediction_path, "georeference")
georeference(
Expand Down
50 changes: 34 additions & 16 deletions predictor/utils.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
from typing import List
import concurrent.futures
import math
import os
import re
from glob import glob
from typing import List, Tuple

import geopandas
import numpy as np
import requests
from PIL import Image
from tensorflow import keras
from glob import glob
import os
import geopandas
import re
from shapely.geometry import box
import math
import requests
import concurrent.futures
from typing import Tuple


try:
from tensorflow import keras
except Exception as ex:
print("Unable to import tensorflow")

IMAGE_SIZE = 256


def get_start_end_download_coords(bbox_coords, zm_level, tile_size):

# start point where we will start downloading the tiles

start_point_lng = bbox_coords[0] # getting the starting lat lng
Expand Down Expand Up @@ -60,10 +60,10 @@ def download_image(url, base_path, source_name):
# filename = z-x-y
filename = f"{base_path}/{source_name}-{match.group(2)}-{match.group(3)}-{match.group(1)}.png"


with open(filename, "wb") as f:
f.write(image)


def convert2worldcd(lat, lng, tile_size):
"""
World coordinates are measured from the Mercator projection's origin
Expand All @@ -80,6 +80,7 @@ def convert2worldcd(lat, lng, tile_size):
# print("world coordinate space is %s, %s",world_x,world_y)
return world_x, world_y


def latlng2tile(zoom, lat, lng, tile_size):
"""By dividing the pixel coordinates by the tile size and taking the
integer parts of the result, you produce as a by-product the tile
Expand All @@ -92,6 +93,7 @@ def latlng2tile(zoom, lat, lng, tile_size):
t_y = math.floor((w_y * zoom_byte) / tile_size)
return t_x, t_y


def download_imagery(start: list, end: list, zm_level, base_path, source="maxar"):
"""Downloads imagery from start to end tile coordinate system

Expand Down Expand Up @@ -152,6 +154,7 @@ def download_imagery(start: list, end: list, zm_level, base_path, source="maxar"
for url in download_urls:
executor.submit(download_image, url, base_path, source_name)


def get_bounding_box(filename: str) -> Tuple[float, float, float, float]:
"""Get the EPSG:3857 coordinates of bounding box for the OAM image.

Expand All @@ -162,7 +165,7 @@ def get_bounding_box(filename: str) -> Tuple[float, float, float, float]:
Returns:
A tuple, (x_min, y_min, x_max, y_max), with coordinates in meters.
"""
filename = re.sub(r'\.(png|jpeg)$', '', filename)
filename = re.sub(r"\.(png|jpeg)$", "", filename)
_, *tile_info = re.split("-", filename)
x_tile, y_tile, zoom = map(int, tile_info)

Expand Down Expand Up @@ -205,7 +208,7 @@ def num2deg(x_tile: int, y_tile: int, zoom: int) -> Tuple[float, float]:
return lon_deg, lat_deg


def open_images(paths: List[str]) -> np.ndarray:
def open_images_keras(paths: List[str]) -> np.ndarray:
"""Open images from some given paths."""
images = []
for path in paths:
Expand All @@ -217,14 +220,29 @@ def open_images(paths: List[str]) -> np.ndarray:

return np.array(images)


def open_images_pillow(paths: List[str]) -> np.ndarray:
"""Open images from given paths using Pillow and resize them."""
images = []
for path in paths:
img = Image.open(path)
img = img.resize((IMAGE_SIZE, IMAGE_SIZE)).convert("RGB")
img_array = np.array(img, dtype=np.float32)
img_array = img_array.reshape(IMAGE_SIZE, IMAGE_SIZE, 3) / 255.0
images.append(img_array)

return np.array(images)


def remove_files(pattern: str) -> None:
"""Remove files matching a wildcard."""
files = glob(pattern)
for file in files:
os.remove(file)


def save_mask(mask: np.ndarray, filename: str) -> None:
"""Save the mask array to the specified location."""
reshaped_mask = mask.reshape((IMAGE_SIZE, IMAGE_SIZE)) * 255
result = Image.fromarray(reshaped_mask.astype(np.uint8))
result.save(filename)
result.save(filename)
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
"Programming Language :: Python :: 3.9",
],
install_requires=[
"requests",
"Pillow",
"rtree>=1.0.0,<=1.1.0",
"tqdm>=4.0.0,<=4.62.3",
"geopandas>=0.14.0,<=0.14.5",
Expand Down
19 changes: 10 additions & 9 deletions tests/app_test.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import os
import shutil

import requests
from efficientnet.tfkeras import EfficientNetB4

from predictor import predict

model_path = "checkpoint.h5"
# import requests
# from efficientnet.tfkeras import EfficientNetB4


model_path = "checkpoint.tflite"
bbox = [100.56228021333352, 13.685230854641182, 100.56383321235313, 13.685961853747969]
zoom_level = 20
tms_url = "https://tiles.openaerialmap.org/6501a65c0906de000167e64d/0/6501a65c0906de000167e64e/{z}/{x}/{y}"

if not os.path.exists(model_path):
url = "https://fair-dev.hotosm.org/api/v1/workspace/download/dataset_65/output/training_297/checkpoint.h5"
response = requests.get(url, stream=True)
with open(model_path, "wb") as out_file:
shutil.copyfileobj(response.raw, out_file)
# if not os.path.exists(model_path):
# url = "https://fair-dev.hotosm.org/api/v1/workspace/download/dataset_65/output/training_297/checkpoint.h5"
# response = requests.get(url, stream=True)
# with open(model_path, "wb") as out_file:
# shutil.copyfileobj(response.raw, out_file)


bbox = [100.56228021333352, 13.685230854641182, 100.56383321235313, 13.685961853747969]
Expand Down