MixedPrecisionModule
Abstract base class for modules that support mixed-precision operations. This class handles automatic type conversions between different precision types, enabling seamless integration of layers with different precision requirements.
The class manages the complexity of precision conversions, allowing derived classes to focus on their core functionality while ensuring type safety and performance optimization.
Parameters
The precision type for input tensors
The precision type for output tensors
The value type corresponding to the DType
The input precision type instance
The output precision type instance
Mixed precision tensor operations for handling conversions
Example usage:
class MixedPrecisionLinear<TInput : DType, TOutput : DType, V>(
inputType: TInput,
outputType: TOutput,
conversionOps: MixedPrecisionTensorOps<V>,
private val inFeatures: Int,
private val outFeatures: Int
) : MixedPrecisionModule<TInput, TOutput, V>(inputType, outputType, conversionOps) {
override fun forwardImpl(input: Tensor<TInput, V>): Tensor<TOutput, V> {
// Implement layer-specific logic here
// Input and output conversions are handled automatically
}
}Properties
Whether to enable automatic input conversion. When false, input tensors must already be in the expected input type.
Whether to enable automatic output conversion. When false, output tensors will be in the module's internal precision.
Statistics for monitoring conversion performance.
Functions
Clears the conversion cache to free memory. Should be called periodically in long-running applications.
Gets the current cache size for monitoring purposes.
Validates precision compatibility for the module configuration. Should be called during module initialization.