scaledDotProductAttention
open override fun <T : DType, V> scaledDotProductAttention(query: Tensor<T, V>, key: Tensor<T, V>, value: Tensor<T, V>, mask: Tensor<T, V>? = null, scale: Float = 0.0f, causal: Boolean = false): Tensor<T, V>(source)
Scaled dot-product attention.
Computes: softmax((Q @ K^T) * scale + mask) @ V
This is a first-class op (not a composition) because it maps directly to platform-specific fused kernels: Flash Attention on CUDA, MPSGraph SDPA on Apple Silicon, etc.
Return
batch, nHeads, seqLen, headDim
Parameters
query
batch, nHeads, seqLen, headDim
key
batch, nKVHeads, kvLen, headDim
value
batch, nKVHeads, kvLen, headDim
mask
optional additive mask batch, 1, seqLen, kvLen (e.g. causal)
scale
scaling factor, defaults to 1/sqrt(headDim)
causal
if true, apply causal masking (ignore mask parameter)