Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question/Bug] DP sharding parameters are inconsistent with others. #2563

Open
JacoCheung opened this issue Nov 18, 2024 · 4 comments
Open

Comments

@JacoCheung
Copy link

JacoCheung commented Nov 18, 2024

Description:
Hi ,torchrec team,
I'm using EmbeddingCollection and constrain the sharding type as DATA_PARALLEL. Subsequently, I should be able to get the parameters and pass it to my optimizers. However, there are serveral problems I encoutered.

Looking forward to any input . Thanks!

TableBatchedEmbeddingSlice is not a leaf tensor

  1. The .parameters() or embedding_collection._dmp_wrapped_module.embeddings['0'].weight returns a TableBatchedEmbeddingSlice, which is unexpectedly not a leaf tensor.
ebc = EmbeddingCollection(
    device=torch.device("meta"),
    tables=[
        EmbeddingConfig(
            name="product_table",
            embedding_dim=4,
            num_embeddings=4,
            feature_names=["product"],
            init_fn=init_fn,
        ),
    ]
)
...

model=DMP(ebc)
weight = model._dmp_wrapped_module.embeddings['product_table'].weight

# This is False
weight.is_leaf

Unfortunately, my app requires such a flag to perform some operations.

model.bfloat16() detachs the weight storage

When I convert the whole model into lower precision, say bf16, the underlying DP tables storage are not affected, however, the weight / params accessor are converted as expected. Then the weight and storage seems to be untied. The optimizer would take no effect on the original storage. A reproducible script is as below:

import os
import torch
import torch.distributed as dist
import torchrec
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.modules.embedding_modules import EmbeddingCollection
from torchrec.modules.embedding_configs import EmbeddingConfig

from fbgemm_gpu.split_embedding_configs import SparseType

from torch.distributed.optim import (
    _apply_optimizer_in_backward as apply_optimizer_in_backward,
)
from torchrec.distributed.embedding import EmbeddingCollectionSharder

from torchrec.distributed.planner import EmbeddingShardingPlanner,  ParameterConstraints
from torchrec.distributed.types import (
    ShardingType,
)
from torchrec.optim.optimizers import in_backward_optimizer_filter
torch.manual_seed(1024)

os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
dist.init_process_group(backend="nccl")

def init_fn(x, val=0.1):
    with torch.no_grad():
        dim0 = x.size(0)
        dim1 = x.size(1)
        natural_seq = torch.arange(0,dim0).cuda().to(x.dtype).unsqueeze(-1)
        x.copy_(natural_seq.expand(-1, dim1))
ebc = EmbeddingCollection(
    device=torch.device("meta"),
    tables=[
        EmbeddingConfig(
            name="product_table",
            embedding_dim=4,
            num_embeddings=4,
            feature_names=["product"],
            init_fn=init_fn,
        ),
    ]
)
sharding_types = [ShardingType.DATA_PARALLEL.value]
constraints = {"product_table": ParameterConstraints(sharding_types=sharding_types)}
from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward
planner = EmbeddingShardingPlanner(
    constraints=constraints,
)
sharders = [EmbeddingCollectionSharder(fused_params = {'output_dtype':SparseType.BF16 }) ]
plan = planner.collective_plan(ebc, sharders, pg = dist.GroupMember.WORLD)

apply_optimizer_in_backward(
    optimizer_class=torch.optim.SGD,
    params=ebc.parameters(),
    optimizer_kwargs={"lr": 0.02},
)

model = torchrec.distributed.DistributedModelParallel(ebc, sharders=sharders, device=torch.device("cuda"), plan = plan)
model = model.bfloat16()

# model.module._dmp_wrapped_module.embeddings['product'].weight
ref_optimizer = torch.optim.SGD(
    dict(in_backward_optimizer_filter(model.named_parameters())).values(),
    lr=0.02,
)
mb = KeyedJaggedTensor(
    keys = ["product"],
    values = torch.tensor([0, 1, 2]).cuda(), # key [0,1] on rank0, [2] on rank 1
    lengths = torch.tensor([3], dtype=torch.int64).cuda(),
)
ret = model(mb)['product'].values()
ret.sum().backward()
ref_optimizer.step()
params = dict(model.named_parameters())
weight = model._dmp_wrapped_module.embeddings['product_table'].weight

print(f"params {params}")
print(f'weight { model.module._lookups[0].module._emb_modules[0]._emb_module.split_embedding_weights()}')

The params gets updated while the underlying storage split_embedding_weights remains the same ( And the next lookup does see the old storage ).

params:

[-0.0200, -0.0200, -0.0200, -0.0200],
[ 0.9805,  0.9805,  0.9805,  0.9805],
[ 1.9766,  1.9766,  1.9766,  1.9766],
[ 3.0000,  3.0000,  3.0000,  3.0000]

storage:

[0., 0., 0., 0.],
[1., 1., 1., 1.],
[2., 2., 2., 2.],
[3., 3., 3., 3.]

grad_fn AsStride when TableBatchedEmbeddingSlice as an operand

Besides, I find out the next_functions is started with an AsStridedBackward0 ahead of AccumulateGrad after if TableBatchedEmbeddingSlice object is as an operand, for example:

myself = params['embeddings.product_table.weight'] * torch.ones_like(params['embeddings.product_table.weight'])
# this is ((<AsStridedBackward0 object at 0x7ff51a936440>, 0), (None, 0))
myself.grad_fn.next_functions
# this one is ((<AccumulateGrad object at 0x7ff51a936440>, 0),)
myself.grad_fn.next_functions[0][0].next_functions

I would like to know why there is an AsStridedBackward0 . One lib that I depend on requires accessing the AccumulateGrad in only one jump.

@PaulZhang12
Copy link
Contributor

@TroyGarden

@TroyGarden
Copy link
Contributor

sorry for the late response. @JacoCheung it looks like an invalid access to the weights. could you please try using the weights in the state_dict?
I'll try to reproduce the issue locally, due to bandwidth limitation likely in next week.

@TroyGarden
Copy link
Contributor

TroyGarden commented Nov 24, 2024

@JacoCheung reguarding quesiont 1 "TableBatchedEmbeddingSlice is not a leaf tensor"
I managed to reproduce the issue with world_size =1
As you can see from the following debugger output, the weight is a TableBatchedEmbeddingSlice, which represents a slice of a table batched embedding (even though in the test case it only has one slice). You'll have to use the _original_data to get the original weights

weight = model._dmp_wrapped_module.embeddings['product_table'].weight
weight
Parameter containing:
Parameter(TableBatchedEmbeddingSlice([[ 0.3090,  0.2935, -0.2901,  0.4279],
                            [ 0.3136,  0.2422, -0.0231, -0.0045],
                            [-0.1398, -0.3822,  0.2852, -0.4772],
                            [ 0.3793, -0.3837, -0.4460,  0.0480]],
                           requires_grad=True))
weight.is_leaf
False
weight._original_tensor.is_leaf
True

If you want the leaf weight, you can try using the parameters from the state_dict

w1 = model._dmp_wrapped_module.state_dict()['embeddings.product_table.weight']
w1
ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4, 4], placement=rank:0/cpu)], size=torch.Size([4, 4]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False)))
w1.is_leaf
True

or

w2 = model._dmp_wrapped_module.embeddings.state_dict()['product_table.weight']
w2
tensor([[ 0.3090,  0.2935, -0.2901,  0.4279],
        [ 0.3136,  0.2422, -0.0231, -0.0045],
        [-0.1398, -0.3822,  0.2852, -0.4772],
        [ 0.3793, -0.3837, -0.4460,  0.0480]])
w2.is_leaf
True

Hope that answered your question.

@TroyGarden
Copy link
Contributor

@JacoCheung as for question 2, it's kind of similar to question 1.

  1. You can also use the state_dict to access the weights.
  2. when you call model.bfloat16, it actually copies all the parameters in the old model and casts them to bfloat16. However, the TableBatchedEmbeddingSlice is only a slice of the original weights, so its _original_data won't change to the new model's weights (bfloat16), it actually still references to the old model's weights (float32). So after the backward, the old model won't be updated.

My suggestion is do the bfloat16 casting before calling the DMP, or more directly, setting the embedding config with dtype=bfloat16.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants