跳转至

Calibrator

功能说明

量化参数配置类,通过Calibrator类封装量化算法。

函数原型

Calibrator(model, cfg, calib_data=None, fuse_module_call_back=None)

参数说明

参数名 输入/返回值 含义 使用限制
model 输入 待量化模型实例。 必选。
数据类型:PyTorch模型。
cfg 输入 已配置的QuantConfig类。 必选。
数据类型:QuantConfig。
calib_data 输入 模型训练数据,可输入真实数据用于Label-Free量化,也可输入虚拟数据来实现Label-Free量化。 可选。
数据类型:list[list[Torch.Tensor]] 或list[Torch.Tensor]。
如果不输入数据,在模型支持单个float格式输入且指定了input_shape时,会自动调用Label-Free量化流程。针对多个输入或者需要自定义输入格式的模型,用户可随机构造输入数据来实现Label-Free量化。
fuse_module_call_back 输入 BN融合用户自定义函数,在量化前会调用该回调。 可选。
数据类型:function。
如果模型结构特殊,不是conv->bn并列结构的,需要用户传入自定义融合函数。

调用示例

from msmodelslim.pytorch.quant.ptq_tools import QuantConfig, Calibrator
disable_names = []
input_shape = [1, 3, 224, 224]
quant_config = QuantConfig(disable_names=disable_names, amp_num=0, input_shape=input_shape)
calib_data = []
image = cv2.imdecode(np.fromfile("./random_image.jpg", dtype=np.uint8), 1)
image = cv2.resize(image, (224, 224,), interpolation=cv2.INTER_CUBIC)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 
image = torch.from_numpy(image).permute(2, 0, 1)/255
image = image.unsqueeze(0)
calib_data.append([image])     #传入一张随机图片数据,用于提高精度
calibrator = Calibrator(model, quant_config, calib_data=calib_data)