-
Notifications
You must be signed in to change notification settings - Fork 0
/
ds_utils.py
126 lines (108 loc) · 4.41 KB
/
ds_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import math
import torch
from torch import nn
import torch.nn.functional as F
from deepspeed.compression.helper import recursive_getattr, recursive_setattr
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM, get_scheduler, DataCollatorForLanguageModeling, pipeline
class LinearLayer_LoRA(nn.Module):
# an simple implementation of LoRA
# for now only support Linear Layer
def __init__(self,
weight,
lora_dim=0,
lora_scaling=1,
lora_droppout=0,
bias=None):
super(LinearLayer_LoRA, self).__init__()
self.weight = weight
self.bias = bias
if lora_dim <= 0:
raise ValueError(
"You are training to use LoRA, whose reduced dim should be larger than 1"
)
try:
# for zero stage 3
rows, columns = weight.ds_shape
except:
rows, columns = weight.shape
self.lora_right_weight = nn.Parameter(torch.zeros(
columns,
lora_dim)) # apply transpose so in forward we do not need to
self.lora_left_weight = nn.Parameter(torch.zeros(lora_dim, rows))
self.lora_scaling = lora_scaling / lora_dim
if lora_droppout > 0:
self.lora_dropout = nn.Dropout(lora_droppout)
else:
self.lora_dropout = nn.Identity()
self.reset_parameters()
# disable the original weight gradient
self.weight.requires_grad = False
# fuse LoRA to the original weight
self.fuse_lora = False
def eval(self):
self.lora_dropout.eval()
# self.fuse_lora_weight()
def train(self, mode=True):
self.lora_dropout.train(mode)
# self.unfuse_lora_weight()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.lora_right_weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_left_weight)
def fuse_lora_weight(self):
if not self.fuse_lora:
self.weight.data += self.lora_scaling * torch.matmul(
self.lora_left_weight.t(), self.lora_right_weight.t())
self.fuse_lora = True
def unfuse_lora_weight(self):
if self.fuse_lora:
self.weight.data -= self.lora_scaling * torch.matmul(
self.lora_left_weight.t(), self.lora_right_weight.t())
self.fuse_lora = False
def forward(self, input):
if self.fuse_lora:
return F.linear(input, self.weight, self.bias)
else:
return F.linear(
input, self.weight,
self.bias) + (self.lora_dropout(input) @ self.lora_right_weight
@ self.lora_left_weight) * self.lora_scaling
# convert the linear layer to LoRA
def convert_linear_layer_to_lora(model,
part_module_name,
lora_dim=0,
lora_scaling=1,
lora_droppout=0):
replace_name = []
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and part_module_name in name:
replace_name.append(name)
for name in replace_name:
module = recursive_getattr(model, name)
tmp = LinearLayer_LoRA(
module.weight, lora_dim, lora_scaling, lora_droppout,
module.bias).to(module.weight.device).to(module.weight.dtype)
recursive_setattr(model, name, tmp)
return model
def _z3_params_to_fetch(param_list):
return [
p for p in param_list
if hasattr(p, 'ds_id') and p.ds_status == deepspeed.runtime.zero.
partition_parameters.ZeroParamStatus.NOT_AVAILABLE
]
# convert the LoRA layer to linear layer
def convert_lora_to_linear_layer(model):
replace_name = []
for name, module in model.named_modules():
if isinstance(module, LinearLayer_LoRA):
replace_name.append(name)
for name in replace_name:
module = recursive_getattr(model, name)
zero_stage_3 = hasattr(module.weight, 'ds_id')
with deepspeed.zero.GatheredParameters(_z3_params_to_fetch([
module.weight, module.bias, module.lora_left_weight,
module.lora_right_weight
]),
modifier_rank=0,
enabled=zero_stage_3):
module.fuse_lora_weight()
return model