Skip to content

Commit

Permalink
rendering fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mikaylagawarecki committed Oct 30, 2024
1 parent ca34438 commit d83f14b
Showing 1 changed file with 27 additions and 26 deletions.
53 changes: 27 additions & 26 deletions intermediate_source/transformer_building_blocks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
Dismantling down the ``nn.Transformer`` modules for gains and profits
======================================================================
Dismantling the ``nn.Transformer`` modules for gains and profits
=================================================================
**Author:** `Mikayla Gawarecki <https://github.com/mikaylagawarecki>`_
The ``torch.nn`` module currently provides various ``Transformer``-related layers.
Expand All @@ -16,6 +16,7 @@
1. People want to add slight customizations to their transformer layers
2. Writing these layers and customizations is not hard
Supporting all transformer variants via a small number of out of the box layers would
yield too many keyword arguments. This tutorial will describe how to build your
own performant transformer layers following our recommended best practices.
Expand All @@ -41,7 +42,7 @@
If you are only interested in performant attention score modifications, please
head to the `FlexAttention blog <https://pytorch.org/blog/flexattention/>`_ that
contains a `gym of masks <https://github.com/pytorch-labs/attention-gym>`_ .
contains a `gym of masks <https://github.com/pytorch-labs/attention-gym>`_.
If you are wondering about what building blocks the ``torch`` library provides
for writing your own transformer layers and best practices, you are in the
Expand All @@ -60,7 +61,7 @@
sequence lengths. They eliminate the need for the bug-prone practices of explicit
padding and masking (think ``key_padding_mask`` in ``nn.MultiHeadAttention``).
*`scaled_dot_product_attention <https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html>`_
* `scaled_dot_product_attention <https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html>`_
``scaled_dot_product_attention`` is a primitive for
:math:`\text{softmax}(\frac{QK^T}{\sqrt{E}} + B)V` that dispatches into either fused
Expand Down Expand Up @@ -101,6 +102,7 @@
2. Layer ordering (where to apply norms, where to apply positional encoding etc.)
3. Modifications to attention score (``ALiBi``, Relative Positional Bias etc.)
In a pre-compiler world, one might write their custom transformer and observe
that it works but is slow. Then, one might write a custom fused kernel for
the specific series of ops. In a compiler world, one can do the former, compile
Expand All @@ -118,24 +120,24 @@
# The improvements are threefold:
#
# * User Experience
# Recall that ``nn.MultiheadAttention`` requires ``query```, ``key`` and
# ``value`` to be dense ``torch.Tensor``s. It also provides a
# ``key_padding_mask`` that is used to mask out padding tokens in the ``key``
# that arise due to different sequence lengths within a batch. Since there is
# no ``query_padding_mask`` in ``nn.MHA``, users have to take care to mask/slice
# the outputs appropriately to account for query sequence lengths. Nested tensor
# cleanly removes the need for this sort of error-prone padding masks.
# Recall that ``nn.MultiheadAttention`` requires ``query``, ``key`` and
# ``value`` to be dense ``torch.Tensors``. It also provides a
# ``key_padding_mask`` that is used to mask out padding tokens in the ``key``
# that arise due to different sequence lengths within a batch. Since there is
# no ``query_padding_mask`` in ``nn.MHA``, users have to take care to mask/slice
# the outputs appropriately to account for query sequence lengths. Nested tensor
# cleanly removes the need for this sort of error-prone padding masks.
#
# * Memory
# Instead of materializing a dense ``[B, S, D]`` tensor with a ``[B, S]``
# padding mask (where ``B`` is batch size, ``S`` is max sequence length in the
# batch and ``D`` is embedding size), nested tensors allow you to cleanly
# represent the batch of varying sequence lengths. As a result, the input and
# intermediate activations will use less memory.
# Instead of materializing a dense ``[B, S, D]`` tensor with a ``[B, S]``
# padding mask (where ``B`` is batch size, ``S`` is max sequence length in the
# batch and ``D`` is embedding size), nested tensors allow you to cleanly
# represent the batch of varying sequence lengths. As a result, the input and
# intermediate activations will use less memory.
#
# * Performance
# Since padding is not materialized and unnecessary computation on padding is
# skipped, performance and memory usage improve.
# Since padding is not materialized and unnecessary computation on padding is
# skipped, performance and memory usage improve.
#
# We'll demonstrate the above by building off the ``MultiheadAttention`` layer in the
# `Nested Tensor tutorial <https://pytorch.org/tutorials/prototype/nestedtensor.html>`_
Expand Down Expand Up @@ -319,8 +321,8 @@ def benchmark(func, *args, **kwargs):

##############################################################################
# We will now demonstrate the performance improvements of using nested tensors
# in the ``MultiheadAttention`` layer for self attention. We compare this against
# the traditional ``nn.MultiheadAttention`` with padding and masking.
# in the ``MultiheadAttention`` layer + compile for self attention. We compare this against
# the traditional ``nn.MultiheadAttention`` + compile with padding and masking.

N, E_q, E_k, E_v, E_total = 512, 512, 512, 512, 512
E_out = E_q
Expand Down Expand Up @@ -392,8 +394,9 @@ def benchmark(func, *args, **kwargs):

######################################################################################
# We can also see the same for backward pass
# padding-specific step: remove output projection bias from padded entries for fair comparison

for i, entry_length in enumerate(sentence_lengths):
# padding-specific step: remove output projection bias from padded entries for fair comparison
padded_result[i, entry_length:, :] = 0.0

_, padded_bw_time, padded_bw_peak_mem = benchmark(lambda : padded_result.sum().backward())
Expand All @@ -417,7 +420,7 @@ def benchmark(func, *args, **kwargs):
# this is fairly straightforward using the ``MultiheadAttention`` layer above and
# gives equivalent results to an ``nn.TransformerEncoderLayer`` with
# ``is_causal=True``.

#
# We demonstrate examples of implementing the rest of the nn layers
# `here <https://github.com/mikaylagawarecki/temp>`_ but omit that from this
# tutorial for brevity.
Expand All @@ -438,9 +441,7 @@ def benchmark(func, *args, **kwargs):
# * Packed Projection
# * Cross Attention
# * Fully masked rows no longer cause ``NaN``s
# * [TODO] Modifying attention score: Relative Positional Embedding with NJT
# * [TODO] KV-Caching with NJT
# * [TODO] Grouped Query Attention with NJT
# * Modifying attention score: ALiBi with FlexAttention and NJT

###############################################################################
# Packed Projection
Expand Down Expand Up @@ -566,7 +567,7 @@ def forward(self, x):
# ---------------
# Cross attention is a form of attention where the query and key/value tensors
# are from different sequences.

#
# One example of this is in ``nn.TransformerDecoderLayer`` where the query comes
# from the decoder and the key/value come from the encoder.
#
Expand Down

0 comments on commit d83f14b

Please sign in to comment.