基于第二代 ChatGLM2
今天是教师节,恭祝全体老师们节日快乐!😊
一、ChatGLM2-6B
在本专栏前面文章中实验了使用 ChatYuan-large-v2
Freeze
微调训练医疗问答任务,训练后效果整体还可以,这篇文章继续探索使用最近比较火的 ChatGLM
官方推出的 p-tuning-v2
的方式训练医疗问答任务。而对于 ChatGLM
模型则使用新出不久的 ChatGLM2-6B
。
ChatGLM2-6B
是 ChatGLM-6B
的第二代版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上,同时引入了许多新特性,如:更强大的性能、更长的上下文、更高效的推理、更开放的协议 等。
更多详细的介绍可参考官方 github
:
官方 github 地址:
P-tuning v2
微调技术利用 deep prompt tuning
,即对预训练 Transformer
的每一层输入应用 continuous prompts
。deep prompt tuning
增加了 continuo us prompts
的能力,并缩小了跨各种设置进行微调的差距,特别是对于小型模型和困难任务。
上图左边为 P-Tuning
,右边为P-Tuning v2
。P-Tuning v2
层与层之间的 continuous prompt
是相互独立的。
论文地址:.07602.pdf
github地址:
二、ChatGLM2-6B 模型下载
huggingface 地址:
三、数据集处理
数据集还是使用 GitHub
上的 Chinese-medical-dialogue-data
中文医疗对话数据集。
GitHub
地址如下:
数据分了 6
个科目类型:
数据格式如下所示:
其中 ask
为病症的问题描述,answer
为病症的回答。
整体加起来数据比较多,这里为了演示效果,只训练 内科、肿瘤科、儿科、外科
四个科目的数据,并且每个科目取前 10000
条数据进行训练、2000
条数据进行验证:
import json
import pandas as pddata_path = ["./data/Chinese-medical-dialogue-data-master/Data_数据/IM_内科/内科5000-33000.csv","./data/Chinese-medical-dialogue-data-master/Data_数据/Oncology_肿瘤科/肿瘤科5-10000.csv","./data/Chinese-medical-dialogue-data-master/Data_数据/Pediatric_儿科/儿科5-14000.csv","./data/Chinese-medical-dialogue-data-master/Data_数据/Surgical_外科/外科5-14000.csv",
]train_json_path = "./data/train.json"
val_json_path = "./data/val.json"
# 每个数据取 10000 条作为训练
train_size = 10000
# 每个数据取 2000 条作为验证
val_size = 2000def doHandler():train_f = open(train_json_path, "a", encoding='utf-8')val_f = open(val_json_path, "a", encoding='utf-8')for path in data_path:data = pd.read_csv(path, encoding='ANSI')train_count = 0val_count = 0for index, row in data.iterrows():ask = row["ask"]answer = row["answer"]line = {"content": ask,"summary": answer}line = json.dumps(line, ensure_ascii=False)if train_count < train_size:train_f.write(line + "\n")train_count = train_count + 1elif val_count < val_size:val_f.write(line + "\n")val_count = val_count + 1else:breakprint("数据处理完毕!")train_f.close()val_f.close()if __name__ == '__main__':doHandler()
处理之后可以看到两个生成的文件:
四、P-Tuning v2 训练
拉取官网训练脚本:
git clone
下载相应依赖:
pip install -r requirements.txt -i
此外还需安装:
pip install rouge_chinese nltk jieba datasets -i
修改 ptuning
下的 train.sh
文件:
PRE_SEQ_LEN=300
LR=2e-2
NUM_GPUS=1torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS main.py \--do_train \--train_file data/train.json \--validation_file data/val.json \--preprocessing_num_workers 10 \--prompt_column content \--response_column summary \--overwrite_cache \--model_name_or_path /home/chatglm2/chatglm-6b \--output_dir output/adgen-chatglm2-6b-pt-$PRE_SEQ_LEN-$LR \--overwrite_output_dir \--max_source_length 300 \--max_target_length 1024 \--per_device_train_batch_size 1 \--per_device_eval_batch_size 1 \--gradient_accumulation_steps 16 \--predict_with_generate \--max_steps 3000 \--logging_steps 10 \--save_steps 1000 \--learning_rate $LR \--pre_seq_len $PRE_SEQ_LEN \--quantization_bit 4
其中 参数解释如下:
–standalone` 以单机模式训练。
–nnodes` 节点数。这里只有一个节点,设置为 1。
–nproc-per-node` 每个节点上的进程数。
–do_train` 执行训练任务。
–train_file` 训练数据文件路径, 上面生成的 train.json 文件。
–validation_file` 验证数据文件路径, 上面生成的 val.json 文件。
–preprocessing_num_workers` 指定数据预处理时的 workers 数。
–prompt_column` 输入信息的字段名称。
–response_column` 输出信息的字段名称。
–overwrite_cache` 覆盖缓存文件。
–model_name_or_path` 预训练模型的名称或路径,注意这里我是用的下载后的模型存放地址,需要修改为你的。
–output_dir` 模型保存目录。
–overwrite_output_dir` 覆盖输出目录。
–max_source_length` 输入文本的最大长度。
–max_target_length` 输出文本的最大长度。
–per_device_train_batch_size` 训练时的批次大小。
–per_device_eval_batch_size` 验证时的批次大小。
–gradient_accumulation_steps` 累积多少个梯度之后再进行一次反向传播。
–predict_with_generate` 预测时使用生成模式。
–max_steps` 最大训练轮数。
–logging_steps` 多少轮打印一次日志。
–save_steps` 多少轮保存一次模型。
–learning_rate` 初始学习率。
–pre_seq_len` 预处理时选取的序列长度。
–quantization_bit` 量化位大小。
执行后可以看到如下打印日志:
训练过程:
训练结束:
最后在 output
目录下可以看到每 1000
步保存的模型。
五、模型测试
5.1 单独调用测试:
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, AutoModel, AutoConfig
import uvicorn, json, datetime
import torch
import osdef main():pre_seq_len = 300# 训练权重地址checkpoint_path = "ptuning/output/adgen-chatglm2-6b-pt-300-2e-2/checkpoint-3000"tokenizer = AutoTokenizer.from_pretrained("chatglm-6b", trust_remote_code=True)config = AutoConfig.from_pretrained("chatglm-6b", trust_remote_code=True, pre_seq_len=pre_seq_len)model = AutoModel.from_pretrained("chatglm-6b", config=config, device_map="auto", trust_remote_code=True)prefix_state_dict = torch.load(os.path.join(checkpoint_path, "pytorch_model.bin"))new_prefix_state_dict = {}for k, v in prefix_state_dict.items():if k.startswith("transformer.prefix_encoder."):new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = vmodel.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)# 量化model = model.quantize(4)model.eval()# 问题question = "突然感到了不适,去检查后竟然得了这个病,请问:宝宝白天爱磨牙会是哪些情况呢"response, history = model.chat(tokenizer,question,history=[],max_length=2048,top_p=0.7,temperature=0.95)print("回答:", response)if torch.backends.mps.is_available():torch.mps.empty_cache()if __name__ == '__main__':main()
回答: 孩子磨牙可能会是缺钙引来的,建议带孩子去医院仔细检查下微量元素,明确病因后有针对性的治疗。平时要留意孩子的饮食卫生,防止排便辛辣刺激性食物,多给孩子喝温开水,多吃蔬菜水果,消化维生素,增进胃肠道扭动。对于家长朋友们来说,要尽可能的帮助孩子及时治疗疾病,另外宝宝在日常生活中饮食也要注意,要营养的均衡,不要过度进补也不要营养不良哦。
5.2 封装成 Api 测试
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, AutoModel, AutoConfig
import uvicorn, json, datetime
import torch
import osapp = FastAPI()# 允许所有域的请求
app.add_middleware(CORSMiddleware,allow_origins=["*"],allow_credentials=True,allow_methods=["*"],allow_headers=["*"],
)@app.post("/")
async def create_item(request: Request):global model, tokenizerjson_post_raw = await request.json()json_post = json.dumps(json_post_raw)json_post_list = json.loads(json_post)prompt = json_post_list.get('prompt')history = json_post_list.get('history')max_length = json_post_list.get('max_length')top_p = json_post_list.get('top_p')temperature = json_post_list.get('temperature')response, history = model.chat(tokenizer,prompt,history=history,max_length=max_length if max_length else 2048,top_p=top_p if top_p else 0.7,temperature=temperature if temperature else 0.95)now = datetime.datetime.now()time = now.strftime("%Y-%m-%d %H:%M:%S")answer = {"response": response,"history": history,"status": 200,"time": time}log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'print(log)if torch.backends.mps.is_available():torch.mps.empty_cache()return answerif __name__ == '__main__':pre_seq_len = 300checkpoint_path = "ptuning/output/adgen-chatglm2-6b-pt-300-2e-2/checkpoint-3000"tokenizer = AutoTokenizer.from_pretrained("chatglm-6b", trust_remote_code=True)config = AutoConfig.from_pretrained("chatglm-6b", trust_remote_code=True, pre_seq_len=pre_seq_len)model = AutoModel.from_pretrained("chatglm-6b", config=config, device_map="auto", trust_remote_code=True)prefix_state_dict = torch.load(os.path.join(checkpoint_path, "pytorch_model.bin"))new_prefix_state_dict = {}for k, v in prefix_state_dict.items():if k.startswith("transformer.prefix_encoder."):new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = vmodel.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)## 量化model = model.quantize(4)model = model.cuda()model.eval()uvicorn.run(app, host='0.0.0.0', port=8103, workers=1)
使用 postMan
测试:
最后测试下原有知识的影响:
更多推荐
基于第二代 ChatGLM2
发布评论