Skip to content

Commit

Permalink
updata DA-ADB_llama
Browse files Browse the repository at this point in the history
  • Loading branch information
somehow77 committed Jun 26, 2024
1 parent 08a081c commit 36192ac
Show file tree
Hide file tree
Showing 14 changed files with 981 additions and 17 deletions.
39 changes: 27 additions & 12 deletions open_intent_detection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ The detailed results can be seen in [results.md](results/results.md).

* KIR means "Known Intent Ratio". "Open" and "Known" denote the macro f1-score over open class and known classes respectively.
* KNNCL (All) utilizes all 12 transformer layers for fine-tuning. KNNCL (last) utilizes only the last transformer layer for fine-tuning as the other baselines.
* We also test the performance of the DA-ADB method on the backbone of llama, corresponding to the methond DA-ADB_llama.

| | | BANKING | | OOS | | StackOverflow | |
|:-----:|:-----:|:-----:|:-----:|:-----:|:-----:|:-----:|:-----:|
Expand All @@ -75,8 +76,9 @@ The detailed results can be seen in [results.md](results/results.md).
|0.25|ARPL|76.8|64.01|84.51|73.44|66.76|62.62|
|0.25|KNNCL (Last)|73.01|66.23|89.87|79.23|28.65|37.37|
|0.25|ADB|79.33|71.63|88.3|78.23|86.75|79.85|
|0.25|KNNCL (All)|**86.14**|**77.01**|**93.07**|**82.45**|85.04|79.06|
|0.25|KNNCL (All)|**86.14**|**77.01**|**93.07**|82.45|85.04|79.06|
|0.25|DA-ADB|81.19|73.73|89.48|79.92|**89.07**|**82.83**|
|0.25|DA-ADB_llama|83.29|76.78|91.53|**84.54**|85.54|77.81|
|||||||||
|0.5|MSP|61.67|72.51|66.68|72.7|53.23|62.7|
|0.5|SEG|55.11|63.32|60.67|62.55|43.04|55.1|
Expand All @@ -89,8 +91,9 @@ The detailed results can be seen in [results.md](results/results.md).
|0.5|ARPL|74.11|77.77|80.36|80.88|75.65|77.87|
|0.5|KNNCL (Last)|70.41|74.96|85.32|83.31|45.38|56.69|
|0.5|ADB|79.61|81.34|86.54|85.16|86.49|85.54|
|0.5|KNNCL (All)|**82.76**|81.31|**88.66**|83.99|86.69|86.15|
|0.5|DA-ADB|81.51|**82.53**|87.93|**85.64**|**87.78**|**86.91**|
|0.5|KNNCL (All)|**82.76**|81.31|88.66|83.99|86.69|86.15|
|0.5|DA-ADB|81.51|82.53|87.93|85.64|**87.78**|**86.91**|
|0.5|DA-ADB_llama|82.66|**83.67**|**90.29**|**88.86**|86.42|86.09|
|||||||||
|0.75|MSP|77.08|84.33|76.19|83.48|73.2|78.7|
|0.75|SEG|64.65|69.54|42.78|42.7|62.72|69.97|
Expand All @@ -102,9 +105,10 @@ The detailed results can be seen in [results.md](results/results.md).
|0.75|MDF|64.59|74.76|63.98|72.02|62.98|71.12|
|0.75|ARPL|79.6|85.16|81.29|86.0|79.64|83.85|
|0.75|KNNCL (Last)|74.78|81.25|84.12|86.1|65.01|71.85|
|0.75|ADB|**81.39**|**86.11**|86.99|**88.94**|82.89|86.11|
|0.75|ADB|81.39|86.11|86.99|88.94|82.89|86.11|
|0.75|KNNCL (All)|77.50|82.30|85.07|85.11|83.15|86.73|
|0.75|DA-ADB|81.12|85.65|**87.39**|88.41|**83.56**|**86.84**|
|0.75|DA-ADB|81.12|85.65|87.39|88.41|**83.56**|**86.84**|
|0.75|DA-ADB_llama|**82.19**|**86.52**|**89.23**|**90.48**|82.69|86.28|

#### Fine-grained Performance

Expand All @@ -122,8 +126,9 @@ The detailed results can be seen in [results.md](results/results.md).
|0.25|ARPL|83.39|62.99|89.63|73.01|72.95|60.55|
|0.25|KNNCL (Last)|79.34|65.54|93.56|78.85|15.26|41.79|
|0.25|ADB|85.05|70.92|92.36|77.85|90.96|77.62|
|0.25|KNNCL (All)|**90.55**|**76.30**|**95.73**|**82.10**|89.59|76.96|
|0.25|KNNCL (All)|**90.55**|**76.30**|**95.73**|82.10|89.59|76.96|
|0.25|DA-ADB|86.57|73.05|93.2|79.57|**92.65**|**80.87**|
|0.25|DA-ADB_llama|88.19|76.18|94.55|**84.27**|90.04|75.36|
|||||||||
|0.5|MSP|46.29|73.2|63.71|72.82|26.94|66.28|
|0.5|SEG|43.03|63.85|61.34|62.57|4.72|60.14|
Expand All @@ -136,8 +141,9 @@ The detailed results can be seen in [results.md](results/results.md).
|0.5|ARPL|71.79|77.93|81.81|80.87|73.97|78.26|
|0.5|KNNCL (Last)|67.21|75.16|87.85|83.25|8.5|61.5|
|0.5|ADB|79.43|81.39|88.6|85.12|87.7|85.32|
|0.5|KNNCL (All)|**84.28**|81.23|**91.17**|83.89|87.59|86.01|
|0.5|DA-ADB|81.93|**82.54**|90.1|**85.58**|**88.86**|**86.71**|
|0.5|KNNCL (All)|**84.28**|81.23|91.17|83.89|87.59|86.01|
|0.5|DA-ADB|81.93|82.54|90.1|85.58|**88.86**|**86.71**|
|0.5|DA-ADB_llama|83.23|**83.68**|**92.04**|**88.82**|87.43|85.95|
|||||||||
|0.75|MSP|46.05|84.99|63.86|83.65|37.86|81.42|
|0.75|SEG|37.22|70.1|40.74|42.72|6.0|74.24|
Expand All @@ -149,9 +155,10 @@ The detailed results can be seen in [results.md](results/results.md).
|0.75|MDF|33.43|75.47|51.33|72.21|28.52|73.96|
|0.75|ARPL|61.26|85.58|74.67|86.1|62.99|85.24|
|0.75|KNNCL (Last)|51.42|81.76|82.05|86.14|7.19|76.16|
|0.75|ADB|67.34|**86.44**|**84.85**|**88.97**|74.1|86.91|
|0.75|ADB|67.34|86.44|84.85|88.97|74.1|86.91|
|0.75|KNNCL (All)|67.01|82.56|84.31|85.11|72.81|**87.66**|
|0.75|DA-ADB|**69.37**|85.93|86.0|88.43|**74.55**|**87.66**|
|0.75|DA-ADB|69.37|85.93|86.0|88.43|**74.55**|**87.66**|
|0.75|DA-ADB_llama|**70.94**|**86.78**|**87.93**|**90.51**|72.84|87.17|

## Tutorials
### a. How to add a new dataset?
Expand All @@ -171,11 +178,11 @@ benchmark_labels = {

### b. How to add a new backbone?

1. Add a new backbone in the [backbones](./backbones) directory. For example, we provide some bert-based backbones in the [file](./backbones/bert.py).
1. Add a new backbone in the [backbones](./backbones) directory. For example, we provide some bert-based backbones in the [file](./backbones/bert.py). And we also provide llama-based backbones in the [file](./backbones/llama.py). You can add a new backbone in the same way.

2. Add the new backbone mapping in the [file](./backbones/__init__.py) as follows:
```
from .bert import new_backbone_class
from .bert import new_backbone_class # from .llama import new_backbone_class
backbones_map = {
'new_backbone': new_backbone_class
}
Expand All @@ -197,6 +204,14 @@ backbone_loader_map = {
}
```

The llama-based model corresponds to the llama dataloader as follows.
```
from .llama_loader import LLAMA_Loader
backbone_loader_map = {
'llama_disaware': LLAMA_Loader,
}
```

3. Add Methods (Take MSP as an example)
- Create a new directory, named "MSP" in the [methods](./methods) directory.

Expand Down
4 changes: 3 additions & 1 deletion open_intent_detection/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .bert import BERT, BERT_Norm, BERT_K_1_way, BERT_SEG, BERT_Disaware, BERT_DOC, BERT_MDF, BERT_MDF_Pretrain, BERT_KNNCL
from .llama import LLAMA_lora_Disaware

backbones_map = {
'bert': BERT,
Expand All @@ -9,5 +10,6 @@
'bert_doc': BERT_DOC,
'bert_mdf': BERT_MDF,
'bert_mdf_pretrain': BERT_MDF_Pretrain,
'bert_knncl': BERT_KNNCL
'bert_knncl': BERT_KNNCL,
'llama_disaware': LLAMA_lora_Disaware,
}
2 changes: 1 addition & 1 deletion open_intent_detection/backbones/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def set_model(self, args, pattern):
backbone = backbones_map[args.backbone]
args.device = self.device = torch.device('cuda:%d' % int(args.gpu_id) if torch.cuda.is_available() else 'cpu')

if pattern == 'bert':
if pattern == 'bert' or pattern == 'llama':
if hasattr(backbone, 'from_pretrained'):
model = backbone.from_pretrained('bert-base-uncased', args = args)
else:
Expand Down
100 changes: 100 additions & 0 deletions open_intent_detection/backbones/llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import torch

from peft import (
LoraConfig,
get_peft_model,
)

from torch import nn
from transformers import AutoModelForCausalLM

from .bert import CosNorm_Classifier

activation_map = {'relu': nn.ReLU(), 'tanh': nn.Tanh()}

class LLAMA_lora_Disaware(nn.Module):
def __init__(self, args):
super().__init__()
self.num_labels = args.num_labels
self.llama = AutoModelForCausalLM.from_pretrained(
args.llama_model,
return_dict=True,
load_in_8bit=False,
device_map=args.device,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
)
self.llama.config.pad_token_id = 0 # unk
self.llama.config.bos_token_id = 1
self.llama.config.eos_token_id = 2
#self.llama.eval()
target_modules=[ "q_proj", "v_proj"]
config = LoraConfig(
r=4,
lora_alpha=8,
target_modules=target_modules,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
print("lora", config)
self.llama = get_peft_model(self.llama, config)
hidden_dropout_prob = 0.1
hidden_size = self.llama.config.hidden_size
hidden_size_2 = hidden_size // 2
self.dense = nn.Linear(hidden_size, hidden_size).half()
self.activation = activation_map[args.activation]
self.dropout = nn.Dropout(hidden_dropout_prob).half()
self.dense = self.dense.to(args.device)
self.activation = self.activation.to(args.device)
self.dropout = self.dropout.to(args.device)
#self.init_weights()
self.cosnorm_classifier = CosNorm_Classifier(
hidden_size, args.num_labels, args.scale, args.device)


def forward(self, input_ids=None, token_type_ids=None, attention_mask=None, labels=None,
feature_ext=False, mode=None, loss_fct=None, centroids=None, dist_infos = None):
outputs = self.llama(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True )
encoded_layer_ = outputs.hidden_states[-1].mean(dim=1)

#input_data = input_data.float()
pooled_output = self.dense(encoded_layer_)
pooled_output = self.activation(pooled_output)
pooled_output = self.dropout(pooled_output)

x = pooled_output

if feature_ext:
return pooled_output

else:

feat_size = x.shape[1]
batch_size = x.shape[0]

f_expand = x.unsqueeze(1).expand(-1, self.num_labels, -1)
centroids_expand = centroids.unsqueeze(0).expand(batch_size, -1, -1)
dist_cur = torch.norm(f_expand - centroids_expand, 2, 2)
values_nn, labels_nn = torch.sort(dist_cur, 1)

nearest_centers = centroids[labels_nn[:, 0]]
dist_denominator = torch.norm(x - nearest_centers, 2, 1)
second_nearest_centers = centroids[labels_nn[:, 1]]
dist_numerator = torch.norm(x - second_nearest_centers, 2, 1)

dist_info = dist_numerator - dist_denominator
dist_info = torch.exp(dist_info)
scalar = dist_info

reachability = scalar.unsqueeze(1).expand(-1, feat_size)
x = reachability * pooled_output

logits = self.cosnorm_classifier(x)

if mode == 'train':
loss = loss_fct(logits, labels)
return loss

elif mode == 'eval':
return pooled_output, logits
47 changes: 47 additions & 0 deletions open_intent_detection/configs/DA-ADB_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
class Param():

def __init__(self, args):

self.hyper_param = self.get_hyper_parameters(args)

def get_hyper_parameters(self, args):
"""
Args:
llama_model (directory): The path for the pre-trained llama model.
num_train_epochs (int): The number of training epochs.
num_labels (autofill): The output dimension.
max_seq_length (autofill): The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded.
freeze_backbone_parameters (binary): Whether to freeze all parameters but the last layer.
feat_dim (int): The feature dimension.
warmup_proportion (float): The warmup ratio for learning rate.
scale (float): The scale factor of the cosine classifier.
lr_boundary (float): The learning rate of the decision boundary.
lr (float): The learning rate of backbone.
activation (str): The activation function of the hidden layer (support 'relu' and 'tanh').
train_batch_size (int): The batch size for training.
eval_batch_size (int): The batch size for evaluation.
test_batch_size (int): The batch size for testing.
wait_patient (int): Patient steps for Early Stop.
"""
hyper_parameters = {

'llama_model': "/home/sharing/disk1/pretrained_embedding/llama/llama",
'num_train_epochs':100,
'num_labels': None,
'max_seq_length': None,
'freeze_backbone_parameters': False,
'feat_dim': 4096,
'warmup_proportion': 0.1,
'scale': 4,
'lr_boundary': 0.05,
'lr': 5e-8,
'activation': 'relu',
'train_batch_size': 32,
'eval_batch_size': 8,
'test_batch_size': 8,
'wait_patient': 10,

}
print("Hyper-parameters: ", hyper_parameters)

return hyper_parameters
4 changes: 3 additions & 1 deletion open_intent_detection/dataloaders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .bert_loader import BERT_Loader
from .llama_loader import LLAMA_Loader

max_seq_lengths = {
'stackoverflow':45,
Expand All @@ -16,7 +17,8 @@
'bert_seg': BERT_Loader,
'bert_disaware': BERT_Loader,
'bert_mdf': BERT_Loader,
'bert_knncl': BERT_Loader
'bert_knncl': BERT_Loader,
'llama_disaware': LLAMA_Loader,
}

benchmark_labels = {
Expand Down
Loading

0 comments on commit 36192ac

Please sign in to comment.