Run ONNX models

ONNX (Open Neural Network Exchange) is a standard format for exporting models — typically created in frameworks like PyTorch — so they can run anywhere. On Dragonwing devices you can use ONNX Runtime with AI Engine Direct to execute ONNX models directly on the NPU for maximum performance.

onnxruntime wheel with AI Engine Direct

onnxruntime currently does not publish prebuilt wheels for aarch64 Linux with AI Engine Direct bindings - so you cannot install onnxruntime from PyPI. You can download prebuilt wheels here:

(Install via pip3 install onnxruntime_qnn-*-linux_aarch64.whl)

To build a wheel for other onnxruntime or Python versions, see edgeimpulse/onnxruntime-qnn-linux-aarch64.

Preparing your onnx file

The NPU only only supports quantized uint8/int8 models with a fixed input shape. If your model is not quantized, or if you have a dynamic input shape your model will automatically be offloaded to the CPU. Here's some tips on how to prepare your model.

A full length tutorial for exporting a PyTorch model to ONNX is available in the PyTorch documentation.

Dynamic shapes

If you have a model with dynamic shapes, you'll need to make them fixed shape first. You can see the shape of your network via Netron.

For example, this model has dynamic shapes:

A model with dynamic shape
An ONNX model with a dynamic shape. Here the input tensor is named `pixel_values`.

You can set a fixed shape via onnxruntime.tools.make_dynamic_shape_fixed:

python3 -m onnxruntime.tools.make_dynamic_shape_fixed \
    model_without_shapes.onnx \
    model_with_shapes.onnx \
    --input_name pixel_values \
    --input_shape 1,3,224,224

Afterwards your model has a fixed shape and is ready to run on your NPU.

An ONNX model with a fixed shape
An ONNX model with a fixed shape

Quantizing models

The NPU only supports uint8/int8 quantized models. Unsupported models, or unsupported layers will be automatically moved back to the CPU. For a guide on quantization models, see ONNX Runtime docs: Quantize ONNX Models.

Don't want to quantize yourself? You can download a range of pre-quantized models from Qualcomm AI Hub, or use Edge Impulse to quantize new or existing models.

Running a model on the NPU (Python)

To offload a model to the NPU, you just need to load the QNNExecutionProvider; and pass it when creating the InferenceSession. E.g.:

import onnxruntime as ort

providers = (("QNNExecutionProvider", {
    "backend_type": "htp",
    "profiling_level": "detailed",
}))

so = ort.SessionOptions()

sess = ort.InferenceSession(MODEL_PATH, sess_options=so, providers=providers)
actual_providers = sess.get_providers()
print(f"Using providers: {actual_providers}")   # will show QNNExecutionProvider,CPUExecutionProvider if QNN can be loaded

(Make sure you use an onnxruntime wheel with AI Engine Direct bindings, see the top of the page)

Example: SqueezeNet-1.1 (Python)

Open the terminal on your development board, or an ssh session to your development board, and:

  1. Create a new venv, and install the onnxruntime and Pillow:

    mkdir -p ~/onnxruntime-demo/
    cd ~/onnxruntime-demo/
    
    python3.12 -m venv .venv
    source .venv/bin/activate
    
    # onnxruntime with AI Engine Direct bindings (only works on Python3.12)
    wget https://cdn.edgeimpulse.com/qc-ai-docs/wheels/onnxruntime_qnn-1.23.0-cp312-cp312-linux_aarch64.whl
    pip3 install onnxruntime_qnn-1.23.0-cp312-cp312-linux_aarch64.whl
    rm onnxruntime*.whl
    
    # Other dependencies
    pip3 install Pillow
  2. Here's an end-to-end example running SqueezeNet-1.1. Save this file as inference_onnx.py:

    import os, sys, time, urllib.request, numpy as np, onnxruntime as ort
    from PIL import Image
    
    use_npu = True if len(sys.argv) >= 2 and sys.argv[1] == '--use-npu' else False
    
    def download_file_if_not_exists(path, url):
        if not os.path.exists(path):
            os.makedirs(os.path.dirname(path), exist_ok=True)
            print(f"Downloading {path} from {url}...")
            urllib.request.urlretrieve(url, path)
        return path
    
    # Path to your model/label/test image (will be download automatically)
    MODEL_PATH = download_file_if_not_exists('models/squeezenet-1.1/model.onnx', 'https://cdn.edgeimpulse.com/qc-ai-docs/models/SqueezeNet-1.1_w8a8.onnx')
    MODEL_DATA_PATH = download_file_if_not_exists('models/squeezenet-1.1/model.data', 'https://cdn.edgeimpulse.com/qc-ai-docs/models/SqueezeNet-1.1_w8a8.data')
    LABELS_PATH = download_file_if_not_exists('models/squeezenet-1.1_labels.txt', 'https://cdn.edgeimpulse.com/qc-ai-docs/models/SqueezeNet-1.1_labels.txt')
    IMAGE_PATH = download_file_if_not_exists('images/boa-constrictor.jpg', 'https://cdn.edgeimpulse.com/qc-ai-docs/examples/boa-constrictor.jpg')
    
    # Parse labels
    with open(LABELS_PATH, 'r') as f:
        labels = [line for line in f.read().splitlines() if line.strip()]
    
    # Use HTP backend of libQnnTFLiteDelegate.so (NPU) when --use-npu is passed in (otherwise CPU)
    providers = []
    if use_npu:
        providers.append(("QNNExecutionProvider", {
            "backend_type": "htp",
        }))
    else:
        providers.append("CPUExecutionProvider")
    
    so = ort.SessionOptions()
    
    sess = ort.InferenceSession(MODEL_PATH, sess_options=so, providers=providers)
    actual_providers = sess.get_providers()
    print(f"Using providers: {actual_providers}") # Show which providers are actually loaded
    
    inputs  = sess.get_inputs()
    outputs = sess.get_outputs()
    
    # !! Quantization parameters (cannot read these params from the onnx model I believe) - update these if you have another model
    scale = 1.0 / 255.0
    zero_point = 0
    dtype = np.uint8
    
    # Load, preprocess and quantize image
    def load_image_for_onnx(path, H, W):
        # Load image
        img = Image.open(path).convert("RGB").resize((W, H))
        img_np = np.array(img, dtype=np.float32)
        # !! Normalize... this model is 0..1 scaled (no further normalization); but that depends on your model !!
        img_np = img_np / 255
        # HWC -> CHW
        img_np = np.transpose(img_np, (2, 0, 1))
        # Add batch dim
        img_np = np.expand_dims(img_np, 0)
    
        # Quantize input if needed
        if dtype == np.float32:
            return img_np
        elif dtype == np.uint8:
            # q = round(x/scale + zp)
            q = np.round(img_np / scale + zero_point)
            return np.clip(q, 0, 255).astype(np.uint8)
        elif dtype == np.int8:
            # Commonly zero_point ≈ 0 (symmetric), but use provided zp anyway
            q = np.round(img_np / scale + zero_point)
            return np.clip(q, -128, 127).astype(np.int8)
        else:
            raise Exception('Unexpected dtype: ' + str(dtype))
    
    # input data scaled 0..1
    input_data = load_image_for_onnx(path=IMAGE_PATH, H=224, W=224)
    
    # Warmup once
    _ = sess.run(None, { sess.get_inputs()[0].name: input_data })
    
    # Run 10x so we can calculate avg. runtime per inference
    start = time.perf_counter()
    for i in range(10):
        out = sess.run(None, { sess.get_inputs()[0].name: input_data })
    end = time.perf_counter()
    
    # Image classification models in AI Hub miss a Softmax() layer at the end of the model, so add it manually
    def softmax(x, axis=-1):
        # subtract max for numerical stability
        x_max = np.max(x, axis=axis, keepdims=True)
        e_x = np.exp(x - x_max)
        return e_x / np.sum(e_x, axis=axis, keepdims=True)
    
    scores = softmax(np.squeeze(out[0], axis=0))
    
    # Take top 5
    top_k_idx = scores.argsort()[-5:][::-1]
    
    print("\nTop-5 predictions:")
    for i in top_k_idx:
        label = labels[i] if i < len(labels) else f"Class {i}"
        print(f"{label}: score={scores[i]}")
    
    print("")
    print(f'Inference took (on average): {((end - start) * 1000) / 10:.4g}ms. per image')

    Note: this script has hard-coded quantization parameters. If you swap out the model you'll might need to change these.

  3. Run the model on the CPU:

    python3 inference_onnx.py
    
    # Top-5 predictions:
    # common iguana: score=0.3682704567909241
    # night snake: score=0.1186317503452301
    # water snake: score=0.1186317503452301
    # boa constrictor: score=0.0813227966427803
    # bullfrog: score=0.0813227966427803
    #
    # Inference took (on average): 6.50 ms per image
  4. Run the model on the NPU:

    python3 inference_onnx.py --use-npu
    
    # Top-5 predictions:
    # common iguana: score=0.30427297949790955
    # water snake: score=0.11838366836309433
    # night snake: score=0.11838366836309433
    # boa constrictor: score=0.11838366836309433
    # rock python: score=0.08115273714065552
    #
    # Inference took (on average): 1.60 ms per image

As you can see this model runs significantly faster on NPU - but there's a slight change in the output of the model.

Example: PyTorch → ONNX → Quantized int8 → inference on NPU

Open the terminal on your development board, or an ssh session to your development board, and:

  1. Create a new venv, and install the onnxruntime and Pillow:

    mkdir -p ~/onnxruntime-demo/
    cd ~/onnxruntime-demo/
    
    python3.12 -m venv .venv
    source .venv/bin/activate
    
    # onnxruntime with AI Engine Direct bindings (only works on Python3.12)
    wget https://cdn.edgeimpulse.com/qc-ai-docs/wheels/onnxruntime_qnn-1.23.0-cp312-cp312-linux_aarch64.whl
    pip3 install onnxruntime_qnn-1.23.0-cp312-cp312-linux_aarch64.whl
    rm onnxruntime*.whl
    
    # Other dependencies
    pip3 install Pillow
  2. Here's an end-to-end example running SqueezeNet-1.1 from torchvision. Save this file as inference_pytorch_onnx.py:

    # pytorch_to_onnx_int8_squeezenet.py
    import sys, os, glob, io, time, numpy as np, onnx, onnxruntime as ort
    from PIL import Image
    
    import torch, torchvision
    import torchvision.transforms as T
    from onnxruntime.quantization import (
        CalibrationDataReader,
        QuantFormat,
        QuantType,
        calibrate,
        quantize_static,
    )
    
    use_npu = True if len(sys.argv) >= 2 and sys.argv[1] == '--use-npu' else False
    
    def download_file_if_not_exists(path, url):
        if not os.path.exists(path):
            os.makedirs(os.path.dirname(path), exist_ok=True)
            print(f"Downloading {path} from {url}...")
            urllib.request.urlretrieve(url, path)
        return path
    
    weights = torchvision.models.SqueezeNet1_1_Weights.DEFAULT
    IMAGE_PATH = download_file_if_not_exists('images/boa-constrictor.jpg', 'https://cdn.edgeimpulse.com/qc-ai-docs/examples/boa-constrictor.jpg')
    
    # Load PyTorch SqueezeNet1_1 from torchvision
    device = "cpu"
    model = torchvision.models.squeezenet1_1(weights)
    model.eval().to(device)
    
    # Export to ONNX (fp32)
    os.makedirs("models", exist_ok=True)
    onnx_fp32 = "models/squeezenet1_1_fp32.onnx"
    input_size = 224
    dummy = torch.randn(1, 3, input_size, input_size, device=device)
    
    torch.onnx.export(
        model, dummy, onnx_fp32,
        input_names=["input"],
        output_names=["logits"],
        opset_version=13,
        do_constant_folding=True,
        dynamic_axes={"input": {0: "batch"}, "logits": {0: "batch"}},
    )
    onnx.checker.check_model(onnx.load(onnx_fp32))
    print(f"Exported FP32 ONNX -> {onnx_fp32}")
    
    # Provide a calibration data reader for static INT8 quantization
    class ImageFolderDataReader(CalibrationDataReader):
        def __init__(self, image_paths):
            self.image_paths = image_paths
            self.transform = T.Compose([
                T.Resize(256), T.CenterCrop(224),
                T.ToTensor(), T.Normalize(mean=[0.485,0.456,0.406],
                                        std=[0.229,0.224,0.225])
            ])
            self._iter = None
    
        def get_next(self):
            if self._iter is None:
                self._iter = iter(self.image_paths)
            try:
                p = next(self._iter)
            except StopIteration:
                return None
            img = Image.open(p).convert("RGB")
            x = self.transform(img).unsqueeze(0).numpy()
            return {"input": x}
    
    # Replace with representative images from your domain
    calib = ImageFolderDataReader([IMAGE_PATH])
    
    # Find the ONNX input name (matches "input" above, but we read it to be robust)
    m = onnx.load(onnx_fp32)
    onnx_input_name = m.graph.input[0].name
    
    onnx_int8 = "models/squeezenet1_1_int8.onnx"
    
    # Use QDQ format (widely supported); uint8 activations + int8 weights is a common choice
    quantize_static(
        model_input=onnx_fp32,
        model_output=onnx_int8,
        calibration_data_reader=calib,
        activation_type=QuantType.QUInt8,
        weight_type=QuantType.QInt8
    )
    onnx.checker.check_model(onnx.load(onnx_int8))
    print(f"Quantized INT8 ONNX -> {onnx_int8}")
    
    # Use HTP backend of libQnnTFLiteDelegate.so (NPU) when --use-npu is passed in (otherwise CPU)
    providers = []
    if use_npu:
        providers.append(("QNNExecutionProvider", {
            "backend_type": "htp",
        }))
    else:
        providers.append("CPUExecutionProvider")
    
    input_shape = (1, 3, 224, 224)
    img = Image.open(IMAGE_PATH).convert("RGB").resize((224, 224))
    input_data = np.expand_dims(np.transpose(np.array(img, dtype=np.float32) / 255.0, (2, 0, 1)), 0)
    #input_data = np.random.rand(*input_shape).astype(np.float32)
    
    so = ort.SessionOptions()
    
    sess = ort.InferenceSession(onnx_int8, sess_options=so, providers=providers)
    actual_providers = sess.get_providers()
    print(f"Using providers: {actual_providers}") # Show which providers are actually loaded
    
    inputs  = sess.get_inputs()
    outputs = sess.get_outputs()
    
    _ = sess.run(None, { sess.get_inputs()[0].name: input_data })
    
    # Run 10x so we can calculate avg. runtime per inference
    start = time.perf_counter()
    for i in range(10):
        out = sess.run(None, { sess.get_inputs()[0].name: input_data })
    end = time.perf_counter()
    
    def softmax(x, axis=-1):
        # subtract max for numerical stability
        x_max = np.max(x, axis=axis, keepdims=True)
        e_x = np.exp(x - x_max)
        return e_x / np.sum(e_x, axis=axis, keepdims=True)
    
    scores = softmax(np.squeeze(out[0], axis=0))
    
    # Take top 5
    top_k_idx = scores.argsort()[-5:][::-1]
    
    print("\nTop-5 predictions:")
    for i in top_k_idx:
        label = weights.meta["categories"][i] if i < len(weights.meta["categories"]) else f"Class {i}"
        print(f"{label}: score={scores[i]}")
    
    print("")
    print(f'Inference took (on average): {((end - start) * 1000) / 10:.4g}ms. per image')
  3. Run the model on the CPU:

    python3 inference_pytorch_onnx.py
    
    # Top-5 predictions:
    # bassinet: score=0.473060667514801
    # mosquito net: score=0.15856008231639862
    # quilt: score=0.13602180778980255
    # crib: score=0.07408633828163147
    # cradle: score=0.03270549699664116
    # 
    # Inference took (on average): 7.418ms. per image
  4. Run the model on the NPU:

    python3 inference_pytorch_onnx.py --use-npu
    
    # Top-5 predictions:
    # bassinet: score=0.473060667514801
    # mosquito net: score=0.15856008231639862
    # quilt: score=0.13602180778980255
    # crib: score=0.07408633828163147
    # cradle: score=0.03270549699664116
    # 
    # Inference took (on average): 7.513ms. per image

Tips & tricks

Disable CPU fallback

To debug, you might want to choose to disable fallback to the CPU via:

so = ort.SessionOptions()
so.add_session_config_entry("session.disable_cpu_ep_fallback", "1")

Building new versions of the the onnxruntime package

See edgeimpulse/onnxruntime-qnn-linux-aarch64.

Last updated