Visualize Attention Mechanism in Sequence to Sequence Models

Monojit SarkarMonojit Sarkar
3 min read

Attention mechanism is really fascinating. It tries to mimic human way of perceiving the world. In this post, I will provide visualizations on how attention mechanisms work on different simple problems.

I will demonstrate them by reversing, right shifting, left shifting characters, sorting digits and finally converting Human readable dates to Machine Translated Dates.

That's how the overall architecture looks like. It used to be a mystery but not anymore.

Reversing an Input Sequence

Input: monojitcnn

Output: nnctijonom

What is the interpretation? Since the first character "m" (look at top character of y-axis) should be the last character after reversing, the attention heatmap for "m" points towards "n" (look at the last character of x-axis) which is the last character of the input sequence. And the rest of trend follows.

Left shift an input sequence by 4 characters

Input: monojitcnn

Output: jitcnnmono

Pay attention to the lower left side of the heatmap. The attention mechanism detects that 4 characters at the beginning of the input sequence will correspond to the last the 4 characters of the input sequence.

Right shift an input sequence by 3 characters

Input: monojitcnn

Output: cnnmonojit

Pay attention to the Upper right side of the heatmap. The attention mechanism detects that 3 characters at the end of the input sequence will correspond to the first 3 characters of the input sequence.

So far the input and outputs for the problems followed a consistent pattern. That made it easy for the attention mechanism to learn the patterns.

What would happen if a little bit of uncertainty is brought into the picture?

Sort digits in ascending order

The task is for a any random 10 digits, the attention mechanism will try to sort them into ascending order.

Input: 1, 9, 4, 3, 8, 7, 5, 6

Output: 1, 3, 4, 5, 6, 7, 8, 9

Note this is a case where the attention matrix doesn't make intuitive sense. The model learns some kind of representation which allows it to sort digits.

Let's increase the difficulty a bit more.

Human Readable Dates to Machine Translated Dates

For this task, I considered random dates between 01-01-2000 and 01-01-2024 in various format. So a given date 30th July 2021 can be represented as:

  • 7 30 21

  • July 30, 2021

  • Jul 30, 2021

  • Friday, July 30, 2021

  • 30 Jul 2021

  • 30 July 2021

  • 30 07 2021

  • Fri 30 Jul 2021

  • Friday 30 July 2021

The task of the model would be to convert all of these formats to 30-07-2021

Date 1

Input: Thrusday, October 13 2011

Output: 13-10-2011

Note how the attention mechanism learns to ignore weekday "thrusday" and learns to map "october" to "10". Really fascinating.

Date 2

Input: 28 March 2005

Output: 28-03-2005

Attention Heatmap for various Human Readable Date Formats

0
Subscribe to my newsletter

Read articles from Monojit Sarkar directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Monojit Sarkar
Monojit Sarkar

I am a self-taught Python aficionado, dancing in the realms of AI and ML. What started as a curious exploration soon turned into a revelation: the unsung heroes behind the AI symphony are linear algebra, probability, and statistics. Astonishingly, these mathematical wizards not only power the algorithms but also surpass human problem-solving finesse.