Skip to content

Commit

Permalink
Dummy but working version.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Jan 25, 2024
1 parent 0165a5f commit 2091859
Showing 1 changed file with 25 additions and 22 deletions.
47 changes: 25 additions & 22 deletions server/text_generation_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,41 +186,44 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
output = output + self.bias if self.bias is not None else output
return output

def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
device = weight.device
# weight, scale = quant_weights(weight, torch.int8, False)
finfo = torch.finfo(qdtype)
# Calculate the scale as dtype max divided by absmax
scale = finfo.max / weight.abs().max().clamp(min=1e-12)
# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
# Return both float8 data and the inverse scale (as float),
# as both required as inputs to torch._scaled_mm
qweight = qweight.to(qdtype)
scale = scale.float().reciprocal()
return qweight, scale

class Fp8Linear(nn.Module):
def __init__(
self,
weight,
bias,
) -> None:
super().__init__()
device = weight.device
# weight, scale = quant_weights(weight, torch.int8, False)
finfo = torch.finfo(weight.dtype)
qdtype = torch.float8_e4m3fn
# Calculate the scale as dtype max divided by absmax
scale = finfo.max / weight.abs().max().clamp(min=1e-12)
# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
x_scl_sat = (weight * scale).clamp(min=finfo.min, max=finfo.max)
# Return both float8 data and the inverse scale (as float),
# as both required as inputs to torch._scaled_mm
self.dtype = weight.dtype
self.qweight = x_scl_sat.to(qdtype).to(device=device)
self.scale = scale.float().reciprocal().to(device=device)
self.qweight, self.scale = fp8_quantize(weight)
self.bias = bias.cuda(device) if bias is not None else None

def forward(self, input: torch.Tensor) -> torch.Tensor:
finfo = torch.finfo(input.dtype)
scale = finfo.max / input.abs().max().clamp(min=1e-12)
qinput = (input * scale).clamp(min=finfo.min, max=finfo.max)

output, _ = torch._scaled_mm(qinput, self.qweight, out_dtype=torch.float16,
scale_a=scale , scale_b=self.scale)
output = output + self.bias if self.bias is not None else output
qinput, scale = fp8_quantize(input)
seqlen = qinput.shape[0]
if seqlen % 16 != 0:
missing = 16 - seqlen % 16
qinput = F.pad(qinput, (0, 0, 0, missing), "constant", value=0)
output, _ = torch._scaled_mm(qinput, self.qweight.t(), out_dtype=self.dtype,
scale_a=scale , scale_b=self.scale, bias=self.bias)
output = output[:seqlen]
return output


class Linear8bitLt(nn.Module):
def __init__(
self,
Expand Down

0 comments on commit 2091859

Please sign in to comment.