模型的 FLOPs 和参数量"/>
THOP: 统计 PyTorch 模型的 FLOPs 和参数量
THOP 是 PyTorch 非常实用的一个第三方库,可以统计模型的 FLOPs 和参数量。使用方法为:
from thop import clever_format
from thop import profileclass YourModule(nn.Module):# your definition
def count_your_model(model, x, y):# your rule hereinput = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ), custom_ops={YourModule: count_your_model})
flops, params = clever_format([flops, params], "%.3f")
profile 函数实现机制是利用了 PyTorch 的 torch.nn.Module.register_forward_hook。
profile
handler_collection = []if custom_ops is None:custom_ops = {}
嵌套定义add_hooks
函数,仅作用于网络的叶子节点。
torch.nn.Module.register_buffer 向模块添加持久缓冲区。这通常用于注册模型参数之外的缓冲区。例如,BatchNorm 的running_mean
不是参数,而是持久状态的一部分。可以使用给定名称作为属性访问缓冲区。
torch.numel 返回input
张量中的元素总数。
全局变量 register_hooks 定义了每种 op 对应的钩子函数,具体定义在 count_hooks.py 中。
torch.nn.Module.register_forward_hook 注册模块上的前向挂钩。
每次在 forward() 计算输出后都会调用该钩子。它应该有以下签名:
hook(module, input, output) -> None
钩子不应该修改输入或输出。返回类型为torch.utils.hooks.RemovableHandle
。
def add_hooks(m):if len(list(m.children())) > 0:returnif hasattr(m, "total_ops") or hasattr(m, "total_params"):logger.warning("Either .total_ops or .total_params is already defined in %s." "Be careful, it might change your code's behavior." % str(m))m.register_buffer('total_ops', torch.zeros(1))m.register_buffer('total_params', torch.zeros(1))for p in m.parameters():m.total_params += torch.Tensor([p.numel()])m_type = type(m)fn = Noneif m_type in custom_ops: # if defined both op maps, use custom_ops to overwrite.fn = custom_ops[m_type]elif m_type in register_hooks:fn = register_hooks[m_type]if fn is None:if verbose:print("THOP has not implemented counting method for ", m)else:if verbose:print("Register FLOP counter for module %s" % str(m))handler = m.register_forward_hook(fn)handler_collection.append(handler)
预先获取模型的模式,后面进行恢复。
torch.nn.Module.apply 将fn
递归地应用于每个子模块(由.children()
返回)以及自身。 典型用途包括初始化模型的参数(另请参见 torch-nn-init)。
运行网络。
# original_device = model.parameters().__next__().devicetraining = model.trainingmodel.eval()model.apply(add_hooks)with torch.no_grad():model(*inputs)
遍历叶子节点,即有效实体,统计计算量和参数量。
total_ops = 0total_params = 0for m in model.modules():if len(list(m.children())) > 0: # skip for non-leaf modulecontinuetotal_ops += m.total_opstotal_params += m.total_paramstotal_ops = total_ops.item()total_params = total_params.item()
清空handler_collection
中的元素句柄。
# reset model to original statusmodel.train(training)for handler in handler_collection:handler.remove()return total_ops, total_params
更多推荐
THOP: 统计 PyTorch 模型的 FLOPs 和参数量
发布评论