GithubHelp home page GithubHelp logo

模型转换pytorch问题 about albert_zh HOT 5 OPEN

dalinvip avatar dalinvip commented on June 5, 2024
模型转换pytorch问题

from albert_zh.

Comments (5)

DoverDW avatar DoverDW commented on June 5, 2024

请问您解决了吗

from albert_zh.

msclock avatar msclock commented on June 5, 2024

@DoverDW 可以转的参考 GLUE 或者 albert_pytorch 仓库

from albert_zh.

DoverDW avatar DoverDW commented on June 5, 2024

@msclock 请问按照albert_pytorch
/convert_albert_tf_checkpoint_to_pytorch.py 文件来就可以吗

from albert_zh.

msclock avatar msclock commented on June 5, 2024

@DoverDW
albert_zh生成的好像有两个版本,一个是https://github.com/brightmart/albert_zh/blob/master/modeling.py的, 一个是https://github.com/brightmart/albert_zh/blob/master/modeling_google.py版本.
modeling_google版本好像可以直接用huggingface transformers albert转换,我这边项目做一个分类的子任务拿到是这个https://github.com/brightmart/albert_zh/blob/master/modeling.py 生成保存的ckpt模型, 里面对应/workspaces/ai-serving-solution/CLUE/baselines/models_pytorch/classifier_pytorch/convert_albert_original_tf_checkpoint_to_pytorch.py 转换脚本, 还要改一下,增加分类子任务的输出权重绑定到模型属性方式/workspaces/ai-serving-solution/CLUE/baselines/models_pytorch/classifier_pytorch/transformers/modeling_albert.py的函数load_tf_weights_in_albert

    for name, array in zip(names, arrays):
        name = name.split("/")
        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
        # which are not required for using pretrained model
        if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
            logger.info("Skipping {}".format("/".join(name)))
            continue

        # Classifier 这里把albert_zh中的输出添加前缀,方便后面代码绑定到对应的模型属性权重上
        if len(name) == 1 and ("output_bias" in name or "output_weights" in name):
            name = ["classifier"] + name

        pointer = model

然后再进行加载

from transformers.modeling_albert import AlbertForSequenceClassification
from transformers.tokenization_bert import BertTokenizer
from transformers.configuration_bert import BertConfig
import torch

news_categories = [
    "other",
    "drawing_name",
    "draing_number",
]
idx2cate = {i: item for i, item in enumerate(news_categories)}

config = BertConfig.from_pretrained(
    "/workspaces/ai-serving-solution/deploy/ai_recognition/analysis/albert/multiclass_output/signature1.1/",
    num_labels=len(news_categories),
)
tokenizer = BertTokenizer.from_pretrained(
    "/workspaces/ai-serving-solution/deploy/ai_recognition/analysis/albert/multiclass_output/signature1.1/",
    padding=True,
)
model = AlbertForSequenceClassification.from_pretrained(
    "/workspaces/ai-serving-solution/deploy/ai_recognition/analysis/albert/multiclass_output/signature1.1/",
    from_tf=True,
    config=config,
)
pytorch_dump_path = "/workspaces/ai-serving-solution/deploy/ai_recognition/analysis/albert/multiclass_output/signature1.1/pytorch_model.bin"
print("Save PyTorch model to {}".format(pytorch_dump_path))
torch.save(model.state_dict(), pytorch_dump_path)
token_codes = tokenizer.encode("主体结构中板梁配筋图", max_length=24)
input_ids = torch.tensor(token_codes).unsqueeze(0)  # Batch size 1
outputs = model(input_ids)
# get output probabilities by doing softmax
probs = outputs[0].softmax(1)
# executing argmax function to get the candidate label index
label_index = probs.argmax(dim=1)[0].tolist()
# get the label name
label = idx2cate[label_index]
# get the label probability
proba = probs.tolist()[0][label_index]
print({"label": label, "proba": proba})

最后,建议直接用huggingface transformers albert 上现有的预训练模型直接拿来用, 上面的步骤太冗余了

from albert_zh.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.