Building a Golang Microservice for Machine Learning Inference with TensorFlow

Marek SkopowskiMarek Skopowski
5 min read

In today’s article we’ll focus on how to create a simple REST API in Go that loads a TensorFlow model and serves a predictions.

As a base we’ll use the code from the previous article “Running TensorFlow Models in Golang” and work on that.

We’ll do a little bit of refactoring regarding the architecture and introduce the modular folder structure for the easier maintenance and scalability.

Note: for setting up the environment and preparing the model please take a look at Setting Up the Environment section of the previous article.

Architecture

We’re going to split the logic that was previously in the main.go file into smaller chunks:

.
├── main.go                          # entry point
├── internal/
│   ├── inference/
│   │   ├── model.go                 # model loading + prediction
│   │   └── labels.go                # label loading
│   └── handler/
│       └── predict.go               # HTTP handler
├── model/
│   └── mobilenet_v2/                # saved_model.pb + variables/
├── static/
│   └── example.jpg
├── ImageNetLabels.txt
├── go.mod
└── go.sum

Model

Loading a model to the memory in model.go file in the inference package and exposing it via the public Model variable:

package inference

import (
    "fmt"

    tf "github.com/wamuir/graft/tensorflow"
)

var Model *tf.SavedModel

func LoadModel(path string) (err error) {
    Model, err = tf.LoadSavedModel(path, []string{"serve"}, nil)
    if err != nil {
        return fmt.Errorf("LoadSavedModel: %w", err)
    }

    return nil
}

Labels

Loading labels into the Labels public variable in the labels.go file in the inference package:

package inference

import (
    "bufio"
    "fmt"
    "log"
    "os"
)

var Labels []string

func LoadLabels(path string) error {
    file, err := os.Open(path)
    if err != nil {
        return fmt.Errorf("os.Open: %w", err)
    }

    defer func(file *os.File) {
        if e := file.Close(); e != nil {
            log.Println("file.Close", e)
        }
    }(file)

    var labels []string

    scanner := bufio.NewScanner(file)
    for scanner.Scan() {
        labels = append(labels, scanner.Text())
    }

    if err = scanner.Err(); err != nil {
        return fmt.Errorf("bufio.Scanner: %w", err)
    }

    Labels = labels

    return nil
}

Note that the contents of the Labels var we’re updating only when the whole process has completed successfully.

Handler

The main logic from our previous article we need to move to the http server handler - the Predict in this example.

The makeTensorFromImage helper function comes here with us.

package handler

import (
    "fmt"
    "image"
    "image/jpeg"
    "log"
    "mime/multipart"
    "net/http"

    "github.com/nfnt/resize"
    tf "github.com/wamuir/graft/tensorflow"

    "github.com/flashlabs/kiss-samples/tensorflowrestapi/internal/inference"
)

func Predict(w http.ResponseWriter, r *http.Request) {
    if r.Method != http.MethodPost {
        http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)

        return
    }

    file, _, err := r.FormFile("image")
    if err != nil {
        http.Error(w, "Failed to get images", http.StatusBadRequest)

        return
    }
    defer func(file multipart.File) {
        if e := file.Close(); e != nil {
            log.Println("file.Close", e)
        }
    }(file)

    img, err := jpeg.Decode(file)
    if err != nil {
        http.Error(w, "Failed to decode image", http.StatusBadRequest)

        return
    }

    tensor, err := makeTensorFromImage(img)
    if err != nil {
        http.Error(w, "Failed to make tensor from image", http.StatusInternalServerError)

        return
    }

    input := inference.Model.Graph.Operation("serving_default_x")
    output := inference.Model.Graph.Operation("StatefulPartitionedCall")

    outputs, err := inference.Model.Session.Run(
        map[tf.Output]*tf.Tensor{
            input.Output(0): tensor,
        },
        []tf.Output{
            output.Output(0),
        },
        nil,
    )
    if err != nil {
        http.Error(w, "Failed to run inference", http.StatusInternalServerError)

        return
    }

    predictions := outputs[0].Value().([][]float32)

    bestIdx, bestScore := 0, float32(0.0)
    for i, p := range predictions[0] {
        if p > bestScore {
            bestIdx, bestScore = i, p
        }
    }

    label := inference.Labels[bestIdx]

    _, err = fmt.Fprintf(w, `{"class_id": %d, "label": "%s", "confidence": %.4f}`+"\n", bestIdx, label, bestScore)
    if err != nil {
        http.Error(w, "Failed to write response", http.StatusInternalServerError)

        return
    }
}

func makeTensorFromImage(img image.Image) (*tf.Tensor, error) {
    // Resize to 224x224
    resized := resize.Resize(224, 224, img, resize.Bilinear)

    // Create a 4D array to hold input
    bounds := resized.Bounds()
    batch := make([][][][]float32, 1) // batch size 1
    batch[0] = make([][][]float32, bounds.Dy())

    for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
        row := make([][]float32, bounds.Dx())
        for x := bounds.Min.X; x < bounds.Max.X; x++ {
            r, g, b, _ := resized.At(x, y).RGBA()
            row[x] = []float32{
                float32(r) / 65535.0, // normalize to [0,1]
                float32(g) / 65535.0,
                float32(b) / 65535.0,
            }
        }
        batch[0][y] = row
    }

    return tf.NewTensor(batch)
}

Please note that instead of logging fatals, we need to write an output to the http.ResponeWriter and break the processing, so the client side knows what’s wrong with the request processing, f.e. if the request method is not POST, we need to communicate this issue with the proper message and the HTTP status code:

if r.Method != http.MethodPost {
    http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)

    return
}

Main Program

Now, having all the logic extracted into the proper packages our main program looks like it should looks like - it’s small and compatc and is responsible for initialization and running the main process:

package main

import (
    "fmt"
    "log"
    "net/http"

    "github.com/flashlabs/kiss-samples/tensorflowrestapi/internal/handler"
    "github.com/flashlabs/kiss-samples/tensorflowrestapi/internal/inference"
)

func main() {
    fmt.Println("Loading TF model...")
    if err := inference.LoadModel("model/saved_mobilenet_v2"); err != nil {
        log.Fatalf("Failed to load SavedModel: %v", err)
    }

    fmt.Println("Loading labels...")
    if err := inference.LoadLabels("ImageNetLabels.txt"); err != nil {
        log.Fatalf("Failed to load labels: %v", err)
    }

    fmt.Println("Setting up handlers...")
    http.HandleFunc("/predict", handler.Predict)

    fmt.Println("listening on :8080")
    log.Fatal(http.ListenAndServe(":8080", nil))
}

As you can see all it does is:

  • Loading a TF model

  • Loading labels

  • Setting up the HTTP handlers

  • Starting a HTTP server on the local port 8080

Running a Program

Just run main.go with the go run main.go and expect the output similar to this:

go run main.go
Loading TF model...
2025-05-18 13:29:22.879008: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: model/saved_mobilenet_v2
2025-05-18 13:29:22.886212: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve }
2025-05-18 13:29:22.886235: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: model/saved_mobilenet_v2
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1747567762.933428 11558409 mlir_graph_optimization_pass.cc:425] MLIR V1 optimization pass is not enabled
2025-05-18 13:29:22.940719: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle.
2025-05-18 13:29:23.159484: I tensorflow/cc/saved_model/loader.cc:220] Running initialization op on SavedModel bundle at path: model/saved_mobilenet_v2
2025-05-18 13:29:23.217182: I tensorflow/cc/saved_model/loader.cc:471] SavedModel load for tags { serve }; Status: success: OK. Took 338178 microseconds.
Loading labels...
Setting up handlers...
listening on :8080

Making a REST Call

Be sure to be in the project directory to be able to read the static/example.jpg file:

curl -X POST -F image=@static/example.jpg http://localhost:8080/predict
{"class_id": 469, "label": "cab", "confidence": 12.6021}

Looking at our example.jpg file:

We can see it’s working.

Next Steps

You have a fully working example of a REST API that handles a POST requests with image payloads.

You might want to add more endpoints, validation, detect image sizes and so on.

The sky is the limit.

Sources

0
Subscribe to my newsletter

Read articles from Marek Skopowski directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Marek Skopowski
Marek Skopowski

Software Engineer x Data Engineer - I make the world a better place to live with software that enables data-driven decision-making