espnet2.speechlm.model.speechlm.lm.loss.fused_cross_entropy_loss
Less than 1 minute
espnet2.speechlm.model.speechlm.lm.loss.fused_cross_entropy_loss
espnet2.speechlm.model.speechlm.lm.loss.fused_cross_entropy_loss(hidden_states: Tensor, input_ids: Tensor, loss_mask: Tensor, lm_head_weight: Tensor, multimodal_vocab_range: Tuple[int, int] | None, num_stream: int, training: bool, z_loss_weight: float = 1e-05, ce_weight: Tensor | None = None) → Tuple[Tensor, Tensor, Dict]
Compute cross-entropy loss using Liger’s fused linear + CE kernel.
Uses reduction=”sum” with pre-masked targets to work around Liger’s reduction=”none” backward bug. Two Liger calls: one for stream 0 (full vocab) and one for streams 1+ (multimodal vocab subset).
- Parameters:
- hidden_states – [B, T, N, H] — unshifted
- input_ids – [B, T, N] — unshifted
- loss_mask – [B, T, N] — unshifted, float (0/1)
- lm_head_weight – [V, H] — may be DTensor
- multimodal_vocab_range – (mm_start, mm_end) or None
- num_stream – int
- training – bool
- ce_weight – [V] optional per-class weight
- Returns: (ce_sum scalar, count scalar, stats dict) ce_sum and stats[‘z_loss’] are raw sums (not divided by count). The caller is responsible for normalization.
