Skip to content

Commit

Permalink
take care of "NA" token and add more metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
Riroaki authored May 25, 2020
1 parent 9c34757 commit a99c3dd
Showing 1 changed file with 121 additions and 60 deletions.
181 changes: 121 additions & 60 deletions code/eval_tacred.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
from knowledge_bert.optimization import BertAdam
from knowledge_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE

logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO)
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -90,12 +90,13 @@ def get_dev_examples(self, data_dir):
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()

@classmethod
def _read_json(cls, input_file):
with open(input_file, "r", encoding='utf-8') as f:
return json.loads(f.read())


class TacredProcessor(DataProcessor):
"""Processor for the CoLA data set (GLUE version)."""

Expand All @@ -116,7 +117,6 @@ def get_test_examples(self, data_dir):
return self._create_examples(
self._read_json(os.path.join(data_dir, "test.json")), "dev")


def get_labels(self):
"""Useless"""
return ["0", "1"]
Expand All @@ -129,7 +129,7 @@ def _create_examples(self, lines, set_type):
for x in line['ents']:
if x[1] == 1:
x[1] = 0
#print(line['text'][x[1]:x[2]].encode("utf-8"))
# print(line['text'][x[1]:x[2]].encode("utf-8"))
text_a = (line['text'], line['ents'])
label = line['label']
examples.append(
Expand All @@ -139,9 +139,9 @@ def _create_examples(self, lines, set_type):

def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, threshold):
"""Loads a data file into a list of `InputBatch`s."""

label_list = sorted(label_list)
label_map = {label : i for i, label in enumerate(label_list)}
label_map = {label: i for i, label in enumerate(label_list)}

entity2id = {}
with open("kg_embed/entity2id.txt") as fin:
Expand All @@ -157,11 +157,13 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer
h_name = ex_text_a[h[1]:h[2]]
t_name = ex_text_a[t[1]:t[2]]
if h[1] < t[1]:
ex_text_a = ex_text_a[:h[1]] + "# "+h_name+" #" + ex_text_a[h[2]:t[1]] + "$ "+t_name+" $" + ex_text_a[t[2]:]
ex_text_a = ex_text_a[:h[1]] + "# "+h_name+" #" + \
ex_text_a[h[2]:t[1]] + "$ "+t_name+" $" + ex_text_a[t[2]:]
else:
ex_text_a = ex_text_a[:t[1]] + "$ "+t_name+" $" + ex_text_a[t[2]:h[1]] + "# "+h_name+" #" + ex_text_a[h[2]:]
ex_text_a = ex_text_a[:t[1]] + "$ "+t_name+" $" + \
ex_text_a[t[2]:h[1]] + "# "+h_name+" #" + ex_text_a[h[2]:]

ent_pos = [x for x in example.text_b if x[-1]>threshold]
ent_pos = [x for x in example.text_b if x[-1] > threshold]
for x in ent_pos:
cnt = 0
if x[1] > h[2]:
Expand All @@ -178,11 +180,13 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer

tokens_b = None
if False:
tokens_b, entities_b = tokenizer.tokenize(example.text_b[0], [x for x in example.text_b[1] if x[-1]>threshold])
tokens_b, entities_b = tokenizer.tokenize(
example.text_b[0], [x for x in example.text_b[1] if x[-1] > threshold])
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3"
_truncate_seq_pair(tokens_a, tokens_b, entities_a, entities_b, max_seq_length - 3)
_truncate_seq_pair(tokens_a, tokens_b, entities_a,
entities_b, max_seq_length - 3)
else:
# Account for [CLS] and [SEP] with "- 2"
if len(tokens_a) > max_seq_length - 2:
Expand Down Expand Up @@ -252,22 +256,24 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer
logger.info("*** Example ***")
logger.info("guid: %s" % (example.guid))
logger.info("tokens: %s" % " ".join(
[str(x) for x in tokens]))
[str(x) for x in tokens]))
logger.info("ents: %s" % " ".join(
[str(x) for x in ents]))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
[str(x) for x in ents]))
logger.info("input_ids: %s" %
" ".join([str(x) for x in input_ids]))
logger.info("input_mask: %s" %
" ".join([str(x) for x in input_mask]))
logger.info(
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
logger.info("label: %s (id = %d)" % (example.label, label_id))

features.append(
InputFeatures(input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
input_ent=input_ent,
ent_mask=ent_mask,
label_id=label_id))
InputFeatures(input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
input_ent=input_ent,
ent_mask=ent_mask,
label_id=label_id))
return features


Expand All @@ -289,19 +295,51 @@ def _truncate_seq_pair(tokens_a, tokens_b, ents_a, ents_b, max_length):
tokens_b.pop()
ents_b.pop()

def accuracy(out, labels):
outputs = np.argmax(out, axis=1)
return np.sum(outputs == labels), outputs

def warmup_linear(x, warmup=0.002):
if x < warmup:
return x/warmup
return 1.0


def eval_result(pred_result, labels, na_id):
correct = 0
total = len(labels)
correct_positive = 0
pred_positive = 0
gold_positive = 0

for i in range(total):
if labels[i] == pred_result[i]:
correct += 1
if labels[i] != na_id:
correct_positive += 1
if labels[i] != na_id:
gold_positive += 1
if pred_result[i] != na_id:
pred_positive += 1
acc = float(correct) / float(total)
try:
micro_p = float(correct_positive) / float(pred_positive)
except:
micro_p = 0
try:
micro_r = float(correct_positive) / float(gold_positive)
except:
micro_r = 0
try:
micro_f1 = 2 * micro_p * micro_r / (micro_p + micro_r)
except:
micro_f1 = 0
result = {'acc': acc, 'micro_p': micro_p,
'micro_r': micro_r, 'micro_f1': micro_f1}
return result


def main():
parser = argparse.ArgumentParser()

## Required parameters
# Required parameters
parser.add_argument("--data_dir",
default=None,
type=str,
Expand All @@ -315,7 +353,7 @@ def main():
required=True,
help="The output directory where the model predictions and checkpoints will be written.")

## Other parameters
# Other parameters
parser.add_argument("--max_seq_length",
default=128,
type=int,
Expand Down Expand Up @@ -389,7 +427,8 @@ def main():
num_labels_task = 42

if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
device = torch.device(
"cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
n_gpu = torch.cuda.device_count()
else:
torch.cuda.set_device(args.local_rank)
Expand All @@ -402,7 +441,7 @@ def main():

if args.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
args.gradient_accumulation_steps))
args.gradient_accumulation_steps))

random.seed(args.seed)
np.random.seed(args.seed)
Expand All @@ -411,13 +450,15 @@ def main():
torch.cuda.manual_seed_all(args.seed)

if not args.do_train and not args.do_eval:
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
raise ValueError(
"At least one of `do_train` or `do_eval` must be True.")

processor = processors()
num_labels = num_labels_task
label_list = None

tokenizer = BertTokenizer.from_pretrained(args.ernie_model, do_lower_case=args.do_lower_case)
tokenizer = BertTokenizer.from_pretrained(
args.ernie_model, do_lower_case=args.do_lower_case)

train_examples = None
num_train_steps = None
Expand All @@ -429,8 +470,11 @@ def main():
for line in fin:
vec = line.strip().split('\t')
vec = [float(x) for x in vec]
if len(vec) != 100:
diff = 100 - len(vec)
vec = vec + [0 for _ in range(diff)]
vecs.append(vec)
embed = torch.FloatTensor(vecs)
embed = torch.tensor(vecs, dtype=torch.float)
embed = torch.nn.Embedding.from_pretrained(embed)

logger.info("Shape of entity embedding: "+str(embed.weight.size()))
Expand All @@ -451,22 +495,26 @@ def main():
test = convert_examples_to_features(
eval_examples, label_list, args.max_seq_length, tokenizer, args.threshold)


for x, mark in file_mark:
print(x, mark)
output_model_file = os.path.join(args.output_dir, x)
model_state_dict = torch.load(output_model_file)
model, _ = BertForSequenceClassification.from_pretrained(args.ernie_model, state_dict=model_state_dict, num_labels=len(label_list))
model, _ = BertForSequenceClassification.from_pretrained(
args.ernie_model, state_dict=model_state_dict, num_labels=len(label_list))
model.to(device)

if mark:
eval_features = dev
output_file = os.path.join(args.output_dir, "eval_pred_{}.txt".format(x.split("_")[-1]))
output_file_ = os.path.join(args.output_dir, "eval_gold_{}.txt".format(x.split("_")[-1]))
output_file = os.path.join(
args.output_dir, "eval_pred_{}.txt".format(x.split("_")[-1]))
output_file_ = os.path.join(
args.output_dir, "eval_gold_{}.txt".format(x.split("_")[-1]))
else:
eval_features = test
output_file = os.path.join(args.output_dir, "test_pred_{}.txt".format(x.split("_")[-1]))
output_file_ = os.path.join(args.output_dir, "test_gold_{}.txt".format(x.split("_")[-1]))
output_file = os.path.join(
args.output_dir, "test_pred_{}.txt".format(x.split("_")[-1]))
output_file_ = os.path.join(
args.output_dir, "test_gold_{}.txt".format(x.split("_")[-1]))
fpred = open(output_file, "w")
fgold = open(output_file_, "w")

Expand All @@ -476,22 +524,31 @@ def main():
# zeros = [0 for _ in range(args.max_seq_length)]
# zeros_ent = [0 for _ in range(100)]
# zeros_ent = [zeros_ent for _ in range(args.max_seq_length)]
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
all_ent = torch.tensor([f.input_ent for f in eval_features], dtype=torch.long)
all_ent_masks = torch.tensor([f.ent_mask for f in eval_features], dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_ent, all_ent_masks, all_label_ids)
all_input_ids = torch.tensor(
[f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor(
[f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor(
[f.segment_ids for f in eval_features], dtype=torch.long)
all_label_ids = torch.tensor(
[f.label_id for f in eval_features], dtype=torch.long)
all_ent = torch.tensor(
[f.input_ent for f in eval_features], dtype=torch.long)
all_ent_masks = torch.tensor(
[f.ent_mask for f in eval_features], dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask,
all_segment_ids, all_ent, all_ent_masks, all_label_ids)
# Run prediction for full data
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
eval_dataloader = DataLoader(
eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

model.eval()
eval_loss, eval_accuracy = 0, 0
eval_loss = 0
nb_eval_steps, nb_eval_examples = 0, 0
pred_all, label_all = [], []
for input_ids, input_mask, segment_ids, input_ent, ent_mask, label_ids in eval_dataloader:
input_ent = embed(input_ent+1) # -1 -> 0
input_ent = embed(input_ent+1) # -1 -> 0
input_ids = input_ids.to(device)
input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device)
Expand All @@ -500,38 +557,42 @@ def main():
label_ids = label_ids.to(device)

with torch.no_grad():
tmp_eval_loss = model(input_ids, segment_ids, input_mask, input_ent, ent_mask, label_ids)
logits = model(input_ids, segment_ids, input_mask, input_ent, ent_mask)
tmp_eval_loss = model(
input_ids, segment_ids, input_mask, input_ent, ent_mask, label_ids)
logits = model(input_ids, segment_ids,
input_mask, input_ent, ent_mask)

logits = logits.detach().cpu().numpy()
label_ids = label_ids.to('cpu').numpy()
tmp_eval_accuracy, pred = accuracy(logits, label_ids)
pred = np.argmax(logits, axis=1)
for a, b in zip(pred, label_ids):
pred_all.append(a)
label_all.append(b)
fgold.write("{}\n".format(label_list[b]))
fpred.write("{}\n".format(label_list[a]))

eval_loss += tmp_eval_loss.mean().item()
eval_accuracy += tmp_eval_accuracy

nb_eval_examples += input_ids.size(0)
nb_eval_steps += 1

eval_loss = eval_loss / nb_eval_steps
eval_accuracy = eval_accuracy / nb_eval_examples

result = {'eval_loss': eval_loss,
'eval_accuracy': eval_accuracy
}
result = eval_result(pred_all, label_all, label_list.index("NA"))

if mark:
output_eval_file = os.path.join(args.output_dir, "eval_results_{}.txt".format(x.split("_")[-1]))
output_eval_file = os.path.join(
args.output_dir, "eval_results_{}.txt".format(x.split("_")[-1]))
else:
output_eval_file = os.path.join(args.output_dir, "test_results_{}.txt".format(x.split("_")[-1]))
output_eval_file = os.path.join(
args.output_dir, "test_results_{}.txt".format(x.split("_")[-1]))

with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))


if __name__ == "__main__":
main()

0 comments on commit a99c3dd

Please sign in to comment.