diff --git a/launcher/src/main.rs b/launcher/src/main.rs index c555300071f..d37726245e9 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -47,6 +47,9 @@ enum Quantization { /// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better /// perplexity performance for you model BitsandbytesFP4, + /// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above + /// This dtype has native ops should be the fastest if available. + Fp8, } impl std::fmt::Display for Quantization { @@ -73,6 +76,9 @@ impl std::fmt::Display for Quantization { Quantization::Eetq => { write!(f, "eetq") } + Quantization::Fp8 => { + write!(f, "fp8") + } } } } diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 99be6c7eeff..6e55a84ae1a 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -19,6 +19,7 @@ class Quantization(str, Enum): gptq = "gptq" awq = "awq" eetq = "eetq" + fp8 = "fp8" class Dtype(str, Enum): diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 5a0de0d7bd6..bd3b3c2a07e 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -186,6 +186,40 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: output = output + self.bias if self.bias is not None else output return output +class Fp8Linear(nn.Module): + def __init__( + self, + weight, + bias, + ) -> None: + super().__init__() + device = weight.device + # weight, scale = quant_weights(weight, torch.int8, False) + finfo = torch.finfo(weight.dtype) + qdtype = torch.float8_e4m3fn + # Calculate the scale as dtype max divided by absmax + scale = finfo.max / weight.abs().max().clamp(min=1e-12) + # scale and clamp the tensor to bring it to + # the representative range of float8 data type + # (as default cast is unsaturated) + x_scl_sat = (weight * scale).clamp(min=finfo.min, max=finfo.max) + # Return both float8 data and the inverse scale (as float), + # as both required as inputs to torch._scaled_mm + self.dtype = weight.dtype + self.qweight = x_scl_sat.to(qdtype).to(device=device) + self.scale = scale.float().reciprocal().to(device=device) + self.bias = bias.cuda(device) if bias is not None else None + + def forward(self, input: torch.Tensor) -> torch.Tensor: + finfo = torch.finfo(input.dtype) + scale = finfo.max / input.abs().max().clamp(min=1e-12) + qinput = (input * scale).clamp(min=finfo.min, max=finfo.max) + + output, _ = torch._scaled_mm(qinput, self.qweight, out_dtype=torch.float16, + scale_a=scale , scale_b=self.scale) + output = output + self.bias if self.bias is not None else output + return output + class Linear8bitLt(nn.Module): def __init__( @@ -298,6 +332,12 @@ def get_linear(weight, bias, quantize): raise ImportError( "Please install EETQ from https://github.com/NetEase-FuXi/EETQ" ) + elif quantize == "fp8": + linear = Fp8Linear(weight, bias) + else: + raise ImportError( + "Please install EETQ from https://github.com/NetEase-FuXi/EETQ" + ) elif quantize == "bitsandbytes": warn_deprecate_bnb() linear = Linear8bitLt(