Skip to content

Commit

Permalink
Minor updates.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mihir Rana committed Aug 30, 2021
1 parent 6818005 commit fa27d5b
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 52 deletions.
8 changes: 4 additions & 4 deletions pytorch_common/additional_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ class BaseDatasetConfig(Munch):
Base configuration class for
dataset-related settings.
Class attributes can be accessed with
both `config["key"]` and `config.key`.
Class attributes can be accessed with both
`configobj["key"]` and `configobj.key`.
"""

def __init__(self, dictionary: Optional[_StringDict] = None):
Expand Down Expand Up @@ -43,8 +43,8 @@ class BaseModelConfig(Munch):
Base configuration class for
model-related settings.
Class attributes can be accessed with
both `config["key"]` and `config.key`.
Class attributes can be accessed with both
`configobj["key"]` and `configobj.key`.
"""

def __init__(self, dictionary: Optional[_StringDict] = None, model_type: Optional[str] = "classification"):
Expand Down
19 changes: 14 additions & 5 deletions pytorch_common/datasets_dl.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,19 @@ def __getitem__(self, index):
raise NotImplementedError

def __len__(self):
raise NotImplementedError
return len(self.data)

def print_dataset(self) -> None:
"""
Print useful summary statistics of the dataset.
"""
logging.info("\n" + "-" * 40)
print_dataframe(self.data)
value_counts = self.data[self.target_col].value_counts()
logging.info(f"Target value counts:\n{value_counts}")

if self.target_col in self.data:
value_counts = self.data[self.target_col].value_counts()
logging.info(f"Target value counts:\n{value_counts}")

logging.info("\n" + "-" * 40)

def save(self, *args, **kwargs) -> None:
Expand Down Expand Up @@ -85,11 +88,14 @@ def remove(cls, *args, **kwargs) -> None:
"""
remove_object(*args, **kwargs)

def progress_apply(self, data: pd.DataFrame, func: Callable, *args, **kwargs) -> pd.DataFrame:
def progress_apply(self, data: Union[pd.DataFrame, pd.Series], func: Callable, *args, **kwargs) -> pd.DataFrame:
"""
Generic function to `progress_apply` a given row-level
function `func` on the given `data` (chunk).
"""
if isinstance(data, pd.Series):
return data.progress_apply(func, *args, **kwargs)
assert isinstance(data, pd.DataFrame)
return data.progress_apply(func, *args, **kwargs, axis=1)

def sample_class(
Expand Down Expand Up @@ -193,7 +199,10 @@ def undersample_class(
self.shuffle_and_reindex_data()

def _get_class_info(
self, class_to_sample: Optional[Union[float, str]] = None, column: Optional[str] = None, minority: bool = True,
self,
class_to_sample: Optional[Union[float, str]] = None,
column: Optional[str] = None,
minority: bool = True,
) -> Tuple[Union[float, str], int, List[int]]:
"""
Get the label, counts, and indices of each class.
Expand Down
9 changes: 6 additions & 3 deletions pytorch_common/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def get_loss_criterion_function(config: _Config, criterion: Optional[str] = "cro
agg_func = torch.mean
else:
raise ValueError(
f"Param 'multilabel_reduction' ('{multilabel_reduction}') " f"must be one of ['sum', 'mean']."
f"Param 'multilabel_reduction' ('{multilabel_reduction}') must be one of ['sum', 'mean']."
)

# Get per-label loss
Expand All @@ -124,7 +124,10 @@ def get_loss_criterion_function(config: _Config, criterion: Optional[str] = "cro
# Multilabel classification
else:
return lambda output_hist, y_hist: agg_func(
torch.stack([loss_criterion(output_hist, y_hist[..., i]) for i in range(y_hist.shape[-1])], dim=0,)
torch.stack(
[loss_criterion(output_hist, y_hist[..., i]) for i in range(y_hist.shape[-1])],
dim=0,
)
)


Expand Down Expand Up @@ -153,7 +156,7 @@ def get_eval_criterion_function(
agg_func = np.mean
else:
raise ValueError(
f"Param 'multilabel_reduction' ('{multilabel_reduction}') " f"must be one of ['mean', 'none']."
f"Param 'multilabel_reduction' ('{multilabel_reduction}') must be one of ['mean', 'none']."
)

# Get per-label eval criterion
Expand Down
2 changes: 1 addition & 1 deletion pytorch_common/models_dl.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def predict_proba(
:return probs: Predicted probabilities of each class
"""
if self.model_type != "classification" and threshold is not None:
raise ValueError(f"Param 'threshold' ('{threshold}') can only " f"be provided for classification models.")
raise ValueError(f"Param 'threshold' ('{threshold}') can only be provided for classification models.")

probs = F.softmax(outputs, dim=-1) # Get probabilities of each class
num_classes = probs.shape[-1]
Expand Down
49 changes: 33 additions & 16 deletions pytorch_common/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,14 @@ def train_model(
if not config.disable_checkpointing:
logging.info("Replacing current best model checkpoint...")
best_checkpoint_file = save_model(
model, config, epoch, train_logger, val_logger, optimizer, scheduler, config_info_dict,
model,
config,
epoch,
train_logger,
val_logger,
optimizer,
scheduler,
config_info_dict,
)
remove_model(config, best_epoch, config_info_dict)
best_epoch = epoch
Expand All @@ -177,7 +184,14 @@ def train_model(
if not config.disable_checkpointing:
logging.info("Dumping model and results...")
save_model(
model, config, stop_epoch, train_logger, val_logger, optimizer, scheduler, config_info_dict,
model,
config,
stop_epoch,
train_logger,
val_logger,
optimizer,
scheduler,
config_info_dict,
)

# Save current and best models
Expand Down Expand Up @@ -370,10 +384,10 @@ def perform_one_epoch(

# Store all required items to be returned
loss_hist: List[float] = []
targets_hist: List[torch.Tensor] = []
outputs_hist: List[torch.Tensor] = []
preds_hist: List[torch.Tensor] = []
probs_hist: List[torch.Tensor] = []
targets_hist: _TensorOrTensors = []
outputs_hist: _TensorOrTensors = []
preds_hist: _TensorOrTensors = []
probs_hist: _TensorOrTensors = []

# Enable gradient computation if training to be performed else disable it.
# Technically not required if this function is called from other supported
Expand Down Expand Up @@ -411,9 +425,7 @@ def perform_one_epoch(

# Print progess
if batch_idx in batches_to_print:
logging.info(
f"{num_examples_complete}/{num_examples} " f"({percent_batches_complete:.0f}%) complete."
)
logging.info(f"{num_examples_complete}/{num_examples} ({percent_batches_complete:.0f}%) complete.")

else: # Perform training / evaluation
# Compute and store loss
Expand Down Expand Up @@ -514,7 +526,7 @@ def take_scheduler_step(scheduler: object, val_metric: Optional[float] = None) -

scheduler_name = scheduler.__class__.__name__
if scheduler_name in REQUIRE_VAL_METRIC:
assert val_metric is not None, f"Param 'val_metric' must be provided " f"for '{scheduler_name}' scheduler."
assert val_metric is not None, f"Param 'val_metric' must be provided for '{scheduler_name}' scheduler."
scheduler.step(val_metric)
else:
scheduler.step()
Expand Down Expand Up @@ -604,7 +616,8 @@ def generate_checkpoint_dict(

# Save items if provided
for name, obj in zip(
("train_logger", "val_logger", "optimizer", "scheduler"), (train_logger, val_logger, optimizer, scheduler),
("train_logger", "val_logger", "optimizer", "scheduler"),
(train_logger, val_logger, optimizer, scheduler),
):
if obj is not None:
checkpoint[name] = obj if name in ["train_logger", "val_logger"] else obj.state_dict()
Expand Down Expand Up @@ -742,7 +755,7 @@ def load_state_dict(
if state_dict is not None:
obj.load_state_dict(state_dict)
else:
raise KeyError(f"{key} argument expected its state dict in " f"the loaded checkpoint but none was found.")
raise KeyError(f"{key} argument expected its state dict in the loaded checkpoint but none was found.")
return obj

# Load optimizer
Expand Down Expand Up @@ -794,9 +807,9 @@ def validate_checkpoint_type(checkpoint_type: str, checkpoint_file: Optional[str
`checkpoint_file`, if provided.
"""
ALLOWED_CHECKPOINT_TYPES = ["state", "model"]
assert checkpoint_type in ALLOWED_CHECKPOINT_TYPES, (
f"Param 'checkpoint_type' ('{checkpoint_type}') " f"must be one of {ALLOWED_CHECKPOINT_TYPES}."
)
assert (
checkpoint_type in ALLOWED_CHECKPOINT_TYPES
), f"Param 'checkpoint_type' ('{checkpoint_type}') must be one of {ALLOWED_CHECKPOINT_TYPES}."

# Check that provided checkpoint_type matches that of checkpoint_file
if checkpoint_file is not None:
Expand Down Expand Up @@ -853,7 +866,11 @@ def __init__(
"""
self.criterion = criterion
self._init_params(
mode=mode, min_delta=min_delta, patience=patience, best_val=best_val, best_val_tol=best_val_tol,
mode=mode,
min_delta=min_delta,
patience=patience,
best_val=best_val,
best_val_tol=best_val_tol,
)
self._validate_params()
self.best: Optional[float] = None
Expand Down
3 changes: 3 additions & 0 deletions pytorch_common/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
from matplotlib.figure import Figure
Expand All @@ -17,6 +18,7 @@
"Union",
"Munch",
"_StringDict",
"_StringArrayDict",
"_Config",
"_Device",
"_Batch",
Expand All @@ -35,6 +37,7 @@


_StringDict = Dict[str, Any]
_StringArrayDict = Dict[str, np.ndarray]
_Config = Union[_StringDict, Munch]
_Device = Union[str, torch.device]
_Batch = Iterable
Expand Down
35 changes: 17 additions & 18 deletions pytorch_common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def create_dir_if_not_exists(dir_path: str) -> None:
if it doesn't exist already.
"""
if not os.path.isdir(dir_path):
os.makedirs(dir_path, exist_ok=True) # exist_ok=True to avoid concurrent dir creation
os.makedirs(dir_path, exist_ok=True) # `exist_ok=True` to avoid concurrent dir creation

# Create parent dir
create_dir_if_not_exists(parent_dir_path)
Expand Down Expand Up @@ -162,9 +162,7 @@ def save_plot(
fig.savefig(get_file_path(config.plot_dir, f"{file_name}.{ext}"), dpi=300)


def save_object(
obj: Any, primary_path: str, file_name: Optional[str] = None, module: Optional[str] = "pickle"
) -> None:
def save_object(obj: Any, primary_path: str, file_name: Optional[str] = None, module: Optional[str] = "pickle") -> None:
"""
This is a generic function to save any given
object using different `module`s, e.g. pickle,
Expand Down Expand Up @@ -293,11 +291,9 @@ def get_pickle_module(pickle_module: Optional[str] = "pickle") -> Union[pickle,
Return the correct module for pickling.
:param pickle_module: must be one of ["pickle", "dill"]
"""
if pickle_module == "pickle":
return pickle
elif pickle_module == "dill":
return dill
raise ValueError(f"Param 'pickle_module' ('{pickle_module}') must be one of ['pickle', 'dill'].")
if not pickle_module in ["pickle", "dill"]:
raise ValueError(f"Param 'pickle_module' ('{pickle_module}') must be one of ['pickle', 'dill'].")
return eval(pickle_module)


def delete_model(model: nn.Module) -> None:
Expand Down Expand Up @@ -358,7 +354,10 @@ def get_unique_config_name(primary_name: str, config_info_dict: Optional[_String


def get_checkpoint_name(
checkpoint_type: str, model_name: str, epoch: int, config_info_dict: Optional[_StringDict] = None,
checkpoint_type: str,
model_name: str,
epoch: int,
config_info_dict: Optional[_StringDict] = None,
) -> str:
"""
Returns the appropriate name of checkpoint file
Expand All @@ -383,7 +382,7 @@ def get_trainable_params(model: nn.Module) -> Dict[str, int]:
num_params = sum(p.numel() for p in model.parameters())
num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
model_name = getattr(model, "__name__", model.__class__.__name__)
logging.info(f"Number of trainable/total parameters in {model_name}: " f"{num_trainable_params}/{num_params}")
logging.info(f"Number of trainable/total parameters in {model_name}: {num_trainable_params}/{num_params}")
return {"trainable": num_trainable_params, "total": num_params}


Expand Down Expand Up @@ -714,7 +713,10 @@ def add_eval_metrics(self, eval_metrics: Dict[str, float], epoch: Optional[int]
self.eval_metrics_hist[eval_criterion][epoch] = eval_metrics[eval_criterion]

def get_eval_metrics(
self, eval_criterion: Optional[str] = None, epoch: Optional[int] = None, flatten: Optional[bool] = False,
self,
eval_criterion: Optional[str] = None,
epoch: Optional[int] = None,
flatten: Optional[bool] = False,
) -> Union[float, List[float], OrderedDict[str, Union[float, List[float]]]]:
"""
Get the evaluation metrics history.
Expand Down Expand Up @@ -742,10 +744,7 @@ def get_eval_metrics(
return self.eval_metrics_hist[eval_criterion] # Return ordered dict
elif epoch is not None:
return OrderedDict(
{
eval_criterion: self.eval_metrics_hist[eval_criterion][epoch]
for eval_criterion in self.eval_criteria
}
{eval_criterion: self.eval_metrics_hist[eval_criterion][epoch] for eval_criterion in self.eval_criteria}
)
return self.eval_metrics_hist

Expand Down Expand Up @@ -779,7 +778,7 @@ def log_epoch_metrics(self, epoch: Optional[int] = -1) -> str:
assert epoch_loss == epoch_eval_metrics
dataset_type = "TRAIN" if self.is_train else "VAL "
mean_loss_epoch = np.mean(self.get_losses(epoch=epoch_loss))
result_str = f"\n\033[1m{dataset_type} Epoch: {epoch_loss}" f"\tAverage loss: {mean_loss_epoch:.4f}, "
result_str = f"\n\033[1m{dataset_type} Epoch: {epoch_loss}\tAverage loss: {mean_loss_epoch:.4f}, "
result_str += ", ".join(
[
f"{eval_criterion}: {self.get_eval_metrics(eval_criterion, epoch_loss):.4f}"
Expand Down Expand Up @@ -957,7 +956,7 @@ def _set_pooler(self, model_type: str) -> None:
self.pooler = self.POOLER_MAPPING[self.model_type]
else:
logging.warning(
f"No supported sequence pooler was found for model of " f"type '{model_type}'. Using the default one."
f"No supported sequence pooler was found for model of type '{model_type}'. Using the default one."
)
self.model_type = self.DEFAULT_POOLER_TYPE
self.pooler = self._default_pooler
Expand Down
10 changes: 5 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
# Application info
name="pytorch_common",
version="1.4",
version="1.5",
author="Mihir Rana",
author_email="[email protected]",
description="Repo for common PyTorch code",
Expand All @@ -15,16 +15,16 @@
install_requires=[
"numpy>=1.17.2",
"pandas>=0.24.0",
"matplotlib>=3.2.1",
"dask[dataframe]==2.21.0",
"matplotlib>=3.3.2",
"dask[dataframe]>=2.30.0",
"toolz==0.10.0",
"scikit-learn>=0.22.1",
"dill==0.3.2",
"dill>=0.3.3",
"munch>=2.5.0",
"locket==0.2.0",
],
# Optional dependencies
extras_require={"nlp": ["transformers>=3.0.2"]}, # for NLP related projects
extras_require={"nlp": ["transformers==4.9.2"]}, # for NLP related projects
# Add config and sql files to the package
# https://python-packaging.readthedocs.io/en/latest/non-code-files.html
include_package_data=True,
Expand Down

0 comments on commit fa27d5b

Please sign in to comment.