|
14 | 14 | from .constrain.cmodule import constrain_module, _CMODULE_TABLE |
15 | 15 | from typing import Any, Dict, List, Optional, Union |
16 | 16 |
|
| 17 | +def fuse_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): |
| 18 | + |
| 19 | + eps = 1e-5 |
| 20 | + clamp_conv_name = prefix + 'conv' |
| 21 | + clamp_bn_name = prefix + 'bn' |
| 22 | + conv_int_name = prefix |
| 23 | + if clamp_conv_name + '.weight' in state_dict and clamp_bn_name + '.weight' in state_dict: |
| 24 | + b_mean = state_dict[clamp_bn_name + '.running_mean'] |
| 25 | + b_var = state_dict[clamp_bn_name + '.running_var'] |
| 26 | + b_w = state_dict[clamp_bn_name + '.weight'] |
| 27 | + b_b = state_dict[clamp_bn_name + '.bias'] |
| 28 | + sigma = 1 / torch.sqrt(b_var + eps) |
| 29 | + alpha = b_w * sigma |
| 30 | + beta = b_b - b_mean * alpha |
| 31 | + c_w = state_dict[clamp_conv_name + '.weight'] |
| 32 | + state_dict[conv_int_name + |
| 33 | + 'weight'] = (c_w * alpha.view(-1, *([1]*(len(c_w.shape)-1)))) |
| 34 | + if clamp_conv_name + '.bias' in state_dict: |
| 35 | + c_b = state_dict[clamp_conv_name + '.bias'] |
| 36 | + state_dict[conv_int_name + 'bias'] = (c_b * alpha + beta) |
| 37 | + state_dict.pop(clamp_conv_name + '.bias') |
| 38 | + else: |
| 39 | + state_dict[conv_int_name + 'bias'] = beta |
| 40 | + state_dict.pop(clamp_bn_name + '.running_mean') |
| 41 | + state_dict.pop(clamp_bn_name + '.running_var') |
| 42 | + state_dict.pop(clamp_bn_name + '.weight') |
| 43 | + state_dict.pop(clamp_bn_name + '.bias') |
| 44 | + state_dict.pop(clamp_bn_name + '.num_batches_tracked') |
| 45 | + state_dict.pop(clamp_conv_name + '.weight') |
| 46 | + else: |
| 47 | + assert clamp_conv_name + '.weight' not in state_dict and clamp_bn_name + \ |
| 48 | + '.weight' not in state_dict, 'load quanted model but contain float clamp params' |
| 49 | + |
17 | 50 | @contextmanager |
18 | | -def calibration(a_calibrate_name='top_10', w_calibrate_name='abs_max'): |
| 51 | +def calibration(): |
19 | 52 | # 保存旧值 |
20 | | - old_a_calibrate_name = QUANT_CONFIGS.quant_info.a_calibrate_name |
21 | | - old_w_calibrate_name = QUANT_CONFIGS.quant_info.w_calibrate_name |
| 53 | + # old_a_calibrate_name = QUANT_CONFIGS.quant_info.a_calibrate_name |
| 54 | + # old_w_calibrate_name = QUANT_CONFIGS.quant_info.w_calibrate_name |
22 | 55 | try: |
23 | 56 | QUANT_CONFIGS.calibration = True |
24 | | - QUANT_CONFIGS.quant_info.a_calibrate_name = a_calibrate_name |
25 | | - QUANT_CONFIGS.quant_info.w_calibrate_name = w_calibrate_name |
| 57 | + # QUANT_CONFIGS.quant_info.a_calibrate_name = a_calibrate_name |
| 58 | + # QUANT_CONFIGS.quant_info.w_calibrate_name = w_calibrate_name |
26 | 59 | yield # <<< 关键点:控制权交给 with 块 |
27 | 60 | finally: |
28 | 61 | QUANT_CONFIGS.calibration = False |
@@ -84,15 +117,20 @@ def init(model: nn.Module, config_file: str = None, disable_module=None, disable |
84 | 117 | # traced_model = symbolic_trace(model) |
85 | 118 | # model = _replace_ops(traced_model, q_configs) |
86 | 119 |
|
| 120 | + has_replaced = [] |
87 | 121 | for name, m in model.named_modules(): |
88 | 122 | if disable_submodel is not None and any(fnmatch(name, pattern) for pattern in disable_submodel): |
89 | 123 | continue |
| 124 | + if any(name.startswith(p + ".") for p in has_replaced): |
| 125 | + continue |
90 | 126 |
|
91 | 127 | m.register_forward_pre_hook(hook_pre_forward) |
92 | 128 | m.register_forward_hook(hook_forward) |
93 | 129 |
|
94 | | - _quantize_submodule(model, name, m, weights_cfg=q_configs.quant_info.to_dict(), activations_cfg=q_configs.quant_info.to_dict(), bias_cfg=q_configs.quant_info.to_dict(), constrain = q_configs.clamp_info.to_dict()) |
95 | | - |
| 130 | + is_replaced = _quantize_submodule(model, name, m, weights_cfg=q_configs.quant_info.to_dict(), activations_cfg=q_configs.quant_info.to_dict(), bias_cfg=q_configs.quant_info.to_dict(), constrain = q_configs.clamp_info.to_dict()) |
| 131 | + if is_replaced: |
| 132 | + has_replaced.append(name) |
| 133 | + |
96 | 134 | def quant_tensor_pre_hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): |
97 | 135 |
|
98 | 136 | def quant_tensor_layer(module, prefix=''): |
@@ -212,13 +250,21 @@ def _quantize_submodule( |
212 | 250 | constrain: Optional[Union[str]] = None, |
213 | 251 | ): |
214 | 252 | qmodule = quantize_module(module, weights_cfg=weights_cfg, activations_cfg=activations_cfg, bias_cfg = bias_cfg, dim = getattr(module, "dim", None), constrain = constrain) |
| 253 | + if isinstance(module, ConvBN1d) or isinstance(module, CConvBN1d) \ |
| 254 | + or isinstance(module, ConvBN2d) or isinstance(module, CConvBN2d) \ |
| 255 | + or isinstance(module, ConvTransposeBN1d) or isinstance(module, CConvTransposeBN1d) \ |
| 256 | + or isinstance(module, ConvTransposeBN2d) or isinstance(module, CConvTransposeBN2d): |
| 257 | + qmodule._register_load_state_dict_pre_hook(fuse_state_dict) |
| 258 | + |
215 | 259 | if qmodule is not None: |
216 | 260 | _set_module_by_name(model, name, qmodule) |
217 | 261 | qmodule.name = name |
218 | 262 | for name, param in module.named_parameters(): |
219 | 263 | # Save device memory by clearing parameters |
220 | 264 | setattr(module, name, None) |
221 | 265 | del param |
| 266 | + return True |
| 267 | + return False |
222 | 268 |
|
223 | 269 | def _constrain_submodule( |
224 | 270 | model: torch.nn.Module, |
|
0 commit comments