diff --git a/examples/test_torch.py b/examples/test_torch.py index 0e793f2..96a1a9d 100644 --- a/examples/test_torch.py +++ b/examples/test_torch.py @@ -1,167 +1,170 @@ """ -pytorch has a "functional" grad API [1,2] as of v1.5 +Define custom derivatives via JVPs (forward mode). - torch.autograd.functional +The result from this implementation is not too accurate, since +centered finite differences method is implemented. -in addition to - # like jax.nn and jax.experimental.stax - torch.nn.functional - -However, unlike jax, torch.autograd.functional's functions don't return -functions. One needs to supply the function to differentiate along with the -input at which grad(func) shall be evaluated. - - # like jax.grad(func)(x) - # - # default vector v in VJP is v=None -> v=1 -> return grad(func)(x) - torch.autograd.functional.vjp(func, x) -> Tensor - - torch.autograd.functional.hessian(func, x) -> Tensor - -whereas - - jax.grad(func) -> grad_func - jax.grad(func)(x) -> grad_func(x) -> DeviceArray - - jax.hessian(func) -> hess_func - jax.hessian(func)(x) -> hess_func(x) -> DeviceArray - - -resources - https://pytorch.org/tutorials/beginner/pytorch_with_examples.html#autograd - https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#sphx-glr-beginner-blitz-autograd-tutorial-py - https://pytorch.org/docs/stable/autograd.html - https://pytorch.org/docs/stable/notes/autograd.html - - -[1] https://github.com/pytorch/pytorch/commit/1f4a4aaf643b70ebcb40f388ae5226a41ca57d9b -[2] https://pytorch.org/docs/stable/autograd.html#functional-higher-level-api """ -import torch import numpy as np - -rand = torch.rand - -### gradient of scalar result c w.r.t. x, evaluated at x -### step by step, see grad_fn -##x = torch.rand(3, requires_grad=True) -####c = x.sin().pow(2.0).sum() -##a = torch.sin(x) -##print(a) -##b = torch.pow(a, 2.0) -##print(b) -##c = torch.sum(b) -##print(c) -### same as torch.autograd.grad(c,x) -##c.backward() -##print(x.grad) -## -## -### VJP: extract one row of J -##x = torch.rand(3, requires_grad=True) -##v = torch.tensor([1.0,0,0]) -##b = x.sin().pow(2.0) -##b.backward(v) -##print(x.grad) - - -func_plain_torch = lambda x: torch.sin(x).pow(2.0).sum() - - -def copy(x, requires_grad=False): - _x = x.clone().detach() - if not requires_grad: - assert not _x.requires_grad +import control as ct +import torch +from torch.autograd import Function +import torch.autograd.gradcheck as gradcheck + + +def model(params, ): + """ + Create state-space representation and return the system transfer function. + - params: parameters for the state-space matrices. See https://github.com/oselin/gradient_ML_LTI/blob/main/gradient_propagation.ipynb + NOTE: The system transfer function is computed via control library, that does not natively support backpropagation of the gradient. + """ + A = torch.tensor([[params[0], params[1]], + [params[2], params[3]]], dtype=torch.double) + + B = torch.tensor([params[4], params[5]], dtype=torch.double) + + C = torch.tensor([params[6], params[7]], dtype=torch.double) + + D = torch.tensor([params[8]], dtype=torch.double) + + G = ct.ss2tf(A, B, C, D) + + return G + + +def forced_response(trn_fcn, u, time): + """ + Return a torch tensor containing the forced response of the system to input + - trn_fcn: transfer function on which to compute the forced resonse + - u: system input in time domain + - time: time array during which the input is applied + """ + output = ct.forced_response(trn_fcn, time, u.detach().numpy()).outputs + output = torch.tensor(output.copy(), requires_grad=True, dtype=torch.double) + return output + + +def impulse_response(trn_fcn, time):# + """ + Return a torch tensor containing the impulse response of the system + - trn_fcn: transfer function on which to compute the impulse resonse + - time: time array during which the impulse response has to be computed + """ + output = ct.impulse_response(trn_fcn, time).outputs + output = torch.tensor(output.copy(), requires_grad=True, dtype=torch.double) + return output + + +def get_magnitude_torch(tensor): + """ + Compute the magnitude of a torch value + - tensor: torch tensor + """ + if torch.equal(tensor, torch.zeros_like(tensor)): + return 0 # Magnitude of a zero tensor is 0 + + magnitude = int(torch.floor(torch.log10(tensor.abs())).item()) + return magnitude + + +def grad(f, x, h=None): + """ + Return the gradient, as list of partial derivatives + computed via (centered) finite differences. + - f: function, callable. Function f on which to compute the gradient + - x: parameter with respect to compute the gradient + - h: step size for finite differences gradient computation + + NOTE: according to the parameter magnitude, a tailored step size h is required. + This gradient implementation takes into account that + + NOTE: for parameters with magnitude of 1e4, h=1e-2 is demonstrated to be significant + """ + + grads, hs = [], [] + if (h is None): # h is set to auto + for x_i in x: + # Get the magnitude of the parameter + coeff = get_magnitude_torch(x_i) - 6 + hs.append(float(10**coeff)) else: - _x.requires_grad = requires_grad - return _x + hs = [h for _ in x] -# ----------------------------------------------------------------------------- -# poor man's jax-like API -# ----------------------------------------------------------------------------- + for i in range(len(x)): + # NOTE: the copy of x will be with requires_gradient=False + x_p = x.clone() + x_m = x.clone() + x_p[i] += hs[i] + x_m[i] -= hs[i] -def _wrap_input(func): - def wrapper(_x): - if isinstance(_x, torch.Tensor): - x = _x - else: - x = torch.Tensor(np.atleast_1d(_x)) - x.requires_grad = True - if x.grad is not None: - x.grad.zero_() - return func(x) + dfdi = (f(x_p) - f(x_m))/(2*hs[i]) + grads.append(dfdi) - return wrapper + return grads -# only to make scalar args work -@_wrap_input -def cos(x): - return torch.cos(x) +class TransferFunction(Function): + """ + Extend torch.autograd capabilities by designing a custom class TransferFunction that inherits from torch.autograd.Function. + This allows to manually define both forward and backward methods. + The gradient is propagated in the backward method via JVP + """ + @staticmethod + def forward(ctx, function_input, u, time): + # Direct computation: compute the forward operation i.e the output of the transfer function + output = forced_response(model(function_input), u, time) -@_wrap_input -def func(x): - return func_plain_torch(x) + # Save the current input and output for further computation of the gradient + ctx.save_for_backward(function_input, u, time) + + return output + @staticmethod + def backward(ctx, grad_output): + # Try to bind the output gradient with the input gradient. i.e. chain rule -def grad(func): - @_wrap_input - def _gradfunc(x): - out = func(x) - out.backward(torch.ones_like(out)) - # x.grad is a Tensor of x.shape which holds the derivatives of func - # w.r.t each x[i,j,k,...] evaluated at x, got it? - return x.grad + f_input, u, time, = ctx.saved_tensors - return _gradfunc + # Create g(x) where x are the params. + # This allows to test the function by manually changing the single parameter + # See grad function to understand why + gx = lambda p: forced_response(model(p), u, time) + + # Compute the gradients wrt each parameter + grads = grad(gx, f_input, h=1e-3) + # Apply the chain rule for each partial derivative to update each parameter p + out = [grad_output*i for i in grads] + + # Convert the output from a list of partial derivatives to a N-by-p matrix, with p number of parameters, N size of data over time + out = torch.stack(out, dim=1) -elementwise_grad = grad + # sum all the gradients to match the needed output dimension, i.e. p + out = torch.sum(out, dim=0) + return out, None, None + + def test(): - # Check that grad() works - assert torch.allclose(grad(torch.sin)(1.234), cos(1.234)) - x = rand(10) * 5 - 5 - assert torch.allclose(elementwise_grad(torch.sin)(x), torch.cos(x)) - assert grad(func)(x).shape == x.shape - - # Show 4 different pytorch grad APIs - x1 = rand(3, requires_grad=True) - - # 1 - c1 = func_plain_torch(x1) - c1.backward() - g1 = x1.grad - - # 2 - x2 = copy(x1, requires_grad=True) - c2 = func_plain_torch(x2) - torch.autograd.backward(c2) - g2 = x2.grad - assert (g1 == g2).all() - - # 3 - x2 = copy(x1, requires_grad=True) - c2 = func_plain_torch(x2) - g2 = torch.autograd.grad(c2, x2)[0] - assert (g1 == g2).all() - - # 4 - x2 = copy(x1) - g2 = torch.autograd.functional.vjp(func_plain_torch, x2)[1] - assert (g1 == g2).all() - - # jax-like functional API defined here - x2 = copy(x1) - g2 = grad(func)(x2) - assert (g1 == g2).all() + # Definition of time, input, parameters, ground truth + time = torch.tensor(np.linspace(1, 10, 101, endpoint=False)) + u = torch.sin(time).requires_grad_(True) + + ref_params = torch.tensor([-1, 1, 3, -4, 1, -1, 0, 1, 0], requires_grad=True, dtype=torch.double) + + + # The extended class has to be called via the .apply method. + # It is easier to assign it to an intermediate variable + myTransferFunction = TransferFunction.apply + + test_passed =gradcheck(myTransferFunction, (ref_params.requires_grad_(True), u.requires_grad_(False), time.requires_grad_(False))) if __name__ == "__main__": - test() + test() \ No newline at end of file