scaledDotProductAttention

abstract fun <T : DType, V> scaledDotProductAttention(query: Tensor<T, V>, key: Tensor<T, V>, value: Tensor<T, V>, mask: Tensor<T, V>? = null, scale: Float = 0.0f, causal: Boolean = false): Tensor<T, V>(source)

Scaled dot-product attention.

Computes: softmax((Q @ K^T) * scale + mask) @ V

This is a first-class op (not a composition) because it maps directly to platform-specific fused kernels: Flash Attention on CUDA, MPSGraph SDPA on Apple Silicon, etc.

Return

batch, nHeads, seqLen, headDim

Parameters

query

batch, nHeads, seqLen, headDim

key

batch, nKVHeads, kvLen, headDim

value

batch, nKVHeads, kvLen, headDim

mask

optional additive mask batch, 1, seqLen, kvLen (e.g. causal)

scale

scaling factor, defaults to 1/sqrt(headDim)

causal

if true, apply causal masking (ignore mask parameter)