Skip to content

Commit

Permalink
update utility functions
Browse files Browse the repository at this point in the history
Former-commit-id: 576c987
  • Loading branch information
FecoDoo committed Apr 1, 2022
1 parent c4de667 commit 206135d
Show file tree
Hide file tree
Showing 9 changed files with 142 additions and 74 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -169,4 +169,8 @@ data/*
misc/
logs/
models/
custom/
custom/
utils/predictors/
videos/
images/
.cache
31 changes: 27 additions & 4 deletions core/callbacks/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,33 @@ def __init__(


class TensorBoard(TensorBoard):
def __init__(self, log_dir='logs', histogram_freq=0, write_graph=True, write_images=False, write_steps_per_second=False, update_freq='epoch', profile_batch=2, embeddings_freq=0, embeddings_metadata=None, **kwargs):
super().__init__(log_dir, histogram_freq, write_graph, write_images, write_steps_per_second, update_freq, profile_batch, embeddings_freq, embeddings_metadata, **kwargs)
def __init__(
self,
log_dir="logs",
histogram_freq=0,
write_graph=True,
write_images=False,
write_steps_per_second=False,
update_freq="epoch",
profile_batch=2,
embeddings_freq=0,
embeddings_metadata=None,
**kwargs
):
super().__init__(
log_dir,
histogram_freq,
write_graph,
write_images,
write_steps_per_second,
update_freq,
profile_batch,
embeddings_freq,
embeddings_metadata,
**kwargs
)


class CSVLogger(CSVLogger):
def __init__(self, filename, separator=',', append=False):
super().__init__(filename, separator, append)
def __init__(self, filename, separator=",", append=False):
super().__init__(filename, separator, append)
6 changes: 3 additions & 3 deletions core/generators/dataset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import List, Tuple, Dict
from core.generators.batch_generator import BatchGenerator
from core.helpers.align import Align, align_from_file
from pathlib import Path


class DatasetGenerator(object):
Expand Down Expand Up @@ -47,9 +46,10 @@ def build_dataset(self):
val_aligns = self.generate_align_hash(val_videos)

with open(cache_path, "wb") as f:
pickle.dump(obj=(train_videos, train_aligns, val_videos, val_aligns), file=f)
pickle.dump(
obj=(train_videos, train_aligns, val_videos, val_aligns), file=f
)


print(
"Found {} videos and {} aligns for training".format(
len(train_videos), len(train_aligns)
Expand Down
12 changes: 6 additions & 6 deletions core/model/lipnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
class LipNet(object):
def __init__(
self,
frame_count: int,
image_channels: int,
image_height: int,
image_width: int,
max_string: int,
frame_count: int = env.FRAME_COUNT,
image_channels: int = env.IMAGE_CHANNELS,
image_height: int = env.IMAGE_HEIGHT,
image_width: int = env.IMAGE_WIDTH,
max_string: int = env.MAX_STRING,
output_size: int = env.OUTPUT_SIZE,
):
input_shape = self.get_input_shape(
Expand Down Expand Up @@ -86,7 +86,7 @@ def __init__(
outputs=self.loss_out,
)

def compile_model(self, optimizer=None):
def compile(self, optimizer=None):
if optimizer is None:
optimizer = Adam(
learning_rate=ADAM_LEARN_RATE,
Expand Down
11 changes: 11 additions & 0 deletions env.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,14 @@
DEV = 0
USE_CACHE = 1
EPOCH = 60

VIDEO_PATTERN = "*.mpg"

# prediction
DICTIONARY_PATH = "data/dictionaries/grid.txt"
MODEL_PATH = "models/lipnet.h5"
VIDEO_PATH = "videos"
DLIB_SHAPE_PREDICTOR_PATH = "data/dlib/shape_predictor_68_face_landmarks.dat"

# others
TF_CPP_MIN_LOG_LEVEL = "3"
19 changes: 0 additions & 19 deletions log.py

This file was deleted.

57 changes: 24 additions & 33 deletions preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,30 @@
import env
from multiprocessing import Pool, log_to_stderr
from pathlib import Path
from utils.preprocessor import Extractor
from utils.converter import Converter
from tqdm import tqdm

ROOT = Path(__file__).resolve().parents[0]
DATA_TARGET_DIR = ROOT.joinpath("data/dataset")
ERROR_DIR = ROOT.joinpath("logs/")
ERROR_DIR = ROOT.joinpath("logs/preprocessing")

if not DATA_TARGET_DIR.exists():
DATA_TARGET_DIR.mkdir()

if not ERROR_DIR.exists():
ERROR_DIR.mkdir()

pattern = "*.mpg"
pattern = env.VIDEO_PATTERN

# define logger
logger = log_to_stderr(level=logging.DEBUG)


def convert(group_path: os.PathLike):
def convert_video_to_array(group_path: os.PathLike):
"""
convert features from *.mpg files and convert into *.npy format
"""

try:
groupname = group_path.name

Expand All @@ -34,7 +38,7 @@ def convert(group_path: os.PathLike):

logger.info(f"Start: {groupname}")

extractor = Extractor(logger)
converter = Converter(logger)

for file_path in group_path.glob(pattern):
hash_string = file_path.stem
Expand All @@ -44,7 +48,7 @@ def convert(group_path: os.PathLike):
logger.info(f"{groupname + ' | ' + hash_string} | skipped")
continue

if not extractor.video_to_frames(file_path, output_path):
if not converter.video_to_frames(file_path, output_path):
videos_failed.append(hash_string + "\n")

with open(str(ERROR_DIR.joinpath(groupname + ".txt")), "w") as f:
Expand All @@ -53,25 +57,15 @@ def convert(group_path: os.PathLike):
except Exception:
logger.error(traceback.format_exc())
finally:
logger.info(f"{group_path.name} completed.")
logger.info(f"{group_path.name} processing completed.")


# @validate_preprocessing_config
def manager():
"""
convert features from *.mpg files and convert into *.npy format
Multiprocessing manager
"""

logger.info(
r"""
__ __ ______ __ __ ______ __ __ ______
/\ \ /\ \ /\ == \ /\ "-.\ \ /\ ___\ /\_\_\_\ /\__ _\
\ \ \____ \ \ \ \ \ _-/ \ \ \-. \ \ \ __\ \/_/\_\/_ \/_/\ \/
\ \_____\ \ \_\ \ \_\ \ \_\\"\_\ \ \_____\ /\_\/\_\ \ \_\
\/_____/ \/_/ \/_/ \/_/ \/_/ \/_____/ \/_/\/_/ \/_/
"""
)

try:
data_source_dir = ROOT.joinpath("../dataset/lipnet/train/").resolve()
data_target_dir = ROOT.joinpath("data/dataset").resolve()
Expand All @@ -89,21 +83,18 @@ def manager():
logger.error(f"Group {i} is not a directory")
groups.remove(i)
continue

if env.DEV:
convert(groups[0])
else:
with Pool(processes=None) as pool:
res = [
pool.apply_async(
convert,
args=(group_path,),
)
for group_path in groups
]

for p in res:
p.get()

with Pool(processes=None) as pool:
res = [
pool.apply_async(
convert_video_to_array,
args=(group_path,),
)
for group_path in groups
]

for p in res:
p.get()

except Exception:
logger.error(traceback.format_exc())
Expand Down
11 changes: 3 additions & 8 deletions utils/preprocessor.py → utils/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,13 @@
from dlib import get_frontal_face_detector, shape_predictor


class Extractor:
class Converter:
def __init__(self, logger):
self.logger = logger
self.frame_shape = (env.IMAGE_HEIGHT, env.IMAGE_WIDTH, env.IMAGE_CHANNELS)
self.image_size = (env.IMAGE_HEIGHT, env.IMAGE_WIDTH)
self.detector = get_frontal_face_detector()
self.predictor = shape_predictor(
os.path.join(
os.path.realpath(os.path.dirname(__file__)),
"predictors/shape_predictor_68_face_landmarks.dat",
)
)
self.predictor = shape_predictor(env.DLIB_SHAPE_PREDICTOR_PATH)

_, (self.mouth_x_idx, self.mouth_y_idx) = list(
face_utils.FACIAL_LANDMARKS_IDXS.items()
Expand Down Expand Up @@ -87,7 +82,7 @@ def extract_mouth_points(self, frame: np.ndarray) -> Optional[np.ndarray]:

shape = face_utils.shape_to_np(self.predictor(gray, detected[0]))

return shape[self.mouth_x_idx:self.mouth_y_idx]
return shape[self.mouth_x_idx : self.mouth_y_idx]

@staticmethod
def crop_image(image: np.ndarray, center: tuple, size: tuple) -> np.ndarray:
Expand Down
63 changes: 63 additions & 0 deletions utils/helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import os
import csv
import numpy as np
import skvideo.io
from typing import List
from core.helpers.video import (
get_video_data_from_file,
reshape_and_normalize_video_data,
)
from core.utils.visualization import visualize_video_subtitle


def display_results(
valid_paths: list, results: list, display: bool = True, visualize: bool = False
):
if not display and not bool:
return

for p, r in zip(valid_paths, results):
if display:
print("\nVideo: {}\n Result: {}".format(p, r))

if visualize:
v = get_entire_video_data(p)
visualize_video_subtitle(v, r)


def query_save_csv_path(default: str = "output.csv"):
path = input("Output CSV name (default is '{}'): ".format(default))

if not path:
path = default
if not path.endswith(".csv"):
path += ".csv"

return os.path.realpath(path)


def query_yes_no(query: str, default: bool = True) -> bool:
prompt = "[Y/n]" if default else "[y/N]"
inp = input(query + " " + prompt + " ")

return default if not inp else inp.lower()[0] == "y"


def write_results_to_csv(path: os.PathLike, valid_paths: list, results: list):
already_exists = path.exists()

with open(path, "w") as f:
writer = csv.writer(f)

if not already_exists:
writer.writerow(["file", "prediction"])

for p, r in zip(valid_paths, results):
writer.writerow([p, r])


def get_entire_video_data(path: os.PathLike) -> np.ndarray:
if path.suffix == ".mpg":
return np.swapaxes(skvideo.io.vread(path), 1, 2)
else:
return get_video_data_from_file(path)

0 comments on commit 206135d

Please sign in to comment.