This repository contains the code for Real-Time Semantic Segmentation to address the Domain Shift problem with Domain Adaptation. The code is implemented in PyTorch and uses the DeepLabV and BieSNet model for semantic segmentation. and provided two different datasets as real and synthetic data. The main.py
script serves as the primary executable to start the training process, guided by configurations specified in the config.yaml
file.
- Dependencies: Install the required Python packages using the following command:
pip install -r requirements.txt
The config.yaml
file contains all the necessary configurations for data, model, training, augmentation, and callbacks. Below is a detailed description of the configuration parameters:
-
Cityscapes Dataset:
images_train_dir
: Directory for training images.images_val_dir
: Directory for validation images.segmentation_train_dir
: Directory for training segmentation labels.segmentation_val_dir
: Directory for validation segmentation labels.image_size
: Tuple specifying the image size.num_classes
: Number of classes.batch_size
: Batch size for training.num_workers
: Number of worker threads for data loading.
-
GTA5 Modified Dataset:
- Similar parameters as above, but specific to the GTA5 dataset.
class_names
: List of class names used for segmentation.
-
DeepLab:
backbone
: Backbone network architecture.output_stride
: Output stride for the model.num_classes
: Number of classes.pretrained
: Boolean indicating if a pretrained model should be used.pretrained_path
: Path to the pretrained model file.optimizer
: Optimizer settings.criterion
: Loss function settings.
-
BiSeNet:
- Similar parameters as for DeepLab.
-
Adversarial Model:
- Configuration for the generator and discriminator used in domain adaptation.
-
Segmentation:
- Training settings specific to segmentation tasks.
-
Domain Adaptation:
- Training settings specific to domain adaptation tasks.
- Various data augmentation settings such as Gaussian blur and horizontal flip.
-
Model Checkpoint:
- Settings for saving model checkpoints.
-
Early Stopping:
- Settings for early stopping based on validation loss.
-
Logging:
- Settings for logging training progress with tools like Weights & Biases.
-
Images Plots:
- Settings for saving image plots during training.
device
: Specifies whether to usecpu
orcuda
(GPU) for training.
To start the training process, run the following command:
python main.py --config config.yaml
For getting help about the arguments, run the following command:
python main.py --help
The main script is responsible for:
- Loading Configuration: Reads the configuration from
config.yaml
. - Initializing Model: Sets up the model architecture based on the configuration.
- Data Loading: Prepares the data loaders for training and validation datasets.
- Training Loop: Executes the training loop, including forward and backward passes, loss calculation, and optimizer updates.
- Validation: Performs validation at specified intervals and logs the results.
- Callbacks: Handles callbacks such as model checkpointing, early stopping, and logging.
In the following link you can find the pretrained weights for DeepLab.
DeepLab petrained weights: https://drive.google.com/file/d/1ZX0UCXvJwqd2uBGCX7LI2n-DfMg3t74v/view?usp=sharing
To download the dataset use the following download links.
Cityscapes: https://drive.google.com/file/d/1Qb4UrNsjvlU-wEsR9d7rckB0YS_LXgb2/view?usp=sharing
GTA5: https://drive.google.com/file/d/1xYxlcMR2WFCpayNrW2-Rb7N-950vvl23/view?usp=sharing
Plese refer to this link to convert GTA5 labels in the same format of Cityscapes: https://github.com/sarrrrry/PyTorchDL_GTA5/blob/master/pytorchdl_gta5/labels.py
First install fvcore with this command:
!pip install -U fvcore
To calculate the FLOPs and number of parameters please use this code:
from fvcore.nn import FlopCountAnalysis, flop_count_table
# -----------------------------
# Initizialize your model here
# -----------------------------
height = ...
width = ...
image = torch.zeros((3, height, width))
flops = FlopCountAnalysis(model, image)
print(flop_count_table(flops))
Reference: https://github.com/facebookresearch/fvcore/blob/main/docs/flop_count.md
Please refer to this pseudo-code for latency and FPS calculation.
$\texttt{image} \gets \texttt{random(3, height, width)}$
$\texttt{iterations} \gets 1000$
$\texttt{latency} \gets \texttt{[]}$
$\texttt{FPS} \gets \texttt{[]}$
repeat$\texttt{iterations}$ times
$\texttt{start = time.time()}$
$\texttt{output = model(image)}$
$\texttt{end = time.time()}$
$\texttt{latency}_i \texttt{ = end - start} $
$\texttt{latency.append(latency}_i \texttt{}) $
$\texttt{FPS}_i = \frac{\texttt{1}}{\texttt{latency}_i}$
$\texttt{FPS.append(FPS}_i \texttt{})$
end
$\texttt{meanLatency} \gets \texttt{mean(latency)*1000}$
$\texttt{stdLatency} \gets \texttt{std(latency)*1000}$
$\texttt{meanFPS} \gets \texttt{mean(FPS)}$
$\texttt{stdFPS} \gets \texttt{std(FPS)}$