跳转至

init

功能说明

PruneTorch类方法,对用户输入的模型进行类初始化。

函数原型

__init__(network, inputs)

参数说明

参数名 输入/返回值 含义 使用限制
network 输入 待剪枝模型实例。 必选。
数据类型:PyTorch模型。
inputs 输入 模型的输入数据,用于解析模型。 可选。
数据类型:Tensor。

调用示例

from msmodelslim.pytorch.prune.prune_torch import PruneTorch
model = torchvision.models.vgg16(pretrained=False)
model.eval()
prune_torch = PruneTorch(model, torch.ones([1, 3, 224, 224]).type(torch.float32))