TrainingRunner

class TrainingRunner<T : DType, V>(val model: Module<T, V>, val loss: Loss, val optimizer: Optimizer)(source)

Runner that holds training components and provides methods to execute training steps.

Constructors

Link copied to clipboard
constructor(model: Module<T, V>, loss: Loss, optimizer: Optimizer)

Properties

Link copied to clipboard
val loss: Loss
Link copied to clipboard
val model: Module<T, V>
Link copied to clipboard

Functions

Link copied to clipboard
fun step(ctx: ExecutionContext, x: Tensor<T, V>, y: Tensor<out DType, *>): Tensor<T, V>

Perform a single training step: forward, backward, optimizer step, and zero grad.

Link copied to clipboard
fun train(ctx: ExecutionContext, dataset: Iterable<Pair<Tensor<T, V>, Tensor<out DType, *>>>, epochs: Int = 1)

Optional helper to run a training loop over a dataset.