GithubHelp home page GithubHelp logo

voidism / diffcse Goto Github PK

View Code? Open in Web Editor NEW
286.0 4.0 26.0 6.45 MB

Code for the NAACL 2022 long paper "DiffCSE: Difference-based Contrastive Learning for Sentence Embeddings"

Home Page: https://arxiv.org/abs/2204.10298

License: MIT License

Python 92.42% Shell 0.35% Makefile 0.02% Dockerfile 0.06% Jsonnet 0.01% CSS 0.06% JavaScript 0.19% Jupyter Notebook 6.89%
contrastive-learning representation-learning self-supervised-learning sentence-embeddings sentence-similarity sentence-transformers

diffcse's Introduction

👋 You've reached the GitHub profile of Yung-Sung!

[CV] [Twitter] [Github] [Google Scholar] [DBLP] [Blog] [Linkedin] [Instagram]

diffcse's People

Contributors

bm-k avatar voidism avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

diffcse's Issues

grid serch

Hello, your work is very excellent, and I am interested in the grid search you mentioned in your paper. My question is how you do this grid search to find the best hyperparameters?

Did you set different parameters in run_diffcse.sh and then pick the best ones based on their performance on the STS-B development set? In that case, it might take many training experiments.

RuntimeError: Input tensor at index 3 has invalid shape [14, 14], but expected [14, 17]

Hello, I have solved the first two problems, but I encountered a new problem at 50% of the first epoch of the code:
Traceback (most recent call last):
File "/home/lizhaohui/DiffCSE/train.py", line 600, in
main()
File "/home/lizhaohui/DiffCSE/train.py", line 564, in main
train_result = trainer.train(model_path=model_path)
File "/home/lizhaohui/DiffCSE/diffcse/trainers.py", line 513, in train
tr_loss += self.training_step(model, inputs)
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/transformers/trainer.py", line 1248, in training_step
loss = self.compute_loss(model, inputs)
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/transformers/trainer.py", line 1277, in compute_loss
outputs = model(**inputs)
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 162, in forward
return self.gather(outputs, self.output_device)
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 174, in gather
return gather(outputs, output_device, dim=self.dim)
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/parallel/scatter_gather.py", line 68, in gather
res = gather_map(outputs)
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/parallel/scatter_gather.py", line 61, in gather_map
return type(out)(((k, gather_map([d[k] for d in outputs]))
File "", line 7, in init
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/transformers/file_utils.py", line 1383, in post_init
for element in iterator:
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/parallel/scatter_gather.py", line 61, in
return type(out)(((k, gather_map([d[k] for d in outputs]))
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/parallel/scatter_gather.py", line 55, in gather_map
return Gather.apply(target_device, dim, *outputs)
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/parallel/_functions.py", line 71, in forward
return comm.gather(inputs, ctx.dim, ctx.target_device)
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/parallel/comm.py", line 230, in gather return torch._C._gather(tensors, dim, destination)
RuntimeError: Input tensor at index 3 has invalid shape [14, 14], but expected [14, 17]
50%|█████████████████████████████████████▍ | 3906/7814 [1:00:50<1:00:51, 1.07it/s]Fatal Python error: PyEval_SaveThread: the function must be called with the GIL held, but the GIL is released (the current Python thread state is NULL)
Python runtime state: finalizing (tstate=0x559cd4786400)

run_diffcse.sh: line 30: 3465377 Aborted (core dumped) python train.py --model_name_or_path bert-base-uncased --generator_name distilbert-base-uncased --train_file data/wiki1m_for_simcse.txt --output_dir output_dir --num_train_epochs 2 --per_device_train_batch_size 64 --learning_rate 7e-6 --max_seq_length 32 --evaluation_strategy steps --metric_for_best_model stsb_spearman --load_best_model_at_end --eval_steps 125 --pooler_type cls --mlp_only_train --overwrite_output_dir --logging_first_step --logging_dir log_dir --temp 0.05 --do_train --do_eval --batchnorm --lambda_weight 0.005 --fp16 --masking_ratio 0.30

runtime error

Hi! I reproduce result of roberta-base on my device and all parameters are same as those in your paper. But I get the following error:
a
b

torch.nn.modules.module.ModuleAttributeError: 'DataParallel' object has no attribute 'sim'

Traceback (most recent call last):
File "/home/lizhaohui/DiffCSE/train.py", line 600, in
main()
File "/home/lizhaohui/DiffCSE/train.py", line 564, in main
train_result = trainer.train(model_path=model_path)
File "/home/lizhaohui/DiffCSE/diffcse/trainers.py", line 514, in train
tr_pos_sim += model.sim.pos_avg
File "/home/lizhaohui/miniconda3/envs/py395/lib/python3.9/site-packages/torch/nn/modules/module.py", line 778, in getattr
raise ModuleAttributeError("'{}' object has no attribute '{}'".format(
torch.nn.modules.module.ModuleAttributeError: 'DataParallel' object has no attribute 'sim'
0%| | 0/7814 [00:11<?, ?it/s]

Where is the code of generator(fixed)?

Hi, thanks for sharing the great work!
Here (

mlm_outputs = cls.discriminator(
) I find the discriminator is a BERT, but I can not find the generator model refered in the paper with Figure1, and it says the generator is a distill BERT. Could you show me where does the generator initiated and embed the masked sentence, please?

第六章 Retrieval Task 的原始碼是否有放在 Github 呢?

您好~ 謝謝您分享模型的原始碼!
我正在嘗試複製論文中的實驗,想請問您是否有將論文中 Table 9: Retrieved top-3 examples by SimCSE and DiffCSE from STS-B test set. 的實驗放上 github 呢? 如果有的話可以指引我原始碼的位置嗎? 謝謝您~

RuntimeError: mat1 dim 1 must match mat2 dim 0

Hi, I am using your code to train DiffCSE with the model BERT-large and Roberta-large, and I use distilbert-base-uncased and distilroberta-base as the generator model. But I always encounter the runtimeError: mat1 dim 1 must match mat2 dim 0, which is fine when I use BERT-base or Roberta-base. Could you give me some advises, please?
擷取

Command to replicate transfer results

CUDA_VISIBLE_DEVICES=6 python train.py --model_name_or_path bert-base-uncased --generator_name distilbert-base-uncased --train_file data/nli_for_simcse.csv --num_train_epochs 2 --per_device_train_batch_size 64 --learning_rate 2e-6 --max_seq_length 32 --evaluation_strategy steps --metric_for_best_model stsb_spearman --load_best_model_at_end --eval_steps 125 --pooler_type cls --overwrite_output_dir --logging_first_step --logging_dir trained --temp 0.05 --do_train --do_eval --batchnorm --lambda_weight 0.05 --fp16 --masking_ratio 0.15 --output_dir trained_orig

Cannot replicate the results for transfer task after training the model on the params above (note lambda, lr, batch size values) are taken from the appendix of the paper.

Results got :
Eval results *****
epoch = 2.0
eval_CR = 88.03
eval_MPQA = 88.01
eval_MR = 80.52
eval_MRPC = 73.97
eval_SST2 = 85.32
eval_SUBJ = 93.48
eval_TREC = 77.07
eval_avg_sts = 0.8235647840844718
eval_avg_transfer = 83.77142857142859
eval_sickr_spearman = 0.8026336149884777
eval_stsb_spearman = 0.8444959531804657

Can you please help with right params to train the model from scratch in supervised setting?

Index out of range in self

The vocab size of the sentence encoder is 50256, but the vocab size of the generator is 30522. This will cause the index out of range.
The code is follow:
cls.generator(mlm_input_ids, attention_mask) mlm_input_ids.min()=0 mlm_input_ids.max()=50264
How should I initialize them so that they can be consistent?

How to install?

Could you pls share how to install diffcse? I tried the below

pip install simcse
git clone [email protected]:voidism/DiffCSE.git
cd transformers-4.2.1
pip install .

then running

from diffcse import DiffCSE

results in ModuleNotFoundError

Thanks!

Implementing Error

I've been trying to run your codes, but for some reason, it seems like it doesn't work on my settings.

I ran into this error mentioned below.

Is there any solution?

warnings.warn('Was asked to gather along dimension 0, but all '
Traceback (most recent call last):
  File "train.py", line 600, in <module>
    main()
  File "train.py", line 564, in main
    train_result = trainer.train(model_path=model_path)
  File "/home/qmin/DiffCSE/diffcse/trainers.py", line 514, in train
    tr_pos_sim += model.sim.pos_avg
  File "/home/qmin/anaconda3/envs/diffcse/lib/python3.7/site-packages/torch/nn/modules/module.py", line 779, in __getattr__
    type(self).__name__, name))
torch.nn.modules.module.ModuleAttributeError: 'DataParallel' object has no attribute 'sim'

What is the use of lm_head?

BertForCL and RobertaForCL both have a .lm_head but the uses seem to be commented out.

If the models only actually use .bert and .discriminator does that mean all of the transformer params are tied between the encoder and the discriminator?

Thanks

Discrepancies in DiffSCE Code Execution and Reported Results: Seeking Insight

I executed the source code of DiffSCE on my computational resource (Tesla V100-SXM2-32GB), using the identical configuration as specified in the run_diffcse.sh file. I obtained the following results, which differ from the results reported in your paper and on your GitHub repository. To illustrate, there is a 3.24-point difference (78.49 - 75.25 = 3.24) in average STS accuracy between your reported results and the results I obtained.

Do you have any insights or suggestions regarding the source of this disparity in performance when running the code to generate results? (@voidism)

[INFO|trainer.py:358] 2023-09-21 19:27:21,467 >> Using amp fp16 backend
09/21/2023 19:27:21 - INFO - __main__ -   *** Evaluate ***
tasks:  ['STSBenchmark', 'SICKRelatedness', 'STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'MR', 'CR', 'SUBJ', 'MPQA', 'SST2', 'MRPC', 'TREC']
./SentEval/senteval/sts.py:42: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
  sent1 = np.array([s.split() for s in sent1])[not_empty_idx]
./SentEval/senteval/sts.py:43: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
  sent2 = np.array([s.split() for s in sent2])[not_empty_idx]
09/21/2023 19:27:54 - INFO - root -   Generating sentence embeddings
09/21/2023 19:28:02 - INFO - root -   Generated sentence embeddings
09/21/2023 19:28:02 - INFO - root -   Training pytorch-MLP-nhid0-rmsprop-bs128 with (inner) 5-fold cross-validation
09/21/2023 19:28:10 - INFO - root -   Best param found at split 1: l2reg = 0.001                 with score 82.31
09/21/2023 19:28:20 - INFO - root -   Best param found at split 2: l2reg = 0.001                 with score 81.99
09/21/2023 19:28:32 - INFO - root -   Best param found at split 3: l2reg = 0.0001                 with score 82.27
09/21/2023 19:28:42 - INFO - root -   Best param found at split 4: l2reg = 0.01                 with score 81.54
09/21/2023 19:28:53 - INFO - root -   Best param found at split 5: l2reg = 0.0001                 with score 82.04
09/21/2023 19:28:54 - INFO - root -   Generating sentence embeddings
09/21/2023 19:28:56 - INFO - root -   Generated sentence embeddings
09/21/2023 19:28:56 - INFO - root -   Training pytorch-MLP-nhid0-rmsprop-bs128 with (inner) 5-fold cross-validation
09/21/2023 19:28:59 - INFO - root -   Best param found at split 1: l2reg = 1e-05                 with score 87.81
09/21/2023 19:29:03 - INFO - root -   Best param found at split 2: l2reg = 0.0001                 with score 88.15
09/21/2023 19:29:07 - INFO - root -   Best param found at split 3: l2reg = 1e-05                 with score 87.32
09/21/2023 19:29:11 - INFO - root -   Best param found at split 4: l2reg = 1e-05                 with score 87.05
09/21/2023 19:29:15 - INFO - root -   Best param found at split 5: l2reg = 0.0001                 with score 87.25
09/21/2023 19:29:15 - INFO - root -   Generating sentence embeddings
09/21/2023 19:29:23 - INFO - root -   Generated sentence embeddings
09/21/2023 19:29:23 - INFO - root -   Training pytorch-MLP-nhid0-rmsprop-bs128 with (inner) 5-fold cross-validation
09/21/2023 19:29:32 - INFO - root -   Best param found at split 1: l2reg = 0.001                 with score 95.22
09/21/2023 19:29:42 - INFO - root -   Best param found at split 2: l2reg = 1e-05                 with score 95.51
09/21/2023 19:29:52 - INFO - root -   Best param found at split 3: l2reg = 0.0001                 with score 95.31
09/21/2023 19:30:01 - INFO - root -   Best param found at split 4: l2reg = 0.001                 with score 95.45
09/21/2023 19:30:09 - INFO - root -   Best param found at split 5: l2reg = 0.0001                 with score 95.46
09/21/2023 19:30:10 - INFO - root -   Generating sentence embeddings
09/21/2023 19:30:12 - INFO - root -   Generated sentence embeddings
09/21/2023 19:30:12 - INFO - root -   Training pytorch-MLP-nhid0-rmsprop-bs128 with (inner) 5-fold cross-validation
09/21/2023 19:30:21 - INFO - root -   Best param found at split 1: l2reg = 0.001                 with score 89.16
09/21/2023 19:30:29 - INFO - root -   Best param found at split 2: l2reg = 1e-05                 with score 88.19
09/21/2023 19:30:37 - INFO - root -   Best param found at split 3: l2reg = 0.001                 with score 88.91
09/21/2023 19:30:45 - INFO - root -   Best param found at split 4: l2reg = 0.001                 with score 88.44
09/21/2023 19:30:54 - INFO - root -   Best param found at split 5: l2reg = 0.001                 with score 88.93
09/21/2023 19:30:55 - INFO - root -   Computing embedding for train
09/21/2023 19:31:22 - INFO - root -   Computed train embeddings
09/21/2023 19:31:22 - INFO - root -   Computing embedding for dev
09/21/2023 19:31:23 - INFO - root -   Computed dev embeddings
09/21/2023 19:31:23 - INFO - root -   Computing embedding for test
09/21/2023 19:31:24 - INFO - root -   Computed test embeddings
09/21/2023 19:31:24 - INFO - root -   Training pytorch-MLP-nhid0-rmsprop-bs128 with standard validation..
09/21/2023 19:31:36 - INFO - root -   [('reg:1e-05', 87.73), ('reg:0.0001', 87.84), ('reg:0.001', 87.61), ('reg:0.01', 86.93)]
09/21/2023 19:31:36 - INFO - root -   Validation : best param found is reg = 0.0001 with score             87.84
09/21/2023 19:31:36 - INFO - root -   Evaluating...
09/21/2023 19:31:39 - INFO - root -   ***** Transfer task : MRPC *****


09/21/2023 19:31:39 - INFO - root -   Computing embedding for train
09/21/2023 19:31:45 - INFO - root -   Computed train embeddings
09/21/2023 19:31:45 - INFO - root -   Computing embedding for test
09/21/2023 19:31:47 - INFO - root -   Computed test embeddings
09/21/2023 19:31:47 - INFO - root -   Training pytorch-MLP-nhid0-rmsprop-bs128 with 5-fold cross-validation
09/21/2023 19:31:51 - INFO - root -   [('reg:1e-05', 74.85), ('reg:0.0001', 74.85), ('reg:0.001', 74.93), ('reg:0.01', 74.07)]
09/21/2023 19:31:51 - INFO - root -   Cross-validation : best param found is reg = 0.001             with score 74.93
09/21/2023 19:31:51 - INFO - root -   Evaluating...
09/21/2023 19:31:52 - INFO - root -   ***** Transfer task : TREC *****


09/21/2023 19:31:54 - INFO - root -   Computed train embeddings
09/21/2023 19:31:54 - INFO - root -   Computed test embeddings
09/21/2023 19:31:54 - INFO - root -   Training pytorch-MLP-nhid0-rmsprop-bs128 with 5-fold cross-validation
09/21/2023 19:32:00 - INFO - root -   [('reg:1e-05', 84.15), ('reg:0.0001', 84.02), ('reg:0.001', 83.47), ('reg:0.01', 76.76)]
09/21/2023 19:32:00 - INFO - root -   Cross-validation : best param found is reg = 1e-05             with score 84.15
09/21/2023 19:32:00 - INFO - root -   Evaluating...
09/21/2023 19:32:00 - INFO - __main__ -   ***** Eval results *****
09/21/2023 19:32:00 - INFO - __main__ -     STS12 = 0.6466070114897755
09/21/2023 19:32:00 - INFO - __main__ -     STS13 = 0.7940081784855644
09/21/2023 19:32:00 - INFO - __main__ -     STS14 = 0.7106309581907064
09/21/2023 19:32:00 - INFO - __main__ -     STS15 = 0.8022190201969241
09/21/2023 19:32:00 - INFO - __main__ -     STS16 = 0.7800045550188356
09/21/2023 19:32:00 - INFO - __main__ -     eval_CR = 87.52
09/21/2023 19:32:00 - INFO - __main__ -     eval_MPQA = 88.73
09/21/2023 19:32:00 - INFO - __main__ -     eval_MR = 82.03
09/21/2023 19:32:00 - INFO - __main__ -     eval_MRPC = 74.93
09/21/2023 19:32:00 - INFO - __main__ -     eval_SST2 = 87.84
09/21/2023 19:32:00 - INFO - __main__ -     eval_SUBJ = 95.39
09/21/2023 19:32:00 - INFO - __main__ -     eval_TREC = 84.15
09/21/2023 19:32:00 - INFO - __main__ -     eval_avg_sts = 0.7525457395203998
09/21/2023 19:32:00 - INFO - __main__ -     eval_avg_transfer = 85.79857142857144
09/21/2023 19:32:00 - INFO - __main__ -     eval_sickr_spearman = 0.734116144071677
09/21/2023 19:32:00 - INFO - __main__ -     eval_stsb_spearman = 0.8002343091893147

Question about scalability wrt. input length

Howdy,

I was wondering if any experiments were done with the DiffCSE framework for long inputs (300-500 tokens), or ie. are there any conditions on the training data necessary for convergence?

PS - congrats on the paper, it was a really fun read :)

无法加载离线数据集

image

如图所示,我已经离线下载好数据集,但是怎么加载呢
老是报错,
datasets = load_dataset(extension, data_files=data_files, cache_dir="./data/")
麻烦看看这是为什么

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.