Source code for espnet2.asr.encoder.avhubert_encoder

# Copyright 2023
# Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

# The original AVHubert work is in:
#     Paper: https://arxiv.org/pdf/2201.02184.pdf
#     Original code: https://github.com/facebookresearch/av_hubert


"""Encoder definition."""
import contextlib
import copy
import logging
import math
import os
import random
from collections import OrderedDict
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Dict, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
from filelock import FileLock
from typeguard import typechecked

from espnet2.asr.encoder.abs_encoder import AbsEncoder
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask

logger = logging.getLogger(__name__)


[docs]def conv3x3(in_planes, out_planes, stride=1): return nn.Conv2d( in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False )
[docs]def downsample_basic_block(inplanes, outplanes, stride): return nn.Sequential( nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(outplanes), )
[docs]def downsample_basic_block_v2(inplanes, outplanes, stride): return nn.Sequential( nn.AvgPool2d( kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False ), nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, bias=False), nn.BatchNorm2d(outplanes), )
[docs]def time_masking(xs_pad, min_T=5, max_T=20): """Masking Contiguous Frames with random length of [min_T, max_T]""" batch_size = xs_pad.size(0) mask = torch.ones_like(xs_pad) for b in range(batch_size): width = min(random.randint(min_T, max_T), xs_pad.size(1)) start = random.randint(0, xs_pad.size(1) - width) mask[b, start : start + width] = 0.0 return xs_pad * mask.to(xs_pad.device)
# avhubert_url(noise_large): # 'https://dl.fbaipublicfiles.com/avhubert/model/lrs3_vox/noise-pretrain/large_vox_iter5.pt' # avhubert_url(noise_base): # 'https://dl.fbaipublicfiles.com/avhubert/model/lrs3_vox/noise-pretrain/base_vox_iter5.pt'
[docs]class FairseqAVHubertEncoder(AbsEncoder): """FairSeq AVHubert pretrained encoder module Args: input_size: input dim avhubert_url: download link for pre-trained avhubert model avhubert_dir_path: dir_path for downloading pre-trained avhubert model """ @typechecked def __init__( self, input_size: int = 1, avhubert_url: str = "./", avhubert_dir_path: str = "./", freeze_finetune_updates: int = 0, encoder_embed_dim: int = 1024, encoder_layerdrop: float = 0.05, dropout_input: float = 0.1, dropout_features: float = 0.1, dropout: float = 0.1, attention_dropout: float = 0.1, feature_grad_mult: float = 0.1, activation_dropout: float = 0.0, wav_input: bool = False, layer_norm_first: bool = True, audio_feat_dim: int = 104, encoder_layers: int = 24, encoder_ffn_embed_dim: int = 4096, encoder_attention_heads: int = 16, extracted: bool = False, pretrain: bool = True, modality_dropout: float = 0.0, audio_dropout: float = 0.0, noise_augmentation: bool = False, noise_path: str = "./data/babble_noise.pt", max_noise_weight: float = 0.5, audio_only: bool = False, ): super().__init__() self._output_size = encoder_embed_dim self.extracted = extracted self.modality_dropout = modality_dropout self.audio_dropout = audio_dropout self.audio_only = audio_only arg_overrides = { "encoder_embed_dim": encoder_embed_dim, "encoder_layerdrop": encoder_layerdrop, "dropout_input": dropout_input, "dropout_features": dropout_features, "dropout": dropout, "attention_dropout": attention_dropout, "feature_grad_mult": feature_grad_mult, "activation_dropout": activation_dropout, "wav_input": wav_input, "layer_norm_first": layer_norm_first, "audio_feat_dim": audio_feat_dim, "encoder_layers": encoder_layers, "encoder_ffn_embed_dim": encoder_ffn_embed_dim, "encoder_attention_heads": encoder_attention_heads, "audio_only": audio_only, } default_cfg = AVHubertConfig() for arg_name, arg_val in arg_overrides.items(): setattr(default_cfg, arg_name, arg_val) model = AVHubertModel.build_model(cfg=default_cfg) self.modality_fuse = model.modality_fuse if pretrain: self.avhubert_model_path = download_avhubert( avhubert_url, avhubert_dir_path, ) ckpt = torch.load( self.avhubert_model_path, map_location=torch.device("cpu"), ) state = { k: v for k, v in ckpt["model"].items() if "label_embs_concat" not in k and "final_proj" not in k } del ckpt model.load_state_dict(state) else: logging.info( "Training from scratch without \ using pre-trained AV-HuBERT model" ) self.pretrained_params = copy.deepcopy(model.state_dict()) self.encoders = model if noise_augmentation: self.noise = torch.load(noise_path) self.max_noise_weight = max_noise_weight else: self.noise = None self.max_noise_weight = None self.freeze_finetune_updates = freeze_finetune_updates self.register_buffer("num_updates", torch.LongTensor([0]))
[docs] def output_size(self) -> int: return self._output_size
[docs] def forward( self, xs_pad: Dict[str, torch.Tensor], ilens: torch.Tensor, prev_states: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Forward AVHubert Encoder. Args: xs_pad[video]: input tensor (B, 1, L, H, W) xs_pad[audio]: input tensor (B, D, L) ilens: input length (B) prev_states: Not to be used now. Returns: position embedded tensor and mask """ if not self.extracted: if "video" in xs_pad: masks = make_pad_mask(ilens, length_dim=2).to(xs_pad["video"].device) elif "audio" in xs_pad: masks = make_pad_mask(ilens, length_dim=2).to(xs_pad["audio"].device) else: ValueError("Input should have video or audio") ft = self.freeze_finetune_updates <= self.num_updates if self.num_updates <= self.freeze_finetune_updates: self.num_updates += 1 elif ft and self.num_updates == self.freeze_finetune_updates + 1: self.num_updates += 1 logging.info("Start fine-tuning AVhubert parameters!") else: self.num_updates += 1 with torch.no_grad() if not ft else contextlib.nullcontext(): enc_outputs = self.encoders.extract_finetune( xs_pad, padding_mask=masks, ) else: masks = make_pad_mask(ilens, length_dim=1).to(xs_pad.device) ft = self.freeze_finetune_updates <= self.num_updates if self.training: xs_pad = time_masking(xs_pad) if self.modality_dropout > 0 and self.modality_fuse == "concat": modality_drop_prob, audio_drop_prob = ( np.random.random(), np.random.random(), ) if modality_drop_prob < self.modality_dropout: if audio_drop_prob < self.audio_dropout: # first half dimension is audio features modal_masks = torch.ones_like(xs_pad) modal_masks[:, :, : modal_masks.size(2) // 2] = 0.0 xs_pad = xs_pad * modal_masks else: # last half dimension is video features modal_masks = torch.ones_like(xs_pad) modal_masks[:, :, modal_masks.size(2) // 2 :] = 0.0 xs_pad = xs_pad * modal_masks if self.noise is not None: start_ind = torch.randint( 0, self.noise.size(0) - xs_pad.size(1), size=[xs_pad.size(0)] ) # B noise_ind = start_ind.view(-1, 1) + torch.arange( 0, xs_pad.size(1) ).unsqueeze(0).repeat( xs_pad.size(0), 1 ) # B,T noise_weight = ( torch.rand([xs_pad.size(0), 1, 1]).to(xs_pad.device) * self.max_noise_weight ) xs_pad = (1 - noise_weight) * xs_pad + noise_weight * self.noise[ noise_ind ].to(xs_pad.device) if self.audio_only: modal_masks = torch.ones_like(xs_pad) modal_masks[:, :, : modal_masks.size(2) // 2] = 0.0 xs_pad = xs_pad * modal_masks if self.num_updates <= self.freeze_finetune_updates: self.num_updates += 1 elif ft and self.num_updates == self.freeze_finetune_updates + 1: self.num_updates += 1 logging.info("Start fine-tuning AVhubert parameters!") else: self.num_updates += 1 with torch.no_grad() if not ft else contextlib.nullcontext(): enc_outputs = self.encoders.forward_transformer( xs_pad, padding_mask=masks, ) xs_pad = enc_outputs[0] masks = enc_outputs[1] # save gpu memory del enc_outputs olens = (~masks).sum(dim=1) return xs_pad, olens, None
[docs] def forward_fusion( self, xs_pad: Dict[str, torch.Tensor], ) -> torch.Tensor: if xs_pad["audio"] is not None: audio_feats = self.encoders.forward_audio(xs_pad["audio"]) else: audio_feats = None if xs_pad["video"] is not None: video_feats = self.encoders.forward_video(xs_pad["video"]) else: video_feats = None return self.encoders.modality_fusion(audio_feats, video_feats)
[docs] def reload_pretrained_parameters(self): self.encoders.load_state_dict(self.pretrained_params, strict=False) logging.info("Pretrained AVHubert model parameters reloaded!")
[docs]@dataclass class AVHubertConfig: """Configuration from original AVHubert Github""" sample_rate: int = field( default=16_000, metadata={ "help": "target sample rate. audio files will be up/down " "sampled to this rate" }, ) label_rate: int = field( default=-1, metadata={"help": "label frame rate. -1 for sequence label"}, ) encoder_layers: int = field( default=12, metadata={"help": "num encoder layers in the transformer"} ) encoder_embed_dim: int = field( default=768, metadata={"help": "encoder embedding dimension"} ) encoder_ffn_embed_dim: int = field( default=3072, metadata={"help": "encoder embedding dimension for FFN"} ) encoder_attention_heads: int = field( default=12, metadata={"help": "num encoder attention heads"} ) activation_fn: str = field( default="gelu", metadata={"help": "activation function to use"} ) # dropouts dropout: float = field( default=0.1, metadata={"help": "dropout probability for the transformer"}, ) attention_dropout: float = field( default=0.1, metadata={"help": "dropout probability for attention weights"}, ) activation_dropout: float = field( default=0.0, metadata={"help": "dropout probability after activation in FFN"}, ) encoder_layerdrop: float = field( default=0.0, metadata={"help": "probability of dropping a tarnsformer layer"}, ) dropout_input: float = field( default=0.0, metadata={"help": "dropout to apply to the input (after feat extr)"}, ) dropout_features: float = field( default=0.0, metadata={"help": "dropout to apply to the features (after feat extr)"}, ) final_dim: int = field( default=0, metadata={ "help": "project final representations and targets to this many " "dimensions. set to encoder_embed_dim is <= 0" }, ) untie_final_proj: bool = field( default=False, metadata={"help": "use separate projection for each target"}, ) layer_norm_first: bool = field( default=False, metadata={"help": "apply layernorm first in the transformer"}, ) conv_feature_layers: str = field( default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2", metadata={ "help": "string describing convolutional feature extraction " "layers in form of a python list that contains " "[(dim, kernel_size, stride), ...]" }, ) conv_bias: bool = field( default=False, metadata={"help": "include bias in conv encoder"} ) logit_temp: float = field( default=0.1, metadata={"help": "temperature to divide logits by"} ) target_glu: bool = field( default=False, metadata={"help": "adds projection + glu to targets"} ) feature_grad_mult: float = field( default=1.0, metadata={"help": "multiply feature extractor var grads by this"}, ) # masking mask_length_audio: int = field(default=10, metadata={"help": "mask length"}) mask_prob_audio: float = field( default=0.65, metadata={"help": "probability of replacing a token with mask"}, ) mask_length_image: int = field(default=10, metadata={"help": "mask length"}) mask_prob_image: float = field( default=0.65, metadata={"help": "probability of replacing a token with mask"}, ) mask_selection: str = field( default="static", metadata={"help": "how to choose mask length"} ) mask_other: float = field( default=0, metadata={ "help": "secondary mask argument " "(used for more complex distributions), " "see help in compute_mask_indicesh" }, ) no_mask_overlap: bool = field( default=False, metadata={"help": "whether to allow masks to overlap"} ) mask_min_space: int = field( default=1, metadata={"help": "min space between spans (if no overlap is enabled)"}, ) # channel masking mask_channel_length: int = field( default=10, metadata={"help": "length of the mask for features (channels)"}, ) mask_channel_prob: float = field( default=0.0, metadata={"help": "probability of replacing a feature with 0"}, ) mask_channel_selection: str = field( default="static", metadata={"help": "how to choose mask length for channel masking"}, ) mask_channel_other: float = field( default=0, metadata={ "help": "secondary mask argument " "(used for more complex distributions), " "see help in compute_mask_indicesh" }, ) no_mask_channel_overlap: bool = field( default=False, metadata={"help": "whether to allow channel masks to overlap"}, ) mask_channel_min_space: int = field( default=1, metadata={"help": "min space between spans (if no overlap is enabled)"}, ) # positional embeddings conv_pos: int = field( default=128, metadata={"help": "number of filters for convolutional positional embeddings"}, ) conv_pos_groups: int = field( default=16, metadata={"help": "number of groups for convolutional positional embedding"}, ) latent_temp: Tuple[float, float, float] = field( default=(2, 0.5, 0.999995), metadata={"help": "legacy (to be removed)"}, ) # loss computation skip_masked: bool = field( default=False, metadata={"help": "skip computing losses over masked frames"}, ) skip_nomask: bool = field( default=False, metadata={"help": "skip computing losses over unmasked frames"}, ) resnet_relu_type: str = field( default="prelu", metadata={"help": "relu type for resnet"} ) resnet_weights: Optional[str] = field( default=None, metadata={"help": "resnet weights"} ) sim_type: str = field(default="cosine", metadata={"help": "similarity type"}) sub_encoder_layers: int = field( default=0, metadata={"help": "number of transformer layers for single modality"} ) audio_feat_dim: int = field( default=-1, metadata={"help": "audio feature dimension"} ) modality_dropout: float = field(default=0, metadata={"help": "drop one modality"}) audio_dropout: float = field(default=0, metadata={"help": "drop audio feature"}) modality_fuse: str = field( default="concat", metadata={"help": "fusing two modalities: add,concat"} ) selection_type: str = field( default="same_other_seq", metadata={ "help": "type of selectig images," "same_other_seq: replace masked span with span from another sequence," "same_seq: repace masked span with span of the same sequence" }, ) masking_type: str = field( default="input", metadata={"help": "input or feature masking"} ) decoder_embed_dim: int = field( default=768, metadata={"help": "decoder embedding dimension"} ) decoder_ffn_embed_dim: int = field( default=3072, metadata={"help": "decoder embedding dimension for FFN"} ) decoder_layers: int = field(default=6, metadata={"help": "num of decoder layers"}) decoder_layerdrop: float = field( default=0.0, metadata={"help": "decoder layerdrop chance"} ) decoder_attention_heads: int = field( default=4, metadata={"help": "num decoder attention heads"} ) decoder_learned_pos: bool = field( default=False, metadata={"help": "use learned positional embeddings in the decoder"}, ) decoder_normalize_before: bool = field( default=False, metadata={"help": "apply layernorm before each decoder block"}, ) no_token_positional_embeddings: bool = field( default=False, metadata={ "help": "if set, disables positional embeddings " "(outside self attention)" }, ) decoder_dropout: float = field( default=0.1, metadata={"help": "dropout probability in the decoder"} ) decoder_attention_dropout: float = field( default=0.1, metadata={ "help": "dropout probability for attention weights " "inside the decoder" }, ) decoder_activation_dropout: float = field( default=0.0, metadata={ "help": "dropout probability after activation in FFN " "inside the decoder" }, ) max_target_positions: int = field( default=2048, metadata={"help": "max target positions"} ) share_decoder_input_output_embed: bool = field( default=False, metadata={"help": "share decoder input and output embeddings"}, ) audio_only: bool = field( default=False, metadata={"help": "whether to use audio stream only"}, ) no_scale_embedding: bool = field(default=True, metadata={"help": "scale embedding"})
[docs]class SubModel(nn.Module): def __init__(self, resnet=None, input_dim=None, cfg=None): super().__init__() self.resnet = resnet self.proj = nn.Linear(input_dim, cfg.encoder_embed_dim) self.encoder = TransformerEncoder(cfg) if cfg.encoder_layers > 0 else None
[docs] def forward(self, x): if self.resnet is not None: x = self.resnet(x) x = self.proj(x.transpose(1, 2)) if self.encoder is not None: x = self.encoder(x)[0].transpose(1, 2) else: x = x.transpose(1, 2) return x
[docs]class AVHubertModel(nn.Module): def __init__(self, cfg: AVHubertConfig, **kwargs) -> None: super().__init__() logger.info(f"HubertModel Config: {cfg}") try: from fairseq.modules import LayerNorm except Exception as e: print("Error: FairSeq is not properly installed.") print("Please install FairSeq: cd ${MAIN_ROOT}/tools && make fairseq.done") raise e feature_ds_rate = 1 self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / cfg.sample_rate sub_cfg = deepcopy(cfg) sub_cfg.encoder_layers = sub_cfg.sub_encoder_layers resnet = ResEncoder(relu_type=cfg.resnet_relu_type, weights=cfg.resnet_weights) self.feature_extractor_audio = SubModel( resnet=None, input_dim=cfg.audio_feat_dim, cfg=sub_cfg ) self.feature_extractor_video = SubModel( resnet=resnet, input_dim=resnet.backend_out, cfg=sub_cfg ) self.modality_dropout, self.audio_dropout = ( cfg.modality_dropout, cfg.audio_dropout, ) self.modality_fuse = cfg.modality_fuse self.encoder_embed_dim = cfg.encoder_embed_dim if self.modality_fuse == "concat": self.embed = cfg.encoder_embed_dim * 2 elif self.modality_fuse == "add": self.embed = cfg.encoder_embed_dim else: ValueError(f"unknown fusion method: {self.modality_fuse}") self.post_extract_proj = ( nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None ) self.mask_prob_image, self.mask_prob_audio = ( cfg.mask_prob_image, cfg.mask_prob_audio, ) self.mask_selection = cfg.mask_selection self.mask_other = cfg.mask_other self.mask_length_image, self.mask_length_audio = ( cfg.mask_length_image, cfg.mask_length_audio, ) self.no_mask_overlap = cfg.no_mask_overlap self.mask_min_space = cfg.mask_min_space self.mask_channel_prob = cfg.mask_channel_prob self.mask_channel_selection = cfg.mask_channel_selection self.mask_channel_other = cfg.mask_channel_other self.mask_channel_length = cfg.mask_channel_length self.no_mask_channel_overlap = cfg.no_mask_channel_overlap self.mask_channel_min_space = cfg.mask_channel_min_space self.dropout_input = nn.Dropout(cfg.dropout_input) self.dropout_features = nn.Dropout(cfg.dropout_features) self.feature_grad_mult = cfg.feature_grad_mult self.logit_temp = cfg.logit_temp self.skip_masked = cfg.skip_masked self.skip_nomask = cfg.skip_nomask self.sim_type = cfg.sim_type self.selection_type = cfg.selection_type self.masking_type = cfg.masking_type self.mask_emb = nn.Parameter( torch.FloatTensor(cfg.audio_feat_dim).uniform_() if self.masking_type == "input" else torch.FloatTensor(cfg.encoder_embed_dim).uniform_() ) self.encoder = TransformerEncoder(cfg) self.layer_norm = LayerNorm(self.embed) self.audio_only = cfg.audio_only
[docs] @classmethod def build_model(cls, cfg: AVHubertConfig): """Build a new model instance.""" kwargs = {} model = cls(cfg, **kwargs) return model
[docs] def forward_features(self, source: torch.Tensor, modality: str) -> torch.Tensor: extractor = eval(f"self.feature_extractor_{modality}") if self.feature_grad_mult > 0: features = extractor(source) if self.feature_grad_mult != 1.0: features = GradMultiply.apply(features, self.feature_grad_mult) else: with torch.no_grad(): features = extractor(source) return features
[docs] def forward_padding_mask( self, features: torch.Tensor, padding_mask: torch.Tensor, ) -> torch.Tensor: extra = padding_mask.size(1) % features.size(1) if extra > 0: padding_mask = padding_mask[:, :-extra] padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) padding_mask = padding_mask.all(-1) return padding_mask
[docs] def extract_finetune( self, source, padding_mask=None, mask=False, ret_conv=False, output_layer=None ): """Forward AVHubert Pretrain Encoder. Args: source['video']: input tensor (B, 1, L, H, W) source['audio']: input tensor (B, F, L) padding_mask: input tensor (B, L) Returns: encoded tensor and mask """ src_audio, src_video = source["audio"], source["video"] if (src_audio is not None and src_video is None) or self.audio_only: features_audio = self.forward_features( src_audio, modality="audio" ) # features: [B, F, T] features_video = features_audio.new_zeros( features_audio.size(0), self.encoder_embed_dim, features_audio.size(-1) ) elif src_audio is None and src_video is not None: features_video = self.forward_features(src_video, modality="video") features_audio = features_video.new_zeros( features_video.size(0), self.encoder_embed_dim, features_video.size(-1) ) elif src_audio is not None and src_video is not None: features_video = self.forward_features(src_video, modality="video") features_audio = self.forward_features( src_audio, modality="audio" ) # features: [B, F, T] else: ValueError("Both audio and video is None") if self.modality_fuse == "concat": features = torch.cat([features_audio, features_video], dim=1) elif self.modality_fuse == "add": features = features_audio + features_video else: ValueError(f"unknown fusion method: {self.modality_fuse}") features = features.transpose(1, 2) # B, 2F, T -> B, T, 2F features = self.layer_norm(features) if padding_mask is not None: padding_mask = self.forward_padding_mask(features, padding_mask) if self.post_extract_proj is not None: features = self.post_extract_proj(features) features = self.dropout_input(features) x = features # feature: (B, T, D), float # target: (B, T), long # x: (B, T, D), float # padding_mask: (B, T), bool # mask_indices: (B, T), bool x, _ = self.encoder( x, padding_mask=padding_mask, layer=None if output_layer is None else output_layer - 1, ) return x, padding_mask
[docs] def forward_audio(self, source_audio): with torch.no_grad(): features_audio = self.forward_features( source_audio, modality="audio" ) # features: [B, F, T] return features_audio
[docs] def forward_video(self, source_video): with torch.no_grad(): features_video = self.forward_features( source_video, modality="video" ) # features: [B, F, T] return features_video
[docs] def modality_fusion(self, features_audio, features_video): if features_video is None and features_audio is not None: features_video = features_audio.new_zeros( features_audio.size(0), self.encoder_embed_dim, features_audio.size(-1) ) elif features_audio is None and features_video is not None: features_audio = features_video.new_zeros( features_video.size(0), self.encoder_embed_dim, features_video.size(-1) ) else: features_video = features_video features_audio = features_audio if self.modality_fuse == "concat": features = torch.cat([features_audio, features_video], dim=1) elif self.modality_fuse == "add": features = features_audio + features_video else: ValueError(f"unknown fusion method: {self.modality_fuse}") return features
[docs] def forward_transformer(self, source, padding_mask=None, output_layer=None): """Forward AVHubert Pretrain Encoder (without frontend). Assume the source is already fused feature. Args: source: input tensor (B, L, D*2) padding_mask: input tensor (B, L) Returns: encoded tensor and mask """ features = source features = self.layer_norm(features) if padding_mask is not None: padding_mask = self.forward_padding_mask(features, padding_mask) if self.post_extract_proj is not None: features = self.post_extract_proj(features) features = self.dropout_input(features) x = features # feature: (B, T, D), float # target: (B, T), long # x: (B, T, D), float # padding_mask: (B, T), bool # mask_indices: (B, T), bool x, _ = self.encoder( x, padding_mask=padding_mask, layer=None if output_layer is None else output_layer - 1, ) return x, padding_mask
[docs]def download_avhubert(model_url, dir_path): os.makedirs(dir_path, exist_ok=True) model_name = model_url.split("/")[-1] model_path = os.path.join(dir_path, model_name) if not os.path.exists(model_path): with FileLock(model_path + ".lock"): torch.hub.download_url_to_file(model_url, model_path) logging.info(f"AVHubert model downloaded {model_path}") else: logging.info(f"AVHubert model {model_path} already exists.") return model_path
[docs]class TransformerEncoder(nn.Module): """From AVHubert github""" def __init__(self, args): super().__init__() try: from fairseq.models.wav2vec.wav2vec2 import TransformerSentenceEncoderLayer from fairseq.modules import LayerNorm from fairseq.modules.transformer_sentence_encoder import init_bert_params except Exception as e: print("Error: FairSeq is not properly installed.") print("Please install FairSeq: cd ${MAIN_ROOT}/tools && make fairseq.done") raise e self.dropout = args.dropout self.embedding_dim = args.encoder_embed_dim self.pos_conv = nn.Conv1d( self.embedding_dim, self.embedding_dim, kernel_size=args.conv_pos, padding=args.conv_pos // 2, groups=args.conv_pos_groups, ) dropout = 0 std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) nn.init.normal_(self.pos_conv.weight, mean=0, std=std) nn.init.constant_(self.pos_conv.bias, 0) self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) self.layers = nn.ModuleList( [ TransformerSentenceEncoderLayer( embedding_dim=self.embedding_dim, ffn_embedding_dim=args.encoder_ffn_embed_dim, num_attention_heads=args.encoder_attention_heads, dropout=self.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, activation_fn=args.activation_fn, layer_norm_first=args.layer_norm_first, ) for _ in range(args.encoder_layers) ] ) self.layer_norm_first = args.layer_norm_first self.layer_norm = LayerNorm(self.embedding_dim) self.layerdrop = args.encoder_layerdrop self.apply(init_bert_params)
[docs] def forward(self, x, padding_mask=None, layer=None): x, layer_results = self.extract_features(x, padding_mask, layer) if self.layer_norm_first and layer is None: x = self.layer_norm(x) return x, layer_results
[docs] def extract_features(self, x, padding_mask=None, tgt_layer=None): if padding_mask is not None: x = index_put(x, padding_mask, 0) x_conv = self.pos_conv(x.transpose(1, 2)) x_conv = x_conv.transpose(1, 2) x = x + x_conv if not self.layer_norm_first: x = self.layer_norm(x) x = torch.nn.functional.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) layer_results = [] r = None for i, layer in enumerate(self.layers): dropout_probability = np.random.random() if not self.training or (dropout_probability > self.layerdrop): x, z = layer(x, self_attn_padding_mask=padding_mask, need_weights=False) if tgt_layer is not None: layer_results.append((x, z)) if i == tgt_layer: r = x break if r is not None: x = r # T x B x C -> B x T x C x = x.transpose(0, 1) return x, layer_results
[docs] def max_positions(self): """Maximum output length supported by the encoder.""" return self.args.max_positions
[docs] def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" return state_dict
[docs]class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, relu_type="relu"): super(BasicBlock, self).__init__() assert relu_type in ["relu", "prelu"] self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = nn.BatchNorm2d(planes) if relu_type == "relu": self.relu1 = nn.ReLU(inplace=True) self.relu2 = nn.ReLU(inplace=True) elif relu_type == "prelu": self.relu1 = nn.PReLU(num_parameters=planes) self.relu2 = nn.PReLU(num_parameters=planes) else: raise Exception("relu type not implemented") self.conv2 = conv3x3(planes, planes) self.bn2 = nn.BatchNorm2d(planes) self.downsample = downsample self.stride = stride
[docs] def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu1(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu2(out) return out
[docs]class ResNet(nn.Module): def __init__( self, block, layers, num_classes=1000, relu_type="relu", gamma_zero=False, avg_pool_downsample=False, ): self.inplanes = 64 self.relu_type = relu_type self.gamma_zero = gamma_zero self.downsample_block = ( downsample_basic_block_v2 if avg_pool_downsample else downsample_basic_block ) super(ResNet, self).__init__() self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=2) self.layer4 = self._make_layer(block, 512, layers[3], stride=2) self.avgpool = nn.AdaptiveAvgPool2d(1) for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2.0 / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() if self.gamma_zero: for m in self.modules(): if isinstance(m, BasicBlock): m.bn2.weight.data.zero_() def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = self.downsample_block( inplanes=self.inplanes, outplanes=planes * block.expansion, stride=stride, ) layers = [] layers.append( block(self.inplanes, planes, stride, downsample, relu_type=self.relu_type) ) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes, relu_type=self.relu_type)) return nn.Sequential(*layers)
[docs] def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = x.view(x.size(0), -1) return x
[docs]class ResEncoder(nn.Module): def __init__(self, relu_type, weights): super(ResEncoder, self).__init__() self.frontend_nout = 64 self.backend_out = 512 frontend_relu = ( nn.PReLU(num_parameters=self.frontend_nout) if relu_type == "prelu" else nn.ReLU() ) self.frontend3D = nn.Sequential( nn.Conv3d( 1, self.frontend_nout, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3), bias=False, ), nn.BatchNorm3d(self.frontend_nout), frontend_relu, nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)), ) self.trunk = ResNet(BasicBlock, [2, 2, 2, 2], relu_type=relu_type) if weights is not None: logger.info(f"Load {weights} for resnet") std = torch.load(weights, map_location=torch.device("cpu"))[ "model_state_dict" ] frontend_std, trunk_std = OrderedDict(), OrderedDict() for key, val in std.items(): new_key = ".".join(key.split(".")[1:]) if "frontend3D" in key: frontend_std[new_key] = val if "trunk" in key: trunk_std[new_key] = val self.frontend3D.load_state_dict(frontend_std) self.trunk.load_state_dict(trunk_std)
[docs] def forward(self, x): B, C, T, H, W = x.size() x = self.frontend3D(x) Tnew = x.shape[2] x = self.threeD_to_2D_tensor(x) x = self.trunk(x) x = x.view(B, Tnew, x.size(1)) x = x.transpose(1, 2).contiguous() return x
[docs] def threeD_to_2D_tensor(self, x): n_batch, n_channels, s_time, sx, sy = x.shape x = x.transpose(1, 2).contiguous() return x.reshape(n_batch * s_time, n_channels, sx, sy)
[docs]class SamePad(nn.Module): def __init__(self, kernel_size, causal=False): super().__init__() if causal: self.remove = kernel_size - 1 else: self.remove = 1 if kernel_size % 2 == 0 else 0
[docs] def forward(self, x): if self.remove > 0: x = x[:, :, : -self.remove] return x
[docs]def index_put(tensor, indices, value): if is_xla_tensor(tensor): for _ in range(indices.dim(), tensor.dim()): indices = indices.unsqueeze(-1) if indices.size(-1) < tensor.size(-1): indices = indices.expand_as(tensor) tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices) else: tensor[indices] = value return tensor
[docs]def is_xla_tensor(tensor): return torch.is_tensor(tensor) and tensor.device.type == "xla"
[docs]class GradMultiply(torch.autograd.Function):
[docs] @staticmethod def forward(ctx, x, scale): ctx.scale = scale res = x.new(x) return res
[docs] @staticmethod def backward(ctx, grad): return grad * ctx.scale, None