MixedPrecisionModule

abstract class MixedPrecisionModule<TInput : DType, TOutput : DType, V>(inputType: TInput, outputType: TOutput, conversionOps: MixedPrecisionTensorOps<V>) : Module<TInput, V> (source)

Abstract base class for modules that support mixed-precision operations. This class handles automatic type conversions between different precision types, enabling seamless integration of layers with different precision requirements.

The class manages the complexity of precision conversions, allowing derived classes to focus on their core functionality while ensuring type safety and performance optimization.

Parameters

TInput

The precision type for input tensors

TOutput

The precision type for output tensors

V

The value type corresponding to the DType

inputType

The input precision type instance

outputType

The output precision type instance

conversionOps

Mixed precision tensor operations for handling conversions

Example usage:

class MixedPrecisionLinear<TInput : DType, TOutput : DType, V>(
inputType: TInput,
outputType: TOutput,
conversionOps: MixedPrecisionTensorOps<V>,
private val inFeatures: Int,
private val outFeatures: Int
) : MixedPrecisionModule<TInput, TOutput, V>(inputType, outputType, conversionOps) {

override fun forwardImpl(input: Tensor<TInput, V>): Tensor<TOutput, V> {
// Implement layer-specific logic here
// Input and output conversions are handled automatically
}
}

Constructors

Link copied to clipboard
constructor(inputType: TInput, outputType: TOutput, conversionOps: MixedPrecisionTensorOps<V>)

Properties

Link copied to clipboard

Whether to enable automatic input conversion. When false, input tensors must already be in the expected input type.

Link copied to clipboard

Whether to enable automatic output conversion. When false, output tensors will be in the module's internal precision.

Link copied to clipboard

Statistics for monitoring conversion performance.

Functions

Link copied to clipboard

Clears the conversion cache to free memory. Should be called periodically in long-running applications.

Link copied to clipboard
override fun forward(input: Tensor<TInput, V>, ctx: ExecutionContext): Tensor<TInput, V>

Main forward pass implementation that handles precision conversions automatically. This method wraps the actual implementation with conversion logic.

Link copied to clipboard

Gets the current cache size for monitoring purposes.

Link copied to clipboard

Validates precision compatibility for the module configuration. Should be called during module initialization.