Skip to content

Commit

Permalink
python refactoring & project restructuration (#223)
Browse files Browse the repository at this point in the history
* py3 super

* contrib

* data

* dl WIP

* rl

* dl new structure

* rl utils

* codestyle

* tests

* codestyle

* working notebooks

* update

* now dl works

* rl should work

* init update

* docs update

* doc fix

* modules fix

* no registry docs

* tests update

* imread update

* rl core refactoring

* rl algorithms refactored

* docs update

* example fixes

* tests update

* tests check

* get_loader -> dl

* import fix
  • Loading branch information
Scitator authored Jun 17, 2019
1 parent 606ffcf commit edb1e1a
Show file tree
Hide file tree
Showing 222 changed files with 5,096 additions and 5,283 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
- pip install tifffile
- wget -P ./data/ https://www.dropbox.com/s/0rvuae4mj6jn922/isbi.tar.gz
- tar -xf ./data/isbi.tar.gz -C ./data/
- for f in examples/_tests_scripts/*.py; do PYTHONPATH=./catalyst:${PYTHONPATH} python "$f"; done
- (set -e; for f in examples/_tests_scripts/*.py; do PYTHONPATH=./catalyst:${PYTHONPATH} python "$f"; done)
- PYTHONPATH=./examples:./catalyst:${PYTHONPATH}
python catalyst/dl/scripts/run.py
--expdir=./examples/_tests_mnist_stages
Expand Down
2 changes: 1 addition & 1 deletion catalyst/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "19.06.2"
__version__ = "19.06.3"
20 changes: 11 additions & 9 deletions catalyst/contrib/criterion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# flake8: noqa

from torch.nn.modules.loss import *
from .ce import *
from .center import *
from .contrastive import *
from .dice import *
from .focal import *
from .huber import *
from .iou import *
from .lovasz import *
from .wing import *
from .ce import NaiveCrossEntropyLoss
from .center import CenterLoss
from .contrastive import ContrastiveDistanceLoss, ContrastiveEmbeddingLoss
from .dice import DiceLoss, BCEDiceLoss
from .focal import FocalLossBinary, FocalLossMultiClass
from .huber import HuberLoss
from .iou import IoULoss, BCEIoULoss
from .lovasz import LovaszLossBinary, LovaszLossMultiClass, \
LovaszLossMultiLabel
from .wing import WingLoss
6 changes: 3 additions & 3 deletions catalyst/contrib/criterion/dice.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from functools import partial

import torch.nn as nn
from catalyst.dl import metrics
from catalyst.dl.utils import criterion


class DiceLoss(nn.Module):
Expand All @@ -11,10 +11,10 @@ def __init__(
threshold: float = None,
activation: str = "Sigmoid"
):
super(DiceLoss, self).__init__()
super().__init__()

self.loss_fn = partial(
metrics.dice,
criterion.dice,
eps=eps,
threshold=threshold,
activation=activation)
Expand Down
65 changes: 33 additions & 32 deletions catalyst/contrib/criterion/focal.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from functools import partial

from torch.nn.modules.loss import _Loss
from catalyst.dl.losses import sigmoid_focal_loss, reduced_focal_loss
from catalyst.dl.utils import criterion


class FocalLossBinary(_Loss):
Expand All @@ -22,13 +22,13 @@ def __init__(

if reduced:
self.loss_fn = partial(
reduced_focal_loss,
criterion.reduced_focal_loss,
gamma=gamma,
threshold=threshold,
reduction=reduction)
else:
self.loss_fn = partial(
sigmoid_focal_loss,
criterion.sigmoid_focal_loss,
gamma=gamma,
alpha=alpha,
reduction=reduction)
Expand Down Expand Up @@ -87,32 +87,33 @@ def forward(self, logits, targets):
return loss


class FocalLossMultiLabel(_Loss):
"""
Compute focal loss for multi-label problem.
Ignores targets having -1 label
"""

def forward(self, logits, targets):
"""
Args:
logits: [bs; num_classes]
targets: [bs; num_classes]
"""
num_classes = logits.size(1)
loss = 0

for cls in range(num_classes):
# Filter anchors with -1 label from loss computation
if cls == self.ignore:
continue

cls_label_target = targets[..., cls].long()
cls_label_input = logits[..., cls]

loss += self.loss_fn(cls_label_input, cls_label_target)

return loss


__all__ = ["FocalLossBinary", "FocalLossMultiClass", "FocalLossMultiLabel"]
# @TODO: check
# class FocalLossMultiLabel(_Loss):
# """
# Compute focal loss for multi-label problem.
# Ignores targets having -1 label
# """
#
# def forward(self, logits, targets):
# """
# Args:
# logits: [bs; num_classes]
# targets: [bs; num_classes]
# """
# num_classes = logits.size(1)
# loss = 0
#
# for cls in range(num_classes):
# # Filter anchors with -1 label from loss computation
# if cls == self.ignore:
# continue
#
# cls_label_target = targets[..., cls].long()
# cls_label_input = logits[..., cls]
#
# loss += self.loss_fn(cls_label_input, cls_label_target)
#
# return loss


__all__ = ["FocalLossBinary", "FocalLossMultiClass"]
2 changes: 1 addition & 1 deletion catalyst/contrib/criterion/huber.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

class HuberLoss(nn.Module):
def __init__(self, clip_delta=1.0, reduction="elementwise_mean"):
super(HuberLoss, self).__init__()
super().__init__()
self.clip_delta = clip_delta
self.reduction = reduction or "none"

Expand Down
4 changes: 2 additions & 2 deletions catalyst/contrib/criterion/iou.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import partial
import torch.nn as nn
from catalyst.dl import metrics
from catalyst.dl.utils import criterion


class IoULoss(nn.Module):
Expand All @@ -21,7 +21,7 @@ def __init__(
):
super().__init__()
self.metric_fn = partial(
metrics.iou,
criterion.iou,
eps=eps,
threshold=threshold,
activation=activation)
Expand Down
38 changes: 36 additions & 2 deletions catalyst/contrib/criterion/wing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,41 @@
from functools import partial
import math

import torch
import torch.nn as nn
from catalyst.dl import losses


def wing_loss(
outputs: torch.Tensor,
targets: torch.Tensor,
width: int = 5,
curvature: float = 0.5,
reduction: str = "mean"
):
"""
https://arxiv.org/pdf/1711.06753.pdf
Source https://github.com/BloodAxe/pytorch-toolbelt
See :class:`~pytorch_toolbelt.losses` for details.
"""
diff_abs = (targets - outputs).abs()
loss = diff_abs.clone()

idx_smaller = diff_abs < width
idx_bigger = diff_abs >= width

loss[idx_smaller] = \
width * torch.log(1 + diff_abs[idx_smaller] / curvature)

C = width - width * math.log(1 + width / curvature)
loss[idx_bigger] = loss[idx_bigger] - C

if reduction == "sum":
loss = loss.sum()
if reduction == "mean":
loss = loss.mean()

return loss


class WingLoss(nn.Module):
Expand All @@ -13,7 +47,7 @@ def __init__(
):
super().__init__()
self.loss_fn = partial(
losses.wing_loss,
wing_loss,
width=width,
curvature=curvature,
reduction=reduction)
Expand Down
3 changes: 0 additions & 3 deletions catalyst/contrib/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,2 @@
# flake8: noqa
from .sequential import *
# from .encoder import *
# from .classification import *
# from .segmentation import *
2 changes: 1 addition & 1 deletion catalyst/contrib/models/classification/mobilenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def conv_1x1_bn(inp, oup):

class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, expand_ratio):
super(InvertedResidual, self).__init__()
super().__init__()
assert stride in [1, 2]

hidden_dim = round(inp * expand_ratio)
Expand Down
6 changes: 3 additions & 3 deletions catalyst/contrib/models/classification/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _make_divisible(v, divisor, min_value=None):

class h_sigmoid(nn.Module):
def __init__(self, inplace=True):
super(h_sigmoid, self).__init__()
super().__init__()
self.relu = nn.ReLU6(inplace=inplace)

def forward(self, x):
Expand All @@ -40,7 +40,7 @@ def forward(self, x):

class h_swish(nn.Module):
def __init__(self, inplace=True):
super(h_swish, self).__init__()
super().__init__()
self.sigmoid = h_sigmoid(inplace=inplace)

def forward(self, x):
Expand All @@ -49,7 +49,7 @@ def forward(self, x):

class SELayer(nn.Module):
def __init__(self, channel, reduction=4):
super(SELayer, self).__init__()
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction),
Expand Down
4 changes: 2 additions & 2 deletions catalyst/contrib/models/encoder/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .resnet import ResnetEncoder
from .mobilenet import MobileNetEncoder
from .mobilenetv2 import MobileNetV2Encoder

__all__ = ["ResnetEncoder", "MobileNetEncoder"]
__all__ = ["ResnetEncoder", "MobileNetV2Encoder"]
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ..classification import MobileNetV2


class MobileNetEncoder(nn.Module):
class MobileNetV2Encoder(nn.Module):
def __init__(
self,
input_size=224,
Expand Down
7 changes: 7 additions & 0 deletions catalyst/contrib/models/segmentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,10 @@
from .linknet import *
from .psp import *
from .unet import *

__all__ = [
"UnetMetaSpec", "UnetSpec", "ResnetUnetSpec",
"Unet", "Linknet", "FPNUnet", "PSPnet",
"ResnetUnet", "ResnetLinknet", "ResnetFPNUnet", "ResnetPSPnet",
"MobileUnet", "ResNetUnet", "ResNetLinknet"
]
6 changes: 3 additions & 3 deletions catalyst/contrib/models/segmentation/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .head import HeadSpec


class UnetSpec(nn.Module):
class UnetMetaSpec(nn.Module):
def __init__(
self,
encoder: EncoderSpec,
Expand All @@ -31,7 +31,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return output


class _UnetSpec(UnetSpec):
class UnetSpec(UnetMetaSpec):
def __init__(
self,
num_classes: int = 1,
Expand Down Expand Up @@ -77,7 +77,7 @@ def _get_components(
raise NotImplementedError()


class _ResnetUnetSpec(UnetSpec):
class ResnetUnetSpec(UnetMetaSpec):
def __init__(
self,
num_classes: int = 1,
Expand Down
6 changes: 3 additions & 3 deletions catalyst/contrib/models/segmentation/fpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from .bridge import UnetBridge
from .decoder import FPNDecoder
from .head import FPNHead
from .core import _UnetSpec, _ResnetUnetSpec
from .core import UnetSpec, ResnetUnetSpec


class FPNUnet(_UnetSpec):
class FPNUnet(UnetSpec):

def _get_components(
self,
Expand Down Expand Up @@ -42,7 +42,7 @@ def _get_components(
return encoder, bridge, decoder, head


class ResnetFPNUnet(_ResnetUnetSpec):
class ResnetFPNUnet(ResnetUnetSpec):

def _get_components(
self,
Expand Down
6 changes: 3 additions & 3 deletions catalyst/contrib/models/segmentation/linknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from .bridge import UnetBridge
from .decoder import UNetDecoder
from .head import UnetHead
from .core import _UnetSpec, _ResnetUnetSpec
from .core import UnetSpec, ResnetUnetSpec


class Linknet(_UnetSpec):
class Linknet(UnetSpec):

def _get_components(
self,
Expand Down Expand Up @@ -43,7 +43,7 @@ def _get_components(
return encoder, bridge, decoder, head


class ResnetLinknet(_ResnetUnetSpec):
class ResnetLinknet(ResnetUnetSpec):

def _get_components(
self,
Expand Down
4 changes: 4 additions & 0 deletions catalyst/contrib/models/segmentation/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,7 @@
from .mobileunet import *
from .resnetunet import *
from .resnetlinknet import *

__all__ = [
"MobileUnet", "ResNetUnet", "ResNetLinknet"
]
19 changes: 1 addition & 18 deletions catalyst/contrib/models/segmentation/models/resnetunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,25 +81,8 @@ def forward(self, x):
class ResNetUnet(nn.Module):
"""
U-Net inspired encoder-decoder architecture for semantic segmentation,
with a ResNet encoder as proposed by Alexander Buslaev.
with a ResNet encoder as proposed by Alexander Buslaev.
Also known as AlbuNet
See:
- https://arxiv.org/abs/1505.04597 -
U-Net: Convolutional Networks for Biomedical Image Segmentation
- https://arxiv.org/abs/1411.4038 -
Fully Convolutional Networks for Semantic Segmentation
- https://arxiv.org/abs/1512.03385 -
Deep Residual Learning for Image Recognition
- https://arxiv.org/abs/1801.05746 -
TernausNet: U-Net with VGG11
Encoder Pre-Trained on ImageNet for Image Segmentation
- https://arxiv.org/abs/1806.00844 -
TernausNetV2: Fully Convolutional Network for Instance Segmentation
based on https://github.com/mapbox/robosat/blob/master/robosat/unet.py
"""
def __init__(
self,
Expand Down
Loading

0 comments on commit edb1e1a

Please sign in to comment.