GithubHelp home page GithubHelp logo

globalpointer_pytorch's Introduction

GlobalPointer_pytorch

喜欢本项目的话,欢迎点击右上角的star,感谢每一个点赞的你。

项目介绍

本项目的模型参考苏剑林的文章GlobalPointer:用统一的方式处理嵌套和非嵌套NER,并用Pytorch实现。

GlobalPoniter多头识别嵌套实体示意图

GlobalPointer的设计思路与TPLinker-NER类似,但在实现方式上不同。具体体现在:

  1. 加性乘性Attention

TPLinker在Multi-Head上用的是加性Attention:

而GlobalPointer用的是乘性Attention:

  1. 位置编码

GlobalPointer在模型中还加入了一种旋转式位置编码RoPE。这是一种“通过绝对位置编码的方式实现相对位置编码”,在本模型中效果明显。

Usage

实验环境

本次实验进行时Python版本为3.6,其他主要的第三方库包括:

  • pytorch==1.8.1
  • wandb==0.10.26 #for logging the result
  • transformers==4.1.1
  • tqdm==4.54.1

下载预训练模型

请下载Bert的中文预训练模型bert-base-chinese存放至 pretrained_models/,并在config.py中配置正确的bert_path

Train

python train.py

Evaluation

python evaluate.py

实验结果

默认配置(超参数已在 config.py 文件中),数据集是 CLUENER

  • 验证集 Best F1:0.7966

globalpointer_pytorch's People

Contributors

gaohongkui avatar qznan avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

globalpointer_pytorch's Issues

关于wandb调用

raise UsageError("api_key not configured (no-tty). call " + directive) wandb.errors.UsageError: api_key not configured (no-tty). call wandb.login(key=[your_api_key])
请问这个问题怎么解决?试了各种办法无果,api_key我已经注册拿到了

实验结果

作者你好,请问你在cluener数据集上的实验结果是多少?谢谢!

loss的参数顺序需要修正

调用multilabel_categorical_crossentropy时出错:train.py / line 188
image
调用loss_fun出错:train.py / line 159
image
loss计算在整体上没有问题,但是是因为后面参数传递也发生了错误。

数据集划分问题

请问数据集文件有dev、train和test,test是没标签,请问带有标签的测试集用来评估测试结果是哪个文件呢?dev文件是验证集吗?evaluate.py这个文件是做什么的呢?评估测试集结果和预测未知标签数据集都是这个吗?

关于内存方面

你在generate_inputs时一次性加入所有的labels会不会导致内存爆啊,就如CMeEE数据集而言,13000 9 256 256 8=66GB

split的维度问题

苏建林的tf原版

def __init__(
        self,
        heads,
        head_size,
        RoPE=True,
        use_bias=True,
        kernel_initializer='glorot_uniform',
        **kwargs
    ):
        super(GlobalPointer, self).__init__(**kwargs)
        self.heads = heads
        self.head_size = head_size
        ...

def call(self, inputs, mask=None):
        # 输入变换
        inputs = self.dense(inputs)
        inputs = tf.split(inputs, self.heads, axis=-1)
        ...

您的版本:

def __init__(self, encoder, ent_type_size, inner_dim, RoPE=True):
    super().__init__()
    self.encoder = encoder
    self.ent_type_size = ent_type_size # 实体类型个数
    self.inner_dim = inner_dim # head_size??? head的维度大小???
    self.hidden_size = encoder.config.hidden_size
    self.dense = nn.Linear(self.hidden_size, self.ent_type_size * self.inner_dim * 2)
    ......

 def forward(self, input_ids, attention_mask, token_type_ids):
      .......
      outputs = self.dense(last_hidden_state)
      outputs = torch.split(outputs, self.inner_dim * 2, dim=-1)

按照苏建林版本我的理解是,head_size表示head头的大小,heads是head的个数,也就是实体类型的个数;在下面的split时按照实体类型的维度展开;
您的版本中torch.split按照head_size * 2的展开,这里是我理解的有问题还是有错误?麻烦指点,谢谢!

预测结果实体多且长的问题

请问作者有在其它数据集尝试GP模型吗?

有个疑惑,自己搜集了一些预料,用bert基础版预训练模型做了个训练,然后再测试集上预测时,出现了预测实体多且长的问题。

测试集一条样本:

{
        "text": "**政府宣布2019年国防开支将比前一年增长7.5%,超过预计今年的经济增长率。第十三届全国人大第二次会议星期二(2019年3月5日)在开幕时公布的政府预算报告显示,今年的国防开支将达到11899亿元人民币,相当于大约1780亿美元。外界一般认为,**实际的军事开支可能高出政府公开的国防预算金额。**国防部公布的消息说,今年的国防预算将重点支持国防和军队改革,全面推动国防和军队现代化建设。**每年一度的国防开支预告一直受到国际广泛关注。各国试图从中了解**战略意图的变化和发展。",
        "entities": [
            {
                "start_idx": 0,
                "end_idx": 4,
                "type": "ORG",
                "entity": "**政府"
            },
            {
                "start_idx": 6,
                "end_idx": 11,
                "type": "TIM",
                "entity": "2019年"
            },
            {
                "start_idx": 18,
                "end_idx": 20,
                "type": "NUM",
                "entity": "一年"
            },
            {
                "start_idx": 57,
                "end_idx": 66,
                "type": "TIM",
                "entity": "2019年3月5日"
            },
            {
                "start_idx": 93,
                "end_idx": 103,
                "type": "NUM",
                "entity": "11899亿元人民币"
            },
            {
                "start_idx": 109,
                "end_idx": 116,
                "type": "NUM",
                "entity": "1780亿美元"
            },
            {
                "start_idx": 124,
                "end_idx": 126,
                "type": "LOC",
                "entity": "**"
            },
            {
                "start_idx": 149,
                "end_idx": 154,
                "type": "ORG",
                "entity": "**国防部"
            },
            {
                "start_idx": 196,
                "end_idx": 198,
                "type": "LOC",
                "entity": "**"
            },
            {
                "start_idx": 228,
                "end_idx": 230,
                "type": "LOC",
                "entity": "**"
            }
        ]
}

预测结果:

{
        "text": "**政府宣布2019年国防开支将比前一年增长7.5%,超过预计今年的经济增长率。第十三届全国人大第二次会议星期二(2019年3月5日)在开幕时公布的政府预算报告显示,今年的国防开支将达到11899亿元人民币,相当于大约1780亿美元。外界一般认为,**实际的军事开支可能高出政府公开的国防预算金额。**国防部公布的消息说,今年的国防预算将重点支持国防和军队改革,全面推动国防和军队现代化建设。**每年一度的国防开支预告一直受到国际广泛关注。各国试图从中了解**战略意图的变化和发展。",
        "pred_entities": [
            {
                "start_idx": 0,
                "end_idx": 1,
                "type": "TIM",
                "entity": ""
            },
            {
                "start_idx": 0,
                "end_idx": 10,
                "type": "TIM",
                "entity": "**政府宣布2019"
            },
            {
                "start_idx": 0,
                "end_idx": 12,
                "type": "TIM",
                "entity": "**政府宣布2019年国"
            },
            {
                "start_idx": 0,
                "end_idx": 24,
                "type": "TIM",
                "entity": "**政府宣布2019年国防开支将比前一年增长7."
            },
            {
                "start_idx": 0,
                "end_idx": 34,
                "type": "TIM",
                "entity": "**政府宣布2019年国防开支将比前一年增长7.5%,超过预计今年的"
            },
            {
                "start_idx": 0,
                "end_idx": 46,
                "type": "TIM",
                "entity": "**政府宣布2019年国防开支将比前一年增长7.5%,超过预计今年的经济增长率。第十三届全国"
            },
            ...
            {
                "start_idx": 236,
                "end_idx": 240,
                "type": "WEA",
                "entity": "化和发展"
            },
            {
                "start_idx": 237,
                "end_idx": 240,
                "type": "WEA",
                "entity": "和发展"
            }
       ]
}

然后,分析了下代码,发现 decode_ent 这里预测实体的起止索引向量维度很高

d = np.where(pred_matrix > threshold)

print(np.array(d).shape)
# Out[4]: (3, 112304)

看起来,模型并没有很好地预测出实体的边界,自己检查过已标注的实体,index正常。想问下,作者有遇到类似情况吗?还是说GP模型在实体较长、嵌套较深、或者上下文信息较丰富的情况下就是会出现这种情况。感谢!

标签数量

您好,您的工作很好的解决了本人标签嵌套的问题,但本人所做任务的标签数足足有接近一万个(细粒度非常高),这使得self.dense成为了一个将近4G的线性层,且由于每个标签单独的占用一个(1, seq_len, seq_len)空间,则在训练时需要较大时间和显存成本,请问作者有没有针对这种高细粒度标签的NER模型呢?非常感谢!

去除padding部分,以及最后计算acc

我看苏神的代码里面有去除padding部分,还有就是最后计算acc,我看是除以y_pred.sum(),其实你计算的就是precision吧,感觉都是实体维度的,acc没必要了吧

ValueError,预训练模型问题

请问按照您提供的transfomer版本,下载不了您给出huggingface里的Bert-base-chineses,而且在其他代码里涉及transformer的地方也会有相同报错,可以指教一下吗
image
image
"Connection error, and we cannot find the requested files in the cached path."
ValueError: Connection error, and we cannot find the requested files in the cached path. Please try again or make sure your Internet connection is on.

似乎有个bug?

下图实现要找token_span,但是好像没考虑同名实体,比如例子(张三传是由张三在2021年拍摄),第一个 张三 可能是属于 movie实体,第二个张三是director实体;

但是下图while循环有个break,匹配到就跳出,以上面的例子看,如果要找第2个张三,似乎匹配到第一个张三就跳出了;

附:代码/common/utils.py/Preprocessor(clss)/get_ent2token_spans(func)
image
@gaohongkui

关于单标签问题

你好,我的数据集就只有一种实体需要识别,那么想请教一下对于本项目哪些地方需要改动,谢谢指点!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.