How to Build an Image Classification App Using TensorFlow Lite and Flutter

Yashraj TarteYashraj Tarte
Jul 01, 2024·
13 min read

Machine learning and Artificial Intelligence elevate mobile app development to unprecedented heights. With machine learning, apps can detect speech, recognise images, and interpret body language. AI provides fresh and captivating methods to engage and connect with people globally. But, how exactly do we weave machine learning into our mobile applications?

Developing mobile applications with integrated machine learning has long been a daunting task. However, with development platforms and tools like Firebase’s ML and TensorFlow Lite, it has become as easy as pie. These tools offer pre-trained machine learning models and resources for training and importing traditional models. But how do we build a captivating experience atop these machine learning models? Enter Flutter.

The Flutter SDK is a versatile UI toolkit developed by Google and its open community, aimed at enhancing applications across Android, iOS, web, and desktop platforms. At its heart, Flutter combines a powerful graphics engine with the Dart programming language. With Flutter, we can seamlessly craft mobile applications featuring machine learning capabilities like image classification and object detection on Android and iOS platforms.

In this article, we will harness the combined power of Flutter and on-device machine learning to develop a Flutter application capable of detecting animals.

Let’s take a quick peek at what we’re going to create today:

Get the Dataset

I’ve downloaded the dataset from here. Simply download the dataset, and pick your favorite classes (in this case, animals). You can download the dataset from here. Feel free to use any other dataset if you prefer.

This article will be divided into two modules:

  • Training a CNN (Convolutional Neural Network) for image classification.

  • Integrating the trained .tflite model into a Flutter application for on-device inference.

The entire, project awaits you for cloning on my GitHub repository, Go have a look at it.

Module 1: Neural Networks Training

Here's the deal with training our model:

Option 1: Using Teachable Machine

What is a Teachable Machine? Well, it's a nifty web-based tool designed to swiftly and effortlessly create machine learning models, making the process accessible to everyone. You can use it to recognize images, sounds, or poses.

Now, let's dive into how we can craft our machine learning model using Teachable Machine.

Firstly, head over to Teachable Machine and launch an Image Project. To kickstart the training process, we'll define five classes based on my dataset: "Elephant," "Kangaroo," "Panda," "Penguin," and "Tiger." Then, upload your training images to begin building your model.

Teachable Machine

Replace Class 1 with Elephant and click on Upload to start uploading the training images for the cats.

Click on the upload button and either select images from your folder or simply drag and drop them into the designated area.

Now repeat the process for the other classes. Change Class 2 to Kangaroo and upload the training images for kangaroos. Change Class 3 to Panda and upload the training images for pandas. Follow the same steps for the remaining animals.

Click on the "Train Model" button after completing the upload of all the images.

It's going to take some time, especially if you've uploaded a hefty amount of training images. So, sit back, relax, and savo your coffee while the magic happens!

Once your model is trained, click on "Export Model" and proceed to download the TensorFlow Lite Floating Point Model.

Exporting the TensorFlow Lite Floating Point Model

Make Sure your model do recognize our cute penguin, Otherwise!!!!

You can find my model.tflite & labels.txt files directly fromhere.

Option 2: Using TensorFlow and Keras

What is TensorFlow? Well, it's an open-source artificial intelligence library that uses data flow graphs to build models. It empowers developers to create large-scale neural networks with numerous layers. TensorFlow is primarily used for:

  • Classification: Sorting data into categories.

  • Perception: Recognizing patterns, like identifying objects in images.

  • Understanding: Interpreting complex data, like translating languages.

  • Discovering: Finding hidden patterns or insights in data.

  • Prediction: Forecasting future trends based on current data.

  • Creation: Generating new content, like music or text.

Let's get our hands dirty and dive into some code:

Let’s start by opening Jupyter Notebook (or Google Colab):

Code Cell 1: Importing Libraries & Training Modules

  • The os module will provide us functions for fetching contents of and writing to a directory.

    Set the base_dir variable to the location of the dataset containing the training images.

Code Cell 2: Preprocessing (formatting images before they are used for model training and inference)

  • First things first, let's set IMAGE_SIZE to 224 because, obviously, that's the size we're going to resize our dataset images to.

  • Next up, BATCH_SIZE is set to 64—because apparently, our neural network likes to chew on 64 images at a time.

  • The rescale=1./255 bit is our way of saying, "Hey, let's shrink those file sizes and speed up training because no one has time for slow models."

  • Our dataset is split into a Test set and a Training set. The training set gets to do the heavy lifting (training our model), while the test (validation) set sits back and judges how well our model performs. By setting validation_split=0.2, we're telling Keras, "Use 80% of the data for training and save 20% for validation, because we like our models tested and accurate."

  • Finally, we have our two trusty generators (train_generator and val_generator) that will grab the directory path and churn out batches of augmented data. Just to keep us in the loop, they'll output something like: "Found 2872 images belonging to 36 classes" for training and "Found 709 images belonging to 36 classes" for validation.

Code Cell 3: Creating a labels.txt file to hold all our labels (names of animals)

  • Print all keys and classes (labels) of the dataset to re-check if everything is working fine.

  • Flutter requires two files: model.tflite and labels.txt.

  • The ‘w’ in the code creates a new file called labels.txt containing the labels (names of animals), and if it already exists, it will overwrite it.

Now we will use MobileNetV2, which is a convolutional neural network architecture that seeks to perform well on mobile devices. It is based on an inverted residual structure where the residual connections are between the bottleneck layers.

Code Cell 4: Creating a base model for Transfer Learning

  • In our approach, we're opting not to load the fully connected output layers of the MobileNetV2 model, which enables us to add and train a new output layer. Therefore, we set the include_top argument to False.

Code Cell 5: Adding Hidden Layers to Neural Networks

  • Make sure to set base_model.trainable to False to keep those weights in check before we compile the model.

  • Moving on, let's toss in our hidden layers:

    1. Conv2D Layer: Convo2D is this fancy layer that uses a window to peek at your input data and magically spits out some output numbers. It's all about deciphering those intricate patterns hidden in your images.

      ReLU Activation: Ah, 'relu'—short for rectified linear unit—does this neat trick where it either gives you back what it sees (if it's positive), or just plays dead and gives you a big fat zero. Simple and effective!

    2. Dropout Layer: This one's the party pooper of Neural Networks. The dropout layer crashes the party by randomly shutting down some of the input neurons during training. Why? To stop the network from getting too cozy with its data and becoming a snob that only recognizes what's already in its training set.

    3. GlobalAveragePooling2D Layer: Imagine a layer that goes, "Hey, let's average out all the excitement from the previous layer." ▹GlobalAveragePooling2D does exactly that—it takes the average of each feature map, effectively downsizing the data for the final stretch.

    4. Dense Layer: Dense layer is like the town square of neurons—it connects everyone from the previous layer to each neuron in its gang. And that '5' there? It's the number of animal types our network is going to gossip about.

      Softmax Activation: Softmax takes a bunch of real numbers and turns them into probabilities. It's like turning up the volume on how confident our network is about each animal type it identifies.

Code Cell 6: Compiling the model

  • Before training the model, we need to compile it using model.compile, which defines the loss function, optimizer, and metrics for prediction. This step is crucial because a compiled model is necessary for training, as it utilizes the specified loss function and optimizer.

  • We'll be using Adam as our optimizer, a popular choice designed specifically for training deep neural networks. Adam serves as a replacement for stochastic gradient descent in many cases, offering efficient optimization for complex models.

Code Cell 7: Training

  • Epochs refer to the number of times the neural network goes through the entire training dataset during training. Each epoch involves using all the training data once for both forward and backward passes. An epoch consists of one or more batches, where each batch is a subset of the dataset used for training.

  • Increasing the number of epochs can improve the accuracy of the neural network by allowing it to learn more from the data. However, setting the number too high can lead to overfitting, where the model becomes overly specialized to the training data and performs poorly on new, unseen data.

Code Cell 7: Output of training process

Training might take a while, especially if you've uploaded a ton of training images. So, sit tight and let the model do its thing. Patience is key while the magic happens!

Now we have to convert our Neural Network Model to a .tflite file which we can use in our Flutter App.

Code Cell 8: Converting the Trained Neural Network Model into a Tensorflow Lite file

  • Keras SavedModel uses tf.saved_model.save to save the model along with all trackable objects attached to it.

  • To convert a SavedModel to a TensorFlow Lite model, we use tf.lite.TFLiteConverter.from_saved_model.

  • You can find my Jupyter Notebook [here] and the Google Colab file [here].

If you are using Google Colab, first upload the dataset.zip file to Google Drive. Mount the drive, extract the files using Colab, and then use the dataset. Finally, you can download the model.tflite and labels.txt files using the following code:

from google.colab import files

files.download(‘model.tflite’)

files.download(‘labels.txt’)

Warning: In case your model fails to recognize our cute panda...!

You can find my model.tflite and labels.txt files directly from [here].

Module 2: Importing and Using TensorFlow Lite File in Our Flutter App

Open terminal, then navigate to your project directory and run the command flutter run project_name.

If you are using Visual Studio Code, follow these steps:

  1. Open the command palette from the 'View' menu in the top bar (or by pressing Ctrl+Shift+P).

  2. Choose the option to create a new Flutter app.

  3. Enter the project name and hit Enter.

Let's dive in and start writing some code once again:

Next, head over to the pubspec.yaml file, add the following dependencies, and save. You may need to run the flutter pub get command, which ensures that Flutter saves the specific package versions found in the pubspec.lock file:

dependencies:
  flutter:
    sdk: flutter
  tflite: ^1.1.2
  image_picker: ^0.8.3+2

For TensorFlow Lite to work, you'll need to configure your Android project settings. Open the android/app/build.gradle file, set minSdkVersion to 19, and add the following settings within the android block:

aaptOptions {
        noCompress 'tflite'
        noCompress 'lite'
    }

android block in android/app/build.gradle path

android block in android/app/build.gradle path

To enable image_picker to work on iOS, you need to add the necessary keys to your Info.plist file located at /ios/Runner/Info.plist.

<key>NSCameraUsageDescription</key>
<string>Need Camera Access</string>
<key>NSMicrophoneUsageDescription</key>
<string>Need Microphone Access</string>
<key>NSPhotoLibraryUsageDescription</key>
<string>Need Gallery Accesss</string>

Create a folder named “assets” and place the model.tflite & labels.txt files within it.Then add their existency in pubspec.yaml file like following.

Then save the file (may be you have to run ‘flutter pub get’ command, by this command flutter saves the concrete package version found in the pubspec.locklockfile) like previous.

Let's proceed to develop the user interface and functionality for our Flutter app.

In the 'main.dart' file, set up MaterialApp with the home parameter pointing to Home().

Code: main.dart

Next, create a new file named 'home.dart' and define a Stateful widget class called Home(). This class will serve as our homepage. Let's kick off the functional Flutter app by importing the required packages and setting up our initial functions.

Code: lib/home.dart at initial phase

  • _loading: This handy flag lets us know if someone's actually bothered to pick an image yet or if we're just twiddling our thumbs.

  • _image: It's the lucky winner from our gallery or camera—a snapshot of whatever strikes our fancy at the moment.

  • _output: Think of this as the oracle's proclamation—the TensorFlow Model's educated guess about what's in that chosen image.

  • picker: This nifty tool gives us the power to pluck images from our gallery or snap new ones with our camera.

Next, we'll create six distinct methods for the class:

Code: First two methods in lib/home.dart

  • The first two methods:

    1. initState(): This method gets the ball rolling when the Home widget springs to life—right after we decide to grace the app with our presence and land on Home(). Before diving into building the widget itself, initState() takes charge. Here, we kick off the process by loading our model using loadModel(), a method we'll define shortly. It sets the stage for everything that follows, ensuring we're ready to roll with all systems go.

    2. dispose() This method disposes and clears our memory.

Code: Last four methods in lib/home.dart

  • The last 4 methods:

    1. classifyImage(): This method executes the classification model on the selected image. numResults specifies the number of classes (in this case, types of animals) we expect. It updates the UI using setState to reflect any changes.

    2. loadModel(): This function loads our model, making it ready for use. It's called within initState() to ensure the model is ready when the widget initializes.

    3. pickImage(): This function retrieves an image directly from the device's camera.

    4. pickGalleryImage(): This function retrieves an image from the user's gallery or photo library.

Let us build the AppBar:

Code: AppBar of homepage in lib/home.dart

Now it's time for the main section of our webpage. Let's create a container to hold the picture that the user has picked.

Code: Container part of homepage in lib/home.dart

We used ClipRRect to create lovely circular borders for the image.

Code: Output display of homepage in lib/home.dart

Next, we'll create two GestureDetectors that, when tapped, will refer to the pickImage and pickGalleryImage functions.

  • NOTE: pickImage (without parentheses) within onTap is a function reference, which means it is run once the user clicks on a certain widget, also known as a callback.

  • pickImage() is a function call that is run instantly.

Code: Two Gesture Detectors of homepage in lib/home.dart

Done!

To run your Flutter project, you have two options:

  1. Using Terminal:

    • Open your terminal or command prompt.

    • Navigate to your Flutter project directory.

    • Run the command: arduinoCopy codeflutter run

  2. Using Visual Studio Code:

    • Open Visual Studio Code.

    • From the top bar, click on "Run Without Debugging" or use the shortcut Ctrl+F5.

I hope your program doesn't crash throughout the build process and no errors occur.

Open the app, snap photographs or choose images from the gallery, and forecast 🤩.

Now you may experience the strength of AI 😉.

Once you get the hang of it, you'll realize how simple it is to utilize TensorFlow Lite with Flutter to create proof-of-concept machine learning mobile apps. It's fantastic to create such a strong programme with such a minimal amount of code. To further your expertise, go to the Kaggle website and download multiple datasets to create different categorization models and apply them in your own app.


Syntax Station Newsletter

Subscribe To Newsletter

Thanks for reading! I hope you found it helpful. Any thoughts, questions, or suggestions are really important to me.

42
Subscribe to my newsletter

Read articles from Yashraj Tarte directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Yashraj Tarte
Yashraj Tarte

I am a student in my final year of study with a deep passion for learning and a voracious appetite for knowledge, I constantly seek to broaden my horizons as an innovative thinker.