How to Optimize StableHLO IR
Why Optimize?
Optimization before transpilation reduces:
-
Code size — fewer operations → smaller
.textsection → fits in 8 KB ITCM -
Runtime — fewer loops and memory accesses → fewer cycles
-
Memory — fewer intermediate arrays → more DTCM for input/output tensors
Apply Default Optimizations
import sk.ainet.compile.hlo.StableHloOptimizer
val module = StableHloModule(content = mlirCode)
val optimizer = StableHloOptimizer.createDefault()
val optimized = optimizer.optimize(module)
println(optimized.content)
The default pipeline applies:
-
Constant Folding
-
Operation Fusion
-
Dead Code Elimination
Apply Aggressive Optimizations
val optimizer = StableHloOptimizer.createAggressive()
val optimized = optimizer.optimize(module)
Aggressive mode runs constant folding twice — before and after fusion — to catch constants created by the fusion pass.
Build a Custom Pipeline
val optimizer = StableHloOptimizer().apply {
addPass(ConstantFoldingPass())
// Skip fusion for debugging
addPass(DeadCodeEliminationPass())
}
val optimized = optimizer.optimize(module)
Measure Optimization Impact
fun countOps(mlir: String): Int =
mlir.lines().count { it.trim().startsWith("%") || it.trim().startsWith("return") }
val before = countOps(module.content)
val optimized = optimizer.optimize(module)
val after = countOps(optimized.content)
println("Operations: $before → $after (${(before - after) * 100 / before}% reduction)")
Available Passes
Constant Folding
Evaluates compile-time-known expressions:
// Before
%0 = stablehlo.constant dense<2.0> : tensor<f32>
%1 = stablehlo.constant dense<3.0> : tensor<f32>
%2 = stablehlo.add %0, %1 : tensor<f32>
// After
%2 = stablehlo.constant dense<5.0> : tensor<f32>
Supports: add, multiply, subtract, divide.
Operation Fusion
Combines sequential operations into single compound operations:
-
Add + ReLU → fused add-relu (eliminates intermediate tensor)
-
Element-wise chain → fused multi-op (single loop instead of multiple)
-
Conv + Bias → fused conv-bias (bias applied inside convolution loop)
Dead Code Elimination
Removes operations whose results are not used:
// Before
%0 = stablehlo.constant dense<1.0> : tensor<f32>
%1 = stablehlo.constant dense<2.0> : tensor<f32> // unused
%2 = stablehlo.add %arg0, %0 : tensor<f32>
// After
%0 = stablehlo.constant dense<1.0> : tensor<f32>
%2 = stablehlo.add %arg0, %0 : tensor<f32>
Validate After Optimization
Always verify that optimization preserved correctness:
# Verify original
cd iree-tools
uv run python main.py verify original.mlir
# Verify optimized
uv run python main.py verify optimized.mlir
# Outputs should match
Write a Custom Pass
class QuantizationAwarePass : OptimizationPass {
override val name = "quantization-aware"
override fun apply(module: StableHloModule): StableHloModule {
val parser = MlirParser()
val structure = parser.parse(module.content).getOrThrow()
// Replace f32 constants with quantized versions
val optimized = structure.operations.map { op ->
if (op is ConstantOp && op.values.all { it in -1.0..1.0 }) {
quantizeToInt8(op)
} else op
}
return module.copy(
content = structure.copy(operations = optimized).toMlirString(),
metadata = module.metadata +
("optimizations" to existingPasses + name)
)
}
}
See Optimization Passes for a detailed explanation of how each pass works.