espnet.nets.pytorch_backend.transformer.embedding.RelPositionalEncoding
Less than 1 minute
espnet.nets.pytorch_backend.transformer.embedding.RelPositionalEncoding
class espnet.nets.pytorch_backend.transformer.embedding.RelPositionalEncoding(d_model, dropout_rate, max_len=5000)
Bases: Module
Relative positional encoding module (new implementation).
Details can be found in https://github.com/espnet/espnet/pull/2816.
See : Appendix B in https://arxiv.org/abs/1901.02860
- Parameters:
- d_model (int) – Embedding dimension.
- dropout_rate (float) – Dropout rate.
- max_len (int) – Maximum input length.
Construct an PositionalEncoding object.
extend_pe(x)
Reset the positional encodings.
forward(x: Tensor)
Add positional encoding.
- Parameters:x (torch.Tensor) – Input tensor (batch, time, *).
- Returns: Encoded tensor (batch, time, *).
- Return type: torch.Tensor