DALLE 2 文生图模型实践指南

编程入门 行业动态 更新时间:2024-10-20 01:24:50

DALLE 2 文生图<a href=https://www.elefans.com/category/jswz/34/1771358.html style=模型实践指南"/>

DALLE 2 文生图模型实践指南

前言:最近在运行dalle2模型进行推断,本篇博客记录相关资料。

相关博客:超详细!DALL · E 文生图模型实践指南


目录

  • 1. 环境搭建和预训练模型准备
    • 环境搭建
    • 预训练模型下载
  • 2. 代码
  • 3. BUG&DEBUG
    • URLError
    • RuntimeError
    • CUDA error


1. 环境搭建和预训练模型准备

本文使用的代码仓库为:

环境搭建

pip install dalle2-pytorch

预训练模型下载

地址:

2. 代码

DALLE2 for inference 完整推断流程如下(from cest_andre):

import torch
from torchvision.transforms import ToPILImage
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter, Decoder, DALLE2
from dalle2_pytorch.train_configs import TrainDiffusionPriorConfig, TrainDecoderConfigprior_config = TrainDiffusionPriorConfig.from_json_path("weights/prior_config.json").prior
prior = prior_config.create().cuda()prior_model_state = torch.load("weights/prior_latest.pth")
prior.load_state_dict(prior_model_state, strict=True)decoder_config = TrainDecoderConfig.from_json_path("weights/decoder_config.json").decoder
decoder = decoder_config.create().cuda()decoder_model_state = torch.load("weights/decoder_latest.pth")["model"]for k in decoder.clip.state_dict().keys():decoder_model_state["clip." + k] = decoder.clip.state_dict()[k]decoder.load_state_dict(decoder_model_state, strict=True)dalle2 = DALLE2(prior=prior, decoder=decoder).cuda()images = dalle2(['your prompt here'],cond_scale = 2.
).cpu()print(images.shape)for img in images:img = ToPILImage()(img)img.show()

3. BUG&DEBUG

URLError

报错信息如下:

Traceback (most recent call last):File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 1350, in do_openh.request(req.get_method(), req.selector, req.data, headers,File "/root/anaconda3/envs/ldm/lib/python3.8/http/client.py", line 1255, in requestself._send_request(method, url, body, headers, encode_chunked)File "/root/anaconda3/envs/ldm/lib/python3.8/http/client.py", line 1301, in _send_requestself.endheaders(body, encode_chunked=encode_chunked)File "/root/anaconda3/envs/ldm/lib/python3.8/http/client.py", line 1250, in endheadersself._send_output(message_body, encode_chunked=encode_chunked)File "/root/anaconda3/envs/ldm/lib/python3.8/http/client.py", line 1010, in _send_outputself.send(msg)File "/root/anaconda3/envs/ldm/lib/python3.8/http/client.py", line 950, in sendself.connect()File "/root/anaconda3/envs/ldm/lib/python3.8/http/client.py", line 1424, in connectself.sock = self._context.wrap_socket(self.sock,File "/root/anaconda3/envs/ldm/lib/python3.8/ssl.py", line 500, in wrap_socketreturn self.sslsocket_class._create(File "/root/anaconda3/envs/ldm/lib/python3.8/ssl.py", line 1040, in _createself.do_handshake()File "/root/anaconda3/envs/ldm/lib/python3.8/ssl.py", line 1309, in do_handshakeself._sslobj.do_handshake()
ConnectionResetError: [Errno 104] Connection reset by peerDuring handling of the above exception, another exception occurred:Traceback (most recent call last):File "/newdata/SD/extra/dalle2_cest.py", line 11, in <module>prior = prior_config.create().cuda()File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/dalle2_pytorch/train_configs.py", line 185, in createclip = self.clip.create()File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/dalle2_pytorch/train_configs.py", line 122, in createreturn OpenAIClipAdapter(self.model)File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/dalle2_pytorch/dalle2_pytorch.py", line 313, in __init__openai_clip, preprocess = clip.load(name)File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/clip/clip.py", line 122, in loadmodel_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/clip/clip.py", line 59, in _downloadwith urllib.request.urlopen(url) as source, open(download_target, "wb") as output:File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 222, in urlopenreturn opener.open(url, data, timeout)File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 525, in openresponse = self._open(req, data)File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 542, in _openresult = self._call_chain(self.handle_open, protocol, protocol +File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 502, in _call_chainresult = func(*args)File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 1393, in https_openreturn self.do_open(http.client.HTTPSConnection, req,File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 1353, in do_openraise URLError(err)
urllib.error.URLError: <urlopen error [Errno 104] Connection reset by peer>

我使用的是。

找到 /root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py 中对应的位置,我这里是第1349行,修改方式也在下面代码中一并给出。

try:h.request(req.get_method(), req.selector, req.data, headers,encode_chunked=req.has_header('Transfer-encoding'))time.sleep(0.5)  # 添加的一行
except OSError as err: # timeout errorraise URLError(err)

RuntimeError

Traceback (most recent call last):File "/newdata/SD/extra/dalle2_cest.py", line 14, in <module>prior.load_state_dict(prior_model_state, strict=True)File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1604, in load_state_dictraise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for DiffusionPrior:Missing key(s) in state_dict: "net.null_text_encodings", "net.null_text_embeds", "net.null_image_embed". Unexpected key(s) in state_dict: "net.null_text_embed". 

解决办法:load_state_dict()函数中的 strict=True 改为 strict=False,如下:

...
prior.load_state_dict(prior_model_state, strict=False)decoder.load_state_dict(decoder_model_state, strict=False)
...

CUDA error

RuntimeError: CUDA error: no kernel image is available for execution on the device
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

解决方法:版本不匹配,更换与系统cuda相匹配的pytorch版本。比如我的cuda版本是12.0,可以使用如下命令安装pytorch:

pip install torch==2.0.0+cu118 torchvision==0.15.1+cu118 torchaudio==2.0.1+cu118 -f .html

到这里,模型就可以完成推断过程啦~嘻嘻!


参考链接

  1. python requests请求报错ConnectionError: (‘Connection aborted.‘, error(104, ‘Connection reset by peer‘))_铁朵斯提的博客-CSDN博客
  2. GPU版本pytorch(Cuda12.1)清华源快速安装一步一步教!小白教学~_清华源安装torch-CSDN博客

更多推荐

DALLE 2 文生图模型实践指南

本文发布于:2023-11-15 00:31:04,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1590502.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:模型   指南   DALLE   文生图

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!