InternalMixedPrecisionModule

abstract class InternalMixedPrecisionModule<TIO : DType, TInternal : DType, V>(ioType: TIO, internalType: TInternal, conversionOps: MixedPrecisionTensorOps<V>) : Module<TIO, V> (source)

Specialized mixed precision module for cases where input and output types are the same but internal computation uses a different precision.

This is commonly used for layers where weights are stored in low precision (e.g., INT8) but computations are performed in higher precision (e.g., FP16) for accuracy.

Parameters

TIO

The input/output precision type

TInternal

The internal computation precision type

V

The value type corresponding to the DType

Constructors

Link copied to clipboard
constructor(ioType: TIO, internalType: TInternal, conversionOps: MixedPrecisionTensorOps<V>)

Functions

Link copied to clipboard
open override fun forward(input: Tensor<TIO, V>, ctx: ExecutionContext): Tensor<TIO, V>

Forward implementation that ensures output is converted back to input type.