Comments (3)
It seems that your are using audio_model.mlp_heads
rather than audio_model.module.mlp_heads
for classification, which indicates your audio_model is not an torch.nn.dataparallel
object, if so, not only your mlp_heads, but also all other parts of AST model should not in the audio_model.module
dict but the audio_model
dict.
One thing you could try is setting strict=True
when you load the model and see how it says.
from ast.
Hi there,
I am not sure if it is a training/test mismatch or model loading issue. Are you using our code and get TEST mAPs every epoch during training? And that doesn't match with the test mAP if you load the model and do inference separately? If that is the case, is there a difference in data loading (especially the norm stats) in these two processes?
Or you only get mAP on the training set during training, and that doesn't match with the test mAP? I think it is normal that the training/test mAPs are different.
I didn't see an issue with your model loading if you also save your model as dataparallel object.
-Yuan
from ast.
this issue was solved by initializing the mlp head correctly using:
with torch.no_grad():
self.mlp_head[0].weight = nn.Parameter(sd["module.mlp_head.0.weight"])
self.mlp_head[0].bias = nn.Parameter(sd["module.mlp_head.0.bias"])
self.mlp_head[1].weight = nn.Parameter(sd["module.mlp_head.1.weight"])
self.mlp_head[1].bias = nn.Parameter(sd["module.mlp_head.1.bias"])
from ast.
Related Issues (20)
- About AST for Speech Enhancement HOT 5
- Using pretrained model for embeddings extraction with audio input samples of different durations. HOT 5
- Epoch: [4][160156/161048] training diverged... HOT 3
- Application on ASR
- Missing "esc_class_labels_indices.csv" file HOT 2
- How to convert fbank tensor back to waveform?
- Multichannel Audio Input
- What is the objective when pretraining? HOT 3
- pre-processing about AudioSet (resample to 16kHz) HOT 13
- Huggingface-compatible ImageNet pre-trained weights HOT 5
- Regression Task HOT 1
- ERROR: Cannot install -r requirements.txt (line 10) HOT 1
- RuntimeError: DataLoader worker (pid 39424) exited unexpectedly with exit code 1. Details are lost due to multiprocessing. Rerunning with num_workers=0 may give better error trace. HOT 1
- For own data HOT 1
- Installing requirement and CUDA on a fresh virtual environnement HOT 1
- how to use my own dataset HOT 3
- AST Audioset Training Time and Hardware HOT 2
- seq2seq classification with AST HOT 2
- How to use the pre-trained model on the AudioSet to extract audio features and save them as npy? HOT 2
- ast input audio length
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from ast.