One of the main goals for torchtitan was to provide a version of distributed LLM that was not only high performance, but utilized native PyTorch techniques and readable code. The challenge is how to compose together so many individual library components (FSDP, TP, PP, Float8, Compile, DCP, ..., just to name a few), and avoid having to make too many changes to the model guts in the process. A lot of the work is behind the scenes, designing individual components to make fewer assumptions, use common abstractions (e.g. DTensor) and generally "get along". But we found a few tweaks to the model code invaluable as well, and wanted to share those changes and the rationale for them.
When applying Pipeline Parallelism, you will have to construct nn.Module objects representing the portion of the model that runs on a given pipeline stage. Whether you plan to manually edit your model code, or use techniques like tracing to extract model chunks, a few changes to the original model code can go a long way to making this process easier.
Most likely, you can write your model in such a way that the top-level nn.Module owns a sequence of child modules that it calls during forward, delegating most of the complexity to the child module forwards. If you can reduce your top level forward to mostly a for-loop over child module calls, then you'll simplify the pipeline-partitioning task to choosing the set of submodules to keep per stage. If you have non-trivial logic in the top-level forward, you'll have to find a way to patch that logic back onto the resulting pipeline stage model, which can be annoying.
Example (PR #321):
We used to slice the freqs_cis
buffer by seq_len
in the top level forward, pass that into child modules, and expect that inside the child modules the seq_len
would match up with the size of other local tensors. But we don't know about whether TP was applied or not when we consider PP splitting and could create a mismatch. Its just as easy to perform the freqs_cis
slicing inside the child submodule, using the runtime-accurate local seq_len
, and this sidesteps the issue at PP slicing time.
Example (PR #322): We decided to actually reuse the top-level model object on every PP stage, just delete the layers we don't want, and make sure that the top-level forward would do the right thing. This means we don't have to make a separate runtime pp_forward that glues together child modules per stage. The first change was using a moduledict instead of modulelist to store layers. This preserves layer Fully Qualified Names (FQNs) even when deleting some layers - e.g. layers.1 stays layers.1 even if you remove layers.0, which isn't true for a list- this matters for checkpoint save/load. Preserving FQNs is a requirement for using Distributed Checkpointing (DCP) since it uses FQNs as globally unique IDs for sharding metadata. The second change was making the input and output layers optional- if the layer exists, we run it, otherwise we feed the input through to bypass it. With these two changes, we can just (meta)-initialize the whole model, delete the unused parts per stage, then materialize the remaining part on GPU before loading a checkpoint.
Initializing the pipeline-parallel model is challenging becuase we assume the model could be so large as to not fit on local GPU (or possibly, even on CPU), and we also want to use the (bitwise) same initialization as we use for 1D or 2D parallel models, to ease debugging or comparisons between runs. It's not that easy to rewrite the original model's init_weights
function to be tolerant of initializing only some layers, and also serializing initialization operations globally for consistent RNG order.
For now, we sidestep all these problems with a simple but brutal solution: Initialize the whole model on some CPU instance, save a checkpoint file, and then lean on Distributed Checkpointing's "load" functionality to initialize the FQNs that are present on a given PP stage after stage creation. For future work, we consider adding a more elaborate initialization scheme to torch.pipelining
.
One issue with seed checkpoints is that we rely on initializing every model state from the checkpoint, which means the model can't have any non-persistent buffers, or else we have to specially initialize those in train.py after pipeline splitting. freqs_cis
was originally a non-persistent buffer, and we changed this to persistent in order to load it from the seed checkpoint.
We intentionally upcast the final output tensor to fp32 inside the loss function rather in the Transformer.forward()
so that forward and backward casts can be fused with the loss forward and backward respectively when we torch.compile()
the loss function. This can improve both throughput and memory usage.
Users should set the environment variable TORCH_NCCL_AVOID_RECORD_STREAMS=1
when using tensor parallelism (TP) to avoid unexpectedly high memory usage.
TP uses async collectives (i.e. with async_op=True
), such as all-gather, reduce-scatter, and all-reduce, to overlap communication with compute. Under the hood, an async collective runs the NCCL communication kernel in a separate CUDA stream, owned by the process group. Calling wait()
on the returned work object has the current stream wait for the process group's stream, allowing the current stream to correctly use the result of the collective.
This represents a producer-consumer pattern across streams: the collective tensors are produced in a compute stream (usually the default stream), and they are consumed in a communication stream (from the process group). Under such producer-consumer patterns across streams, we must ensure that the tensors are not freed before their usage in the consumer stream.
Tensor.record_stream
is a legacy approach for ensuring this. The process group will call record_stream(comm_stream)
on the collective input and output tensors after issuing the collective kernel in the process group's comm_stream
. This records a CUDA event in comm_stream
, and the CUDA caching allocator that manages CUDA tensor memory in PyTorch will query this recorded event upon future allocations. Only once the event has completed, meaning that the collective has finished running, can the tensor memory be freed and considered for future reuse. This couples the caching allocator's memory reuse with GPU kernel timing, which does not otherwise happen without record_stream
. While the collective kernel runs on GPU, any allocations made from the CPU for future ops cannot reuse that memory, even if we know that those future ops must run after the current collective. This inability to reuse leads to unexpected memory stacking.
By setting TORCH_NCCL_AVOID_RECORD_STREAMS=1
, the process group avoids calling record_stream
on the collective tensors and instead uses a different approach. It simply stashes references to the collective tensors until the user calls wait()
on the work object. Holding references ensures that the collective tensors will not be freed by the caching allocator. This can only lead to a memory regression if the user never calls wait()
, where with record_stream
, the caching allocator would still eventually free the collective tensors once the collective finishes on the GPU. Since this is not common or an expected usage, we recommend setting this environment variable.