Flatten

class Flatten<T : DType, V>(ctx: ExecutionContext, val startDim: Int = 0, val endDim: Int = -1) : TensorTransform<T, V> (source)

Flattens a tensor along specified dimensions.

Usage

val flatten = Flatten<FP32, Float>(ctx, startDim = 1)
val flattened = flatten.apply(convOutput) // [N, C, H, W] -> [N, C*H*W]

Parameters

T

The tensor data type

V

The value type

ctx

The execution context for tensor operations

startDim

The first dimension to flatten (default: 0)

endDim

The last dimension to flatten (default: -1, meaning last dimension)

Constructors

Link copied to clipboard
constructor(ctx: ExecutionContext, startDim: Int = 0, endDim: Int = -1)

Properties

Link copied to clipboard
val endDim: Int
Link copied to clipboard

Functions

Link copied to clipboard
open override fun apply(input: Tensor<T, V>): Tensor<T, V>

Applies this transformation to the given input.

Link copied to clipboard
open override fun getOutputShape(inputShape: Shape): Shape

By default, tensor transforms preserve the input shape. Override this method if the transform changes the shape.

Link copied to clipboard
open override fun toString(): String