Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PP API and nD Distributed Timeline Profiling #41

Merged
merged 3 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/pictures/ndtimeline_arch.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/pictures/ndtimeline_trace.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/pictures/pp.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
MackZackA marked this conversation as resolved.
Show resolved Hide resolved
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
2 changes: 1 addition & 1 deletion examples/open_llama_4D_benchmark/llama_mfu_calculator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
MackZackA marked this conversation as resolved.
Show resolved Hide resolved
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
MackZackA marked this conversation as resolved.
Show resolved Hide resolved
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
2 changes: 1 addition & 1 deletion examples/open_llama_4D_benchmark/sharding_plan.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
MackZackA marked this conversation as resolved.
Show resolved Hide resolved
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pytest
tqdm
optree
accelerate
transformers==4.37.2
transformers==4.40.2
flash_attn
matplotlib
mmh3
4 changes: 1 addition & 3 deletions test/checkpoint/nano_gpt/test_nano_gpt_load_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,7 @@ def init_method(self):
@skip_unless_torch_gpu
@with_comms
def test_load(self):
ddp_gpt, dist_optimizer, _ = build_gpt_model_optimizer_and_dataset(
self.init_method, dp_size=2, tp_size=2
)
ddp_gpt, dist_optimizer, _ = build_gpt_model_optimizer_and_dataset(self.init_method, dp_size=2, tp_size=2)

# Load the model and optimizer after first data

Expand Down
2 changes: 1 addition & 1 deletion test/checkpoint/open_llama/test_open_llama_dp_reshard.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
MackZackA marked this conversation as resolved.
Show resolved Hide resolved
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
2 changes: 1 addition & 1 deletion test/checkpoint/open_llama/test_open_llama_load_save.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
MackZackA marked this conversation as resolved.
Show resolved Hide resolved
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
2 changes: 1 addition & 1 deletion test/checkpoint/open_llama/test_open_llama_tp_reshard.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
MackZackA marked this conversation as resolved.
Show resolved Hide resolved
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
6 changes: 4 additions & 2 deletions test/model/open_llama/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def test_attention(self):
input.retain_grad()
non_parallel_attention, _ = get_model()
non_parallel_attention = non_parallel_attention.cuda()
golden_outputs = non_parallel_attention(input)
dummy_position_ids = torch.randint(low=0, high=s, size=(bsz, s)).cuda()
golden_outputs = non_parallel_attention(input, position_ids=dummy_position_ids)
golden_loss = golden_outputs[0].mean()
golden_loss.backward()

Expand Down Expand Up @@ -84,8 +85,9 @@ def test_attention(self):
d_input = distribute_tensor(input.detach(), device_mesh, [Shard(1)])
d_input.requires_grad_()
d_input.retain_grad()
d_position_id = distribute_tensor(dummy_position_ids.detach(), device_mesh, [Replicate()])

vescale_outputs = vescale_attention(d_input)
vescale_outputs = vescale_attention(d_input, position_ids=d_position_id)
vescale_outputs[0] = vescale_outputs[0].redistribute(placements=[Replicate()] * device_mesh.ndim)
vescale_loss = vescale_outputs[0].mean()

Expand Down
6 changes: 4 additions & 2 deletions test/model/open_llama/test_decoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def test_decoder(self):
input.retain_grad()
non_parallel_decoder, _ = get_model()
non_parallel_decoder = non_parallel_decoder.cuda()
golden_outputs = non_parallel_decoder(input)
dummy_position_id = torch.randint(low=0, high=s, size=(bsz, s)).cuda()
golden_outputs = non_parallel_decoder(input, position_ids=dummy_position_id)
golden_loss = golden_outputs[0].mean()
golden_loss.backward()

Expand Down Expand Up @@ -95,8 +96,9 @@ def test_decoder(self):
d_input = distribute_tensor(input.detach(), device_mesh, [Shard(1)])
d_input.requires_grad_()
d_input.retain_grad()
d_position_id = distribute_tensor(dummy_position_id.detach(), device_mesh, [Replicate()])

vescale_outputs = vescale_decoder(d_input)
vescale_outputs = vescale_decoder(d_input, position_ids=d_position_id)
vescale_outputs[0] = vescale_outputs[0].redistribute(placements=[Replicate()] * device_mesh.ndim)
vescale_loss = vescale_outputs[0].mean()

Expand Down
1 change: 1 addition & 0 deletions test/ndtimeline/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# make pylint happy
37 changes: 37 additions & 0 deletions test/ndtimeline/test_local_raw_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################

import os
from vescale.ndtimeline.world_info import WorldInfo
from vescale.ndtimeline.handlers import LocalRawNDHandler
from vescale.ndtimeline.variables import LOCAL_LOGGING_PATH


def test_basic_usage():
h = LocalRawNDHandler(run_id=0, chunk_sz=10, backup_cnt=3)
file_name = "timeline_run0_raw.log"
h("test_metric", 1.0, [1.0], [1.0], [{}], range(0, 1), WorldInfo(0, 0), {})
assert os.path.exists(os.path.join(LOCAL_LOGGING_PATH, file_name))
for _ in range(4):
h("test_metric", 1.0, [1.0], [1.0], [{}], range(0, 1), WorldInfo(0, 0), {})
h("test_metric2", 2.0, [1.0], [1.0], [{}], range(0, 1), WorldInfo(0, 0), {})
assert os.path.exists(os.path.join(LOCAL_LOGGING_PATH, file_name + ".2"))
assert not os.path.exists(os.path.join(LOCAL_LOGGING_PATH, file_name + ".4"))
os.remove(os.path.join(LOCAL_LOGGING_PATH, file_name))
for i in range(1, 4):
os.remove(os.path.join(LOCAL_LOGGING_PATH, file_name + "." + str(i)))
assert not os.path.exists(os.path.join(LOCAL_LOGGING_PATH, file_name + ".2"))
30 changes: 30 additions & 0 deletions test/ndtimeline/test_metric_level.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################

from vescale.ndtimeline import NDMetricLevel


def test_cmp_level():
assert NDMetricLevel.FRAMEWORK_DEBUG >= NDMetricLevel.INFO
assert NDMetricLevel.USER_DEBUG >= NDMetricLevel.INFO
assert NDMetricLevel.USER_DEBUG > NDMetricLevel.INFO
assert NDMetricLevel.USER_INFO < NDMetricLevel.INFO
assert NDMetricLevel.USER_INFO <= NDMetricLevel.INFO
assert NDMetricLevel.INFO < NDMetricLevel.DEBUG
assert NDMetricLevel.TRACE <= NDMetricLevel.TRACE
assert NDMetricLevel.TRACE >= NDMetricLevel.TRACE
assert NDMetricLevel.TRACE == NDMetricLevel.TRACE
61 changes: 61 additions & 0 deletions test/ndtimeline/test_parser_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################

import pytest
from vescale.ndtimeline.world_info import WorldInfo
from vescale.ndtimeline.handlers import ParserNDHandler
from vescale.ndtimeline.exceptions import NDHandlerError


def test_normal_input_with_tags():
metric_name = "test_metric"
recent_elapsed_raw_parts = [1.0, 3.2, 1.4]
elapsed = sum(recent_elapsed_raw_parts)
recent_since_start_raw_parts = [1710332816.6118143, 1710332833.2222, 1710332846.1313]
single_tag = {"is_test": True}
tags = [single_tag] * (len(recent_elapsed_raw_parts) - 1) + [{"is_test": False}]
step_range = range(0, 1)
world_info = WorldInfo(0, 0)
callback = ParserNDHandler()
records = callback(
metric_name, elapsed, recent_elapsed_raw_parts, recent_since_start_raw_parts, tags, step_range, world_info, {}
)
assert len(records) == 1
assert records[0].step == 0


def test_normal_invalid_input():
metric_name = "test_metric"
recent_elapsed_raw_parts = [1.0, 3.2, 1.4]
elapsed = sum(recent_elapsed_raw_parts)
recent_since_start_raw_parts = [1710332816.6118143, 1710332846.1313]
single_tag = {"is_test": True}
tags = [single_tag] * (len(recent_elapsed_raw_parts) - 1) + [{"is_test": False}]
step_range = range(0, 1)
world_info = WorldInfo(0, 0)
callback = ParserNDHandler()
with pytest.raises(NDHandlerError):
callback(
metric_name,
elapsed,
recent_elapsed_raw_parts,
recent_since_start_raw_parts,
tags,
step_range,
world_info,
{},
)
53 changes: 53 additions & 0 deletions test/parallel/pipeline/api/four_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################

import torch
import torch.nn as nn
import os


class MLP(nn.Module):
def __init__(self, features_in, feature_middle, features_out, value):
super().__init__()
self.value = value
self.counter = 0
self.fc1 = nn.Linear(1024, 1024, bias=False)
self.fc1.weight.data.fill_(value)
self.fc2 = nn.Linear(1024, 1024, bias=False)
self.fc2.weight.data.fill_(value * 2)
self.gelu = nn.GELU()

def forward(self, x):
t = self.fc1(x)
t = self.gelu(t)
t = self.fc2(t)
torch.save(t, f"{os.environ['model_name']}_mlp{self.value}_fwd{self.counter}_out_tensor.pt")
self.counter += 1
return t


class FourMLP(nn.Module):
def __init__(self, hidden):
super().__init__()
self.mlp1 = MLP(hidden * 1, hidden * 2, hidden * 3, 0)
self.mlp2 = MLP(hidden * 3, hidden * 4, hidden * 5, 1)
self.mlp3 = MLP(hidden * 5, hidden * 6, hidden * 7, 2)
self.mlp4 = MLP(hidden * 7, hidden * 8, hidden * 9, 3)
self.sequence = nn.Sequential(self.mlp1, self.mlp2, self.mlp3, self.mlp4)

def forward(self, x):
return self.sequence(x)
Loading