-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #13 from iesl/regularizers
Added delta regularizer for now. More can be added later.
- Loading branch information
Showing
4 changed files
with
108 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .regularizer import BoxRegularizer | ||
from .l2_side_regularizer import l2_side_regularizer, L2SideBoxRegularizer |
58 changes: 58 additions & 0 deletions
58
box_embeddings/modules/regularization/l2_side_regularizer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from typing import List, Tuple, Union, Dict, Any, Optional | ||
from box_embeddings.common.registrable import Registrable | ||
import torch | ||
from box_embeddings.parameterizations.box_tensor import BoxTensor | ||
from box_embeddings.modules.regularization.regularizer import BoxRegularizer | ||
|
||
eps = 1e-23 | ||
|
||
|
||
def l2_side_regularizer( | ||
box_tensor: BoxTensor, log_scale: bool = False | ||
) -> Union[float, torch.Tensor]: | ||
"""Applies l2 regularization on all sides of all boxes and returns the sum. | ||
Args: | ||
box_tensor: TODO | ||
log_scale: mean in log scale | ||
Returns: | ||
(None) | ||
""" | ||
z = box_tensor.z # (..., box_dim) | ||
Z = box_tensor.Z # (..., box_dim) | ||
|
||
if not log_scale: | ||
return torch.mean((Z - z) ** 2) | ||
else: | ||
return torch.mean(torch.log(torch.abs(Z - z) + eps)) | ||
|
||
|
||
@BoxRegularizer.register("l2_side") | ||
class L2SideBoxRegularizer(BoxRegularizer): | ||
|
||
"""Applies l2 regularization on side lengths.""" | ||
|
||
def __init__(self, weight: float, log_scale: bool = False) -> None: | ||
"""TODO: to be defined. | ||
Args: | ||
weight: Weight (hyperparameter) given to this regularization in the overall loss. | ||
log_scale: Whether the output should be in log scale or not. | ||
Should be true in almost any practical case where box_dim>5. | ||
""" | ||
super().__init__(weight, log_scale=log_scale) | ||
|
||
def _forward(self, box_tensor: BoxTensor) -> Union[float, torch.Tensor]: | ||
"""Applies l2 regularization on all sides of all boxes and returns the sum. | ||
Args: | ||
box_tensor: TODO | ||
Returns: | ||
(None) | ||
""" | ||
|
||
return l2_side_regularizer(box_tensor, log_scale=self.log_scale) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from typing import List, Tuple, Union, Dict, Any, Optional | ||
from box_embeddings.common.registrable import Registrable | ||
import torch | ||
from box_embeddings.parameterizations.box_tensor import BoxTensor | ||
|
||
|
||
class BoxRegularizer(torch.nn.Module, Registrable): | ||
|
||
"""Base box-regularizer class""" | ||
|
||
def __init__( | ||
self, weight: float, log_scale: bool = True, **kwargs: Any | ||
) -> None: | ||
""" | ||
Args: | ||
weight: Weight (hyperparameter) given to this regularization in the overall loss. | ||
log_scale: Whether the output should be in log scale or not. | ||
Should be true in almost any practical case where box_dim>5. | ||
kwargs: Unused | ||
""" | ||
super().__init__() # type:ignore | ||
self.weight = weight | ||
self.log_scale = log_scale | ||
|
||
def forward(self, box_tensor: BoxTensor) -> Union[float, torch.Tensor]: | ||
"""Calls the _forward and multiplies the weight | ||
Args: | ||
box_tensor: Input box tensor | ||
Returns: | ||
scalar regularization loss | ||
""" | ||
|
||
return self.weight * self._forward(box_tensor) | ||
|
||
def _forward(self, box_tensor: BoxTensor) -> Union[float, torch.Tensor]: | ||
"""The method that does all the work and needs to be overriden | ||
Args: | ||
box_tensor: Input box tensor | ||
Returns: | ||
0 | ||
""" | ||
|
||
return 0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
torch >= 1.7.0 | ||
torch >= 1.6.0 | ||
numpy |