Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding torch implementation #3

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
281 changes: 142 additions & 139 deletions examples/test_torch.py
Original file line number Diff line number Diff line change
@@ -1,167 +1,170 @@
"""
pytorch has a "functional" grad API [1,2] as of v1.5
Define custom derivatives via JVPs (forward mode).

torch.autograd.functional
The result from this implementation is not too accurate, since
centered finite differences method is implemented.

in addition to

# like jax.nn and jax.experimental.stax
torch.nn.functional

However, unlike jax, torch.autograd.functional's functions don't return
functions. One needs to supply the function to differentiate along with the
input at which grad(func) shall be evaluated.

# like jax.grad(func)(x)
#
# default vector v in VJP is v=None -> v=1 -> return grad(func)(x)
torch.autograd.functional.vjp(func, x) -> Tensor

torch.autograd.functional.hessian(func, x) -> Tensor

whereas

jax.grad(func) -> grad_func
jax.grad(func)(x) -> grad_func(x) -> DeviceArray

jax.hessian(func) -> hess_func
jax.hessian(func)(x) -> hess_func(x) -> DeviceArray


resources
https://pytorch.org/tutorials/beginner/pytorch_with_examples.html#autograd
https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#sphx-glr-beginner-blitz-autograd-tutorial-py
https://pytorch.org/docs/stable/autograd.html
https://pytorch.org/docs/stable/notes/autograd.html


[1] https://github.com/pytorch/pytorch/commit/1f4a4aaf643b70ebcb40f388ae5226a41ca57d9b
[2] https://pytorch.org/docs/stable/autograd.html#functional-higher-level-api
"""

import torch
import numpy as np

rand = torch.rand

### gradient of scalar result c w.r.t. x, evaluated at x
### step by step, see grad_fn
##x = torch.rand(3, requires_grad=True)
####c = x.sin().pow(2.0).sum()
##a = torch.sin(x)
##print(a)
##b = torch.pow(a, 2.0)
##print(b)
##c = torch.sum(b)
##print(c)
### same as torch.autograd.grad(c,x)
##c.backward()
##print(x.grad)
##
##
### VJP: extract one row of J
##x = torch.rand(3, requires_grad=True)
##v = torch.tensor([1.0,0,0])
##b = x.sin().pow(2.0)
##b.backward(v)
##print(x.grad)


func_plain_torch = lambda x: torch.sin(x).pow(2.0).sum()


def copy(x, requires_grad=False):
_x = x.clone().detach()
if not requires_grad:
assert not _x.requires_grad
import control as ct
import torch
from torch.autograd import Function
import torch.autograd.gradcheck as gradcheck


def model(params, ):
"""
Create state-space representation and return the system transfer function.
- params: parameters for the state-space matrices. See https://github.com/oselin/gradient_ML_LTI/blob/main/gradient_propagation.ipynb
NOTE: The system transfer function is computed via control library, that does not natively support backpropagation of the gradient.
"""
A = torch.tensor([[params[0], params[1]],
[params[2], params[3]]], dtype=torch.double)

B = torch.tensor([params[4], params[5]], dtype=torch.double)

C = torch.tensor([params[6], params[7]], dtype=torch.double)

D = torch.tensor([params[8]], dtype=torch.double)

G = ct.ss2tf(A, B, C, D)

return G


def forced_response(trn_fcn, u, time):
"""
Return a torch tensor containing the forced response of the system to input
- trn_fcn: transfer function on which to compute the forced resonse
- u: system input in time domain
- time: time array during which the input is applied
"""
output = ct.forced_response(trn_fcn, time, u.detach().numpy()).outputs
output = torch.tensor(output.copy(), requires_grad=True, dtype=torch.double)
return output


def impulse_response(trn_fcn, time):#
"""
Return a torch tensor containing the impulse response of the system
- trn_fcn: transfer function on which to compute the impulse resonse
- time: time array during which the impulse response has to be computed
"""
output = ct.impulse_response(trn_fcn, time).outputs
output = torch.tensor(output.copy(), requires_grad=True, dtype=torch.double)
return output


def get_magnitude_torch(tensor):
"""
Compute the magnitude of a torch value
- tensor: torch tensor
"""
if torch.equal(tensor, torch.zeros_like(tensor)):
return 0 # Magnitude of a zero tensor is 0

magnitude = int(torch.floor(torch.log10(tensor.abs())).item())
return magnitude


def grad(f, x, h=None):
"""
Return the gradient, as list of partial derivatives
computed via (centered) finite differences.
- f: function, callable. Function f on which to compute the gradient
- x: parameter with respect to compute the gradient
- h: step size for finite differences gradient computation

NOTE: according to the parameter magnitude, a tailored step size h is required.
This gradient implementation takes into account that

NOTE: for parameters with magnitude of 1e4, h=1e-2 is demonstrated to be significant
"""

grads, hs = [], []
if (h is None): # h is set to auto
for x_i in x:
# Get the magnitude of the parameter
coeff = get_magnitude_torch(x_i) - 6
hs.append(float(10**coeff))
else:
_x.requires_grad = requires_grad
return _x
hs = [h for _ in x]


# -----------------------------------------------------------------------------
# poor man's jax-like API
# -----------------------------------------------------------------------------
for i in range(len(x)):
# NOTE: the copy of x will be with requires_gradient=False
x_p = x.clone()
x_m = x.clone()

x_p[i] += hs[i]
x_m[i] -= hs[i]

def _wrap_input(func):
def wrapper(_x):
if isinstance(_x, torch.Tensor):
x = _x
else:
x = torch.Tensor(np.atleast_1d(_x))
x.requires_grad = True
if x.grad is not None:
x.grad.zero_()
return func(x)
dfdi = (f(x_p) - f(x_m))/(2*hs[i])
grads.append(dfdi)

return wrapper
return grads


# only to make scalar args work
@_wrap_input
def cos(x):
return torch.cos(x)
class TransferFunction(Function):
"""
Extend torch.autograd capabilities by designing a custom class TransferFunction that inherits from torch.autograd.Function.
This allows to manually define both forward and backward methods.
The gradient is propagated in the backward method via JVP
"""
@staticmethod
def forward(ctx, function_input, u, time):

# Direct computation: compute the forward operation i.e the output of the transfer function
output = forced_response(model(function_input), u, time)

@_wrap_input
def func(x):
return func_plain_torch(x)
# Save the current input and output for further computation of the gradient
ctx.save_for_backward(function_input, u, time)

return output

@staticmethod
def backward(ctx, grad_output):
# Try to bind the output gradient with the input gradient. i.e. chain rule

def grad(func):
@_wrap_input
def _gradfunc(x):
out = func(x)
out.backward(torch.ones_like(out))
# x.grad is a Tensor of x.shape which holds the derivatives of func
# w.r.t each x[i,j,k,...] evaluated at x, got it?
return x.grad
f_input, u, time, = ctx.saved_tensors

return _gradfunc
# Create g(x) where x are the params.
# This allows to test the function by manually changing the single parameter
# See grad function to understand why
gx = lambda p: forced_response(model(p), u, time)

# Compute the gradients wrt each parameter
grads = grad(gx, f_input, h=1e-3)

# Apply the chain rule for each partial derivative to update each parameter p
out = [grad_output*i for i in grads]

# Convert the output from a list of partial derivatives to a N-by-p matrix, with p number of parameters, N size of data over time
out = torch.stack(out, dim=1)

elementwise_grad = grad
# sum all the gradients to match the needed output dimension, i.e. p
out = torch.sum(out, dim=0)

return out, None, None



def test():
# Check that grad() works
assert torch.allclose(grad(torch.sin)(1.234), cos(1.234))
x = rand(10) * 5 - 5
assert torch.allclose(elementwise_grad(torch.sin)(x), torch.cos(x))
assert grad(func)(x).shape == x.shape

# Show 4 different pytorch grad APIs
x1 = rand(3, requires_grad=True)

# 1
c1 = func_plain_torch(x1)
c1.backward()
g1 = x1.grad

# 2
x2 = copy(x1, requires_grad=True)
c2 = func_plain_torch(x2)
torch.autograd.backward(c2)
g2 = x2.grad
assert (g1 == g2).all()

# 3
x2 = copy(x1, requires_grad=True)
c2 = func_plain_torch(x2)
g2 = torch.autograd.grad(c2, x2)[0]
assert (g1 == g2).all()

# 4
x2 = copy(x1)
g2 = torch.autograd.functional.vjp(func_plain_torch, x2)[1]
assert (g1 == g2).all()

# jax-like functional API defined here
x2 = copy(x1)
g2 = grad(func)(x2)
assert (g1 == g2).all()
# Definition of time, input, parameters, ground truth
time = torch.tensor(np.linspace(1, 10, 101, endpoint=False))
u = torch.sin(time).requires_grad_(True)

ref_params = torch.tensor([-1, 1, 3, -4, 1, -1, 0, 1, 0], requires_grad=True, dtype=torch.double)


# The extended class has to be called via the .apply method.
# It is easier to assign it to an intermediate variable
myTransferFunction = TransferFunction.apply

test_passed =gradcheck(myTransferFunction, (ref_params.requires_grad_(True), u.requires_grad_(False), time.requires_grad_(False)))


if __name__ == "__main__":
test()
test()