Flatten
class Flatten<T : DType, V>(ctx: ExecutionContext, val startDim: Int = 0, val endDim: Int = -1) : TensorTransform<T, V> (source)
Flattens a tensor along specified dimensions.
Usage
val flatten = Flatten<FP32, Float>(ctx, startDim = 1)
val flattened = flatten.apply(convOutput) // [N, C, H, W] -> [N, C*H*W]Content copied to clipboard
Parameters
T
The tensor data type
V
The value type
ctx
The execution context for tensor operations
startDim
The first dimension to flatten (default: 0)
endDim
The last dimension to flatten (default: -1, meaning last dimension)