voidism / diffcse Goto Github PK
View Code? Open in Web Editor NEWCode for the NAACL 2022 long paper "DiffCSE: Difference-based Contrastive Learning for Sentence Embeddings"
Home Page: https://arxiv.org/abs/2204.10298
License: MIT License
Code for the NAACL 2022 long paper "DiffCSE: Difference-based Contrastive Learning for Sentence Embeddings"
Home Page: https://arxiv.org/abs/2204.10298
License: MIT License
Howdy,
I was wondering if any experiments were done with the DiffCSE framework for long inputs (300-500 tokens), or ie. are there any conditions on the training data necessary for convergence?
PS - congrats on the paper, it was a really fun read :)
I run run_diffcse.sh.
but,this error occurs:
AttributeError: module 'dill._dill' has no attribute 'stack'
Hello, I have solved the first two problems, but I encountered a new problem at 50% of the first epoch of the code:
Traceback (most recent call last):
File "/home/lizhaohui/DiffCSE/train.py", line 600, in
main()
File "/home/lizhaohui/DiffCSE/train.py", line 564, in main
train_result = trainer.train(model_path=model_path)
File "/home/lizhaohui/DiffCSE/diffcse/trainers.py", line 513, in train
tr_loss += self.training_step(model, inputs)
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/transformers/trainer.py", line 1248, in training_step
loss = self.compute_loss(model, inputs)
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/transformers/trainer.py", line 1277, in compute_loss
outputs = model(**inputs)
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 162, in forward
return self.gather(outputs, self.output_device)
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 174, in gather
return gather(outputs, output_device, dim=self.dim)
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/parallel/scatter_gather.py", line 68, in gather
res = gather_map(outputs)
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/parallel/scatter_gather.py", line 61, in gather_map
return type(out)(((k, gather_map([d[k] for d in outputs]))
File "", line 7, in init
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/transformers/file_utils.py", line 1383, in post_init
for element in iterator:
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/parallel/scatter_gather.py", line 61, in
return type(out)(((k, gather_map([d[k] for d in outputs]))
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/parallel/scatter_gather.py", line 55, in gather_map
return Gather.apply(target_device, dim, *outputs)
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/parallel/_functions.py", line 71, in forward
return comm.gather(inputs, ctx.dim, ctx.target_device)
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/parallel/comm.py", line 230, in gather return torch._C._gather(tensors, dim, destination)
RuntimeError: Input tensor at index 3 has invalid shape [14, 14], but expected [14, 17]
50%|█████████████████████████████████████▍ | 3906/7814 [1:00:50<1:00:51, 1.07it/s]Fatal Python error: PyEval_SaveThread: the function must be called with the GIL held, but the GIL is released (the current Python thread state is NULL)
Python runtime state: finalizing (tstate=0x559cd4786400)
run_diffcse.sh: line 30: 3465377 Aborted (core dumped) python train.py --model_name_or_path bert-base-uncased --generator_name distilbert-base-uncased --train_file data/wiki1m_for_simcse.txt --output_dir output_dir --num_train_epochs 2 --per_device_train_batch_size 64 --learning_rate 7e-6 --max_seq_length 32 --evaluation_strategy steps --metric_for_best_model stsb_spearman --load_best_model_at_end --eval_steps 125 --pooler_type cls --mlp_only_train --overwrite_output_dir --logging_first_step --logging_dir log_dir --temp 0.05 --do_train --do_eval --batchnorm --lambda_weight 0.005 --fp16 --masking_ratio 0.30
您好~ 想請問您有釋出 alignment & uniformity 的原始碼嗎? 或是哪裡可以了解該 metrics 的訓練細節?
期待您的回覆!
Hello, your work is very excellent, and I am interested in the grid search you mentioned in your paper. My question is how you do this grid search to find the best hyperparameters?
Did you set different parameters in run_diffcse.sh and then pick the best ones based on their performance on the STS-B development set? In that case, it might take many training experiments.
I've been trying to run your codes, but for some reason, it seems like it doesn't work on my settings.
I ran into this error mentioned below.
Is there any solution?
warnings.warn('Was asked to gather along dimension 0, but all '
Traceback (most recent call last):
File "train.py", line 600, in <module>
main()
File "train.py", line 564, in main
train_result = trainer.train(model_path=model_path)
File "/home/qmin/DiffCSE/diffcse/trainers.py", line 514, in train
tr_pos_sim += model.sim.pos_avg
File "/home/qmin/anaconda3/envs/diffcse/lib/python3.7/site-packages/torch/nn/modules/module.py", line 779, in __getattr__
type(self).__name__, name))
torch.nn.modules.module.ModuleAttributeError: 'DataParallel' object has no attribute 'sim'
I executed the source code of DiffSCE on my computational resource (Tesla V100-SXM2-32GB), using the identical configuration as specified in the run_diffcse.sh file. I obtained the following results, which differ from the results reported in your paper and on your GitHub repository. To illustrate, there is a 3.24-point difference (78.49 - 75.25 = 3.24) in average STS accuracy between your reported results and the results I obtained.
Do you have any insights or suggestions regarding the source of this disparity in performance when running the code to generate results? (@voidism)
[INFO|trainer.py:358] 2023-09-21 19:27:21,467 >> Using amp fp16 backend
09/21/2023 19:27:21 - INFO - __main__ - *** Evaluate ***
tasks: ['STSBenchmark', 'SICKRelatedness', 'STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'MR', 'CR', 'SUBJ', 'MPQA', 'SST2', 'MRPC', 'TREC']
./SentEval/senteval/sts.py:42: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
sent1 = np.array([s.split() for s in sent1])[not_empty_idx]
./SentEval/senteval/sts.py:43: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
sent2 = np.array([s.split() for s in sent2])[not_empty_idx]
09/21/2023 19:27:54 - INFO - root - Generating sentence embeddings
09/21/2023 19:28:02 - INFO - root - Generated sentence embeddings
09/21/2023 19:28:02 - INFO - root - Training pytorch-MLP-nhid0-rmsprop-bs128 with (inner) 5-fold cross-validation
09/21/2023 19:28:10 - INFO - root - Best param found at split 1: l2reg = 0.001 with score 82.31
09/21/2023 19:28:20 - INFO - root - Best param found at split 2: l2reg = 0.001 with score 81.99
09/21/2023 19:28:32 - INFO - root - Best param found at split 3: l2reg = 0.0001 with score 82.27
09/21/2023 19:28:42 - INFO - root - Best param found at split 4: l2reg = 0.01 with score 81.54
09/21/2023 19:28:53 - INFO - root - Best param found at split 5: l2reg = 0.0001 with score 82.04
09/21/2023 19:28:54 - INFO - root - Generating sentence embeddings
09/21/2023 19:28:56 - INFO - root - Generated sentence embeddings
09/21/2023 19:28:56 - INFO - root - Training pytorch-MLP-nhid0-rmsprop-bs128 with (inner) 5-fold cross-validation
09/21/2023 19:28:59 - INFO - root - Best param found at split 1: l2reg = 1e-05 with score 87.81
09/21/2023 19:29:03 - INFO - root - Best param found at split 2: l2reg = 0.0001 with score 88.15
09/21/2023 19:29:07 - INFO - root - Best param found at split 3: l2reg = 1e-05 with score 87.32
09/21/2023 19:29:11 - INFO - root - Best param found at split 4: l2reg = 1e-05 with score 87.05
09/21/2023 19:29:15 - INFO - root - Best param found at split 5: l2reg = 0.0001 with score 87.25
09/21/2023 19:29:15 - INFO - root - Generating sentence embeddings
09/21/2023 19:29:23 - INFO - root - Generated sentence embeddings
09/21/2023 19:29:23 - INFO - root - Training pytorch-MLP-nhid0-rmsprop-bs128 with (inner) 5-fold cross-validation
09/21/2023 19:29:32 - INFO - root - Best param found at split 1: l2reg = 0.001 with score 95.22
09/21/2023 19:29:42 - INFO - root - Best param found at split 2: l2reg = 1e-05 with score 95.51
09/21/2023 19:29:52 - INFO - root - Best param found at split 3: l2reg = 0.0001 with score 95.31
09/21/2023 19:30:01 - INFO - root - Best param found at split 4: l2reg = 0.001 with score 95.45
09/21/2023 19:30:09 - INFO - root - Best param found at split 5: l2reg = 0.0001 with score 95.46
09/21/2023 19:30:10 - INFO - root - Generating sentence embeddings
09/21/2023 19:30:12 - INFO - root - Generated sentence embeddings
09/21/2023 19:30:12 - INFO - root - Training pytorch-MLP-nhid0-rmsprop-bs128 with (inner) 5-fold cross-validation
09/21/2023 19:30:21 - INFO - root - Best param found at split 1: l2reg = 0.001 with score 89.16
09/21/2023 19:30:29 - INFO - root - Best param found at split 2: l2reg = 1e-05 with score 88.19
09/21/2023 19:30:37 - INFO - root - Best param found at split 3: l2reg = 0.001 with score 88.91
09/21/2023 19:30:45 - INFO - root - Best param found at split 4: l2reg = 0.001 with score 88.44
09/21/2023 19:30:54 - INFO - root - Best param found at split 5: l2reg = 0.001 with score 88.93
09/21/2023 19:30:55 - INFO - root - Computing embedding for train
09/21/2023 19:31:22 - INFO - root - Computed train embeddings
09/21/2023 19:31:22 - INFO - root - Computing embedding for dev
09/21/2023 19:31:23 - INFO - root - Computed dev embeddings
09/21/2023 19:31:23 - INFO - root - Computing embedding for test
09/21/2023 19:31:24 - INFO - root - Computed test embeddings
09/21/2023 19:31:24 - INFO - root - Training pytorch-MLP-nhid0-rmsprop-bs128 with standard validation..
09/21/2023 19:31:36 - INFO - root - [('reg:1e-05', 87.73), ('reg:0.0001', 87.84), ('reg:0.001', 87.61), ('reg:0.01', 86.93)]
09/21/2023 19:31:36 - INFO - root - Validation : best param found is reg = 0.0001 with score 87.84
09/21/2023 19:31:36 - INFO - root - Evaluating...
09/21/2023 19:31:39 - INFO - root - ***** Transfer task : MRPC *****
09/21/2023 19:31:39 - INFO - root - Computing embedding for train
09/21/2023 19:31:45 - INFO - root - Computed train embeddings
09/21/2023 19:31:45 - INFO - root - Computing embedding for test
09/21/2023 19:31:47 - INFO - root - Computed test embeddings
09/21/2023 19:31:47 - INFO - root - Training pytorch-MLP-nhid0-rmsprop-bs128 with 5-fold cross-validation
09/21/2023 19:31:51 - INFO - root - [('reg:1e-05', 74.85), ('reg:0.0001', 74.85), ('reg:0.001', 74.93), ('reg:0.01', 74.07)]
09/21/2023 19:31:51 - INFO - root - Cross-validation : best param found is reg = 0.001 with score 74.93
09/21/2023 19:31:51 - INFO - root - Evaluating...
09/21/2023 19:31:52 - INFO - root - ***** Transfer task : TREC *****
09/21/2023 19:31:54 - INFO - root - Computed train embeddings
09/21/2023 19:31:54 - INFO - root - Computed test embeddings
09/21/2023 19:31:54 - INFO - root - Training pytorch-MLP-nhid0-rmsprop-bs128 with 5-fold cross-validation
09/21/2023 19:32:00 - INFO - root - [('reg:1e-05', 84.15), ('reg:0.0001', 84.02), ('reg:0.001', 83.47), ('reg:0.01', 76.76)]
09/21/2023 19:32:00 - INFO - root - Cross-validation : best param found is reg = 1e-05 with score 84.15
09/21/2023 19:32:00 - INFO - root - Evaluating...
09/21/2023 19:32:00 - INFO - __main__ - ***** Eval results *****
09/21/2023 19:32:00 - INFO - __main__ - STS12 = 0.6466070114897755
09/21/2023 19:32:00 - INFO - __main__ - STS13 = 0.7940081784855644
09/21/2023 19:32:00 - INFO - __main__ - STS14 = 0.7106309581907064
09/21/2023 19:32:00 - INFO - __main__ - STS15 = 0.8022190201969241
09/21/2023 19:32:00 - INFO - __main__ - STS16 = 0.7800045550188356
09/21/2023 19:32:00 - INFO - __main__ - eval_CR = 87.52
09/21/2023 19:32:00 - INFO - __main__ - eval_MPQA = 88.73
09/21/2023 19:32:00 - INFO - __main__ - eval_MR = 82.03
09/21/2023 19:32:00 - INFO - __main__ - eval_MRPC = 74.93
09/21/2023 19:32:00 - INFO - __main__ - eval_SST2 = 87.84
09/21/2023 19:32:00 - INFO - __main__ - eval_SUBJ = 95.39
09/21/2023 19:32:00 - INFO - __main__ - eval_TREC = 84.15
09/21/2023 19:32:00 - INFO - __main__ - eval_avg_sts = 0.7525457395203998
09/21/2023 19:32:00 - INFO - __main__ - eval_avg_transfer = 85.79857142857144
09/21/2023 19:32:00 - INFO - __main__ - eval_sickr_spearman = 0.734116144071677
09/21/2023 19:32:00 - INFO - __main__ - eval_stsb_spearman = 0.8002343091893147
BertForCL and RobertaForCL both have a .lm_head but the uses seem to be commented out.
If the models only actually use .bert and .discriminator does that mean all of the transformer params are tied between the encoder and the discriminator?
Thanks
I have used this run_diffcse.sh, but the avg sts result is only 77.00. Is there some thing need to notice?
“Why don’t you have a code file to convert the model format to Hugging Face format? Can your model be evaluated directly after being saved without conversion?
Hi, thanks for sharing the great work!
Here (
Line 189 in 33b29a3
Traceback (most recent call last):
File "/home/lizhaohui/DiffCSE/train.py", line 600, in
main()
File "/home/lizhaohui/DiffCSE/train.py", line 564, in main
train_result = trainer.train(model_path=model_path)
File "/home/lizhaohui/DiffCSE/diffcse/trainers.py", line 514, in train
tr_pos_sim += model.sim.pos_avg
File "/home/lizhaohui/miniconda3/envs/py395/lib/python3.9/site-packages/torch/nn/modules/module.py", line 778, in getattr
raise ModuleAttributeError("'{}' object has no attribute '{}'".format(
torch.nn.modules.module.ModuleAttributeError: 'DataParallel' object has no attribute 'sim'
0%| | 0/7814 [00:11<?, ?it/s]
The vocab size of the sentence encoder is 50256, but the vocab size of the generator is 30522. This will cause the index out of range.
The code is follow:
cls.generator(mlm_input_ids, attention_mask) mlm_input_ids.min()=0 mlm_input_ids.max()=50264
How should I initialize them so that they can be consistent?
Could you pls share how to install diffcse
? I tried the below
pip install simcse
git clone [email protected]:voidism/DiffCSE.git
cd transformers-4.2.1
pip install .
then running
from diffcse import DiffCSE
results in ModuleNotFoundError
Thanks!
CUDA_VISIBLE_DEVICES=6 python train.py --model_name_or_path bert-base-uncased --generator_name distilbert-base-uncased --train_file data/nli_for_simcse.csv --num_train_epochs 2 --per_device_train_batch_size 64 --learning_rate 2e-6 --max_seq_length 32 --evaluation_strategy steps --metric_for_best_model stsb_spearman --load_best_model_at_end --eval_steps 125 --pooler_type cls --overwrite_output_dir --logging_first_step --logging_dir trained --temp 0.05 --do_train --do_eval --batchnorm --lambda_weight 0.05 --fp16 --masking_ratio 0.15 --output_dir trained_orig
Cannot replicate the results for transfer task after training the model on the params above (note lambda, lr, batch size values) are taken from the appendix of the paper.
Results got :
Eval results *****
epoch = 2.0
eval_CR = 88.03
eval_MPQA = 88.01
eval_MR = 80.52
eval_MRPC = 73.97
eval_SST2 = 85.32
eval_SUBJ = 93.48
eval_TREC = 77.07
eval_avg_sts = 0.8235647840844718
eval_avg_transfer = 83.77142857142859
eval_sickr_spearman = 0.8026336149884777
eval_stsb_spearman = 0.8444959531804657
Can you please help with right params to train the model from scratch in supervised setting?
您好~ 謝謝您分享模型的原始碼!
我正在嘗試複製論文中的實驗,想請問您是否有將論文中 Table 9: Retrieved top-3 examples by SimCSE and DiffCSE from STS-B test set. 的實驗放上 github 呢? 如果有的話可以指引我原始碼的位置嗎? 謝謝您~
Hi, I am using your code to train DiffCSE with the model BERT-large and Roberta-large, and I use distilbert-base-uncased and distilroberta-base as the generator model. But I always encounter the runtimeError: mat1 dim 1 must match mat2 dim 0, which is fine when I use BERT-base or Roberta-base. Could you give me some advises, please?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.