Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions examples/decision_tree/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Decision Tree Example (Iris Dataset)

This folder demonstrates how to train and use a simple decision tree model on the Iris dataset, then converte a trained model to GGML format and run inference in a Go-based MIPS environment.

------

# Folder Description

Below is a quick overview of each folder:

1. **`converet/`**
- **`convert.py`**: An example script that demonstrates converting models to GGML format.
- **`iris_ggml_model.bin`**: An example model already converted to GGML format.

2. **`go-mips-inference/`**
- **`build.sh`**: A Bash script to build the Go inference program.
- **`mips_inference`**: Compiled Go inference program.
- **`mips_inference.go`**: Main Go source code implementing the MIPS inference logic.
- **`val.py`**: Python script for verifying outputs.

3. **`train/`**
- **`iris_decision_tree_model.pkl`**: A serialized scikit-learn decision tree model trained on the Iris dataset.
- **`train.py`**: Python script that trains a decision tree on the Iris dataset, saving the model to `iris_decision_tree_model.pkl`.

------

# How to Run the Decision Tree Example

## 1. Train a Decision Tree

1. Navigate to the `train/` folder
2. Run `python train.py`

## 2. Convert to a GGML

1. Navigate to the `convert/` folder
2. Run `python convert.py`

## 3. Build MIPS Inference

1. Navigate to the `go-mips-inference` folder
2. Run `./build.sh`

## 4. Validate the result

We provide `val.py` (in the go-mips-inference folder) that can be used to check or compare inference outputs.

```
$ python val.py
Accuracy on the Iris test set: 100.00%
```
42 changes: 42 additions & 0 deletions examples/decision_tree/converet/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import joblib
import numpy as np
import struct


clf = joblib.load('../train/iris_decision_tree_model.pkl')


def convert_tree_to_ggml(clf, filename='iris_ggml_model.bin'):
with open(filename, 'wb') as fout:
fout.write(struct.pack("i", 0x67676d6c)) # Magic: "ggml"

n_features = clf.n_features_in_
fout.write(struct.pack("i", n_features))

# Save decision tree parameters
tree_ = clf.tree_
n_nodes = tree_.node_count # Total number of nodes in the tree

fout.write(struct.pack("i", n_nodes))

# Write the tree structure (node information)
for i in range(n_nodes):
# Split feature
feature = tree_.feature[i]
# Threshold for the split
threshold = tree_.threshold[i]
# Left and right children (indices)
left_child = tree_.children_left[i]
right_child = tree_.children_right[i]
# Value is an array of size (n_classes)
value = tree_.value[i].flatten()

fout.write(struct.pack("i", feature)) # Feature index
fout.write(struct.pack("f", threshold)) # Threshold
fout.write(struct.pack("i", left_child)) # Left child
fout.write(struct.pack("i", right_child)) # Right child
fout.write(struct.pack(f"{len(value)}f", *value)) # Class probabilities

print(f"Model converted to GGML format: {filename}")

convert_tree_to_ggml(clf)
Binary file added examples/decision_tree/converet/iris_ggml_model.bin
Binary file not shown.
13 changes: 13 additions & 0 deletions examples/decision_tree/go-mips-inference/build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/usr/bin/env bash
set -e
cd "$(dirname "$0")"

if [ ! -f go.mod ]; then
go mod init go-mips-inference
fi

go mod tidy

go build -o mips_inference mips_inference.go

echo "Build complete. The executable is named 'mips_inference'."
46 changes: 46 additions & 0 deletions examples/decision_tree/go-mips-inference/common/vmutils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package common

import (
"io/ioutil"
"log"
"math"
)

// vm only ===================================================================================

// memory layout in MIPS
const (
INPUT_ADDR = 0x31000000
OUTPUT_ADDR = 0x32000000
MODEL_ADDR = 0x33000000
MAGIC_ADDR = 0x30000800
)

// ReadBytesFromFile reads the entire file into a byte slice
func ReadBytesFromFile(filePath string) []byte {
data, err := ioutil.ReadFile(filePath)
if err != nil {
log.Fatalf("Error reading file %s: %v\n", filePath, err)
}
return data
}

// ReadInt32FromBytes extracts an int32 from `data` at the current index `idx`
func ReadInt32FromBytes(data []byte, idx *int) int32 {
val := uint32(data[*idx]) |
uint32(data[*idx+1])<<8 |
uint32(data[*idx+2])<<16 |
uint32(data[*idx+3])<<24
*idx += 4
return int32(val)
}

// ReadFloat32FromBytes extracts a float32 from `data` at the current index `idx`
func ReadFloat32FromBytes(data []byte, idx *int) float32 {
bits := uint32(data[*idx]) |
uint32(data[*idx+1])<<8 |
uint32(data[*idx+2])<<16 |
uint32(data[*idx+3])<<24
*idx += 4
return math.Float32frombits(bits)
}
3 changes: 3 additions & 0 deletions examples/decision_tree/go-mips-inference/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
module go-mips-inference

go 1.23.4
Binary file not shown.
120 changes: 120 additions & 0 deletions examples/decision_tree/go-mips-inference/mips_inference.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package main

import (
"fmt"
"os"
"strconv"
"go-mips-inference/common"
)

type irisModel struct {
nFeatures int32
nNodes int32
features []int32
thresholds []float32
lefts []int32
rights []int32
values [][]float32
}

func loadIrisModel(model *irisModel, filePath string) error {
// 1) Read the file instead of memory address
modelBytes := common.ReadBytesFromFile(filePath)

idx := 0

// 2) Check magic number
magic := common.ReadInt32FromBytes(modelBytes, &idx)
if magic != 0x67676d6c {
return fmt.Errorf("invalid magic number: 0x%x", magic)
}

// 3) Read n_features, n_nodes
model.nFeatures = common.ReadInt32FromBytes(modelBytes, &idx)
model.nNodes = common.ReadInt32FromBytes(modelBytes, &idx)

// 4) Allocate slices
model.features = make([]int32, model.nNodes)
model.thresholds = make([]float32, model.nNodes)
model.lefts = make([]int32, model.nNodes)
model.rights = make([]int32, model.nNodes)
model.values = make([][]float32, model.nNodes)

// 5) Read node data
for i := int32(0); i < model.nNodes; i++ {
model.features[i] = common.ReadInt32FromBytes(modelBytes, &idx)
model.thresholds[i] = common.ReadFloat32FromBytes(modelBytes, &idx)
model.lefts[i] = common.ReadInt32FromBytes(modelBytes, &idx)
model.rights[i] = common.ReadInt32FromBytes(modelBytes, &idx)

// 3 classes for Iris
classProb := make([]float32, 3)
for j := 0; j < 3; j++ {
classProb[j] = common.ReadFloat32FromBytes(modelBytes, &idx)
}
model.values[i] = classProb
}
return nil
}

// Evaluate the decision tree
func evalIrisModel(model *irisModel, features []float32) int {
var node int32 = 0

for {
// Leaf check
if model.lefts[node] == -1 && model.rights[node] == -1 {
// return class with highest probability
maxIndex := 0
for i := 1; i < len(model.values[node]); i++ {
if model.values[node][i] > model.values[node][maxIndex] {
maxIndex = i
}
}
return maxIndex
}

// Branch
if features[model.features[node]] <= model.thresholds[node] {
node = model.lefts[node]
} else {
node = model.rights[node]
}
}
}

func main() {
// 1) Create model struct
var model irisModel

// 2) Load your model file from local path
modelPath := "models/iris/ggml-model-small-f32-big-endian.bin"
err := loadIrisModel(&model, modelPath)
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to load model: %v\n", err)
os.Exit(1)
}

// 3) Example input


if len(os.Args) < 5 {
fmt.Fprintf(os.Stderr, "Usage: %s f1 f2 f3 f4\n", os.Args[0])
os.Exit(1)
}

// 2) Convert them to float32
input := make([]float32, 4)
for i := 0; i < 4; i++ {
val, err := strconv.ParseFloat(os.Args[i+1], 32)
if err != nil {
fmt.Fprintf(os.Stderr, "Invalid feature: %s\n", os.Args[i+1])
os.Exit(1)
}
input[i] = float32(val)
}

// 4) Evaluate
predictedClass := evalIrisModel(&model, input)
fmt.Printf("Predicted class: %d\n", predictedClass)
}
Binary file not shown.
57 changes: 57 additions & 0 deletions examples/decision_tree/go-mips-inference/val.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import subprocess
import re
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

def predict_with_mips_inference(features):
"""
Calls the local ./mips_inference executable file and passes 4 floating-point parameters.
Returns the predicted class (integer) from the program's output.
"""
str_args = [str(f) for f in features]

try:
result = subprocess.run(
["./mips_inference"] + str_args,
capture_output=True,
text=True
)

if result.returncode != 0:
raise RuntimeError(f"mips_inference execution failed, error message: {result.stderr}")

# Assume the output contains "Predicted class: X"
output = result.stdout.strip()
match = re.search(r"Predicted class:\s*(\d+)", output)
if match:
predicted_class = int(match.group(1))
return predicted_class
else:
raise ValueError(f"Failed to parse the prediction result, please check the mips_inference output format: {output}")

except FileNotFoundError:
raise FileNotFoundError("The './mips_inference' executable file was not found. Please ensure it is compiled and in the current directory.")

def main():
iris = load_iris()
X = iris.data
y = iris.target

X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)

y_pred = []
for i, features in enumerate(X_test):
predicted = predict_with_mips_inference(features)
y_pred.append(predicted)

acc = accuracy_score(y_test, y_pred)
print(f"Accuracy on the Iris test set: {acc * 100:.2f}%")

if __name__ == "__main__":
main()
Binary file not shown.
25 changes: 25 additions & 0 deletions examples/decision_tree/train/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Import necessary libraries
import torch
import torch.nn as nn
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
import joblib

iris = load_iris()
X = iris.data
y = iris.target

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

clf = DecisionTreeClassifier(random_state=42)
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy of Decision Tree on Iris dataset: {accuracy:.2f}')

# Save the model (using joblib or pickle)
joblib.dump(clf, 'iris_decision_tree_model.pkl')