基于第二代 ChatGLM2

编程入门 行业动态 更新时间:2024-10-26 15:24:40

基于第二代 ChatGLM2

基于第二代 ChatGLM2

今天是教师节,恭祝全体老师们节日快乐!😊

一、ChatGLM2-6B

在本专栏前面文章中实验了使用 ChatYuan-large-v2 Freeze 微调训练医疗问答任务,训练后效果整体还可以,这篇文章继续探索使用最近比较火的 ChatGLM 官方推出的 p-tuning-v2 的方式训练医疗问答任务。而对于 ChatGLM 模型则使用新出不久的 ChatGLM2-6B

ChatGLM2-6BChatGLM-6B 的第二代版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上,同时引入了许多新特性,如:更强大的性能、更长的上下文、更高效的推理、更开放的协议 等。

更多详细的介绍可参考官方 github

官方 github 地址:

P-tuning v2 微调技术利用 deep prompt tuning,即对预训练 Transformer 的每一层输入应用 continuous promptsdeep prompt tuning 增加了 continuo us prompts 的能力,并缩小了跨各种设置进行微调的差距,特别是对于小型模型和困难任务。


上图左边为 P-Tuning,右边为P-Tuning v2P-Tuning v2 层与层之间的 continuous prompt 是相互独立的。

论文地址:.07602.pdf

github地址:

二、ChatGLM2-6B 模型下载

huggingface 地址:

三、数据集处理

数据集还是使用 GitHub 上的 Chinese-medical-dialogue-data 中文医疗对话数据集。

GitHub 地址如下:

数据分了 6 个科目类型:

数据格式如下所示:

其中 ask 为病症的问题描述,answer 为病症的回答。

整体加起来数据比较多,这里为了演示效果,只训练 内科、肿瘤科、儿科、外科 四个科目的数据,并且每个科目取前 10000 条数据进行训练、2000 条数据进行验证:

import json
import pandas as pddata_path = ["./data/Chinese-medical-dialogue-data-master/Data_数据/IM_内科/内科5000-33000.csv","./data/Chinese-medical-dialogue-data-master/Data_数据/Oncology_肿瘤科/肿瘤科5-10000.csv","./data/Chinese-medical-dialogue-data-master/Data_数据/Pediatric_儿科/儿科5-14000.csv","./data/Chinese-medical-dialogue-data-master/Data_数据/Surgical_外科/外科5-14000.csv",
]train_json_path = "./data/train.json"
val_json_path = "./data/val.json"
# 每个数据取 10000 条作为训练
train_size = 10000
# 每个数据取 2000 条作为验证
val_size = 2000def doHandler():train_f = open(train_json_path, "a", encoding='utf-8')val_f = open(val_json_path, "a", encoding='utf-8')for path in data_path:data = pd.read_csv(path, encoding='ANSI')train_count = 0val_count = 0for index, row in data.iterrows():ask = row["ask"]answer = row["answer"]line = {"content": ask,"summary": answer}line = json.dumps(line, ensure_ascii=False)if train_count < train_size:train_f.write(line + "\n")train_count = train_count + 1elif val_count < val_size:val_f.write(line + "\n")val_count = val_count + 1else:breakprint("数据处理完毕!")train_f.close()val_f.close()if __name__ == '__main__':doHandler()

处理之后可以看到两个生成的文件:

四、P-Tuning v2 训练

拉取官网训练脚本:

git clone 

下载相应依赖:

pip install -r requirements.txt -i 

此外还需安装:

pip install rouge_chinese nltk jieba datasets -i 

修改 ptuning 下的 train.sh 文件:

PRE_SEQ_LEN=300
LR=2e-2
NUM_GPUS=1torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS main.py \--do_train \--train_file data/train.json \--validation_file data/val.json \--preprocessing_num_workers 10 \--prompt_column content \--response_column summary \--overwrite_cache \--model_name_or_path /home/chatglm2/chatglm-6b \--output_dir output/adgen-chatglm2-6b-pt-$PRE_SEQ_LEN-$LR \--overwrite_output_dir \--max_source_length 300 \--max_target_length 1024 \--per_device_train_batch_size 1 \--per_device_eval_batch_size 1 \--gradient_accumulation_steps 16 \--predict_with_generate \--max_steps 3000 \--logging_steps 10 \--save_steps 1000 \--learning_rate $LR \--pre_seq_len $PRE_SEQ_LEN \--quantization_bit 4

其中 参数解释如下:

–standalone` 以单机模式训练。

–nnodes` 节点数。这里只有一个节点,设置为 1。

–nproc-per-node` 每个节点上的进程数。

–do_train` 执行训练任务。

–train_file` 训练数据文件路径, 上面生成的 train.json 文件。

–validation_file` 验证数据文件路径, 上面生成的 val.json 文件。

–preprocessing_num_workers` 指定数据预处理时的 workers 数。

–prompt_column` 输入信息的字段名称。

–response_column` 输出信息的字段名称。

–overwrite_cache` 覆盖缓存文件。

–model_name_or_path` 预训练模型的名称或路径,注意这里我是用的下载后的模型存放地址,需要修改为你的。

–output_dir` 模型保存目录。

–overwrite_output_dir` 覆盖输出目录。

–max_source_length` 输入文本的最大长度。

–max_target_length` 输出文本的最大长度。

–per_device_train_batch_size` 训练时的批次大小。

–per_device_eval_batch_size` 验证时的批次大小。

–gradient_accumulation_steps` 累积多少个梯度之后再进行一次反向传播。

–predict_with_generate` 预测时使用生成模式。

–max_steps` 最大训练轮数。

–logging_steps` 多少轮打印一次日志。

–save_steps` 多少轮保存一次模型。

–learning_rate` 初始学习率。

–pre_seq_len` 预处理时选取的序列长度。

–quantization_bit` 量化位大小。

执行后可以看到如下打印日志:

训练过程:

训练结束:

最后在 output 目录下可以看到每 1000 步保存的模型。

五、模型测试

5.1 单独调用测试:

from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, AutoModel, AutoConfig
import uvicorn, json, datetime
import torch
import osdef main():pre_seq_len = 300# 训练权重地址checkpoint_path = "ptuning/output/adgen-chatglm2-6b-pt-300-2e-2/checkpoint-3000"tokenizer = AutoTokenizer.from_pretrained("chatglm-6b", trust_remote_code=True)config = AutoConfig.from_pretrained("chatglm-6b", trust_remote_code=True, pre_seq_len=pre_seq_len)model = AutoModel.from_pretrained("chatglm-6b", config=config, device_map="auto", trust_remote_code=True)prefix_state_dict = torch.load(os.path.join(checkpoint_path, "pytorch_model.bin"))new_prefix_state_dict = {}for k, v in prefix_state_dict.items():if k.startswith("transformer.prefix_encoder."):new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = vmodel.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)# 量化model = model.quantize(4)model.eval()# 问题question = "突然感到了不适,去检查后竟然得了这个病,请问:宝宝白天爱磨牙会是哪些情况呢"response, history = model.chat(tokenizer,question,history=[],max_length=2048,top_p=0.7,temperature=0.95)print("回答:", response)if torch.backends.mps.is_available():torch.mps.empty_cache()if __name__ == '__main__':main()

回答: 孩子磨牙可能会是缺钙引来的,建议带孩子去医院仔细检查下微量元素,明确病因后有针对性的治疗。平时要留意孩子的饮食卫生,防止排便辛辣刺激性食物,多给孩子喝温开水,多吃蔬菜水果,消化维生素,增进胃肠道扭动。对于家长朋友们来说,要尽可能的帮助孩子及时治疗疾病,另外宝宝在日常生活中饮食也要注意,要营养的均衡,不要过度进补也不要营养不良哦。

5.2 封装成 Api 测试

from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, AutoModel, AutoConfig
import uvicorn, json, datetime
import torch
import osapp = FastAPI()# 允许所有域的请求
app.add_middleware(CORSMiddleware,allow_origins=["*"],allow_credentials=True,allow_methods=["*"],allow_headers=["*"],
)@app.post("/")
async def create_item(request: Request):global model, tokenizerjson_post_raw = await request.json()json_post = json.dumps(json_post_raw)json_post_list = json.loads(json_post)prompt = json_post_list.get('prompt')history = json_post_list.get('history')max_length = json_post_list.get('max_length')top_p = json_post_list.get('top_p')temperature = json_post_list.get('temperature')response, history = model.chat(tokenizer,prompt,history=history,max_length=max_length if max_length else 2048,top_p=top_p if top_p else 0.7,temperature=temperature if temperature else 0.95)now = datetime.datetime.now()time = now.strftime("%Y-%m-%d %H:%M:%S")answer = {"response": response,"history": history,"status": 200,"time": time}log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'print(log)if torch.backends.mps.is_available():torch.mps.empty_cache()return answerif __name__ == '__main__':pre_seq_len = 300checkpoint_path = "ptuning/output/adgen-chatglm2-6b-pt-300-2e-2/checkpoint-3000"tokenizer = AutoTokenizer.from_pretrained("chatglm-6b", trust_remote_code=True)config = AutoConfig.from_pretrained("chatglm-6b", trust_remote_code=True, pre_seq_len=pre_seq_len)model = AutoModel.from_pretrained("chatglm-6b", config=config, device_map="auto", trust_remote_code=True)prefix_state_dict = torch.load(os.path.join(checkpoint_path, "pytorch_model.bin"))new_prefix_state_dict = {}for k, v in prefix_state_dict.items():if k.startswith("transformer.prefix_encoder."):new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = vmodel.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)## 量化model = model.quantize(4)model = model.cuda()model.eval()uvicorn.run(app, host='0.0.0.0', port=8103, workers=1)

使用 postMan 测试:

最后测试下原有知识的影响:

更多推荐

基于第二代 ChatGLM2

本文发布于:2024-02-17 10:11:25,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1693653.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!