Please help me to understand how mlx gradients flow through the graph #1106
Answered
by
awni
AndreiChirap
asked this question in
Q&A
-
Hi everyone, I'm trying to learn how MLX's Autograd engine works. Can someone help me translate this PyTorch code snippet to MLX? I think this will help me understand the process better. import math
import torch
import torchvision
mnist = torchvision.datasets.MNIST(root='/tmp', train=True, download=True)
digits = mnist.data.reshape(-1, 28 * 28)/255
targets = mnist.targets
train_data = digits[:50000]
train_targets = targets[:50000]
test_data = digits[50000:]
test_targets = targets[50000:]
lr = 0.09
in_features = 28 * 28
out_features = 10
scale = math.sqrt(1/in_features)
uniform = torch.distributions.uniform.Uniform(low=-scale, high=scale)
W = uniform.sample((out_features, in_features))
W.requires_grad = True
b = uniform.sample((out_features,))
b.requires_grad = True
def zero_grad(*params):
for p in params:
p.grad = None
@torch.no_grad()
def step(*params, lr=lr):
for p in params:
p -= lr * p.grad
for i in range(1000):
zero_grad(W, b)
out = torch.nn.functional.gelu(torch.addmm(b, train_data, W.T))
loss = torch.nn.functional.cross_entropy(out, train_targets, reduction="mean")
loss.backward()
step(W, b)
_, idx = out.max(1)
print("Train Acc: ", (train_targets == idx).sum()/len(train_targets))
out = torch.nn.functional.gelu(torch.addmm(b, test_data, W.T))
_, idx = out.max(1)
print("Test Acc: ", (test_targets == idx).sum()/len(test_targets)) Thanks! |
Beta Was this translation helpful? Give feedback.
Answered by
awni
Nov 17, 2024
Replies: 1 comment 1 reply
-
I would take a look at our MNIST example as a starting point.. it should touch every aspect of the above though the model / optimizer are different. For an even more basic intro to this check out the Linear Regression example. Also check-out the usage guid on function transformations and more specifically the section on autograd. |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
AndreiChirap
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I would take a look at our MNIST example as a starting point.. it should touch every aspect of the above though the model / optimizer are different.
For an even more basic intro to this check out the Linear Regression example.
Also check-out the usage guid on function transformations and more specifically the section on autograd.