trainStep

fun <T : DType, V> trainStep(model: Module<T, V>, loss: Loss, optimizer: Optimizer, ctx: ExecutionContext, x: Tensor<T, V>, y: Tensor<out DType, *>): Tensor<T, V>(source)

Perform a single training step: forward, backward, optimizer step, and zero grad.

Return

The computed loss value tensor.

Parameters

model

The neural network module to train.

loss

The loss function to minimize.

optimizer

The optimizer used to update model parameters.

ctx

The execution context, which must support recording (TrainingExecutionContext).

x

Input tensor.

y

Target tensor.