Fine-tune a ML Model under 50 lines of code with FastAI
Our learning journey of Machine Learning, starts with building a very simple model for recognising objects in the images. Worth mentioning is the fact that we are not going to create a model from scratch, not train it from scratch. We are going to use a pre-trained model, and fine-tune it to our needs. Another important to mention fact is that we will be using fastai library, which is a high-level wrapper around PyTorch. It abstracts away a lot of the complexities that we would have to deal with if we were to use PyTorch directly, or build our own model from scratch. In fact, as far as I know, it was fastai's mission to make Deep Learning more accessible to everyone, and I think they are doing a great job at it. So, let's get started.
The Problem definition and FastAI Course
Before we dive into the code, let's define the problem we are trying to solve. Let's say that we have a bunch of images of cats and dogs, and we want to build a model that can recognize whether the image contains a cat or a dog. This is a very simple problem, but it will allow us to get familiar with the fastai library, and the process of fine-tuning a model. What's more, if you would like to dig deeper and understand more, I strongly recommend to checkout original fastai course, which is available for free on fast.ai, as well as read the book available here: Deep Learning for Coders with fastai and PyTorch and here as an online version: https://github.com/fastai/fastbook.
Setting up the environment
Unlike the course suggests, I would rather show you how to create and fine-tune the model entirely on your own machine, locally in the IDE, without any cloud services, nor Python notebooks. For the reference I was using Python 3.11.6 version and two main libraries used in the project are:
duckduckgo_search ver: 4.2
fastai ver: 2.7.13
As well as I was using VSCode as my IDE, and macOS Ventura as my operating system. However, I will not go into details on how to install Python, nor VSCode. Python 3 installed on your machine as well as some kind of IDE should be enough to follow along.
First of all you will need to create a new folder for your project, and then create a new virtual environment for it. To create a new virtual environment with venv, you can run the following command:
python3 -m venv env
This will create a new directory called env which will contain the virtual environment for your new project. After creating the virtual environment, you can activate it using the following command:
On macOS or Linux:
source env/bin/activate
On Windows:
.\env\Scripts\activate
Once you activate the virtual environment, you can create a new file which will be the entry file for our project. I called mine main.py
. In this file we will import all the libraries we need, and write the code for our project.
Last but not least we will need to install two main libraries which are fastai as PyTorch wrapper, and duckduckgo_search which will allow us to download images from the internet. To install them you can run the following commands:
pip install fastai duckduckgo_search
Coding Time
Open main.py
file and import the libraries we will need:
from duckduckgo_search import DDGS
from fastdownload import download_url
from fastcore.all import *
from time import sleep
from fastai.vision.all import *
Data is the core element for building our project, in this case it will be images of cats and dogs. We will need to download them from the internet, and then split them into training and validation sets. Otherwise, we could of course download them manually, and then split them manually, but that would be a lot of work.
To download the images we will use duckduckgo_search library. It allows us to search for images on duckduckgo.com, and then download them. Let's create a search function which will take a query and a number of images we want to download as parameters, and then return a list of downloaded images.
def search_images(keywords, max_images=30):
ddgs = DDGS()
print(f"Searching for '{keywords}'...")
return L(ddgs.images(keywords=keywords, max_results=max_images)).itemgot('image')
Please note that, it may look different in the fastai course, because the library has changed since then. The same applies to the future, as the code we have just written may be obsolete one day, that's why I recommend to always check the documentation of the library you are using.
Let's also verify that the function works as expected, by calling it with a query and a number of images we want to download:
dog_photo_url = search_images('dog photo', max_images=1)[0]
print(dog_photo_url)
The code above, will use the search_images function to search for 'dog photo' on duckduckgo.com, and then return the list of image urls. After search it will be print the url.
You can test the code by running the following command:
python main.py
If you can see the url of the image, it means that the function works as expected. Now, let's download the image using the download_url function from fastdownload library:
download_url(dog_photo_url, 'dog.jpg', show_progress=True)
The code above will download the image from the url, and save it as dog.jpg file. If you run the code again you should see a "dog.jpg" file in your project directory.
Next, we can do the same for cat, search for one image of a cat, download it, and save in the project's directory so that we can use it later for testing our model.
download_url(search_images('cat photo', max_images=1)[0], 'cat.jpg', show_progress=True)
After running main.py
again, you should see a "cat.jpg" file in your project directory.
Data for Fine-Tuning Needed
Now that we have our dog and cat photo to verify our model workings in the future, we can start gathering data for model fine-tuning. We will need to download a bunch of images of cats and dogs, and then split them into separate folders names accordingly. Let's create a function that will do that for us:
searches = 'dog','cat'
path = Path('dog_or_cat')
for animal in searches:
dest = (path/animal)
dest.mkdir(exist_ok=True, parents=True)
download_images(dest, urls=search_images(f'{animal} photo'))
sleep(10) # Pause between searches to avoid over-loading server
download_images(dest, urls=search_images(f'{animal} photo'))
sleep(10)
resize_images(path/animal, max_size=400, dest=path/animal)
failed = verify_images(get_image_files(path))
failed.map(Path.unlink)
print(f'{len(failed)} images removed')
The code above will create a new directory called "dog_or_cat" in our project directory, and then download images of dogs and cats into it. It will also resize the images to 400 pixels, and remove any images that are corrupted. The code may take a while to run, depending on your internet connection, and the number of images you want to download. You can change the number of images by changing the max_images parameter in the search_images function. However, please note that it's good to have at least 30 images of each class, so that the model can learn properly.
Fine-tuning Time
Now that we have our data ready, we can start fine-tuning our model. First of all we will need to create a DataBlock, which will be used to create a DataLoaders object. DataLoaders object will be used to create a Learner object, which will be used to train our model. Let's start with the DataBlock:
dls = DataBlock(
blocks=(ImageBlock, CategoryBlock),
get_items=get_image_files,
splitter=RandomSplitter(valid_pct=0.2, seed=42),
get_y=parent_label,
item_tfms=[Resize(192, method='squish')]
).dataloaders(path, bs=32)
The code above, basically creates a payload object, which will be used to fine-tune our model. Let's break it down a little bit:
DataBlock: This is a class provided by the fastai library for defining how to load and process data for a machine learning model.
blocks=(ImageBlock, CategoryBlock): This specifies the type of data the DataBlock will work with. In this case, it's working with image data (ImageBlock) and category labels (CategoryBlock).
get_items=get_image_files: This specifies the function to use for getting the items (in this case, image files) that the DataBlock will work with.
splitter=RandomSplitter(valid_pct=0.2, seed=42): This specifies how to split the data into training and validation sets. It uses a random splitter with 20% of the data set aside for validation and a seed value of 42 for reproducibility.
get_y=parent_label: This specifies how to get the labels for the data. In this case, it's using the parent directory of the image file as the label.
item_tfms=[Resize(192, method='squish')]: This specifies any item transformations to apply to the data. In this case, it resizes the images to a height of 192 pixels using the "squish" method.
.dataloaders(path, bs=32): This creates the dataloaders for the data block, where path is the directory containing the image files, and bs=32 specifies the batch size for the dataloaders.
Now that we have our DataBlock ready, we can create a Learner object, which will be used to train our model:
learn = vision_learner(dls, resnet18, metrics=error_rate)
The code above, creates a Learner object, which will be used to train our model. What we do here is pass the dls
object which is a payload of data for our model, and the resnet18
model which is a pre-trained model provided by fastai library. We also specify the error_rate
metric, which will be used to evaluate our model.
Last but not least, we will need to fine-tune our model. To do that we will use the fine_tune
method provided by the Learner object:
learn.fine_tune(10)
The code above, fine-tunes our model for 10 epochs. You can change the number of epochs by changing the number in the fine_tune method. However, please note that it's good to fine some sweet spot, where the model is not overfitting, but also not underfitting. You can read more about it in the fastai course. So too much data, or too much epochs is not necessarily a good thing, and the same applies to too little data, or too little epochs.
Model Evaluation
Our model is fine tuned, we gathered the data, and now it's time to evaluate it. Let's give it our dog and cat photo from the beginning, and see what it predicts:
predicted_animal,prediction_index,prediction_probability = learn.predict(PILImage.create('dog.jpg'))
print(f"The photo depicts: {predicted_animal}. Probability: {prediction_probability[prediction_index]:.4f}")
predicted_animal,prediction_index,prediction_probability = learn.predict(PILImage.create('cat.jpg'))
print(f"The photo depicts: {predicted_animal}. Probability: {prediction_probability[prediction_index]:.4f}")
What we do here is we use the predict
method provided by the Learner object, and pass it the image we want to predict. Then we print the result, which is a tuple containing the predicted class, the predicted tensor's index, and the array of predictions probability. The prediction probability is a number between 0 and 1. Feel free to play around with this code, brake it, rebuild, or even try to predict your own images.
Conclusion
In this article we have learned how to fine-tune a pre-trained model using fastai library. We have also learned how to gather data for our model, and how to evaluate it. I hope you have enjoyed reading this article, and that you have learned something new. If you have any questions or comments, please feel free to leave them below. I will be happy to answer them. It also marks the end of the first part of fastai course series. Again I strongly recommend to check out the original fastai course, as well as the book. In the next part we will learn how to deploy our model to the web, and make it available for everyone to use. So stay tuned, and see you in the next article.
Subscribe to my newsletter
Read articles from Pawel Kowalewski directly inside your inbox. Subscribe to the newsletter, and don't miss out.
Written by
Pawel Kowalewski
Pawel Kowalewski
Software Engineer, Educator, Content Creator, and a former University Lecturer. Coding and programming enthusiast with a primary focus on Web Development and related technologies that hold significant potential for positive impact. Beyond my professional pursuits, I'm an amateur electro-acoustic guitarist, motorcyclist, and rally enthusiast outside of working hours.