Skip to content

Commit

Permalink
Remove intermediate weight conversion to FP32 when possible (OpenNMT#839
Browse files Browse the repository at this point in the history
)

* Remove intermediate weight conversion to FP32 when possible

This can save memory during conversion when the model weights are
stored in FP16 and the model is converted with quantization.

* Check the weights were not copied or converted
  • Loading branch information
guillaumekln authored Jun 17, 2022
1 parent 2c8db5e commit a71d783
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 20 deletions.
39 changes: 25 additions & 14 deletions python/ctranslate2/specs/model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def _check(spec, name, value):
raise ValueError("Missing value for attribute %s" % name)

if isinstance(value, np.ndarray):
# Use float32 as the working floating point type.
if value.dtype in (np.float16, np.float64):
# float64 is not a supported type.
if value.dtype == np.float64:
value = value.astype(np.float32)
elif isinstance(value, float):
value = np.dtype("float32").type(value)
Expand Down Expand Up @@ -144,32 +144,44 @@ def _quantize(self, quantization):
def _quantize(spec, name, value):
if not isinstance(value, np.ndarray):
return
if "weight" in name and quantization != "float16":

scale = None
is_quantizable = "weight" in name

if is_quantizable:
if quantization == "int16":
# Represent the value with 10 bits so the multiplication is 20 bits
# and 12 bits are left for accumulation.
scale = np.dtype(value.dtype).type(
2**10 / np.amax(np.absolute(value))
)
scale = np.float32(2**10 / np.amax(np.absolute(value)))
if value.dtype != scale.dtype:
value = value.astype(scale.dtype)
value *= scale
value = np.rint(value)
value = np.clip(
value, np.iinfo(np.int16).min, np.iinfo(np.int16).max
)
value = value.astype(np.int16)
elif quantization in ("int8", "int8_float16"):
amax = np.amax(np.absolute(value), axis=1)
amax = np.amax(np.absolute(value), axis=1).astype(np.float32)
amax[amax == 0] = 127.0
scale = 127.0 / amax
if value.dtype != scale.dtype:
value = value.astype(scale.dtype)
value *= np.expand_dims(scale, 1)
value = np.rint(value)
value = value.astype(np.int8)
setattr(spec, "weight_scale", scale)
setattr(spec, "weight", value)
elif quantization in ("float16", "int8_float16"):

if quantization in ("float16", "int8_float16"):
if value.dtype == np.float32:
key = _split_scope(name)[-1]
setattr(spec, key, value.astype(np.float16))
value = value.astype(np.float16)
else:
if value.dtype == np.float16:
value = value.astype(np.float32)

key = _split_scope(name)[-1]
setattr(spec, key, value)
if scale is not None:
setattr(spec, "%s_scale" % key, scale)

self._visit(_quantize)

Expand All @@ -184,8 +196,7 @@ def optimize(self, quantization: str = None) -> None:
(possible values are: int8, int8_float16, int16, float16).
"""
self._alias_variables()
if quantization is not None:
self._quantize(quantization)
self._quantize(quantization)

def _visit(self, fn):
"""Recursively visits this layer and its children."""
Expand Down
82 changes: 76 additions & 6 deletions python/tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,6 +1142,10 @@ def test_transformers_lm_scoring(tmpdir):
assert generator.score_batch([["<|endoftext|>"]])[0] == []


def _array_equal(a, b):
return a.dtype == b.dtype and np.array_equal(a, b)


def test_layer_spec_validate():
class SubSpec(ctranslate2.specs.LayerSpec):
def __init__(self):
Expand All @@ -1160,12 +1164,12 @@ def __init__(self):
spec = Spec()
spec.validate()
assert spec.a.dtype == np.float32
assert spec.b.dtype == np.float32
assert spec.b.dtype == np.float16
assert spec.c.dtype == np.int32
assert spec.d == OPTIONAL
assert spec.e.a.dtype == np.float32
assert spec.f.dtype == np.int8 and np.array_equal(spec.f, 1)
assert spec.g.dtype == np.int8 and np.array_equal(spec.g, [104, 101, 108, 108, 111])
assert spec.e.a.dtype == np.float16
assert _array_equal(spec.f, np.int8(1))
assert _array_equal(spec.g, np.array([104, 101, 108, 108, 111], dtype=np.int8))


def test_layer_spec_optimize():
Expand Down Expand Up @@ -1208,10 +1212,76 @@ def __init__(self):

spec = Spec()
spec.optimize(quantization="int8")
assert np.array_equal(
assert _array_equal(
spec.weight, np.array([[-127, -38, 64, 25], [0, 0, 0, 0]], dtype=np.int8)
)
assert np.array_equal(spec.weight_scale, np.array([12.7, 1], dtype=np.float32))
assert _array_equal(spec.weight_scale, np.array([12.7, 1], dtype=np.float32))


@pytest.mark.parametrize(
"quantization,expected_weight,expected_weight_scale,expected_bias",
[
(
None,
np.array([[-10, -3, 5, 2]], dtype=np.float32),
None,
np.array([4], dtype=np.float32),
),
(
"float16",
np.array([[-10, -3, 5, 2]], dtype=np.float16),
None,
np.array([4], dtype=np.float16),
),
(
"int8",
np.array([[-127, -38, 64, 25]], dtype=np.int8),
np.array([12.7], dtype=np.float32),
np.array([4], dtype=np.float32),
),
(
"int8_float16",
np.array([[-127, -38, 64, 25]], dtype=np.int8),
np.array([12.7], dtype=np.float32),
np.array([4], dtype=np.float16),
),
(
"int16",
np.array([[-1024, -307, 512, 205]], dtype=np.int16),
np.float32(102.4),
np.array([4], dtype=np.float32),
),
],
)
def test_fp16_weights(
quantization, expected_weight, expected_weight_scale, expected_bias
):
class Spec(ctranslate2.specs.LayerSpec):
def __init__(self, weight, bias):
self.weight = weight
self.bias = bias

weight = np.array([[-10, -3, 5, 2]], dtype=np.float16)
bias = np.array([4], dtype=np.float16)

spec = Spec(weight, bias)
spec.validate()
spec.optimize(quantization=quantization)

assert _array_equal(spec.weight, expected_weight)
assert _array_equal(spec.bias, expected_bias)

# Check the weights were not copied or converted.
if quantization == "float16":
assert spec.weight is weight
assert spec.bias is bias
elif quantization == "int8_float16":
assert spec.bias is bias

if expected_weight_scale is None:
assert not hasattr(spec, "weight_scale")
else:
assert _array_equal(spec.weight_scale, expected_weight_scale)


def test_index_spec():
Expand Down

0 comments on commit a71d783

Please sign in to comment.