GithubHelp home page GithubHelp logo

labram's People

Contributors

935963004 avatar itsaphel avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

labram's Issues

关于labram中codebook有无的消融实验

祝贺您!很高兴看到您的团队发表的这篇充满见解的文章,关于其中的实验细节,我想向您询问一下。我看到您在文章中说,是否使用CodeBook对实验的结果不会有太大的提升。关于这个实验我很好奇,我想向您询问,有无codebook的实验是怎么做的呢?

Classification accuracy calculation Error in finetuning when multilabel situation

Hi @935963004 , thanks for opening the excellent work!

As title illustrated, in finetuning phase line 117, the training accuracy is calculated as the class_acc = (output.max(-1)[-1] == targets).float().mean() , since the output has the tensor with dimension (64, 3) and target (64, 1), after max(-1)[-1], the pred dimension is (64), directly calling == will use the pytorch broadcasting mechanism,

  • pred broadcast expanded from [64] to [1, 64]
  • target expanded from [64, 1] broadcast to [64, 64]
    As a result, both tensors are expanded to [64, 64], and then the two expanded tensors are compared element-wise. Therefore, each pred[i] element is compared with an element of each row of target[:, 0] in target, and the end result is a [64, 64] tensor representing the comparison result for each combination.
    The correct approach is as follows
    simply replace with the (output.max(-1)[-1] == targets.squeeze()).float().mean()
    OR
    just as the val and test processing:
    class_acc = utils.get_metrics(output.detach().cpu().numpy(), targets.detach().cpu().numpy(), ['accuracy'], is_binary)['accuracy']

Here is the minimal reproducible code:

import torch
pred = torch.tensor([0, 0, 2, 1, 0, 2, 2, 2, 1, 1, 0, 2, 0, 2, 0, 0, 1, 1, 2, 0, 2, 0, 0, 2,
0, 0, 0, 0, 0, 2, 2, 0, 1, 2, 0, 2, 0, 2, 1, 2, 0, 0, 1, 0, 2, 1, 2, 0,
2, 1, 1, 2, 1, 1, 2, 2, 0, 1, 0, 0, 1, 1, 0, 2])
target = torch.tensor([[0],[0],[2],[1],[0],[2],[2], [2],[1],[1],[0],[2],[0],[2],[0],[1],[1], [1], [2], [0], [2], [1], [0], [2],[0],[0], [0],[0], [0], [2],[2],[0],[1],[2],
[0],[2],[0],[2],[1],[2],[0],[0],[1],[0],[2],[1], [2],[0],[2],[1],[1],[2], [1], [1],[2],[2],[0],[1],[0],[0],[1],[1],[0],[2]])

comparison = (pred == target).float()
comparison_squeeze = (pred == target.squeeze()).float()

accuracy = comparison.mean()
accuracy_squeeze = comparison_squeeze.mean()

accuracy.item(), accuracy_squeeze.item()

(0.3408203125, 0.96875)

用公开CHB-MIT数据集跑微调代码?

作者您好,我在使用您的微调代码去跑CHB-MIT数据集(二分类)时,发现训练集准确率可以很高,loss也在降低。但是验证集和测试集的评价指标一直为0。请问您有遇到这种问题吗,不知您是怎么解决的?
1
2
3

.cnt files do classification tasks

If the dataset is a cnt file, what code needs to be used or written to use the model?I seem to be seeing a couple of preprocessed python scripts:make_h5dataset_for_pretrain.py、dataset.py. Do I just modify the folder path and run it directly, or do I need to write my own code? If I need to write my own code, what do I need to pay attention to? I've seen a lot of people have questions about how to run the model on their own dataset, can you provide some help? Thank you

关于训练模型时的数据集加载问题

作者您好,根据您提供的代码,我制作了相应的多个数据集文件用于训练vqnsp,每个数据集单独保存为一个.hdf5文件。并且数据集借助仓库中提供的ShockDataset()类进行加载。但是在加载时发现,使用shuffle=True创建的Sampler会大大降低数据集加载的速度,导致GPU一直处于闲置的状态(数据集加载过慢)。当shuffle=False时,训练速度才恢复正常。我想这和dataloader对样本进行采样的逻辑有关,当随机采样时,需要频繁访问多个磁盘位置的数据;当按顺序采样时,只需要采样物理存储位置较近的相邻的数据即可。所以顺序采样加载数据集的速度更快。但是这其实不符合小批量随机梯度下降的训练逻辑。请问您在训练过程中有遇到这样的问题吗,不知您是如何解决的?

How to deal with data sets with different number of channels?

Thank you very much for the code. I would like to ask you, you mentioned in the code that you need to provide different time windows to ensure that the sequence length of each data set is equal to 256, but for data sets with the number of channels less than 64/32, do we need to complete it to 64/32? 32?

RuntimeError: The size of tensor a (341) must match the size of tensor b (286) at non-singleton dimension 1

I put the modified cnt dataset on the model and ran it sending some errors.
D:\ProgramData\anaconda3\envs\labram\python.exe E:\lab\DL\LaBraM-main\run_class_finetuning.py
Not using distributed mode
Namespace(batch_size=64, epochs=30, update_freq=1, save_ckpt_freq=5, robust_test=None, model='labram_base_patch200_200', qkv_bias=True, rel_pos_bias=True, abs_pos_emb=True, layer_scale_init_value=0.1, input_size=200, drop=0.0, attn_drop_rate=0.0, drop_path=0.1, disable_eval_during_finetuning=False, model_ema=False, model_ema_decay=0.9999, model_ema_force_cpu=False, opt='adamw', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=0.05, weight_decay_end=None, lr=0.0005, layer_decay=0.9, warmup_lr=1e-06, min_lr=1e-06, warmup_epochs=5, warmup_steps=-1, smoothing=0.1, reprob=0.25, remode='pixel', recount=1, resplit=False, finetune='', model_key='model|module', model_prefix='', model_filter_name='gzp', init_scale=0.001, use_mean_pooling=True, disable_weight_decay_on_rel_pos_bias=False, nb_classes=4, output_dir='E:/lab/DL/LaBraM-main/checkpoints/finetune_MI_base', log_dir='E:/lab/DL/LaBraM-main/log/finetune_MI_base', device='cuda', seed=0, resume='', auto_resume=True, save_ckpt=True, start_epoch=0, eval=False, dist_eval=False, num_workers=10, pin_mem=True, world_size=1, local_rank=-1, dist_on_itp=False, dist_url='env://', enable_deepspeed=False, dataset='MI', distributed=False)
2199 399 680
Sampler_train = <torch.utils.data.distributed.DistributedSampler object at 0x000001F93E230D50>
Patch size = 200
Model = NeuralTransformer(
(patch_embed): TemporalConv(
(conv1): Conv2d(1, 8, kernel_size=(1, 15), stride=(1, 8), padding=(0, 7))
(gelu1): GELU(approximate='none')
(norm1): GroupNorm(4, 8, eps=1e-05, affine=True)
(conv2): Conv2d(8, 8, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1))
(gelu2): GELU(approximate='none')
(norm2): GroupNorm(4, 8, eps=1e-05, affine=True)
(conv3): Conv2d(8, 8, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1))
(norm3): GroupNorm(4, 8, eps=1e-05, affine=True)
(gelu3): GELU(approximate='none')
)
(pos_drop): Dropout(p=0.0, inplace=False)
(blocks): ModuleList(
(0): Block(
(norm1): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=200, out_features=600, bias=False)
(q_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(k_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=200, out_features=200, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): Identity()
(norm2): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=200, out_features=800, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=800, out_features=200, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(1): Block(
(norm1): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=200, out_features=600, bias=False)
(q_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(k_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=200, out_features=200, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): DropPath(p=0.00909090880304575)
(norm2): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=200, out_features=800, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=800, out_features=200, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(2): Block(
(norm1): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=200, out_features=600, bias=False)
(q_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(k_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=200, out_features=200, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): DropPath(p=0.0181818176060915)
(norm2): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=200, out_features=800, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=800, out_features=200, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(3): Block(
(norm1): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=200, out_features=600, bias=False)
(q_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(k_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=200, out_features=200, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): DropPath(p=0.027272727340459824)
(norm2): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=200, out_features=800, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=800, out_features=200, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(4): Block(
(norm1): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=200, out_features=600, bias=False)
(q_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(k_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=200, out_features=200, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): DropPath(p=0.036363635212183)
(norm2): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=200, out_features=800, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=800, out_features=200, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(5): Block(
(norm1): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=200, out_features=600, bias=False)
(q_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(k_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=200, out_features=200, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): DropPath(p=0.045454543083906174)
(norm2): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=200, out_features=800, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=800, out_features=200, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(6): Block(
(norm1): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=200, out_features=600, bias=False)
(q_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(k_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=200, out_features=200, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): DropPath(p=0.054545458406209946)
(norm2): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=200, out_features=800, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=800, out_features=200, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(7): Block(
(norm1): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=200, out_features=600, bias=False)
(q_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(k_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=200, out_features=200, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): DropPath(p=0.06363636255264282)
(norm2): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=200, out_features=800, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=800, out_features=200, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(8): Block(
(norm1): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=200, out_features=600, bias=False)
(q_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(k_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=200, out_features=200, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): DropPath(p=0.0727272778749466)
(norm2): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=200, out_features=800, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=800, out_features=200, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(9): Block(
(norm1): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=200, out_features=600, bias=False)
(q_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(k_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=200, out_features=200, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): DropPath(p=0.08181818574666977)
(norm2): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=200, out_features=800, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=800, out_features=200, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(10): Block(
(norm1): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=200, out_features=600, bias=False)
(q_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(k_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=200, out_features=200, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): DropPath(p=0.09090909361839294)
(norm2): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=200, out_features=800, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=800, out_features=200, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(11): Block(
(norm1): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=200, out_features=600, bias=False)
(q_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(k_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=200, out_features=200, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): DropPath(p=0.10000000149011612)
(norm2): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=200, out_features=800, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=800, out_features=200, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
)
(norm): Identity()
(fc_norm): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(head): Linear(in_features=200, out_features=4, bias=True)
)
number of params: 5825540
LR = 0.00050000
Batch size = 64
Update frequent = 1
Number of training examples = 2199
Number of training training per epoch = 34
Assigned values = [0.2541865828329001, 0.2824295364810001, 0.31381059609000006, 0.3486784401000001, 0.3874204890000001, 0.4304672100000001, 0.4782969000000001, 0.531441, 0.5904900000000001, 0.6561, 0.7290000000000001, 0.81, 0.9, 1.0]
Skip weight decay name marked in model: {'time_embed', 'cls_token', 'pos_embed'}
Param groups = {
"layer_0_no_decay": {
"weight_decay": 0.0,
"params": [
"cls_token",
"pos_embed",
"patch_embed.conv1.bias",
"patch_embed.norm1.weight",
"patch_embed.norm1.bias",
"patch_embed.conv2.bias",
"patch_embed.norm2.weight",
"patch_embed.norm2.bias",
"patch_embed.conv3.bias",
"patch_embed.norm3.weight",
"patch_embed.norm3.bias"
],
"lr_scale": 0.2541865828329001
},
"layer_13_no_decay": {
"weight_decay": 0.0,
"params": [
"time_embed",
"fc_norm.weight",
"fc_norm.bias",
"head.bias"
],
"lr_scale": 1.0
},
"layer_0_decay": {
"weight_decay": 0.05,
"params": [
"patch_embed.conv1.weight",
"patch_embed.conv2.weight",
"patch_embed.conv3.weight"
],
"lr_scale": 0.2541865828329001
},
"layer_1_no_decay": {
"weight_decay": 0.0,
"params": [
"blocks.0.gamma_1",
"blocks.0.gamma_2",
"blocks.0.norm1.weight",
"blocks.0.norm1.bias",
"blocks.0.attn.q_bias",
"blocks.0.attn.v_bias",
"blocks.0.attn.q_norm.weight",
"blocks.0.attn.q_norm.bias",
"blocks.0.attn.k_norm.weight",
"blocks.0.attn.k_norm.bias",
"blocks.0.attn.proj.bias",
"blocks.0.norm2.weight",
"blocks.0.norm2.bias",
"blocks.0.mlp.fc1.bias",
"blocks.0.mlp.fc2.bias"
],
"lr_scale": 0.2824295364810001
},
"layer_1_decay": {
"weight_decay": 0.05,
"params": [
"blocks.0.attn.qkv.weight",
"blocks.0.attn.proj.weight",
"blocks.0.mlp.fc1.weight",
"blocks.0.mlp.fc2.weight"
],
"lr_scale": 0.2824295364810001
},
"layer_2_no_decay": {
"weight_decay": 0.0,
"params": [
"blocks.1.gamma_1",
"blocks.1.gamma_2",
"blocks.1.norm1.weight",
"blocks.1.norm1.bias",
"blocks.1.attn.q_bias",
"blocks.1.attn.v_bias",
"blocks.1.attn.q_norm.weight",
"blocks.1.attn.q_norm.bias",
"blocks.1.attn.k_norm.weight",
"blocks.1.attn.k_norm.bias",
"blocks.1.attn.proj.bias",
"blocks.1.norm2.weight",
"blocks.1.norm2.bias",
"blocks.1.mlp.fc1.bias",
"blocks.1.mlp.fc2.bias"
],
"lr_scale": 0.31381059609000006
},
"layer_2_decay": {
"weight_decay": 0.05,
"params": [
"blocks.1.attn.qkv.weight",
"blocks.1.attn.proj.weight",
"blocks.1.mlp.fc1.weight",
"blocks.1.mlp.fc2.weight"
],
"lr_scale": 0.31381059609000006
},
"layer_3_no_decay": {
"weight_decay": 0.0,
"params": [
"blocks.2.gamma_1",
"blocks.2.gamma_2",
"blocks.2.norm1.weight",
"blocks.2.norm1.bias",
"blocks.2.attn.q_bias",
"blocks.2.attn.v_bias",
"blocks.2.attn.q_norm.weight",
"blocks.2.attn.q_norm.bias",
"blocks.2.attn.k_norm.weight",
"blocks.2.attn.k_norm.bias",
"blocks.2.attn.proj.bias",
"blocks.2.norm2.weight",
"blocks.2.norm2.bias",
"blocks.2.mlp.fc1.bias",
"blocks.2.mlp.fc2.bias"
],
"lr_scale": 0.3486784401000001
},
"layer_3_decay": {
"weight_decay": 0.05,
"params": [
"blocks.2.attn.qkv.weight",
"blocks.2.attn.proj.weight",
"blocks.2.mlp.fc1.weight",
"blocks.2.mlp.fc2.weight"
],
"lr_scale": 0.3486784401000001
},
"layer_4_no_decay": {
"weight_decay": 0.0,
"params": [
"blocks.3.gamma_1",
"blocks.3.gamma_2",
"blocks.3.norm1.weight",
"blocks.3.norm1.bias",
"blocks.3.attn.q_bias",
"blocks.3.attn.v_bias",
"blocks.3.attn.q_norm.weight",
"blocks.3.attn.q_norm.bias",
"blocks.3.attn.k_norm.weight",
"blocks.3.attn.k_norm.bias",
"blocks.3.attn.proj.bias",
"blocks.3.norm2.weight",
"blocks.3.norm2.bias",
"blocks.3.mlp.fc1.bias",
"blocks.3.mlp.fc2.bias"
],
"lr_scale": 0.3874204890000001
},
"layer_4_decay": {
"weight_decay": 0.05,
"params": [
"blocks.3.attn.qkv.weight",
"blocks.3.attn.proj.weight",
"blocks.3.mlp.fc1.weight",
"blocks.3.mlp.fc2.weight"
],
"lr_scale": 0.3874204890000001
},
"layer_5_no_decay": {
"weight_decay": 0.0,
"params": [
"blocks.4.gamma_1",
"blocks.4.gamma_2",
"blocks.4.norm1.weight",
"blocks.4.norm1.bias",
"blocks.4.attn.q_bias",
"blocks.4.attn.v_bias",
"blocks.4.attn.q_norm.weight",
"blocks.4.attn.q_norm.bias",
"blocks.4.attn.k_norm.weight",
"blocks.4.attn.k_norm.bias",
"blocks.4.attn.proj.bias",
"blocks.4.norm2.weight",
"blocks.4.norm2.bias",
"blocks.4.mlp.fc1.bias",
"blocks.4.mlp.fc2.bias"
],
"lr_scale": 0.4304672100000001
},
"layer_5_decay": {
"weight_decay": 0.05,
"params": [
"blocks.4.attn.qkv.weight",
"blocks.4.attn.proj.weight",
"blocks.4.mlp.fc1.weight",
"blocks.4.mlp.fc2.weight"
],
"lr_scale": 0.4304672100000001
},
"layer_6_no_decay": {
"weight_decay": 0.0,
"params": [
"blocks.5.gamma_1",
"blocks.5.gamma_2",
"blocks.5.norm1.weight",
"blocks.5.norm1.bias",
"blocks.5.attn.q_bias",
"blocks.5.attn.v_bias",
"blocks.5.attn.q_norm.weight",
"blocks.5.attn.q_norm.bias",
"blocks.5.attn.k_norm.weight",
"blocks.5.attn.k_norm.bias",
"blocks.5.attn.proj.bias",
"blocks.5.norm2.weight",
"blocks.5.norm2.bias",
"blocks.5.mlp.fc1.bias",
"blocks.5.mlp.fc2.bias"
],
"lr_scale": 0.4782969000000001
},
"layer_6_decay": {
"weight_decay": 0.05,
"params": [
"blocks.5.attn.qkv.weight",
"blocks.5.attn.proj.weight",
"blocks.5.mlp.fc1.weight",
"blocks.5.mlp.fc2.weight"
],
"lr_scale": 0.4782969000000001
},
"layer_7_no_decay": {
"weight_decay": 0.0,
"params": [
"blocks.6.gamma_1",
"blocks.6.gamma_2",
"blocks.6.norm1.weight",
"blocks.6.norm1.bias",
"blocks.6.attn.q_bias",
"blocks.6.attn.v_bias",
"blocks.6.attn.q_norm.weight",
"blocks.6.attn.q_norm.bias",
"blocks.6.attn.k_norm.weight",
"blocks.6.attn.k_norm.bias",
"blocks.6.attn.proj.bias",
"blocks.6.norm2.weight",
"blocks.6.norm2.bias",
"blocks.6.mlp.fc1.bias",
"blocks.6.mlp.fc2.bias"
],
"lr_scale": 0.531441
},
"layer_7_decay": {
"weight_decay": 0.05,
"params": [
"blocks.6.attn.qkv.weight",
"blocks.6.attn.proj.weight",
"blocks.6.mlp.fc1.weight",
"blocks.6.mlp.fc2.weight"
],
"lr_scale": 0.531441
},
"layer_8_no_decay": {
"weight_decay": 0.0,
"params": [
"blocks.7.gamma_1",
"blocks.7.gamma_2",
"blocks.7.norm1.weight",
"blocks.7.norm1.bias",
"blocks.7.attn.q_bias",
"blocks.7.attn.v_bias",
"blocks.7.attn.q_norm.weight",
"blocks.7.attn.q_norm.bias",
"blocks.7.attn.k_norm.weight",
"blocks.7.attn.k_norm.bias",
"blocks.7.attn.proj.bias",
"blocks.7.norm2.weight",
"blocks.7.norm2.bias",
"blocks.7.mlp.fc1.bias",
"blocks.7.mlp.fc2.bias"
],
"lr_scale": 0.5904900000000001
},
"layer_8_decay": {
"weight_decay": 0.05,
"params": [
"blocks.7.attn.qkv.weight",
"blocks.7.attn.proj.weight",
"blocks.7.mlp.fc1.weight",
"blocks.7.mlp.fc2.weight"
],
"lr_scale": 0.5904900000000001
},
"layer_9_no_decay": {
"weight_decay": 0.0,
"params": [
"blocks.8.gamma_1",
"blocks.8.gamma_2",
"blocks.8.norm1.weight",
"blocks.8.norm1.bias",
"blocks.8.attn.q_bias",
"blocks.8.attn.v_bias",
"blocks.8.attn.q_norm.weight",
"blocks.8.attn.q_norm.bias",
"blocks.8.attn.k_norm.weight",
"blocks.8.attn.k_norm.bias",
"blocks.8.attn.proj.bias",
"blocks.8.norm2.weight",
"blocks.8.norm2.bias",
"blocks.8.mlp.fc1.bias",
"blocks.8.mlp.fc2.bias"
],
"lr_scale": 0.6561
},
"layer_9_decay": {
"weight_decay": 0.05,
"params": [
"blocks.8.attn.qkv.weight",
"blocks.8.attn.proj.weight",
"blocks.8.mlp.fc1.weight",
"blocks.8.mlp.fc2.weight"
],
"lr_scale": 0.6561
},
"layer_10_no_decay": {
"weight_decay": 0.0,
"params": [
"blocks.9.gamma_1",
"blocks.9.gamma_2",
"blocks.9.norm1.weight",
"blocks.9.norm1.bias",
"blocks.9.attn.q_bias",
"blocks.9.attn.v_bias",
"blocks.9.attn.q_norm.weight",
"blocks.9.attn.q_norm.bias",
"blocks.9.attn.k_norm.weight",
"blocks.9.attn.k_norm.bias",
"blocks.9.attn.proj.bias",
"blocks.9.norm2.weight",
"blocks.9.norm2.bias",
"blocks.9.mlp.fc1.bias",
"blocks.9.mlp.fc2.bias"
],
"lr_scale": 0.7290000000000001
},
"layer_10_decay": {
"weight_decay": 0.05,
"params": [
"blocks.9.attn.qkv.weight",
"blocks.9.attn.proj.weight",
"blocks.9.mlp.fc1.weight",
"blocks.9.mlp.fc2.weight"
],
"lr_scale": 0.7290000000000001
},
"layer_11_no_decay": {
"weight_decay": 0.0,
"params": [
"blocks.10.gamma_1",
"blocks.10.gamma_2",
"blocks.10.norm1.weight",
"blocks.10.norm1.bias",
"blocks.10.attn.q_bias",
"blocks.10.attn.v_bias",
"blocks.10.attn.q_norm.weight",
"blocks.10.attn.q_norm.bias",
"blocks.10.attn.k_norm.weight",
"blocks.10.attn.k_norm.bias",
"blocks.10.attn.proj.bias",
"blocks.10.norm2.weight",
"blocks.10.norm2.bias",
"blocks.10.mlp.fc1.bias",
"blocks.10.mlp.fc2.bias"
],
"lr_scale": 0.81
},
"layer_11_decay": {
"weight_decay": 0.05,
"params": [
"blocks.10.attn.qkv.weight",
"blocks.10.attn.proj.weight",
"blocks.10.mlp.fc1.weight",
"blocks.10.mlp.fc2.weight"
],
"lr_scale": 0.81
},
"layer_12_no_decay": {
"weight_decay": 0.0,
"params": [
"blocks.11.gamma_1",
"blocks.11.gamma_2",
"blocks.11.norm1.weight",
"blocks.11.norm1.bias",
"blocks.11.attn.q_bias",
"blocks.11.attn.v_bias",
"blocks.11.attn.q_norm.weight",
"blocks.11.attn.q_norm.bias",
"blocks.11.attn.k_norm.weight",
"blocks.11.attn.k_norm.bias",
"blocks.11.attn.proj.bias",
"blocks.11.norm2.weight",
"blocks.11.norm2.bias",
"blocks.11.mlp.fc1.bias",
"blocks.11.mlp.fc2.bias"
],
"lr_scale": 0.9
},
"layer_12_decay": {
"weight_decay": 0.05,
"params": [
"blocks.11.attn.qkv.weight",
"blocks.11.attn.proj.weight",
"blocks.11.mlp.fc1.weight",
"blocks.11.mlp.fc2.weight"
],
"lr_scale": 0.9
},
"layer_13_decay": {
"weight_decay": 0.05,
"params": [
"head.weight"
],
"lr_scale": 1.0
}
}
Optimizer config: {'lr': 0.0005, 'weight_decay': 0.0, 'eps': 1e-08}
Use step level LR scheduler!
Set warmup steps = 170
Set warmup steps = 0
Max WD = 0.0500000, Min WD = 0.0500000
criterion = LabelSmoothingCrossEntropy()
Auto resume checkpoint:
Start training for 30 epochs
Traceback (most recent call last):
File "E:\lab\DL\LaBraM-main\run_class_finetuning.py", line 582, in
main(opts, ds_init)
File "E:\lab\DL\LaBraM-main\run_class_finetuning.py", line 496, in main
train_stats = train_one_epoch(
^^^^^^^^^^^^^^^^
File "E:\lab\DL\LaBraM-main\engine_for_finetuning.py", line 77, in train_one_epoch
loss, output = train_class_batch(
^^^^^^^^^^^^^^^^^^
File "E:\lab\DL\LaBraM-main\engine_for_finetuning.py", line 19, in train_class_batch
outputs = model(samples, ch_names)
^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\ProgramData\anaconda3\envs\labram\Lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\ProgramData\anaconda3\envs\labram\Lib\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "E:\lab\DL\LaBraM-main\modeling_finetune.py", line 395, in forward
x = self.forward_features(x, input_chans=input_chans, return_patch_tokens=return_patch_tokens, return_all_tokens=return_all_tokens, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "E:\lab\DL\LaBraM-main\modeling_finetune.py", line 362, in forward_features
x = x + pos_embed
~~^~~~~~~~~~~
RuntimeError: The size of tensor a (341) must match the size of tensor b (286) at non-singleton dimension 1

进程已结束,退出代码为 1

Some help with reproduction

Hello @935963004,

I would like to starting say thank you for your work, I think it is a fundamental and necessary work in EEG decoding. Thank you for that!

So, I am trying to understand and run your code, but some things are not working, and I would like to request your assistance. From the beginning, with a toy example.

import torch
from torch import nn

from modeling_finetune import NeuralTransformer

# As commment in the meet, the expect input is:
# Batch size, channels, time//patch_size, patch_size

in_chans = 1  # **Not working if in_chans is different of 1. Issue with temporal_embedding.**
batch_size = 1
patch_size = 200
n_time_points_patched = 16  # Max number for patch, the value is hardcode in
# the model
EEG_size = 1600

# Generating an empty vector just to get the output.
X = torch.zeros(batch_size, in_chans, n_time_points_patched, patch_size)
# Everything is default
model = NeuralTransformer(
    EEG_size=EEG_size,
    patch_size=patch_size,
    in_chans=in_chans,
    out_chans=8,
    num_classes=1000,
    embed_dim=200,
    depth=12,
    num_heads=10,
    mlp_ratio=4.,
    qkv_bias=False,
    qk_norm=None,
    qk_scale=None,
    drop_rate=0.,
    attn_drop_rate=0.,
    drop_path_rate=0.,
    norm_layer=nn.LayerNorm,
    init_values=0, # default value is not working, changed from None to zero.
    use_abs_pos_emb=False,  # Not working
    use_rel_pos_bias=False, 
    use_shared_rel_pos_bias=False,
    use_mean_pooling=True,
    init_scale=0.001,
)

with torch.no_grad():
    y_pred = model(X)

My questions are:

  • How to make it work with any number of channels?
  • How do we solve the issue with positional embedding? And what about temporal embedding?
  • How to adapt the model to get something as input:
    "(batch, channel, time_steps)"

In my naive intuition if I change the in_chans everything should working because of the TemporalConv module, but it's not.

FYI @LemonFace0309, @jonxuxu and @shahbuland, @RashikShahjahan

All the best!

关于预训练多少epoch最优的问题

作者,您好!

首先非常感谢您分享如此精彩的工作。我有一个小小的困惑,恳请您能给予解答。
就是在预训练阶段是否划分了验证集进行模型预训练效果评估,如果没有验证集,单纯基于所有预训练数据,如何判断预训练达到多少epoch后停止可以得到最优的预训练模型。

祝好!

TUAB fine-tuning replication: unexpected accuracies and loss values

Hi @935963004,

Thanks so much for your work with this project - really exciting stuff. I was trying out fine-tuning on the TUAB dataset and came across some unexpected numbers. It could be something on my end and I'll provide update(s) as I dig into it further but wanted to get a thread started. Thanks for any help!

Created the TUAB dataset with make_TUAB.py and then running fine tuning with the recommended settings in the readme (except for setting to 1 gpu, increasing batch size, and num workers, and removing --dist_eval tag):

!OMP_NUM_THREADS=1 torchrun --nnodes=1 --nproc_per_node=1 run_class_finetuning.py \
        --output_dir ./checkpoints/finetune_tuab_base/ \
        --log_dir ./log/finetune_tuab_base \
        --model labram_base_patch200_200 \
        --finetune ./checkpoints/labram-base.pth \
        --weight_decay 0.05 \
        --batch_size 512 \
        --lr 5e-4 \
        --update_freq 1 \
        --warmup_epochs 5 \
        --epochs 50 \
        --layer_decay 0.65 \
        --drop_path 0.1 \
        --save_ckpt_freq 5 \
        --disable_rel_pos_bias \
        --abs_pos_emb \
        --dataset TUAB \
        --disable_qkv_bias \
        --seed 0 \
        --num_workers 12 \

Below is some sample output I'm seeing:

...
Use step level LR scheduler!
Set warmup steps = 605
Set warmup steps = 0
Max WD = 0.0500000, Min WD = 0.0500000
criterion = BCEWithLogitsLoss()
Auto resume checkpoint: 
Start training for 50 epochs
[W reducer.cpp:1346] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration,  which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
Warning: NaN or Inf found in input tensor.
Epoch: [0]  [  0/121]  eta: 3 days, 16:01:13  lr: 0.000000  min_lr: 0.000000  loss: 0.6931 (0.6931)  class_acc: 1.0000 (1.0000)  loss_scale: 32768.0000 (32768.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: inf (inf)  time: 2618.7900  data: 2615.6816  max mem: 34391
Warning: NaN or Inf found in input tensor.
Warning: NaN or Inf found in input tensor.
Epoch: [0]  [ 10/121]  eta: 7:21:42  lr: 0.000008  min_lr: 0.000000  loss: 0.6929 (0.6927)  class_acc: 1.0000 (0.9998)  loss_scale: 8192.0000 (11170.9091)  weight_decay: 0.0500 (0.0500)  grad_norm: 6.2091 (inf)  time: 238.7637  data: 237.7895  max mem: 34433
Epoch: [0]  [ 20/121]  eta: 3:31:06  lr: 0.000017  min_lr: 0.000000  loss: 0.6915 (0.6911)  class_acc: 1.0000 (0.9999)  loss_scale: 8192.0000 (9752.3810)  weight_decay: 0.0500 (0.0500)  grad_norm: 6.1851 (inf)  time: 0.7382  data: 0.0030  max mem: 34433
Epoch: [0]  [ 30/121]  eta: 2:09:22  lr: 0.000025  min_lr: 0.000000  loss: 0.6854 (0.6872)  class_acc: 1.0000 (0.9999)  loss_scale: 8192.0000 (9249.0323)  weight_decay: 0.0500 (0.0500)  grad_norm: 6.0339 (inf)  time: 0.9033  data: 0.1924  max mem: 34433
Epoch: [0]  [ 40/121]  eta: 1:27:24  lr: 0.000033  min_lr: 0.000000  loss: 0.6681 (0.6793)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8991.2195)  weight_decay: 0.0500 (0.0500)  grad_norm: 5.7410 (inf)  time: 1.0505  data: 0.3342  max mem: 34433
Epoch: [0]  [ 50/121]  eta: 1:01:48  lr: 0.000041  min_lr: 0.000000  loss: 0.6368 (0.6673)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8834.5098)  weight_decay: 0.0500 (0.0500)  grad_norm: 5.5797 (inf)  time: 0.9733  data: 0.2568  max mem: 34433
Epoch: [0]  [ 60/121]  eta: 0:44:31  lr: 0.000050  min_lr: 0.000000  loss: 0.5923 (0.6513)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8729.1803)  weight_decay: 0.0500 (0.0500)  grad_norm: 5.6199 (inf)  time: 0.8588  data: 0.1478  max mem: 34433
Epoch: [0]  [ 70/121]  eta: 0:32:04  lr: 0.000058  min_lr: 0.000000  loss: 0.5391 (0.6321)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8653.5211)  weight_decay: 0.0500 (0.0500)  grad_norm: 5.5495 (inf)  time: 0.7452  data: 0.0358  max mem: 34433
Epoch: [0]  [ 80/121]  eta: 0:22:40  lr: 0.000066  min_lr: 0.000000  loss: 0.4821 (0.6105)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8596.5432)  weight_decay: 0.0500 (0.0500)  grad_norm: 5.2271 (inf)  time: 0.7910  data: 0.0819  max mem: 34433
Epoch: [0]  [ 90/121]  eta: 0:15:19  lr: 0.000075  min_lr: 0.000000  loss: 0.4252 (0.5875)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8552.0879)  weight_decay: 0.0500 (0.0500)  grad_norm: 4.8080 (inf)  time: 1.0040  data: 0.2929  max mem: 34433
Epoch: [0]  [100/121]  eta: 0:09:22  lr: 0.000083  min_lr: 0.000000  loss: 0.3711 (0.5638)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8516.4356)  weight_decay: 0.0500 (0.0500)  grad_norm: 4.3451 (inf)  time: 0.9471  data: 0.2359  max mem: 34433
Epoch: [0]  [110/121]  eta: 0:04:28  lr: 0.000091  min_lr: 0.000000  loss: 0.3215 (0.5400)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8487.2072)  weight_decay: 0.0500 (0.0500)  grad_norm: 3.8768 (inf)  time: 0.7331  data: 0.0248  max mem: 34433
Epoch: [0]  [120/121]  eta: 0:00:22  lr: 0.000099  min_lr: 0.000000  loss: 0.2762 (0.5166)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8462.8099)  weight_decay: 0.0500 (0.0500)  grad_norm: 3.4341 (inf)  time: 0.7080  data: 0.0001  max mem: 34433
Epoch: [0] Total time: 0:45:20 (22.4855 s / it)
Averaged stats: lr: 0.000099  min_lr: 0.000000  loss: 0.2762 (0.5166)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8462.8099)  weight_decay: 0.0500 (0.0500)  grad_norm: 3.4341 (inf)
Val:  [ 0/20]  eta: 2:50:41  loss: 0.2346 (0.2346)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 512.0632  data: 511.3996  max mem: 34433
Val:  [10/20]  eta: 0:09:48  loss: 0.2346 (0.2346)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 58.8147  data: 58.4929  max mem: 34433
Val:  [19/20]  eta: 0:01:01  loss: 0.2346 (0.2346)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 61.3067  data: 61.0130  max mem: 34433
Val: Total time: 0:20:26 (61.3127 s / it)
* loss 0.235
Accuracy of the network on the 14718 val EEG: 0.00%
Test:  [ 0/11]  eta: 1:34:08  loss: 0.2346 (0.2346)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 513.4699  data: 513.1569  max mem: 34433
Test:  [10/11]  eta: 0:00:48  loss: 0.2346 (0.2346)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 48.5245  data: 48.2462  max mem: 34433
Test: Total time: 0:08:53 (48.5333 s / it)
* loss 0.235
Accuracy of the network on the 14718 test EEG: 0.00%
Max accuracy val: 0.00%, max accuracy test: 0.00%
Epoch: [1]  [  0/121]  eta: 0:25:29  lr: 0.000100  min_lr: 0.000000  loss: 0.2362 (0.2362)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 3.0189 (3.0189)  time: 12.6372  data: 11.8822  max mem: 34435
Epoch: [1]  [ 10/121]  eta: 0:03:19  lr: 0.000108  min_lr: 0.000000  loss: 0.2185 (0.2189)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 2.8253 (2.8273)  time: 1.7946  data: 1.0805  max mem: 34435
Epoch: [1]  [ 20/121]  eta: 0:02:23  lr: 0.000117  min_lr: 0.000000  loss: 0.1991 (0.2033)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 2.6067 (2.6502)  time: 0.8571  data: 0.1461  max mem: 34435
Epoch: [1]  [ 30/121]  eta: 0:01:58  lr: 0.000125  min_lr: 0.000000  loss: 0.1711 (0.1894)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 2.2782 (2.4876)  time: 1.0270  data: 0.3148  max mem: 34435
Epoch: [1]  [ 40/121]  eta: 0:01:40  lr: 0.000133  min_lr: 0.000000  loss: 0.1473 (0.1768)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 1.9917 (2.3386)  time: 1.0426  data: 0.3300  max mem: 34435
Epoch: [1]  [ 50/121]  eta: 0:01:25  lr: 0.000142  min_lr: 0.000001  loss: 0.1273 (0.1656)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 1.7436 (2.2026)  time: 1.0637  data: 0.3517  max mem: 34435
Epoch: [1]  [ 60/121]  eta: 0:01:12  lr: 0.000150  min_lr: 0.000001  loss: 0.1103 (0.1555)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 1.5298 (2.0785)  time: 1.1007  data: 0.3882  max mem: 34435
Epoch: [1]  [ 70/121]  eta: 0:00:57  lr: 0.000158  min_lr: 0.000001  loss: 0.0962 (0.1463)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 1.3468 (1.9653)  time: 0.9096  data: 0.1978  max mem: 34435
Epoch: [1]  [ 80/121]  eta: 0:00:46  lr: 0.000166  min_lr: 0.000001  loss: 0.0842 (0.1381)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 1.1898 (1.8619)  time: 0.9218  data: 0.2102  max mem: 34435
Epoch: [1]  [ 90/121]  eta: 0:00:34  lr: 0.000175  min_lr: 0.000001  loss: 0.0741 (0.1306)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 1.0556 (1.7673)  time: 1.1428  data: 0.4293  max mem: 34435
Epoch: [1]  [100/121]  eta: 0:00:23  lr: 0.000183  min_lr: 0.000001  loss: 0.0654 (0.1238)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.9381 (1.6807)  time: 1.1746  data: 0.4622  max mem: 34435
Epoch: [1]  [110/121]  eta: 0:00:12  lr: 0.000191  min_lr: 0.000001  loss: 0.0580 (0.1176)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.8384 (1.6012)  time: 1.2109  data: 0.4988  max mem: 34435
Epoch: [1]  [120/121]  eta: 0:00:01  lr: 0.000200  min_lr: 0.000001  loss: 0.0518 (0.1120)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.7521 (1.5281)  time: 0.9662  data: 0.2559  max mem: 34435
Epoch: [1] Total time: 0:02:13 (1.1073 s / it)
Averaged stats: lr: 0.000200  min_lr: 0.000001  loss: 0.0518 (0.1120)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.7521 (1.5281)
Val:  [ 0/20]  eta: 0:05:09  loss: 0.0461 (0.0461)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 15.4873  data: 15.0481  max mem: 34435
Val:  [10/20]  eta: 0:00:16  loss: 0.0461 (0.0461)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 1.6699  data: 1.3682  max mem: 34435
Val:  [19/20]  eta: 0:00:01  loss: 0.0461 (0.0461)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 1.3343  data: 1.0552  max mem: 34435
Val: Total time: 0:00:26 (1.3377 s / it)
* loss 0.046
Accuracy of the network on the 14718 val EEG: 0.00%
Test:  [ 0/11]  eta: 0:02:29  loss: 0.0461 (0.0461)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 13.6329  data: 13.3226  max mem: 34435
Test:  [10/11]  eta: 0:00:01  loss: 0.0461 (0.0461)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 1.4797  data: 1.2113  max mem: 34435
Test: Total time: 0:00:16 (1.4852 s / it)
* loss 0.046
Accuracy of the network on the 14718 test EEG: 0.00%
Max accuracy val: 0.00%, max accuracy test: 0.00%
Epoch: [2]  [  0/121]  eta: 0:21:41  lr: 0.000200  min_lr: 0.000001  loss: 0.0463 (0.0463)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.6762 (0.6762)  time: 10.7563  data: 10.0010  max mem: 34435
Epoch: [2]  [ 10/121]  eta: 0:03:00  lr: 0.000209  min_lr: 0.000001  loss: 0.0437 (0.0438)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.6413 (0.6420)  time: 1.6236  data: 0.9095  max mem: 34435
Epoch: [2]  [ 20/121]  eta: 0:02:06  lr: 0.000217  min_lr: 0.000001  loss: 0.0410 (0.0416)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.6035 (0.6112)  time: 0.7737  data: 0.0637  max mem: 34435
Epoch: [2]  [ 30/121]  eta: 0:01:43  lr: 0.000225  min_lr: 0.000001  loss: 0.0370 (0.0396)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.5469 (0.5830)  time: 0.8707  data: 0.1589  max mem: 34435
Epoch: [2]  [ 40/121]  eta: 0:01:27  lr: 0.000233  min_lr: 0.000001  loss: 0.0335 (0.0378)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.4963 (0.5570)  time: 0.9061  data: 0.1927  max mem: 34435
Epoch: [2]  [ 50/121]  eta: 0:01:14  lr: 0.000242  min_lr: 0.000001  loss: 0.0304 (0.0361)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.4528 (0.5330)  time: 0.9294  data: 0.2166  max mem: 34435
Epoch: [2]  [ 60/121]  eta: 0:01:02  lr: 0.000250  min_lr: 0.000001  loss: 0.0277 (0.0345)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.4138 (0.5108)  time: 0.9026  data: 0.1909  max mem: 34435
Epoch: [2]  [ 70/121]  eta: 0:00:49  lr: 0.000258  min_lr: 0.000001  loss: 0.0252 (0.0331)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.3788 (0.4902)  time: 0.7824  data: 0.0718  max mem: 34435
Epoch: [2]  [ 80/121]  eta: 0:00:39  lr: 0.000267  min_lr: 0.000001  loss: 0.0231 (0.0317)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.3477 (0.4711)  time: 0.7819  data: 0.0711  max mem: 34435
Epoch: [2]  [ 90/121]  eta: 0:00:29  lr: 0.000275  min_lr: 0.000001  loss: 0.0212 (0.0305)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.3201 (0.4533)  time: 0.8553  data: 0.1451  max mem: 34435
Epoch: [2]  [100/121]  eta: 0:00:19  lr: 0.000283  min_lr: 0.000001  loss: 0.0195 (0.0293)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.2956 (0.4367)  time: 0.8650  data: 0.1550  max mem: 34435
Epoch: [2]  [110/121]  eta: 0:00:10  lr: 0.000291  min_lr: 0.000001  loss: 0.0180 (0.0283)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.2735 (0.4211)  time: 0.8804  data: 0.1692  max mem: 34435
Epoch: [2]  [120/121]  eta: 0:00:00  lr: 0.000300  min_lr: 0.000001  loss: 0.0166 (0.0272)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.2533 (0.4066)  time: 0.7990  data: 0.0883  max mem: 34435
Epoch: [2] Total time: 0:01:51 (0.9207 s / it)
Averaged stats: lr: 0.000300  min_lr: 0.000001  loss: 0.0166 (0.0272)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.2533 (0.4066)
Val:  [ 0/20]  eta: 0:04:56  loss: 0.0154 (0.0154)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 14.8285  data: 14.5373  max mem: 34435
Val:  [10/20]  eta: 0:00:15  loss: 0.0154 (0.0154)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 1.5975  data: 1.3217  max mem: 34435
Val:  [19/20]  eta: 0:00:01  loss: 0.0154 (0.0154)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 1.2781  data: 1.0134  max mem: 34435
Val: Total time: 0:00:25 (1.2811 s / it)
* loss 0.015
Accuracy of the network on the 14718 val EEG: 0.00%
Test:  [ 0/11]  eta: 0:02:27  loss: 0.0154 (0.0154)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 13.4538  data: 13.1486  max mem: 34435
Test:  [10/11]  eta: 0:00:01  loss: 0.0154 (0.0154)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 1.4634  data: 1.1955  max mem: 34435
Test: Total time: 0:00:16 (1.4693 s / it)
* loss 0.015
Accuracy of the network on the 14718 test EEG: 0.00%
Max accuracy val: 0.00%, max accuracy test: 0.00%
Epoch: [3]  [  0/121]  eta: 0:22:54  lr: 0.000300  min_lr: 0.000001  loss: 0.0154 (0.0154)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.2353 (0.2353)  time: 11.3583  data: 10.6085  max mem: 34435
Epoch: [3]  [ 10/121]  eta: 0:03:06  lr: 0.000309  min_lr: 0.000001  loss: 0.0148 (0.0148)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.2266 (0.2267)  time: 1.6779  data: 0.9647  max mem: 34435
Epoch: [3]  [ 20/121]  eta: 0:02:11  lr: 0.000317  min_lr: 0.000001  loss: 0.0142 (0.0143)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.2171 (0.2189)  time: 0.7983  data: 0.0875  max mem: 34435
Epoch: [3]  [ 30/121]  eta: 0:01:45  lr: 0.000325  min_lr: 0.000001  loss: 0.0132 (0.0138)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.2025 (0.2116)  time: 0.8703  data: 0.1595  max mem: 34435
Epoch: [3]  [ 40/121]  eta: 0:01:27  lr: 0.000334  min_lr: 0.000001  loss: 0.0122 (0.0133)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.1886 (0.2047)  time: 0.8464  data: 0.1368  max mem: 34435
Epoch: [3]  [ 50/121]  eta: 0:01:13  lr: 0.000342  min_lr: 0.000001  loss: 0.0114 (0.0129)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.1765 (0.1982)  time: 0.8564  data: 0.1469  max mem: 34435
Epoch: [3]  [ 60/121]  eta: 0:01:01  lr: 0.000350  min_lr: 0.000001  loss: 0.0107 (0.0125)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.1654 (0.1920)  time: 0.8801  data: 0.1703  max mem: 34435
Epoch: [3]  [ 70/121]  eta: 0:00:49  lr: 0.000358  min_lr: 0.000001  loss: 0.0100 (0.0121)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.1552 (0.1862)  time: 0.7979  data: 0.0883  max mem: 34435
Epoch: [3]  [ 80/121]  eta: 0:00:39  lr: 0.000367  min_lr: 0.000001  loss: 0.0094 (0.0117)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.1456 (0.1807)  time: 0.7810  data: 0.0717  max mem: 34435
Epoch: [3]  [ 90/121]  eta: 0:00:29  lr: 0.000375  min_lr: 0.000001  loss: 0.0088 (0.0114)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.1370 (0.1755)  time: 0.8510  data: 0.1414  max mem: 34435
Epoch: [3]  [100/121]  eta: 0:00:19  lr: 0.000383  min_lr: 0.000001  loss: 0.0083 (0.0110)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.1290 (0.1706)  time: 0.8608  data: 0.1515  max mem: 34435
Epoch: [3]  [110/121]  eta: 0:00:10  lr: 0.000392  min_lr: 0.000001  loss: 0.0078 (0.0107)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.1217 (0.1659)  time: 0.8725  data: 0.1636  max mem: 34435
Epoch: [3]  [120/121]  eta: 0:00:00  lr: 0.000400  min_lr: 0.000001  loss: 0.0073 (0.0104)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.1148 (0.1614)  time: 0.7904  data: 0.0820  max mem: 34435
Epoch: [3] Total time: 0:01:50 (0.9140 s / it)
Averaged stats: lr: 0.000400  min_lr: 0.000001  loss: 0.0073 (0.0104)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.1148 (0.1614)
Val:  [ 0/20]  eta: 0:04:54  loss: 0.0069 (0.0069)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 14.7361  data: 14.4306  max mem: 34435
Val:  [10/20]  eta: 0:00:15  loss: 0.0069 (0.0069)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 1.5994  data: 1.3224  max mem: 34435
Val:  [19/20]  eta: 0:00:01  loss: 0.0069 (0.0069)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 1.2757  data: 1.0101  max mem: 34435
Val: Total time: 0:00:25 (1.2790 s / it)
* loss 0.007
Accuracy of the network on the 14718 val EEG: 0.00%
Test:  [ 0/11]  eta: 0:02:28  loss: 0.0069 (0.0069)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 13.4998  data: 13.2017  max mem: 34435
Test:  [10/11]  eta: 0:00:01  loss: 0.0069 (0.0069)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 1.4676  data: 1.2003  max mem: 34435
Test: Total time: 0:00:16 (1.4737 s / it)
* loss 0.007
Accuracy of the network on the 14718 test EEG: 0.00%
Max accuracy val: 0.00%, max accuracy test: 0.00%
Epoch: [4]  [  0/121]  eta: 0:22:45  lr: 0.000401  min_lr: 0.000001  loss: 0.0069 (0.0069)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.1085 (0.1085)  time: 11.2845  data: 10.5294  max mem: 34435
Epoch: [4]  [ 10/121]  eta: 0:03:05  lr: 0.000409  min_lr: 0.000002  loss: 0.0067 (0.0067)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.1053 (0.1055)  time: 1.6719  data: 0.9574  max mem: 34435
Epoch: [4]  [ 20/121]  eta: 0:02:10  lr: 0.000417  min_lr: 0.000002  loss: 0.0065 (0.0065)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.1020 (0.1027)  time: 0.7890  data: 0.0790  max mem: 34435
Epoch: [4]  [ 30/121]  eta: 0:01:45  lr: 0.000425  min_lr: 0.000002  loss: 0.0061 (0.0063)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.0966 (0.1000)  time: 0.8712  data: 0.1616  max mem: 34435
Epoch: [4]  [ 40/121]  eta: 0:01:27  lr: 0.000434  min_lr: 0.000002  loss: 0.0058 (0.0062)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.0916 (0.0974)  time: 0.8668  data: 0.1574  max mem: 34435
Epoch: [4]  [ 50/121]  eta: 0:01:14  lr: 0.000442  min_lr: 0.000002  loss: 0.0055 (0.0060)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.0869 (0.0950)  time: 0.8685  data: 0.1592  max mem: 34435
Epoch: [4]  [ 60/121]  eta: 0:01:01  lr: 0.000450  min_lr: 0.000002  loss: 0.0052 (0.0059)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.0826 (0.0927)  time: 0.8772  data: 0.1675  max mem: 34435
Epoch: [4]  [ 70/121]  eta: 0:00:49  lr: 0.000459  min_lr: 0.000002  loss: 0.0049 (0.0057)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.0785 (0.0904)  time: 0.7934  data: 0.0831  max mem: 34435
Epoch: [4]  [ 80/121]  eta: 0:00:39  lr: 0.000467  min_lr: 0.000002  loss: 0.0047 (0.0056)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.0747 (0.0883)  time: 0.7916  data: 0.0816  max mem: 34435
Epoch: [4]  [ 90/121]  eta: 0:00:29  lr: 0.000475  min_lr: 0.000002  loss: 0.0045 (0.0054)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.0712 (0.0863)  time: 0.8624  data: 0.1533  max mem: 34435
Epoch: [4]  [100/121]  eta: 0:00:19  lr: 0.000483  min_lr: 0.000002  loss: 0.0042 (0.0053)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.0678 (0.0843)  time: 0.8471  data: 0.1383  max mem: 34435
Epoch: [4]  [110/121]  eta: 0:00:10  lr: 0.000492  min_lr: 0.000002  loss: 0.0040 (0.0052)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.0647 (0.0824)  time: 0.8458  data: 0.1372  max mem: 34435
Epoch: [4]  [120/121]  eta: 0:00:00  lr: 0.000500  min_lr: 0.000002  loss: 0.0039 (0.0051)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.0618 (0.0806)  time: 0.7797  data: 0.0708  max mem: 34435
Epoch: [4] Total time: 0:01:50 (0.9124 s / it)
Averaged stats: lr: 0.000500  min_lr: 0.000002  loss: 0.0039 (0.0051)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.0618 (0.0806)
Val:  [ 0/20]  eta: 0:04:55  loss: 0.0037 (0.0037)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 14.7555  data: 14.4493  max mem: 34435
Val:  [10/20]  eta: 0:00:15  loss: 0.0037 (0.0037)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 1.5907  data: 1.3137  max mem: 34435
Val:  [19/20]  eta: 0:00:01  loss: 0.0037 (0.0037)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 1.2756  data: 1.0098  max mem: 34435
Val: Total time: 0:00:25 (1.2788 s / it)
* loss 0.004
Accuracy of the network on the 14718 val EEG: 0.00%
Test:  [ 0/11]  eta: 0:02:27  loss: 0.0037 (0.0037)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 13.4220  data: 13.1036  max mem: 34435
Test:  [10/11]  eta: 0:00:01  loss: 0.0037 (0.0037)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 1.4606  data: 1.1914  max mem: 34435
Test: Total time: 0:00:16 (1.4663 s / it)
* loss 0.004
Accuracy of the network on the 14718 test EEG: 0.00%
Max accuracy val: 0.00%, max accuracy test: 0.00%
...

'RelativePositionBias' is lost

Traceback (most recent call last):
File "/root/autodl-tmp/LaBraM/run_labram_pretraining.py", line 28, in
import modeling_pretrain
File "/root/autodl-tmp/LaBraM/modeling_pretrain.py", line 16, in
from modeling_finetune import Block, _cfg, PatchEmbed, RelativePositionBias
ImportError: cannot import name 'RelativePositionBias' from 'modeling_finetune' (/root/autodl-tmp/LaBraM/modeling_finetune.py)

There is not 'RelativePositionBias' in modeling_finetune.py. How to solve?

Preprocessing data

Hi!

Thanks for the exciting work. I have a question regarding the preprocessing of the pertaining datasets.

In the readme, you say, "Notably, you can also write your own codes for preprocessing EEG data. Make sure that the preprocessing is consistent with that of our paper, that is, removing useless channels, filtering between 0.1 Hz and 75 Hz, notch filtering of 50 Hz, resampling to 200 Hz, and setting the unit to uV".

I'm wondering if you have some example code of precisely this type of preprocessing and if not, what do you mean by "removing useless channels," i.e., what is a useless channel?

可以提供一个模型使用的example吗?

我想对 LaBraM 在EEG数据上的编码向量进行分析(模型权重从labram-base.pth导入,之后就冻结,不需要作训练和微调)。

假设信号已完成0.1Hz-75Hz的带通滤波,50Hz的陷波滤波, 重采样至200Hz

输入:
sig.shape :(64, 2006060) # 64通道,200Hz采样率,一小时时长的EEG信号

输出:
LaBraM模型对该eeg的表征向量:V

V.shape :(64, n, V_dim)
64:通道数
n:EEG信号根据窗口宽度(1秒?),被分割为n=3600段
模型对信号的编码向量维度(好像是32维)

V = LaBraM( sig )

einops.EinopsError: Error while processing rearrange-reduction pattern "B N (A T) -> B N A T". Input tensor shape: torch.Size([23, 2000]). Additional info: {'T': 200}.

einops.EinopsError: Error while processing rearrange-reduction pattern "B N (A T) -> B N A T".
Input tensor shape: torch.Size([23, 2000]). Additional info: {'T': 200}.
Wrong shape: expected 3 dims. Received 2-dim tensor.
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 790759) of binary: /d/miniconda3/envs/lb_labram/bin/python

The TUAB data set and labram-base.pth model used,
But the input and output do not match
How to solve this error? Thank you so much

What are A and N in B N A T?

class TemporalConv(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, in_chans=1, out_chans=8):
        super().__init__()
        self.conv1 = nn.Conv2d(in_chans, out_chans, kernel_size=(1, 15), stride=(1, 8), padding=(0, 7))
        self.gelu1 = nn.GELU()
        self.norm1 = nn.GroupNorm(4, out_chans)
        self.conv2 = nn.Conv2d(out_chans, out_chans, kernel_size=(1, 3), padding=(0, 1))
        self.gelu2 = nn.GELU()
        self.norm2 = nn.GroupNorm(4, out_chans)
        self.conv3 = nn.Conv2d(out_chans, out_chans, kernel_size=(1, 3), padding=(0, 1))
        self.norm3 = nn.GroupNorm(4, out_chans)
        self.gelu3 = nn.GELU()

    def forward(self, x, **kwargs):
        x = rearrange(x, 'B N A T -> B (N A) T')
        B, NA, T = x.shape
        x = x.unsqueeze(1)
        x = self.gelu1(self.norm1(self.conv1(x)))
        x = self.gelu2(self.norm2(self.conv2(x)))
        x = self.gelu3(self.norm3(self.conv3(x)))
        x = rearrange(x, 'B C NA T -> B NA (T C)')
        return x

Hello. I am analyzing your code to utilize it, and in the forward method of the TemporalConv class in modeling_pretrain.py above, in the part where einops is rearranged, the input dimension is listed as 4-dimensional. I thought B is the batch size, N is the number of electrodes, and T is the sample length, but I couldn't figure out what A means. Also, I have a question about whether N is the number of electrodes because of A.

Creating embeddings

Hello, thanks so much for making your code available!

I'd like to embed EEG data using the model and wanted to check whether the following makes sense to you. Given a pretrained model, e.g. your provided base model, I'd avoid adding a classification head to NeuralTransformer and catch the output of self.forward_features(...). Given input signals of [batch_size, channels, samples] and default repo parameters, the model would then provide embeddings [batch_size, embed_dim].

Does that seem OK or would you recommend another approach?

Thanks for your help!

为什么不在pretrain的时候使用第一步训好的模型权重?

你好,谢谢你们分享如此精彩的项目,我仔细看了一下labram的代码,发现总览图中下半部分的模型就是neural tokenizer的模型,只是head不一样。也就是说在第一步训练tokenizer的时候就把位置编码和时间编码弄进去了。然后建模的时候是把单个patch变成feature,对所有patch一起建模的。我没想通的是为什么要在第二步训练下面这个LaBram模型的时候完全重新训练一个,把LaBram作为student模型。直接用第一步训好的neural tokenizer,加载这个权重,然后只换一个head不行吗?就像你们在第三步微调的时候那样,加载原有的模型权重,只重新训练head。真诚发问,希望可以得到回复。
image

Error: Unexpected key(s) in state_dict: "logit_scale".

Hi! Thank you for your great work. I am trying to load the pre-trained base model to extract feature embeddings (no fine-tuning). When I load the model, I get the following error:

raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for NeuralTransformerForMEM: Unexpected key(s) in state_dict: "logit_scale".

Any solution? Also, do you have any pointers like which python file should I run to extract the feature embeddings from the pre-trained model?

How to deal with different number of channels within one dataset?

Hi Weibang,

Thanks for sharing your excellent work! There's one thing that I'm not super clear about. When preprocessing a dataset like TUEP, in which the number of channels ranges from 19 to 23, how did you deal with this variation so that different samples could be batched and dumped into hdf5 files for later use?

Thanks in advance.

UserWarning: y_pred contains classes not in y_true warnings.warn("y_pred contains classes not in y_true")

Test: [7/8] eta: 0:00:06 loss: 1.3861 (1.3866) accuracy: 0.2396 (0.2500) balanced_accuracy: 0.2500 (0.2471) cohen_kappa: 0.0000 (0.0000) f1_weighted: 0.0926 (0.1013) time: 6.2649 data: 4.9884 max mem: 6735
Test: Total time: 0:00:51 (6.4069 s / it)

  • loss 1.387
    Accuracy of the network on the 680 test EEG: 0.25%
    Max accuracy val: 0.25%, max accuracy test: 0.25%
    I can't seem to see the effect of the training.

为什么TUAB数据集2000个采样为一个样本?

作者您好,非常感谢您的工作,我有一个疑问。那就是您的时间窗口是4或者8,但是在make_TUAB.py的脚本中,您隔2000个采样才取一个样本,而1s应该是200个采样,也就是时间窗口为10.请问这个设置有什么深意嘛?谢谢您
image

关于TUAB和TUEV数据集的预处理

您好,感谢您的优秀工作!我最近在复现您的实验,在处理原始TUAB和TUEV数据到h5数据集时遇到一些问题,您提供的代码是读取.cnt文件而,make_TUAB.py 和 make_TUEV.py 处理后的文件是存到.pkl格式的,请问这个如何处理成h5 dataset呢?我理解pretrain 模型的输入都是h5 dataset格式的,提前感谢您的耐心回复!

AttributeError: 'VQNSP' object has no attribute 'module'. Did you mean: 'modules'?

Hello! I get the following error while training the vqnsp.

raise AttributeError("'{}' object has no attribute '{}'".format( AttributeError: 'VQNSP' object has no attribute 'module'. Did you mean: 'modules'?

If I use "modules" instead of "module", then the code works. Is it required to update your code or am I missing something?

CUDA Error

I'm trying to fine tune the model and it's resulting in the following error

Start training for 50 epochs
/opt/conda/conda-bld/pytorch_1682343995622/work/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [52,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed. /opt/conda/conda-bld/pytorch_1682343995622/work/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [13,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed. /opt/conda/conda-bld/pytorch_1682343995622/work/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [25,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed. /opt/conda/conda-bld/pytorch_1682343995622/work/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [18,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed. /opt/conda/conda-bld/pytorch_1682343995622/work/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [60,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.

WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 3392161 closing signal SIGTERM ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -6) local_rank: 1 (pid: 3392162) of binary: /home/anaconda3/envs/labram/bin/python
Traceback (most recent call last):
File "/home/anaconda3/envs/labram/bin/torchrun", line 33, in
sys.exit(load_entry_point('torch==2.0.1', 'console_scripts', 'torchrun')())

NVCC

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Wed_Nov_22_10:17:15_PST_2023
Cuda compilation tools, release 12.3, V12.3.107
Build cuda_12.3.r12.3/compiler.33567101_0

Trying to run it on a node with 4 RTX A5000

Issue with cuda

Hello, and thank you for the awesome work!
I got the following error which is (I think) related to distributed training while try to run your example on colab pro plus

Traceback (most recent call last):
File "/content/LaBraM/run_class_finetuning.py", line 565, in
main(opts, ds_init)
File "/content/LaBraM/run_class_finetuning.py", line 232, in main
utils.init_distributed_mode(args)
File "/content/LaBraM/utils.py", line 428, in init_distributed_mode
torch.cuda.set_device(args.gpu)
File "/usr/local/lib/python3.10/dist-packages/torch/cuda/init.py", line 399, in set_device
torch._C._cuda_setDevice(device)
RuntimeError: CUDA error: invalid device ordinal
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

Traceback (most recent call last):
File "/content/LaBraM/run_class_finetuning.py", line 565, in
main(opts, ds_init)
File "/content/LaBraM/run_class_finetuning.py", line 232, in main
utils.init_distributed_mode(args)
File "/content/LaBraM/utils.py", line 428, in init_distributed_mode
torch.cuda.set_device(args.gpu)
File "/usr/local/lib/python3.10/dist-packages/torch/cuda/init.py", line 399, in set_device
torch._C._cuda_setDevice(device)
RuntimeError: CUDA error: invalid device ordinal
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

| distributed init (rank 0): env://, gpu 0
Traceback (most recent call last):
File "/content/LaBraM/run_class_finetuning.py", line 565, in
main(opts, ds_init)
File "/content/LaBraM/run_class_finetuning.py", line 232, in main
utils.init_distributed_mode(args)
File "/content/LaBraM/utils.py", line 428, in init_distributed_mode
torch.cuda.set_device(args.gpu)
File "/usr/local/lib/python3.10/dist-packages/torch/cuda/init.py", line 399, in set_device
torch._C._cuda_setDevice(device)
RuntimeError: CUDA error: invalid device ordinal
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

Traceback (most recent call last):
File "/content/LaBraM/run_class_finetuning.py", line 565, in
main(opts, ds_init)
File "/content/LaBraM/run_class_finetuning.py", line 232, in main
utils.init_distributed_mode(args)
File "/content/LaBraM/utils.py", line 428, in init_distributed_mode
torch.cuda.set_device(args.gpu)
File "/usr/local/lib/python3.10/dist-packages/torch/cuda/init.py", line 399, in set_device
torch._C._cuda_setDevice(device)
RuntimeError: CUDA error: invalid device ordinal
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

Traceback (most recent call last):
File "/content/LaBraM/run_class_finetuning.py", line 565, in
main(opts, ds_init)
File "/content/LaBraM/run_class_finetuning.py", line 232, in main
utils.init_distributed_mode(args)
File "/content/LaBraM/utils.py", line 428, in init_distributed_mode
torch.cuda.set_device(args.gpu)
File "/usr/local/lib/python3.10/dist-packages/torch/cuda/init.py", line 399, in set_device
torch._C._cuda_setDevice(device)
RuntimeError: CUDA error: invalid device ordinal
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

Traceback (most recent call last):
File "/content/LaBraM/run_class_finetuning.py", line 565, in
main(opts, ds_init)
File "/content/LaBraM/run_class_finetuning.py", line 232, in main
utils.init_distributed_mode(args)
File "/content/LaBraM/utils.py", line 428, in init_distributed_mode
torch.cuda.set_device(args.gpu)
File "/usr/local/lib/python3.10/dist-packages/torch/cuda/init.py", line 399, in set_device
torch._C._cuda_setDevice(device)
RuntimeError: CUDA error: invalid device ordinal
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

W0630 22:56:38.144000 139452743365248 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 20757 closing signal SIGTERM
W0630 22:56:38.144000 139452743365248 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 20758 closing signal SIGTERM
W0630 22:56:38.145000 139452743365248 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 20759 closing signal SIGTERM
W0630 22:56:38.145000 139452743365248 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 20761 closing signal SIGTERM
W0630 22:56:38.145000 139452743365248 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 20763 closing signal SIGTERM
W0630 22:56:38.145000 139452743365248 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 20764 closing signal SIGTERM
E0630 22:56:38.262000 139452743365248 torch/distributed/elastic/multiprocessing/api.py:826] failed (exitcode: 1) local_rank: 3 (pid: 20760) of binary: /usr/bin/python3
Traceback (most recent call last):
File "/usr/local/bin/torchrun", line 8, in
sys.exit(main())
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/errors/init.py", line 347, in wrapper
return f(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 879, in main
run(args)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 870, in run
elastic_launch(
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 132, in call
return launch_agent(self._config, self._entrypoint, list(args))
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 263, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:

run_class_finetuning.py FAILED

Failures:
[1]:
time : 2024-06-30_22:56:38
host : 7283cc30eeb0
rank : 5 (local_rank: 5)
exitcode : 1 (pid: 20762)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html

Root Cause (first observed failure):
[0]:
time : 2024-06-30_22:56:38
host : 7283cc30eeb0
rank : 3 (local_rank: 3)
exitcode : 1 (pid: 20760)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.