Skip to content

Commit

Permalink
Fix absent fp8_kv property on llama and qwen models (#662)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajtejankar authored Oct 30, 2024
1 parent 2ff1c71 commit c2441e2
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,11 @@ def __init__(
if is_fp8_kv(config.quantize):
self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item()
self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item()
self.kv_dtype = 'fp8'
self.fp8_kv = True
else:
self.k_scale = 1.0
self.v_scale = 1.0
self.kv_dtype = 'auto'
self.fp8_kv = False

self.query_key_value = load_attention(config, prefix, weights, layer_id)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,11 @@ def __init__(
if is_fp8_kv(config.quantize):
self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item()
self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item()
self.kv_dtype = 'fp8'
self.fp8_kv = True
else:
self.k_scale = 1.0
self.v_scale = 1.0
self.kv_dtype = 'auto'
self.fp8_kv = False

self.query_key_value = load_attention(config, prefix, weights, layer_id)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,11 @@ def __init__(
if is_fp8_kv(config.quantize):
self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item()
self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item()
self.kv_dtype = 'fp8'
self.fp8_kv = True
else:
self.k_scale = 1.0
self.v_scale = 1.0
self.kv_dtype = 'auto'
self.fp8_kv = False

self.c_attn = load_attention(config, prefix, weights, layer_id)

Expand Down

0 comments on commit c2441e2

Please sign in to comment.