Embedding

class Embedding<OutT : DType, V>(val numEmbeddings: Int, val embeddingDim: Int, initWeight: Tensor<OutT, V>, val paddingIdx: Int? = null, val name: String = "Embedding") : DualModule<Int32, OutT, V> , ModuleParameters<OutT, V> (source)

Embedding layer as a DualModule: consumes integer index tensors (Int32) and produces floating outputs (OutT). Supports optional paddingIdx which zeros the corresponding embedding row.

Constructors

Link copied to clipboard
constructor(numEmbeddings: Int, embeddingDim: Int, initWeight: Tensor<OutT, V>, paddingIdx: Int? = null, name: String = "Embedding")
constructor(ctx: ExecutionContext, dtype: KClass<OutT>, params: EmbeddingParams, name: String = "Embedding", mean: Float = 0.0f, std: Float = 0.1f, random: Random = Random.Default)

Default-initializing constructor with FP32 weights by default.

Types

Link copied to clipboard
object Companion

Properties

Link copied to clipboard
Link copied to clipboard
open override val modules: List<ModuleNode>

Child modules/nodes for traversal. Keep dtype-agnostic.

Link copied to clipboard
open override val name: String

Human-readable module name

Link copied to clipboard
Link copied to clipboard
Link copied to clipboard
open override val params: List<ModuleParameter<OutT, V>>

Parameters owned by this node (weights, biases, etc.).

Functions

Link copied to clipboard
fun forward(indices: IntArray, ctx: ExecutionContext): Tensor<OutT, V>

open override fun forward(input: Tensor<Int32, V>, ctx: ExecutionContext): Tensor<OutT, V>

Forward pass that requires an ExecutionContext.

Link copied to clipboard
fun forwardAny(input: Tensor<out DType, V>, ctx: ExecutionContext, strict: Boolean = true): Tensor<OutT, V>

Accepts any tensor and validates/coerces to indices in strict mode. Useful for legacy FP tensors.