A Step-by-Step Guide to K-Nearest Neighbors (KNN) in Machine Learning
Introduction
Welcome back, everyone, to the 3rd blog post in our Machine Learning Algorithms Series! Today, we'll dive into K-Nearest Neighbors (KNN), a fundamental algorithm in machine learning. We'll be implementing the KNN algorithm from scratch in Python. By the end of this blog, you'll have a clear understanding of how KNN works, how to implement it, and when to use it. Let's get started!
What is KNN?
K-Nearest Neighbors (KNN) is a straightforward powerful supervised machine learning algorithm used for both classification and regression tasks. Its simplicity lies in its non-parametric nature, meaning it doesn't assume anything about the underlying data distribution. Instead, KNN works by finding the 'k' closest data points (neighbors) in the training dataset to a new input point and making predictions based on these neighbors.
For classification tasks, KNN predicts the class label of the new data point by a majority vote among its nearest neighbors. The class label that appears most frequently among the nearest neighbors is assigned to the new data point.
For regression tasks, KNN predicts the value of the new data point by taking the average of the values of its nearest neighbors. This average value serves as the predicted output for the new data point.
Step by Step Implementation
Importing Necessary Libraries
We start by importing the necessary libraries. These help us handle data, compute distances, and visualize results.
import numpy as np
from collections import Counter
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_regression, make_classification
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
numpy
: For numerical operations. Counter: For counting occurrences of elements.
train_test_split
: To split data into training and testing sets.
make_regression
and make_classification
: To generate synthetic datasets.
matplotlib
: For plotting.
Defining the Euclidean Distance Function
This function calculates the Euclidean distance between two points. It’s essential for determining the nearest neighbors.
def euclidean_distance(x1, x2):
return np.sqrt(np.sum((x1 - x2) ** 2))
Implementing the KNN Class
The KNN class encapsulates the algorithm’s logic.
Initialization
The __init__
method initializes the KNN class with the number of neighbors k
and a flag isclassifier
to indicate whether the task is classification or regression.
class KNN:
def __init__(self, isclassifier, k=3):
self.k = k
self.isclassifier = isclassifier
Training
The fit
method stores the training data. There’s no complex training process in KNN—just storing the data.
def fit(self, x, y):
self.x_train = x
self.y_train = y
Prediction
The predict
method generates predictions for the test data by calling _predict_single
for each test point.
def predict(self, X):
self.x_test = X
predictions = [self._predict_single(x) for x in X]
return predictions
Single Prediction
The _predict_single
method calculates distances from the test point to all training points, finds the k
nearest neighbors, and makes predictions based on the type of task (classification or regression).
def _predict_single(self, x1):
# Find distance between x1 and all other points of x_train
distances = [euclidean_distance(x1, x2) for x2 in self.x_train]
# Sort the distances, and get the index of top k points closest to x1.
k_indices = np.argsort(distances)[:self.k]
k_nearest_nbrs = [self.y_train[i] for i in k_indices]
if self.isclassifier:
prediction = Counter(k_nearest_nbrs).most_common()
return prediction[0][0]
else:
return np.mean(k_nearest_nbrs)
Main Function for Testing
This section tests our KNN implementation with both classification and regression tasks.
Classification Task
if __name__ == "__main__":
cmap = ListedColormap(["#FF0000", "#00FF00", "#0000FF"])
# Classification
X, y = make_classification(n_samples=100, n_features=2, n_informative=2, n_redundant=0, n_clusters_per_class=1, random_state=44)
x_train, x_test, y_train, y_test = train_test_split(X, y, random_state=42)
classifier = KNN(isclassifier=True, k=5)
classifier.fit(x_train, y_train)
preds = classifier.predict(x_test)
accuracy = np.sum(preds == y_test) / len(y_test)
print("On Classification Task")
print("Accuracy:", accuracy)
Data Generation: Creates a synthetic dataset for classification.
Data Splitting: Splits the data into training and testing sets.
Training: Stores the training data in KNN classifier object.
Prediction and Accuracy: Predicts the labels for the test set and calculates accuracy.
Regression Task
# Regression
X, y = make_regression(n_samples=100, n_features=1, noise=0.1)
x_train, x_test, y_train, y_test = train_test_split(X, y, random_state=42)
regressor = KNN(isclassifier=False, k=5)
regressor.fit(x_train, y_train)
rmse = np.sqrt(np.mean((y_test - regressor.predict(x_test)) ** 2))
print("On Regression Task")
print("RMSE:", rmse)
Data Generation: Creates a synthetic dataset for regression.
Data Splitting: Splits the data into training and testing sets.
Training: Trains the KNN regressor.
Prediction and RMSE: Predicts the values for the test set and calculates Root Mean Squared Error (RMSE).
Output
- Our KNN algorithm seems to be performing quite well on both Classification and Regression tasks.
Common Misconceptions about KNN
KNN is always accurate: KNN can be effective but is sensitive to noise and irrelevant features. Proper feature selection and preprocessing are essential.
KNN works well with high-dimensional data: In high-dimensional spaces, the concept of distance becomes less meaningful (curse of dimensionality).
KNN is computationally efficient: Prediction can be slow for large datasets due to the need to calculate distances to all training points. Techniques like KD-Trees can help.
When to Apply K-Nearest Neighbors: Key Points to Consider
1. Type of Task: Classification or Regression
Classification: Classifying a new sample based on the majority class of its nearest neighbors.
Regression: Predicting a continuous value based on the average value of its nearest neighbors.
2. Dataset Size and Dimensionality
Small to Medium-Sized Datasets: KNN works well with small to medium-sized datasets.
Low to Moderate Dimensionality: KNN performs best in low to moderate dimensions.
3. Data Distribution
Locally Homogeneous Data: KNN assumes that nearby points are similar.
Smooth Decision Boundaries: Effective when decision boundaries between classes are smooth.
4. No Assumption of Data Distribution
- Non-Parametric Nature: KNN makes no assumptions about data distribution, making it flexible and model-free.
Advantages of KNN
Simplicity: Easy to understand and implement.
Versatility: Suitable for both classification and regression tasks.
No Training Phase: No complex training process—just storing the dataset.
Disadvantages of KNN
Computationally Intensive: Prediction can be slow for large datasets.
Sensitivity to Irrelevant Features: All features contribute equally, which can be problematic if some features are irrelevant.
Curse of Dimensionality: Performance degrades in high-dimensional spaces.
Practical Applications
Image Recognition: KNN can be used for tasks like handwritten digit recognition.
Recommender Systems: Helps in collaborative filtering by finding similar users or items.
Medical Diagnosis: Assists in diagnosing diseases based on historical patient data.
Conclusion
I hope this guide has been helpful and encourages you to explore and experiment further with K-Nearest Neighbors (KNN). If you like this blog please leave a like and a follow, you can also checkout my other blogs on machine learning algorithms, I have been posting these blogs in a series, hope you like them.
Subscribe to my newsletter
Read articles from Arbash Hussain directly inside your inbox. Subscribe to the newsletter, and don't miss out.
Written by
Arbash Hussain
Arbash Hussain
I'm a student of Computer Science with a passion for data science and AI. My interest for computer science has motivated me to work with various tech stacks like Flutter, Next.js, React.js, Pygame and Unity. For data science projects, I've used tools like MLflow, AWS, Tableau, SQL, and MongoDB, and I've worked with Flask and Django to build data-driven applications. I'm always eager to learn and stay updated with the latest in the field. I'm looking forward to connecting with like-minded professionals and finding opportunities to make an impact through data and AI.