ONNX: Optimization Without Compromise
How ONNX enables blazing-fast inference while maintaining model quality
🎯 The Challenge We Faced
Running sophisticated AI models for knowledge intelligence presents challenges:
- Inference speed: Need real-time responses for interactive exploration
- Resource constraints: Personal devices have limited compute
- Model portability: Different users, different hardware
- Memory efficiency: Large models strain system resources
We needed to optimize without sacrificing the quality that makes Oboyu intelligent.
💡 Why ONNX Was Our Answer
The ONNX Advantage
# Before ONNX: PyTorch model
pytorch_time = 145ms per inference
memory_usage = 2.3GB
# After ONNX: Optimized model
onnx_time = 43ms per inference # 3.4x faster
memory_usage = 890MB # 61% less memory
accuracy_delta = 0.001 # Negligible quality loss
Real Performance Gains
graph LR
subgraph "PyTorch Pipeline"
A[Input] --> B[Python Overhead]
B --> C[Dynamic Graph]
C --> D[GPU Transfer]
D --> E[Computation]
E --> F[Result]
end
subgraph "ONNX Pipeline"
G[Input] --> H[Direct Runtime]
H --> I[Static Graph]
I --> J[Optimized Ops]
J --> K[Result]
end
📊 Optimization Results
Model Performance Comparison
Model | Framework | Size | Inference Time | Memory | Quality Score |
---|---|---|---|---|---|
Japanese BERT | PyTorch | 445MB | 145ms | 2.3GB | 0.887 |
Japanese BERT | ONNX | 443MB | 43ms | 890MB | 0.886 |
Japanese BERT | ONNX+Quant | 112MB | 28ms | 450MB | 0.881 |
Entity Extractor | PyTorch | 678MB | 89ms | 1.8GB | 0.923 |
Entity Extractor | ONNX | 675MB | 31ms | 780MB | 0.922 |
Batch Processing Performance
# Benchmark: Processing 1000 documents
results = {
"pytorch": {
"total_time": 145.2, # seconds
"throughput": 6.9, # docs/second
"gpu_memory": 4.2, # GB
},
"onnx": {
"total_time": 43.1, # seconds
"throughput": 23.2, # docs/second
"gpu_memory": 1.8, # GB
}
}
🛠️ Implementation Journey
1. Model Conversion Pipeline
import torch
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType
from transformers import AutoModel
class ONNXConverter:
def __init__(self, model_name):
self.model = AutoModel.from_pretrained(model_name)
self.model.eval()
def convert_to_onnx(self, output_path, sequence_length=512):
# Create dummy input
dummy_input = torch.randint(0, 1000, (1, sequence_length))
# Export to ONNX
torch.onnx.export(
self.model,
dummy_input,
output_path,
input_names=['input_ids'],
output_names=['embeddings'],
dynamic_axes={
'input_ids': {0: 'batch_size', 1: 'sequence'},
'embeddings': {0: 'batch_size'}
},
opset_version=14,
do_constant_folding=True,
)
# Verify the model
onnx_model = onnx.load(output_path)
onnx.checker.check_model(onnx_model)
return output_path
2. Quantization Strategy
class IntelligentQuantizer:
def __init__(self, calibration_data):
self.calibration_data = calibration_data
def quantize_model(self, onnx_path, output_path):
# Dynamic quantization for CPU deployment
quantize_dynamic(
onnx_path,
output_path,
weight_type=QuantType.QInt8,
optimize_model=True
)
# Validate quality preservation
quality_score = self.validate_quality(onnx_path, output_path)
if quality_score < 0.98: # 98% quality threshold
raise ValueError(f"Quality degradation too high: {quality_score}")
return output_path
def validate_quality(self, original, quantized):
# Compare outputs on calibration data
original_outputs = self.run_inference(original, self.calibration_data)
quantized_outputs = self.run_inference(quantized, self.calibration_data)
# Calculate similarity
similarity = cosine_similarity(original_outputs, quantized_outputs)
return similarity.mean()
3. Optimized Inference Engine
import onnxruntime as ort
import numpy as np
class OptimizedInferenceEngine:
def __init__(self, model_path):
# Create session with optimizations
self.session = ort.InferenceSession(
model_path,
providers=['CPUExecutionProvider'],
sess_options=self._get_optimized_options()
)
# Pre-allocate buffers for speed
self.input_name = self.session.get_inputs()[0].name
self.output_name = self.session.get_outputs()[0].name
def _get_optimized_options(self):
options = ort.SessionOptions()
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
options.intra_op_num_threads = 4
options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
return options
def infer_batch(self, input_ids_batch):
# Efficient batch inference
outputs = self.session.run(
[self.output_name],
{self.input_name: input_ids_batch}
)
return outputs[0]
4. Hardware-Specific Optimization
class HardwareOptimizer:
def optimize_for_platform(self, model_path):
import platform
if platform.processor() == 'arm':
# Apple Silicon optimization
return self._optimize_for_apple_silicon(model_path)
elif 'Intel' in platform.processor():
# Intel optimization with VNNI
return self._optimize_for_intel(model_path)
else:
# Generic optimization
return self._generic_optimization(model_path)
def _optimize_for_apple_silicon(self, model_path):
# Use CoreML provider for M1/M2
providers = ['CoreMLExecutionProvider', 'CPUExecutionProvider']
return ort.InferenceSession(model_path, providers=providers)
🎯 Advanced Techniques
1. Graph Optimization
# Custom graph optimizations for knowledge tasks
def optimize_knowledge_graph(onnx_model):
# Fuse common patterns in knowledge extraction
graph = onnx_model.graph
# Pattern: Embedding lookup + normalization
fuse_embedding_norm(graph)
# Pattern: Multi-head attention optimization
optimize_attention_heads(graph)
# Pattern: Entity extraction specific ops
fuse_entity_extraction(graph)
return onnx_model
2. Mixed Precision Strategy
# Selective precision for different model components
precision_map = {
"embeddings": "float32", # Keep full precision
"attention": "float16", # Can reduce precision
"feed_forward": "int8", # Aggressive quantization
"output_layer": "float32" # Keep full precision
}
3. Caching and Preprocessing
class InferenceCache:
def __init__(self, cache_size=10000):
self.cache = LRUCache(cache_size)
self.hit_rate = 0
def infer_with_cache(self, text, model):
cache_key = hash(text)
if cache_key in self.cache:
self.hit_rate += 1
return self.cache[cache_key]
result = model.infer(text)
self.cache[cache_key] = result
return result
⚖️ Trade-offs and Alternatives
When ONNX Excels
- ✅ Production deployment with speed requirements
- ✅ Edge devices and resource constraints
- ✅ Cross-platform compatibility needed
- ✅ Batch processing workflows
When to Consider Alternatives
- ❌ Rapid prototyping → Stay with PyTorch
- ❌ Custom operators needed → TorchScript
- ❌ TPU deployment → TensorFlow/JAX
- ❌ Extreme quantization → TensorRT
🎓 Lessons Learned
- Not All Models Convert Equal: Some architectures optimize better
- Quantization Sweet Spot: INT8 works for most layers, but not all
- Profiling is Essential: Measure actual speedups, not theoretical
- Hardware Matters: Platform-specific optimizations yield big gains
🔮 Future Optimizations
- Sparse Models: Exploring structured sparsity for 10x speedups
- Custom Operators: Knowledge-specific ONNX operators
- Edge Deployment: WebAssembly compilation for browser execution
- Neural Architecture Search: Finding optimal architectures for ONNX
📚 Resources
"ONNX proved that optimization doesn't require compromise. We achieved 3x speedup while maintaining the intelligence that makes Oboyu special. Sometimes the best optimization is choosing the right tool." - Oboyu Team