Skip to content

Commit

Permalink
Merge pull request #1433 from galthran-wq/batching
Browse files Browse the repository at this point in the history
Batching on `.represent` to improve performance and utilize GPU in full
  • Loading branch information
serengil authored Feb 16, 2025
2 parents 112d189 + f1734b2 commit ca73032
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 101 deletions.
11 changes: 7 additions & 4 deletions deepface/DeepFace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import warnings
import logging
from typing import Any, Dict, IO, List, Union, Optional
from typing import Any, Dict, IO, List, Union, Optional, Sequence

# this has to be set before importing tensorflow
os.environ["TF_USE_LEGACY_KERAS"] = "1"
Expand Down Expand Up @@ -376,7 +376,7 @@ def find(


def represent(
img_path: Union[str, np.ndarray, IO[bytes]],
img_path: Union[str, np.ndarray, IO[bytes], Sequence[Union[str, np.ndarray, IO[bytes]]]],
model_name: str = "VGG-Face",
enforce_detection: bool = True,
detector_backend: str = "opencv",
Expand All @@ -390,10 +390,13 @@ def represent(
Represent facial images as multi-dimensional vector embeddings.
Args:
img_path (str or np.ndarray or IO[bytes]): The exact path to the image, a numpy array
img_path (str, np.ndarray, IO[bytes], or Sequence[Union[str, np.ndarray, IO[bytes]]]):
The exact path to the image, a numpy array
in BGR format, a file object that supports at least `.read` and is opened in binary
mode, or a base64 encoded image. If the source image contains multiple faces,
the result will include information for each detected face.
the result will include information for each detected face. If a sequence is provided,
each element should be a string or numpy array representing an image, and the function
will process images in batch.
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet
Expand Down
9 changes: 7 additions & 2 deletions deepface/models/FacialRecognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,17 @@ class FacialRecognition(ABC):
input_shape: Tuple[int, int]
output_shape: int

def forward(self, img: np.ndarray) -> List[float]:
def forward(self, img: np.ndarray) -> Union[List[float], List[List[float]]]:
if not isinstance(self.model, Model):
raise ValueError(
"You must overwrite forward method if it is not a keras model,"
f"but {self.model_name} not overwritten!"
)
# model.predict causes memory issue when it is called in a for loop
# embedding = model.predict(img, verbose=0)[0].tolist()
return self.model(img, training=False).numpy()[0].tolist()
if img.shape == 4 and img.shape[0] == 1:
img = img[0]
embeddings = self.model(img, training=False).numpy()
if embeddings.shape[0] == 1:
return embeddings[0].tolist()
return embeddings.tolist()
27 changes: 13 additions & 14 deletions deepface/models/facial_recognition/Dlib.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# built-in dependencies
from typing import List
from typing import List, Union

# 3rd party dependencies
import numpy as np
Expand All @@ -26,35 +26,34 @@ def __init__(self):
self.input_shape = (150, 150)
self.output_shape = 128

def forward(self, img: np.ndarray) -> List[float]:
def forward(self, img: np.ndarray) -> Union[List[float], List[List[float]]]:
"""
Find embeddings with Dlib model.
This model necessitates the override of the forward method
because it is not a keras model.
Args:
img (np.ndarray): pre-loaded image in BGR
img (np.ndarray): pre-loaded image(s) in BGR
Returns
embeddings (list): multi-dimensional vector
embeddings (list of lists or list of floats): multi-dimensional vectors
"""
# return self.model.predict(img)[0].tolist()

# extract_faces returns 4 dimensional images
if len(img.shape) == 4:
img = img[0]
# Handle single image case
if len(img.shape) == 3:
img = np.expand_dims(img, axis=0)

# bgr to rgb
img = img[:, :, ::-1] # bgr to rgb
img = img[:, :, :, ::-1] # bgr to rgb

# img is in scale of [0, 1] but expected [0, 255]
if img.max() <= 1:
img = img * 255

img = img.astype(np.uint8)

img_representation = self.model.model.compute_face_descriptor(img)
img_representation = np.array(img_representation)
img_representation = np.expand_dims(img_representation, axis=0)
return img_representation[0].tolist()
embeddings = self.model.model.compute_face_descriptor(img)
embeddings = [np.array(embedding).tolist() for embedding in embeddings]
if len(embeddings) == 1:
return embeddings[0]
return embeddings


class DlibResNet:
Expand Down
19 changes: 11 additions & 8 deletions deepface/models/facial_recognition/SFace.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# built-in dependencies
from typing import Any, List
from typing import Any, List, Union

# 3rd party dependencies
import numpy as np
Expand Down Expand Up @@ -27,7 +27,7 @@ def __init__(self):
self.input_shape = (112, 112)
self.output_shape = 128

def forward(self, img: np.ndarray) -> List[float]:
def forward(self, img: np.ndarray) -> Union[List[float], List[List[float]]]:
"""
Find embeddings with SFace model
This model necessitates the override of the forward method
Expand All @@ -37,14 +37,17 @@ def forward(self, img: np.ndarray) -> List[float]:
Returns
embeddings (list): multi-dimensional vector
"""
# return self.model.predict(img)[0].tolist()
input_blob = (img * 255).astype(np.uint8)

# revert the image to original format and preprocess using the model
input_blob = (img[0] * 255).astype(np.uint8)
embeddings = []
for i in range(input_blob.shape[0]):
embedding = self.model.model.feature(input_blob[i])
embeddings.append(embedding)
embeddings = np.concatenate(embeddings, axis=0)

embeddings = self.model.model.feature(input_blob)

return embeddings[0].tolist()
if embeddings.shape[0] == 1:
return embeddings[0].tolist()
return embeddings.tolist()


def load_model(
Expand Down
13 changes: 9 additions & 4 deletions deepface/models/facial_recognition/VGGFace.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def __init__(self):
def forward(self, img: np.ndarray) -> List[float]:
"""
Generates embeddings using the VGG-Face model.
This method incorporates an additional normalization layer,
necessitating the override of the forward method.
This method incorporates an additional normalization layer.
Args:
img (np.ndarray): pre-loaded image in BGR
Expand All @@ -70,8 +69,14 @@ def forward(self, img: np.ndarray) -> List[float]:

# having normalization layer in descriptor troubles for some gpu users (e.g. issue 957, 966)
# instead we are now calculating it with traditional way not with keras backend
embedding = self.model(img, training=False).numpy()[0].tolist()
embedding = verification.l2_normalize(embedding)
embedding = super().forward(img)
if (
isinstance(embedding, list) and
isinstance(embedding[0], list)
):
embedding = verification.l2_normalize(embedding, axis=1)
else:
embedding = verification.l2_normalize(embedding)
return embedding.tolist()


Expand Down
165 changes: 96 additions & 69 deletions deepface/modules/representation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# built-in dependencies
from typing import Any, Dict, List, Union, Optional
from typing import Any, Dict, List, Union, Optional, Sequence, IO

# 3rd party dependencies
import numpy as np
Expand All @@ -11,7 +11,7 @@


def represent(
img_path: Union[str, np.ndarray],
img_path: Union[str, IO[bytes], np.ndarray, Sequence[Union[str, np.ndarray, IO[bytes]]]],
model_name: str = "VGG-Face",
enforce_detection: bool = True,
detector_backend: str = "opencv",
Expand All @@ -25,9 +25,11 @@ def represent(
Represent facial images as multi-dimensional vector embeddings.
Args:
img_path (str or np.ndarray): The exact path to the image, a numpy array in BGR format,
or a base64 encoded image. If the source image contains multiple faces, the result will
include information for each detected face.
img_path (str, np.ndarray, or Sequence[Union[str, np.ndarray]]):
The exact path to the image, a numpy array in BGR format,
a base64 encoded image, or a sequence of these.
If the source image contains multiple faces,
the result will include information for each detected face.
model_name (str): Model for face recognition. Options: VGG-Face, Facenet, Facenet512,
OpenFace, DeepFace, DeepID, Dlib, ArcFace, SFace and GhostFaceNet
Expand Down Expand Up @@ -70,70 +72,95 @@ def represent(
task="facial_recognition", model_name=model_name
)

# ---------------------------------
# we have run pre-process in verification. so, this can be skipped if it is coming from verify.
target_size = model.input_shape
if detector_backend != "skip":
img_objs = detection.extract_faces(
img_path=img_path,
detector_backend=detector_backend,
grayscale=False,
enforce_detection=enforce_detection,
align=align,
expand_percentage=expand_percentage,
anti_spoofing=anti_spoofing,
max_faces=max_faces,
)
else: # skip
# Try load. If load error, will raise exception internal
img, _ = image_utils.load_image(img_path)

if len(img.shape) != 3:
raise ValueError(f"Input img must be 3 dimensional but it is {img.shape}")

# make dummy region and confidence to keep compatibility with `extract_faces`
img_objs = [
{
"face": img,
"facial_area": {"x": 0, "y": 0, "w": img.shape[0], "h": img.shape[1]},
"confidence": 0,
}
]
# ---------------------------------

if max_faces is not None and max_faces < len(img_objs):
# sort as largest facial areas come first
img_objs = sorted(
img_objs,
key=lambda img_obj: img_obj["facial_area"]["w"] * img_obj["facial_area"]["h"],
reverse=True,
)
# discard rest of the items
img_objs = img_objs[0:max_faces]

for img_obj in img_objs:
if anti_spoofing is True and img_obj.get("is_real", True) is False:
raise ValueError("Spoof detected in the given image.")
img = img_obj["face"]

# bgr to rgb
img = img[:, :, ::-1]

region = img_obj["facial_area"]
confidence = img_obj["confidence"]

# resize to expected shape of ml model
img = preprocessing.resize_image(
img=img,
# thanks to DeepId (!)
target_size=(target_size[1], target_size[0]),
)

# custom normalization
img = preprocessing.normalize_input(img=img, normalization=normalization)

embedding = model.forward(img)

# Handle list of image paths or 4D numpy array
if isinstance(img_path, list):
images = img_path
elif isinstance(img_path, np.ndarray) and img_path.ndim == 4:
images = [img_path[i] for i in range(img_path.shape[0])]
else:
images = [img_path]

batch_images = []
batch_regions = []
batch_confidences = []

for single_img_path in images:
# ---------------------------------
# we have run pre-process in verification.
# so, this can be skipped if it is coming from verify.
target_size = model.input_shape
if detector_backend != "skip":
img_objs = detection.extract_faces(
img_path=single_img_path,
detector_backend=detector_backend,
grayscale=False,
enforce_detection=enforce_detection,
align=align,
expand_percentage=expand_percentage,
anti_spoofing=anti_spoofing,
max_faces=max_faces,
)
else: # skip
# Try load. If load error, will raise exception internal
img, _ = image_utils.load_image(single_img_path)

if len(img.shape) != 3:
raise ValueError(f"Input img must be 3 dimensional but it is {img.shape}")

# make dummy region and confidence to keep compatibility with `extract_faces`
img_objs = [
{
"face": img,
"facial_area": {"x": 0, "y": 0, "w": img.shape[0], "h": img.shape[1]},
"confidence": 0,
}
]
# ---------------------------------

if max_faces is not None and max_faces < len(img_objs):
# sort as largest facial areas come first
img_objs = sorted(
img_objs,
key=lambda img_obj: img_obj["facial_area"]["w"] * img_obj["facial_area"]["h"],
reverse=True,
)
# discard rest of the items
img_objs = img_objs[0:max_faces]

for img_obj in img_objs:
if anti_spoofing is True and img_obj.get("is_real", True) is False:
raise ValueError("Spoof detected in the given image.")
img = img_obj["face"]

# bgr to rgb
img = img[:, :, ::-1]

region = img_obj["facial_area"]
confidence = img_obj["confidence"]

# resize to expected shape of ml model
img = preprocessing.resize_image(
img=img,
# thanks to DeepId (!)
target_size=(target_size[1], target_size[0]),
)

# custom normalization
img = preprocessing.normalize_input(img=img, normalization=normalization)

batch_images.append(img)
batch_regions.append(region)
batch_confidences.append(confidence)

# Convert list of images to a numpy array for batch processing
batch_images = np.concatenate(batch_images, axis=0)

# Forward pass through the model for the entire batch
embeddings = model.forward(batch_images)
if len(batch_images) == 1:
embeddings = [embeddings]

for embedding, region, confidence in zip(embeddings, batch_regions, batch_confidences):
resp_objs.append(
{
"embedding": embedding,
Expand Down
Loading

0 comments on commit ca73032

Please sign in to comment.