Skip to content

Commit

Permalink
Merge pull request #13 from iesl/regularizers
Browse files Browse the repository at this point in the history
Added delta regularizer for now. More can be added later.
  • Loading branch information
dhruvdcoder authored Jan 16, 2021
2 parents 1964ee4 + 16ebcb2 commit 444e3e8
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 1 deletion.
2 changes: 2 additions & 0 deletions box_embeddings/modules/regularization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .regularizer import BoxRegularizer
from .l2_side_regularizer import l2_side_regularizer, L2SideBoxRegularizer
58 changes: 58 additions & 0 deletions box_embeddings/modules/regularization/l2_side_regularizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import List, Tuple, Union, Dict, Any, Optional
from box_embeddings.common.registrable import Registrable
import torch
from box_embeddings.parameterizations.box_tensor import BoxTensor
from box_embeddings.modules.regularization.regularizer import BoxRegularizer

eps = 1e-23


def l2_side_regularizer(
box_tensor: BoxTensor, log_scale: bool = False
) -> Union[float, torch.Tensor]:
"""Applies l2 regularization on all sides of all boxes and returns the sum.
Args:
box_tensor: TODO
log_scale: mean in log scale
Returns:
(None)
"""
z = box_tensor.z # (..., box_dim)
Z = box_tensor.Z # (..., box_dim)

if not log_scale:
return torch.mean((Z - z) ** 2)
else:
return torch.mean(torch.log(torch.abs(Z - z) + eps))


@BoxRegularizer.register("l2_side")
class L2SideBoxRegularizer(BoxRegularizer):

"""Applies l2 regularization on side lengths."""

def __init__(self, weight: float, log_scale: bool = False) -> None:
"""TODO: to be defined.
Args:
weight: Weight (hyperparameter) given to this regularization in the overall loss.
log_scale: Whether the output should be in log scale or not.
Should be true in almost any practical case where box_dim>5.
"""
super().__init__(weight, log_scale=log_scale)

def _forward(self, box_tensor: BoxTensor) -> Union[float, torch.Tensor]:
"""Applies l2 regularization on all sides of all boxes and returns the sum.
Args:
box_tensor: TODO
Returns:
(None)
"""

return l2_side_regularizer(box_tensor, log_scale=self.log_scale)
47 changes: 47 additions & 0 deletions box_embeddings/modules/regularization/regularizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import List, Tuple, Union, Dict, Any, Optional
from box_embeddings.common.registrable import Registrable
import torch
from box_embeddings.parameterizations.box_tensor import BoxTensor


class BoxRegularizer(torch.nn.Module, Registrable):

"""Base box-regularizer class"""

def __init__(
self, weight: float, log_scale: bool = True, **kwargs: Any
) -> None:
"""
Args:
weight: Weight (hyperparameter) given to this regularization in the overall loss.
log_scale: Whether the output should be in log scale or not.
Should be true in almost any practical case where box_dim>5.
kwargs: Unused
"""
super().__init__() # type:ignore
self.weight = weight
self.log_scale = log_scale

def forward(self, box_tensor: BoxTensor) -> Union[float, torch.Tensor]:
"""Calls the _forward and multiplies the weight
Args:
box_tensor: Input box tensor
Returns:
scalar regularization loss
"""

return self.weight * self._forward(box_tensor)

def _forward(self, box_tensor: BoxTensor) -> Union[float, torch.Tensor]:
"""The method that does all the work and needs to be overriden
Args:
box_tensor: Input box tensor
Returns:
0
"""

return 0.0
2 changes: 1 addition & 1 deletion core_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
torch >= 1.7.0
torch >= 1.6.0
numpy

0 comments on commit 444e3e8

Please sign in to comment.