Skip to content

Commit

Permalink
feat: added SMARTLoss implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Apr 18, 2022
1 parent db0e482 commit 6fdfb0c
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__pycache__
29 changes: 29 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from setuptools import setup, find_packages

setup(
name = 'smart-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.1',
license='MIT',
description = 'SMART Fine-Tuning - Pytorch',
author = 'Flavio Schneider',
author_email = '[email protected]',
url = 'https://github.com/archinetai/smart-pytorch',
keywords = [
'artificial intelligence',
'deep learning',
'fine-tuning',
'pre-trained',
],
install_requires=[
'torch>=1.6',
'data-science-types>=0.2'
],
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
56 changes: 56 additions & 0 deletions smart_pytorch/smart_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import List, Union, Callable

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

def inf_norm(x):
return torch.norm(x, p=float('inf'), dim=-1, keepdim=True)

def to_list(x):
return x if isinstance(x, list) else [x]

class SMARTLoss(nn.Module):

def __init__(
self,
eval_fn: Callable,
loss_fn: Union[Callable, List[Callable]],
norm_fn: Callable = inf_norm,
num_steps: int = 1,
step_size: float = 1e-3,
epsilon: float = 1e-6,
noise_var: float = 1e-5
) -> None:
super().__init__()
self.eval_fn = eval_fn
self.loss_fn = to_list(loss_fn)
self.norm_fn = norm_fn
self.num_steps = num_steps
self.step_size = step_size
self.epsilon = epsilon
self.noise_var = noise_var

def forward(self, embed: Tensor, state: Union[Tensor, List[Tensor]]) -> Tensor:
states = to_list(state)
noise = torch.randn_like(embed, requires_grad=True) * self.noise_var

for i in range(self.num_steps + 2):
# Compute perturbed states
embed_perturbed = embed + noise
states_perturbed = to_list(self.eval_fn(embed_perturbed))
loss = 0
# Compute perturbation loss over all states
for j in range(len(states)):
loss += self.loss_fn[j](states_perturbed[j], states[j].detach())
if i == self.num_steps + 1:
return loss
# Compute noise gradient
noise_gradient = torch.autograd.grad(loss, noise)[0]
# Move noise towards gradient to change state as much as possible
step = noise + self.step_size * noise_gradient
step_norm = self.norm_fn(step)
noise = step / (step_norm + self.epsilon)
# Reset noise gradients for next step
noise = noise.detach().requires_grad_()

0 comments on commit 6fdfb0c

Please sign in to comment.