Skip to content

Commit e85f004

Browse files
committed
[update]update ver:3.0.5
1 parent a23055f commit e85f004

29 files changed

Lines changed: 468 additions & 200 deletions

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ linger 基于 PyTorch 进行量化算子的搭建,因此只要符合 PyTorch
1818

1919
## 快速入门
2020
- [安装](doc/tutorial/install.md):支持pip、源码、docker三种安装方式
21-
- [量化训练快速入门](doc/tutorial/quant_quick_strat.md): 先进行浮点网络的约束训练,再针对量化友好的浮点模型进行量化训练微调
21+
- [量化训练快速入门](doc/tutorial/quant_quick_start.md): 先进行浮点网络的约束训练,再针对量化友好的浮点模型进行量化训练微调
2222
- [量化训练进阶指导](doc/tutorial/quant_advanced_guide.md): 量化进阶配置
2323
- [onnx导出教程](doc/tutorial/export_onnx.md):将量化无损的PyTorch模型导出为ONNX格式的模型
2424

doc/tutorial/install.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
### 创建虚拟环境
77
```Shell
8-
conda create -n linger_thinker_3.x python==3.12.10
8+
conda create -n linger_thinker_3.x python==3.10.0
99
conda activate linger_thinker_3.x
1010
pip install -U pip
1111
cat requirements.txt | xargs -n 1 pip install

doc/tutorial/linger_api.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
## linger 约束训练接口
1+
## linger normalize 约束训练接口
22

33

44
## linger 量化训练接口

doc/tutorial/quant_advanced_guide.md

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,16 @@
11
## 校准(PTQ)使用方法
22
* 因为校准时会默认按照weight的clip配置进行weight的初始化,故暂不支持循环多组数据校准(仅支持一轮输入校准)
33
* 校准时会创建add、bmm等小算子的module
4-
```python
5-
@linger.register_calibrate_method('custom_calibration')
6-
def test_init(self, tensor):
7-
with torch.no_grad():
8-
self.learning_data.fill_(torch.tensor(-999))
9-
self.scale.fill_(torch.tensor(-999))
10-
self.is_calibrate.fill_(True)
11-
12-
13-
with linger.calibration(a_calibrate_name="custom_calibration", w_calibrate_name="custom_calibration"):
14-
model = linger.init(model)
15-
model(torch.load("/yrfs4/inference/sqtu2/LLM/code/linger3.0/my_linger/calibrate_input.pt"))
4+
```python
5+
# 修改cfg.yaml,通过a_calibrate_name和w_calibrate_name设置校准方法,推荐使用默认配置即可;
6+
# 量化配置
7+
model = linger.init(model, config_file = 'cfg.yaml')
8+
# 加载预训练模型
9+
model.load_state_dict("./pre_train.pt")
10+
with linger.calibration(): # 校准开关
11+
model(torch.load("/yrfs4/inference/sqtu2/LLM/code/linger3.0/my_linger/calibrate_input.pt")) # 走一遍前向,开始校准
12+
13+
# 开始量化训练
1614
```
1715

1816
## linger.init/constrain中'disable_module'使用方法
@@ -30,7 +28,6 @@
3028
## linger.init中可通过yaml文件加载配置,当前配置可通过linger.config_save_to_yaml保存
3129
## config.yaml 介绍
3230
* 基础配置
33-
calibration: false # 校准开关
3431
clamp_info: # 约束信息配置
3532
clamp_activation_value: 8 # 激活约束浮点值,8代表约束到[-8, 8]
3633
clamp_bias_value: null # bias约束浮点值,默认值为None
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ with torch.no_grad():
4141

4242
## config.yaml 介绍
4343
* 基础配置
44-
calibration: false # 校准开关
4544
clamp_info: # 约束信息配置
4645
clamp_activation_value: 8 # 激活约束浮点值,8代表约束到[-8, 8]
4746
clamp_bias_value: null # bias约束浮点值,默认值为None

doc/tutorial/release.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,7 @@
1010
* 解决LSTM和GRU算子一致性
1111
# V3.0.4 2026.01.07
1212
* 优化导图代码
13-
* 解决部分算子一致性问题
13+
* 解决部分算子一致性问题
14+
# V3.0.5 2026.02.05
15+
* 优化导图代码,解决cat算子导图问题
16+
* 优化代码适配不同torch版本

linger/__version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@ def _to_int(s):
88
return s
99

1010

11-
__version__ = "3.0.4"
11+
__version__ = "3.0.5"
1212
version_info = tuple(_to_int(s) for s in __version__.split("."))

linger/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def _save_to_yaml(cls, config_path: str):
133133
将当前配置保存到 YAML 文件
134134
"""
135135
config_dict = cls._to_save_dict()
136+
config_dict.pop('calibration', None)
136137
os.makedirs(os.path.dirname(config_path), exist_ok=True)
137138
with open(config_path, 'w', encoding='utf-8') as f:
138139
yaml.dump(config_dict, f, default_flow_style=False, indent=2, allow_unicode=True)

linger/constrain/cmodule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def cweight(self):
130130

131131
@property
132132
def cbias(self):
133-
return self.bias if self.clamp_bias is None else torch.clamp(self.bias, min = -self.clamp_bias, max = self.clamp_bias)
133+
return self.bias if (self.clamp_bias is None or self.bias is None) else torch.clamp(self.bias, min = -self.clamp_bias, max = self.clamp_bias)
134134

135135

136136
def cforward(self, input: torch.Tensor) -> torch.Tensor:

linger/initialize.py

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,48 @@
1414
from .constrain.cmodule import constrain_module, _CMODULE_TABLE
1515
from typing import Any, Dict, List, Optional, Union
1616

17+
def fuse_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
18+
19+
eps = 1e-5
20+
clamp_conv_name = prefix + 'conv'
21+
clamp_bn_name = prefix + 'bn'
22+
conv_int_name = prefix
23+
if clamp_conv_name + '.weight' in state_dict and clamp_bn_name + '.weight' in state_dict:
24+
b_mean = state_dict[clamp_bn_name + '.running_mean']
25+
b_var = state_dict[clamp_bn_name + '.running_var']
26+
b_w = state_dict[clamp_bn_name + '.weight']
27+
b_b = state_dict[clamp_bn_name + '.bias']
28+
sigma = 1 / torch.sqrt(b_var + eps)
29+
alpha = b_w * sigma
30+
beta = b_b - b_mean * alpha
31+
c_w = state_dict[clamp_conv_name + '.weight']
32+
state_dict[conv_int_name +
33+
'weight'] = (c_w * alpha.view(-1, *([1]*(len(c_w.shape)-1))))
34+
if clamp_conv_name + '.bias' in state_dict:
35+
c_b = state_dict[clamp_conv_name + '.bias']
36+
state_dict[conv_int_name + 'bias'] = (c_b * alpha + beta)
37+
state_dict.pop(clamp_conv_name + '.bias')
38+
else:
39+
state_dict[conv_int_name + 'bias'] = beta
40+
state_dict.pop(clamp_bn_name + '.running_mean')
41+
state_dict.pop(clamp_bn_name + '.running_var')
42+
state_dict.pop(clamp_bn_name + '.weight')
43+
state_dict.pop(clamp_bn_name + '.bias')
44+
state_dict.pop(clamp_bn_name + '.num_batches_tracked')
45+
state_dict.pop(clamp_conv_name + '.weight')
46+
else:
47+
assert clamp_conv_name + '.weight' not in state_dict and clamp_bn_name + \
48+
'.weight' not in state_dict, 'load quanted model but contain float clamp params'
49+
1750
@contextmanager
18-
def calibration(a_calibrate_name='top_10', w_calibrate_name='abs_max'):
51+
def calibration():
1952
# 保存旧值
20-
old_a_calibrate_name = QUANT_CONFIGS.quant_info.a_calibrate_name
21-
old_w_calibrate_name = QUANT_CONFIGS.quant_info.w_calibrate_name
53+
# old_a_calibrate_name = QUANT_CONFIGS.quant_info.a_calibrate_name
54+
# old_w_calibrate_name = QUANT_CONFIGS.quant_info.w_calibrate_name
2255
try:
2356
QUANT_CONFIGS.calibration = True
24-
QUANT_CONFIGS.quant_info.a_calibrate_name = a_calibrate_name
25-
QUANT_CONFIGS.quant_info.w_calibrate_name = w_calibrate_name
57+
# QUANT_CONFIGS.quant_info.a_calibrate_name = a_calibrate_name
58+
# QUANT_CONFIGS.quant_info.w_calibrate_name = w_calibrate_name
2659
yield # <<< 关键点:控制权交给 with 块
2760
finally:
2861
QUANT_CONFIGS.calibration = False
@@ -84,15 +117,20 @@ def init(model: nn.Module, config_file: str = None, disable_module=None, disable
84117
# traced_model = symbolic_trace(model)
85118
# model = _replace_ops(traced_model, q_configs)
86119

120+
has_replaced = []
87121
for name, m in model.named_modules():
88122
if disable_submodel is not None and any(fnmatch(name, pattern) for pattern in disable_submodel):
89123
continue
124+
if any(name.startswith(p + ".") for p in has_replaced):
125+
continue
90126

91127
m.register_forward_pre_hook(hook_pre_forward)
92128
m.register_forward_hook(hook_forward)
93129

94-
_quantize_submodule(model, name, m, weights_cfg=q_configs.quant_info.to_dict(), activations_cfg=q_configs.quant_info.to_dict(), bias_cfg=q_configs.quant_info.to_dict(), constrain = q_configs.clamp_info.to_dict())
95-
130+
is_replaced = _quantize_submodule(model, name, m, weights_cfg=q_configs.quant_info.to_dict(), activations_cfg=q_configs.quant_info.to_dict(), bias_cfg=q_configs.quant_info.to_dict(), constrain = q_configs.clamp_info.to_dict())
131+
if is_replaced:
132+
has_replaced.append(name)
133+
96134
def quant_tensor_pre_hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
97135

98136
def quant_tensor_layer(module, prefix=''):
@@ -212,13 +250,21 @@ def _quantize_submodule(
212250
constrain: Optional[Union[str]] = None,
213251
):
214252
qmodule = quantize_module(module, weights_cfg=weights_cfg, activations_cfg=activations_cfg, bias_cfg = bias_cfg, dim = getattr(module, "dim", None), constrain = constrain)
253+
if isinstance(module, ConvBN1d) or isinstance(module, CConvBN1d) \
254+
or isinstance(module, ConvBN2d) or isinstance(module, CConvBN2d) \
255+
or isinstance(module, ConvTransposeBN1d) or isinstance(module, CConvTransposeBN1d) \
256+
or isinstance(module, ConvTransposeBN2d) or isinstance(module, CConvTransposeBN2d):
257+
qmodule._register_load_state_dict_pre_hook(fuse_state_dict)
258+
215259
if qmodule is not None:
216260
_set_module_by_name(model, name, qmodule)
217261
qmodule.name = name
218262
for name, param in module.named_parameters():
219263
# Save device memory by clearing parameters
220264
setattr(module, name, None)
221265
del param
266+
return True
267+
return False
222268

223269
def _constrain_submodule(
224270
model: torch.nn.Module,

0 commit comments

Comments
 (0)