Visualize Attention Mechanism in Sequence to Sequence Models
Attention mechanism is really fascinating. It tries to mimic human way of perceiving the world. In this post, I will provide visualizations on how attention mechanisms work on different simple problems.
I will demonstrate them by reversing, right shifting, left shifting characters, sorting digits and finally converting Human readable dates to Machine Translated Dates.
That's how the overall architecture looks like. It used to be a mystery but not anymore.
Reversing an Input Sequence
Input: monojitcnn
Output: nnctijonom
What is the interpretation? Since the first character "m" (look at top character of y-axis) should be the last character after reversing, the attention heatmap for "m" points towards "n" (look at the last character of x-axis) which is the last character of the input sequence. And the rest of trend follows.
Left shift an input sequence by 4 characters
Input: monojitcnn
Output: jitcnnmono
Pay attention to the lower left side of the heatmap. The attention mechanism detects that 4 characters at the beginning of the input sequence will correspond to the last the 4 characters of the input sequence.
Right shift an input sequence by 3 characters
Input: monojitcnn
Output: cnnmonojit
Pay attention to the Upper right side of the heatmap. The attention mechanism detects that 3 characters at the end of the input sequence will correspond to the first 3 characters of the input sequence.
So far the input and outputs for the problems followed a consistent pattern. That made it easy for the attention mechanism to learn the patterns.
What would happen if a little bit of uncertainty is brought into the picture?
Sort digits in ascending order
The task is for a any random 10 digits, the attention mechanism will try to sort them into ascending order.
Input: 1, 9, 4, 3, 8, 7, 5, 6
Output: 1, 3, 4, 5, 6, 7, 8, 9
Note this is a case where the attention matrix doesn't make intuitive sense. The model learns some kind of representation which allows it to sort digits.
Let's increase the difficulty a bit more.
Human Readable Dates to Machine Translated Dates
For this task, I considered random dates between 01-01-2000 and 01-01-2024 in various format. So a given date 30th July 2021 can be represented as:
7 30 21
July 30, 2021
Jul 30, 2021
Friday, July 30, 2021
30 Jul 2021
30 July 2021
30 07 2021
Fri 30 Jul 2021
Friday 30 July 2021
The task of the model would be to convert all of these formats to 30-07-2021
Date 1
Input: Thrusday, October 13 2011
Output: 13-10-2011
Note how the attention mechanism learns to ignore weekday "thrusday" and learns to map "october" to "10". Really fascinating.
Date 2
Input: 28 March 2005
Output: 28-03-2005
Attention Heatmap for various Human Readable Date Formats
Subscribe to my newsletter
Read articles from Monojit Sarkar directly inside your inbox. Subscribe to the newsletter, and don't miss out.
Written by
Monojit Sarkar
Monojit Sarkar
I am a self-taught Python aficionado, dancing in the realms of AI and ML. What started as a curious exploration soon turned into a revelation: the unsung heroes behind the AI symphony are linear algebra, probability, and statistics. Astonishingly, these mathematical wizards not only power the algorithms but also surpass human problem-solving finesse.