An introduction to BERT

Jiangyu ZhengJiangyu Zheng
4 min read

1 Introduction

BERT is important in the world of NLP because it introduced a fundamentally better way for machines to understand human language by considering context from both directions—left and right—at the same time. The bidirectional approach enables better, more accurate language understanding than previous models, which typically processed text in a single direction or used a shallow concatenation of independently trained left-to-right and right-to-left LMs.

BERT is pretrained on large corpora and then fine-tuned for specific tasks. This makes it incredibly flexible and efficient, reducing the need for many heavily-engineered task-specific architectures. When released in 2018, BERT set new records on 11 NLP benchmarks, including SQuAD, GLUE, and others. It significantly raised the performance baseline in NLP research and application. It also inspired a wave of successors (e.g., RoBERTa, ALBERT, DistilBERT).

2 How BERT works

A distinctive feature of BERT is its unified architecture accross different tasks. There is minimal difference between pre-trained architecture and the final downstream architecture.

BERT used only the encoder part of the Transformer architecture introduced in 2017 paper “Attention is All You Need.” This allows it to:

  • Capture long-range dependencies.

  • Use self-attention to weigh the importance of each word in sentence.

Unlike earlier models which read left-to-right or right-to-left, BERT is deeply bidirectional, meaning it learns both direction simultaneously.

2.1 Pre-train

BERT is pretrained on large corpora using two unsupervised tasks: Masked LM and Next Sentence Prediction (NSP).

(1) Masked LM: BERT simply masks some percentage of the input tokens at random and tries to predict them based on the surrounding context.

The following code is referenced from a BERT-pytorch repository.

class MaskedDocument:
    def __init__(self, sentences, vocabulary_size):
        self.sentences = sentences
        self.vocabulary_size = vocabulary_size
        self.THRESHOLD = 0.15

    def __getitem__(self, item):
        """Get a masked sentence and the corresponding target.

        For wiki-example, [5,6,MASK_INDEX,8,9], [0,0,7,0,0]
        """
        sentence = self.sentences[item]

        masked_sentence = []
        target_sentence = []

        for token_index in sentence:
            r = random()
            if r < self.THRESHOLD:  # we mask 15% of all tokens in each sequence at random.
                if r < self.THRESHOLD * 0.8:  # 80% of the time: Replace the word with the [MASK] token
                    masked_sentence.append(MASK_INDEX)
                    target_sentence.append(token_index)
                elif r < self.THRESHOLD * 0.9:  # 10% of the time: Replace the word with a random word
                    random_token_index = randint(5, self.vocabulary_size-1)
                    masked_sentence.append(random_token_index)
                    target_sentence.append(token_index)
                else:  # 10% of the time: Keep the word unchanged
                    masked_sentence.append(token_index)
                    target_sentence.append(token_index)
            else:
                masked_sentence.append(token_index)
                target_sentence.append(PAD_INDEX)

        return masked_sentence, target_sentence

(2) NSP: In order to let the model capture sentence relationships, BERT is pre-trained for a binarized NSP task.

class PairedDataset:

    def __init__(self, data_path, dictionary, dataset_limit=None):
        self.source_corpus = MaskedCorpus(data_path, dictionary, dataset_limit=dataset_limit)
        self.dataset_size = self.source_corpus.sentences_count
        self.corpus_size = len(self.source_corpus)

    def __getitem__(self, item):

        document_index = randint(0, self.corpus_size-1)
        document = self.source_corpus[document_index]
        sentence_index = randint(0, len(document) - 2)
        A_masked_sentence, A_target_sentence = document[sentence_index]

        if random() < 0.5:  # 50% of the time B is the actual next sentence that follows A
            B_masked_sentence, B_target_sentence = document[sentence_index + 1]
            is_next = 1
        else:  # 50% of the time it is a random sentence from the corpus
            random_document_index = randint(0, self.corpus_size-1)
            random_document = self.source_corpus[random_document_index]
            random_sentence_index = randint(0, len(random_document)-1)
            B_masked_sentence, B_target_sentence = random_document[random_sentence_index]
            is_next = 0

        sequence = [CLS_INDEX] + A_masked_sentence + [SEP_INDEX] + B_masked_sentence + [SEP_INDEX]

        # segment : something like [0,0,0,0,0,1,1,1,1,1,1,1])
        segment = [0] + [0] * len(A_masked_sentence) + [0] + [1] * len(B_masked_sentence) + [1]

        target = [PAD_INDEX] + A_target_sentence + [PAD_INDEX] + B_target_sentence + [PAD_INDEX]

        return (sequence, segment), (target, is_next)

This helps the model capture sentence relationships, which is useful in question answering (QA) and natural language inference (NLI). In order to differentiate whether a token belongs to sentence A or sentence B, BERT adds a learnable embedding called segment embedding.

BERT is called "bidirectional" because it looks at the entire sentence (both left and right context) simultaneously when learning the meaning of a word. BERT achieves this by using masked LM and the Transformer encoder.

class BERT(nn.Module):

    def __init__(self, encoder, token_embedding, positional_embedding, segment_embedding,
                 hidden_size, vocabulary_size):
        super(BERT, self).__init__()

        self.encoder = encoder
        self.token_embedding = token_embedding
        self.positional_embedding = positional_embedding
        self.segment_embedding = segment_embedding
        self.token_prediction_layer = nn.Linear(hidden_size, vocabulary_size)
        self.classification_layer = nn.Linear(hidden_size, 2)

    def forward(self, inputs):
        sequence, segment = inputs
        token_embedded = self.token_embedding(sequence)
        positional_embedded = self.positional_embedding(sequence)
        segment_embedded = self.segment_embedding(segment)
        embedded_sources = token_embedded + positional_embedded + segment_embedded

        mask = pad_masking(sequence)
        encoded_sources = self.encoder(embedded_sources, mask)
        token_predictions = self.token_prediction_layer(encoded_sources)
        classification_embedding = encoded_sources[:, 0, :]
        classification_output = self.classification_layer(classification_embedding)
        return token_predictions, classification_output

2.2 Fine-tuning

Once pre-trained, BERT can be fine-tuned on specific downstream tasks. Fine-tuning is efficient: just a few additional training steps with labeled data.

During pre-training (e.g., Next Sentence Prediction), BERT learns to make sentence-level decisions using the [CLS] token, so it already captures global sentence meaning. Thus, the final hidden state of the [CLS] token serves as a summary embedding of the entire input sequence.

0
Subscribe to my newsletter

Read articles from Jiangyu Zheng directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Jiangyu Zheng
Jiangyu Zheng