diff --git a/environment.yml b/environment.yml index 6c33716..bb92c7d 100644 --- a/environment.yml +++ b/environment.yml @@ -21,3 +21,4 @@ dependencies: - tensorflow==2.11.0 - protobuf==3.19.6 - filesplit + - googletrans==4.0.0rc1 diff --git a/requirements.txt b/requirements.txt index 83a4f84..df94458 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ ir_datasets filesplit flake8 pylint +googletrans diff --git a/src/dal/ds.py b/src/dal/ds.py index d432765..0360db5 100644 --- a/src/dal/ds.py +++ b/src/dal/ds.py @@ -197,14 +197,17 @@ def aggregate(cls, original, refined_query, changes, output): print(f'Saving original queries, better changes, and {metric} values based on {ranker} ...') with open(f'{output}/{ranker}.{metric}.agg.gold.tsv', mode='w', encoding='UTF-8') as agg_gold, \ open(f'{output}/{ranker}.{metric}.agg.all.tsv', mode='w', encoding='UTF-8') as agg_all, \ - open(f'{output}/{ranker}.{metric}.agg.platinum.tsv', mode='w', encoding='UTF-8') as agg_plat: - agg_gold.write(f'qid\torder\tquery\t{ranker}.{metric}\n') + open(f'{output}/{ranker}.{metric}.agg.platinum.tsv', mode='w', encoding='UTF-8') as agg_plat, \ + open(f'{output}/{ranker}.{metric}.neg.example.tsv', mode='w', encoding='UTF-8') as neg_exp: agg_all.write(f'qid\torder\tquery\t{ranker}.{metric}\n') + agg_gold.write(f'qid\torder\tquery\t{ranker}.{metric}\n') agg_plat.write(f'qid\torder\tquery\t{ranker}.{metric}\n') + neg_exp.write(f'qid\torder\tquery\t{ranker}.{metric}\n') for index, row in tqdm(original.iterrows(), total=original.shape[0]): + agg_all.write(f'{row.qid}\t-1\t{row.query}\t{row[f"original.{ranker}.{metric}"]}\n') agg_gold.write(f'{row.qid}\t-1\t{row.query}\t{row[f"original.{ranker}.{metric}"]}\n') agg_plat.write(f'{row.qid}\t-1\t{row.query}\t{row[f"original.{ranker}.{metric}"]}\n') - agg_all.write(f'{row.qid}\t-1\t{row.query}\t{row[f"original.{ranker}.{metric}"]}\n') + neg_exp.write(f'{row.qid}\t-1\t{row.query}\t{row[f"original.{ranker}.{metric}"]}\n') all = list() for change, metric_value in changes: all.append((row[change], row[f'{change}.{ranker}.{metric}'], change)) all = sorted(all, key=lambda x: x[1], reverse=True) @@ -213,6 +216,7 @@ def aggregate(cls, original, refined_query, changes, output): agg_all.write(f'{row.qid}\t{change}\t{query}\t{metric_value}\n') if metric_value > 0 and metric_value >= row[f'original.{ranker}.{metric}']: agg_gold.write(f'{row.qid}\t{change}\t{query}\t{metric_value}\n') if metric_value > 0 and metric_value > row[f'original.{ranker}.{metric}']: agg_plat.write(f'{row.qid}\t{change}\t{query}\t{metric_value}\n') + if ('bt' in change) and metric_value < row[f'original.{ranker}.{metric}']: neg_exp.write(f'{row.qid}\t{change}\t{query}\t{metric_value}\n') @classmethod def aggregate_refiner_rag(cls, original, changes, output): diff --git a/src/param.py b/src/param.py index 0fb7669..6f7fd71 100644 --- a/src/param.py +++ b/src/param.py @@ -9,13 +9,13 @@ os.environ['CUDA_VISIBLE_DEVICES'] = '-1' settings = { - 'cmd': ['rag_fusion'], # steps of pipeline, ['query_refinement', 'similarity', 'search', 'rag_fusion', 'eval','agg', 'box','dense_retrieve', 'stats] - 'datalist': ['./../data/raw/robust04'], # ['./../data/raw/robust04', './../data/raw/gov2', './../data/raw/antique', './../data/raw/dbpedia'], - 'domainlist': ['robust04'], # ['robust04', 'gov2', 'antique', 'dbpedia'], - 'fusion': ['all', 'global', 'local', 'bt'], #['all', 'global', 'local', 'bt'] + 'cmd': ['query_refinement'], # steps of pipeline, ['query_refinement', 'similarity', 'search', 'rag_fusion', 'eval','agg', 'box','dense_retrieve', 'stats] + 'datalist': ['./../data/raw/robust04'], # ['./../data/raw/robust04', './../data/raw/gov2', './../data/raw/antique', './../data/raw/dbpedia', './../data/raw/clueweb09b'] + 'domainlist': ['robust04'], # ['robust04', 'gov2', 'antique', 'dbpedia', 'clueweb09b'] + 'fusion': ['all', 'global', 'local', 'bt'], # ['all', 'global', 'local', 'bt'] 'ncore': 2, - 'ranker': ['bm25'], # 'qld', 'bm25', 'tct_colbert' - 'metric': ['map'], # any valid trec_eval.9.0.4 metrics like map, ndcg, recip_rank, ... + 'ranker': ['qld', 'bm25'], # 'qld', 'bm25', 'tct_colbert' + 'metric': ['map', 'ndcg', 'recip_rank'], # any valid trec_eval.9.0.4 metrics like 'map', 'ndcg', 'recip_rank', ... 'batch': None, # search per batch of queries for IR search using pyserini, if None, search per query 'topk': 100, # number of retrieved documents for a query 'large_ds': False, @@ -33,7 +33,7 @@ 'qrels': '../data/raw/msmarco.passage/qrels.train.tsv', 'dense_index': 'msmarco-passage-tct_colbert-hnsw', 'extcorpus': 'orcas', - 'pairing': [None, 'docs', 'query'], # [context={msmarco does not have userinfo}, input={query, doc, doc(s)}, output={query, doc, doc(s)}], s means concat of docs + 'pairing': [None, 'docs', 'query'], # [context={msmarco does not have userinfo}, input={query, doc, doc(s)}, output={query, doc, doc(s)}], s means concat of docs 'lseq': {"inputs": 32, "targets": 256}, # query length and doc length for t5 model, }, 'aol-ia': { @@ -42,10 +42,10 @@ 'dense_index': '../data/raw/aol-ia/dense-index/tct_colbert.title/', # change based on index_item 'qrels': '../data/raw/aol-ia/qrels.train.tsv', 'dense_encoder':'../data/raw/aol-ia/dense-encoder/tct_colbert.title/', # change based on index_item - 'pairing': [None, 'docs', 'query'], # [context={2 scenarios, one with userID and one without userID). input={'userid','query','doc(s)'} output={'query','doc(s)'} + 'pairing': [None, 'docs', 'query'], # [context={2 scenarios, one with userID and one without userID). input={'userid','query','doc(s)'} output={'query','doc(s)'} 'extcorpus': 'msmarco.passage', 'lseq': {"inputs": 32, "targets": 256}, # query length and doc length for t5 model, - 'filter': {'minql': 1, 'mindocl': 10} # [min query length, min doc length], after merge queries with relevant 'index_item', if |query| <= minql drop the row, if |'index_item'| < mindocl, drop row + 'filter': {'minql': 1, 'mindocl': 10} # [min query length, min doc length], after merge queries with relevant 'index_item', if |query| <= minql drop the row, if |'index_item'| < mindocl, drop row }, 'robust04': { 'index': '../data/raw/robust04/lucene-index.robust04.pos+docvectors+rawdocs', diff --git a/src/refinement/lang_code.py b/src/refinement/lang_code.py new file mode 100644 index 0000000..c1a8746 --- /dev/null +++ b/src/refinement/lang_code.py @@ -0,0 +1,316 @@ +google = { + 'afrikaans': 'af', + 'albanian': 'sq', + 'amharic': 'am', + 'arabic': 'ar', + 'armenian': 'hy', + 'azerbaijani': 'az', + 'basque': 'eu', + 'belarusian': 'be', + 'bengali': 'bn', + 'bosnian': 'bs', + 'bulgarian': 'bg', + 'catalan': 'ca', + 'cebuano': 'ceb', + 'chichewa': 'ny', + 'chinese_simplified': 'zh-cn', + 'chinese_traditional': 'zh-tw', + 'corsican': 'co', + 'croatian': 'hr', + 'czech': 'cs', + 'danish': 'da', + 'dutch': 'nl', + 'english': 'en', + 'esperanto': 'eo', + 'estonian': 'et', + 'filipino': 'tl', + 'finnish': 'fi', + 'french': 'fr', + 'frisian': 'fy', + 'galician': 'gl', + 'georgian': 'ka', + 'german': 'de', + 'greek': 'el', + 'gujarati': 'gu', + 'haitian creole': 'ht', + 'hausa': 'ha', + 'hawaiian': 'haw', + 'hebrew': 'iw', + 'hindi': 'hi', + 'hmong': 'hmn', + 'hungarian': 'hu', + 'icelandic': 'is', + 'igbo': 'ig', + 'indonesian': 'id', + 'irish': 'ga', + 'italian': 'it', + 'japanese': 'ja', + 'javanese': 'jw', + 'kannada': 'kn', + 'kazakh': 'kk', + 'khmer': 'km', + 'korean': 'ko', + 'kurdish': 'ku', + 'kyrgyz': 'ky', + 'lao': 'lo', + 'latin': 'la', + 'latvian': 'lv', + 'lithuanian': 'lt', + 'luxembourgish': 'lb', + 'macedonian': 'mk', + 'malagasy': 'mg', + 'malay': 'ms', + 'malayalam': 'ml', + 'maltese': 'mt', + 'maori': 'mi', + 'marathi': 'mr', + 'mongolian': 'mn', + 'myanmar': 'my', + 'nepali': 'ne', + 'norwegian': 'no', + 'odia': 'or', + 'pashto': 'ps', + 'persian': 'fa', + 'polish': 'pl', + 'portuguese': 'pt', + 'punjabi': 'pa', + 'romanian': 'ro', + 'russian': 'ru', + 'samoan': 'sm', + 'scots gaelic': 'gd', + 'serbian': 'sr', + 'sesotho': 'st', + 'shona': 'sn', + 'sindhi': 'sd', + 'sinhala': 'si', + 'slovak': 'sk', + 'slovenian': 'sl', + 'somali': 'so', + 'spanish': 'es', + 'sundanese': 'su', + 'swahili': 'sw', + 'swedish': 'sv', + 'tajik': 'tg', + 'tamil': 'ta', + 'telugu': 'te', + 'thai': 'th', + 'turkish': 'tr', + 'ukrainian': 'uk', + 'urdu': 'ur', + 'uyghur': 'ug', + 'uzbek': 'uz', + 'vietnamese': 'vi', + 'welsh': 'cy', + 'xhosa': 'xh', + 'yiddish': 'yi', + 'yoruba': 'yo', + 'zulu': 'zu' +} + +nllb = { + 'acehnese (arabic script)': 'ace_Arab', + 'acehnese (latin script)': 'ace_Latn', + 'mesopotamian arabic': 'acm_Arab', + 'ta’izzi-adeni arabic': 'acq_Arab', + 'tunisian arabic': 'aeb_Arab', + 'afrikaans': 'afr_Latn', + 'south levantine arabic': 'ajp_Arab', + 'akan': 'aka_Latn', + 'amharic': 'amh_Ethi', + 'north levantine arabic': 'apc_Arab', + 'arabic': 'arb_Arab', + 'modern standard arabic (romanized)': 'arb_Latn', + 'najdi arabic': 'ars_Arab', + 'moroccan arabic': 'ary_Arab', + 'egyptian arabic': 'arz_Arab', + 'assamese': 'asm_Beng', + 'asturian': 'ast_Latn', + 'awadhi': 'awa_Deva', + 'central aymara': 'ayr_Latn', + 'south azerbaijani': 'azb_Arab', + 'north azerbaijani': 'azj_Latn', + 'bashkir': 'bak_Cyrl', + 'bambara': 'bam_Latn', + 'balinese': 'ban_Latn', + 'belarusian': 'bel_Cyrl', + 'bemba': 'bem_Latn', + 'bengali': 'ben_Beng', + 'bhojpuri': 'bho_Deva', + 'banjar (arabic script)': 'bjn_Arab', + 'banjar (latin script)': 'bjn_Latn', + 'standard tibetan': 'bod_Tibt', + 'bosnian': 'bos_Latn', + 'buginese': 'bug_Latn', + 'bulgarian': 'bul_Cyrl', + 'catalan': 'cat_Latn', + 'cebuano': 'ceb_Latn', + 'czech': 'ces_Latn', + 'chokwe': 'cjk_Latn', + 'central kurdish': 'ckb_Arab', + 'crimean tatar': 'crh_Latn', + 'welsh': 'cym_Latn', + 'danish': 'dan_Latn', + 'german': 'deu_Latn', + 'southwestern dinka': 'dik_Latn', + 'dyula': 'dyu_Latn', + 'dzongkha': 'dzo_Tibt', + 'greek': 'ell_Grek', + 'english': 'eng_Latn', + 'esperanto': 'epo_Latn', + 'estonian': 'est_Latn', + 'basque': 'eus_Latn', + 'ewe': 'ewe_Latn', + 'faroese': 'fao_Latn', + 'fijian': 'fij_Latn', + 'finnish': 'fin_Latn', + 'fon': 'fon_Latn', + 'french': 'fra_Latn', + 'friulian': 'fur_Latn', + 'nigerian fulfulde': 'fuv_Latn', + 'scottish gaelic': 'gla_Latn', + 'irish': 'gle_Latn', + 'galician': 'glg_Latn', + 'guarani': 'grn_Latn', + 'gujarati': 'guj_Gujr', + 'haitian creole': 'hat_Latn', + 'hausa': 'hau_Latn', + 'hebrew': 'heb_Hebr', + 'hindi': 'hin_Deva', + 'chhattisgarhi': 'hne_Deva', + 'croatian': 'hrv_Latn', + 'hungarian': 'hun_Latn', + 'armenian': 'hye_Armn', + 'igbo': 'ibo_Latn', + 'ilocano': 'ilo_Latn', + 'indonesian': 'ind_Latn', + 'icelandic': 'isl_Latn', + 'italian': 'ita_Latn', + 'javanese': 'jav_Latn', + 'japanese': 'jpn_Jpan', + 'kabyle': 'kab_Latn', + 'jingpho': 'kac_Latn', + 'kamba': 'kam_Latn', + 'kannada': 'kan_Knda', + 'kashmiri (arabic script)': 'kas_Arab', + 'kashmiri (devanagari script)': 'kas_Deva', + 'georgian': 'kat_Geor', + 'central kanuri (arabic script)': 'knc_Arab', + 'central kanuri (latin script)': 'knc_Latn', + 'kazakh': 'kaz_Cyrl', + 'kabiyè': 'kbp_Latn', + 'kabuverdianu': 'kea_Latn', + 'khmer': 'khm_Khmr', + 'kikuyu': 'kik_Latn', + 'kinyarwanda': 'kin_Latn', + 'kyrgyz': 'kir_Cyrl', + 'kimbundu': 'kmb_Latn', + 'northern kurdish': 'kmr_Latn', + 'kikongo': 'kon_Latn', + 'korean': 'kor_Hang', + 'lao': 'lao_Laoo', + 'ligurian': 'lij_Latn', + 'limburgish': 'lim_Latn', + 'lingala': 'lin_Latn', + 'lithuanian': 'lit_Latn', + 'lombard': 'lmo_Latn', + 'latgalian': 'ltg_Latn', + 'luxembourgish': 'ltz_Latn', + 'luba-kasai': 'lua_Latn', + 'ganda': 'lug_Latn', + 'luo': 'luo_Latn', + 'mizo': 'lus_Latn', + 'standard latvian': 'lvs_Latn', + 'magahi': 'mag_Deva', + 'maithili': 'mai_Deva', + 'malayalam': 'mal_Mlym', + 'marathi': 'mar_Deva', + 'minangkabau (arabic script)': 'min_Arab', + 'minangkabau (latin script)': 'min_Latn', + 'macedonian': 'mkd_Cyrl', + 'plateau malagasy': 'plt_Latn', + 'maltese': 'mlt_Latn', + 'meitei (bengali script)': 'mni_Beng', + 'halh mongolian': 'khk_Cyrl', + 'mossi': 'mos_Latn', + 'maori': 'mri_Latn', + 'burmese': 'mya_Mymr', + 'dutch': 'nld_Latn', + 'norwegian nynorsk': 'nno_Latn', + 'norwegian bokmål': 'nob_Latn', + 'nepali': 'npi_Deva', + 'northern sotho': 'nso_Latn', + 'nuer': 'nus_Latn', + 'nyanja': 'nya_Latn', + 'occitan': 'oci_Latn', + 'west central oromo': 'gaz_Latn', + 'odia': 'ory_Orya', + 'pangasinan': 'pag_Latn', + 'eastern panjabi': 'pan_Guru', + 'papiamento': 'pap_Latn', + 'persian': 'pes_Arab', + 'polish': 'pol_Latn', + 'portuguese': 'por_Latn', + 'dari': 'prs_Arab', + 'southern pashto': 'pbt_Arab', + 'ayacucho quechua': 'quy_Latn', + 'romanian': 'ron_Latn', + 'rundi': 'run_Latn', + 'russian': 'rus_Cyrl', + 'sango': 'sag_Latn', + 'sanskrit': 'san_Deva', + 'santali': 'sat_Olck', + 'sicilian': 'scn_Latn', + 'shan': 'shn_Mymr', + 'sinhala': 'sin_Sinh', + 'slovak': 'slk_Latn', + 'slovenian': 'slv_Latn', + 'samoan': 'smo_Latn', + 'shona': 'sna_Latn', + 'sindhi': 'snd_Arab', + 'somali': 'som_Latn', + 'southern sotho': 'sot_Latn', + 'spanish': 'spa_Latn', + 'tosk albanian': 'als_Latn', + 'sardinian': 'srd_Latn', + 'serbian': 'srp_Cyrl', + 'swati': 'ssw_Latn', + 'sundanese': 'sun_Latn', + 'swedish': 'swe_Latn', + 'swahili': 'swh_Latn', + 'silesian': 'szl_Latn', + 'tamil': 'tam_Taml', + 'tatar': 'tat_Cyrl', + 'telugu': 'tel_Telu', + 'tajik': 'tgk_Cyrl', + 'tagalog': 'tgl_Latn', + 'thai': 'tha_Thai', + 'tigrinya': 'tir_Ethi', + 'tamasheq (latin script)': 'taq_Latn', + 'tamasheq (tifinagh script)': 'taq_Tfng', + 'tok pisin': 'tpi_Latn', + 'tswana': 'tsn_Latn', + 'tsonga': 'tso_Latn', + 'turkmen': 'tuk_Latn', + 'tumbuka': 'tum_Latn', + 'turkish': 'tur_Latn', + 'twi': 'twi_Latn', + 'central atlas tamazight': 'tzm_Tfng', + 'uyghur': 'uig_Arab', + 'ukrainian': 'ukr_Cyrl', + 'umbundu': 'umb_Latn', + 'urdu': 'urd_Arab', + 'northern uzbek': 'uzn_Latn', + 'venetian': 'vec_Latn', + 'vietnamese': 'vie_Latn', + 'waray': 'war_Latn', + 'wolof': 'wol_Latn', + 'xhosa': 'xho_Latn', + 'eastern yiddish': 'ydd_Hebr', + 'yoruba': 'yor_Latn', + 'yue chinese': 'yue_Hant', + 'chinese_simplified': 'zho_Hans', + 'chinese_traditional': 'zho_Hant', + 'malay': 'zsm_Latn', + 'zulu': 'zul_Latn' +} + diff --git a/src/refinement/refiner_factory.py b/src/refinement/refiner_factory.py index 86a7bd2..83f629a 100644 --- a/src/refinement/refiner_factory.py +++ b/src/refinement/refiner_factory.py @@ -1,6 +1,7 @@ from refinement.refiners.abstractqrefiner import AbstractQRefiner from refinement.refiners.stem import Stem # Stem refiner is the wrapper for all stemmers as an refiner :) from refinement import refiner_param +from itertools import product from param import settings import os @@ -35,7 +36,7 @@ def get_nrf_refiner(): if refiners_name['SRemovalStemmer']: from refinement.stemmers.sstemmer import SRemovalStemmer; refiners_list.append(Stem(SRemovalStemmer())) if refiners_name['Trunc4Stemmer']: from refinement.stemmers.trunc4 import Trunc4Stemmer; refiners_list.append(Stem(Trunc4Stemmer())) if refiners_name['Trunc5Stemmer']: from refinement.stemmers.trunc5 import Trunc5Stemmer; refiners_list.append(Stem(Trunc5Stemmer())) - if refiners_name['BackTranslation']: from refinement.refiners.backtranslation import BackTranslation; refiners_list.extend([BackTranslation(each_lng) for index, each_lng in enumerate(refiner_param.backtranslation['tgt_lng'])]) + if refiners_name['BackTranslation']: from refinement.refiners.backtranslation import BackTranslation; refiners_list.extend([BackTranslation(trans, lang) for lang, trans in product(refiner_param.backtranslation['tgt_lng'], refiner_param.backtranslation['translator'])]) # since RF needs index and search output which depends on ir method and topics corpora, we cannot add this here. Instead, we run it individually # RF assumes that there exist abstractqueryexpansion files diff --git a/src/refinement/refiner_param.py b/src/refinement/refiner_param.py index 436120f..ba6d082 100644 --- a/src/refinement/refiner_param.py +++ b/src/refinement/refiner_param.py @@ -45,8 +45,9 @@ # Backtranslation settings backtranslation = { - 'src_lng': 'eng_Latn', - 'tgt_lng': ['yue_Hant', 'kor_Hang', 'arb_Arab', 'pes_Arab', 'fra_Latn', 'deu_Latn', 'rus_Cyrl', 'zsm_Latn', 'tam_Taml', 'swh_Latn'], # ['yue_Hant', 'kor_Hang', 'arb_Arab', 'pes_Arab', 'fra_Latn', 'deu_Latn', 'rus_Cyrl', 'zsm_Latn', 'tam_Taml', 'swh_Latn'] + 'translator': ['google'], # ['google', 'nllb'] + 'src_lng': 'english', + 'tgt_lng': ['persian', 'french', 'german', 'russian', 'malay', 'tamil', 'swahili', 'chinese_simplified', 'korean', 'arabic'], # ['persian', 'french', 'german', 'russian', 'malay', 'tamil', 'swahili', 'chinese_simplified', 'korean', 'arabic'] 'max_length': 512, 'device': 'cpu', 'model_card': 'facebook/nllb-200-distilled-600M', diff --git a/src/refinement/refiners/backtranslation.py b/src/refinement/refiners/backtranslation.py index c676ea2..27c2a3d 100644 --- a/src/refinement/refiners/backtranslation.py +++ b/src/refinement/refiners/backtranslation.py @@ -1,34 +1,46 @@ -from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM -from src.refinement.refiners.abstractqrefiner import AbstractQRefiner -from src.refinement.refiner_param import backtranslation - +from refinement.refiners.abstractqrefiner import AbstractQRefiner +from refinement.refiner_param import backtranslation +from refinement.lang_code import google, nllb class BackTranslation(AbstractQRefiner): - def __init__(self, tgt): + def __init__(self, translator, tgt): AbstractQRefiner.__init__(self) # Initialization + self.src = backtranslation['src_lng'] self.tgt = tgt - model = AutoModelForSeq2SeqLM.from_pretrained(backtranslation['model_card']) - tokenizer = AutoTokenizer.from_pretrained(backtranslation['model_card']) + self.translator_name = translator # Translation models - self.translator = pipeline("translation", model=model, tokenizer=tokenizer, src_lang=backtranslation['src_lng'], tgt_lang=self.tgt, max_length=backtranslation['max_length'], device=backtranslation['device']) - self.back_translator = pipeline("translation", model=model, tokenizer=tokenizer, src_lang=self.tgt, tgt_lang=backtranslation['src_lng'], max_length=backtranslation['max_length'], device=backtranslation['device']) + # Google + if self.translator_name == 'google': + from googletrans import Translator + self.translator = Translator(service_urls=['translate.google.com']) + # Meta's NLLB + else: + from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM + model = AutoModelForSeq2SeqLM.from_pretrained(backtranslation['model_card']) + tokenizer = AutoTokenizer.from_pretrained(backtranslation['model_card']) + self.translator = pipeline("translation", model=model, tokenizer=tokenizer, src_lang=nllb[self.src], tgt_lang=nllb[self.tgt], max_length=backtranslation['max_length'], device=backtranslation['device']) + self.back_translator = pipeline("translation", model=model, tokenizer=tokenizer, src_lang=nllb[self.tgt], tgt_lang=nllb[self.src], max_length=backtranslation['max_length'], device=backtranslation['device']) ''' Generates the backtranslated query then calculates the semantic similarity of the two queries ''' def get_refined_query(self, query, args=None): - translated_query = self.translator(query) - back_translated_query = self.back_translator(translated_query[0]['translation_text']) - return super().get_refined_query(back_translated_query[0]['translation_text']) + if self.translator_name == 'google': + translated_query = self.translator.translate(query, src=google[self.src], dest=google[self.tgt]) + backtranslated_query = (self.translator.translate(translated_query.text, src=google[self.tgt], dest=google[self.src])).text + else: + translated_query = self.translator(query) + backtranslated_query = (self.back_translator(translated_query[0]['translation_text']))[0]['translation_text'] + return super().get_refined_query(backtranslated_query) def get_refined_query_batch(self, queries, args=None): try: translated_queries = self.translator([query for query in queries]) - back_translated_queries = self.back_translator([tq_['translation_text'] for tq_ in translated_queries]) - q_s = [q_['translation_text'] for q_ in back_translated_queries] + backtranslated_queries = self.backtranslator([tq_['translation_text'] for tq_ in translated_queries]) + q_s = [q_['translation_text'] for q_ in backtranslated_queries] except: q_s = [None] * len(queries) return q_s @@ -38,7 +50,7 @@ def get_refined_query_batch(self, queries, args=None): Example: 'backtranslation_fra_latn' ''' def get_model_name(self): - return 'bt_' + self.tgt.lower() + return 'bt_' + self.translator_name + '_' + self.tgt.lower() if __name__ == "__main__": diff --git a/src/refinement/refiners/relevancefeedback.py b/src/refinement/refiners/relevancefeedback.py index 60dfff9..d613946 100644 --- a/src/refinement/refiners/relevancefeedback.py +++ b/src/refinement/refiners/relevancefeedback.py @@ -31,7 +31,7 @@ def get_topn_relevant_docids(self, q=None, qid=None): self.f.seek(0) i = 0 for x in self.f: - x_splited = x.split() + x_splited = x.split() try : if (int(x_splited[0]) == qid or x_splited[0] == qid): relevant_documents.append(x_splited[2]) diff --git a/src/stats/analyze.py b/src/stats/analyze.py index de36c8a..68df002 100644 --- a/src/stats/analyze.py +++ b/src/stats/analyze.py @@ -1,5 +1,6 @@ from src.param import settings from src.refinement.refiner_param import refiners +import matplotlib.pyplot as plt from itertools import product import pandas as pd import os @@ -28,14 +29,14 @@ def combine_refiner_results(infile, output, datasets, ranker_metrics): on the 'refiner' column, and writes the combined results to CSV files for each ranker-metric combination. """ for r, m in ranker_metrics: - df_list = [pd.read_csv(f'{infile}.{ds}.{r}.{m}.csv', skiprows=1, names=['category', 'refiner', ds]) for ds in datasets] + df_list = [pd.read_csv(f'{infile}.{ds}.{r}.{m}.csv', skiprows=1, names=['category', 'refiner', f'{ds}_|q**|', f'{ds}_%']) for ds in datasets] merged_df = df_list[0] for df in df_list[1:]: merged_df = pd.merge(merged_df, df, on=['refiner', 'category'], how='outer') merged_df = merged_df.fillna(0) merged_df.to_csv(f'{output}.{r}.{m}.csv', index=False) -def compare_refiners(infile, output, globalr, ranker, metric, overlap=False, agg_bt=False): +def compare_refiners(infile, output, globalr, ranker, metric, refiners_list=[], overlap=False, agg=False): """ Compare the performance of different query refiners. @@ -57,29 +58,39 @@ def compare_refiners(infile, output, globalr, ranker, metric, overlap=False, agg The comparison results are then sorted, saved to an output file, and categorized as global or local. """ df = pd.read_csv(infile, sep='\t', header=0) - df['order'] = df['order'].apply(lambda x: 'q^' if '-1' in x else x) num_q = df['qid'].nunique() # Number of queries in the dataset - - if agg_bt: - df['order'] = df['order'].apply(lambda x: 'bt' if 'bt' in x else x) - bt_rows = df[df['order'] == 'bt'] - max_indices = bt_rows.groupby(['qid', 'order'])[f'{ranker}.{metric}'].idxmax() - df_bt = bt_rows.loc[max_indices] - df = pd.concat([df_bt, df[df['order'] != 'bt']], ignore_index=True) - + if len(refiners_list) == 0: refiners_list = get_refiner_list('global') + get_refiner_list('local') + ['-1'] + if agg: + # df['order'] = df['order'].apply(lambda x: 'bt' if 'bt' in x else x) + # bt_rows = df[df['order'] == 'bt'] + # max_indices = bt_rows.groupby(['qid', 'order'])[f'{ranker}.{metric}'].idxmax() + # df_bt = bt_rows.loc[max_indices] + # df = pd.concat([df_bt, df[df['order'] != 'bt']], ignore_index=True) + filtered_dfs = [] + for refiner_name in refiners_list: + filtered_df = df.loc[df[df['order'].str.contains(refiner_name)].groupby('qid')[f'{ranker}.{metric}'].idxmax()] + filtered_df['order'] = refiner_name + filtered_dfs.append(filtered_df) + df = pd.concat(filtered_dfs, ignore_index=True) + df.reset_index(drop=True, inplace=True) + + # df = df[df['order'].str.contains('|'.join(refiners_list))] # Selecting q** (The best query among all refined queries for each qid) if not overlap: max_indices = df.groupby('qid')[f'{ranker}.{metric}'].idxmax() df = df.loc[max_indices] df.reset_index(drop=True, inplace=True) + # plot_chart(df, ranker, metric, f'{output}.png') + df['order'] = df['order'].apply(lambda x: 'q^' if '-1' in x else x) + # Write results in a file - final_table = pd.DataFrame(columns=['category', 'refiner', '%']) + final_table = pd.DataFrame(columns=['category', 'refiner', '|q**|', '%']) for refiner, group in df.groupby('order'): c = 'global 'if any(name in refiner for name in globalr) else 'local' - final_table.loc[len(final_table)] = [c, refiner, (len(group)/num_q)*100] - final_table = final_table.sort_values(by=['category', '%'], ascending=[True, False]) - final_table.to_csv(output, index=False) + final_table.loc[len(final_table)] = [c, refiner, len(group), (len(group)/num_q)*100] + final_table = final_table.sort_values(by=['category', 'refiner'], ascending=[True, True]) + final_table.to_csv(f'{output}.csv', index=False) def create_analyze_table(refiner_name): @@ -113,38 +124,87 @@ def create_analyze_table(refiner_name): """ final_table = pd.DataFrame(columns=["dataset", "|Q|", "|Q'|", "ir_metric", "|Q*|", "%", "delta"]) - for ds in ['clueweb09b']: + for ds in ['dbpedia', 'robust04', 'antique', 'gov2', 'clueweb09b']: for ranker, metric in product(settings['ranker'], settings['metric']): df = pd.read_csv(f'./output/{ds}/{ranker}.{metric}/{ranker}.{metric}.agg.platinum.tsv', sep='\t', header=0) df = df[df['order'].str.contains(refiner_name) | (df['order'] == '-1')] num_q = df['qid'].nunique() # Number of queries in the dataset q_prim = len(df[(df['order'] == '-1') & (df[f'{ranker}.{metric}'] != 1)]) - q_star = df[df['order'].str.contains(refiner_name)].groupby('qid').ngroups - avg_metric = 0 - for qid, group in df.groupby('qid'): - filtered_df = group[group['order'].str.contains(refiner_name)] - if len(filtered_df != 0): - max_index = filtered_df[f'{ranker}.{metric}'].idxmax() - avg_metric += filtered_df.loc[max_index, f'{ranker}.{metric}'] - group.loc[group['order'] == '-1', f'{ranker}.{metric}'].values[0] - percent = (q_star/q_prim)*100 + + avg_metric = 0 + max_metric_index = df[df['order'].str.contains(refiner_name)].groupby('qid')[f'{ranker}.{metric}'].idxmax() + for idx in max_metric_index: avg_metric += df.loc[idx, f'{ranker}.{metric}'] - df.loc[(df['qid'] == df.loc[idx, 'qid']) & (df['order'] == '-1'), f'{ranker}.{metric}'].values[0] avg_metric = avg_metric / q_star - final_table.loc[len(final_table)] = [ds, num_q, q_prim, f'{ranker}.{metric}', q_star, percent, f'+{avg_metric}'] + + extra_col = {} + if 'bt' in refiner_name: + lang = ['pes_arab', 'fra_latn', 'deu_latn', 'rus_cyrl', 'zsm_latn', 'tam_taml', 'swh_latn', 'yue_hant', 'kor_hang', 'arb_arab'] + extra_col = {f'{item1}_{item2}': 0 for item1, item2 in product(lang, ['%', 'delta'])} + for refiner, group in df.groupby('order'): + if refiner == '-1': continue + avg = group[f'{ranker}.{metric}'].sum() + qid_list = group['qid'].tolist() + avg_original = df[(df['qid'].isin(qid_list)) & (df['order'] == '-1')][f'{ranker}.{metric}'].sum() + extra_col[f'{refiner.replace("bt_", "", 1)}_%'] = len(group)/q_prim * 100 + extra_col[f'{refiner.replace("bt_", "", 1)}_delta'] = f'+{(avg - avg_original)/len(group)}' + + new_line = {"dataset":ds, "|Q|":num_q, "|Q'|":q_prim, "ir_metric":f'{ranker}.{metric}', "|Q*|":q_star, "%":percent, "delta":f'+{avg_metric}', **extra_col} + final_table = pd.concat([final_table, pd.DataFrame([new_line])], ignore_index=True) final_table.to_csv(f'./output/analyze/analyze_{refiner_name}.all.csv', index=False) +def refiner_distribution_table(infile, output): + for ds in ['dbpedia', 'robust04', 'antique', 'gov2', 'clueweb09b']: + for ranker, metric in product(settings['ranker'], settings['metric']): + df = pd.read_csv(f'{infile}/{ds}/{ranker}.{metric}/{ranker}.{metric}.agg.all.tsv', sep='\t', header=0) + df['delta'] = df['delta'] = df.groupby('qid')[f'{ranker}.{metric}'].transform(lambda x: x - x[df['order'] == '-1'].iloc[0]) + df.reset_index(drop=True, inplace=True) + filtered_dfs = [] + refiners_list = ['-1', 'bt', 'tagmee', 'relevancefeedback', 'anchor'] + for refiner_name in refiners_list: + filtered_df = df.loc[ + df[df['order'].str.contains(refiner_name)].groupby('qid')[f'{ranker}.{metric}'].idxmax()] + filtered_df['order'] = refiner_name + filtered_dfs.append(filtered_df) + df = pd.concat(filtered_dfs, ignore_index=True) + df.reset_index(drop=True, inplace=True) + df.to_csv(f'{output}/cal.delta.refiner.original.{ds}.{ranker}.{metric}.csv', index=False) + # plot_chart(df, ranker, metric, f'{output}/cal.delta.refiner.original.{ds}.{ranker}.{metric}.png') + + +def plot_chart(df, ranker, metric, output): + colors = ["#C2FF3300", "#7800FF00", "#FFC30000", "#0082FF00", "#FF573300", "#00FFA700", "#0082FF00", "#FF00E500", "#FF007700", "#00FFCC00","#FFDC0000"] + grouped_df = df.groupby('order') + index = 0 + for order, group in grouped_df: + if order == '-1': continue + trend = group['delta'].to_list() + # trend = sorted(trend, reverse=True) + plt.hist(trend, label=order.split('.')[0], alpha=0.8) + index += 1 + # plt.gca().axes.get_xaxis().set_visible(False) + plt.xlabel(f'{ranker}.{metric}') + plt.legend(loc='upper left', bbox_to_anchor=(1, 1)) + plt.title(f'Scatter plot {ranker}.{metric}') + plt.savefig(output) + plt.close() + # plt.show() + + if __name__ == '__main__': globalr = get_refiner_list('global') localr = get_refiner_list('local') - # for ds in ['robust04', 'antique', 'dbpedia']: - # infile = f'./output/{ds}' - # output = f'./output/analyze' - # if not os.path.isdir(output): os.makedirs(output) - # [compare_refiners(infile=f'{infile}/{ranker}.{metric}/{ranker}.{metric}.agg.platinum.tsv', output=f'{output}/compare.refiners.{ds}.{ranker}.{metric}.csv', globalr=globalr, ranker=ranker, metric=metric, overlap=False, agg_bt=True) for ranker, metric in product(settings['ranker'], settings['metric'])] - # [compare_refiners(infile=f'{infile}/{ranker}.{metric}/{ranker}.{metric}.agg.rag.platinum.tsv', output=f'{output}/compare.refiners.rag.{ds}.{ranker}.{metric}.csv', globalr=globalr, ranker=ranker, metric=metric, overlap=False, agg_bt=True) for ranker, metric in product(settings['ranker'], settings['metric'])] - # - # combine_refiner_results(infile='./output/analyze/compare.refiners', output='./output/analyze/compare.refiners.all.datasets', datasets=['robust04', 'antique', 'dbpedia'], ranker_metrics=product(settings['ranker'], settings['metric'])) + for ds in ['dbpedia', 'robust04', 'antique', 'gov2', 'clueweb09b']: + infile = f'./output/{ds}' + output = f'./output/analyze' + if not os.path.isdir(output): os.makedirs(output) + [compare_refiners(infile=f'{infile}/{ranker}.{metric}/{ranker}.{metric}.agg.platinum.tsv', output=f'{output}/compare.refiners.{ds}.{ranker}.{metric}', globalr=globalr, ranker=ranker, metric=metric, refiners_list=[], overlap=False, agg=True) for ranker, metric in product(settings['ranker'], settings['metric'])] + # [compare_refiners(infile=f'{infile}/{ranker}.{metric}/{ranker}.{metric}.agg.rag.platinum.tsv', output=f'{output}/compare.refiners.rag.{ds}.{ranker}.{metric}', globalr=globalr, ranker=ranker, metric=metric, overlap=False, agg=True) for ranker, metric in product(settings['ranker'], settings['metric'])] + # ['-1', 'bt', 'conceptluster', 'relevancefeedback', 'anchor'] + combine_refiner_results(infile='./output/analyze/compare.refiners', output='./output/analyze/compare.refiners.all.datasets', datasets=['dbpedia', 'robust04', 'antique', 'gov2', 'clueweb09b'], ranker_metrics=product(settings['ranker'], settings['metric'])) # combine_refiner_results(infile='./output/analyze/compare.refiners.rag', output='./output/analyze/compare.refiners.rag.all.datasets', datasets=['robust04', 'antique', 'dbpedia'], ranker_metrics=product(settings['ranker'], settings['metric'])) - create_analyze_table('bt') \ No newline at end of file + # create_analyze_table('bt') + # refiner_distribution_table(infile='./output', output='./output/analyze/chart') \ No newline at end of file diff --git a/translated_query.py b/translated_query.py new file mode 100644 index 0000000..198864c --- /dev/null +++ b/translated_query.py @@ -0,0 +1,18 @@ +from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM + +model = AutoModelForSeq2SeqLM.from_pretrained('facebook/nllb-200-distilled-600M') +tokenizer = AutoTokenizer.from_pretrained('facebook/nllb-200-distilled-600M') + +# ['yue_Hant', 'kor_Hang', 'arb_Arab', 'pes_Arab', 'fra_Latn', 'deu_Latn', 'rus_Cyrl', 'zsm_Latn', 'tam_Taml', 'swh_Latn'] +def translate(q, l): + translator = pipeline("translation", model=model, tokenizer=tokenizer, src_lang='eng_Latn', tgt_lang=l, max_length=512, device='cpu') + return translator(q) + +q = "murals" +l = 'deu_Latn' +t = translate(q, l) +bt = translate(q, 'eng_Latn') +print(t[0]['translation_text'], ) +print(bt[0]['translation_text']) + +