Skip to content

Commit

Permalink
pep8
Browse files Browse the repository at this point in the history
  • Loading branch information
leonardozcm committed Jul 4, 2024
1 parent 1db7261 commit ffe8924
Showing 1 changed file with 40 additions and 11 deletions.
51 changes: 40 additions & 11 deletions python/llm/src/ipex_llm/transformers/npu_models/linear.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This file is adapted from
# https://github.com/intel/intel-npu-acceleration-library/blob/main/intel_npu_acceleration_library/nn/linear.py

#
# Copyright © 2024 Intel Corporation
# SPDX-License-Identifier: Apache 2.0
Expand All @@ -13,6 +31,8 @@
import uuid
import math

from ipex_llm.utils.common import invalidInputError


class Linear(torch.nn.Module):
"""Torch Linear operation NPU backend."""
Expand All @@ -22,15 +42,15 @@ def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
Args:
weight (torch.Tensor): Linear operation weight
bias (Optional[torch.Tensor], optional): Linear operation optional bias. Defaults to None.
bias (Optional[torch.Tensor], optional): Linear operation optional bias.
Defaults to None.
"""
super().__init__()

self.weight = torch.nn.Parameter(weight)
self.bias = torch.nn.Parameter(bias) if isinstance(bias, torch.Tensor) else None
self.outC, self.inC = self.weight.shape
self.op_id = str(uuid.uuid4())
# assert self.weight.dtype == torch.float16
self._mm = AutogradMatMul.apply

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -100,9 +120,8 @@ def fromTensor(
weights_quant, scale = quantize_tensor(weight)
return QuantizedLinear(weights_quant, scale, bias)
else:
raise RuntimeError(
f"intel-npu-acceleration-library library do not support yet the requeste datatype: {dtype}"
)
invalidInputError(False,
f"NPU do not support yet the requeste datatype: {dtype}")


class QuantizedLinear(torch.nn.Module):
Expand All @@ -119,7 +138,8 @@ def __init__(
Args:
weight (torch.Tensor): Linear operation weight
scale (torch.Tensor): Quantization scale
bias (Optional[torch.Tensor], optional): Linear operation optional bias. Defaults to None.
bias (Optional[torch.Tensor], optional): Linear operation optional bias.
Defaults to None.
Raises:
RuntimeError: Quantized weight must be in torch.int8 format
Expand All @@ -128,8 +148,12 @@ def __init__(

self.weight = Parameter(weight, requires_grad=False)
if self.weight.dtype not in (torch.int8, torch.uint8):
raise RuntimeError(
f"Quantized weight must be in torch.(u)int8 dtype instead of {self.weight.dtype}"
invalidInputError(
False,
(
f"Quantized weight must be in torch.(u)int8"
" dtype instead of {self.weight.dtype}"
)
)
self.outC, self.inC = self.weight.shape
if self.weight.dtype == torch.uint8:
Expand All @@ -147,14 +171,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x (torch.Tensor): Input tensor
Raises:
RuntimeError: Training is not supported for QuantizedLinear layer. Use `.eval()` to do inference only
RuntimeError: Training is not supported for QuantizedLinear layer.
Use `.eval()` to do inference only
Returns:
torch.Tensor: result
"""
if self.training:
raise RuntimeError(
"Training is not supported for QuantizedLinear layer. Use `.eval()` to do inference only"
invalidInputError(
False,
(
"Training is not supported for QuantizedLinear layer."
"Use `.eval()` to do inference only"
)
)
out = run_matmul(x, self.weight.data, self.scale.data, self.op_id)

Expand Down

0 comments on commit ffe8924

Please sign in to comment.