tensorflow内存泄漏或模型只加载不运行

编程入门 行业动态 更新时间:2024-10-25 02:18:40

tensorflow内存泄漏或<a href=https://www.elefans.com/category/jswz/34/1771358.html style=模型只加载不运行"/>

tensorflow内存泄漏或模型只加载不运行

使用tf2模型进行推理的过程中,发现模型的内存占用在逐步增加,甚至会因为OOM被kill掉进程,有时候模型只加载不运行,搜索得到很多五花八门的答案,有些认为是tf2本身的问题,但在使用内存追踪的时候发现,是模型的动态图没有得到释放,而导致这个问题出现的原因,是数据的加载方式存在问题!!!

        mhc_a_batches = list(chunks(mhc_seqs_a, self.batch_size))mhc_b_batches = list(chunks(mhc_seqs_b, self.batch_size))pep_batches = list(chunks(pep_seqs, self.batch_size))assert len(mhc_a_batches) == len(mhc_b_batches)assert len(mhc_a_batches) == len(pep_batches)size = len(mhc_a_batches)# 开始预测preds = []for i in range(size):_preds = self.model([mhc_a_batches[i], mhc_b_batches[i], pep_batches[i]], training = False)preds.extend(_preds.numpy().tolist())return preds

如这段代码,直接使用了list作为模型的输入,尽管tf2也支持numpy的输入格式,但却存在隐患,会产生大量的空tensor!!!

将其改为这样的形式,问题得到解决:

 mhc_seqs_a = tf.convert_to_tensor(mhc_seqs_a, dtype=tf.float32)mhc_seqs_b = tf.convert_to_tensor(mhc_seqs_b, dtype=tf.float32)pep_seqs   = tf.convert_to_tensor(pep_seqs, dtype=tf.float32)assert len(mhc_seqs_a) == len(mhc_seqs_b)assert len(mhc_seqs_a) == len(pep_seqs)ds = tf.data.Dataset.from_tensor_slices((mhc_seqs_a, mhc_seqs_b, pep_seqs)).batch(self.batch_size).prefetch(1)preds = []for x, y, z in ds:_preds = self.model([x,y,z], training=False)preds.extend(_preds.numpy().tolist())return preds

现在可以愉快的进行模型推理了,而且速度比之前要快几倍不止,实测在GPU上提速近30倍,可想而知对于上亿级别的数据,节省的时间多么可观!

更多推荐

tensorflow内存泄漏或模型只加载不运行

本文发布于:2023-11-15 23:05:28,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1608625.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:模型   加载   内存   tensorflow

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!