Skip to content

Commit

Permalink
Decouple and run MLP runs for comparisons
Browse files Browse the repository at this point in the history
  • Loading branch information
rpsilva-aws committed Jan 15, 2025
1 parent bb1dbaf commit b65e408
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 149 deletions.
1 change: 0 additions & 1 deletion test/neuron/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ function run_xla_op_tests3 {
run_test "$CDIR/spmd/test_xla_auto_sharding.py"
#run_test "$CDIR/spmd/test_spmd_parameter_wrapping.py"
run_test "$CDIR/spmd/test_train_spmd_linear_model.py"
run_test "$CDIR/spmd/test_train_spmd_linear_model.py" "$@" --use_gradient_checkpointing
run_test "$CDIR/spmd/test_xla_spmd_python_api_interaction.py"
run_test "$CDIR/spmd/test_xla_auto_sharding.py"
run_test "$CDIR/spmd/test_fsdp_v2.py"
Expand Down
2 changes: 1 addition & 1 deletion test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ function run_xla_op_tests3 {
run_test "$CDIR/spmd/test_xla_auto_sharding.py"
run_test "$CDIR/spmd/test_spmd_parameter_wrapping.py"
run_test "$CDIR/spmd/test_mp_input_sharding.py"
run_test "$CDIR/spmd/test_train_spmd_linear_model.py"
run_test "$CDIR/spmd/test_train_spmd_linear_model.py" "$@" --skip-gradient-checkpointing
run_save_tensor_hlo run_test "$CDIR/spmd/test_spmd_lowering_context.py"
run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY
run_test "$CDIR/test_input_output_aliases.py"
Expand Down
Empty file added test/spmd/__init__.py
Empty file.
185 changes: 39 additions & 146 deletions test/spmd/test_train_spmd_linear_model.py
Original file line number Diff line number Diff line change
@@ -1,164 +1,57 @@
import argparse
from contextlib import contextmanager
import os
import sys
from typing import Optional
import unittest

import numpy as np
import torch
from torch import nn
import torch.optim as optim

import args_parse
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
import torch_xla.utils.utils as xu
from torch_xla.distributed.spmd import Mesh
from torch_xla.utils.checkpoint import checkpoint
import test_xla_sharding_base

parent_folder = os.path.dirname(os.path.dirname(__file__))
sys.path.append(parent_folder)
from utils.train_spmd_linear_model import train_and_evaluate

MODEL_OPTS = {
'--sharding': {
'choices': ['batch', 'megatron-lm', 'fsdp'],
'nargs': '+',
'default': [],
},
'--input_dim': {
'type': int,
'default': 16834,
},
'--train_dataset_len': {
'type': int,
'default': 1024 * 8,
},
'--use_gradient_checkpointing': {
'action': 'store_true',
}
}
SKIP_GRADIENT_CHECKPOINTING: bool = False

FLAGS = {}
PROFILER_SERVER = None

@contextmanager
def extended_argv(args):
original_argv = sys.argv[:]
sys.argv.extend(args)
try:
yield
finally:
sys.argv = original_argv

class SimpleLinear(nn.Module):
NUM_CLASSES = 3

def __init__(self):
super().__init__()
self.layers = torch.nn.Sequential(
nn.Linear(FLAGS.input_dim, FLAGS.input_dim // 2),
nn.ReLU(),
nn.Linear(FLAGS.input_dim // 2, 3),
# # Add an additional 3x3 layer at the end to ensure the final layer
# # is not sharded.
nn.Linear(3, self.NUM_CLASSES),
)
class TestSPMDLinearModel(test_xla_sharding_base.XlaShardingTest):

def forward(self, x):
if FLAGS.use_gradient_checkpointing:
for n_l, layer in enumerate(self.layers):
# Apply gradient checkpointing for reduced memory footprint.
# This would result in increased computation cost.
if n_l > 0:
x = checkpoint(layer, x)
else:
x = layer(x)
else:
x = self.layers(x)
return x
def test_basic(self):
print('Training loop with baseline')
with extended_argv([]):
baseline_losses, baseline_result = train_and_evaluate()
# Verify that the model losses are not zero.
assert all(loss != 0 for loss in baseline_losses)
# Verify that the model produces non-zero outputs.
assert not torch.any(baseline_result == 0)


def train():
device = xm.xla_device()
torch.manual_seed(42)
model = SimpleLinear().to(device)
print('===> Preparing data..')
train_loader = xu.SampleGenerator(
data=(torch.randn(FLAGS.batch_size, FLAGS.input_dim),
torch.randint(
0, model.NUM_CLASSES, (FLAGS.batch_size,), dtype=torch.int64)),
sample_count=FLAGS.train_dataset_len // FLAGS.batch_size)

num_devices = xr.global_runtime_device_count()
print(f'num_devices: {num_devices}')
# Define a mesh with all devices along one axis
mesh_shape = (num_devices, 1)
device_ids = np.arange(num_devices)
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))

if 'batch' in FLAGS.sharding:
train_loader = pl.MpDeviceLoader(
train_loader, device, input_sharding=xs.ShardingSpec(mesh, (0, 1)))

if 'fsdp' in FLAGS.sharding:
train_loader = pl.MpDeviceLoader(
train_loader, device, input_sharding=xs.ShardingSpec(mesh, (0, 1)))
print('Sharding model weights')
# Shard the weights according to their 0th dim
xs.mark_sharding(model.layers[0].weight, mesh, (0, 1))
xs.mark_sharding(model.layers[2].weight, mesh, (0, 1))

if 'megatron-lm' in FLAGS.sharding:
print('Sharding model weights')
# Shard the first layer's weights row-wise
xs.mark_sharding(model.layers[0].weight, mesh, (0, 1))
# Shard the second layer's weights column-wise
xs.mark_sharding(model.layers[2].weight, mesh, (1, 0))

optimizer = optim.SGD(model.parameters(), lr=FLAGS.lr)

loss_fn = nn.CrossEntropyLoss()

def train_loop_fn(loader, epoch):
model.train()
for step, (data, target) in enumerate(loader):
with xp.StepTrace('train_linear_model'):
with xp.Trace('build_graph'):
x = data.to(device)
y = target.to(device)
optimizer.zero_grad()
output = model(x)
loss = loss_fn(output, y)
losses.append(loss.clone().detach())
loss.backward()
optimizer.step()
xm.mark_step()
if step % FLAGS.log_steps == 0:
print(f"Epoch {epoch} step {step} loss {loss}")

losses = []
for epoch in range(FLAGS.num_epochs):
train_loop_fn(train_loader, epoch)
return losses, model


def train_and_evaluate():
default_config = {
'batch_size': 128,
'num_epochs': 1,
'lr': 0.1,
'log_steps': 8,
'opts': MODEL_OPTS.items()
}

global PROFILER_SERVER, FLAGS
FLAGS = args_parse.parse_common_options(**default_config)
if FLAGS.profile:
PROFILER_SERVER = xp.start_server(FLAGS.profiler_port)
xr.use_spmd(auto=FLAGS.auto_spmd)
print('Start training loop...')
losses, m = train()
t = torch.randn(10, FLAGS.input_dim).to(xm.xla_device())
return [loss.cpu().item() for loss in losses], m(t).cpu()
if not SKIP_GRADIENT_CHECKPOINTING:
print('Training loop with gradient checkpointing')
with extended_argv(['--use_gradient_checkpointing']):
checkpointing_losses, checkpointing_result = train_and_evaluate()
# Verify that the runs match with and without checkpointing.
assert torch.allclose(baseline_result, checkpointing_result)
assert all(
torch.allclose(baseline_loss, checkpointing_loss)
for baseline_loss, checkpointing_loss in zip(
baseline_losses, checkpointing_losses))


if __name__ == '__main__':
losses, result = train_and_evaluate()
# Verify that the model losses are not zero.
assert all(loss != 0 for loss in losses)
# Verify that the model produces non-zero outputs.
assert torch.all(result != 0)
parser = argparse.ArgumentParser()
parser.add_argument('--skip-gradient-checkpointing', action='store_true')
parsed_args, remaining_argv = parser.parse_known_args()
SKIP_GRADIENT_CHECKPOINTING = parsed_args.skip_gradient_checkpointing
test = unittest.main(argv=[sys.argv[0]] + remaining_argv)
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 0 additions & 1 deletion test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ python3 "$TEST_CDIR/spmd/test_xla_sharding.py"
python3 "$TEST_CDIR/spmd/test_xla_virtual_device.py"
python3 "$TEST_CDIR/spmd/test_xla_distributed_checkpoint.py"
python3 "$TEST_CDIR/spmd/test_train_spmd_linear_model.py"
python3 "$TEST_CDIR/spmd/test_train_spmd_linear_model.py" "$@" --use_gradient_checkpointing
python3 "$TEST_CDIR/spmd/test_xla_spmd_python_api_interaction.py"
python3 "$TEST_CDIR/spmd/test_xla_auto_sharding.py"
python3 "$TEST_CDIR/spmd/test_fsdp_v2.py"
Expand Down
Empty file added test/utils/__init__.py
Empty file.
152 changes: 152 additions & 0 deletions test/utils/train_spmd_linear_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import sys
from typing import Optional

import numpy as np
import torch
from torch import nn
import torch.optim as optim

import args_parse
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
import torch_xla.utils.utils as xu
from torch_xla.distributed.spmd import Mesh
from torch_xla.utils.checkpoint import checkpoint

MODEL_OPTS = {
'--sharding': {
'choices': ['batch', 'megatron-lm', 'fsdp'],
'nargs': '+',
'default': [],
},
'--input_dim': {
'type': int,
'default': 16834,
},
'--train_dataset_len': {
'type': int,
'default': 1024 * 8,
},
'--use_gradient_checkpointing': {
'action': 'store_true',
}
}

FLAGS = {}
PROFILER_SERVER = None


class SimpleLinear(nn.Module):
NUM_CLASSES = 3

def __init__(self):
super().__init__()
self.layers = torch.nn.Sequential(
nn.Linear(FLAGS.input_dim, FLAGS.input_dim // 2),
nn.ReLU(),
nn.Linear(FLAGS.input_dim // 2, 3),
# # Add an additional 3x3 layer at the end to ensure the final layer
# # is not sharded.
nn.Linear(3, self.NUM_CLASSES),
)

def forward(self, x):
if FLAGS.use_gradient_checkpointing:
for n_l, layer in enumerate(self.layers):
# Apply gradient checkpointing for reduced memory footprint.
# This would result in increased computation cost.
if n_l > 0:
x = checkpoint(layer, x)
else:
x = layer(x)
else:
x = self.layers(x)
return x


def train():
device = xm.xla_device()
torch.manual_seed(42)
model = SimpleLinear().to(device)
print('===> Preparing data..')
train_loader = xu.SampleGenerator(
data=(torch.randn(FLAGS.batch_size, FLAGS.input_dim),
torch.randint(
0, model.NUM_CLASSES, (FLAGS.batch_size,), dtype=torch.int64)),
sample_count=FLAGS.train_dataset_len // FLAGS.batch_size)

num_devices = xr.global_runtime_device_count()
print(f'num_devices: {num_devices}')
# Define a mesh with all devices along one axis
mesh_shape = (num_devices, 1)
device_ids = np.arange(num_devices)
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))

if 'batch' in FLAGS.sharding:
train_loader = pl.MpDeviceLoader(
train_loader, device, input_sharding=xs.ShardingSpec(mesh, (0, 1)))

if 'fsdp' in FLAGS.sharding:
train_loader = pl.MpDeviceLoader(
train_loader, device, input_sharding=xs.ShardingSpec(mesh, (0, 1)))
print('Sharding model weights')
# Shard the weights according to their 0th dim
xs.mark_sharding(model.layers[0].weight, mesh, (0, 1))
xs.mark_sharding(model.layers[2].weight, mesh, (0, 1))

if 'megatron-lm' in FLAGS.sharding:
print('Sharding model weights')
# Shard the first layer's weights row-wise
xs.mark_sharding(model.layers[0].weight, mesh, (0, 1))
# Shard the second layer's weights column-wise
xs.mark_sharding(model.layers[2].weight, mesh, (1, 0))

optimizer = optim.SGD(model.parameters(), lr=FLAGS.lr)

loss_fn = nn.CrossEntropyLoss()

def train_loop_fn(loader, epoch):
model.train()
for step, (data, target) in enumerate(loader):
with xp.StepTrace('train_linear_model'):
with xp.Trace('build_graph'):
x = data.to(device)
y = target.to(device)
optimizer.zero_grad()
output = model(x)
loss = loss_fn(output, y)
losses.append(loss.clone().detach())
loss.backward()
optimizer.step()
xm.mark_step()
if step % FLAGS.log_steps == 0:
print(f"Epoch {epoch} step {step} loss {loss}")

losses = []
for epoch in range(FLAGS.num_epochs):
train_loop_fn(train_loader, epoch)
return losses, model


def train_and_evaluate():
default_config = {
'batch_size': 128,
'num_epochs': 1,
'lr': 0.1,
'log_steps': 8,
'opts': MODEL_OPTS.items()
}

global PROFILER_SERVER, FLAGS
FLAGS = args_parse.parse_common_options(**default_config)
if FLAGS.profile:
PROFILER_SERVER = xp.start_server(FLAGS.profiler_port)
xr.use_spmd(auto=FLAGS.auto_spmd)
print('Start training loop...')
losses, m = train()
t = torch.randn(10, FLAGS.input_dim).to(xm.xla_device())
return [loss.cpu() for loss in losses], m(t).cpu()

0 comments on commit b65e408

Please sign in to comment.