Skip to content

Commit

Permalink
Initial fp8.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Jan 25, 2024
1 parent 7872b8c commit 0165a5f
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 0 deletions.
6 changes: 6 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ enum Quantization {
/// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better
/// perplexity performance for you model
BitsandbytesFP4,
/// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above
/// This dtype has native ops should be the fastest if available.
Fp8,
}

impl std::fmt::Display for Quantization {
Expand All @@ -73,6 +76,9 @@ impl std::fmt::Display for Quantization {
Quantization::Eetq => {
write!(f, "eetq")
}
Quantization::Fp8 => {
write!(f, "fp8")
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions server/text_generation_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class Quantization(str, Enum):
gptq = "gptq"
awq = "awq"
eetq = "eetq"
fp8 = "fp8"


class Dtype(str, Enum):
Expand Down
40 changes: 40 additions & 0 deletions server/text_generation_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,40 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
output = output + self.bias if self.bias is not None else output
return output

class Fp8Linear(nn.Module):
def __init__(
self,
weight,
bias,
) -> None:
super().__init__()
device = weight.device
# weight, scale = quant_weights(weight, torch.int8, False)
finfo = torch.finfo(weight.dtype)
qdtype = torch.float8_e4m3fn
# Calculate the scale as dtype max divided by absmax
scale = finfo.max / weight.abs().max().clamp(min=1e-12)
# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
x_scl_sat = (weight * scale).clamp(min=finfo.min, max=finfo.max)
# Return both float8 data and the inverse scale (as float),
# as both required as inputs to torch._scaled_mm
self.dtype = weight.dtype
self.qweight = x_scl_sat.to(qdtype).to(device=device)
self.scale = scale.float().reciprocal().to(device=device)
self.bias = bias.cuda(device) if bias is not None else None

def forward(self, input: torch.Tensor) -> torch.Tensor:
finfo = torch.finfo(input.dtype)
scale = finfo.max / input.abs().max().clamp(min=1e-12)
qinput = (input * scale).clamp(min=finfo.min, max=finfo.max)

output, _ = torch._scaled_mm(qinput, self.qweight, out_dtype=torch.float16,
scale_a=scale , scale_b=self.scale)
output = output + self.bias if self.bias is not None else output
return output


class Linear8bitLt(nn.Module):
def __init__(
Expand Down Expand Up @@ -298,6 +332,12 @@ def get_linear(weight, bias, quantize):
raise ImportError(
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
)
elif quantize == "fp8":
linear = Fp8Linear(weight, bias)
else:
raise ImportError(
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
)
elif quantize == "bitsandbytes":
warn_deprecate_bnb()
linear = Linear8bitLt(
Expand Down

0 comments on commit 0165a5f

Please sign in to comment.