-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
推理和生成相关调研和设计 #265
Comments
目前倾向于参考 Huggingface 的方案,将整个 inference 流程分解为 task-specific 的部分和 model-related 部分,学习 Huggingface 的推理 API,在模型内部支持 tensor 并行和 pipeline 并行的调用,先支持经典的 text_classification 和 text_generation 任务。 |
目前来说我想好的整个pineline和huggingface流程差不多, 首先我们得有一个基类的pipeline作为可继承的类使用: from libai.config import LazyConfig, try_get_key
from libai.engine.default import DefaultTrainer
from libai.utils.checkpoint import Checkpointer
from libai.data.structures import DistTensorData, Instance
class BasicPipeline:
def __init__(
self,
config_file,
**kwargs):
self.cfg = LazyConfig.load(config_file)
self.model = self.load_model(config_file)
self.tokenier = self.build_tokenizer(config_file)
...
def load_model(cfg):
model = DefaultTrainer.build_model(cfg).eval()
# 这里除了加载libai的模型用checkpointer以外,
# 也可以用户支持自定义, 从其他框架导入weight, 比如load_huggingface_weight
Checkpointer(model, save_dir=cfg.train.output_dir).resume_or_load(
cfg.train.load_weight, resume=False
)
if try_get_key(cfg, "train.graph.enabled", default=False):
model = DefaultTrainer.build_graph(cfg, model, is_train=False)
return model
def build_tokenizer(cfg):
...
def __call__(self, inputs, *args, batch_size=None, **kwargs):
model_inputs = self.preprocess(inputs, batch_size)
model_outputs = self.forward(model_inputs)
outputs = self.postprocess(model_outputs)
return outputs
def preprocess(self, inputs, batch_size, **kwargs):
...
return Instance(
input_ids=DistTensorData(...),
attention_mask=DistTensorData(...),
tokentype_ids=DistTensorData(...),
)
def forward(self, model_inputs, **kwargs):
...
model_outputs = self.model(model_inputs)
return model_outputs
def postprocess(self, model_outputs, **kwargs):
...
return outputs 对于其中 |
对于不同的任务, 我们的inference代码会不一样, 分类任务如果是对于只有encoder的分类任务, 那么模型会比较简单, 直接输出类别和分数就可以了. 生成任务但是如果是包含decoder的生成任务, 在进行
def couplet(model, src, data_loader, config):
vocab = data_loader.vocab
tokenizer = data_loader.tokenizer
model.eval()
tokens = [vocab.stoi[tok] for tok in tokenizer(src)] # 构造一个样本
num_tokens = len(tokens)
src = (torch.LongTensor(tokens).reshape(num_tokens, 1)) # 将src_len 作为第一个维度
tgt_tokens = greedy_decode(model, src, max_len=num_tokens + 5,
start_symbol=data_loader.BOS_IDX, config=config,
data_loader=data_loader).flatten() # 解码的预测结果
return "".join([vocab.itos[tok] for tok in tgt_tokens]).replace("<bos>", "").replace("<eos>", "")
def greedy_decode(model, src, max_len, start_symbol, config, data_loader):
src = src.to(config.device)
memory = model.encoder(src) # 对输入的Token序列进行解码翻译
ys = torch.ones(1, 1).fill_(start_symbol). \
type(torch.long).to(config.device) # 解码的第一个输入,起始符号
for i in range(max_len - 1):
memory = memory.to(config.device)
tgt_mask = (model.my_transformer.generate_square_subsequent_mask(ys.size(0))
.type(torch.bool)).to(config.device) # 根据tgt_len产生一个注意力mask矩阵(对称的)
out = model.decoder(ys, memory, tgt_mask) # [tgt_len,tgt_vocab_size]
out = out.transpose(0, 1) # [tgt_vocab_size, tgt_len]
prob = model.classification(out[:, -1]) # 只对对预测的下一个词进行分类
_, next_word = torch.max(prob, dim=1) # 选择概率最大者
next_word = next_word.item()
ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
# 将当前时刻解码的预测输出结果,同之前所有的结果堆叠作为输入再去预测下一个词。
if next_word == data_loader.EOS_IDX: # 如果当前时刻的预测输出为结束标志,则跳出循环结束预测。
break
return ys
生成任务的加速从大体上看, 上述的代码是没有问题的, 但是有一个点我们可以加速的地方, 我们可以把decoder里面第一次运行的key-value保存起来, 在huggingface里面也是这么做的 由于decoder里面的key和value, 都是通过encoder的输出进行全连接得到的, 在网络是eval()模式, 而且encoder也只进行了一次前向的情况下, 在每次调用decoder期间, 用到的key和value都是同一个值, 也就是说在decoder里面key和value的生成只需要进行一次计算, 然后保存起来, 以后的计算都是重复的. 在LiBai的 def forward(
self,
hidden_states,
attention_mask=None,
encoder_states=None,
encoder_attention_mask=None,
past_key_value=None,
use_cache=False,
):
...
if past_key_value is not None:
if self.is_decoder:
assert len(past_key_value) == 4
self_attn_past_key_value = past_key_value[:2]
cross_attn_past_key_value = past_key_value[2:]
else:
self_attn_past_key_value = past_key_value
cross_attn_past_key_value = None
else:
self_attn_past_key_value, cross_attn_past_key_value = None, None
layernorm_output = self.input_layernorm(hidden_states)
attention_output = self.self_attention(
layernorm_output,
attention_mask=attention_mask,
past_key_value=self_attn_past_key_value,
use_cache=use_cache,
)
attention_output = self.drop_path(attention_output)
if use_cache:
attention_output, presents = attention_output
hidden_states = hidden_states + attention_output
layernorm_output = self.post_attention_layernorm(hidden_states)
if self.is_decoder:
# todo: use key-value to pass the arguments
attention_output = self.cross_attention(
layernorm_output,
encoder_states,
attention_mask=encoder_attention_mask,
past_key_value=cross_attn_past_key_value,
use_cache=use_cache,
)
if use_cache:
attention_output, decoder_presents = attention_output
presents += decoder_presents
attention_output = self.drop_path(attention_output)
hidden_states = hidden_states + attention_output
layernorm_output = self.post_cross_attention_layernorm(hidden_states)
mlp_output = self.mlp(layernorm_output)
mlp_output = self.drop_path(mlp_output)
output = hidden_states + mlp_output
if use_cache:
output = (output, presents)
return output 所以我们需要在写inference的时候, 需要在调用transformer_layer的地方, 设置 关于怎么修改代码有两个办法,
大致代码如下: from types import MethodType
def my_forward(self, ...):
...
dec_embedding_output = self.embedding(decoder_input_ids)
dec_hidden_states = dec_embedding_output
presents = []
if past_key_values is None:
past_key_values = [None] * self.decoder.layers
for layer, past_key_value in zip(self.decoder.layers, past_key_values):
dec_hidden_states, present = layer(
dec_hidden_states,
decoder_attn_mask,
encoder_states,
encoder_decoder_attn_mask,
past_key_value=past_key_value,
use_cache=True,
)
presents.append(present)
decoder_states = self.decoder.final_layernorm(dec_hidden_states)
logits = self.lm_head(decoder_states, self.embedding.word_embeddings.weight)
return logits, presents
# 重新指定model.forward()
model.forward = MethodType(my_forward, model)
其中方法1的好处是不用修改libai里面本来的代码, libai里面的代码让人看上去觉得比较干净, 坏处就是每个包含decoder的model, 可能都需要单独写一个forward()来重构一下. 方法2的好处是可以一劳永逸, 在inference里面会比较干净, 坏处就是在 |
我倾向于用方法2,megatron 和 huggingface 应该都是这样做的~ |
做推理生成任务的时候,输入序列是变长的是吧,那目前只能用 eager global 来做了? |
我理解 不止生成任务, 可能分类任务输入序列也是变长的, 只不过都会进行padding到max_length. |
正好我们下午要和 idea 开会,这个部分的问题涉及到 NLP 的 domain knowledge,我们和他们请教一下 |
调研了不同的 NLP 库在预测阶段的处理方式
FairSeq
针对生成任务的代码主要在 https://github.com/pytorch/fairseq/blob/main/fairseq/sequence_generator.py
针对序列预测任务的代码主要在 https://github.com/pytorch/fairseq/blob/7e758841da9e05cb21826a60d30a563a9e189d1d/fairseq/sequence_scorer.py#L12
主要针对生成的任务进行构建的,tasks 支持比较少,而且两种风格不统一,同时不支持模型并行模式的推理。
AllenNLP
主要代码在 https://github.com/allenai/allennlp/blob/426d894ceef591b406cb77a7b094c88c85ad0068/allennlp/models/model.py#L193
在模型层面进行实现,每种模型绑定一个推理方式,这种方式下,模型和任务没有解耦,在训练中耦合 generation 的逻辑
Megatron-LM
提供了 api 代码 https://github.com/NVIDIA/Megatron-LM/blob/e156d2fea7fc5c98e645f7742eb86b643956d840/megatron/text_generation/api.py#L30
支持的 tasks 比较少,不过可以支持复杂并行的模型推理,比如 pipeline 并行,但是整体实现以及调用流程比较复杂,对用户不友好
HuggingFace
主要代码在 https://github.com/huggingface/transformers/blob/eb5bdcdfa51f743887ee1d9c7f230444d7a8b23c/src/transformers/pipelines/base.py#L710
在整个流程抽象为如下的处理流
调用方式清晰简单
扩展任务比较方便,可以继承基类 Pipeline,解耦了任务相关的流程和模型推理的流程。
@thinksoso @xiezipeng-ML 遗漏的内容可以补充一下,有错误的地方可以修正~
The text was updated successfully, but these errors were encountered: