GithubHelp home page GithubHelp logo

Comments (14)

YuanGongND avatar YuanGongND commented on August 16, 2024

It is a bug, but setting input_fdim=128 should solve the problem, and I think that could also improve the performance.

from ast.

devksingh4 avatar devksingh4 commented on August 16, 2024

Hello @YuanGongND ,
The issue with our data is that we do not have enough audio to use 128 n_mels on the spectrogram, as we end up with mel filterbanks with all zero values. Is there any other workaround for this?

from ast.

YuanGongND avatar YuanGongND commented on August 16, 2024

I am not sure why mel filterbanks are all zero value. But I think it is worth try setting input_fdim=128 with audioset_pretrain=True. The AudioSet pretrained model is trained with 128 bins so might not generalize well to 64 bin input anyway. Otherwise, you can just use the ImageNet pretrained model. I would suggest trying both and compare the results.

from ast.

H-Liu1997 avatar H-Liu1997 commented on August 16, 2024

hi, I also got this size-mismatch issue when set input_tdim!=1024, just run ast_model.py can reproduce this error.

from ast.

YuanGongND avatar YuanGongND commented on August 16, 2024

Hi,

How did you initialize the AST model? You need to specify the input_tdim when you initialize the AST model and your actual input length needs to match the input_tdim.

-Yuan

from ast.

H-Liu1997 avatar H-Liu1997 commented on August 16, 2024

yeah, thank you for your reply!
I specify the input_tdim = 256 with imagenet_pretrain && audioset_pretrain = True, after the following modification it can work on my machine.

            new_pos_embed = self.v.pos_embed[:, 2:, :].detach().reshape(1, num_patches, 768).transpose(1, 2).reshape(1, 768, 12, t_dim)
            # if the input sequence length is larger than the original audioset (10s), then cut the positional embedding
#             if t_dim < 101:
#                 print(new_pos_embed.shape)
#                 new_pos_embed = new_pos_embed[:, :, :, 50 - int(t_dim/2): 50 - int(t_dim/2) + t_dim]
#                 print(new_pos_embed.shape)
#             # otherwise interpolate
#             else:
#                 new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(12, t_dim), mode='bilinear')
            new_pos_embed = new_pos_embed.reshape(1, 768, num_patches).transpose(1, 2)
            self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1))

from ast.

YuanGongND avatar YuanGongND commented on August 16, 2024

Hi @H-Liu1997,

Your solution can avoid the error, but would also cause a performance drop. I don't suggest commenting out the positional embedding adaptation code.

Could you share your code of initializing the AST model and forward pass (specifically your input shape)?

-Yuan

from ast.

H-Liu1997 avatar H-Liu1997 commented on August 16, 2024

Could you share your code of initializing the AST model and forward pass (specifically your input shape)?

@YuanGongND Sure! Thank you very much for your help!

  • I tested the input shape batch_size*128*1024 and it worked well
  • Then, I want to train on custom data with input shape batch_size*128*128 and use pretrained weights
  • I only copy ast_model.py and call ASTModel(label_dim=1, fstride=10, tstride=10, input_fdim=128, input_tdim=128, imagenet_pretrain=True, audioset_pretrain=True, model_size='base384', verbose=True) for initialization.
  • I replace all 1024 with 128, comment out the positional embedding code and change the order of
# will raise load_state_dict error on my machine
audio_model = torch.nn.DataParallel(audio_model)
audio_model.load_state_dict(sd, strict=False)
# to
audio_model.load_state_dict(sd, strict=False)
audio_model = torch.nn.DataParallel(audio_model)
  • the full code I used is following:
# -*- coding: utf-8 -*-
# @Time    : 6/10/21 5:04 PM
# @Author  : Yuan Gong
# @Affiliation  : Massachusetts Institute of Technology
# @Email   : [email protected]
# @File    : ast_models.py

import torch
import torch.nn as nn
from torch.cuda.amp import autocast
import os
import wget
# os.environ['TORCH_HOME'] = '../../pretrained_models'
import timm
from timm.models.layers import to_2tuple,trunc_normal_

# override the timm package to relax the input shape constraint.
   class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()

        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x

    class ASTModel(nn.Module):
    """
    The AST model.
    :param label_dim: the label dimension, i.e., the number of total classes, it is 527 for AudioSet, 50 for ESC-50, and 35 for speechcommands v2-35
    :param fstride: the stride of patch spliting on the frequency dimension, for 16*16 patchs, fstride=16 means no overlap, fstride=10 means overlap of 6
    :param tstride: the stride of patch spliting on the time dimension, for 16*16 patchs, tstride=16 means no overlap, tstride=10 means overlap of 6
    :param input_fdim: the number of frequency bins of the input spectrogram
    :param input_tdim: the number of time frames of the input spectrogram
    :param imagenet_pretrain: if use ImageNet pretrained model
    :param audioset_pretrain: if use full AudioSet and ImageNet pretrained model
    :param model_size: the model size of AST, should be in [tiny224, small224, base224, base384], base224 and base 384 are same model, but are trained differently during ImageNet pretraining.
    """
    def __init__(self, label_dim=1, fstride=10, tstride=10, input_fdim=128, input_tdim=128, imagenet_pretrain=True, audioset_pretrain=True, model_size='base384', verbose=True):

        super(ASTModel, self).__init__()
        # assert timm.__version__ == '0.4.5', 'Please use timm == 0.4.5, the code might not be compatible with newer versions.'

        if verbose == True:
            print('---------------AST Model Summary---------------')
            print('ImageNet pretraining: {:s}, AudioSet pretraining: {:s}'.format(str(imagenet_pretrain),str(audioset_pretrain)))
        # override timm input shape restriction
        timm.models.vision_transformer.PatchEmbed = PatchEmbed

        # if AudioSet pretraining is not used (but ImageNet pretraining may still apply)
        if audioset_pretrain == False:
            if model_size == 'tiny224':
                self.v = timm.create_model('vit_deit_tiny_distilled_patch16_224', pretrained=imagenet_pretrain)
            elif model_size == 'small224':
                self.v = timm.create_model('vit_deit_small_distilled_patch16_224', pretrained=imagenet_pretrain)
            elif model_size == 'base224':
                self.v = timm.create_model('vit_deit_base_distilled_patch16_224', pretrained=imagenet_pretrain)
            elif model_size == 'base384':
                self.v = timm.create_model('vit_deit_base_distilled_patch16_384', pretrained=imagenet_pretrain)
            else:
                raise Exception('Model size must be one of tiny224, small224, base224, base384.')
            self.original_num_patches = self.v.patch_embed.num_patches
            self.oringal_hw = int(self.original_num_patches ** 0.5)
            self.original_embedding_dim = self.v.pos_embed.shape[2]
            self.mlp_head = nn.Sequential(nn.LayerNorm(self.original_embedding_dim), nn.Linear(self.original_embedding_dim, label_dim))

            # automatcially get the intermediate shape
            f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim)
            num_patches = f_dim * t_dim
            self.v.patch_embed.num_patches = num_patches
            if verbose == True:
                print('frequncey stride={:d}, time stride={:d}'.format(fstride, tstride))
                print('number of patches={:d}'.format(num_patches))

            # the linear projection layer
            new_proj = torch.nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride))
            if imagenet_pretrain == True:
                new_proj.weight = torch.nn.Parameter(torch.sum(self.v.patch_embed.proj.weight, dim=1).unsqueeze(1))
                new_proj.bias = self.v.patch_embed.proj.bias
            self.v.patch_embed.proj = new_proj

            # the positional embedding
            if imagenet_pretrain == True:
                # get the positional embedding from deit model, skip the first two tokens (cls token and distillation token), reshape it to original 2D shape (24*24).
                new_pos_embed = self.v.pos_embed[:, 2:, :].detach().reshape(1, self.original_num_patches, self.original_embedding_dim).transpose(1, 2).reshape(1, self.original_embedding_dim, self.oringal_hw, self.oringal_hw)
                # cut (from middle) or interpolate the second dimension of the positional embedding
                if t_dim <= self.oringal_hw:
                    new_pos_embed = new_pos_embed[:, :, :, int(self.oringal_hw / 2) - int(t_dim / 2): int(self.oringal_hw / 2) - int(t_dim / 2) + t_dim]
                else:
                    new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(self.oringal_hw, t_dim), mode='bilinear')
                # cut (from middle) or interpolate the first dimension of the positional embedding
                if f_dim <= self.oringal_hw:
                    new_pos_embed = new_pos_embed[:, :, int(self.oringal_hw / 2) - int(f_dim / 2): int(self.oringal_hw / 2) - int(f_dim / 2) + f_dim, :]
                else:
                    new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(f_dim, t_dim), mode='bilinear')
                # flatten the positional embedding
                new_pos_embed = new_pos_embed.reshape(1, self.original_embedding_dim, num_patches).transpose(1,2)
                # concatenate the above positional embedding with the cls token and distillation token of the deit model.
                self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1))
            else:
                # if not use imagenet pretrained model, just randomly initialize a learnable positional embedding
                # TODO can use sinusoidal positional embedding instead
                new_pos_embed = nn.Parameter(torch.zeros(1, self.v.patch_embed.num_patches + 2, self.original_embedding_dim))
                self.v.pos_embed = new_pos_embed
                trunc_normal_(self.v.pos_embed, std=.02)

        # now load a model that is pretrained on both ImageNet and AudioSet
        elif audioset_pretrain == True:
            if audioset_pretrain == True and imagenet_pretrain == False:
                raise ValueError('currently model pretrained on only audioset is not supported, please set imagenet_pretrain = True to use audioset pretrained model.')
            if model_size != 'base384':
                raise ValueError('currently only has base384 AudioSet pretrained model.')
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            if os.path.exists('../../Datasets/checkpoints/audioset_10_10_0.4593.pth') == False:
                # this model performs 0.4593 mAP on the audioset eval set
                audioset_mdl_url = 'https://www.dropbox.com/s/cv4knew8mvbrnvq/audioset_0.4593.pth?dl=1'
                wget.download(audioset_mdl_url, out='../../pretrained_models/audioset_10_10_0.4593.pth')
            sd = torch.load('../../Datasets/checkpoints/audioset_10_10_0.4593.pth', map_location=device)
            audio_model = ASTModel(label_dim=1, fstride=10, tstride=10, input_fdim=128, input_tdim=128, imagenet_pretrain=False, audioset_pretrain=False, model_size='base384', verbose=False)
            
            audio_model.load_state_dict(sd, strict=False)
            audio_model = torch.nn.DataParallel(audio_model)
            self.v = audio_model.module.v
            self.original_embedding_dim = self.v.pos_embed.shape[2]
            self.mlp_head = nn.Sequential(nn.LayerNorm(self.original_embedding_dim), nn.Linear(self.original_embedding_dim, label_dim))

            f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim)
            num_patches = f_dim * t_dim
            self.v.patch_embed.num_patches = num_patches
            if verbose == True:
                print('frequncey stride={:d}, time stride={:d}'.format(fstride, tstride))
                print('t_dim={:d}, f_dim={:d}'.format(t_dim, f_dim))
                print('number of patches={:d}'.format(num_patches))
            # print(self.v.pos_embed[:, 2:, :].shape)
            new_pos_embed = self.v.pos_embed[:, 2:, :].detach().reshape(1, num_patches, 768).transpose(1, 2).reshape(1, 768, 12, t_dim)
            # if the input sequence length is larger than the original audioset (10s), then cut the positional embedding
#             if t_dim < 101:
#                 print(new_pos_embed.shape)
#                 new_pos_embed = new_pos_embed[:, :, :, 50 - int(t_dim/2): 50 - int(t_dim/2) + t_dim]
#                 print(new_pos_embed.shape)
#             # otherwise interpolate
#             else:
#                 new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(12, t_dim), mode='bilinear')
            new_pos_embed = new_pos_embed.reshape(1, 768, num_patches).transpose(1, 2)
            self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1))

    def get_shape(self, fstride, tstride, input_fdim=128, input_tdim=128):
        test_input = torch.randn(1, 1, input_fdim, input_tdim)
        test_proj = nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride))
        test_out = test_proj(test_input)
        f_dim = test_out.shape[2]
        t_dim = test_out.shape[3]
        #print(f_dim,t_dim)
        return f_dim, t_dim

    @autocast()
    def forward(self, x):
        """
        :param x: the input spectrogram, expected shape: (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128)
        :return: prediction
        """
        # expect input x = (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128)
        x = x.unsqueeze(1)
        x = x.transpose(2, 3)

        B = x.shape[0]
        x = self.v.patch_embed(x)
        cls_tokens = self.v.cls_token.expand(B, -1, -1)
        dist_token = self.v.dist_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, dist_token, x), dim=1)
        x = x + self.v.pos_embed
        x = self.v.pos_drop(x)
        for blk in self.v.blocks:
            x = blk(x)
        x = self.v.norm(x)
        x = (x[:, 0] + x[:, 1]) / 2

        x = self.mlp_head(x)
        return x

    if __name__ == '__main__':
    input_tdim = 256
    ast_mdl = ASTModel(input_tdim=input_tdim)
    # input a batch of 10 spectrogram, each with 100 time frames and 128 frequency bins
    test_input = torch.rand([10, input_tdim, 128])
    test_output = ast_mdl(test_input)
    # output should be in shape [10, 527], i.e., 10 samples, each with prediction of 527 classes.
    print(test_output.shape)

    input_tdim = 512
    ast_mdl = ASTModel(input_tdim=input_tdim,label_dim=50, audioset_pretrain=True)
    # input a batch of 10 spectrogram, each with 512 time frames and 128 frequency bins
    test_input = torch.rand([10, input_tdim, 128])
    test_output = ast_mdl(test_input)
    # output should be in shape [10, 50], i.e., 10 samples, each with prediction of 50 classes.
    print(test_output.shape)

from ast.

YuanGongND avatar YuanGongND commented on August 16, 2024

If you keep the original ast_models.py unchanged and run it, will you get an error? I just ran it and didn't get an error. The second example is similar to your case, right?

    input_tdim = 512
    ast_mdl = ASTModel(input_tdim=input_tdim,label_dim=50, audioset_pretrain=True)
    # input a batch of 10 spectrogram, each with 512 time frames and 128 frequency bins
    test_input = torch.rand([10, input_tdim, 128])
    test_output = ast_mdl(test_input)
    # output should be in shape [10, 50], i.e., 10 samples, each with prediction of 50 classes.
    print(test_output.shape)

from ast.

YuanGongND avatar YuanGongND commented on August 16, 2024

And I think switching the order of

audio_model = torch.nn.DataParallel(audio_model)
audio_model.load_state_dict(sd, strict=False)

at line 128 could lead to the issue, can you tell me the error message if you don't switch the order?

from ast.

H-Liu1997 avatar H-Liu1997 commented on August 16, 2024

@YuanGongND
Sorry for the inconvenience.
You are right, switching the order lead to the issue. it will work well if unchanged everything, and just call ASTModel(input_tdim=input_tdim,label_dim=50, audioset_pretrain=True)

  • the issue is because (maybe) I changed all 1024 to 128 in the _init_ function.
  • after that, I will get a DDP error in line 128 and I switched the order, which leads to another issue.

now, the code can work well, thank you for your help and sorry for the inconvenience again.

from ast.

YuanGongND avatar YuanGongND commented on August 16, 2024

Thanks for the clasrification and great to know the code works.

from ast.

sreenivasaupadhyaya avatar sreenivasaupadhyaya commented on August 16, 2024

Hi @YuanGongND abd @H-Liu1997 ,
I have the same use case as yours and i tried the below code as suggested by @YuanGongND , The error dissapears if I set audioset_pretraining=False.

`input_tdim = 100
input_fdim = 64
ast_mdl = ASTModel(input_tdim=input_tdim,input_fdim = input_fdim,label_dim=50, audioset_pretrain=True, imagenet_pretrain=True)

input a batch of 10 spectrogram, each with 512 time frames and 128 frequency bins

test_input = torch.rand([10, input_tdim, input_fdim])
test_output = ast_mdl(test_input)

output should be in shape [10, 50], i.e., 10 samples, each with prediction of 50 classes.

print(test_output.shape)`

and i got the following error!

image

Any help is appreciated. Thanks.

from ast.

sreenivasaupadhyaya avatar sreenivasaupadhyaya commented on August 16, 2024

Hi @YuanGongND abd @H-Liu1997 , I have the same use case as yours and i tried the below code as suggested by @YuanGongND , The error dissapears if I set audioset_pretraining=False.

`input_tdim = 100 input_fdim = 64 ast_mdl = ASTModel(input_tdim=input_tdim,input_fdim = input_fdim,label_dim=50, audioset_pretrain=True, imagenet_pretrain=True)

input a batch of 10 spectrogram, each with 512 time frames and 128 frequency bins

test_input = torch.rand([10, input_tdim, input_fdim]) test_output = ast_mdl(test_input)

output should be in shape [10, 50], i.e., 10 samples, each with prediction of 50 classes.

print(test_output.shape)`

and i got the following error!

image

Any help is appreciated. Thanks.

This issue is solved. The issue was with the mis match when the input was on the CPU. I had to force the input to be on GPU to solve it.
However this didnt had the problem in case of imagenet pretrained mode.

from ast.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.