TernaryMatmul
Optimized matrix multiplication for BitNet-style ternary weights.
When the weight matrix contains only {-1, 0, +1} values, multiplication becomes addition-only, which is significantly faster:
output[i] = sum over j of: activation[j] * ternary_weight[j,i]
= sum where weight=+1: activation[j]
- sum where weight=-1: activation[j]Content copied to clipboard
This avoids all floating-point multiplications, replacing them with conditional additions/subtractions based on the ternary weight value.
Functions
Link copied to clipboard
Check if a tensor's underlying data is ternary. This can be used to dispatch to optimized ternary matmul.
Link copied to clipboard
fun matmul(input: Tensor<FP32, Float>, ternaryWeights: TernaryTensorData, ctx: ExecutionContext): Tensor<FP32, Float>
Perform matrix multiplication with ternary weights.
Link copied to clipboard
fun matmulAutoDispatch(input: Tensor<FP32, Float>, weight: Tensor<*, *>, ctx: ExecutionContext): Tensor<FP32, Float>
Perform matmul with automatic dispatch based on weight type. Uses ternary-optimized path when weights are TernaryTensorData, otherwise falls back to standard matmul.