Skip to content

Commit

Permalink
better cie data
Browse files Browse the repository at this point in the history
  • Loading branch information
LittlePea13 committed Jul 31, 2024
1 parent 89bb98c commit 2b86172
Showing 1 changed file with 63 additions and 33 deletions.
96 changes: 63 additions & 33 deletions relik/reader/data/relik_reader_re_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ def _init_relation_labels(self, entities_untyped, sample):
relation_labels = torch.zeros((num_entities, num_entities, num_relations))

if sample.window_triplet_labels_tokens is None:
relation_labels.fill_(-100)
return relation_labels

for relation in sample.window_triplet_labels_tokens:
Expand Down Expand Up @@ -506,7 +507,7 @@ def generator():
continue

# add gold candidates if missing
if self.add_gold_candidates:
if self.add_gold_candidates and sample.window_triplet_labels_tokens:
candidates_set = set(sample.triplet_candidates)
candidates_to_add = set()
for candidate_title in sample.window_triplet_labels_tokens:
Expand Down Expand Up @@ -699,28 +700,34 @@ def flip_cands(flip_candidates, candidates):
acceptable_tokens_from_candidates = (
self.model_max_length - 20 - len(input_subwords)
)

current_len = len(candidates_encoding_result[i]) + len(
candidates_encoding_result[i + len(sample.span_candidates)]
) if len(self.special_symbols_types) > 0 else len(
candidates_encoding_result[i]
)
while (
cum_len + len(candidates_encoding_result[i])
cum_len + current_len
< acceptable_tokens_from_candidates
):
cum_len += len(candidates_encoding_result[i])
cum_len += current_len
i += 1
if len(self.special_symbols_types) == 0 or i + len(sample.span_candidates) >= len(candidates_encoding_result):
current_len = len(candidates_encoding_result[i])
else:
current_len = len(candidates_encoding_result[i] + candidates_encoding_result[i + len(sample.span_candidates)])

assert i > 0

candidates_encoding_result = candidates_encoding_result[:i]
candidates_encoding_result = candidates_encoding_result[:i] + candidates_encoding_result[len(sample.span_candidates):i]
if len(self.special_symbols_types) > 0:
candidates_symbols = candidates_symbols[
: i - len(sample.span_candidates)
]
sample.triplet_candidates = sample.triplet_candidates[
: i - len(sample.span_candidates)
]
else:
candidates_symbols = candidates_symbols[:i]
sample.triplet_candidates = sample.triplet_candidates[:i]
candidates_entities_symbols = candidates_entities_symbols[:i]
sample.span_candidates = sample.span_candidates[:i]

candidates_symbols = candidates_symbols[:i]
sample.triplet_candidates = sample.triplet_candidates[:i]
else:
if len(sample.window_triplet_labels_tokens) == 0:
sample.window_triplet_labels_tokens = []
gold_candidates_set = set(
[wl["relation"] for wl in sample.window_triplet_labels_tokens]
)
Expand All @@ -733,11 +740,19 @@ def flip_cands(flip_candidates, candidates):
gold_candidates_indices = [
i + len(sample.span_candidates)
for i in gold_candidates_indices
]
# add entities indices
gold_candidates_indices = gold_candidates_indices + list(
range(len(sample.span_candidates))
] + [len(sample.span_candidates)]
gold_candidates_set_entities = set(
[wl[2] for wl in sample.window_labels_tokens]
)
gold_candidates_indices += [
i
for i, wc in enumerate(sample.span_candidates)
if wc in gold_candidates_set_entities
]
# # add entities indices
# gold_candidates_indices = gold_candidates_indices + list(
# range(len(sample.span_candidates))
# )
necessary_taken_tokens = sum(
map(
len,
Expand All @@ -757,7 +772,7 @@ def flip_cands(flip_candidates, candidates):
if acceptable_tokens_from_candidates <= 0:
logger.warning(
"Sample {} has no candidates after truncation due to max length".format(
sample.id
sample.doc_id
)
)
continue
Expand All @@ -784,10 +799,10 @@ def flip_cands(flip_candidates, candidates):
if len(self.special_symbols_types) > 0:
sample.triplet_candidates = [
sample.triplet_candidates[i - len(sample.span_candidates)]
for i in new_indices[len(sample.span_candidates) :]
for i in new_indices[len(sample.span_candidates) :-1]
]
candidates_symbols = candidates_symbols[
: i - len(sample.span_candidates)
: len(sample.triplet_candidates)
]
else:
candidates_symbols = [
Expand All @@ -799,7 +814,7 @@ def flip_cands(flip_candidates, candidates):
if len(sample.triplet_candidates) == 0:
logger.warning(
"Sample {} has no candidates after truncation due to max length".format(
sample.sample_id
sample.doc_id
)
)
continue
Expand Down Expand Up @@ -838,6 +853,9 @@ def flip_cands(flip_candidates, candidates):
sample,
tokenization_output,
)
elif not self.for_inference:
continue

if self.materialize_samples:
sample.materialize = {
"tokenization_output": tokenization_output,
Expand Down Expand Up @@ -1172,18 +1190,30 @@ def convert_to_char_annotations(
entities.append(entity)
sample.predicted_entities = entities
for triplet in sample.predicted_relations:
triplet["subject"][0] = sample.token2char_start[
triplet["subject"] = (
sample.token2char_start[
str(sample.word2token_start[str(triplet["subject"][0])])
]
triplet["subject"][1] = sample.token2char_end[
str(sample.word2token_end[str(triplet["subject"][1] - 1)])
]
triplet["object"][0] = sample.token2char_start[
str(sample.word2token_start[str(triplet["object"][0])])
]
triplet["object"][1] = sample.token2char_end[
str(sample.word2token_end[str(triplet["object"][1] - 1)])
]
],
sample.token2char_end[
str(sample.word2token_end[str(triplet["subject"][1] - 1)])
],
triplet["subject"][2],
)
triplet["object"] = (
sample.token2char_start[
str(sample.word2token_start[str(triplet["object"][0])])
],
sample.token2char_end[
str(sample.word2token_end[str(triplet["object"][1] - 1)])
],
triplet["object"][2],
)
# triplet["object"][0] = sample.token2char_start[
# str(sample.word2token_start[str(triplet["object"][0])])
# ]
# triplet["object"][1] = sample.token2char_end[
# str(sample.word2token_end[str(triplet["object"][1] - 1)])
# ]

sample = RelikREDataset._new_output_format(sample)

Expand Down

0 comments on commit 2b86172

Please sign in to comment.