Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Regularizations #13

Merged
merged 2 commits into from
Jan 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions box_embeddings/modules/regularization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .regularizer import BoxRegularizer
from .l2_side_regularizer import l2_side_regularizer, L2SideBoxRegularizer
58 changes: 58 additions & 0 deletions box_embeddings/modules/regularization/l2_side_regularizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import List, Tuple, Union, Dict, Any, Optional
from box_embeddings.common.registrable import Registrable
import torch
from box_embeddings.parameterizations.box_tensor import BoxTensor
from box_embeddings.modules.regularization.regularizer import BoxRegularizer

eps = 1e-23


def l2_side_regularizer(
box_tensor: BoxTensor, log_scale: bool = False
) -> Union[float, torch.Tensor]:
"""Applies l2 regularization on all sides of all boxes and returns the sum.

Args:
box_tensor: TODO
log_scale: mean in log scale

Returns:
(None)
"""
z = box_tensor.z # (..., box_dim)
Z = box_tensor.Z # (..., box_dim)

if not log_scale:
return torch.mean((Z - z) ** 2)
else:
return torch.mean(torch.log(torch.abs(Z - z) + eps))


@BoxRegularizer.register("l2_side")
class L2SideBoxRegularizer(BoxRegularizer):

"""Applies l2 regularization on side lengths."""

def __init__(self, weight: float, log_scale: bool = False) -> None:
"""TODO: to be defined.

Args:
weight: Weight (hyperparameter) given to this regularization in the overall loss.
log_scale: Whether the output should be in log scale or not.
Should be true in almost any practical case where box_dim>5.


"""
super().__init__(weight, log_scale=log_scale)

def _forward(self, box_tensor: BoxTensor) -> Union[float, torch.Tensor]:
"""Applies l2 regularization on all sides of all boxes and returns the sum.

Args:
box_tensor: TODO

Returns:
(None)
"""

return l2_side_regularizer(box_tensor, log_scale=self.log_scale)
47 changes: 47 additions & 0 deletions box_embeddings/modules/regularization/regularizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import List, Tuple, Union, Dict, Any, Optional
from box_embeddings.common.registrable import Registrable
import torch
from box_embeddings.parameterizations.box_tensor import BoxTensor


class BoxRegularizer(torch.nn.Module, Registrable):

"""Base box-regularizer class"""

def __init__(
self, weight: float, log_scale: bool = True, **kwargs: Any
) -> None:
"""
Args:
weight: Weight (hyperparameter) given to this regularization in the overall loss.
log_scale: Whether the output should be in log scale or not.
Should be true in almost any practical case where box_dim>5.
kwargs: Unused
"""
super().__init__() # type:ignore
self.weight = weight
self.log_scale = log_scale

def forward(self, box_tensor: BoxTensor) -> Union[float, torch.Tensor]:
"""Calls the _forward and multiplies the weight

Args:
box_tensor: Input box tensor

Returns:
scalar regularization loss
"""

return self.weight * self._forward(box_tensor)

def _forward(self, box_tensor: BoxTensor) -> Union[float, torch.Tensor]:
"""The method that does all the work and needs to be overriden

Args:
box_tensor: Input box tensor

Returns:
0
"""

return 0.0
2 changes: 1 addition & 1 deletion core_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
torch >= 1.7.0
torch >= 1.6.0
numpy