Comments (14)
It is a bug, but setting input_fdim=128
should solve the problem, and I think that could also improve the performance.
from ast.
Hello @YuanGongND ,
The issue with our data is that we do not have enough audio to use 128 n_mels on the spectrogram, as we end up with mel filterbanks with all zero values. Is there any other workaround for this?
from ast.
I am not sure why mel filterbanks are all zero value. But I think it is worth try setting input_fdim=128
with audioset_pretrain=True
. The AudioSet pretrained model is trained with 128 bins so might not generalize well to 64 bin input anyway. Otherwise, you can just use the ImageNet pretrained model. I would suggest trying both and compare the results.
from ast.
hi, I also got this size-mismatch issue when set input_tdim!=1024
, just run ast_model.py
can reproduce this error.
from ast.
Hi,
How did you initialize the AST model? You need to specify the input_tdim
when you initialize the AST model and your actual input length needs to match the input_tdim
.
-Yuan
from ast.
yeah, thank you for your reply!
I specify the input_tdim = 256
with imagenet_pretrain && audioset_pretrain = True
, after the following modification it can work on my machine.
new_pos_embed = self.v.pos_embed[:, 2:, :].detach().reshape(1, num_patches, 768).transpose(1, 2).reshape(1, 768, 12, t_dim)
# if the input sequence length is larger than the original audioset (10s), then cut the positional embedding
# if t_dim < 101:
# print(new_pos_embed.shape)
# new_pos_embed = new_pos_embed[:, :, :, 50 - int(t_dim/2): 50 - int(t_dim/2) + t_dim]
# print(new_pos_embed.shape)
# # otherwise interpolate
# else:
# new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(12, t_dim), mode='bilinear')
new_pos_embed = new_pos_embed.reshape(1, 768, num_patches).transpose(1, 2)
self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1))
from ast.
Hi @H-Liu1997,
Your solution can avoid the error, but would also cause a performance drop. I don't suggest commenting out the positional embedding adaptation code.
Could you share your code of initializing the AST model and forward pass (specifically your input shape)?
-Yuan
from ast.
Could you share your code of initializing the AST model and forward pass (specifically your input shape)?
@YuanGongND Sure! Thank you very much for your help!
- I tested the input shape
batch_size*128*1024
and it worked well - Then, I want to train on custom data with input shape
batch_size*128*128
and use pretrained weights - I only copy
ast_model.py
and callASTModel(label_dim=1, fstride=10, tstride=10, input_fdim=128, input_tdim=128, imagenet_pretrain=True, audioset_pretrain=True, model_size='base384', verbose=True)
for initialization. - I replace all 1024 with 128, comment out the positional embedding code and change the order of
# will raise load_state_dict error on my machine
audio_model = torch.nn.DataParallel(audio_model)
audio_model.load_state_dict(sd, strict=False)
# to
audio_model.load_state_dict(sd, strict=False)
audio_model = torch.nn.DataParallel(audio_model)
- the full code I used is following:
# -*- coding: utf-8 -*-
# @Time : 6/10/21 5:04 PM
# @Author : Yuan Gong
# @Affiliation : Massachusetts Institute of Technology
# @Email : [email protected]
# @File : ast_models.py
import torch
import torch.nn as nn
from torch.cuda.amp import autocast
import os
import wget
# os.environ['TORCH_HOME'] = '../../pretrained_models'
import timm
from timm.models.layers import to_2tuple,trunc_normal_
# override the timm package to relax the input shape constraint.
class PatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class ASTModel(nn.Module):
"""
The AST model.
:param label_dim: the label dimension, i.e., the number of total classes, it is 527 for AudioSet, 50 for ESC-50, and 35 for speechcommands v2-35
:param fstride: the stride of patch spliting on the frequency dimension, for 16*16 patchs, fstride=16 means no overlap, fstride=10 means overlap of 6
:param tstride: the stride of patch spliting on the time dimension, for 16*16 patchs, tstride=16 means no overlap, tstride=10 means overlap of 6
:param input_fdim: the number of frequency bins of the input spectrogram
:param input_tdim: the number of time frames of the input spectrogram
:param imagenet_pretrain: if use ImageNet pretrained model
:param audioset_pretrain: if use full AudioSet and ImageNet pretrained model
:param model_size: the model size of AST, should be in [tiny224, small224, base224, base384], base224 and base 384 are same model, but are trained differently during ImageNet pretraining.
"""
def __init__(self, label_dim=1, fstride=10, tstride=10, input_fdim=128, input_tdim=128, imagenet_pretrain=True, audioset_pretrain=True, model_size='base384', verbose=True):
super(ASTModel, self).__init__()
# assert timm.__version__ == '0.4.5', 'Please use timm == 0.4.5, the code might not be compatible with newer versions.'
if verbose == True:
print('---------------AST Model Summary---------------')
print('ImageNet pretraining: {:s}, AudioSet pretraining: {:s}'.format(str(imagenet_pretrain),str(audioset_pretrain)))
# override timm input shape restriction
timm.models.vision_transformer.PatchEmbed = PatchEmbed
# if AudioSet pretraining is not used (but ImageNet pretraining may still apply)
if audioset_pretrain == False:
if model_size == 'tiny224':
self.v = timm.create_model('vit_deit_tiny_distilled_patch16_224', pretrained=imagenet_pretrain)
elif model_size == 'small224':
self.v = timm.create_model('vit_deit_small_distilled_patch16_224', pretrained=imagenet_pretrain)
elif model_size == 'base224':
self.v = timm.create_model('vit_deit_base_distilled_patch16_224', pretrained=imagenet_pretrain)
elif model_size == 'base384':
self.v = timm.create_model('vit_deit_base_distilled_patch16_384', pretrained=imagenet_pretrain)
else:
raise Exception('Model size must be one of tiny224, small224, base224, base384.')
self.original_num_patches = self.v.patch_embed.num_patches
self.oringal_hw = int(self.original_num_patches ** 0.5)
self.original_embedding_dim = self.v.pos_embed.shape[2]
self.mlp_head = nn.Sequential(nn.LayerNorm(self.original_embedding_dim), nn.Linear(self.original_embedding_dim, label_dim))
# automatcially get the intermediate shape
f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim)
num_patches = f_dim * t_dim
self.v.patch_embed.num_patches = num_patches
if verbose == True:
print('frequncey stride={:d}, time stride={:d}'.format(fstride, tstride))
print('number of patches={:d}'.format(num_patches))
# the linear projection layer
new_proj = torch.nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride))
if imagenet_pretrain == True:
new_proj.weight = torch.nn.Parameter(torch.sum(self.v.patch_embed.proj.weight, dim=1).unsqueeze(1))
new_proj.bias = self.v.patch_embed.proj.bias
self.v.patch_embed.proj = new_proj
# the positional embedding
if imagenet_pretrain == True:
# get the positional embedding from deit model, skip the first two tokens (cls token and distillation token), reshape it to original 2D shape (24*24).
new_pos_embed = self.v.pos_embed[:, 2:, :].detach().reshape(1, self.original_num_patches, self.original_embedding_dim).transpose(1, 2).reshape(1, self.original_embedding_dim, self.oringal_hw, self.oringal_hw)
# cut (from middle) or interpolate the second dimension of the positional embedding
if t_dim <= self.oringal_hw:
new_pos_embed = new_pos_embed[:, :, :, int(self.oringal_hw / 2) - int(t_dim / 2): int(self.oringal_hw / 2) - int(t_dim / 2) + t_dim]
else:
new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(self.oringal_hw, t_dim), mode='bilinear')
# cut (from middle) or interpolate the first dimension of the positional embedding
if f_dim <= self.oringal_hw:
new_pos_embed = new_pos_embed[:, :, int(self.oringal_hw / 2) - int(f_dim / 2): int(self.oringal_hw / 2) - int(f_dim / 2) + f_dim, :]
else:
new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(f_dim, t_dim), mode='bilinear')
# flatten the positional embedding
new_pos_embed = new_pos_embed.reshape(1, self.original_embedding_dim, num_patches).transpose(1,2)
# concatenate the above positional embedding with the cls token and distillation token of the deit model.
self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1))
else:
# if not use imagenet pretrained model, just randomly initialize a learnable positional embedding
# TODO can use sinusoidal positional embedding instead
new_pos_embed = nn.Parameter(torch.zeros(1, self.v.patch_embed.num_patches + 2, self.original_embedding_dim))
self.v.pos_embed = new_pos_embed
trunc_normal_(self.v.pos_embed, std=.02)
# now load a model that is pretrained on both ImageNet and AudioSet
elif audioset_pretrain == True:
if audioset_pretrain == True and imagenet_pretrain == False:
raise ValueError('currently model pretrained on only audioset is not supported, please set imagenet_pretrain = True to use audioset pretrained model.')
if model_size != 'base384':
raise ValueError('currently only has base384 AudioSet pretrained model.')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if os.path.exists('../../Datasets/checkpoints/audioset_10_10_0.4593.pth') == False:
# this model performs 0.4593 mAP on the audioset eval set
audioset_mdl_url = 'https://www.dropbox.com/s/cv4knew8mvbrnvq/audioset_0.4593.pth?dl=1'
wget.download(audioset_mdl_url, out='../../pretrained_models/audioset_10_10_0.4593.pth')
sd = torch.load('../../Datasets/checkpoints/audioset_10_10_0.4593.pth', map_location=device)
audio_model = ASTModel(label_dim=1, fstride=10, tstride=10, input_fdim=128, input_tdim=128, imagenet_pretrain=False, audioset_pretrain=False, model_size='base384', verbose=False)
audio_model.load_state_dict(sd, strict=False)
audio_model = torch.nn.DataParallel(audio_model)
self.v = audio_model.module.v
self.original_embedding_dim = self.v.pos_embed.shape[2]
self.mlp_head = nn.Sequential(nn.LayerNorm(self.original_embedding_dim), nn.Linear(self.original_embedding_dim, label_dim))
f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim)
num_patches = f_dim * t_dim
self.v.patch_embed.num_patches = num_patches
if verbose == True:
print('frequncey stride={:d}, time stride={:d}'.format(fstride, tstride))
print('t_dim={:d}, f_dim={:d}'.format(t_dim, f_dim))
print('number of patches={:d}'.format(num_patches))
# print(self.v.pos_embed[:, 2:, :].shape)
new_pos_embed = self.v.pos_embed[:, 2:, :].detach().reshape(1, num_patches, 768).transpose(1, 2).reshape(1, 768, 12, t_dim)
# if the input sequence length is larger than the original audioset (10s), then cut the positional embedding
# if t_dim < 101:
# print(new_pos_embed.shape)
# new_pos_embed = new_pos_embed[:, :, :, 50 - int(t_dim/2): 50 - int(t_dim/2) + t_dim]
# print(new_pos_embed.shape)
# # otherwise interpolate
# else:
# new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(12, t_dim), mode='bilinear')
new_pos_embed = new_pos_embed.reshape(1, 768, num_patches).transpose(1, 2)
self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1))
def get_shape(self, fstride, tstride, input_fdim=128, input_tdim=128):
test_input = torch.randn(1, 1, input_fdim, input_tdim)
test_proj = nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride))
test_out = test_proj(test_input)
f_dim = test_out.shape[2]
t_dim = test_out.shape[3]
#print(f_dim,t_dim)
return f_dim, t_dim
@autocast()
def forward(self, x):
"""
:param x: the input spectrogram, expected shape: (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128)
:return: prediction
"""
# expect input x = (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128)
x = x.unsqueeze(1)
x = x.transpose(2, 3)
B = x.shape[0]
x = self.v.patch_embed(x)
cls_tokens = self.v.cls_token.expand(B, -1, -1)
dist_token = self.v.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_token, x), dim=1)
x = x + self.v.pos_embed
x = self.v.pos_drop(x)
for blk in self.v.blocks:
x = blk(x)
x = self.v.norm(x)
x = (x[:, 0] + x[:, 1]) / 2
x = self.mlp_head(x)
return x
if __name__ == '__main__':
input_tdim = 256
ast_mdl = ASTModel(input_tdim=input_tdim)
# input a batch of 10 spectrogram, each with 100 time frames and 128 frequency bins
test_input = torch.rand([10, input_tdim, 128])
test_output = ast_mdl(test_input)
# output should be in shape [10, 527], i.e., 10 samples, each with prediction of 527 classes.
print(test_output.shape)
input_tdim = 512
ast_mdl = ASTModel(input_tdim=input_tdim,label_dim=50, audioset_pretrain=True)
# input a batch of 10 spectrogram, each with 512 time frames and 128 frequency bins
test_input = torch.rand([10, input_tdim, 128])
test_output = ast_mdl(test_input)
# output should be in shape [10, 50], i.e., 10 samples, each with prediction of 50 classes.
print(test_output.shape)
from ast.
If you keep the original ast_models.py
unchanged and run it, will you get an error? I just ran it and didn't get an error. The second example is similar to your case, right?
input_tdim = 512
ast_mdl = ASTModel(input_tdim=input_tdim,label_dim=50, audioset_pretrain=True)
# input a batch of 10 spectrogram, each with 512 time frames and 128 frequency bins
test_input = torch.rand([10, input_tdim, 128])
test_output = ast_mdl(test_input)
# output should be in shape [10, 50], i.e., 10 samples, each with prediction of 50 classes.
print(test_output.shape)
from ast.
And I think switching the order of
audio_model = torch.nn.DataParallel(audio_model)
audio_model.load_state_dict(sd, strict=False)
at line 128 could lead to the issue, can you tell me the error message if you don't switch the order?
from ast.
@YuanGongND
Sorry for the inconvenience.
You are right, switching the order lead to the issue. it will work well if unchanged everything, and just call ASTModel(input_tdim=input_tdim,label_dim=50, audioset_pretrain=True)
- the issue is because (maybe) I changed all 1024 to 128 in the
_init_
function. - after that, I will get a DDP error in line 128 and I switched the order, which leads to another issue.
now, the code can work well, thank you for your help and sorry for the inconvenience again.
from ast.
Thanks for the clasrification and great to know the code works.
from ast.
Hi @YuanGongND abd @H-Liu1997 ,
I have the same use case as yours and i tried the below code as suggested by @YuanGongND , The error dissapears if I set audioset_pretraining=False.
`input_tdim = 100
input_fdim = 64
ast_mdl = ASTModel(input_tdim=input_tdim,input_fdim = input_fdim,label_dim=50, audioset_pretrain=True, imagenet_pretrain=True)
input a batch of 10 spectrogram, each with 512 time frames and 128 frequency bins
test_input = torch.rand([10, input_tdim, input_fdim])
test_output = ast_mdl(test_input)
output should be in shape [10, 50], i.e., 10 samples, each with prediction of 50 classes.
print(test_output.shape)`
and i got the following error!
Any help is appreciated. Thanks.
from ast.
Hi @YuanGongND abd @H-Liu1997 , I have the same use case as yours and i tried the below code as suggested by @YuanGongND , The error dissapears if I set audioset_pretraining=False.
`input_tdim = 100 input_fdim = 64 ast_mdl = ASTModel(input_tdim=input_tdim,input_fdim = input_fdim,label_dim=50, audioset_pretrain=True, imagenet_pretrain=True)
input a batch of 10 spectrogram, each with 512 time frames and 128 frequency bins
test_input = torch.rand([10, input_tdim, input_fdim]) test_output = ast_mdl(test_input)
output should be in shape [10, 50], i.e., 10 samples, each with prediction of 50 classes.
print(test_output.shape)`
and i got the following error!
Any help is appreciated. Thanks.
This issue is solved. The issue was with the mis match when the input was on the CPU. I had to force the input to be on GPU to solve it.
However this didnt had the problem in case of imagenet pretrained mode.
from ast.
Related Issues (20)
- seq2seq classification with AST HOT 2
- After fine-tune a 3-class dataset, how to load its fine-tuned weighted to update pre-trained ast model? HOT 7
- CPU memory increase while training HOT 6
- Fine tuning AST model to Music Emotion Classification Overfit HOT 3
- How can I adapt the pretrained AST model to fit my own dataset HOT 6
- ESC-50-master zip file location has changed HOT 2
- Installing requirements issues HOT 1
- When I download the pretrained model with stride=16, I need to change `fstride` and `tstride` in the source code from 10 to 16. Besides these changes, what else do I need to adjust?
- Different audio sample size for fine-tuning the model gives overfitting issue HOT 1
- training MAP HOT 2
- One question regarding the linear projection of AST. HOT 1
- Inquiry Regarding Audio Spectrogram Transformer HOT 2
- self-contained Google Colab script error HOT 2
- Ask for help HOT 1
- some questions when reproducing your results HOT 2
- csv error HOT 1
- AssertionError: choose a window size 400 that is [2, 1] HOT 2
- Discrepancy in Model Performance Using HuggingFace Pipeline Utility HOT 5
- AttributeError: module 'numpy.typing' has no attribute 'NDArray' HOT 3
- How to incrementally fine-tune AST model with the new data?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from ast.