Building a Minimal IR for ONNX Model Compilation: From Concept to Implementation

Introduction: The Role of Intermediate Representations
Our compiler's Intermediate Representation (IR) serves as the crucial bridge between high-level ONNX models and optimized executable code. This post documents my implementation of a minimal Relay-inspired IR and the ONNX translation pipeline, serving as an initial step for my tvm compiler project.
Core Implementation Components
1. Operator Registry: The Conversion Backbone
Purpose: Central mapping of ONNX operators to conversion function
class OperatorRegistry {
private:
std::unordered_map<std::string, ConversionFunc> registry;
static OperatorRegistry* _operator_registry_instance;
void register_all_ops();
public:
OperatorRegistry() {
register_all_ops();
}
static OperatorRegistry* get_instance();
void registerOp(const std::string& optype, ConversionFunc func){
registry[optype] = std::move(func);
}
ConversionFunc getConversionFunc(const std::string& optype){
auto it = registry.find(optype);
if(it != registry.end()){
return it->second;
}
throw std::runtime_error("Operator not found in registry: " + optype);
}
std::vector<std::shared_ptr<RelayExpr>> convertNode(onnx::NodeProto& node, const std::vector<std::shared_ptr<RelayExpr>>& inputs){
auto it = registry.find(node.op_type());
if(it != registry.end()){
return it->second(node, inputs);
}
throw std::runtime_error("Operator not found in registry: " + node.op_type());
}
};
The operator registry serves as the central dispatch mechanism for ONNX-to-Relay conversions, implementing a classic factory pattern. At its core is a mapping that translates ONNX operator type strings (like "Gemm") to conversion functions. These functions accept an ONNX node prototype and input tensors, returning constructed Relay expressions. The singleton pattern ensures a single global registry while maintaining testability through implementations.
Key design aspects:
1. Extensibility: New operators can be added via simple registerOp calls
2. Type Safety: Conversion functions enforce strict signature matching
This design enables clean separation between operator definitions and graph traversal logic, critical for maintaining a modular codebase as we expand supported operators.
2. Type System Implementation
The core type system implements a hierarchy of type representations essential for validating neural network operations:
Base Type Class
class Type {
public:
virtual void print(std::ostream& os) {
os << "Not implemented print for Type";
exit(1); // Enforces implementation in derived classes
}
};
Primitive Types
Handles fundamental data types with enum-based kind tracking:
class PrimType : public Type {
public:
enum TypeKind { kInt, kFloat, kBool };
TypeKind kind;
void print(std::ostream& os) override {
switch(kind) {
case kInt: os << "int"; break;
case kFloat: os << "float"; break;
case kBool: os << "bool"; break;
}
}
};
Tensor Type
Captures shape and dtype information critical for neural network tensors:
class TensorType : public Type {
public:
std::vector<int> shape;
PrimType dtype;
void print(std::ostream& os) override {
os << "Tensor[(";
for(int i = 0; i < shape.size(); i++) {
os << shape[i];
if(i < shape.size() - 1) os << ", ";
}
os << "), ";
dtype.print(os);
os << "]";
}
};
Key Features:
1. Polymorphic Storage
Types are stored as shared_ptr<Type> enabling type-safe container storage:
std::shared_ptr<Type> inputVarType =
std::make_shared<TensorType>(input_shape, PrimType::kFloat);
2. Type-Aware Variables
Relay variables embed type information directly:
class RelayVar : public RelayExpr {
std::shared_ptr<Type> type;
// ...
};
3. Validation Through Printing
The print() method serves dual purpose for both debugging and type validation:
// Test case example
TEST(RelayTypes, TensorTypePrint) {
TensorType tensor({1,3}, PrimType::kFloat);
std::stringstream ss;
tensor.print(ss);
ASSERT_EQ(ss.str(), "Tensor[(1, 3), float]");
}
3. ONNX Model Parsing Architecture
The parser implements a three-phase translation process from ONNX models to Relay IR:
1. Model Inspection & Validation
2. Symbolic Variable Creation
3. Graph Translation Workflow
// Core conversion loop
for(auto node : model.graph().node()) {
// 1. Collect input arguments
std::vector<std::shared_ptr<relay::RelayExpr>> args;
for(const auto& input : node.input()) {
args.push_back(input2relayVars[input]);
}
// 2. Dispatch to operator registry
auto* converter = relay::OperatorRegistry::get_instance();
auto output = converter->convertNode(node, args);
// 3. Map outputs
for(int i = 0; i < output.size(); i++) {
output2relayExprs[node.output(i)] = output[i];
}
}
4. Final Outcome
Now, the onnx to relay converter can parser very simple onnx model like this
$ ./myTvm ../misc/onnx-input/simple_model.onnx
Parsing ONNX model: ../misc/onnx-input/simple_model.onnx
Graph name: main_graph
Inputs:
input : (1, 3, )
Outputs:
output : (1, 2, )
Nodes:
/fc/Gemm : Gemm
input: input
input: fc.weight
input: fc.bias
output: output
def @main(%input: Tensor[(1, 3), float]) {
relay.nn.bias_add(relay.nn.dense(%input, %fc.weight), %fc.bias)
}
This post is mainly about the commit 66006905 in the repo.
Subscribe to my newsletter
Read articles from Yuanbo Li directly inside your inbox. Subscribe to the newsletter, and don't miss out.
Written by
