Engineering
Engineering

Making Pytorch Transformer Twice as Fast on Sequence Generation.

byon December 17, 2020


At Scale AI, we use Machine Learning models in a wide range of applications to

empower our data labeling pipeline. We strive for speed and efficiency, and

always try to get the best out of the models. Here, we will discuss some

tricks we discovered that drastically improve over the PyTorch Transformer

implementation in just a few lines of code.

Transformers are Here to Stay



Transformers have become ubiquitous. They were first introduced in Attention

is All You Need

(Vaswani et al., 2017) and

were quickly added to Pytorch. Their popularity increased even more with the

development of HuggingFace, which made large NLP pre-trained models such as

BERT

(Devlin et al., 2018)

widely accessible, and created recipes to enable simple fine-tuning on a wide

range of tasks. They’ve been successfully applied to a wide variety of

sequence-to-sequence (Seq2Seq) tasks including machine translation, text

summarization, or even image captioning (an image is just a sequence of

pixels!). This popularity is completely warranted, because Transformers have

some significant upsides:


  • The Transformer architecture is non-sequential making it distributable.
  • Other traditional methods for sequence modeling such as RNNs are limited to
  • processing sequences one token at a time. This sequential nature prevents
  • parallelization and makes training slow. Transformers process entire
  • sequences at once in a highly parallel fashion. This makes them incredibly
  • fast on GPUs and helps handle long-range dependencies elegantly.

  • Transformers make few assumptions about the data. Traditionally, ML
  • practitioners have tailored their networks to process specific types of
  • data. Constraints such as forcing RNNs to process text sequentially from
  • left-to-right allowed these networks to perform well even on scarce training
  • data, spearheading breakthroughs. However, these constraints introduce
  • biases in the model as sequential ordering is rarely the most optimal way to
  • understand text. Transformers keep data representations generic, which make
  • them capable of learning more subtle interactions between words. Recent
  • papers
  • (Alexey Dosovitskiy et al., 2020)
  • show that the same story could also be true in computer vision, with
  • transformers outperforming the long established CNNs once trained on the
  • huge datasets that have recently been collected.
  • With enough data, Transformers learn more complex and accurate
  • representations than the constrained networks which used to be the only
  • viable option.

Sequence-to-Sequence with Transformers



But Transformers also have their weaknesses. When generating sequences for

Seq2Seq tasks at inference time, Transformers are constrained because each

item in the output sequence can only be predicted one at a time. This,

combined with the quadratic attention complexity can make them slower than

their counterparts. (For training, this is not an issue thanks to teacher

forcing).


Seq2Seq models typically create an internal high-level representation of the

input sequence and then decode (i.e. generate) the output sentence. Given the

high-level representation of the input sentence and the words that have

already been decoded, Seq2Seq models estimate the most likely words to

complete the sentence. This phenomenon is called auto-regression and

the phase corresponding to generating a new word (or token) is a

timestep.

When a Transformer is used as a Seq2Seq model, the input sequence is fed

through an Encoder, and the output sequence is then generated by a Decoder, as

illustrated in figures 1 and 2.

Decoding Inefficiency of the PyTorch Transformers



The Transformer class in Pytorch is generic which is great because it gives

the ML researchers at Scale AI fine-tuned control but that also means it isn’t

optimized for speed. Let’s take a deeper look.


First, it can be seen in Figure 1 that the encoder output can be computed

separately from the decoder. This means that the encoder outputs can be

computed once and re-used for each timestep thereafter. But Pytorch does NOT

save this for you - and in fact wastes compute for each decoding timestep. To

fix this, the Transformer Encoder and Decoder should always be separated.

# THIS IS THE NAIVE WAY TO USE TRANSFORMERS

# INITIALIZATION
transformer = nn.Transformer(
  d_model=hdim,
  nhead=nhead,
  num_encoder_layers=num_layers,
  num_decoder_layers=num_layers,
  dim_feedforward=dim_feedforward,
).to(device=device)
transformer.eval()

# INFERENCE LOOP
decoded_tokens = first token
for i in range(len_output_to_decode) : # generate `len_output_to_decode` tokens
  mask_dec = generate_square_subsequent_mask(
    i + 1, device=first_token.device
  ) # create mask for autoregressive decoding
  decoded_embeddings = embedding(decoded_tokens)
  output = transformer(src, decoded_embeddings, tgt_mask=mask_dec)
  logits = to_vocab(output) # projection to vocab size

  # keep most likely tokens
  top_indices = torch.argmax(logits, dim=-1)
  # we only care about the last token that was decoded
  top_indices_last_token = top_indices[-1:]
  # add most likely token to the already decoded tokens
  decoded_tokens = torch.cat(
    [decoded_tokens, top_indices_last_token], dim=0
  )


The code below is a much more efficient way to get the same results by

decoupling the encoder and the decoder. Note that the code corresponding to

the inference loop barely changes.

# INITIALIZATION
encoder = nn.TransformerEncoder(
  nn.TransformerEncoderLayer(
    d_model=hdim, nhead=nhead, dim_feedforward=dim_feedforward
  ),
  num_layers=num_layers,
).to(device=device)
encoder.eval()

decoder = nn.TransformerDecoder(
  nn.TransformerDecoderLayer(
    d_model=hdim, nhead=nhead, dim_feedforward=dim_feedforward
  ),
  num_layers=num_layers,
).to(device=device)
decoder.eval()

# INFERENCE LOOP
decoded_tokens = first_token
src_embeddings = encoder(src)
for i in range(lenoutput_to_decode):
  mask_dec = generate_square_subsequent_mask(
    i + 1, device=first_token.device
  ) # create mask for autoregressive decoding
  decoded_embeddings = embedding(decoded_tokens)

  # the decoder uses the encoder output `src_embeddings`
  output = decoder(decoded_embeddings, src_embeddings, tgt_mask=mask_dec)

  logits = to_vocab(output) # projection to vocab size

  # keep most likely tokens
  top_indices = torch.argmax(logits, dim=-1)
  # we only care about the last token that was decoded
  top_indices_last_token = top_indices[-1:]
  # add most likely token to the already decoded tokens
  decoded_tokens = torch.cat(
    [decoded_tokens, top_indices_last_token], dim=0
  )


The main inefficiency extends on the previous point. It can be seen in Figure

2 that the embedding of a decoded token only depends on the tokens that were

decoded before it. This is a direct benefit of the Transformer model being

autoregressive. Thus, it is unnecessary to recompute the embeddings of the

already decoded tokens repeatedly and instead, we can again cache them. Each

timestep then only consists of computing the attention for the newest token’s

embedding.


Figure 3: Decoder self-attention links when decoding tokens.

The boxes at the bottom represent the embeddings of the output tokens

before self-attention, the top boxes represent the embeddings of the output

tokens after self-attention. Using our trick (right side), most of the

embeddings are not recomputed as they are cached. The number of links to

into account becomes linear instead of quadratic.


From a complexity perspective, generating the n-th output token without our

trick involves computing self-attention over the entire current output

(O(n²)) and computing encoder-decoder attention between the whole

input (of size that we will note M) with the current output (O(Mn)).

Hence, the complexity of each timestep is O(Mn + n²). Given that we

want to decode N tokens, N timesteps are needed, and the final complexity is

O(MN² + N³).


Our trick accelerates each timestep. Only the parts of the self-attention and

encoder-decoder attention responsible for updating the last token are

computed. Figure 3 shows how this works for the self-attention. The new

complexity of each timestep is O(M + N), so with N timesteps the

final complexity is sped up to O(MN + N²).


The PyTorch Transformer decoder architecture is not assumed to be

autoregressive. However, by inheriting the TransformerDecoder layer, we

introduce a CausalTransformerDecoder which uses a cache to implement the

improvement above. Our

code

differs from the Pytorch implementation by a few lines only. Our new decoder

works similarly to the original TransformerDecoder, except that we now have to

take into account the cache:

causal_decoder = CausalTransformerDecoder(
  CausalTransformerDecoderLayer(
    d_model=hdim,
    nhead=nhead,
    dim_feedforward=dim_feedforward,
  ),
  num_layers=6,
).to(device=device)
causal_decoder.eval()

decoded_tokens = first_token
src_embeddings = encoder(src)
cache = None
for i in range(len_output_to_decode):
  mask_dec = generate_square_subsequent_mask(
    i + 1, device=first_token.device
  ) # create mask for autoregressive decoding
  decoded_embeddings = embedding(decoded_tokens)

  # only change here: we add the cache as an extra parameter
  output, cache = causal_decoder(decoded_embeddings, src_embeddings, cache)

  logits = to_vocab(output) # projection to vocab size

  # keep most likely tokens
  top_indices = torch.argmax(logits, dim=-1)
  # we only care about the last token that was decoded
  top_indices_last_token = top_indices[-1:]
  # add most likely token to the already decoded tokens
  decoded_tokens = torch.cat(
    [decoded_tokens, top_indices_last_token], dim=0
)

Experiments



We put our changes to the test to see how much faster we could get. We present

two different scenarios: translation and generation of long texts.


We compare our three different implementations

(see footnotes for details):


  • The most naive Pytorch implementation (defined in the first piece of code),
  • which uses nn.Transformer
  • The Pytorch encoder-decoder implementation (second piece of code).
  • Our CausalTransformerDecoder (third piece of code).


As a reminder, these are three different implementations of the same model.

When initialized with the same weights, they return the same outputs.

Text Translation



The first setting corresponds to translation. In this setting the input and

output sequences are generally short and of similar lengths.

The non-linear curves show that the attention mechanisms progressively

become the most compute-intensive parts of the model as the number of input

and output tokens increase.

Our causal implementation is up to 40% faster than the Pytorch

Encoder-Decoder implementation, and 150% faster than the Pytorch

nn.Transformer implementation for 500 input/output tokens.

Long Text Generation



We now ask the model to generate long sequences from a fixed size input. Such

a situation might arise when generating a story from an image or from an

initial prompt.


The results below were obtained with a fixed input size of 500 tokens.

Increasing the number of input tokens makes the models slower but doesn’t

change the overall trends observed.

Our causal model is twice as fast as the PyTorch encoder-decoder

implementation when the number of tokens to generate exceeds 1,000.

When decoding more than 500 tokens, the time ratio between the causal model

and the other implementations becomes linear. This confirms the theory

according to which the overall decoding complexity was reduced by a factor of

N.


Finally, our CausalTransformerDecoder can also be used without any input

sentence (i.e. without an encoder), as it is the case in some common

generation settings. The model is typically asked to complete a story or an

article. More information about this type of generation can be found in the

GPT papers

(Alec Radford et al., 2018). The results we find for this case are similar to the ones above.

Digging deeper…



One might notice that caching the output of each layer is sub-optimal. Indeed,

the first stage of the attention layers consists of projecting the embeddings

to the keys, values and queries spaces. In the PyTorch implementation and the

proposed implementation, the same embeddings get projected repeatedly.

Instead, the queries, keys and values could be directly cached. However, this

requires substantially more changes, which could become unstable with new

Pytorch upgrades. Moreover, the estimated gains are minor - less than 5% from

our experiments.


For standard NLP use cases, the HuggingFace repository already embeds these

optimizations. Notably, it caches keys and values. It also comes with

different decoding flavors, such as beam search or nucleus sampling.

Conclusion



The simple tricks proposed take advantage of the fact that the overall Pytorch

implementation of the Transformer is too generic. The changes provide modest

speed improvements when generating a few hundreds of tokens which can become

significant boosts over the original PyTorch implementation when the output

length nears a thousand tokens. These gains are naturally directly

proportional to the number of output tokens to decode. And best of all, they

can be implemented in just a few lines using our

repo.

References:



Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan

N Gomez, Łukasz Kaiser, and Illia Polosukhin.

Attention is all you need. In Advances in Neural Information

Processing Systems 30, pp. 5998–6008. Curran Associates, Inc., 2017. URL:

https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf


Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova.

BERT: pre-training of deep bidirectional transformers for language

understanding. CoRR, abs/1810.04805, 2018. URL:

http://arxiv.org/abs/1810.04805.


Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn,

Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani et al.

An Image is Worth 16x16 Words: Transformers for Image Recognition at

Scale.

arXiv preprint 2020 URL:

https://arxiv.org/pdf/2010.11929


Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Clement Delangue,

Anthony Moi, Pierric Cistac et al.

Transformers: State-of-the-art natural language processing arXiv

preprint 2019 URL:

https://arxiv.org/pdf/1910.03771.pdf


Alec Radford, Karthik Narasimhan, Tim Salimans, and Ilya Sutskever.

Improving language understanding by generative pre-training. arXiv

preprint 2018 URL:

https://www.cs.ubc.ca/~amuham01/LING530/papers/radford2018improving.pdf

Footnotes:



1 - The vocabulary size used for this experiment was 30,000. It corresponds to

the vocabulary size chosen in the BERT original paper, which uses BPE

tokenization. We also ran the same experiments with a much smaller vocabulary

(128 tokens), to imitate a character-level setting which did not show

significant benefits.


2 - The hyperparameters used for the Transformer architecture are the ones of

the original paper (6 layers, 8 heads, 512 hidden dimensions, 2048

feed-forward hidden dimensions for both encoder/decoder). The results should

be similar with other configurations, provided that the encoder and decoder

have the same size.


3 - The displayed results correspond to a batch-size of 8 sequences, but we

made sure that a batch-size of 1 gives the same trend. A batch-size of 1 is

usually the most common in inference as requests are sent asynchronously.

However, in our specific setup the model takes several seconds to generate

sentences, so it is more natural to batch requests.


3 - As a safety check, we benchmarked GPT-2 HuggingFace implementation against

our Causal Decoder. To do that, we used the same set of hyperparameters. We

generated up to 1000 tokens with the two models. The speed ratio between these

two models was close to 1, oscillating between 0.85 and 1.10.

4 - All the experiments were run on a V100 GPU.


The future of your industry starts here.