From 79316ee3cfbcf0311fd087fd1e0903ad78b47735 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Fri, 6 Sep 2024 06:14:38 +0200 Subject: [PATCH] format --- .../op_registry/op_handlers.py | 6 +++-- .../parallelization/parallel_layers/loss.py | 27 ++++++++++--------- optimum/fx/parallelization/passes.py | 10 +++---- optimum/fx/parallelization/utils.py | 6 ++--- 4 files changed, 27 insertions(+), 22 deletions(-) diff --git a/optimum/fx/parallelization/op_registry/op_handlers.py b/optimum/fx/parallelization/op_registry/op_handlers.py index c6e4f721e7..4a9c55e376 100644 --- a/optimum/fx/parallelization/op_registry/op_handlers.py +++ b/optimum/fx/parallelization/op_registry/op_handlers.py @@ -19,7 +19,7 @@ from torch.fx import Node from ..core import Config -from ..utils import is_activation, is_embedding, is_linear, is_cross_entropy, is_cross_entropy_parallel_compatible +from ..utils import is_activation, is_cross_entropy, is_cross_entropy_parallel_compatible, is_embedding, is_linear class Registry: @@ -450,7 +450,9 @@ def propagate(self) -> List[int]: elif is_cross_entropy(self.node): logits = self.node.all_input_nodes[0] axis = self.extract_axis(logits) - if axis is None or (is_cross_entropy_parallel_compatible(self.node) and axis == logits.meta['val'].ndim - 1): + if axis is None or ( + is_cross_entropy_parallel_compatible(self.node) and axis == logits.meta["val"].ndim - 1 + ): # for cross entropy, the input logits parallel axis can only be the last axis or None return [None] else: diff --git a/optimum/fx/parallelization/parallel_layers/loss.py b/optimum/fx/parallelization/parallel_layers/loss.py index de3b005035..0a11e33c08 100644 --- a/optimum/fx/parallelization/parallel_layers/loss.py +++ b/optimum/fx/parallelization/parallel_layers/loss.py @@ -12,11 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import wraps +from typing import Optional + import torch -import torch.nn as nn import torch.distributed as dist -from typing import Optional -from functools import wraps +import torch.nn as nn + from ..core import ParallelExecutionCtx @@ -100,7 +102,7 @@ def backward(ctx, grad_output: torch.Tensor): return grad_input, None, None -def sharded_cross_entropy(sharded_logits: torch.Tensor, target: torch.Tensor,process_group: dist.ProcessGroup): +def sharded_cross_entropy(sharded_logits: torch.Tensor, target: torch.Tensor, process_group: dist.ProcessGroup): return _ShardedCrossEntropy.apply(sharded_logits, target, process_group) @@ -127,15 +129,15 @@ def wrapper( reduce = True if reduce is None else reduce if size_average and reduce: - reduction = 'mean' + reduction = "mean" elif reduce: - reduction = 'sum' + reduction = "sum" else: - reduction = 'none' + reduction = "none" - if reduction == 'mean': + if reduction == "mean": return loss.mean() - elif reduction == 'sum': + elif reduction == "sum": return loss.sum() return loss @@ -146,15 +148,16 @@ class VocabParallelCrossEntropyLoss(nn.Module): """ Simple parallel cross entropy implementation which does not support weighted mode and label smoothing yet. """ - def __init__(self, ctx: ParallelExecutionCtx, reduction: str = 'mean') -> None: + + def __init__(self, ctx: ParallelExecutionCtx, reduction: str = "mean") -> None: super(VocabParallelCrossEntropyLoss, self).__init__() self.process_group = ctx.tp_group self.reduction = reduction def forward(self, sharded_logits: torch.Tensor, target: torch.Tensor): loss: torch.Tensor = _ShardedCrossEntropy.apply(sharded_logits, target, self.process_group) - if self.reduction == 'mean': + if self.reduction == "mean": return loss.mean() - elif self.reduction == 'sum': + elif self.reduction == "sum": return loss.sum() return loss diff --git a/optimum/fx/parallelization/passes.py b/optimum/fx/parallelization/passes.py index 97d18b3b0a..9015526328 100644 --- a/optimum/fx/parallelization/passes.py +++ b/optimum/fx/parallelization/passes.py @@ -21,6 +21,7 @@ import torch.distributed as dist import torch.nn as nn from torch.fx import Graph, GraphModule, Node + from .core import Config, ParallelExecutionCtx, ParameterMeta from .decomp import decompose_and_functionalize from .distributed import scatter @@ -28,15 +29,15 @@ from .parallel_layers import ( ColumnParallelLinear, RowParallelLinear, - VocabParallelEmbedding, VocabParallelCrossEntropyLoss, - sharded_cross_entropy_wrapper_fn + VocabParallelEmbedding, + sharded_cross_entropy_wrapper_fn, ) from .utils import ( + is_cross_entropy, is_embedding, is_linear, is_shape_consumer, - is_cross_entropy, stable_topological_sort, ) @@ -282,7 +283,7 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf elif is_cross_entropy(node): axis_before = ParallelAxisSolverPass.get_stored_field_info(node.args[0], "parallel_axis") if axis_before is not None: - self.place_marker_per_node(node, {'axis' : 'vocab'}) + self.place_marker_per_node(node, {"axis": "vocab"}) return graph_module @@ -383,7 +384,6 @@ def handle_cross_entropy(node: Node, ctx: ParallelExecutionCtx) -> None: else: node.target = sharded_cross_entropy_wrapper_fn(process_group=ctx.tp_group) - @staticmethod def handle_hard_coded_axis_param(node: Node, ctx: ParallelExecutionCtx) -> None: def extract_shape_from_node(node: Node) -> List[Any]: diff --git a/optimum/fx/parallelization/utils.py b/optimum/fx/parallelization/utils.py index 83a926af6b..3074638737 100644 --- a/optimum/fx/parallelization/utils.py +++ b/optimum/fx/parallelization/utils.py @@ -96,9 +96,9 @@ def is_cross_entropy_parallel_compatible(node: Node) -> bool: For now `VocabParallelCrossEntropyLoss` does not support weighted mode, index ignoring and label smoothing. """ if node.op == "call_function": - weight = node.kwargs.get('weight', None) - ignore_index = node.kwargs.get('ignore_index', -100) - label_smoothing = node.kwargs.get('label_smoothing', 0.0) + weight = node.kwargs.get("weight", None) + ignore_index = node.kwargs.get("ignore_index", -100) + label_smoothing = node.kwargs.get("label_smoothing", 0.0) if len(node.args) > 2 and weight is None: weight = node.args[2] if len(node.args) > 4 and ignore_index == -100: