TrainingLoop

Java-friendly training loop wrapping SKaiNET's trainStep function.

Example usage from Java:

TrainingLoop loop = TrainingLoop.builder()
.model(model)
.loss(Losses.crossEntropy())
.optimizer(Optimizers.adam(0.001))
.context(ctx)
.build();

// Single step
float stepLoss = loop.step(inputBatch, targetBatch);

// Full training
TrainingResult result = loop.train(dataIterator, 10);

Types

Link copied to clipboard
class Builder

Builder for TrainingLoop.

Link copied to clipboard
object Companion

Functions

Link copied to clipboard
fun model(): Module<DType, Any?>

Returns the model being trained.

Link copied to clipboard
fun step(x: Tensor<*, *>, y: Tensor<*, *>): Float

Perform a single training step.

Link copied to clipboard
fun train(epochDataProvider: Supplier<Iterator<Pair<Tensor<*, *>, Tensor<*, *>>>>, epochs: Int): TrainingResult

Train the model for the specified number of epochs using an iterable of (input, target) pairs per epoch.

Link copied to clipboard
fun trainAsync(epochDataProvider: Supplier<Iterator<Pair<Tensor<*, *>, Tensor<*, *>>>>, epochs: Int): CompletableFuture<TrainingResult>

Train asynchronously using virtual threads.