Handling Imbalanced Datasets: Applying SMOTE for MultiClass Classification

Mohamad MahmoodMohamad Mahmood
8 min read

Handling imbalanced datasets, especially in multi-class classification problems, can be challenging. One popular technique for addressing this issue is the Synthetic Minority Over-sampling Technique (SMOTE). Here's an overview of how to apply SMOTE in multi-class classification:

Understanding the Problem

Imbalanced datasets occur when certain classes are underrepresented compared to others. In multi-class problems, this imbalance can lead to biased models that are more accurate for the majority class but perform poorly on minority classes.

[0] Prep Dataset

[a] Set pandas column width

(optional but useful)

# Set the maximum column width to None
import pandas as pd
pd.set_option('display.max_colwidth', None)

[b] Get the dataset

# get source dataset
import pandas as pd
import ast

dset_url='https://archive.org/download/misc-dataset/airline_tweets_clean_label_count.csv'
df_airline_clean_label_count=pd.read_csv(dset_url)

# Convert stringified lists to Python lists
df_airline_clean_label_count['label'] = df_airline_clean_label_count['label'].apply(ast.literal_eval)

df_airline_clean_label_count.info()
df_airline_clean_label_count.head()

Output:

[c] Create work-in-progress dataframe, df_wip

# Create a work-in-progress copy
df_wip = df_airline_clean_label_count[df_airline_clean_label_count['label_count'] > 0].copy()

# # Select only the first 2000 rows
# df_wip = df_wip.head(2000)

# Display the result
print(len(df_wip))
df_wip.head()

Output:

[d] Visualize label distributions

import pandas as pd
import matplotlib.pyplot as plt
from collections import Counter

# define task
def plot_label_distribution(df, label_column):
    """
    Plots the distribution of labels from a specified label column in a DataFrame.

    Parameters:
    df (pd.DataFrame): The DataFrame containing the label data.
    label_column (str): The name of the column containing the labels.

    Returns:
    None
    """
    # Step 1: Flatten the list of labels
    all_labels = [label for sublist in df[label_column] for label in sublist]

    # Step 2: Count occurrences of each label
    label_counts = Counter(all_labels)

    # Step 3: Create a visualization
    plt.figure(figsize=(10, 6))
    plt.bar(label_counts.keys(), label_counts.values(), color='skyblue')
    plt.title('Label Distribution')
    plt.xlabel('Labels')
    plt.ylabel('Frequency')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

# run task
plot_label_distribution(df_wip, 'label')

Output:

[e] Prep for multiclass label

# Convert the 'label' column from list to string
df_wip['label'] = df_wip['label'].apply(lambda x: ', '.join(x) if isinstance(x, list) else x)
display(df_wip.head(1))

Output:

Import necessary libraries for classification tasks

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.multiclass import OneVsRestClassifier
from sklearn.multioutput import MultiOutputClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
from sklearn.preprocessing import MultiLabelBinarizer

MultiClass Classification - Before SMOTE

# Import necessary libraries
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import LabelEncoder
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report

# Prepare the data for multiclass classification
# Assume df_wip is your DataFrame with 'clean' (text) and 'label' (single-label) columns

# Step 1: Encode labels
label1_encoder = LabelEncoder()
y1 = label1_encoder.fit_transform(df_wip['label'])

# Step 2: Vectorize the cleaned text
vectorizer1 = TfidfVectorizer()
X1 = vectorizer1.fit_transform(df_wip['clean'])

# Step 3: Split the data into training and testing sets
X1_train, X1_test, y1_train, y1_test = train_test_split(X1, y1, test_size=0.2, random_state=42)

# Step 4: Train the model using a Random Forest Classifier
model1 = RandomForestClassifier(n_estimators=100, random_state=42)
model1.fit(X1_train, y1_train)  # Train on the original training data (no SMOTE)

# Step 5: Make predictions on the test set
y1_pred = model1.predict(X1_test)

# Step 6: Evaluate the model
print("Classification Report for Random Forest (Multiclass without SMOTE):\n")
print(classification_report(y1_test, y1_pred, target_names=label1_encoder.classes_, zero_division=0))

Output:

Classification Report for Random Forest (Multiclass without SMOTE):

                         precision    recall  f1-score   support

     aircraft_condition       0.00      0.00      0.00         7
       customer_service       0.97      0.95      0.96       197
flight_booking_problems       1.00      0.91      0.95        65
      flight_experience       0.92      0.99      0.95       293
          food_beverage       1.00      0.50      0.67        20
           lost_luggage       0.98      0.98      0.98       166
      positive_feedback       0.95      0.99      0.97       302
        safety_security       1.00      0.40      0.57        15

               accuracy                           0.95      1065
              macro avg       0.85      0.71      0.76      1065
           weighted avg       0.95      0.95      0.95      1065

MultiClass Classification - After SMOTE

# Import necessary libraries
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import LabelEncoder
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
from imblearn.over_sampling import SMOTE

# Prepare the data for multiclass classification
# Assume df_wip is your DataFrame with 'clean' (text) and 'label' (single-label) columns

# Step 1: Encode labels
label2_encoder = LabelEncoder()
y2 = label2_encoder.fit_transform(df_wip['label'])

# Step 2: Vectorize the cleaned text
vectorizer2 = TfidfVectorizer()
X2 = vectorizer2.fit_transform(df_wip['clean'])

# Step 3: Split the data into training and testing sets
X2_train, X2_test, y2_train, y2_test = train_test_split(X2, y2, test_size=0.2, random_state=42)

# Step 4: Apply SMOTE to handle class imbalance
smote = SMOTE(random_state=42)
X2_train_resampled, y2_train_resampled = smote.fit_resample(X2_train, y2_train)

# Step 5: Train the model using a Random Forest Classifier
model2 = RandomForestClassifier(n_estimators=100, random_state=42)
model2.fit(X2_train_resampled, y2_train_resampled)

# Step 6: Make predictions on the test set
y2_pred = model2.predict(X2_test)

# Step 7: Evaluate the model
print("Classification Report for Random Forest (Multiclass with SMOTE):\n")
print(classification_report(y2_test, y2_pred, target_names=label2_encoder.classes_, zero_division=0))

Output:

Classification Report for Random Forest (Multiclass with SMOTE):

                         precision    recall  f1-score   support

     aircraft_condition       1.00      0.14      0.25         7
       customer_service       0.98      0.94      0.96       197
flight_booking_problems       1.00      0.95      0.98        65
      flight_experience       0.96      1.00      0.98       293
          food_beverage       1.00      0.80      0.89        20
           lost_luggage       0.97      0.98      0.97       166
      positive_feedback       0.96      0.98      0.97       302
        safety_security       1.00      0.93      0.97        15

               accuracy                           0.97      1065
              macro avg       0.98      0.84      0.87      1065
           weighted avg       0.97      0.97      0.97      1065

Analysis

The provided outputs above are classification reports for a multiclass classification problem (not multilabel) with and without the application of SMOTE.


1. Data Context

  • The dataset contains airline-related tweets labeled with one of several categories (e.g., customer_service, flight_experience, food_beverage, etc.).

  • Each tweet belongs to exactly one class (label_count = 1), making this a multiclass classification problem .

  • The labels are imbalanced, as evident from the support values in the classification reports:

    • Majority classes: flight_experience (293 samples), positive_feedback (302 samples), customer_service (197 samples).

    • Minority classes: aircraft_condition (7 samples), safety_security (15 samples), food_beverage (20 samples).


2. Without SMOTE

Classification Report

Classification Report for Random Forest (Multiclass without SMOTE):

 precision recall f1-score support

 aircraft_condition 0.00 0.00 0.00 7

 customer_service 0.97 0.95 0.96 197

flight_booking_problems 1.00 0.91 0.95 65

 flight_experience 0.92 0.99 0.95 293

 food_beverage 1.00 0.50 0.67 20

 lost_luggage 0.98 0.98 0.98 166

 positive_feedback 0.95 0.99 0.97 302

 safety_security 1.00 0.40 0.57 15

 accuracy 0.95 1065

 macro avg 0.85 0.71 0.76 1065

 weighted avg 0.95 0.95 0.95 1065

Key Observations:

  1. Performance on Majority Classes :

    • The model performs exceptionally well on majority classes like flight_experience, positive_feedback, customer_service, and lost_luggage. Precision, recall, and F1-scores are all above 0.9.

    • For example, flight_experience has an F1-score of 0.95 , and positive_feedback has an F1-score of 0.97 .

  2. Performance on Minority Classes :

    • The model struggles with minority classes due to their limited representation in the training data:

      • aircraft_condition: Precision, recall, and F1-score are all 0.00 because the model fails to predict this class entirely.

      • safety_security: Recall is 0.40 , and F1-score is 0.57 , indicating poor performance.

      • food_beverage: Recall is 0.50 , and F1-score is 0.67 , which is better than aircraft_condition but still suboptimal.

  3. Overall Metrics :

    • Accuracy : High (0.95 ) because the majority classes dominate the dataset.

    • Macro Average : Lower (0.76 ) due to poor performance on minority classes.

    • Weighted Average : High (0.95 ) because it is heavily influenced by the majority classes.


3. With SMOTE

Classification Report

Classification Report for Random Forest (Multiclass with SMOTE):

 precision recall f1-score support

 aircraft_condition 1.00 0.14 0.25 7

 customer_service 0.98 0.94 0.96 197

flight_booking_problems 1.00 0.95 0.98 65

 flight_experience 0.96 1.00 0.98 293

 food_beverage 1.00 0.80 0.89 20

 lost_luggage 0.97 0.98 0.97 166

 positive_feedback 0.96 0.98 0.97 302

 safety_security 1.00 0.93 0.97 15

 accuracy 0.97 1065

 macro avg 0.98 0.84 0.87 1065

 weighted avg 0.97 0.97 0.97 1065

Key Observations:

  1. Improvement on Minority Classes :

    • SMOTE significantly improves performance on minority classes:

      • aircraft_condition: While precision is 1.00 , recall improves to 0.14 , and F1-score increases to 0.25 (up from 0.00 ).

      • safety_security: Recall improves to 0.93 , and F1-score increases to 0.97 (up from 0.57 ).

      • food_beverage: Recall improves to 0.80 , and F1-score increases to 0.89 (up from 0.67 ).

  2. Performance on Majority Classes :

    • SMOTE does not degrade performance on majority classes:

      • flight_experience maintains high metrics (F1-score of 0.98 ).

      • positive_feedback also maintains high metrics (F1-score of 0.97 ).

      • customer_service shows a slight drop in recall (0.94 vs. 0.95 ) but remains robust overall.

  3. Overall Metrics :

    • Accuracy : Slightly higher (0.97 vs. 0.95 ) due to improved performance across all classes.

    • Macro Average : Improved (0.87 vs. 0.76 ) because minority classes now contribute more meaningfully.

    • Weighted Average : Improved (0.97 vs. 0.95 ) due to better handling of both majority and minority classes.


4. Comparison Without vs. With SMOTE

METRICWITHOUT SMOTEWITH SMOTE
Accuracy0.950.97
Macro Avg F1-Score0.760.87
Weighted Avg F1-Score0.950.97

Key Takeaways :

  1. SMOTE Improves Minority Class Performance :

    • Minority classes like aircraft_condition, safety_security, and food_beverage see significant improvements in recall and F1-scores.

    • This is because SMOTE oversamples these classes during training, giving the model more opportunities to learn their patterns.

  2. No Significant Degradation on Majority Classes :

    • Despite the introduction of synthetic samples, the model maintains strong performance on majority classes.
  3. Better Overall Metrics :

    • Both macro and weighted average F1-scores improve, indicating that SMOTE enhances the model's ability to generalize across all classes.

5. Recommendations

  1. Use SMOTE for Imbalanced Datasets :

    • If your dataset has significant class imbalance, SMOTE is a valuable tool to improve performance on minority classes without sacrificing accuracy on majority classes.
  2. Monitor Overfitting :

    • While SMOTE helps balance the dataset, it can sometimes lead to overfitting if the synthetic samples do not accurately represent the true distribution of the minority classes. Regularly validate the model on unseen data to ensure generalization.
  3. Experiment with Other Techniques :

    • Consider combining SMOTE with other strategies like undersampling majority classes or using ensemble methods (e.g., BalancedRandomForestClassifier) for further improvement.
  4. Fine-Tune Hyperparameters :

    • Experiment with different hyperparameters for the Random Forest classifier (e.g., max_depth, min_samples_split) to optimize performance.

Conclusion

The application of SMOTE in this multiclass classification problem leads to noticeable improvements in performance, particularly for minority classes. It addresses the class imbalance issue effectively, resulting in higher macro and weighted average F1-scores while maintaining strong performance on majority classes.

Colab Notebook:

https://colab.research.google.com/drive/1ohvMIKD9tlTHAmUD2xyv2Oe3px0lvPNK?usp=sharing

0
Subscribe to my newsletter

Read articles from Mohamad Mahmood directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Mohamad Mahmood
Mohamad Mahmood

Mohamad's interest is in Programming (Mobile, Web, Database and Machine Learning). He studies at the Center For Artificial Intelligence Technology (CAIT), Universiti Kebangsaan Malaysia (UKM).