Transformers框架之Model学习

Transformers快速入门中介绍了pipeline()方法的使用。Model是Transformers中重要组件之一,主要负责加载和使用模型。接下来,以翻译任务为例子,使用huggingface中的现有模型进行翻译。

一、快速使用

1
2
3
4
5
# pileline中指定从英文翻译到中文(只针对多语言,这里是可以不指定的)
from transformers import pipeline

pipe = pipeline("translation_EN_to_ZH", model="Helsinki-NLP/opus-mt-en-zh")
print(pipe("are you ok?")) #[{'translation_text': '你还好吗?'}]

接下来,就逐步操作,了解一下Model组件。

二、加载模型

我们需要先在huggingface中找到自己想要的模型,这里我选择Helsinki-NLP/opus-mt-en-zh,更多模型,可以在https://huggingface.co/models中获取

在线加载模型

1
2
3
4
5
# Load model directly
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-zh")
model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-zh")

它自动从huggingface下载完模型后,会保存到 ~/.cache/huggingface/hub/models 目录下。

保存模型到本地

保存模型到本地,方便下次加载:

1
2
3
4
# 保存模型
model.save_pretrained("../model/opus-mt-en-zh")
# 分词器也保存一下
tokenizer.save_pretrained("../model/opus-mt-en-zh")

离线加载模型

1
2
tokenizer = AutoTokenizer.from_pretrained("../model/opus-mt-en-zh")  
model = AutoModelForSeq2SeqLM.from_pretrained("../model/opus-mt-en-zh")

使用模型

1
2
3
4
5
6
7
# 编码
inputs = tokenizer("are you ok?", return_tensors="pt")
# 模型预测
outputs = model.generate(**inputs)
# 解码
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
# 你还好吗?

三、查看与修改模型配置

1
2
3
print(model.config)
# 内容太多,就不全放出来了。只看一个配置
# "max_length": 512 输入和输出最在长度是512

本章暂只介绍一下Model的基础用法,比如训练模型操作,后续会给出示例。

更多内容