GithubHelp home page GithubHelp logo

gitabtion / bertbasedcorrectionmodels Goto Github PK

View Code? Open in Web Editor NEW
260.0 260.0 42.0 110 KB

PyTorch impelementations of BERT-based Spelling Error Correction Models. 基于BERT的文本纠错模型,使用PyTorch实现。

License: Apache License 2.0

Python 99.91% Shell 0.09%
csc pytorch transformers

bertbasedcorrectionmodels's Introduction

Hi there 👋

bertbasedcorrectionmodels's People

Contributors

gitabtion avatar okcd00 avatar yazooliu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

bertbasedcorrectionmodels's Issues

ValueError: Expected input batch_size (80) to match target batch_size (88)

l can train correctly the model with the public datasets, but when l use my company data to train model,an error occurs,as follows:
ValueError: Expected input batch_size (80) to match target batch_size (88)

note: my data format is json, the same as the above public datasets, an special example as follows:
{
"id": "--",
"original_text": "播放我的世界之梦想大陆",
"wrong_ids": [],
"correct_text": "播放我的世界之梦想大陆"
}

Other languages

Great work,

I wonder if this approach can be adopted for other languages.?

cannot Reproduce the result

I fellow the steps. And get different result.

Epoch 9: 100%|█████████████████████████████████████████████████████████████████████████| 199/199 [00:55<00:00, 3.56it/s, loss=0.103, v_num=1]
/home/dell/workspace/jiangbingyu/correction/checkpoints/SoftMaskedBert/epoch=09-val_loss=0.13123.ckpt
Testing: 0it [00:00, ?it/s]2021-09-08 23:47:58,342 SoftMaskedBertModel INFO: Testing...
Testing: 97%|█████████████████████████████████████████████████████████████████████████████████████████████▏ | 67/69 [00:03<00:00, 18.43it/s]
2021-09-08 23:48:02,103 SoftMaskedBertModel INFO: Test.
2021-09-08 23:48:02,105 SoftMaskedBertModel INFO: loss: 0.08779423662285873
2021-09-08 23:48:02,105 SoftMaskedBertModel INFO: Detection:
acc: 0.5000
2021-09-08 23:48:02,106 SoftMaskedBertModel INFO: Correction:
acc: 0.6900
2021-09-08 23:48:02,114 SoftMaskedBertModel INFO: The detection result is precision=0.8228782287822878, recall=0.6308345120226309 and F1=0.7141713370696557
2021-09-08 23:48:02,115 SoftMaskedBertModel INFO: The correction result is precision=0.7399103139013453, recall=0.6534653465346535 and F1=0.694006309148265
2021-09-08 23:48:02,116 SoftMaskedBertModel INFO: Sentence Level: acc:0.690000, precision:0.829508, recall:0.466790, f1:0.597403
Testing: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:03<00:00, 18.27it/s]

DATALOADER:0 TEST RESULTS
{'val_loss': 0.08779423662285873}

评测脚本问题

用你的配置训练的bert纠错模型,用你的评测脚本:

Sentence Level: 
acc:0.793636, precision:0.828810, recall:0.732472, f1:0.777669

使用realise模型的评测脚本:

{'sent-detect-acc': 82.18181818181817, 
'sent-detect-p': 72.86689419795222, 
'sent-detect-r': 78.9279112754159, 
'sent-detect-f1': 75.77639751552793, 
'sent-correct-acc': 79.9090909090909, 
'sent-correct-p': 68.60068259385666, 
'sent-correct-r': 74.3068391866913, 
'sent-correct-f1': 71.33984028393967}

你只在src == tgt时统计了FP,统计出的FP偏小,导致计算precision时的分母偏小,最终的precision偏大

关于train_SoftMaskedBert中的HYPER_PARAMS问题

train_SoftMaskedBert中的HYPER_PARAMS数值,是指关于detection loss 与 correction loss的权重占比吗?
这个数值是应用于CscTrainingModel.training_step下的loss = self.w * outputs[1] + (1 - self.w) * outputs[0]吗?
0.8: detection0.2 + correction0.8
我是否可以通过修改该数值达到模型侧重于提高detection的prf

运行train_csc.py报错:AttributeError: Can't pickle local object 'get_csc_loader.<locals>._collate_fn'

依赖包严格按照给出版本号安装
D:\SoftRun\Anaconda3\envs\torch161\python.exe E:/nlpcode/BertBasedCorrectionModels-master/tools/train_csc.py
2021-04-30 13:58:02,496 bert4csc INFO: Namespace(config_file='', opts=[])
2021-04-30 13:58:02,496 bert4csc INFO: Loaded configuration file csc/train_bert4csc.yml
2021-04-30 13:58:02,496 bert4csc INFO:
MODEL:
BERT_CKPT: "bert-base-chinese"
DEVICE: "cuda:0"
NAME: "bert4csc"

[loss_coefficient]

HYPER_PARAMS: [ 1.0 ]
GPU_IDS: [0]

DATASETS:
TRAIN: "datasets/csc/train.json"
VALID: "datasets/csc/dev.json"
TEST: "datasets/csc/test.json"

SOLVER:
BASE_LR: 0.001
WEIGHT_DECAY: 0.00001
BATCH_SIZE: 16
WARMUP_EPOCHS: 8
MAX_EPOCHS: 20
ACCUMULATE_GRAD_BATCHES: 16

TEST:
BATCH_SIZE: 16

TASK:
NAME: "csc"

OUTPUT_DIR: "checkpoints/bert4csc"

2021-04-30 13:58:02,496 bert4csc INFO: Running with config:
DATALOADER:
NUM_WORKERS: 4
DATASETS:
TEST: datasets/csc/test.json
TRAIN: datasets/csc/train.json
VALID: datasets/csc/dev.json
INPUT:
MAX_LEN: 512
MODE: ['train', 'test']
MODEL:
BERT_CKPT: bert-base-chinese
DEVICE: cuda:0
GPU_IDS: [0]
HYPER_PARAMS: [1.0]
NAME: bert4csc
NUM_CLASSES: 10
WEIGHTS:
OUTPUT_DIR: checkpoints/bert4csc
SOLVER:
ACCUMULATE_GRAD_BATCHES: 16
BASE_LR: 0.001
BATCH_SIZE: 16
BIAS_LR_FACTOR: 2
CHECKPOINT_PERIOD: 10
GAMMA: 0.1
LOG_PERIOD: 100
MAX_EPOCHS: 20
MOMENTUM: 0.9
OPTIMIZER_NAME: AdamW
STEPS: (30000,)
WARMUP_EPOCHS: 8
WARMUP_FACTOR: 0.3333333333333333
WARMUP_ITERS: 500
WARMUP_METHOD: linear
WEIGHT_DECAY: 1e-05
WEIGHT_DECAY_BIAS: 0
TASK:
NAME: csc
TEST:
BATCH_SIZE: 16
CKPT_FN:
Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']

  • This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).

  • This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    D:\SoftRun\Anaconda3\envs\torch161\lib\site-packages\pytorch_lightning\utilities\distributed.py:49: UserWarning: Checkpoint directory E:\nlpcode\BertBasedCorrectionModels-master\checkpoints\bert4csc exists and is not empty.
    warnings.warn(*args, **kwargs)
    GPU available: True, used: True
    TPU available: None, using: 0 TPU cores
    LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

    | Name | Type | Params


0 | bert | BertForMaskedLM | 102 M

102 M Trainable params
0 Non-trainable params
102 M Total params
Validation sanity check: 0it [00:00, ?it/s]2021-04-30 13:58:11,384 bert4csc INFO: Valid.
Traceback (most recent call last):
File "E:/nlpcode/BertBasedCorrectionModels-master/tools/train_csc.py", line 53, in
main()
File "E:/nlpcode/BertBasedCorrectionModels-master/tools/train_csc.py", line 49, in main
train(cfg, model, loaders, ckpt_callback)
File "E:\nlpcode\BertBasedCorrectionModels-master\tools\bases.py", line 78, in train
trainer.fit(model, train_loader, valid_loader)
File "D:\SoftRun\Anaconda3\envs\torch161\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 470, in fit
results = self.accelerator_backend.train()
File "D:\SoftRun\Anaconda3\envs\torch161\lib\site-packages\pytorch_lightning\accelerators\gpu_accelerator.py", line 68, in train
results = self.train_or_test()
File "D:\SoftRun\Anaconda3\envs\torch161\lib\site-packages\pytorch_lightning\accelerators\accelerator.py", line 69, in train_or_test
results = self.trainer.train()
File "D:\SoftRun\Anaconda3\envs\torch161\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 492, in train
self.run_sanity_check(self.get_model())
File "D:\SoftRun\Anaconda3\envs\torch161\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 690, in run_sanity_check
_, eval_results = self.run_evaluation(test_mode=False, max_batches=self.num_sanity_val_batches)
File "D:\SoftRun\Anaconda3\envs\torch161\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 593, in run_evaluation
for batch_idx, batch in enumerate(dataloader):
File "D:\SoftRun\Anaconda3\envs\torch161\lib\site-packages\torch\utils\data\dataloader.py", line 291, in iter
return _MultiProcessingDataLoaderIter(self)
File "D:\SoftRun\Anaconda3\envs\torch161\lib\site-packages\torch\utils\data\dataloader.py", line 737, in init
w.start()
File "D:\SoftRun\Anaconda3\envs\torch161\lib\multiprocessing\process.py", line 105, in start
self._popen = self._Popen(self)
File "D:\SoftRun\Anaconda3\envs\torch161\lib\multiprocessing\context.py", line 223, in _Popen
return _default_context.get_context().Process._Popen(process_obj)
File "D:\SoftRun\Anaconda3\envs\torch161\lib\multiprocessing\context.py", line 322, in _Popen
return Popen(process_obj)
File "D:\SoftRun\Anaconda3\envs\torch161\lib\multiprocessing\popen_spawn_win32.py", line 65, in init
reduction.dump(process_obj, to_child)
File "D:\SoftRun\Anaconda3\envs\torch161\lib\multiprocessing\reduction.py", line 60, in dump
ForkingPickler(file, protocol).dump(obj)
AttributeError: Can't pickle local object 'get_csc_loader.._collate_fn'

Process finished with exit code 1

无法加载训练的模型,程序自动从HuggingFace下载模型,这是什么原因?

你好,通过调用inference.py中的load_model_directly()方法,无法加载训练的模型,具体代码如下:

① 代码部分:

def load_model_directly():
ckpt_file = 'SoftMaskedBert/epoch=05-val_loss=0.03253.ckpt'
config_file = 'csc/train_SoftMaskedBert.yml'

from bbcm.config import cfg
cp = get_abs_path('checkpoints', ckpt_file)
cfg.merge_from_file(get_abs_path('configs', config_file))
tokenizer = BertTokenizer.from_pretrained(cfg.MODEL.BERT_CKPT)
print("###tokenizer加载完毕")
print("### tokenizer: ", tokenizer)
if cfg.MODEL.NAME in ['bert4csc', 'macbert4csc']:
    model = BertForCsc.load_from_checkpoint(cp,
                                            cfg=cfg,
                                            tokenizer=tokenizer)
else:
    print("###加载模型")
    print("###cp : ", cp)
    model = SoftMaskedBertModel.load_from_checkpoint(cp,
                                                     cfg=cfg,
                                                     tokenizer=tokenizer)
print("###model加载完毕")
model.eval()
model.to(cfg.MODEL.DEVICE)
return model

② 问题:
感觉这段代码没有起作用,ckpt文件无法加载,程序还是自动从huggingface下载了。
model = SoftMaskedBertModel.load_from_checkpoint(cp,
cfg=cfg,
tokenizer=tokenizer)
我查了一下load_from_checkpoint() 方法,对于参数cp, cfg的传递,没有看明白。

Infer question

How to output the position of the wrong word when inferring?

中文的某些符号问题

你好!谢谢你的开源模型!
我在测试时发现,纠错的句子中如果包含某些中文符号时,结果中会出现"[UNK]", 这种情况如何处理呢?

texts = ['今天我很“高心”']
model.predict(texts)

results:
['今天我很[UNK]高兴[UNK]']

想问一下作者,关于纠正网络中候选字集部分是从哪里加载的?

之前用其他的纠正模型,是有看到加载拼音 、形似字字典,纠错的过程中是替换字典中的后选择,然后通过模型去计算分数。
但是在我们这个softmasked bert里,从e‘输入到纠正网络后,直接给出的就是纠正好的文字,请问这些文字是来自哪里?来自vocab文件吗?

数据预处理的时候编码报错

初次运行模型在数据预处理B1_training.sgml的时候编码报错,文件是通过给的网址下的,有尝试在open的时候加入encoding='utf-8'但是没有作用。人工看了下文件也看不出问题出在哪..问题第一次似乎出在处理第5842行的时候。

<PASSAGE id="B1-0826-1">因為那是我的第一次去北京,我的朋友就是我的導遊。跟他我們一起去了北京特別的地方,必如說長城、故宮、天堂公園什麼的。</PASSAGE>

Traceback (most recent call last):
  File "/home/BertBasedCorrect/tools/train_csc.py", line 51, in <module>
    main()
  File "/home/BertBasedCorrect/tools/train_csc.py", line 28, in main
    preproc()
  File "/home/BertBasedCorrect/bbcm/data/processors/csc.py", line 185, in preproc
    for item in read_data(get_abs_path('datasets', 'csc')):
  File "/home/BertBasedCorrect/bbcm/data/processors/csc.py", line 116, in read_data
    for line in f:
  File "/home/anaconda3/envs/torch/lib/python3.7/codecs.py", line 322, in decode
    (result, consumed) = self._buffer_decode(data, self.errors, final)
UnicodeDecodeError: 'utf-8' codec can't decode byte 0x80 in position 4867: invalid start byte

Process finished with exit code 1

为什么精度这么高?

为什么这个跑出来的精度比论文上的精度高那么多,就连bert的微调都可以达到state of the art?
一开始看训练数据量不一样,以为删除了脏数据所以提升,把训练数据全部加上去,精度还是很高,连bert微调都能当sota了。。。
为啥呀

Could I jion the program to pull the request??

你好!
Good Job!

I reproduce your training model and generate the best model to inference on my special business sample as I DO not have more business data, but found some issue.
During inference, the model will show the overcorrection: check the correct word to wrong word like this:
{
"paragraph": "在本合同中,除上下文另有规定外,下列用语应当具有如下含义:",
"error_fragments": [
{
"error_init_id": 26,
"error_end_id": 27,
"src_fragment": "含", -> the correct word in paragraph
"tgt_fragment": "涵" -> model output wrong word
}
]
}
So, there will be one white_name_list config and code to fix this overcorrection issue.
I'd like to pull this code in tools/inference.py.
So are you approve and do you hava any good idea ? Let's keep talk, thx

BestRegards
Yazhou

MODEL.NUM_CLASSES = 10

这个NUM_CLASSES自己在训练的时候是不是应该改成vocab里面的元素个数呢?

使用cpu训练报错

使用config文件如下:
MODEL:
BERT_CKPT: "bert-base-chinese"
DEVICE: "cpu"
NAME: "SoftMaskedBertModel"

[loss_coefficient]

HYPER_PARAMS: [0.8]

DATASETS:
TRAIN: "datasets/csc/train.json"
VALID: "datasets/csc/dev.json"
TEST: "datasets/csc/test.json"

SOLVER:
BASE_LR: 0.0001
WEIGHT_DECAY: 5e-8
BATCH_SIZE: 32
MAX_EPOCHS: 10
ACCUMULATE_GRAD_BATCHES: 4

TEST:
BATCH_SIZE: 16

TASK:
NAME: "csc"

OUTPUT_DIR: "checkpoints/SoftMaskedBert"

运行命令:
python tools/train_csc.py --config_file csc/train_SoftMaskedBert.yml

报错:
/Users//opt/anaconda3/envs/sc/lib/python3.6/site-packages/pytorch_lightning/utilities/distributed.py:49: UserWarning: Checkpoint directory /Users///Documents/personal/BertBasedCorrectionModels/checkpoints/SoftMaskedBert exists and is not empty.
warnings.warn(*args, kwargs)
GPU available: False, used: False
TPU available: None, using: 0 TPU cores
Traceback (most recent call last):
File "tools/train_csc.py", line 52, in
main()
File "tools/train_csc.py", line 48, in main
train(cfg, model, loaders, ckpt_callback)
File "/Users/
//Documents/personal/BertBasedCorrectionModels/tools/bases.py", line 78, in train
trainer.fit(model, train_loader, valid_loader)
File "/Users//opt/anaconda3/envs/sc/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 454, in fit
self.accelerator_backend.setup(model)
File "/Users/
//opt/anaconda3/envs/sc/lib/python3.6/site-packages/pytorch_lightning/accelerators/cpu_accelerator.py", line 49, in setup
self.setup_optimizers(model)
File "/Users///opt/anaconda3/envs/sc/lib/python3.6/site-packages/pytorch_lightning/accelerators/accelerator.py", line 145, in setup_optimizers
optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)
File "/Users/
//opt/anaconda3/envs/sc/lib/python3.6/site-packages/pytorch_lightning/trainer/optimizers.py", line 30, in init_optimizers
optim_conf = model.configure_optimizers()
File "/Users///Documents/personal/BertBasedCorrectionModels/bbcm/engine/bases.py", line 21, in configure_optimizers
scheduler = build_lr_scheduler(self.cfg, optimizer)
File "/Users/
//Documents/personal/BertBasedCorrectionModels/bbcm/solver/build.py", line 49, in build_lr_scheduler
scheduler = getattr(lr_scheduler, cfg.SOLVER.SCHED)(scheduler_args)
File "/Users/
//Documents/personal/BertBasedCorrectionModels/bbcm/solver/lr_scheduler.py", line 73, in init
super().init(optimizer, last_epoch, verbose)
TypeError: init() takes from 2 to 3 positional arguments but 4 were given

关于det_labels的意思

您好,感谢开源,请问一下,模型训练过程中det_labels的意义是什么?

class BertForCsc(CscTrainingModel):
    def __init__(self, cfg, tokenizer):
        super().__init__(cfg)
        self.cfg = cfg
        self.bert = BertForMaskedLM.from_pretrained(cfg.MODEL.BERT_CKPT)
        self.tokenizer = tokenizer

    def forward(self, texts, cor_labels=None, det_labels=None):
        # print('text: ', texts)
        # print('cor_labels: ', cor_labels)
        # print('det labels: ', det_labels)
        if cor_labels is not None:
            # 正确样本不为空
            text_labels = self.tokenizer(cor_labels, padding=True, return_tensors='pt')['input_ids']
            text_labels = text_labels.to(self.cfg.MODEL.DEVICE)
            print('text labels: ', text_labels)
            # Tokens with indices set to -100 are ignored (masked)
            text_labels[text_labels == 0] = -100

看懂了..这个是对模型做检错,但是好像模型并没有做检错这个工作是吗?似乎是直接完成纠错的。因此计算出的det_acc也是恒定为1

模型复现时的精度问题

您好!我在github readme中看到test的精度与论文有较大差异,请问这是因为test数据集不同导致的吗?如何才能达到论文上的精度呢?期待您的回复。
该项目的结果:
c88b8b870086f42bc03bf76d095ba48
论文中的结果:
116408b1e5300aa24500ebd7e24b829

是用自己的数据集训练报错

用作者的数据集训练可以,使用自己的数据集结果报错,
ValueError: Expected input batch_size (304) to match target batch_size (336).

我把自己的数据集融合到作者的数据集中也能训练,唯独无法单独训练自己的,不知道是不是还应该修改其他参数?

suggest to update this readme.md to show the model predict output result and format(output data type is list)

Hi, guy.
Good Job.

In readme.md file part 2:
from tools.inference import *
ckpt_fn = 'SoftMaskedBert/epoch=02-val_loss=0.02904.ckpt' # find it in checkpoints/
config_file = 'csc/train_SoftMaskedBert.yml' # find it in configs/
model = load_model_directly(ckpt_fn=ckpt_fn, config_file=config_file)
texts = ['今天我很高心', '测试', '继续测试']
model.predict(texts)

suggest to update this file to show the model predict output result and format(output data type is list).
Or please allow me to update by pull request ?
Keep talk.

Best Regards
Yazhou

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.