Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CodeStyle][Typos][I-15] Fix typo infered (part4) #70985

Merged
merged 2 commits into from
Jan 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions paddle/phi/infermeta/spmd_rules/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,36 +27,36 @@ using phi::distributed::auto_parallel::str_join;

////////////////// Utils Functions //////////////////

TensorDistAttr GetMatmulInferedDistAttr(
TensorDistAttr GetMatmulInferredDistAttr(
const TensorDistAttr& origin_dist_attr,
const std::vector<int64_t>& shape,
const std::string& tensor_axis,
const std::unordered_map<std::string, int64_t>& axis_to_dim_map,
bool trans_axis) {
TensorDistAttr dist_attr = CopyTensorDistAttrForOutput(origin_dist_attr);
std::vector<int64_t> infered_dims_mapping;
infered_dims_mapping.reserve(tensor_axis.size());
std::vector<int64_t> inferred_dims_mapping;
inferred_dims_mapping.reserve(tensor_axis.size());

for (size_t i = 0; i < tensor_axis.size(); ++i) {
if (shape.size() > i && shape[i] == 1) {
infered_dims_mapping.push_back(-1);
inferred_dims_mapping.push_back(-1);
} else {
auto itr = axis_to_dim_map.find(tensor_axis.substr(i, 1));
if (itr == axis_to_dim_map.end()) {
// infer the k axis as -1 in inferbackward.
infered_dims_mapping.push_back(-1);
inferred_dims_mapping.push_back(-1);
} else {
infered_dims_mapping.push_back(itr->second);
inferred_dims_mapping.push_back(itr->second);
}
}
}

if (trans_axis) {
std::iter_swap(infered_dims_mapping.end() - 2,
infered_dims_mapping.end() - 1);
std::iter_swap(inferred_dims_mapping.end() - 2,
inferred_dims_mapping.end() - 1);
}

dist_attr.set_dims_mapping(infered_dims_mapping);
dist_attr.set_dims_mapping(inferred_dims_mapping);
return dist_attr;
}

Expand Down Expand Up @@ -199,9 +199,9 @@ SpmdInfo MatmulInferSpmd(const DistMetaTensor& x,
if (trans_y) {
std::iter_swap(y_shape.end() - 2, y_shape.end() - 1);
}
TensorDistAttr x_dist_attr_dst = GetMatmulInferedDistAttr(
TensorDistAttr x_dist_attr_dst = GetMatmulInferredDistAttr(
x_dist_attr_src, x_shape, x_axes, axis_to_dim_map, trans_x);
TensorDistAttr y_dist_attr_dst = GetMatmulInferedDistAttr(
TensorDistAttr y_dist_attr_dst = GetMatmulInferredDistAttr(
y_dist_attr_src, y_shape, y_axes, axis_to_dim_map, trans_y);

// Step2.3: Handle Partial
Expand Down Expand Up @@ -256,9 +256,9 @@ SpmdInfo MatmulInferSpmdReverse(const DistMetaTensor& x,
auto axis_to_dim_map =
ShardingMergeForTensors({{out_axes, out_dims_mapping}}, false);

TensorDistAttr x_dist_attr_dst = GetMatmulInferedDistAttr(
TensorDistAttr x_dist_attr_dst = GetMatmulInferredDistAttr(
x.dist_attr(), x_shape, x_axes, axis_to_dim_map, trans_x);
TensorDistAttr y_dist_attr_dst = GetMatmulInferedDistAttr(
TensorDistAttr y_dist_attr_dst = GetMatmulInferredDistAttr(
y.dist_attr(), y_shape, y_axes, axis_to_dim_map, trans_y);

// step3: Handle Partial
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/infermeta/spmd_rules/reshape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ std::vector<int64_t> InferTargetShape(const std::vector<int64_t>& shape,
0,
common::errors::InvalidArgument(
"The total element number of the src tensor (%lld) is not "
"divisible by the infered size (%lld) of the -1 dimension.",
"divisible by the inferred size (%lld) of the -1 dimension.",
len,
infer_size));
new_shape[infer_idx] = infer_size;
Expand Down
44 changes: 22 additions & 22 deletions python/paddle/distributed/auto_parallel/static/operators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,8 +717,8 @@ def merge_forward_backward_dims_mapping(fw_results, bw_results):
flatten_bw_outputs = paddle.utils.flatten(bw_results[1])
ninputs = len(flatten_fw_inputs)
noutputs = len(flatten_fw_outputs)
infered_input_dims_mappings = []
infered_output_dims_mappings = []
inferred_input_dims_mappings = []
inferred_output_dims_mappings = []

for i in range(ninputs):
compatible_dims_mapping = compute_compatible_dims_mapping(
Expand All @@ -727,7 +727,7 @@ def merge_forward_backward_dims_mapping(fw_results, bw_results):
flatten_bw_inputs[i].dims_mapping,
]
)
infered_input_dims_mappings.append(compatible_dims_mapping)
inferred_input_dims_mappings.append(compatible_dims_mapping)

for i in range(noutputs):
compatible_dims_mapping = compute_compatible_dims_mapping(
Expand All @@ -736,60 +736,60 @@ def merge_forward_backward_dims_mapping(fw_results, bw_results):
flatten_bw_outputs[i].dims_mapping,
]
)
infered_output_dims_mappings.append(compatible_dims_mapping)
return infered_input_dims_mappings, infered_output_dims_mappings
inferred_output_dims_mappings.append(compatible_dims_mapping)
return inferred_input_dims_mappings, inferred_output_dims_mappings


def update_op_dims_mapping(
dist_op, input_arg_names, output_arg_names, fw_results, bw_results
):
(
infered_input_dims_mappings,
infered_output_dims_mappings,
inferred_input_dims_mappings,
inferred_output_dims_mappings,
) = merge_forward_backward_dims_mapping(fw_results, bw_results)

op_dist_attr = dist_op.dist_attr
changed = False
if len(input_arg_names) != len(infered_input_dims_mappings):
if len(input_arg_names) != len(inferred_input_dims_mappings):
warnings.warn(
f"dims mapping is NOT Match, infered [{len(infered_input_dims_mappings)}], original: [{len(input_arg_names)}]; dist op: [{dist_op}]"
f"dims mapping is NOT Match, inferred [{len(inferred_input_dims_mappings)}], original: [{len(input_arg_names)}]; dist op: [{dist_op}]"
)
if len(output_arg_names) != len(infered_output_dims_mappings):
if len(output_arg_names) != len(inferred_output_dims_mappings):
warnings.warn(
f"dims mapping is NOT Match, infered [{len(infered_output_dims_mappings)}], original: [{len(output_arg_names)}]; dist op: [{dist_op}]"
f"dims mapping is NOT Match, inferred [{len(inferred_output_dims_mappings)}], original: [{len(output_arg_names)}]; dist op: [{dist_op}]"
)

for i in range(len(input_arg_names)):
original_dims_mapping = op_dist_attr.get_input_dims_mapping(
input_arg_names[i]
)
infered_dims_mapping = infered_input_dims_mappings[i]
if (infered_dims_mapping is not None) and (
original_dims_mapping != infered_dims_mapping
inferred_dims_mapping = inferred_input_dims_mappings[i]
if (inferred_dims_mapping is not None) and (
original_dims_mapping != inferred_dims_mapping
):
_logger.debug(
f"Changed: Op [{dist_op.serial_op.type}], name [{input_arg_names[i]}], Original [{original_dims_mapping}], Infered [{infered_dims_mapping}]"
f"Changed: Op [{dist_op.serial_op.type}], name [{input_arg_names[i]}], Original [{original_dims_mapping}], Inferred [{inferred_dims_mapping}]"
)
changed = True
op_dist_attr.set_input_dims_mapping(
input_arg_names[i], infered_dims_mapping
input_arg_names[i], inferred_dims_mapping
)
# TODO support partial for inputs

for i in range(len(output_arg_names)):
original_dims_mapping = op_dist_attr.get_output_dims_mapping(
output_arg_names[i]
)
infered_dims_mapping = infered_output_dims_mappings[i]
if (infered_dims_mapping is not None) and (
original_dims_mapping != infered_dims_mapping
inferred_dims_mapping = inferred_output_dims_mappings[i]
if (inferred_dims_mapping is not None) and (
original_dims_mapping != inferred_dims_mapping
):
_logger.debug(
f"Changed: Op [{dist_op.serial_op.type}], name [{output_arg_names[i]}], Original [{original_dims_mapping}], Infered [{infered_dims_mapping}]"
f"Changed: Op [{dist_op.serial_op.type}], name [{output_arg_names[i]}], Original [{original_dims_mapping}], Inferred [{inferred_dims_mapping}]"
)
changed = True
op_dist_attr.set_output_dims_mapping(
output_arg_names[i], infered_dims_mapping
output_arg_names[i], inferred_dims_mapping
)

# NOTE in partial stage-I, we infer partial for output in infer_forward only
Expand All @@ -802,7 +802,7 @@ def update_op_dims_mapping(
!= output_dist_attr._partial_dims()
):
# _logger.info(
# "Changed: Op [{}], tensor name [{}], Original partial on [{}], Infered partial on [{}]".format(
# "Changed: Op [{}], tensor name [{}], Original partial on [{}], Inferred partial on [{}]".format(
# dist_op.serial_op.type,
# output_arg_names[i],
# output_dist_attr._partial_dims(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ def update_dims_mapping(dist_op):
if changed:
(
_,
infered_output_dims_mappings,
inferred_output_dims_mappings,
) = merge_forward_backward_dims_mapping(fw_results, bw_results)
dist_op.dist_attr.set_output_dims_mapping(
mask_name, infered_output_dims_mappings[0]
mask_name, inferred_output_dims_mappings[0]
)

return changed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ def update_dims_mapping(dist_op):
)

# step4: update xshape
infered_input_dims_mappings, _ = merge_forward_backward_dims_mapping(
inferred_input_dims_mappings, _ = merge_forward_backward_dims_mapping(
fw_results, bw_results
)
dist_op.dist_attr.set_output_dims_mapping(
xshape_name, [-1] + infered_input_dims_mappings[0]
xshape_name, [-1] + inferred_input_dims_mappings[0]
)

return changed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ def update_dims_mapping(dist_op):
)

# step4: update xshape
infered_input_dims_mappings, _ = merge_forward_backward_dims_mapping(
inferred_input_dims_mappings, _ = merge_forward_backward_dims_mapping(
fw_results, bw_results
)
dist_op.dist_attr.set_output_dims_mapping(
xshape_name, [-1] + infered_input_dims_mappings[0]
xshape_name, [-1] + inferred_input_dims_mappings[0]
)

return changed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class Partitioner:
Given a serial program which has been auto completed with shard annotation, the Partitioner
convert the serial program into a "distributed" program. The Partitioner will modify the serial
program in following two ways, which is also the major difference between serial and distributed program:
1. partition op: replace a serial op into its corresponding dist op infered from the shard annotation
1. partition op: replace a serial op into its corresponding dist op inferred from the shard annotation
2. partition var: if a var is sharded, modify the shape of var according to its shard annotation

Partitioner is supposed to be call by the auto parallel framework, and not supposed to be directly called by user.
Expand Down
Loading
Loading