Comments (5)
请问您解决了吗
from albert_zh.
@DoverDW 可以转的参考 GLUE 或者 albert_pytorch 仓库
from albert_zh.
@msclock 请问按照albert_pytorch
/convert_albert_tf_checkpoint_to_pytorch.py 文件来就可以吗
from albert_zh.
@DoverDW
albert_zh生成的好像有两个版本,一个是https://github.com/brightmart/albert_zh/blob/master/modeling.py的, 一个是https://github.com/brightmart/albert_zh/blob/master/modeling_google.py版本.
modeling_google版本好像可以直接用huggingface transformers albert转换,我这边项目做一个分类的子任务拿到是这个https://github.com/brightmart/albert_zh/blob/master/modeling.py 生成保存的ckpt模型, 里面对应/workspaces/ai-serving-solution/CLUE/baselines/models_pytorch/classifier_pytorch/convert_albert_original_tf_checkpoint_to_pytorch.py 转换脚本, 还要改一下,增加分类子任务的输出权重绑定到模型属性方式/workspaces/ai-serving-solution/CLUE/baselines/models_pytorch/classifier_pytorch/transformers/modeling_albert.py的函数load_tf_weights_in_albert
for name, array in zip(names, arrays):
name = name.split("/")
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
logger.info("Skipping {}".format("/".join(name)))
continue
# Classifier 这里把albert_zh中的输出添加前缀,方便后面代码绑定到对应的模型属性权重上
if len(name) == 1 and ("output_bias" in name or "output_weights" in name):
name = ["classifier"] + name
pointer = model
然后再进行加载
from transformers.modeling_albert import AlbertForSequenceClassification
from transformers.tokenization_bert import BertTokenizer
from transformers.configuration_bert import BertConfig
import torch
news_categories = [
"other",
"drawing_name",
"draing_number",
]
idx2cate = {i: item for i, item in enumerate(news_categories)}
config = BertConfig.from_pretrained(
"/workspaces/ai-serving-solution/deploy/ai_recognition/analysis/albert/multiclass_output/signature1.1/",
num_labels=len(news_categories),
)
tokenizer = BertTokenizer.from_pretrained(
"/workspaces/ai-serving-solution/deploy/ai_recognition/analysis/albert/multiclass_output/signature1.1/",
padding=True,
)
model = AlbertForSequenceClassification.from_pretrained(
"/workspaces/ai-serving-solution/deploy/ai_recognition/analysis/albert/multiclass_output/signature1.1/",
from_tf=True,
config=config,
)
pytorch_dump_path = "/workspaces/ai-serving-solution/deploy/ai_recognition/analysis/albert/multiclass_output/signature1.1/pytorch_model.bin"
print("Save PyTorch model to {}".format(pytorch_dump_path))
torch.save(model.state_dict(), pytorch_dump_path)
token_codes = tokenizer.encode("主体结构中板梁配筋图", max_length=24)
input_ids = torch.tensor(token_codes).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
# get output probabilities by doing softmax
probs = outputs[0].softmax(1)
# executing argmax function to get the candidate label index
label_index = probs.argmax(dim=1)[0].tolist()
# get the label name
label = idx2cate[label_index]
# get the label probability
proba = probs.tolist()[0][label_index]
print({"label": label, "proba": proba})
最后,建议直接用huggingface transformers albert 上现有的预训练模型直接拿来用, 上面的步骤太冗余了
from albert_zh.
Related Issues (20)
- 打算训一个更小的预训练模型,你们预训练数据都是咋准备的啊? HOT 1
- 文件不存在呀
- how to load albert_zh model
- 有兄弟试过 iflytek 数据训练出的分类模型吗?albert_tiny 模型进行下游任务,得到的结果比较郁闷。
- 请问考虑将模型发布到 tfhub 上面吗?
- The exact English pretraining data and Chinese pretraining data that are exact same to the BERT paper's pretraining data.
- 句向量特征提取的最佳实践
- 预训练语料构造问题
- 'str' object has no attribute 'size' HOT 1
- 求助 HOT 2
- 请问中文albert训练过程中是按字符级分割还是按词语级分割? HOT 1
- 请问预训练模型是文本分类模型吗? 想做语言模型,预测下一个单字或词应该怎么修改?
- whats the difference between `albert_tiny_zh` and `albert_tiny_google_zh` HOT 1
- 数据集下载不了 HOT 1
- 请问DataProcessor类在哪里呢
- 想請問模型的 license
- 预训练的差异 HOT 1
- 有相关onnx模型转换与调用的支持嘛?
- 在预训练生成特定格式的文件(tfrecords) 时内存不足问题 HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from albert_zh.