TernaryMatmul

Optimized matrix multiplication for BitNet-style ternary weights.

When the weight matrix contains only {-1, 0, +1} values, multiplication becomes addition-only, which is significantly faster:

output[i] = sum over j of: activation[j] * ternary_weight[j,i]
= sum where weight=+1: activation[j]
- sum where weight=-1: activation[j]

This avoids all floating-point multiplications, replacing them with conditional additions/subtractions based on the ternary weight value.

Functions

Link copied to clipboard
fun isTernaryWeight(tensor: Tensor<*, *>): Boolean

Check if a tensor's underlying data is ternary. This can be used to dispatch to optimized ternary matmul.

Link copied to clipboard
fun matmul(input: Tensor<FP32, Float>, ternaryWeights: TernaryTensorData, ctx: ExecutionContext): Tensor<FP32, Float>

Perform matrix multiplication with ternary weights.

Link copied to clipboard

Perform matmul with automatic dispatch based on weight type. Uses ternary-optimized path when weights are TernaryTensorData, otherwise falls back to standard matmul.