How to Optimize StableHLO IR

Why Optimize?

Optimization before transpilation reduces:

  • Code size — fewer operations → smaller .text section → fits in 8 KB ITCM

  • Runtime — fewer loops and memory accesses → fewer cycles

  • Memory — fewer intermediate arrays → more DTCM for input/output tensors

Apply Default Optimizations

import sk.ainet.compile.hlo.StableHloOptimizer

val module = StableHloModule(content = mlirCode)
val optimizer = StableHloOptimizer.createDefault()
val optimized = optimizer.optimize(module)

println(optimized.content)

The default pipeline applies:

  1. Constant Folding

  2. Operation Fusion

  3. Dead Code Elimination

Apply Aggressive Optimizations

val optimizer = StableHloOptimizer.createAggressive()
val optimized = optimizer.optimize(module)

Aggressive mode runs constant folding twice — before and after fusion — to catch constants created by the fusion pass.

Build a Custom Pipeline

val optimizer = StableHloOptimizer().apply {
    addPass(ConstantFoldingPass())
    // Skip fusion for debugging
    addPass(DeadCodeEliminationPass())
}
val optimized = optimizer.optimize(module)

Measure Optimization Impact

fun countOps(mlir: String): Int =
    mlir.lines().count { it.trim().startsWith("%") || it.trim().startsWith("return") }

val before = countOps(module.content)
val optimized = optimizer.optimize(module)
val after = countOps(optimized.content)

println("Operations: $before → $after (${(before - after) * 100 / before}% reduction)")

Available Passes

Constant Folding

Evaluates compile-time-known expressions:

// Before
%0 = stablehlo.constant dense<2.0> : tensor<f32>
%1 = stablehlo.constant dense<3.0> : tensor<f32>
%2 = stablehlo.add %0, %1 : tensor<f32>

// After
%2 = stablehlo.constant dense<5.0> : tensor<f32>

Supports: add, multiply, subtract, divide.

Operation Fusion

Combines sequential operations into single compound operations:

  • Add + ReLU → fused add-relu (eliminates intermediate tensor)

  • Element-wise chain → fused multi-op (single loop instead of multiple)

  • Conv + Bias → fused conv-bias (bias applied inside convolution loop)

Dead Code Elimination

Removes operations whose results are not used:

// Before
%0 = stablehlo.constant dense<1.0> : tensor<f32>
%1 = stablehlo.constant dense<2.0> : tensor<f32>  // unused
%2 = stablehlo.add %arg0, %0 : tensor<f32>

// After
%0 = stablehlo.constant dense<1.0> : tensor<f32>
%2 = stablehlo.add %arg0, %0 : tensor<f32>

Validate After Optimization

Always verify that optimization preserved correctness:

# Verify original
cd iree-tools
uv run python main.py verify original.mlir

# Verify optimized
uv run python main.py verify optimized.mlir

# Outputs should match

Write a Custom Pass

class QuantizationAwarePass : OptimizationPass {
    override val name = "quantization-aware"

    override fun apply(module: StableHloModule): StableHloModule {
        val parser = MlirParser()
        val structure = parser.parse(module.content).getOrThrow()

        // Replace f32 constants with quantized versions
        val optimized = structure.operations.map { op ->
            if (op is ConstantOp && op.values.all { it in -1.0..1.0 }) {
                quantizeToInt8(op)
            } else op
        }

        return module.copy(
            content = structure.copy(operations = optimized).toMlirString(),
            metadata = module.metadata +
                ("optimizations" to existingPasses + name)
        )
    }
}

See Optimization Passes for a detailed explanation of how each pass works.