From a71d7831fa44975e45162b9c6602e29e32af3614 Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Fri, 17 Jun 2022 11:44:10 +0200 Subject: [PATCH] Remove intermediate weight conversion to FP32 when possible (#839) * Remove intermediate weight conversion to FP32 when possible This can save memory during conversion when the model weights are stored in FP16 and the model is converted with quantization. * Check the weights were not copied or converted --- python/ctranslate2/specs/model_spec.py | 39 +++++++----- python/tests/test.py | 82 ++++++++++++++++++++++++-- 2 files changed, 101 insertions(+), 20 deletions(-) diff --git a/python/ctranslate2/specs/model_spec.py b/python/ctranslate2/specs/model_spec.py index 6a3a202aa..9ab3ec78d 100644 --- a/python/ctranslate2/specs/model_spec.py +++ b/python/ctranslate2/specs/model_spec.py @@ -73,8 +73,8 @@ def _check(spec, name, value): raise ValueError("Missing value for attribute %s" % name) if isinstance(value, np.ndarray): - # Use float32 as the working floating point type. - if value.dtype in (np.float16, np.float64): + # float64 is not a supported type. + if value.dtype == np.float64: value = value.astype(np.float32) elif isinstance(value, float): value = np.dtype("float32").type(value) @@ -144,13 +144,17 @@ def _quantize(self, quantization): def _quantize(spec, name, value): if not isinstance(value, np.ndarray): return - if "weight" in name and quantization != "float16": + + scale = None + is_quantizable = "weight" in name + + if is_quantizable: if quantization == "int16": # Represent the value with 10 bits so the multiplication is 20 bits # and 12 bits are left for accumulation. - scale = np.dtype(value.dtype).type( - 2**10 / np.amax(np.absolute(value)) - ) + scale = np.float32(2**10 / np.amax(np.absolute(value))) + if value.dtype != scale.dtype: + value = value.astype(scale.dtype) value *= scale value = np.rint(value) value = np.clip( @@ -158,18 +162,26 @@ def _quantize(spec, name, value): ) value = value.astype(np.int16) elif quantization in ("int8", "int8_float16"): - amax = np.amax(np.absolute(value), axis=1) + amax = np.amax(np.absolute(value), axis=1).astype(np.float32) amax[amax == 0] = 127.0 scale = 127.0 / amax + if value.dtype != scale.dtype: + value = value.astype(scale.dtype) value *= np.expand_dims(scale, 1) value = np.rint(value) value = value.astype(np.int8) - setattr(spec, "weight_scale", scale) - setattr(spec, "weight", value) - elif quantization in ("float16", "int8_float16"): + + if quantization in ("float16", "int8_float16"): if value.dtype == np.float32: - key = _split_scope(name)[-1] - setattr(spec, key, value.astype(np.float16)) + value = value.astype(np.float16) + else: + if value.dtype == np.float16: + value = value.astype(np.float32) + + key = _split_scope(name)[-1] + setattr(spec, key, value) + if scale is not None: + setattr(spec, "%s_scale" % key, scale) self._visit(_quantize) @@ -184,8 +196,7 @@ def optimize(self, quantization: str = None) -> None: (possible values are: int8, int8_float16, int16, float16). """ self._alias_variables() - if quantization is not None: - self._quantize(quantization) + self._quantize(quantization) def _visit(self, fn): """Recursively visits this layer and its children.""" diff --git a/python/tests/test.py b/python/tests/test.py index b74d2cc38..9147d9a52 100644 --- a/python/tests/test.py +++ b/python/tests/test.py @@ -1142,6 +1142,10 @@ def test_transformers_lm_scoring(tmpdir): assert generator.score_batch([["<|endoftext|>"]])[0] == [] +def _array_equal(a, b): + return a.dtype == b.dtype and np.array_equal(a, b) + + def test_layer_spec_validate(): class SubSpec(ctranslate2.specs.LayerSpec): def __init__(self): @@ -1160,12 +1164,12 @@ def __init__(self): spec = Spec() spec.validate() assert spec.a.dtype == np.float32 - assert spec.b.dtype == np.float32 + assert spec.b.dtype == np.float16 assert spec.c.dtype == np.int32 assert spec.d == OPTIONAL - assert spec.e.a.dtype == np.float32 - assert spec.f.dtype == np.int8 and np.array_equal(spec.f, 1) - assert spec.g.dtype == np.int8 and np.array_equal(spec.g, [104, 101, 108, 108, 111]) + assert spec.e.a.dtype == np.float16 + assert _array_equal(spec.f, np.int8(1)) + assert _array_equal(spec.g, np.array([104, 101, 108, 108, 111], dtype=np.int8)) def test_layer_spec_optimize(): @@ -1208,10 +1212,76 @@ def __init__(self): spec = Spec() spec.optimize(quantization="int8") - assert np.array_equal( + assert _array_equal( spec.weight, np.array([[-127, -38, 64, 25], [0, 0, 0, 0]], dtype=np.int8) ) - assert np.array_equal(spec.weight_scale, np.array([12.7, 1], dtype=np.float32)) + assert _array_equal(spec.weight_scale, np.array([12.7, 1], dtype=np.float32)) + + +@pytest.mark.parametrize( + "quantization,expected_weight,expected_weight_scale,expected_bias", + [ + ( + None, + np.array([[-10, -3, 5, 2]], dtype=np.float32), + None, + np.array([4], dtype=np.float32), + ), + ( + "float16", + np.array([[-10, -3, 5, 2]], dtype=np.float16), + None, + np.array([4], dtype=np.float16), + ), + ( + "int8", + np.array([[-127, -38, 64, 25]], dtype=np.int8), + np.array([12.7], dtype=np.float32), + np.array([4], dtype=np.float32), + ), + ( + "int8_float16", + np.array([[-127, -38, 64, 25]], dtype=np.int8), + np.array([12.7], dtype=np.float32), + np.array([4], dtype=np.float16), + ), + ( + "int16", + np.array([[-1024, -307, 512, 205]], dtype=np.int16), + np.float32(102.4), + np.array([4], dtype=np.float32), + ), + ], +) +def test_fp16_weights( + quantization, expected_weight, expected_weight_scale, expected_bias +): + class Spec(ctranslate2.specs.LayerSpec): + def __init__(self, weight, bias): + self.weight = weight + self.bias = bias + + weight = np.array([[-10, -3, 5, 2]], dtype=np.float16) + bias = np.array([4], dtype=np.float16) + + spec = Spec(weight, bias) + spec.validate() + spec.optimize(quantization=quantization) + + assert _array_equal(spec.weight, expected_weight) + assert _array_equal(spec.bias, expected_bias) + + # Check the weights were not copied or converted. + if quantization == "float16": + assert spec.weight is weight + assert spec.bias is bias + elif quantization == "int8_float16": + assert spec.bias is bias + + if expected_weight_scale is None: + assert not hasattr(spec, "weight_scale") + else: + assert _array_equal(spec.weight_scale, expected_weight_scale) def test_index_spec():