diff --git a/deepface/DeepFace.py b/deepface/DeepFace.py index 3abe6db97..5ae05aadd 100644 --- a/deepface/DeepFace.py +++ b/deepface/DeepFace.py @@ -174,7 +174,7 @@ def analyze( expand_percentage: int = 0, silent: bool = False, anti_spoofing: bool = False, -) -> List[Dict[str, Any]]: +) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]: """ Analyze facial attributes such as age, gender, emotion, and race in the provided image. Args: @@ -206,7 +206,10 @@ def analyze( anti_spoofing (boolean): Flag to enable anti spoofing (default is False). Returns: - results (List[Dict[str, Any]]): A list of dictionaries, where each dictionary represents + (List[List[Dict[str, Any]]]): A list of analysis results if received batched image, + explained below. + + (List[Dict[str, Any]]): A list of dictionaries, where each dictionary represents the analysis results for a detected face. Each dictionary in the list contains the following keys: @@ -253,6 +256,29 @@ def analyze( - 'middle eastern': Confidence score for Middle Eastern ethnicity. - 'white': Confidence score for White ethnicity. """ + + if isinstance(img_path, np.ndarray) and len(img_path.shape) == 4: + # Received 4-D array, which means image batch. + # Check batch dimension and process each image separately. + if img_path.shape[0] > 1: + batch_resp_obj = [] + # Execute analysis for each image in the batch. + for single_img in img_path: + resp_obj = demography.analyze( + img_path=single_img, + actions=actions, + enforce_detection=enforce_detection, + detector_backend=detector_backend, + align=align, + expand_percentage=expand_percentage, + silent=silent, + anti_spoofing=anti_spoofing, + ) + + # Append the response object to the batch response list. + batch_resp_obj.append(resp_obj) + return batch_resp_obj + return demography.analyze( img_path=img_path, actions=actions, diff --git a/deepface/models/Demography.py b/deepface/models/Demography.py index ad9392029..1493059b3 100644 --- a/deepface/models/Demography.py +++ b/deepface/models/Demography.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, List from abc import ABC, abstractmethod import numpy as np from deepface.commons import package_utils @@ -18,5 +18,53 @@ class Demography(ABC): model_name: str @abstractmethod - def predict(self, img: np.ndarray) -> Union[np.ndarray, np.float64]: + def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.ndarray, np.float64]: pass + + def _predict_internal(self, img_batch: np.ndarray) -> np.ndarray: + """ + Predict for single image or batched images. + This method uses legacy method while receiving single image as input. + And switch to batch prediction if receives batched images. + + Args: + img_batch: + Batch of images as np.ndarray (n, x, y, c) + with n >= 1, x = image width, y = image height, c = channel + Or Single image as np.ndarray (1, x, y, c) + with x = image width, y = image height and c = channel + The channel dimension will be 1 if input is grayscale. (For emotion model) + """ + if not self.model_name: # Check if called from derived class + raise NotImplementedError("no model selected") + assert img_batch.ndim == 4, "expected 4-dimensional tensor input" + + if img_batch.shape[0] == 1: # Single image + # Predict with legacy method. + return self.model(img_batch, training=False).numpy()[0, :] + + # Batch of images + # Predict with batch prediction + return self.model.predict_on_batch(img_batch) + + def _preprocess_batch_or_single_input( + self, + img: Union[np.ndarray, List[np.ndarray]] + ) -> np.ndarray: + + """ + Preprocess single or batch of images, return as 4-D numpy array. + Args: + img: Single image as np.ndarray (224, 224, 3) or + List of images as List[np.ndarray] or + Batch of images as np.ndarray (n, 224, 224, 3) + Returns: + Four-dimensional numpy array (n, 224, 224, 3) + """ + image_batch = np.array(img) + + # Check input dimension + if len(image_batch.shape) == 3: + # Single image - add batch dimension + image_batch = np.expand_dims(image_batch, axis=0) + return image_batch diff --git a/deepface/models/demography/Age.py b/deepface/models/demography/Age.py index 67ab3ae65..c96015919 100644 --- a/deepface/models/demography/Age.py +++ b/deepface/models/demography/Age.py @@ -1,3 +1,7 @@ +# stdlib dependencies + +from typing import List, Union + # 3rd party dependencies import numpy as np @@ -37,11 +41,29 @@ def __init__(self): self.model = load_model() self.model_name = "Age" - def predict(self, img: np.ndarray) -> np.float64: - # model.predict causes memory issue when it is called in a for loop - # age_predictions = self.model.predict(img, verbose=0)[0, :] - age_predictions = self.model(img, training=False).numpy()[0, :] - return find_apparent_age(age_predictions) + def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.float64, np.ndarray]: + """ + Predict apparent age(s) for single or multiple faces + Args: + img: Single image as np.ndarray (224, 224, 3) or + List of images as List[np.ndarray] or + Batch of images as np.ndarray (n, 224, 224, 3) + Returns: + np.ndarray (age_classes,) if single image, + np.ndarray (n, age_classes) if batched images. + """ + # Preprocessing input image or image list. + imgs = self._preprocess_batch_or_single_input(img) + + # Prediction from 3 channels image + age_predictions = self._predict_internal(imgs) + + # Calculate apparent ages + if len(age_predictions.shape) == 1: # Single prediction list + return find_apparent_age(age_predictions) + + return np.array([ + find_apparent_age(age_prediction) for age_prediction in age_predictions]) def load_model( @@ -65,7 +87,7 @@ def load_model( # -------------------------- - age_model = Model(inputs=model.input, outputs=base_model_output) + age_model = Model(inputs=model.inputs, outputs=base_model_output) # -------------------------- @@ -78,15 +100,16 @@ def load_model( return age_model - def find_apparent_age(age_predictions: np.ndarray) -> np.float64: """ Find apparent age prediction from a given probas of ages Args: - age_predictions (?) + age_predictions (age_classes,) Returns: apparent_age (float) """ + assert len(age_predictions.shape) == 1, f"Input should be a list of predictions, \ + not batched. Got shape: {age_predictions.shape}" output_indexes = np.arange(0, 101) apparent_age = np.sum(age_predictions * output_indexes) return apparent_age diff --git a/deepface/models/demography/Emotion.py b/deepface/models/demography/Emotion.py index d2633b519..499c246cf 100644 --- a/deepface/models/demography/Emotion.py +++ b/deepface/models/demography/Emotion.py @@ -1,3 +1,6 @@ +# stdlib dependencies +from typing import List, Union + # 3rd party dependencies import numpy as np import cv2 @@ -43,16 +46,38 @@ def __init__(self): self.model = load_model() self.model_name = "Emotion" - def predict(self, img: np.ndarray) -> np.ndarray: - img_gray = cv2.cvtColor(img[0], cv2.COLOR_BGR2GRAY) + def _preprocess_image(self, img: np.ndarray) -> np.ndarray: + """ + Preprocess single image for emotion detection + Args: + img: Input image (224, 224, 3) + Returns: + Preprocessed grayscale image (48, 48) + """ + img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) img_gray = cv2.resize(img_gray, (48, 48)) - img_gray = np.expand_dims(img_gray, axis=0) - - # model.predict causes memory issue when it is called in a for loop - # emotion_predictions = self.model.predict(img_gray, verbose=0)[0, :] - emotion_predictions = self.model(img_gray, training=False).numpy()[0, :] - - return emotion_predictions + return img_gray + + def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray: + """ + Predict emotion probabilities for single or multiple faces + Args: + img: Single image as np.ndarray (224, 224, 3) or + List of images as List[np.ndarray] or + Batch of images as np.ndarray (n, 224, 224, 3) + Returns: + np.ndarray (n, n_emotions) + where n_emotions is the number of emotion categories + """ + # Preprocessing input image or image list. + imgs = self._preprocess_batch_or_single_input(img) + + processed_imgs = np.expand_dims(np.array([self._preprocess_image(img) for img in imgs]), axis=-1) + + # Prediction + predictions = self._predict_internal(processed_imgs) + + return predictions def load_model( diff --git a/deepface/models/demography/Gender.py b/deepface/models/demography/Gender.py index ad1c15e3c..b6a3ef1c3 100644 --- a/deepface/models/demography/Gender.py +++ b/deepface/models/demography/Gender.py @@ -1,3 +1,7 @@ +# stdlib dependencies + +from typing import List, Union + # 3rd party dependencies import numpy as np @@ -37,11 +41,23 @@ def __init__(self): self.model = load_model() self.model_name = "Gender" - def predict(self, img: np.ndarray) -> np.ndarray: - # model.predict causes memory issue when it is called in a for loop - # return self.model.predict(img, verbose=0)[0, :] - return self.model(img, training=False).numpy()[0, :] + def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray: + """ + Predict gender probabilities for single or multiple faces + Args: + img: Single image as np.ndarray (224, 224, 3) or + List of images as List[np.ndarray] or + Batch of images as np.ndarray (n, 224, 224, 3) + Returns: + np.ndarray (n, 2) + """ + # Preprocessing input image or image list. + imgs = self._preprocess_batch_or_single_input(img) + + # Prediction + predictions = self._predict_internal(imgs) + return predictions def load_model( url=WEIGHTS_URL, @@ -64,7 +80,7 @@ def load_model( # -------------------------- - gender_model = Model(inputs=model.input, outputs=base_model_output) + gender_model = Model(inputs=model.inputs, outputs=base_model_output) # -------------------------- diff --git a/deepface/models/demography/Race.py b/deepface/models/demography/Race.py index 2334c8b46..eae5154cc 100644 --- a/deepface/models/demography/Race.py +++ b/deepface/models/demography/Race.py @@ -1,3 +1,6 @@ +# stdlib dependencies +from typing import List, Union + # 3rd party dependencies import numpy as np @@ -37,10 +40,24 @@ def __init__(self): self.model = load_model() self.model_name = "Race" - def predict(self, img: np.ndarray) -> np.ndarray: - # model.predict causes memory issue when it is called in a for loop - # return self.model.predict(img, verbose=0)[0, :] - return self.model(img, training=False).numpy()[0, :] + def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray: + """ + Predict race probabilities for single or multiple faces + Args: + img: Single image as np.ndarray (224, 224, 3) or + List of images as List[np.ndarray] or + Batch of images as np.ndarray (n, 224, 224, 3) + Returns: + np.ndarray (n, n_races) + where n_races is the number of race categories + """ + # Preprocessing input image or image list. + imgs = self._preprocess_batch_or_single_input(img) + + # Prediction + predictions = self._predict_internal(imgs) + + return predictions def load_model( @@ -62,7 +79,7 @@ def load_model( # -------------------------- - race_model = Model(inputs=model.input, outputs=base_model_output) + race_model = Model(inputs=model.inputs, outputs=base_model_output) # -------------------------- diff --git a/tests/test_analyze.py b/tests/test_analyze.py index bad44260e..a36acc5d1 100644 --- a/tests/test_analyze.py +++ b/tests/test_analyze.py @@ -1,8 +1,10 @@ # 3rd party dependencies import cv2 +import numpy as np # project dependencies from deepface import DeepFace +from deepface.models.demography import Age, Emotion, Gender, Race from deepface.commons.logger import Logger logger = Logger() @@ -16,6 +18,7 @@ def test_standard_analyze(): demography_objs = DeepFace.analyze(img, silent=True) for demography in demography_objs: logger.debug(demography) + assert type(demography) == dict assert demography["age"] > 20 and demography["age"] < 40 assert demography["dominant_gender"] == "Woman" logger.info("✅ test standard analyze done") @@ -29,6 +32,7 @@ def test_analyze_with_all_actions_as_tuple(): for demography in demography_objs: logger.debug(f"Demography: {demography}") + assert type(demography) == dict age = demography["age"] gender = demography["dominant_gender"] race = demography["dominant_race"] @@ -53,6 +57,7 @@ def test_analyze_with_all_actions_as_list(): for demography in demography_objs: logger.debug(f"Demography: {demography}") + assert type(demography) == dict age = demography["age"] gender = demography["dominant_gender"] race = demography["dominant_race"] @@ -74,6 +79,7 @@ def test_analyze_for_some_actions(): demography_objs = DeepFace.analyze(img, ["age", "gender"], silent=True) for demography in demography_objs: + assert type(demography) == dict age = demography["age"] gender = demography["dominant_gender"] @@ -95,6 +101,7 @@ def test_analyze_for_preloaded_image(): resp_objs = DeepFace.analyze(img, silent=True) for resp_obj in resp_objs: logger.debug(resp_obj) + assert type(resp_obj) == dict assert resp_obj["age"] > 20 and resp_obj["age"] < 40 assert resp_obj["dominant_gender"] == "Woman" @@ -131,7 +138,73 @@ def test_analyze_for_different_detectors(): ] # validate probabilities + assert type(result) == dict if result["dominant_gender"] == "Man": assert result["gender"]["Man"] > result["gender"]["Woman"] else: assert result["gender"]["Man"] < result["gender"]["Woman"] + +def test_analyze_for_batched_image(): + img = "dataset/img4.jpg" + # Copy and combine the same image to create multiple faces + img = cv2.imread(img) + img = np.stack([img, img]) + assert len(img.shape) == 4 # Check dimension. + assert img.shape[0] == 2 # Check batch size. + + demography_batch = DeepFace.analyze(img, silent=True) + # 2 image in batch, so 2 demography objects. + assert len(demography_batch) == 2 + + for demography_objs in demography_batch: + assert len(demography_objs) == 1 # 1 face in each image + for demography in demography_objs: # Iterate over faces + assert type(demography) == dict # Check type + assert demography["age"] > 20 and demography["age"] < 40 + assert demography["dominant_gender"] == "Woman" + logger.info("✅ test analyze for multiple faces done") + +def test_batch_detect_age_for_multiple_faces(): + # Load test image and resize to model input size + img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224)) + imgs = [img, img] + results = Age.ApparentAgeClient().predict(imgs) + # Check there are two ages detected + assert len(results) == 2 + # Check two faces ages are the same in integer format(e.g. 23.6 -> 23) + # Must use int() to compare because of max float precision issue in different platforms + assert np.array_equal(int(results[0]), int(results[1])) + logger.info("✅ test batch detect age for multiple faces done") + +def test_batch_detect_emotion_for_multiple_faces(): + # Load test image and resize to model input size + img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224)) + imgs = [img, img] + results = Emotion.EmotionClient().predict(imgs) + # Check there are two emotions detected + assert len(results) == 2 + # Check two faces emotions are the same + assert np.array_equal(results[0], results[1]) + logger.info("✅ test batch detect emotion for multiple faces done") + +def test_batch_detect_gender_for_multiple_faces(): + # Load test image and resize to model input size + img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224)) + imgs = [img, img] + results = Gender.GenderClient().predict(imgs) + # Check there are two genders detected + assert len(results) == 2 + # Check two genders are the same + assert np.array_equal(results[0], results[1]) + logger.info("✅ test batch detect gender for multiple faces done") + +def test_batch_detect_race_for_multiple_faces(): + # Load test image and resize to model input size + img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224)) + imgs = [img, img] + results = Race.RaceClient().predict(imgs) + # Check there are two races detected + assert len(results) == 2 + # Check two races are the same + assert np.array_equal(results[0], results[1]) + logger.info("✅ test batch detect race for multiple faces done") \ No newline at end of file