ONNX

ONNX (Open Neural Network Exchange) is a standard format for exporting models — typically created in frameworks like PyTorch — so they can run anywhere. On Dragonwing devices you can use ONNX Runtime with AI Engine Direct to execute ONNX models directly on the NPU for maximum performance.

onnxruntime wheel with AI Engine Direct

onnxruntime currently does not publish prebuilt wheels for aarch64 Linux with AI Engine Direct bindings - so you cannot install onnxruntime through pip. You can download prebuilt wheels here:

(Install via pip3 install onnxruntime_qnn-*-linux_aarch64.whl)

To build a wheel for other onnxruntime or Python versions, see edgeimpulse/onnxruntime-qnn-linux-aarch64.

Preparing your onnx file

The NPU only only supports quantized uint8/int8 models with a fixed input shape. If your model is not quantized, or if you have a dynamic input shape your model will automatically be offloaded to the CPU. Here's some tips on how to prepare your model.

A full length tutorial for exporting a PyTorch model to ONNX is available in the PyTorch documentation.

Dynamic shapes

If you have a model with dynamic shapes, you'll need to make them fixed shape first. You can see the shape of your network via Netron.

For example, this model has dynamic shapes:

A model with dynamic shape
An ONNX model with a dynamic shape. Here the input tensor is named `pixel_values`.

You can set a fixed shape via onnxruntime.tools.make_dynamic_shape_fixed:

python3 -m onnxruntime.tools.make_dynamic_shape_fixed \
    model_without_shapes.onnx \
    model_with_shapes.onnx \
    --input_name pixel_values \
    --input_shape 1,3,224,224

Afterwards your model has a fixed shape and is ready to run on your NPU.

An ONNX model with a fixed shape
An ONNX model with a fixed shape

Quantizing models

The NPU only supports uint8/int8 quantized models. Unsupported models, or unsupported layers will be automatically moved back to the CPU. For a guide on quantization models, see ONNX Runtime docs: Quantize ONNX Models.

Don't want to quantize yourself? You can download a range of pre-quantized models from Qualcomm AI Hub, or use Edge Impulse to quantize new or existing models.

Running a model on the NPU (Python)

To offload a model to the NPU, you just need to load the QNNExecutionProvider; and pass it when creating the InferenceSession. E.g.:

import onnxruntime as ort

providers = (("QNNExecutionProvider", {
    "backend_type": "htp",
    "profiling_level": "detailed",
}))

so = ort.SessionOptions()

sess = ort.InferenceSession(MODEL_PATH, sess_options=so, providers=providers)
actual_providers = sess.get_providers()
print(f"Using providers: {actual_providers}")   # will show QNNExecutionProvider,CPUExecutionProvider if QNN can be loaded

(Make sure you use an onnxruntime wheel with AI Engine Direct bindings, see the top of the page)

Example: SqueezeNet-1.1 (Python)

Open the terminal on your development board, or an ssh session to your development board, and:

  1. Create a new venv, and install the onnxruntime and Pillow:

    python3.12 -m venv .venv-onnxruntime-demo
    source .venv-onnxruntime-demo/bin/activate
    
    # onnxruntime with AI Engine Direct bindings (only works on Python3.12)
    wget https://cdn.edgeimpulse.com/qc-ai-docs/wheels/onnxruntime_qnn-1.23.0-cp312-cp312-linux_aarch64.whl
    pip3 install onnxruntime_qnn-1.23.0-cp312-cp312-linux_aarch64.whl
    rm onnxruntime*.whl
    
    # Other dependencies
    pip3 install Pillow
  2. Here's an end-to-end example running SqueezeNet-1.1. Save this file as inference_onnx.py:

    #!/usr/bin/env python3
    import os, sys, time, urllib.request
    import numpy as np
    from PIL import Image
    import onnxruntime as ort
    
    def curr_ms():
        return round(time.time() * 1000)
    
    use_npu = True if len(sys.argv) >= 2 and sys.argv[1] == '--use-npu' else False
    
    # Path to your quantized ONNX model and test image (will be download automatically)
    MODEL_PATH = "model.onnx"
    MODEL_DATA_PATH = "model.data"
    IMAGE_PATH = "boa-constrictor.jpg"
    LABELS_PATH = "SqueezeNet-1.1_labels.txt"
    
    if not os.path.exists(MODEL_PATH):
        print("Downloading model...")
        model_url = 'https://cdn.edgeimpulse.com/qc-ai-docs/models/SqueezeNet-1.1_w8a8.onnx'
        urllib.request.urlretrieve(model_url, MODEL_PATH)
    
    if not os.path.exists(MODEL_DATA_PATH):
        print("Downloading model data...")
        model_url = 'https://cdn.edgeimpulse.com/qc-ai-docs/models/SqueezeNet-1.1_w8a8.data'
        urllib.request.urlretrieve(model_url, MODEL_DATA_PATH)
    
    if not os.path.exists(LABELS_PATH):
        print("Downloading labels...")
        labels_url = 'https://cdn.edgeimpulse.com/qc-ai-docs/models/SqueezeNet-1.1_labels.txt'
        urllib.request.urlretrieve(labels_url, LABELS_PATH)
    
    if not os.path.exists(IMAGE_PATH):
        print("Downloading image...")
        image_url = 'https://cdn.edgeimpulse.com/qc-ai-docs/examples/boa-constrictor.jpg'
        urllib.request.urlretrieve(image_url, IMAGE_PATH)
    
    with open(LABELS_PATH, 'r') as f:
        labels = [line for line in f.read().splitlines() if line.strip()]
    
    providers = []
    if use_npu:
        providers.append(("QNNExecutionProvider", {
            "backend_type": "htp",
        }))
    else:
        providers.append("CPUExecutionProvider")
    
    so = ort.SessionOptions()
    
    sess = ort.InferenceSession(MODEL_PATH, sess_options=so, providers=providers)
    actual_providers = sess.get_providers()
    print(f"Using providers: {actual_providers}") # Show which providers are actually loaded
    
    inputs  = sess.get_inputs()
    outputs = sess.get_outputs()
    
    def load_image_for_onnx(path, H, W):
        img = Image.open(path).convert("RGB").resize((W, H))
        arr = np.array(img)
        arr = arr.astype(np.float32) / 255.0
    
        arr = np.transpose(arr, (2, 0, 1))  # HWC -> CHW
        arr = np.expand_dims(arr, 0)        # -> NCHW
    
        return arr
    
    # input data scaled 0..1
    input_data_f32 = load_image_for_onnx(path=IMAGE_PATH, H=224, W=224)
    
    # quantize model (cannot read these params from the onnx model I believe)
    scale = 1.0 / 255.0
    zero_point = 0
    input_data_u8 = np.round(input_data_f32.astype(np.float32) / float(scale)) + int(zero_point)
    input_data_u8 = np.clip(input_data_u8, 0, 255).astype(np.uint8)
    
    # Warmup once
    _ = sess.run(None, {sess.get_inputs()[0].name: input_data_u8})
    
    # Run 10x so we can calculate avg. runtime per inference
    start = curr_ms()
    for i in range(10):
        out = sess.run(None, {sess.get_inputs()[0].name: input_data_u8})
    end = curr_ms()
    
    # Image classification models in AI Hub miss a Softmax() layer at the end of the model, so add it manually
    def softmax(x, axis=-1):
        # subtract max for numerical stability
        x_max = np.max(x, axis=axis, keepdims=True)
        e_x = np.exp(x - x_max)
        return e_x / np.sum(e_x, axis=axis, keepdims=True)
    
    scores = softmax(np.squeeze(out[0], axis=0))
    
    # Take top 5
    top_k_idx = scores.argsort()[-5:][::-1]
    
    print("\nTop-5 predictions:")
    for i in top_k_idx:
        label = labels[i] if i < len(labels) else f"Class {i}"
        print(f"{label}: score={scores[i]}")
    
    print("")
    print(f"Inference took (on average): {(end - start) / 10:.2f} ms per image")

    Note: this script has hard-coded quantization parameters. If you swap out the model you'll might need to change these.

  3. Run the model on the CPU:

    python3 inference_onnx.py
    
    # Top-5 predictions:
    # common iguana: score=0.3682704567909241
    # night snake: score=0.1186317503452301
    # water snake: score=0.1186317503452301
    # boa constrictor: score=0.0813227966427803
    # bullfrog: score=0.0813227966427803
    #
    # Inference took (on average): 6.50 ms per image
  4. Run the model on the NPU:

    python3 inference_onnx.py --use-npu
    
    # Top-5 predictions:
    # common iguana: score=0.30427297949790955
    # water snake: score=0.11838366836309433
    # night snake: score=0.11838366836309433
    # boa constrictor: score=0.11838366836309433
    # rock python: score=0.08115273714065552
    #
    # Inference took (on average): 1.60 ms per image

As you can see this model runs significantly faster on NPU - but there's a slight change in the output of the model.

Tips & tricks

Disable CPU fallback

To debug, you might want to choose to disable fallback to the CPU via:

so = ort.SessionOptions()
so.add_session_config_entry("session.disable_cpu_ep_fallback", "1")

Building new versions of the the onnxruntime package

See edgeimpulse/onnxruntime-qnn-linux-aarch64.

Last updated