Running TensorFlow Models in Golang

Marek SkopowskiMarek Skopowski
6 min read

In this article we’re going to walk through loading a pre-trained TensorFlow model and running inference with the Go bindings.

Now, because of the

The TensorFlow team is not currently maintaining the Documentation for installing the Go bindings for TensorFlow.

https://github.com/tensorflow/tensorflow/tree/master/tensorflow/go

The new “official” contributor for the Go bindings (as recommended by the TF itself) is William Muir and his graft repo - https://github.com/wamuir/graft

Setting Up the Environment

Reqs:

  • Go 1.21+

  • TensorFlow (TF) installed (Go bindings rely on TF C library)

Installing the TF:

brew install tensorflow

Installing the Go TF package:

go get -u github.com/wamuir/graft/tensorflow/...

To check if your TF installation works, please follow the “Hello Tensorflow” example from the graft README - https://github.com/wamuir/graft:

package main

import (
    tf "github.com/wamuir/graft/tensorflow"
    "github.com/wamuir/graft/tensorflow/op"
    "fmt"
)

func main() {
    // Construct a graph with an operation that produces a string constant.
    s := op.NewScope()
    c := op.Const(s, "Hello from TensorFlow version " + tf.Version())
    graph, err := s.Finalize()
    if err != nil {
        panic(err)
    }

    // Execute the graph in a session.
    sess, err := tf.NewSession(graph, nil)
    if err != nil {
        panic(err)
    }
    output, err := sess.Run(nil, []tf.Output{c}, nil)
    if err != nil {
        panic(err)
    }
    fmt.Println(output[0].Value())
}

The output when you run the program should be similar to this one:

go run main.go
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1745750228.799498 1712814 mlir_graph_optimization_pass.cc:425] MLIR V1 optimization pass is not enabled
Hello from TF version 2.19.0

In case of any issues please refer to the section “Common Pitfalls and Troubleshooting” at the end of this article.

Preparing the Model

We’re going to use the pre-trained image classifier model mobilenet_v2 - https://www.kaggle.com/models/google/mobilenet-v2/tensorFlow2

Unfortunately the model downloaded from the given source had issues with the provided input layer, so in the blog article repository (see “Sources”) you can find a converter.py script, that exported it from source and provided us with the named input layer called serving_default_x.

You don’t have to do it, it’s already done, but you can take a look at the converter.py to see how you can export a model to a SavedModel format.

Loading and Running the Model in Go

In the attached code you can observe the whole processing operations split into main sections:

  • Load the model

  • Load an image

  • Create input tensors (preprocess the image)

  • Run the session (run inference)

  • Fetch outputs (predictions)

  • Find the best prediction and disply results

Our goal is to detect what’s on that image, we want to know if that’s a squirrel:

1.jpg

package main

import (
    "bufio"
    "fmt"
    "image"
    "image/jpeg"
    "log"
    "os"

    "github.com/nfnt/resize"
    tf "github.com/wamuir/graft/tensorflow"
)

func main() {
    // Load the SavedModel
    model, err := tf.LoadSavedModel("saved_mobilenet_v2", []string{"serve"}, nil)
    if err != nil {
        log.Fatal("LoadSavedModel", err)
    }
    defer func(Session *tf.Session) {
        if e := Session.Close(); e != nil {
            log.Fatal("Session.Close", e)
        }
    }(model.Session)

    // Load an image
    img, err := loadImage("images/1.jpg")
    if err != nil {
        log.Fatal("loadImage", err)
    }

    // Preprocess the image
    tensor, err := makeTensorFromImage(img)
    if err != nil {
        log.Fatal("makeTensorFromImage", err)
    }

    inputOp := model.Graph.Operation("serving_default_x")
    if inputOp == nil {
        log.Fatal("model.Graph.Operation: serving_default_x not found")
    }

    outputOp := model.Graph.Operation("StatefulPartitionedCall")
    if outputOp == nil {
        log.Fatal("model.Graph.Operation: StatefulPartitionedCall not found")
    }

    // Run inference
    outputs, err := model.Session.Run(
        map[tf.Output]*tf.Tensor{
            inputOp.Output(0): tensor,
        },
        []tf.Output{
            outputOp.Output(0),
        },
        nil,
    )
    if err != nil {
        log.Fatal("Session.Run", err)
    }

    // Predictions
    predictions := outputs[0].Value().([][]float32)

    // Find the top-1 prediction
    bestIdx := 0
    bestScore := float32(0.0)
    for i, p := range predictions[0] {
        if p > bestScore {
            bestIdx = i
            bestScore = p
        }
    }

    labels, err := loadLabels("ImageNetLabels.txt")
    if err != nil {
        log.Fatal("loadLabels", err)
    }

    fmt.Printf("Predicted label: %s (index: %d, confidence: %.4f)\n", labels[bestIdx], bestIdx, bestScore)
}

func loadImage(filename string) (image.Image, error) {
    file, err := os.Open(filename)
    if err != nil {
        return nil, fmt.Errorf("os.Open: %w", err)
    }
    defer func(file *os.File) {
        if e := file.Close(); e != nil {
            log.Fatal("file.Close", e)
        }
    }(file)

    img, err := jpeg.Decode(file)
    if err != nil {
        return nil, fmt.Errorf("jpeg.Decode: %w", err)
    }

    return img, nil
}

func makeTensorFromImage(img image.Image) (*tf.Tensor, error) {
    // Resize to 224x224
    resized := resize.Resize(224, 224, img, resize.Bilinear)

    // Create a 4D array to hold input
    bounds := resized.Bounds()
    batch := make([][][][]float32, 1) // batch size 1
    batch[0] = make([][][]float32, bounds.Dy())

    for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
        row := make([][]float32, bounds.Dx())
        for x := bounds.Min.X; x < bounds.Max.X; x++ {
            r, g, b, _ := resized.At(x, y).RGBA()
            row[x] = []float32{
                float32(r) / 65535.0, // normalize to [0,1]
                float32(g) / 65535.0,
                float32(b) / 65535.0,
            }
        }
        batch[0][y] = row
    }

    return tf.NewTensor(batch)
}

func loadLabels(filename string) ([]string, error) {
    file, err := os.Open(filename)
    if err != nil {
        return nil, fmt.Errorf("os.Open: %w", err)
    }
    defer func(file *os.File) {
        if e := file.Close(); e != nil {
            log.Fatal("file.Close", e)
        }
    }(file)

    var labels []string

    scanner := bufio.NewScanner(file)
    for scanner.Scan() {
        labels = append(labels, scanner.Text())
    }

    if err = scanner.Err(); err != nil {
        return nil, fmt.Errorf("bufio.Scanner: %w", err)
    }

    return labels, nil
}

The output is as follows:

go run main.go
2025-04-27 18:56:20.238487: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: saved_mobilenet_v2
2025-04-27 18:56:20.244913: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve }
2025-04-27 18:56:20.244945: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: saved_mobilenet_v2
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1745772980.292412 2175877 mlir_graph_optimization_pass.cc:425] MLIR V1 optimization pass is not enabled
2025-04-27 18:56:20.298863: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle.
2025-04-27 18:56:20.505321: I tensorflow/cc/saved_model/loader.cc:220] Running initialization op on SavedModel bundle at path: saved_mobilenet_v2
2025-04-27 18:56:20.562274: I tensorflow/cc/saved_model/loader.cc:471] SavedModel load for tags { serve }; Status: success: OK. Took 323789 microseconds.
Predicted label: fox squirrel (index: 336, confidence: 8.3710)

As you can see, the most common tag is a “fox squirrel”, which is exactly what we wanted to achieve. Personally not sure if this is a fox squirel or any other regular squirrel, but for sure it’s a squirrel.

All the resources like models, images and labels you can find in the article repository.

Common Pitfalls and Troubleshooting

Issues with the TensorFlow library:

go run main.go
# github.com/wamuir/graft/tensorflow
../../../go/pkg/mod/github.com/wamuir/graft@v0.10.0/tensorflow/tensor.go:69:26: could not determine what C.TF_FLOAT8_E4M3FN refers to
../../../go/pkg/mod/github.com/wamuir/graft@v0.10.0/tensorflow/tensor.go:68:26: could not determine what C.TF_FLOAT8_E5M2 refers to
../../../go/pkg/mod/github.com/wamuir/graft@v0.10.0/tensorflow/tensor.go:70:26: could not determine what C.TF_INT4 refers to
../../../go/pkg/mod/github.com/wamuir/graft@v0.10.0/tensorflow/tensor.go:71:26: could not determine what C.TF_UINT4 refers to

Even though TensorFlow is installed via Homebrew, it's not properly configured for pkg-config, which is needed for Go to find and link against the TensorFlow C library.

Run brew link --force libtensorflow

Then if needed add also env vars to you bash profile file (I’m using the .zshrc):

# TensorFlow configuration
export LIBRARY_PATH="/opt/homebrew/lib:$LIBRARY_PATH"
export CPATH="/opt/homebrew/include:$CPATH"
export PKG_CONFIG_PATH="/opt/homebrew/lib/pkgconfig:$PKG_CONFIG_PATH"

Sources

0
Subscribe to my newsletter

Read articles from Marek Skopowski directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Marek Skopowski
Marek Skopowski

Software Engineer x Data Engineer - I make the world a better place to live with software that enables data-driven decision-making