Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenglongjiepheonix committed Sep 6, 2024
1 parent f0c8a6b commit 79316ee
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 22 deletions.
6 changes: 4 additions & 2 deletions optimum/fx/parallelization/op_registry/op_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.fx import Node

from ..core import Config
from ..utils import is_activation, is_embedding, is_linear, is_cross_entropy, is_cross_entropy_parallel_compatible
from ..utils import is_activation, is_cross_entropy, is_cross_entropy_parallel_compatible, is_embedding, is_linear


class Registry:
Expand Down Expand Up @@ -450,7 +450,9 @@ def propagate(self) -> List[int]:
elif is_cross_entropy(self.node):
logits = self.node.all_input_nodes[0]
axis = self.extract_axis(logits)
if axis is None or (is_cross_entropy_parallel_compatible(self.node) and axis == logits.meta['val'].ndim - 1):
if axis is None or (
is_cross_entropy_parallel_compatible(self.node) and axis == logits.meta["val"].ndim - 1
):
# for cross entropy, the input logits parallel axis can only be the last axis or None
return [None]
else:
Expand Down
27 changes: 15 additions & 12 deletions optimum/fx/parallelization/parallel_layers/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import wraps
from typing import Optional

import torch
import torch.nn as nn
import torch.distributed as dist
from typing import Optional
from functools import wraps
import torch.nn as nn

from ..core import ParallelExecutionCtx


Expand Down Expand Up @@ -100,7 +102,7 @@ def backward(ctx, grad_output: torch.Tensor):
return grad_input, None, None


def sharded_cross_entropy(sharded_logits: torch.Tensor, target: torch.Tensor,process_group: dist.ProcessGroup):
def sharded_cross_entropy(sharded_logits: torch.Tensor, target: torch.Tensor, process_group: dist.ProcessGroup):
return _ShardedCrossEntropy.apply(sharded_logits, target, process_group)


Expand All @@ -127,15 +129,15 @@ def wrapper(
reduce = True if reduce is None else reduce

if size_average and reduce:
reduction = 'mean'
reduction = "mean"
elif reduce:
reduction = 'sum'
reduction = "sum"
else:
reduction = 'none'
reduction = "none"

if reduction == 'mean':
if reduction == "mean":
return loss.mean()
elif reduction == 'sum':
elif reduction == "sum":
return loss.sum()
return loss

Expand All @@ -146,15 +148,16 @@ class VocabParallelCrossEntropyLoss(nn.Module):
"""
Simple parallel cross entropy implementation which does not support weighted mode and label smoothing yet.
"""
def __init__(self, ctx: ParallelExecutionCtx, reduction: str = 'mean') -> None:

def __init__(self, ctx: ParallelExecutionCtx, reduction: str = "mean") -> None:
super(VocabParallelCrossEntropyLoss, self).__init__()
self.process_group = ctx.tp_group
self.reduction = reduction

def forward(self, sharded_logits: torch.Tensor, target: torch.Tensor):
loss: torch.Tensor = _ShardedCrossEntropy.apply(sharded_logits, target, self.process_group)
if self.reduction == 'mean':
if self.reduction == "mean":
return loss.mean()
elif self.reduction == 'sum':
elif self.reduction == "sum":
return loss.sum()
return loss
10 changes: 5 additions & 5 deletions optimum/fx/parallelization/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,23 @@
import torch.distributed as dist
import torch.nn as nn
from torch.fx import Graph, GraphModule, Node

from .core import Config, ParallelExecutionCtx, ParameterMeta
from .decomp import decompose_and_functionalize
from .distributed import scatter
from .op_registry import REGISTRY, FallbackParallelAxisPropagateHandler
from .parallel_layers import (
ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding,
VocabParallelCrossEntropyLoss,
sharded_cross_entropy_wrapper_fn
VocabParallelEmbedding,
sharded_cross_entropy_wrapper_fn,
)
from .utils import (
is_cross_entropy,
is_embedding,
is_linear,
is_shape_consumer,
is_cross_entropy,
stable_topological_sort,
)

Expand Down Expand Up @@ -282,7 +283,7 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf
elif is_cross_entropy(node):
axis_before = ParallelAxisSolverPass.get_stored_field_info(node.args[0], "parallel_axis")
if axis_before is not None:
self.place_marker_per_node(node, {'axis' : 'vocab'})
self.place_marker_per_node(node, {"axis": "vocab"})

return graph_module

Expand Down Expand Up @@ -383,7 +384,6 @@ def handle_cross_entropy(node: Node, ctx: ParallelExecutionCtx) -> None:
else:
node.target = sharded_cross_entropy_wrapper_fn(process_group=ctx.tp_group)


@staticmethod
def handle_hard_coded_axis_param(node: Node, ctx: ParallelExecutionCtx) -> None:
def extract_shape_from_node(node: Node) -> List[Any]:
Expand Down
6 changes: 3 additions & 3 deletions optimum/fx/parallelization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ def is_cross_entropy_parallel_compatible(node: Node) -> bool:
For now `VocabParallelCrossEntropyLoss` does not support weighted mode, index ignoring and label smoothing.
"""
if node.op == "call_function":
weight = node.kwargs.get('weight', None)
ignore_index = node.kwargs.get('ignore_index', -100)
label_smoothing = node.kwargs.get('label_smoothing', 0.0)
weight = node.kwargs.get("weight", None)
ignore_index = node.kwargs.get("ignore_index", -100)
label_smoothing = node.kwargs.get("label_smoothing", 0.0)
if len(node.args) > 2 and weight is None:
weight = node.args[2]
if len(node.args) > 4 and ignore_index == -100:
Expand Down

0 comments on commit 79316ee

Please sign in to comment.