A Step by Step Guide to Linear Discriminant Analysis (LDA) in Machine Learning

Arbash HussainArbash Hussain
6 min read

Introduction

Welcome to the ninth blog post in our Machine Learning series! Today, we'll explore Linear Discriminant Analysis (LDA), a powerful algorithm used for reducing dimensions and classification. By the end of this guide, you'll understand how LDA works and when to use it in your machine learning projects. As always, we will also code it from scratch in Python.

What is Linear Discriminant Analysis?

In the simplest terms, LDA is a technique that reduces the number of features (dimensions) in your data while keeping as much class-related information as possible. Imagine having a huge dataset with lots of columns, but you want to focus on just a few key ones that help separate your categories—LDA helps you do that. It's especially useful in classification tasks.

Mathematical Intuition Behind LDA

LDA seeks to project data onto a lower-dimensional space in such a way that it maximizes the separation between different classes. In other words, it reduces dimensionality while keeping class separability intact. The algorithm does this by computing two types of scatter matrices:

  1. Within-class scatter matrix: Measures how much data points within each class vary.

  2. Between-class scatter matrix: Measures how much the means of different classes differ from each other.

We want to maximize the between-class scatter while minimizing the within-class scatter, which will give us the best projection for separating classes. Let's walk through the steps in detail:

Step 1: Compute the Mean Vectors

For each class c, calculate the mean vector μc, which represents the average feature values for all data points in that class.

The mean vector for class c is given by:

$$\mu_c = \frac{1}{n_c} \sum_{i=1}^{n_c} x_i$$

Where:

  • n_c is the number of samples in class c.

  • x_i are the data points belonging to class c.

Step 2: Compute the Within-Class Scatter Matrix

The within-class scatter matrix measures the variance within each class. It’s calculated by summing up the scatter for each class separately.

$$S_W = \sum_{c=1}^{k} S_W^{(c)}$$

Where the scatter matrix for each class c is:

$$S_W^{(c)} = \sum_{i=1}^{n_c} (x_i - \mu_c)(x_i - \mu_c)^T$$

This represents the covariance of the data points within class c.

Step 3: Compute the Between-Class Scatter Matrix

The between-class scatter matrix measures how much the class means differ from the overall mean. It’s calculated by comparing each class’s mean to the overall mean of the data.

Let the overall mean vector be μ:

$$\mu = \frac{1}{n} \sum_{i=1}^{n} x_i$$

Now, the between-class scatter matrix S_B is given by:

$$S_B = \sum_{c=1}^{k} n_c (\mu_c - \mu)(\mu_c - \mu)^T$$

We multiply by the number of samples n_c​ in each class because larger classes should have more influence on the overall scatter.

Where:

  • n_c is the number of samples in class c.

  • μc is the mean vector of class c.

  • μ is the overall mean vector.

Step 4: Solve the Generalized Eigenvalue Problem

To find the optimal linear discriminants (the directions that maximize class separability), we need to solve the following generalized eigenvalue problem:

$$S_W^{-1} S_B \mathbf{w} = \lambda \mathbf{w}$$

Where:

  • w are the eigenvectors (these are the directions, or linear discriminants, along which the data will be projected).

  • λ are the eigenvalues (these tell us how much variance the corresponding eigenvector captures).

Step 5: Select the Top n Discriminants

We sort the eigenvectors w by their corresponding eigenvalues λ in descending order. The eigenvectors corresponding to the largest eigenvalues capture the directions that maximize class separability.

If we choose to project the data into a lower dimension (let’s say 2D for visualization), we select the top 2 eigenvectors with the largest eigenvalues.

Step 6: Project the Data

Finally, the data is projected onto the new space formed by the selected eigenvectors:

$$X_{\text{new}} = X \mathbf{W}$$

Where:

  • X is the original dataset,

  • W is the matrix of eigenvectors (linear discriminants),

  • X_new is the transformed data in the new feature space.

Implementation Steps

Step 1: Initialize LDA

import numpy as np

class LDA:
    def __init__(self, n_components):
        self.n_components = n_components
        self.linear_discriminants = None

We create an LDA class. The n_components tells us how many new features (discriminants) we want to extract.

Step 2: Calculate Scatter Matrices

    def fit(self, X, y):
        n_features = X.shape[1]
        labels = np.unique(y)

        # Within Class
        mean_all = np.mean(X, axis=0)
        sw = np.zeros((n_features, n_features)) # Within Class
        sb = np.zeros((n_features, n_features)) # Between Class

        # For each class c find its mean, within class scatter, between class scatter
        for c in labels:
            xc = X[c == y] 
            mean_c = np.mean(xc, axis=0)

            # (4,features) * (features * 4) => (4,4)
            sw += (xc - mean_c).T.dot(xc - mean_c)
            nc = xc.shape[0]

            # (4,1) * (4,1)T => (4,4)
            mean_diff = (mean_c - mean_all).reshape(n_features, 1)
            sb += nc * (mean_diff).dot(mean_diff.T)

In this function, we calculate two matrices:

  • sw: It shows how much the data within each class scatters.

  • sb: This shows how much the classes differ from each other.

Step 3: Eigenvectors and Eigenvalues


A = np.linalg.inv(sw).dot(sb)

# Eigenvalues and Eigenvectors
eigenvalues, eigenvectors = np.linalg.eig(A)
eigenvectors = eigenvectors.T # Because we transposed earlier
idxs = np.argsort(abs(eigenvalues))[::-1]  # Descending Order
eigenvalues = eigenvalues[idxs]
eigenvectors = eigenvectors[idxs]
self.linear_discriminants = eigenvectors[: self.n_components]

This step finds the eigenvectors and eigenvalues, which help us identify the most important directions (or features) that separate the classes.

Step 4: Transform the Data

def transform(self, X):
    return np.dot(X, self.linear_discriminants.T)

The transform method projects the original data into the new space formed by the linear discriminants.

Testing

For testing we apply LDA to Iris dataset.

if __name__ == "__main__":
    from sklearn.datasets import load_iris
    import matplotlib.pyplot as plt

    data = load_iris()
    X, y = data.data, data.target

    lda = LDA(n_components=2)
    lda.fit(X, y)
    x_projected = lda.transform(X)

    print("Shape of X before transforming:", X.shape)
    print("Shape of X after transforming:", x_projected.shape)

    x1 = x_projected[:, 0]
    x2 = x_projected[:, 1]

    plt.scatter(
        x1, x2, c=y, edgecolors="none", alpha=0.8, cmap=plt.cm.get_cmap("viridis", 3)
    )

    plt.xlabel("linear discriminant 1")
    plt.ylabel("linear discriminant 2")
    plt.colorbar()
    plt.show()

Output

Common Misconceptions about LDA

  • LDA and PCA are the same: Nope! PCA focuses on maximizing the variance, while LDA focuses on maximizing the separation between classes.

  • LDA can work on any type of data: LDA is best for classification tasks and works well when the classes are well-separated.

When to Apply LDA

  • You have a classification problem with labeled data.

  • You want to reduce the number of features while keeping the classes distinguishable.

  • Your classes are fairly well-separated.

Advantages of LDA

  • Simplicity: LDA is simple to implement and understand.

  • Efficient: It works well when the classes are linearly separable.

  • Low computational cost: Compared to some other algorithms, LDA is fast and lightweight.

Disadvantages of LDA

  • Linear separability: It doesn’t perform well when classes overlap in a non-linear way.

  • Sensitive to outliers: LDA is not robust to outliers and can be easily affected by them.

Practical Applications of LDA

  • Face recognition: LDA is often used to classify different people based on facial features.

  • Medical diagnosis: It can be applied to classify patients based on different test results.

  • Marketing: LDA helps classify customers into different groups based on behavior.

Conclusion

I hope this guide is useful to you. If so, please like and follow. You can also check out my other blogs in my series on machine learning algorithms. Your feedback and engagement are highly appreciated, so feel free to share your thoughts or ask questions in the comments section.

0
Subscribe to my newsletter

Read articles from Arbash Hussain directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Arbash Hussain
Arbash Hussain

I'm a student of Computer Science with a passion for data science and AI. My interest for computer science has motivated me to work with various tech stacks like Flutter, Next.js, React.js, Pygame and Unity. For data science projects, I've used tools like MLflow, AWS, Tableau, SQL, and MongoDB, and I've worked with Flask and Django to build data-driven applications. I'm always eager to learn and stay updated with the latest in the field. I'm looking forward to connecting with like-minded professionals and finding opportunities to make an impact through data and AI.