Skip to content

Commit

Permalink
Merge pull request #282 from yangheng95/dev
Browse files Browse the repository at this point in the history
2.1.12
  • Loading branch information
yangheng95 authored Mar 18, 2023
2 parents 14e32b7 + 89b6d5b commit 3d6cc4a
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from pyabsa import ABSAInstruction

if __name__ == "__main__":
generator = ABSAInstruction.ABSAGenerator("multilingual")
generator = ABSAInstruction.ABSAGenerator(
"checkpoints/multitask/googleflan-t5-base-instruction/checkpoint-2745"
)
example = [
"The food is good, but the service is bad.",
"The laptop is good, but the battery life is bad.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,18 @@

import findfile
from pyabsa import ABSAInstruction as absa_instruction

warnings.filterwarnings("ignore")
import pandas as pd


task_name = "multitask"
experiment_name = "instruction"
# model_checkpoint = 'allenai/tk-instruct-base-def-pos'
model_checkpoint = "kevinscaria/ate_tk-instruct-base-def-pos-neg-neut-combined"
# model_checkpoint = "kevinscaria/ate_tk-instruct-base-def-pos-neg-neut-combined"
# model_checkpoint = 'allenai/tk-instruct-large-def-pos'
# model_checkpoint = 'allenai/tk-instruct-3b-def-pos'
# model_checkpoint = 'google/mt5-base'
model_checkpoint = "google/flan-t5-base"

print("Experiment Name: ", experiment_name)
model_out_path = "checkpoints"
Expand All @@ -33,12 +34,12 @@
# Load the data
# id_train_file_path = './integrated_datasets'
# id_test_file_path = './integrated_datasets'
# id_train_file_path = "./integrated_datasets/acos_datasets/"
# id_test_file_path = "./integrated_datasets/acos_datasets"
id_train_file_path = './integrated_datasets/acos_datasets/501.Laptop14'
id_test_file_path = './integrated_datasets/acos_datasets/501.Laptop14'
# id_train_file_path = './integrated_datasets/acos_datasets/504.Restaurant16'
# id_test_file_path = './integrated_datasets/acos_datasets/504.Restaurant16'
id_train_file_path = "./integrated_datasets/acos_datasets/"
id_test_file_path = "./integrated_datasets/acos_datasets"
# id_train_file_path = './integrated_datasets/acos_datasets/501.Laptop14'
# id_test_file_path = './integrated_datasets/acos_datasets/501.Laptop14'
# id_train_file_path = './integrated_datasets/acos_datasets/502.Restaurant14'
# id_test_file_path = './integrated_datasets/acos_datasets/502.Restaurant14'


id_tr_df = absa_instruction.data_utils.read_json(id_train_file_path, "train")
Expand Down Expand Up @@ -72,9 +73,9 @@
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"learning_rate": 5e-5,
"per_device_train_batch_size": 4,
"per_device_train_batch_size": 16,
"per_device_eval_batch_size": 16,
"num_train_epochs": 6,
"num_train_epochs": 3,
"weight_decay": 0.01,
"warmup_ratio": 0.1,
"load_best_model_at_end": True,
Expand Down
2 changes: 1 addition & 1 deletion pyabsa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# Copyright (C) 2021. All Rights Reserved.

__name__ = "pyabsa"
__version__ = "2.1.11"
__version__ = "2.1.12"

from pyabsa.framework.flag_class import *

Expand Down
61 changes: 36 additions & 25 deletions pyabsa/tasks/ABSAInstruction/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,57 +48,67 @@ def prepare_instruction_dataloader(self, df):
cat_instructor = CategoryInstruction()
alldata = []
for i, data in df.iterrows():
_aspects = [label["aspect"] for label in data["labels"]]
_aspects = ["aspect:" + label["aspect"] for label in data["labels"]]
aspects = []
for asp in _aspects:
if asp.strip() not in aspects:
aspects.append(asp.strip())
aspects = ", ".join(aspects)
alldata.append(
{"text": ate_instructor.prepare_input(data["text"]), "labels": aspects}
)
aspects = "|".join(aspects)

opinions = ", ".join(
polarities = []
_polarities = [
"{}:{}".format(label["aspect"], label["polarity"])
for label in data["labels"]
]
for pol in _polarities:
if pol not in polarities:
polarities.append(pol)
polarities = "|".join(polarities)

opinions = "|".join(
[
"{}:{}".format(label["aspect"], label["opinion"])
for label in data["labels"]
]
)
alldata.append(
{
"text": op_instructor.prepare_input(data["text"], aspects),
"labels": opinions,
}
)

polarities = ", ".join(
categories = "|".join(
[
"{}:{}".format(label["aspect"], label["polarity"])
"{}:{}".format(label["aspect"], label["category"])
for label in data["labels"]
]
)

# ATE task
alldata.append(
{"text": ate_instructor.prepare_input(data["text"]), "labels": aspects}
)

# APC task
alldata.append(
{
"text": apc_instructor.prepare_input(data["text"], aspects),
"labels": polarities,
}
)

categories = ", ".join(
[
"{}:{}".format(
label["aspect"], label["category"].replace("NULL", "")
)
for label in data["labels"]
]
)
# Opinion task
alldata.append(
{
"text": cat_instructor.prepare_input(data["text"], aspects),
"labels": categories,
"text": op_instructor.prepare_input(data["text"], aspects),
"labels": opinions,
}
)
# print(alldata[-1]['labels'])

# Category task
if "NULL" not in categories:
alldata.append(
{
"text": cat_instructor.prepare_input(data["text"], aspects),
"labels": categories,
}
)

alldata = pd.DataFrame(alldata)
return alldata

Expand Down Expand Up @@ -163,6 +173,7 @@ def read_json(data_path, data_type="train"):

files = findfile.find_files(data_path, [data_type, ".jsonl"], exclude_key=[".txt"])
for f in files:
print(f)
with open(f, "r", encoding="utf8") as fin:
for line in fin:
data.append(json.loads(line))
Expand Down
16 changes: 8 additions & 8 deletions pyabsa/tasks/ABSAInstruction/instruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ def __init__(self, bos_instruction=None, eos_instruction=None):
example 1-
input: I charge it at night and skip taking the cord with me because of the good battery life.
{self.eos_instruction}
battery life, cord
aspect:battery life|aspect:cord
example 2-
input: Great food, good size menu, great service and an unpretensious setting.
{self.eos_instruction}
food, menu, service, setting
aspect:food|aspect:menu|aspect:service|aspect:setting
Now extract aspects from the following example:
input: """
Expand Down Expand Up @@ -64,13 +64,13 @@ def __init__(self, bos_instruction=None, eos_instruction=None):
input: I charge it at night and skip taking the cord with me because of the good battery life.
The aspects are: battery life, cord
{self.eos_instruction}
battery life:positive, cord:positive
battery life:positive|cord:positive
example 2-
input: Great food, good size menu, great service and an unpretensious setting.
The aspects are: food, menu, service, setting
{self.eos_instruction}
food:positive, menu:positive, service:positive, setting:positive
food:positive|menu:positive|service:positive|setting:positive
Now predict aspect sentiments from the following example:
Expand Down Expand Up @@ -103,13 +103,13 @@ def __init__(self, bos_instruction=None, eos_instruction=None):
input: I charge it at night and skip taking the cord with me because of the good battery life.
The aspects are: battery life, cord
{self.eos_instruction}
battery life:good, cord:NULL
battery life:good|cord:NULL
example 2-
input: Great food, good size menu, great service and an unpretensious setting.
The aspects are: food, menu, service, setting
{self.eos_instruction}
food:great, menu:good, service:great, setting:unpretensious
food:great|menu:good|service:great|setting:unpretensious
Now extract opinions for the following example:
input:"""
Expand Down Expand Up @@ -141,11 +141,11 @@ def __init__(self, bos_instruction=None, eos_instruction=None):
input: I charge it at night and skip taking the cord with me because of the good battery life.
The aspects are: battery life, cord
{self.eos_instruction}
battery life:POWER_SUPPLY#GENERAL, cord:NULL
battery life:POWER_SUPPLY#GENERAL|cord:NULL
example 2-
input: Great food, good size menu, great service and an unpretensious setting.
The aspects are: food, menu, service, setting
The aspects are: food:FOOD#QUALITY| menu:RESTAURANT#GENERAL|service:SERVICE#GENERAL|setting:SERVICE#GENERAL
{self.eos_instruction}
food:FOOD#QUALITY, menu:RESTAURANT#GENERAL, service:SERVICE#GENERAL, setting:SERVICE#GENERAL
Expand Down
63 changes: 41 additions & 22 deletions pyabsa/tasks/ABSAInstruction/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import autocuda
import sklearn
import torch
from pyabsa.framework.checkpoint_class.checkpoint_template import CheckpointManager
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -32,6 +33,7 @@ def __init__(self, checkpoint):

self.tokenizer = AutoTokenizer.from_pretrained(checkpoint)
self.model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
self.model.config.max_length = 128
self.data_collator = DataCollatorForSeq2Seq(self.tokenizer)
self.device = autocuda.auto_cuda()
self.model.to(self.device)
Expand Down Expand Up @@ -94,7 +96,7 @@ def predict(self, text, **kwargs):
ate_outputs = self.tokenizer.batch_decode(
ate_outputs, skip_special_tokens=True
)[0]
result["aspect"] = [asp.strip() for asp in ate_outputs.split(",")]
result["aspect"] = [asp.strip() for asp in ate_outputs.split("|")]

# APC inference
inputs = self.tokenizer(
Expand All @@ -106,7 +108,7 @@ def predict(self, text, **kwargs):
apc_outputs = self.tokenizer.batch_decode(
apc_outputs, skip_special_tokens=True
)[0]
result["sentiment"] = [sent.strip() for sent in apc_outputs.split(",")]
result["sentiment"] = [sent.strip() for sent in apc_outputs.split("|")]

# Opinion inference
inputs = self.tokenizer(
Expand All @@ -118,7 +120,7 @@ def predict(self, text, **kwargs):
op_outputs = self.tokenizer.batch_decode(op_outputs, skip_special_tokens=True)[
0
]
result["opinion"] = [op.strip() for op in op_outputs.split(",")]
result["opinion"] = [op.strip() for op in op_outputs.split("|")]

# Category inference
inputs = self.tokenizer(
Expand All @@ -130,7 +132,7 @@ def predict(self, text, **kwargs):
cat_outputs = self.tokenizer.batch_decode(
cat_outputs, skip_special_tokens=True
)[0]
result["category"] = [cat.strip() for cat in cat_outputs.split(",")]
result["category"] = [cat.strip() for cat in cat_outputs.split("|")]
ensemble_result = {
"text": text,
"Quadruples": [
Expand Down Expand Up @@ -207,26 +209,43 @@ def get_aspect_metrics(self, true_aspects, pred_aspects):
return aspect_p, aspect_r, aspect_f1

def get_classic_metrics(self, y_true, y_pred):
total_pred = 0
total_gt = 0
tp = 1e-6
valid_gts = []
valid_preds = []
for gt, pred in zip(y_true, y_pred):
print(gt)
print(pred)

gt_list = gt.split(", ")
pred_list = pred.split(", ")
total_pred += len(pred_list)
total_gt += len(gt_list)
for gt_val in gt_list:
gt_list = gt.split("|")
pred_list = pred.split("|")
while gt_list:
gt_val = gt_list[-1].strip().lower()
for pred_val in pred_list:
gt_val = gt_val.replace(" ", "")
pred_val = pred_val.replace(" ", "")
if pred_val.strip().lower() == gt_val.strip().lower():
tp += 1
p = tp / total_pred
r = tp / total_gt
return {"precision": p, "recall": r, "f1": 2 * p * r / (p + r)}
pred_val = pred_val.strip().lower()
gt_key, _, gt_label = gt_val.partition(":")
pred_key, _, pred_label = pred_val.partition(":")
if gt_key.startswith(pred_key):
if gt_label:
valid_gts.append(gt_label)
else:
break
if pred_label:
valid_preds.append(pred_label)
else:
valid_preds.append("")
break

gt_list.pop()

report = sklearn.metrics.classification_report(valid_gts, valid_preds)
print(report)
accuracy = sklearn.metrics.accuracy_score(valid_gts, valid_preds)
precision = precision_score(valid_gts, valid_preds, average="macro")
recall = recall_score(valid_gts, valid_preds, average="macro")
f1 = f1_score(valid_gts, valid_preds, average="macro")

return {
"accuracy": accuracy,
"precision": precision,
"recall": recall,
"f1": f1,
}

# def get_classic_metrics(self, y_true, y_pred):
#
Expand Down

0 comments on commit 3d6cc4a

Please sign in to comment.