
At Scale AI, we use Machine Learning models in a wide range of applications to
empower our data labeling pipeline. We strive for speed and efficiency, and
always try to get the best out of the models. Here, we will discuss some
tricks we discovered that drastically improve over the PyTorch Transformer
implementation in just a few lines of code.
Transformers have become ubiquitous. They were first introduced in Attention
is All You Need
were quickly added to Pytorch. Their popularity increased even more with the
development of HuggingFace, which made large NLP pre-trained models such as
BERT
widely accessible, and created recipes to enable simple fine-tuning on a wide
range of tasks. They’ve been successfully applied to a wide variety of
sequence-to-sequence (Seq2Seq) tasks including machine translation, text
summarization, or even image captioning (an image is just a sequence of
pixels!). This popularity is completely warranted, because Transformers have
some significant upsides:
But Transformers also have their weaknesses. When generating sequences for
Seq2Seq tasks at inference time, Transformers are constrained because each
item in the output sequence can only be predicted one at a time. This,
combined with the quadratic attention complexity can make them slower than
their counterparts. (For training, this is not an issue thanks to teacher
forcing).
Seq2Seq models typically create an internal high-level representation of the
input sequence and then decode (i.e. generate) the output sentence. Given the
high-level representation of the input sentence and the words that have
already been decoded, Seq2Seq models estimate the most likely words to
complete the sentence. This phenomenon is called auto-regression and
the phase corresponding to generating a new word (or token) is a
timestep.
When a Transformer is used as a Seq2Seq model, the input sequence is fed
through an Encoder, and the output sequence is then generated by a Decoder, as
illustrated in figures 1 and 2.
The Transformer class in Pytorch is generic which is great because it gives
the ML researchers at Scale AI fine-tuned control but that also means it isn’t
optimized for speed. Let’s take a deeper look.
First, it can be seen in Figure 1 that the encoder output can be computed
separately from the decoder. This means that the encoder outputs can be
computed once and re-used for each timestep thereafter. But Pytorch does NOT
save this for you - and in fact wastes compute for each decoding timestep. To
fix this, the Transformer Encoder and Decoder should always be separated.
# THIS IS THE NAIVE WAY TO USE TRANSFORMERS
# INITIALIZATION
transformer = nn.Transformer(
d_model=hdim,
nhead=nhead,
num_encoder_layers=num_layers,
num_decoder_layers=num_layers,
dim_feedforward=dim_feedforward,
).to(device=device)
transformer.eval()
# INFERENCE LOOP
decoded_tokens = first token
for i in range(len_output_to_decode) : # generate `len_output_to_decode` tokens
mask_dec = generate_square_subsequent_mask(
i + 1, device=first_token.device
) # create mask for autoregressive decoding
decoded_embeddings = embedding(decoded_tokens)
output = transformer(src, decoded_embeddings, tgt_mask=mask_dec)
logits = to_vocab(output) # projection to vocab size
# keep most likely tokens
top_indices = torch.argmax(logits, dim=-1)
# we only care about the last token that was decoded
top_indices_last_token = top_indices[-1:]
# add most likely token to the already decoded tokens
decoded_tokens = torch.cat(
[decoded_tokens, top_indices_last_token], dim=0
)
The code below is a much more efficient way to get the same results by
decoupling the encoder and the decoder. Note that the code corresponding to
the inference loop barely changes.
# INITIALIZATION
encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=hdim, nhead=nhead, dim_feedforward=dim_feedforward
),
num_layers=num_layers,
).to(device=device)
encoder.eval()
decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(
d_model=hdim, nhead=nhead, dim_feedforward=dim_feedforward
),
num_layers=num_layers,
).to(device=device)
decoder.eval()
# INFERENCE LOOP
decoded_tokens = first_token
src_embeddings = encoder(src)
for i in range(lenoutput_to_decode):
mask_dec = generate_square_subsequent_mask(
i + 1, device=first_token.device
) # create mask for autoregressive decoding
decoded_embeddings = embedding(decoded_tokens)
# the decoder uses the encoder output `src_embeddings`
output = decoder(decoded_embeddings, src_embeddings, tgt_mask=mask_dec)
logits = to_vocab(output) # projection to vocab size
# keep most likely tokens
top_indices = torch.argmax(logits, dim=-1)
# we only care about the last token that was decoded
top_indices_last_token = top_indices[-1:]
# add most likely token to the already decoded tokens
decoded_tokens = torch.cat(
[decoded_tokens, top_indices_last_token], dim=0
)
The main inefficiency extends on the previous point. It can be seen in Figure
2 that the embedding of a decoded token only depends on the tokens that were
decoded before it. This is a direct benefit of the Transformer model being
autoregressive. Thus, it is unnecessary to recompute the embeddings of the
already decoded tokens repeatedly and instead, we can again cache them. Each
timestep then only consists of computing the attention for the newest token’s
embedding.
Figure 3: Decoder self-attention links when decoding tokens.
The boxes at the bottom represent the embeddings of the output tokens
before self-attention, the top boxes represent the embeddings of the output
tokens after self-attention. Using our trick (right side), most of the
embeddings are not recomputed as they are cached. The number of links to
into account becomes linear instead of quadratic.
From a complexity perspective, generating the n-th output token without our
trick involves computing self-attention over the entire current output
(O(n²)) and computing encoder-decoder attention between the whole
input (of size that we will note M) with the current output (O(Mn)).
Hence, the complexity of each timestep is O(Mn + n²). Given that we
want to decode N tokens, N timesteps are needed, and the final complexity is
O(MN² + N³).
Our trick accelerates each timestep. Only the parts of the self-attention and
encoder-decoder attention responsible for updating the last token are
computed. Figure 3 shows how this works for the self-attention. The new
complexity of each timestep is O(M + N), so with N timesteps the
final complexity is sped up to O(MN + N²).
The PyTorch Transformer decoder architecture is not assumed to be
autoregressive. However, by inheriting the TransformerDecoder layer, we
introduce a CausalTransformerDecoder which uses a cache to implement the
improvement above. Our
differs from the Pytorch implementation by a few lines only. Our new decoder
works similarly to the original TransformerDecoder, except that we now have to
take into account the cache:
causal_decoder = CausalTransformerDecoder(
CausalTransformerDecoderLayer(
d_model=hdim,
nhead=nhead,
dim_feedforward=dim_feedforward,
),
num_layers=6,
).to(device=device)
causal_decoder.eval()
decoded_tokens = first_token
src_embeddings = encoder(src)
cache = None
for i in range(len_output_to_decode):
mask_dec = generate_square_subsequent_mask(
i + 1, device=first_token.device
) # create mask for autoregressive decoding
decoded_embeddings = embedding(decoded_tokens)
# only change here: we add the cache as an extra parameter
output, cache = causal_decoder(decoded_embeddings, src_embeddings, cache)
logits = to_vocab(output) # projection to vocab size
# keep most likely tokens
top_indices = torch.argmax(logits, dim=-1)
# we only care about the last token that was decoded
top_indices_last_token = top_indices[-1:]
# add most likely token to the already decoded tokens
decoded_tokens = torch.cat(
[decoded_tokens, top_indices_last_token], dim=0
)
We put our changes to the test to see how much faster we could get. We present
two different scenarios: translation and generation of long texts.
We compare our three different implementations
As a reminder, these are three different implementations of the same model.
When initialized with the same weights, they return the same outputs.
The first setting corresponds to translation. In this setting the input and
output sequences are generally short and of similar lengths.
The non-linear curves show that the attention mechanisms progressively
become the most compute-intensive parts of the model as the number of input
and output tokens increase.
Our causal implementation is up to 40% faster than the Pytorch
Encoder-Decoder implementation, and 150% faster than the Pytorch
nn.Transformer implementation for 500 input/output tokens.
We now ask the model to generate long sequences from a fixed size input. Such
a situation might arise when generating a story from an image or from an
initial prompt.
The results below were obtained with a fixed input size of 500 tokens.
Increasing the number of input tokens makes the models slower but doesn’t
change the overall trends observed.
Our causal model is twice as fast as the PyTorch encoder-decoder
implementation when the number of tokens to generate exceeds 1,000.
When decoding more than 500 tokens, the time ratio between the causal model
and the other implementations becomes linear. This confirms the theory
according to which the overall decoding complexity was reduced by a factor of
N.
Finally, our CausalTransformerDecoder can also be used without any input
sentence (i.e. without an encoder), as it is the case in some common
generation settings. The model is typically asked to complete a story or an
article. More information about this type of generation can be found in the
GPT papers
(Alec Radford et al., 2018). The results we find for this case are similar to the ones above.
One might notice that caching the output of each layer is sub-optimal. Indeed,
the first stage of the attention layers consists of projecting the embeddings
to the keys, values and queries spaces. In the PyTorch implementation and the
proposed implementation, the same embeddings get projected repeatedly.
Instead, the queries, keys and values could be directly cached. However, this
requires substantially more changes, which could become unstable with new
Pytorch upgrades. Moreover, the estimated gains are minor - less than 5% from
our experiments.
For standard NLP use cases, the HuggingFace repository already embeds these
optimizations. Notably, it caches keys and values. It also comes with
different decoding flavors, such as beam search or nucleus sampling.
The simple tricks proposed take advantage of the fact that the overall Pytorch
implementation of the Transformer is too generic. The changes provide modest
speed improvements when generating a few hundreds of tokens which can become
significant boosts over the original PyTorch implementation when the output
length nears a thousand tokens. These gains are naturally directly
proportional to the number of output tokens to decode. And best of all, they
can be implemented in just a few lines using our
repo.
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan
N Gomez, Łukasz Kaiser, and Illia Polosukhin.
Attention is all you need. In Advances in Neural Information
Processing Systems 30, pp. 5998–6008. Curran Associates, Inc., 2017. URL:
https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf
Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova.
BERT: pre-training of deep bidirectional transformers for language
understanding. CoRR, abs/1810.04805, 2018. URL:
http://arxiv.org/abs/1810.04805.
Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn,
Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani et al.
An Image is Worth 16x16 Words: Transformers for Image Recognition at
Scale.
arXiv preprint 2020 URL:
https://arxiv.org/pdf/2010.11929
Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Clement Delangue,
Anthony Moi, Pierric Cistac et al.
Transformers: State-of-the-art natural language processing arXiv
preprint 2019 URL:
https://arxiv.org/pdf/1910.03771.pdf
Alec Radford, Karthik Narasimhan, Tim Salimans, and Ilya Sutskever.
Improving language understanding by generative pre-training. arXiv
preprint 2018 URL:
https://www.cs.ubc.ca/~amuham01/LING530/papers/radford2018improving.pdf
1 - The vocabulary size used for this experiment was 30,000. It corresponds to
the vocabulary size chosen in the BERT original paper, which uses BPE
tokenization. We also ran the same experiments with a much smaller vocabulary
(128 tokens), to imitate a character-level setting which did not show
significant benefits.
2 - The hyperparameters used for the Transformer architecture are the ones of
the original paper (6 layers, 8 heads, 512 hidden dimensions, 2048
feed-forward hidden dimensions for both encoder/decoder). The results should
be similar with other configurations, provided that the encoder and decoder
have the same size.
3 - The displayed results correspond to a batch-size of 8 sequences, but we
made sure that a batch-size of 1 gives the same trend. A batch-size of 1 is
usually the most common in inference as requests are sent asynchronously.
However, in our specific setup the model takes several seconds to generate
sentences, so it is more natural to batch requests.
3 - As a safety check, we benchmarked GPT-2 HuggingFace implementation against
our Causal Decoder. To do that, we used the same set of hyperparameters. We
generated up to 1000 tokens with the two models. The speed ratio between these
two models was close to 1, oscillating between 0.85 and 1.10.
4 - All the experiments were run on a V100 GPU.