Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The LLaMA implementation by Keras Hub exhibits significant deviations in accuracy compared to the standard implementation (Hugging Face). #1993

Open
pass-lin opened this issue Nov 27, 2024 · 2 comments
Assignees

Comments

@pass-lin
Copy link

pass-lin commented Nov 27, 2024

import os

os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"#chinese huggingface mirror source
os.environ["KERAS_BACKEND"] = "torch"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
model_name = 'NousResearch/Meta-Llama-3.1-8B'
import torch
from transformers import AutoTokenizer,AutoModelForCausalLM,AutoConfig
import keras
hf_model = AutoModelForCausalLM.from_pretrained(model_name,
                                                device_map="cuda:0",
                                             torch_dtype=torch.bfloat16, 
                                             _attn_implementation = 'eager',
                                             trust_remote_code=False).eval()
import keras_hub
keras.config.set_dtype_policy('bfloat16')
keras_model = keras_hub.models.Llama3CausalLM.from_preset('hf://'+model_name)

tokenizer = AutoTokenizer.from_pretrained(model_name)
input_ids,mask = tokenizer('计算量决定了网络执行时间的长短,参数量决定了占用显存的量').values()
input_ids = keras.ops.expand_dims(input_ids,0)
mask = keras.ops.expand_dims(mask,0)

x1 = hf_model.forward(input_ids,attention_mask=mask)
x2 = keras_model([mask,input_ids])

error = keras.ops.abs(x1.logits-x2)

error = keras.ops.abs(x1.logits-x2)

print(keras.ops.max(error))
print(keras.ops.min(error))
print(keras.ops.mean(error))
print(keras.ops.std(error))

print(keras.ops.max(error,-1))
print(keras.ops.min(error,-1))
print(keras.ops.mean(error,-1))
print(keras.ops.std(error,-1))

The output is

tensor(3.2188, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MaxBackward1>)
tensor(0., device='cuda:0', dtype=torch.bfloat16, grad_fn=<MinBackward1>)
tensor(0.2441, device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)
tensor(0.2129, device='cuda:0', dtype=torch.bfloat16, grad_fn=<StdBackward0>)
tensor([[0.5938, 0.4062, 2.0938, 1.0781, 1.2188, 2.4062, 2.2812, 1.5625, 1.5234,
         1.3750, 1.4844, 2.9531, 2.3281, 1.7344, 2.4062, 2.1875, 2.4062, 3.2188,
         1.6953, 1.6250, 1.7969, 1.5078]], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<AmaxBackward0>)
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<AminBackward0>)
tensor([[0.1060, 0.0669, 0.3340, 0.1436, 0.1621, 0.3320, 0.2617, 0.2236, 0.2246,
         0.2090, 0.2422, 0.2490, 0.2930, 0.2637, 0.2500, 0.3066, 0.3574, 0.3340,
         0.2676, 0.2598, 0.2344, 0.2461]], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)
tensor([[0.0679, 0.0510, 0.2578, 0.1138, 0.1270, 0.2617, 0.2080, 0.1719, 0.1768,
         0.1631, 0.1855, 0.2090, 0.2285, 0.2061, 0.2002, 0.2432, 0.2812, 0.2715,
         0.2090, 0.2031, 0.1816, 0.1904]], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<StdBackward0>)

The implementations of standard implementation (Hugging Face) and Keras have significant accuracy differences in logits, both of which are based on the PyTorch backend to avoid framework-specific errors to a certain extent. In practical use, it's also observed that the LLaMA implementation by Keras Hub tends to have repetitive decoding more easily, while the implementations by HF and VLLM are less prone to repetitive decoding.
Is it necessary to fix this precision difference?

@pass-lin
Copy link
Author

pass-lin commented Dec 3, 2024

Upon further attempts, I found that not only does the issue exist under bf16, but a similar magnitude of error occurs under fp32 as well. Typically, we consider an error tolerance of 1e-5 or below to be acceptable under fp32, but here the error is significantly higher than that number. Therefore, the implementation of the llama model may have a considerable margin of error.

import os

os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"#chinese huggingface mirror source
os.environ["KERAS_BACKEND"] = "torch"
#os.environ["CUDA_VISIBLE_DEVICES"] = "1"
model_name = 'NousResearch/Meta-Llama-3.1-8B'
import torch
from transformers import AutoTokenizer,AutoModelForCausalLM,AutoConfig
import keras
hf_model = AutoModelForCausalLM.from_pretrained(model_name,
                                                device_map="cuda:1",
                                             torch_dtype=torch.float32, 
                                             _attn_implementation = 'eager',
                                             trust_remote_code=False).eval()
import keras_hub
#keras.config.set_dtype_policy('bfloat16')
keras_model = keras_hub.models.Llama3CausalLM.from_preset('hf://'+model_name)

tokenizer = AutoTokenizer.from_pretrained(model_name)
input_ids,mask = tokenizer('计算量决定了网络执行时间的长短,参数量决定了占用显存的量').values()
input_ids = keras.ops.expand_dims(input_ids,0)
mask = keras.ops.expand_dims(mask,0)

x1 = hf_model.forward(input_ids.cuda("cuda:1"),attention_mask=mask.cuda("cuda:1"))
x2 = keras_model([mask,input_ids])

error = keras.ops.abs(x1.logits.cpu()-x2.cpu())

print(keras.ops.max(error))
print(keras.ops.min(error))
print(keras.ops.mean(error))
print(keras.ops.std(error))

print(keras.ops.max(error,-1))
print(keras.ops.min(error,-1))
print(keras.ops.mean(error,-1))
print(keras.ops.std(error,-1))
tensor(3.3085, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(0., device='cuda:0', grad_fn=<MinBackward1>)
tensor(0.2417, device='cuda:0', grad_fn=<MeanBackward1>)
tensor(0.2120, device='cuda:0', grad_fn=<StdBackward0>)
tensor([[0.4981, 0.5633, 1.9278, 1.0281, 0.9935, 2.5044, 2.2573, 1.5885, 1.5354,
         1.3483, 1.4797, 2.9066, 2.3571, 1.6378, 2.4488, 2.2407, 2.5110, 3.3085,
         1.7227, 1.6624, 1.7762, 1.5082]], device='cuda:0',
       grad_fn=<AmaxBackward0>)
tensor([[1.9670e-06, 4.7684e-07, 4.7684e-07, 0.0000e+00, 4.7684e-07, 1.4305e-06,
         1.5497e-06, 0.0000e+00, 4.7684e-07, 1.6689e-06, 2.3842e-06, 2.0266e-06,
         1.2398e-05, 1.1921e-07, 3.8147e-06, 9.0599e-06, 5.0068e-06, 4.5300e-06,
         2.3842e-07, 1.1921e-06, 2.6226e-06, 7.1526e-06]], device='cuda:0',
       grad_fn=<AminBackward0>)
tensor([[0.0929, 0.0694, 0.3153, 0.1340, 0.1513, 0.3295, 0.2608, 0.2181, 0.2219,
         0.2122, 0.2425, 0.2484, 0.2934, 0.2523, 0.2490, 0.3077, 0.3529, 0.3431,
         0.2756, 0.2599, 0.2309, 0.2557]], device='cuda:0',
       grad_fn=<MeanBackward1>)
tensor([[0.0589, 0.0522, 0.2429, 0.1061, 0.1178, 0.2618, 0.2068, 0.1680, 0.1737,
         0.1650, 0.1862, 0.2082, 0.2285, 0.1969, 0.1989, 0.2443, 0.2786, 0.2801,
         0.2143, 0.2034, 0.1799, 0.1978]], device='cuda:0',
       grad_fn=<StdBackward0>)

@mattdangerw
Copy link
Member

Thanks! Will take a look.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants