Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Feb 19, 2024
1 parent b3a7d5b commit 684eb87
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 54 deletions.
49 changes: 4 additions & 45 deletions medusa/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,32 +100,8 @@ def generate_medusa_buffers(medusa_choices, device="cuda"):
for i in range(len(depth_counts)):
for j in range(depth_counts[i]):
cur_medusa_choice = sorted_medusa_choices[start + j]
medusa_tree_indices[start + j + 1] = cur_medusa_choice[-1] + TOPK * i + 1 ##根据每组最后一个节点和所在深度计算所在位置
medusa_tree_indices[start + j + 1] = cur_medusa_choice[-1] + TOPK * i + 1
start += depth_counts[i]
"""
逻辑上结构:
A (原始头预测token, 没在sorted_medusa_choices中)
B C ... K (第一个头预测token, 预测topk个)
banana cute ... key(第二个头预测token, 预测topk个)
铺平之后: A B C ... K banana cute ... key (一共1+topk*深度个=1+4*10=41个)
A:0
----
B:1
C:2
...
k:11
----
banana:12
cute:13
key:22
不是所有路径都选,节点有可能被多条路径选多次,事先设置选64个路径
medusa_tree_indices: 所有路径经过的节点,根据从短到长,从小到大记录下平铺后最后一个节点序号
"""

# Generate position IDs for the Medusa structure
medusa_position_ids = torch.zeros(medusa_len, dtype=torch.long)
Expand All @@ -138,7 +114,7 @@ def generate_medusa_buffers(medusa_choices, device="cuda"):
retrieve_indices_nest = []
retrieve_paths = []
for i in range(len(sorted_medusa_choices)):
cur_medusa_choice = sorted_medusa_choices[-i-1] ##倒着循环
cur_medusa_choice = sorted_medusa_choices[-i-1]
retrieve_indice = []
if cur_medusa_choice in retrieve_paths:
continue
Expand Down Expand Up @@ -356,12 +332,6 @@ def generate_candidates(medusa_logits, logits, tree_indices, retrieve_indices, t
# Unsqueeze the tree candidates for dimension consistency.
# tree_candidates = tree_candidates.unsqueeze(0)
return cart_candidates, tree_candidates
"""
cart_candidates.shape
torch.Size([2, 42, 5])
tree_candidates.shape
torch.Size([2, 64])
"""

def update_position_id(medusa_position_ids, attention_mask, input_ids):
bs = input_ids.shape[0]
Expand All @@ -376,9 +346,7 @@ def update_position_id(medusa_position_ids, attention_mask, input_ids):
def update_attention_mask(attention_mask, tree_candidates):
bs = tree_candidates.shape[0]
n = tree_candidates.shape[1]
# 创建一个新的张量,用于在尾部添加n个token
new_tokens = torch.ones((bs, n), dtype=attention_mask.dtype, device=attention_mask.device)
# 使用torch.cat来扩增attention_mask
extended_attention_mask = torch.cat((attention_mask, new_tokens), dim=1)
return extended_attention_mask

Expand Down Expand Up @@ -549,7 +517,6 @@ def evaluate_posterior(

if sampling == 'typical':
if fast:
## logits 最后一个是新预测的,candidates第0个是原始头的输出,不用比较
posterior_prob = torch.softmax(logits[:,:,:-1] / temperature, dim=-1)
candidates_prob = torch.gather(
posterior_prob, dim=-1, index=candidates[:,:,1:].unsqueeze(-1)
Expand Down Expand Up @@ -634,22 +601,16 @@ def gather_from_past_key_values(past_key_values_data, select_indices):
layers, batch_size, head_num, _, hidden_size = past_key_values_data.shape
seqlen = select_indices.shape[1]

# 初始化结果张量,用于存放选择的数据或全零填充
result_data = torch.zeros(layers, batch_size, head_num, seqlen, hidden_size, device=past_key_values_data.device, dtype=past_key_values_data.dtype)

# 扩展 select_indices 以匹配 past_key_values_data 的操作维度
expanded_indices = select_indices.unsqueeze(0).unsqueeze(2).expand(layers, batch_size, head_num, seqlen)

# 创建一个掩码,用于识别 select_indices 中的有效索引(非 -1 值)
valid_indices_mask = expanded_indices != -1

# 修正 -1 索引值以避免 gather 时的错误,将 -1 替换为一个有效的索引(如 0),后续再通过掩码处理
corrected_indices = torch.where(valid_indices_mask, expanded_indices, torch.zeros_like(expanded_indices))

# 使用 gather 选择数据
gathered_data = torch.gather(past_key_values_data, 3, corrected_indices.unsqueeze(-1).expand(-1, -1, -1, -1, hidden_size))

# 利用掩码将结果中对应 -1 索引的位置替换为全零
result_data = torch.where(valid_indices_mask.unsqueeze(-1), gathered_data, result_data)
return result_data

Expand All @@ -659,9 +620,7 @@ def update_ids_new(input_ids, new_ids):
return input_ids

def update_mask(attention_mask, accept_length):
# 创建一个每行都是0到max_seqlen-1的范围张量
range_tensor = torch.arange(accept_length.max().item(), device='cuda:0').expand(accept_length.shape[0], -1)
# 根据 accept_length 生成 mask,其中有效长度标记为1,其他为0
new_attention_mask = (range_tensor < accept_length.unsqueeze(1)).to(int)
attention_mask = torch.cat((attention_mask, new_attention_mask), dim=-1)
return attention_mask
Expand Down Expand Up @@ -769,8 +728,8 @@ def update_inference_inputs(
valid_length = accept_length
else:
# Extract logits and medusa logits for the last accepted tokens
logits = logits[batch_indices, best_candidate, accept_length-1] #最后一个logits
medusa_logits = medusa_logits[:, batch_indices, best_candidate, accept_length-1] #最后一个logits
logits = logits[batch_indices, best_candidate, accept_length-1]
medusa_logits = medusa_logits[:, batch_indices, best_candidate, accept_length-1]
valid_length = None
# Update the new token counter
new_token += max_accept_length
Expand Down
9 changes: 0 additions & 9 deletions run.sh

This file was deleted.

0 comments on commit 684eb87

Please sign in to comment.