forked from AnswerDotAI/fsdp_qlora
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval_utils.py
113 lines (99 loc) · 4.81 KB
/
eval_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import re
from datasets import load_dataset
import torch
from bitsandbytes.nn import Linear4bit, Params4bit
import torch.nn as nn
from typing import List
from accelerate import init_empty_weights
def extract_last_number_or_ratio(s):
# Find all sequences of digits, possibly with leading currency symbols, decimal points, and ratios
patterns = re.findall(r'[\$€£]?\d+(?:\.\d+)?(?:\:\d+(?:\.\d+)?)?', s)
# Return the last pattern found, or None if there are no matches
if patterns:
return patterns[-1]
else:
return None
def replace_linear(model:nn.Module, linear_replacement:nn.Module, quant_config:dict|None=None,
skip_modules:List[str]=["lm_head"], **kwargs):
"""
Replace linear modules with a new Linear module.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
linear_replacement (`torch.nn.Module`):
The linear module that replaces the old one. Only expects standard arguments.
If other arguments need to be passed, use a lambda.
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
List of modules names not to convert. Defaults to `lm_head`.
"""
for name, module in model.named_children():
if name in skip_modules:
print(f"Skipping {name}")
continue
if len(list(module.children())) > 0:
replace_linear(module, linear_replacement, quant_config, skip_modules, **kwargs)
if isinstance(module, torch.nn.Linear):
if issubclass(linear_replacement, Linear4bit):
model._modules[name] = linear_replacement(
module.in_features,
module.out_features,
module.bias is not None,
**kwargs
)
# elif issubclass(linear_replacement, HQQLinear):
# model._modules[name] = linear_replacement(module, quant_config, **kwargs)
else:
raise ValueError(f"Unsupported linear replacement: {type(linear_replacement)}")
return model
def load_and_quantize(module:nn.Module, name:str, value:torch.Tensor, device:torch.device=None, dtype:torch.dtype=None,
skip_names:list[str]=[], is_meta_rank:bool=False, low_memory:bool=True, verbose:bool=False,
quant_method:str='bnb', is_dora:bool=False):
"""
Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
Quantizes `Params4bit` on `device` then places on "cpu" if low_memory=True or "meta" if is_meta_rank=True.
"""
def place_on_device(value):
if is_meta_rank:
device = 'meta'
elif low_memory:
device = 'cpu'
return value.to(device=device, dtype=dtype)
if any([skip_name in name for skip_name in skip_names]):
if verbose:
print(f"Skipping {name} because it is in skip_names")
return
module_key, _, value_key = name.rpartition('.')
try:
submodule = module.get_submodule(module_key)
except AttributeError as e:
print(f"Module {module_key} not found:\n{e}")
return
try:
if quant_method=='bnb':
param = submodule.get_parameter(value_key)
if isinstance(param, Params4bit):
# With `sync_module_states=True`, a meta device Params4bit needs to be the same
# shape as the quantized Params4bit with an initialized quant_state. However,
# FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
# workaround quantizes Params4bit to initialize quant_state on all ranks, then
# replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
if is_dora:
setattr(submodule, "dora_scale", value.norm(p=2, dim=1).to(dtype=dtype).to("cpu"))
print("DORA scale initialized")
value = type(param)(value.to(device=device, dtype=dtype).data, **param.__dict__).cuda(device)
if is_meta_rank:
value = type(param)(value.data.to("meta"), **value.__dict__)
elif low_memory:
value = type(param)(value.data.to("cpu"), **value.__dict__)
# print("Loaded quantized layer")
else:
value = type(param)(place_on_device(value).data)
# print("Loaded regular layer")
except AttributeError:
# it's a buffer
value = place_on_device(value)
pass
setattr(submodule, value_key, value)
def load_and_quantize_parallel(name_param, model, **kwargs):
name, param = name_param
load_and_quantize(model, name, param, **kwargs)