diff --git a/relik/reader/data/relik_reader_re_data.py b/relik/reader/data/relik_reader_re_data.py index 1c80bb3..3df5136 100644 --- a/relik/reader/data/relik_reader_re_data.py +++ b/relik/reader/data/relik_reader_re_data.py @@ -337,6 +337,7 @@ def _init_relation_labels(self, entities_untyped, sample): relation_labels = torch.zeros((num_entities, num_entities, num_relations)) if sample.window_triplet_labels_tokens is None: + relation_labels.fill_(-100) return relation_labels for relation in sample.window_triplet_labels_tokens: @@ -506,7 +507,7 @@ def generator(): continue # add gold candidates if missing - if self.add_gold_candidates: + if self.add_gold_candidates and sample.window_triplet_labels_tokens: candidates_set = set(sample.triplet_candidates) candidates_to_add = set() for candidate_title in sample.window_triplet_labels_tokens: @@ -699,28 +700,34 @@ def flip_cands(flip_candidates, candidates): acceptable_tokens_from_candidates = ( self.model_max_length - 20 - len(input_subwords) ) - + current_len = len(candidates_encoding_result[i]) + len( + candidates_encoding_result[i + len(sample.span_candidates)] + ) if len(self.special_symbols_types) > 0 else len( + candidates_encoding_result[i] + ) while ( - cum_len + len(candidates_encoding_result[i]) + cum_len + current_len < acceptable_tokens_from_candidates ): - cum_len += len(candidates_encoding_result[i]) + cum_len += current_len i += 1 + if len(self.special_symbols_types) == 0 or i + len(sample.span_candidates) >= len(candidates_encoding_result): + current_len = len(candidates_encoding_result[i]) + else: + current_len = len(candidates_encoding_result[i] + candidates_encoding_result[i + len(sample.span_candidates)]) assert i > 0 - candidates_encoding_result = candidates_encoding_result[:i] + candidates_encoding_result = candidates_encoding_result[:i] + candidates_encoding_result[len(sample.span_candidates):i] if len(self.special_symbols_types) > 0: - candidates_symbols = candidates_symbols[ - : i - len(sample.span_candidates) - ] - sample.triplet_candidates = sample.triplet_candidates[ - : i - len(sample.span_candidates) - ] - else: - candidates_symbols = candidates_symbols[:i] - sample.triplet_candidates = sample.triplet_candidates[:i] + candidates_entities_symbols = candidates_entities_symbols[:i] + sample.span_candidates = sample.span_candidates[:i] + + candidates_symbols = candidates_symbols[:i] + sample.triplet_candidates = sample.triplet_candidates[:i] else: + if len(sample.window_triplet_labels_tokens) == 0: + sample.window_triplet_labels_tokens = [] gold_candidates_set = set( [wl["relation"] for wl in sample.window_triplet_labels_tokens] ) @@ -733,11 +740,19 @@ def flip_cands(flip_candidates, candidates): gold_candidates_indices = [ i + len(sample.span_candidates) for i in gold_candidates_indices - ] - # add entities indices - gold_candidates_indices = gold_candidates_indices + list( - range(len(sample.span_candidates)) + ] + [len(sample.span_candidates)] + gold_candidates_set_entities = set( + [wl[2] for wl in sample.window_labels_tokens] ) + gold_candidates_indices += [ + i + for i, wc in enumerate(sample.span_candidates) + if wc in gold_candidates_set_entities + ] + # # add entities indices + # gold_candidates_indices = gold_candidates_indices + list( + # range(len(sample.span_candidates)) + # ) necessary_taken_tokens = sum( map( len, @@ -757,7 +772,7 @@ def flip_cands(flip_candidates, candidates): if acceptable_tokens_from_candidates <= 0: logger.warning( "Sample {} has no candidates after truncation due to max length".format( - sample.id + sample.doc_id ) ) continue @@ -784,10 +799,10 @@ def flip_cands(flip_candidates, candidates): if len(self.special_symbols_types) > 0: sample.triplet_candidates = [ sample.triplet_candidates[i - len(sample.span_candidates)] - for i in new_indices[len(sample.span_candidates) :] + for i in new_indices[len(sample.span_candidates) :-1] ] candidates_symbols = candidates_symbols[ - : i - len(sample.span_candidates) + : len(sample.triplet_candidates) ] else: candidates_symbols = [ @@ -799,7 +814,7 @@ def flip_cands(flip_candidates, candidates): if len(sample.triplet_candidates) == 0: logger.warning( "Sample {} has no candidates after truncation due to max length".format( - sample.sample_id + sample.doc_id ) ) continue @@ -838,6 +853,9 @@ def flip_cands(flip_candidates, candidates): sample, tokenization_output, ) + elif not self.for_inference: + continue + if self.materialize_samples: sample.materialize = { "tokenization_output": tokenization_output, @@ -1172,18 +1190,30 @@ def convert_to_char_annotations( entities.append(entity) sample.predicted_entities = entities for triplet in sample.predicted_relations: - triplet["subject"][0] = sample.token2char_start[ + triplet["subject"] = ( + sample.token2char_start[ str(sample.word2token_start[str(triplet["subject"][0])]) - ] - triplet["subject"][1] = sample.token2char_end[ - str(sample.word2token_end[str(triplet["subject"][1] - 1)]) - ] - triplet["object"][0] = sample.token2char_start[ - str(sample.word2token_start[str(triplet["object"][0])]) - ] - triplet["object"][1] = sample.token2char_end[ - str(sample.word2token_end[str(triplet["object"][1] - 1)]) - ] + ], + sample.token2char_end[ + str(sample.word2token_end[str(triplet["subject"][1] - 1)]) + ], + triplet["subject"][2], + ) + triplet["object"] = ( + sample.token2char_start[ + str(sample.word2token_start[str(triplet["object"][0])]) + ], + sample.token2char_end[ + str(sample.word2token_end[str(triplet["object"][1] - 1)]) + ], + triplet["object"][2], + ) + # triplet["object"][0] = sample.token2char_start[ + # str(sample.word2token_start[str(triplet["object"][0])]) + # ] + # triplet["object"][1] = sample.token2char_end[ + # str(sample.word2token_end[str(triplet["object"][1] - 1)]) + # ] sample = RelikREDataset._new_output_format(sample)