ScaledDotProductAttentionOperation

Scaled dot-product attention operation for tape recording. Output shape = query shape: batch, nHeads, seqLen, headDim

Constructors

Link copied to clipboard
constructor(parameters: Map<String, Any> = emptyMap())

Functions

Link copied to clipboard
open override fun clone(newParameters: Map<String, Any> = parameters): Operation

Clone this operation with potentially different parameters

Link copied to clipboard
open override fun <T : DType, V> execute(inputs: List<Tensor<T, V>>): List<Tensor<T, V>>

Execute this operation with the given inputs

Link copied to clipboard
open override fun inferOutputs(inputs: List<TensorSpec>): List<TensorSpec>

Infer the output tensor specifications from input specifications

Link copied to clipboard
open override fun validateInputs(inputs: List<TensorSpec>): ValidationResult

Validate that the given inputs are compatible with this operation