Building a Minimal IR for ONNX Model Compilation: From Concept to Implementation

Yuanbo LiYuanbo Li
4 min read

Introduction: The Role of Intermediate Representations

Our compiler's Intermediate Representation (IR) serves as the crucial bridge between high-level ONNX models and optimized executable code. This post documents my implementation of a minimal Relay-inspired IR and the ONNX translation pipeline, serving as an initial step for my tvm compiler project.

Core Implementation Components

1. Operator Registry: The Conversion Backbone

Purpose: Central mapping of ONNX operators to conversion function

class OperatorRegistry { 
private:
    std::unordered_map<std::string, ConversionFunc> registry;
    static OperatorRegistry* _operator_registry_instance;
    void register_all_ops();
public:
    OperatorRegistry() {
        register_all_ops(); 
    }

    static OperatorRegistry* get_instance();

    void registerOp(const std::string& optype, ConversionFunc func){
        registry[optype] = std::move(func);
    }

    ConversionFunc getConversionFunc(const std::string& optype){
        auto it = registry.find(optype);
        if(it != registry.end()){
            return it->second;
        }
        throw std::runtime_error("Operator not found in registry: " + optype);
    }

    std::vector<std::shared_ptr<RelayExpr>> convertNode(onnx::NodeProto& node, const std::vector<std::shared_ptr<RelayExpr>>& inputs){
        auto it = registry.find(node.op_type());
        if(it != registry.end()){
            return it->second(node, inputs);
        }
        throw std::runtime_error("Operator not found in registry: " + node.op_type());
    }
};

The operator registry serves as the central dispatch mechanism for ONNX-to-Relay conversions, implementing a classic factory pattern. At its core is a mapping that translates ONNX operator type strings (like "Gemm") to conversion functions. These functions accept an ONNX node prototype and input tensors, returning constructed Relay expressions. The singleton pattern ensures a single global registry while maintaining testability through implementations.

Key design aspects:

1. Extensibility: New operators can be added via simple registerOp calls

2. Type Safety: Conversion functions enforce strict signature matching

This design enables clean separation between operator definitions and graph traversal logic, critical for maintaining a modular codebase as we expand supported operators.

2. Type System Implementation

The core type system implements a hierarchy of type representations essential for validating neural network operations:

Base Type Class

class Type {
public:
    virtual void print(std::ostream& os) {
        os << "Not implemented print for Type";
        exit(1); // Enforces implementation in derived classes
    }
};

Primitive Types

Handles fundamental data types with enum-based kind tracking:

class PrimType : public Type {
public:
    enum TypeKind { kInt, kFloat, kBool };
    TypeKind kind;

    void print(std::ostream& os) override {
        switch(kind) {
            case kInt: os << "int"; break;
            case kFloat: os << "float"; break;
            case kBool: os << "bool"; break;
        }
    }
};

Tensor Type

Captures shape and dtype information critical for neural network tensors:

class TensorType : public Type {
public:
    std::vector<int> shape;
    PrimType dtype;

    void print(std::ostream& os) override {
        os << "Tensor[(";
        for(int i = 0; i < shape.size(); i++) {
            os << shape[i];
            if(i < shape.size() - 1) os << ", ";
        }
        os << "), ";
        dtype.print(os);
        os << "]";
    }
};

Key Features:

1. Polymorphic Storage

Types are stored as shared_ptr<Type> enabling type-safe container storage:

   std::shared_ptr<Type> inputVarType = 
       std::make_shared<TensorType>(input_shape, PrimType::kFloat);

2. Type-Aware Variables

Relay variables embed type information directly:

   class RelayVar : public RelayExpr {
       std::shared_ptr<Type> type;
       // ...
   };

3. Validation Through Printing

The print() method serves dual purpose for both debugging and type validation:

   // Test case example
   TEST(RelayTypes, TensorTypePrint) {
       TensorType tensor({1,3}, PrimType::kFloat);
       std::stringstream ss;
       tensor.print(ss);
       ASSERT_EQ(ss.str(), "Tensor[(1, 3), float]");
   }

3. ONNX Model Parsing Architecture

The parser implements a three-phase translation process from ONNX models to Relay IR:

1. Model Inspection & Validation

2. Symbolic Variable Creation

3. Graph Translation Workflow

// Core conversion loop
for(auto node : model.graph().node()) {
    // 1. Collect input arguments
    std::vector<std::shared_ptr<relay::RelayExpr>> args;
    for(const auto& input : node.input()) {
        args.push_back(input2relayVars[input]);
    }

    // 2. Dispatch to operator registry
    auto* converter = relay::OperatorRegistry::get_instance();
    auto output = converter->convertNode(node, args);

    // 3. Map outputs
    for(int i = 0; i < output.size(); i++) {
        output2relayExprs[node.output(i)] = output[i];
    }
}

4. Final Outcome

Now, the onnx to relay converter can parser very simple onnx model like this

$ ./myTvm ../misc/onnx-input/simple_model.onnx 
Parsing ONNX model: ../misc/onnx-input/simple_model.onnx
Graph name: main_graph
Inputs: 
  input : (1, 3, )

Outputs: 
  output : (1, 2, )

Nodes: 
  /fc/Gemm : Gemm
    input: input
    input: fc.weight
    input: fc.bias
    output: output

def @main(%input: Tensor[(1, 3), float]) {
  relay.nn.bias_add(relay.nn.dense(%input, %fc.weight), %fc.bias)
}

This post is mainly about the commit 66006905 in the repo.

0
Subscribe to my newsletter

Read articles from Yuanbo Li directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Yuanbo Li
Yuanbo Li