THOP: 统计 PyTorch 模型的 FLOPs 和参数量

编程入门 行业动态 更新时间:2024-10-25 14:31:37

THOP: 统计 PyTorch <a href=https://www.elefans.com/category/jswz/34/1771358.html style=模型的 FLOPs 和参数量"/>

THOP: 统计 PyTorch 模型的 FLOPs 和参数量

THOP 是 PyTorch 非常实用的一个第三方库,可以统计模型的 FLOPs 和参数量。使用方法为:

from thop import clever_format
from thop import profileclass YourModule(nn.Module):# your definition
def count_your_model(model, x, y):# your rule hereinput = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ), custom_ops={YourModule: count_your_model})
flops, params = clever_format([flops, params], "%.3f")

profile 函数实现机制是利用了 PyTorch 的 torch.nn.Module.register_forward_hook。

profile

    handler_collection = []if custom_ops is None:custom_ops = {}

嵌套定义add_hooks函数,仅作用于网络的叶子节点。
torch.nn.Module.register_buffer 向模块添加持久缓冲区。这通常用于注册模型参数之外的缓冲区。例如,BatchNorm 的running_mean不是参数,而是持久状态的一部分。可以使用给定名称作为属性访问缓冲区。
torch.numel 返回input张量中的元素总数。
全局变量 register_hooks 定义了每种 op 对应的钩子函数,具体定义在 count_hooks.py 中。
torch.nn.Module.register_forward_hook 注册模块上的前向挂钩。

每次在 forward() 计算输出后都会调用该钩子。它应该有以下签名:

hook(module, input, output) -> None

钩子不应该修改输入或输出。返回类型为torch.utils.hooks.RemovableHandle

    def add_hooks(m):if len(list(m.children())) > 0:returnif hasattr(m, "total_ops") or hasattr(m, "total_params"):logger.warning("Either .total_ops or .total_params is already defined in %s." "Be careful, it might change your code's behavior." % str(m))m.register_buffer('total_ops', torch.zeros(1))m.register_buffer('total_params', torch.zeros(1))for p in m.parameters():m.total_params += torch.Tensor([p.numel()])m_type = type(m)fn = Noneif m_type in custom_ops:  # if defined both op maps, use custom_ops to overwrite.fn = custom_ops[m_type]elif m_type in register_hooks:fn = register_hooks[m_type]if fn is None:if verbose:print("THOP has not implemented counting method for ", m)else:if verbose:print("Register FLOP counter for module %s" % str(m))handler = m.register_forward_hook(fn)handler_collection.append(handler)

预先获取模型的模式,后面进行恢复。
torch.nn.Module.apply 将fn递归地应用于每个子模块(由.children()返回)以及自身。 典型用途包括初始化模型的参数(另请参见 torch-nn-init)。
运行网络。

    # original_device = model.parameters().__next__().devicetraining = model.trainingmodel.eval()model.apply(add_hooks)with torch.no_grad():model(*inputs)

遍历叶子节点,即有效实体,统计计算量和参数量。

    total_ops = 0total_params = 0for m in model.modules():if len(list(m.children())) > 0:  # skip for non-leaf modulecontinuetotal_ops += m.total_opstotal_params += m.total_paramstotal_ops = total_ops.item()total_params = total_params.item()

清空handler_collection中的元素句柄。

    # reset model to original statusmodel.train(training)for handler in handler_collection:handler.remove()return total_ops, total_params

更多推荐

THOP: 统计 PyTorch 模型的 FLOPs 和参数量

本文发布于:2024-02-17 04:14:06,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1692588.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:模型   参数   THOP   PyTorch   FLOPs

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!