Skip to content

Commit

Permalink
add src_q/k/v_bias for cross_att (#2685)
Browse files Browse the repository at this point in the history
* add src_q/k/v_bias for cross_att

* fix lint

* add src_key_bias for whisper cross attn
  • Loading branch information
Mddct authored Feb 7, 2025
1 parent ead4d14 commit 59dc505
Show file tree
Hide file tree
Showing 7 changed files with 11 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ decoder_conf:
gradient_checkpointing: true
input_layer: embed_learnable_pe
key_bias: false
src_key_bias: false
linear_units: 5120
normalize_before: true
num_blocks: 32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ decoder_conf:
gradient_checkpointing: true
input_layer: embed_learnable_pe
key_bias: false
src_key_bias: false
linear_units: 5120
normalize_before: true
num_blocks: 32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ decoder_conf:
gradient_checkpointing: true
input_layer: embed_learnable_pe
key_bias: false
src_key_bias: false
linear_units: 5120
normalize_before: true
num_blocks: 32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ decoder_conf:
gradient_checkpointing: true
input_layer: embed_learnable_pe
key_bias: false
src_key_bias: false
linear_units: 5120
normalize_before: true
num_blocks: 32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ decoder_conf:
gradient_checkpointing: true
input_layer: embed_learnable_pe
key_bias: false
src_key_bias: false
linear_units: 5120
normalize_before: true
num_blocks: 32
Expand Down
7 changes: 5 additions & 2 deletions wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def __init__(
mlp_bias: bool = True,
n_expert: int = 8,
n_expert_activated: int = 2,
src_query_bias: bool = True,
src_key_bias: bool = True,
src_value_bias: bool = True,
):
super().__init__()
attention_dim = encoder_output_size
Expand Down Expand Up @@ -123,8 +126,8 @@ def __init__(
value_bias, use_sdpa, n_kv_head, head_dim),
WENET_ATTENTION_CLASSES["crossattn"](
attention_heads, attention_dim, src_attention_dropout_rate,
query_bias, key_bias, value_bias, use_sdpa, n_kv_head,
head_dim) if src_attention else None,
src_query_bias, src_key_bias, src_value_bias, use_sdpa,
n_kv_head, head_dim) if src_attention else None,
mlp_class(attention_dim,
linear_units,
dropout_rate,
Expand Down
1 change: 1 addition & 0 deletions wenet/whisper/convert_whisper_to_wenet_config_and_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def convert_to_wenet_yaml(tokenizer, dims, wenet_yaml_path: str):
configs['decoder_conf']['normalize_before'] = True
configs['decoder_conf']['src_attention'] = True
configs['decoder_conf']['key_bias'] = False
configs['decoder_conf']['src_key_bias'] = False
configs['decoder_conf']['activation_type'] = "gelu"

configs['tokenizer'] = 'whisper'
Expand Down

0 comments on commit 59dc505

Please sign in to comment.