onion12138 / casrelpytorch Goto Github PK
View Code? Open in Web Editor NEWReimplement CasRel model in PyTorch.使用PyTorch对吉林大学CasRel模型进行复现,并在百度关系抽取数据集上训练测试。
Reimplement CasRel model in PyTorch.使用PyTorch对吉林大学CasRel模型进行复现,并在百度关系抽取数据集上训练测试。
你好!我按照readme里面的格式定义了我的数据集,但是运行时在loss_fn中计算predict['obj_heads']的时候会出现传入的gold参数tensor多出一维的问题,使用code中自带的baidu数据集就不会出现这种情况,数据的格式都是一样的,请问这是因为处理过程有什么不同吗?
具体报错信息如下:
File "/CasRelPyTorch/Run.py", line 38, in loss_fn
loss = F.binary_cross_entropy(pred, gold, reduction='none')
File "/python3.7/site-packages/torch/nn/functional.py", line 2752, in binary_cross_entropy
"Please ensure they have the same size.".format(target.size(), input.size())
ValueError: Using a target size (torch.Size([8, 148, 1])) that is different to the input size (torch.Size([8, 148])) is deprecated. Please ensure they have the same size.
您好,感谢您的开源代码。
我有个问题想请教您,我看到您在 data.py 文件的 getitem_() 函数中,在构建 subject 对应的 objects 数据时,通过 random.choice 方法只选取了一个 subject 及其 object 作为数据,其余的 subject 不构建对应的数据了嘛 0.0?一条句子中可能有多个一对多的关系,那其余一对多的关系不能作为数据,怎么保证单条抽取的精度和完整度呢
你好,请问中文训练集上多少ep后收敛,f1等一直为0,即使10个ep后还是这样,这正常吗?
我把数据集换成了完整的duie百度数据集,在训练的时候第一轮训练到10%时出错:
Epoch 1/10: 10%|▉ | 10600/106960 [37:13<5:41:25, 4.70it/s, loss:0.09831]
0%| | 0/20652 [00:00<?, ?it/s]
Traceback (most recent call last):
File "C:\Users\Administrator\Desktop\NLP\CasRelPyTorch-master\Run.py", line 63, in
trainer.train()
File "D:\Anaconda3\lib\site-packages\fastNLP\core\trainer.py", line 622, in train
raise e
File "D:\Anaconda3\lib\site-packages\fastNLP\core\trainer.py", line 615, in train
self._train()
File "D:\Anaconda3\lib\site-packages\fastNLP\core\trainer.py", line 720, in _train
self.callback_manager.on_epoch_end()
File "D:\Anaconda3\lib\site-packages\fastNLP\core\callback.py", line 314, in wrapper
returns.append(getattr(callback, func.name)(*arg))
File "C:\Users\Administrator\Desktop\NLP\CasRelPyTorch-master\model\callback.py", line 30, in on_epoch_end
precision, recall, f1_score = metric(self.data_iter, self.rel_vocab, self.config, self.model)
File "C:\Users\Administrator\Desktop\NLP\CasRelPyTorch-master\model\evaluate.py", line 71, in metric
gold_triples = set(to_tuple(batch_y['triples'][0]))
TypeError: unhashable type: 'dict'
Process finished with exit code 1
请问能解答一下哪里的问题吗,谢谢了
你好,在尝试使用这份代码时,我遇到了如下的bug:
Traceback (most recent call last):
File "run.py", line 64, in
trainer.train()
File "/home/anaconda3/envs/pytorch/lib/python3.8/site-packages/fastNLP/core/trainer.py", line 622, in train
raise e
File "/home/anaconda3/envs/pytorch/lib/python3.8/site-packages/fastNLP/core/trainer.py", line 615, in train
self._train()
File "/home/anaconda3/envs/pytorch/lib/python3.8/site-packages/fastNLP/core/trainer.py", line 683, in _train
self._grad_backward(loss)
File "/home/anaconda3/envs/pytorch/lib/python3.8/site-packages/fastNLP/core/trainer.py", line 781, in _grad_backward
loss.backward()
File "/home/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/_tensor.py", line 255, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/home/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/autograd/init.py", line 147, in backward
Variable._execution_engine.run_backward(
RuntimeError: Function MmBackward returned an invalid gradient at index 0 - got [712, 768] but expected shape compatible with [712, 21128]
我比较奇怪的是为什么forward没有报错而在backward报错了;以及想问下您,结果是在目前这个master版本下面跑出来的吗?有没有什么debug的头绪呢?
感谢!
environment:
pytorch==1.9.0+cuda111
transformers==4.8.2
fastNLP==0.6.0
请问测试时是只能一条条数据进行测试,无法进行批量处理吗,那这样的话,用cpu和gpu是不是推理速度差不多呢
老板,请教下您这个代码复现的结果和源代码的效果是差不多么?准确率和查全率、查准率这几个指标相差是否很大呢?
看了下代码和数据,这个针对《A Novel Cascade Binary Tagging Framework for Relational Triple Extraction》不用关系标注数据来提取三元组。从百度数据来看这还是标签数据吧,用模型学习到这种标签数据在原sentence中表示。我这样理解如果换了数据集这个模型很难抽取其中的实体关系。不知道我理解的对不对。
为什么训练时候显示显存才占了5个G,官方的占30个g,而且还训练的很慢,这是怎么回事啊,标签按理说应该占据很大内存呀,seqlen*关系种类;然后我更改了batch,又说GPU暴了,这是啥情况
您好,我想问一下我把数据集换成WebNLG后,召回率比原文低很多。这是否是因为您将编码解码方式更改后导致的呢?
1、text = '王小红,女,汉族。张三毕业于哈佛大学。李四,1914年生。'
输出:
{('张三', '毕业院校', '哈佛大学'), ('张三', '民族', '汉族'), ('张三', '出生日期', '1914年'), ('王小红', '民族', '汉族'), ('王小红', '毕业院校', '哈佛大学')}
李四没有检测到,关系是组合的结果。
2、text = '张三毕业于哈佛大学。李四,1914年生。'
输出:
{('张三', '出生日期', '1914年'), ('张三', '毕业院校', '哈佛大学'), ('李四', '毕业院校', '哈佛大学'), ('李四', '出生日期', '1914年')}
关系是组合的结果
您好,请问这个模型是进行多对(一对以上)关系三元组的识别吗,因为我看模型使用的数据集《百度》中一条数据只包含了一组关系实体对。
I use this pytorch version reimplement of CasRel to run on dataset:NYT,and use bert-base-cased and other settings as same as the original paper does, but I got the best F1 of 70,which is lower than its paper's result 89.6, but when I use the origin keras-version CasRel to run on same dataset NYT, got the best F1 of 88.8, which is almost similar as paper's result!
So, I wonder the correctness of this reimplement! Have you had any similar problems?
f1: 0.00, precision: 0.00, recall: 0.00,模型跑起来了,但是都是0,有人出现这种情况吗
作者您好!首先感谢您所做的工作。
我下载了您的代码,在没有改动的情况下运行了几次,但是测试集上的f1、precision、和recall分别在0.70、0.80和0.63左右,和readme中的0.78 0.80、 0.76差距比较大,感觉主要是recall太低导致的,您知道具体的原因吗?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.