Accuracy

class Accuracy(dim: Int = -1, threshold: Float? = null) : Metric(source)

Classification accuracy metric.

Computes the fraction of predictions that match the target labels. Supports both hard targets (class indices) and soft targets (one-hot or probabilities).

For predictions, the class with maximum value along dim is selected. For soft targets, the class with maximum value is used as the ground truth. For hard targets (Int32), the value directly represents the class index.

Parameters

dim

The dimension along which to find the predicted class (default: -1, last dimension)

threshold

Optional threshold for binary classification. If provided, predictions > threshold are classified as class 1, otherwise class 0. Only applicable when dim size is 1 or 2.

Constructors

Link copied to clipboard
constructor(dim: Int = -1, threshold: Float? = null)

Properties

Link copied to clipboard
open override val name: String

The name of this metric for display purposes.

Functions

Link copied to clipboard
open override fun compute(): Double

Compute the metric value from accumulated statistics.

Link copied to clipboard
open override fun reset()

Reset the accumulated statistics to start a fresh evaluation.

Link copied to clipboard
open override fun <T : DType, V> update(predictions: Tensor<T, V>, targets: Tensor<out DType, *>, ctx: ExecutionContext)

Update the metric state with a batch of predictions and targets.