Skip to content

Commit

Permalink
Add google translator
Browse files Browse the repository at this point in the history
  • Loading branch information
DelaramRajaei committed May 1, 2024
1 parent 9420c4c commit fe3d081
Show file tree
Hide file tree
Showing 11 changed files with 479 additions and 65 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ dependencies:
- tensorflow==2.11.0
- protobuf==3.19.6
- filesplit
- googletrans==4.0.0rc1
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ ir_datasets
filesplit
flake8
pylint
googletrans
10 changes: 7 additions & 3 deletions src/dal/ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,17 @@ def aggregate(cls, original, refined_query, changes, output):
print(f'Saving original queries, better changes, and {metric} values based on {ranker} ...')
with open(f'{output}/{ranker}.{metric}.agg.gold.tsv', mode='w', encoding='UTF-8') as agg_gold, \
open(f'{output}/{ranker}.{metric}.agg.all.tsv', mode='w', encoding='UTF-8') as agg_all, \
open(f'{output}/{ranker}.{metric}.agg.platinum.tsv', mode='w', encoding='UTF-8') as agg_plat:
agg_gold.write(f'qid\torder\tquery\t{ranker}.{metric}\n')
open(f'{output}/{ranker}.{metric}.agg.platinum.tsv', mode='w', encoding='UTF-8') as agg_plat, \
open(f'{output}/{ranker}.{metric}.neg.example.tsv', mode='w', encoding='UTF-8') as neg_exp:
agg_all.write(f'qid\torder\tquery\t{ranker}.{metric}\n')
agg_gold.write(f'qid\torder\tquery\t{ranker}.{metric}\n')
agg_plat.write(f'qid\torder\tquery\t{ranker}.{metric}\n')
neg_exp.write(f'qid\torder\tquery\t{ranker}.{metric}\n')
for index, row in tqdm(original.iterrows(), total=original.shape[0]):
agg_all.write(f'{row.qid}\t-1\t{row.query}\t{row[f"original.{ranker}.{metric}"]}\n')
agg_gold.write(f'{row.qid}\t-1\t{row.query}\t{row[f"original.{ranker}.{metric}"]}\n')
agg_plat.write(f'{row.qid}\t-1\t{row.query}\t{row[f"original.{ranker}.{metric}"]}\n')
agg_all.write(f'{row.qid}\t-1\t{row.query}\t{row[f"original.{ranker}.{metric}"]}\n')
neg_exp.write(f'{row.qid}\t-1\t{row.query}\t{row[f"original.{ranker}.{metric}"]}\n')
all = list()
for change, metric_value in changes: all.append((row[change], row[f'{change}.{ranker}.{metric}'], change))
all = sorted(all, key=lambda x: x[1], reverse=True)
Expand All @@ -213,6 +216,7 @@ def aggregate(cls, original, refined_query, changes, output):
agg_all.write(f'{row.qid}\t{change}\t{query}\t{metric_value}\n')
if metric_value > 0 and metric_value >= row[f'original.{ranker}.{metric}']: agg_gold.write(f'{row.qid}\t{change}\t{query}\t{metric_value}\n')
if metric_value > 0 and metric_value > row[f'original.{ranker}.{metric}']: agg_plat.write(f'{row.qid}\t{change}\t{query}\t{metric_value}\n')
if ('bt' in change) and metric_value < row[f'original.{ranker}.{metric}']: neg_exp.write(f'{row.qid}\t{change}\t{query}\t{metric_value}\n')

@classmethod
def aggregate_refiner_rag(cls, original, changes, output):
Expand Down
18 changes: 9 additions & 9 deletions src/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

settings = {
'cmd': ['rag_fusion'], # steps of pipeline, ['query_refinement', 'similarity', 'search', 'rag_fusion', 'eval','agg', 'box','dense_retrieve', 'stats]
'datalist': ['./../data/raw/robust04'], # ['./../data/raw/robust04', './../data/raw/gov2', './../data/raw/antique', './../data/raw/dbpedia'],
'domainlist': ['robust04'], # ['robust04', 'gov2', 'antique', 'dbpedia'],
'fusion': ['all', 'global', 'local', 'bt'], #['all', 'global', 'local', 'bt']
'cmd': ['query_refinement'], # steps of pipeline, ['query_refinement', 'similarity', 'search', 'rag_fusion', 'eval','agg', 'box','dense_retrieve', 'stats]
'datalist': ['./../data/raw/robust04'], # ['./../data/raw/robust04', './../data/raw/gov2', './../data/raw/antique', './../data/raw/dbpedia', './../data/raw/clueweb09b']
'domainlist': ['robust04'], # ['robust04', 'gov2', 'antique', 'dbpedia', 'clueweb09b']
'fusion': ['all', 'global', 'local', 'bt'], # ['all', 'global', 'local', 'bt']
'ncore': 2,
'ranker': ['bm25'], # 'qld', 'bm25', 'tct_colbert'
'metric': ['map'], # any valid trec_eval.9.0.4 metrics like map, ndcg, recip_rank, ...
'ranker': ['qld', 'bm25'], # 'qld', 'bm25', 'tct_colbert'
'metric': ['map', 'ndcg', 'recip_rank'], # any valid trec_eval.9.0.4 metrics like 'map', 'ndcg', 'recip_rank', ...
'batch': None, # search per batch of queries for IR search using pyserini, if None, search per query
'topk': 100, # number of retrieved documents for a query
'large_ds': False,
Expand All @@ -33,7 +33,7 @@
'qrels': '../data/raw/msmarco.passage/qrels.train.tsv',
'dense_index': 'msmarco-passage-tct_colbert-hnsw',
'extcorpus': 'orcas',
'pairing': [None, 'docs', 'query'], # [context={msmarco does not have userinfo}, input={query, doc, doc(s)}, output={query, doc, doc(s)}], s means concat of docs
'pairing': [None, 'docs', 'query'], # [context={msmarco does not have userinfo}, input={query, doc, doc(s)}, output={query, doc, doc(s)}], s means concat of docs
'lseq': {"inputs": 32, "targets": 256}, # query length and doc length for t5 model,
},
'aol-ia': {
Expand All @@ -42,10 +42,10 @@
'dense_index': '../data/raw/aol-ia/dense-index/tct_colbert.title/', # change based on index_item
'qrels': '../data/raw/aol-ia/qrels.train.tsv',
'dense_encoder':'../data/raw/aol-ia/dense-encoder/tct_colbert.title/', # change based on index_item
'pairing': [None, 'docs', 'query'], # [context={2 scenarios, one with userID and one without userID). input={'userid','query','doc(s)'} output={'query','doc(s)'}
'pairing': [None, 'docs', 'query'], # [context={2 scenarios, one with userID and one without userID). input={'userid','query','doc(s)'} output={'query','doc(s)'}
'extcorpus': 'msmarco.passage',
'lseq': {"inputs": 32, "targets": 256}, # query length and doc length for t5 model,
'filter': {'minql': 1, 'mindocl': 10} # [min query length, min doc length], after merge queries with relevant 'index_item', if |query| <= minql drop the row, if |'index_item'| < mindocl, drop row
'filter': {'minql': 1, 'mindocl': 10} # [min query length, min doc length], after merge queries with relevant 'index_item', if |query| <= minql drop the row, if |'index_item'| < mindocl, drop row
},
'robust04': {
'index': '../data/raw/robust04/lucene-index.robust04.pos+docvectors+rawdocs',
Expand Down
316 changes: 316 additions & 0 deletions src/refinement/lang_code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,316 @@
google = {
'afrikaans': 'af',
'albanian': 'sq',
'amharic': 'am',
'arabic': 'ar',
'armenian': 'hy',
'azerbaijani': 'az',
'basque': 'eu',
'belarusian': 'be',
'bengali': 'bn',
'bosnian': 'bs',
'bulgarian': 'bg',
'catalan': 'ca',
'cebuano': 'ceb',
'chichewa': 'ny',
'chinese_simplified': 'zh-cn',
'chinese_traditional': 'zh-tw',
'corsican': 'co',
'croatian': 'hr',
'czech': 'cs',
'danish': 'da',
'dutch': 'nl',
'english': 'en',
'esperanto': 'eo',
'estonian': 'et',
'filipino': 'tl',
'finnish': 'fi',
'french': 'fr',
'frisian': 'fy',
'galician': 'gl',
'georgian': 'ka',
'german': 'de',
'greek': 'el',
'gujarati': 'gu',
'haitian creole': 'ht',
'hausa': 'ha',
'hawaiian': 'haw',
'hebrew': 'iw',
'hindi': 'hi',
'hmong': 'hmn',
'hungarian': 'hu',
'icelandic': 'is',
'igbo': 'ig',
'indonesian': 'id',
'irish': 'ga',
'italian': 'it',
'japanese': 'ja',
'javanese': 'jw',
'kannada': 'kn',
'kazakh': 'kk',
'khmer': 'km',
'korean': 'ko',
'kurdish': 'ku',
'kyrgyz': 'ky',
'lao': 'lo',
'latin': 'la',
'latvian': 'lv',
'lithuanian': 'lt',
'luxembourgish': 'lb',
'macedonian': 'mk',
'malagasy': 'mg',
'malay': 'ms',
'malayalam': 'ml',
'maltese': 'mt',
'maori': 'mi',
'marathi': 'mr',
'mongolian': 'mn',
'myanmar': 'my',
'nepali': 'ne',
'norwegian': 'no',
'odia': 'or',
'pashto': 'ps',
'persian': 'fa',
'polish': 'pl',
'portuguese': 'pt',
'punjabi': 'pa',
'romanian': 'ro',
'russian': 'ru',
'samoan': 'sm',
'scots gaelic': 'gd',
'serbian': 'sr',
'sesotho': 'st',
'shona': 'sn',
'sindhi': 'sd',
'sinhala': 'si',
'slovak': 'sk',
'slovenian': 'sl',
'somali': 'so',
'spanish': 'es',
'sundanese': 'su',
'swahili': 'sw',
'swedish': 'sv',
'tajik': 'tg',
'tamil': 'ta',
'telugu': 'te',
'thai': 'th',
'turkish': 'tr',
'ukrainian': 'uk',
'urdu': 'ur',
'uyghur': 'ug',
'uzbek': 'uz',
'vietnamese': 'vi',
'welsh': 'cy',
'xhosa': 'xh',
'yiddish': 'yi',
'yoruba': 'yo',
'zulu': 'zu'
}

nllb = {
'acehnese (arabic script)': 'ace_Arab',
'acehnese (latin script)': 'ace_Latn',
'mesopotamian arabic': 'acm_Arab',
'ta’izzi-adeni arabic': 'acq_Arab',
'tunisian arabic': 'aeb_Arab',
'afrikaans': 'afr_Latn',
'south levantine arabic': 'ajp_Arab',
'akan': 'aka_Latn',
'amharic': 'amh_Ethi',
'north levantine arabic': 'apc_Arab',
'arabic': 'arb_Arab',
'modern standard arabic (romanized)': 'arb_Latn',
'najdi arabic': 'ars_Arab',
'moroccan arabic': 'ary_Arab',
'egyptian arabic': 'arz_Arab',
'assamese': 'asm_Beng',
'asturian': 'ast_Latn',
'awadhi': 'awa_Deva',
'central aymara': 'ayr_Latn',
'south azerbaijani': 'azb_Arab',
'north azerbaijani': 'azj_Latn',
'bashkir': 'bak_Cyrl',
'bambara': 'bam_Latn',
'balinese': 'ban_Latn',
'belarusian': 'bel_Cyrl',
'bemba': 'bem_Latn',
'bengali': 'ben_Beng',
'bhojpuri': 'bho_Deva',
'banjar (arabic script)': 'bjn_Arab',
'banjar (latin script)': 'bjn_Latn',
'standard tibetan': 'bod_Tibt',
'bosnian': 'bos_Latn',
'buginese': 'bug_Latn',
'bulgarian': 'bul_Cyrl',
'catalan': 'cat_Latn',
'cebuano': 'ceb_Latn',
'czech': 'ces_Latn',
'chokwe': 'cjk_Latn',
'central kurdish': 'ckb_Arab',
'crimean tatar': 'crh_Latn',
'welsh': 'cym_Latn',
'danish': 'dan_Latn',
'german': 'deu_Latn',
'southwestern dinka': 'dik_Latn',
'dyula': 'dyu_Latn',
'dzongkha': 'dzo_Tibt',
'greek': 'ell_Grek',
'english': 'eng_Latn',
'esperanto': 'epo_Latn',
'estonian': 'est_Latn',
'basque': 'eus_Latn',
'ewe': 'ewe_Latn',
'faroese': 'fao_Latn',
'fijian': 'fij_Latn',
'finnish': 'fin_Latn',
'fon': 'fon_Latn',
'french': 'fra_Latn',
'friulian': 'fur_Latn',
'nigerian fulfulde': 'fuv_Latn',
'scottish gaelic': 'gla_Latn',
'irish': 'gle_Latn',
'galician': 'glg_Latn',
'guarani': 'grn_Latn',
'gujarati': 'guj_Gujr',
'haitian creole': 'hat_Latn',
'hausa': 'hau_Latn',
'hebrew': 'heb_Hebr',
'hindi': 'hin_Deva',
'chhattisgarhi': 'hne_Deva',
'croatian': 'hrv_Latn',
'hungarian': 'hun_Latn',
'armenian': 'hye_Armn',
'igbo': 'ibo_Latn',
'ilocano': 'ilo_Latn',
'indonesian': 'ind_Latn',
'icelandic': 'isl_Latn',
'italian': 'ita_Latn',
'javanese': 'jav_Latn',
'japanese': 'jpn_Jpan',
'kabyle': 'kab_Latn',
'jingpho': 'kac_Latn',
'kamba': 'kam_Latn',
'kannada': 'kan_Knda',
'kashmiri (arabic script)': 'kas_Arab',
'kashmiri (devanagari script)': 'kas_Deva',
'georgian': 'kat_Geor',
'central kanuri (arabic script)': 'knc_Arab',
'central kanuri (latin script)': 'knc_Latn',
'kazakh': 'kaz_Cyrl',
'kabiyè': 'kbp_Latn',
'kabuverdianu': 'kea_Latn',
'khmer': 'khm_Khmr',
'kikuyu': 'kik_Latn',
'kinyarwanda': 'kin_Latn',
'kyrgyz': 'kir_Cyrl',
'kimbundu': 'kmb_Latn',
'northern kurdish': 'kmr_Latn',
'kikongo': 'kon_Latn',
'korean': 'kor_Hang',
'lao': 'lao_Laoo',
'ligurian': 'lij_Latn',
'limburgish': 'lim_Latn',
'lingala': 'lin_Latn',
'lithuanian': 'lit_Latn',
'lombard': 'lmo_Latn',
'latgalian': 'ltg_Latn',
'luxembourgish': 'ltz_Latn',
'luba-kasai': 'lua_Latn',
'ganda': 'lug_Latn',
'luo': 'luo_Latn',
'mizo': 'lus_Latn',
'standard latvian': 'lvs_Latn',
'magahi': 'mag_Deva',
'maithili': 'mai_Deva',
'malayalam': 'mal_Mlym',
'marathi': 'mar_Deva',
'minangkabau (arabic script)': 'min_Arab',
'minangkabau (latin script)': 'min_Latn',
'macedonian': 'mkd_Cyrl',
'plateau malagasy': 'plt_Latn',
'maltese': 'mlt_Latn',
'meitei (bengali script)': 'mni_Beng',
'halh mongolian': 'khk_Cyrl',
'mossi': 'mos_Latn',
'maori': 'mri_Latn',
'burmese': 'mya_Mymr',
'dutch': 'nld_Latn',
'norwegian nynorsk': 'nno_Latn',
'norwegian bokmål': 'nob_Latn',
'nepali': 'npi_Deva',
'northern sotho': 'nso_Latn',
'nuer': 'nus_Latn',
'nyanja': 'nya_Latn',
'occitan': 'oci_Latn',
'west central oromo': 'gaz_Latn',
'odia': 'ory_Orya',
'pangasinan': 'pag_Latn',
'eastern panjabi': 'pan_Guru',
'papiamento': 'pap_Latn',
'persian': 'pes_Arab',
'polish': 'pol_Latn',
'portuguese': 'por_Latn',
'dari': 'prs_Arab',
'southern pashto': 'pbt_Arab',
'ayacucho quechua': 'quy_Latn',
'romanian': 'ron_Latn',
'rundi': 'run_Latn',
'russian': 'rus_Cyrl',
'sango': 'sag_Latn',
'sanskrit': 'san_Deva',
'santali': 'sat_Olck',
'sicilian': 'scn_Latn',
'shan': 'shn_Mymr',
'sinhala': 'sin_Sinh',
'slovak': 'slk_Latn',
'slovenian': 'slv_Latn',
'samoan': 'smo_Latn',
'shona': 'sna_Latn',
'sindhi': 'snd_Arab',
'somali': 'som_Latn',
'southern sotho': 'sot_Latn',
'spanish': 'spa_Latn',
'tosk albanian': 'als_Latn',
'sardinian': 'srd_Latn',
'serbian': 'srp_Cyrl',
'swati': 'ssw_Latn',
'sundanese': 'sun_Latn',
'swedish': 'swe_Latn',
'swahili': 'swh_Latn',
'silesian': 'szl_Latn',
'tamil': 'tam_Taml',
'tatar': 'tat_Cyrl',
'telugu': 'tel_Telu',
'tajik': 'tgk_Cyrl',
'tagalog': 'tgl_Latn',
'thai': 'tha_Thai',
'tigrinya': 'tir_Ethi',
'tamasheq (latin script)': 'taq_Latn',
'tamasheq (tifinagh script)': 'taq_Tfng',
'tok pisin': 'tpi_Latn',
'tswana': 'tsn_Latn',
'tsonga': 'tso_Latn',
'turkmen': 'tuk_Latn',
'tumbuka': 'tum_Latn',
'turkish': 'tur_Latn',
'twi': 'twi_Latn',
'central atlas tamazight': 'tzm_Tfng',
'uyghur': 'uig_Arab',
'ukrainian': 'ukr_Cyrl',
'umbundu': 'umb_Latn',
'urdu': 'urd_Arab',
'northern uzbek': 'uzn_Latn',
'venetian': 'vec_Latn',
'vietnamese': 'vie_Latn',
'waray': 'war_Latn',
'wolof': 'wol_Latn',
'xhosa': 'xho_Latn',
'eastern yiddish': 'ydd_Hebr',
'yoruba': 'yor_Latn',
'yue chinese': 'yue_Hant',
'chinese_simplified': 'zho_Hans',
'chinese_traditional': 'zho_Hant',
'malay': 'zsm_Latn',
'zulu': 'zul_Latn'
}

Loading

0 comments on commit fe3d081

Please sign in to comment.