From 995c2f782e85f847d1f6c73d4137d54ef4a22073 Mon Sep 17 00:00:00 2001 From: Kevin Heffernan Date: Thu, 30 Nov 2023 10:29:59 -0800 Subject: [PATCH] pxsim release --- README.md | 3 +- source/pxsim.py | 251 ++++++++++++++++++++++++++++++++++++++++++ source/xsim.py | 4 +- tasks/pxsim/README.md | 27 +++++ tasks/pxsim/eval.sh | 98 +++++++++++++++++ 5 files changed, 380 insertions(+), 3 deletions(-) create mode 100644 source/pxsim.py create mode 100644 tasks/pxsim/README.md create mode 100755 tasks/pxsim/eval.sh diff --git a/README.md b/README.md index 526f9632..07c4954b 100644 --- a/README.md +++ b/README.md @@ -3,10 +3,11 @@ LASER is a library to calculate and use multilingual sentence embeddings. **NEWS** +* 2023/11/30 Released [**P-xSIM**](tasks/pxsim), a dual approach extension to multilingual similarity search (xSIM) * 2023/11/16 Released [**laser_encoders**](laser_encoders), a pip-installable package supporting LASER-2 and LASER-3 models * 2023/06/26 [**xSIM++**](https://arxiv.org/abs/2306.12907) evaluation pipeline and data [**released**](tasks/xsimplusplus/README.md) * 2022/07/06 Updated LASER models with support for over 200 languages are [**now available**](nllb/README.md) -* 2022/07/06 Multilingual similarity search (**xsim**) evaluation pipeline [**released**](tasks/xsim/README.md) +* 2022/07/06 Multilingual similarity search (**xSIM**) evaluation pipeline [**released**](tasks/xsim/README.md) * 2022/05/03 [**Librivox S2S is available**](tasks/librivox-s2s): Speech-to-Speech translations automatically mined in Librivox [9] * 2019/11/08 [**CCMatrix is available**](tasks/CCMatrix): Mining billions of high-quality parallel sentences on the WEB [8] * 2019/07/31 Gilles Bodard and Jérémy Rapin provided a [**Docker environment**](docker) to use LASER diff --git a/source/pxsim.py b/source/pxsim.py new file mode 100644 index 00000000..be94769b --- /dev/null +++ b/source/pxsim.py @@ -0,0 +1,251 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# LASER Language-Agnostic SEntence Representations +# is a toolkit to calculate multilingual sentence embeddings +# and to use them for various tasks such as document classification, +# and bitext filtering +# +# -------------------------------------------------------- +# +# Tool to calculate the dual approach multilingual similarity error rate (P-xSIM) + +import typing as tp +from pathlib import Path + +import faiss +import numpy as np +import torch +from scipy.special import softmax +from sklearn.metrics.pairwise import cosine_similarity +from stopes.eval.auto_pcp.audio_comparator import Comparator, get_model_pred +from xsim import Margin, score_margin + + +def get_neighbors( + x: np.ndarray, y: np.ndarray, k: int, margin: str +) -> tp.Tuple[np.ndarray, np.ndarray, int]: + x_copy = x.astype(np.float32).copy() + y_copy = y.astype(np.float32).copy() + nbex, dim = x.shape + # create index + idx_x = faiss.IndexFlatIP(dim) + idx_y = faiss.IndexFlatIP(dim) + # L2 normalization needed for cosine distance + faiss.normalize_L2(x_copy) + faiss.normalize_L2(y_copy) + idx_x.add(x_copy) + idx_y.add(y_copy) + if margin == Margin.ABSOLUTE.value: + scores, indices = idx_y.search(x_copy, k) + else: + # return cosine similarity and indices of k closest neighbors + Cos_xy, Idx_xy = idx_y.search(x_copy, k) + Cos_yx, Idx_yx = idx_x.search(y_copy, k) + + # average cosines + Avg_xy = Cos_xy.mean(axis=1) + Avg_yx = Cos_yx.mean(axis=1) + + scores = score_margin(Cos_xy, Idx_xy, Avg_xy, Avg_yx, margin, k) + indices = Idx_xy + return scores, indices, nbex + + +def get_cosine_scores(src_emb: np.ndarray, neighbor_embs: np.ndarray) -> np.ndarray: + assert src_emb.shape[0] == neighbor_embs.shape[1] + src_embs = np.repeat( + np.expand_dims(src_emb, axis=0), neighbor_embs.shape[0], axis=0 + ) + cosine_scores = cosine_similarity(src_embs, neighbor_embs).diagonal() + return cosine_scores + + +def get_comparator_scores( + src_emb: np.ndarray, + neighbor_embs: np.ndarray, + comparator_model: tp.Any, + symmetrize_comparator: bool, +) -> np.ndarray: + src_embs = np.repeat( + np.expand_dims(src_emb, axis=0), neighbor_embs.shape[0], axis=0 + ) + a = torch.from_numpy(src_embs).unsqueeze(1) # restore depth dim + b = torch.from_numpy(neighbor_embs).unsqueeze(1) + res = get_comparator_preds(a, b, comparator_model, symmetrize_comparator) + scores_softmax = softmax(res) + return np.array(scores_softmax) + + +def get_comparator_preds( + src_emb: np.ndarray, tgt_emb: np.ndarray, model: tp.Any, symmetrize: bool +): + preds = ( + get_model_pred( + model, + src=src_emb[:, 0], + mt=tgt_emb[:, 0], + use_gpu=model.use_gpu, + batch_size=1, + )[:, 0] + .cpu() + .numpy() + ) + if symmetrize: + preds2 = ( + get_model_pred( + model, + src=tgt_emb[:, 0], + mt=src_emb[:, 0], + use_gpu=model.use_gpu, + batch_size=1, + )[:, 0] + .cpu() + .numpy() + ) + preds = (preds2 + preds) / 2 + return preds + + +def get_blended_predictions( + alpha: float, + nbex: int, + margin_scores: np.ndarray, + x_aux: np.ndarray, + y_aux: np.ndarray, + neighbor_indices: np.ndarray, + comparator_model: tp.Optional[tp.Any] = None, + symmetrize_comparator: bool = False, +) -> list[int]: + predictions = [] + for src_index in range(nbex): + neighbors = neighbor_indices[src_index] + neighbor_embs = y_aux[neighbors].astype(np.float32) + src_emb = x_aux[src_index].astype(np.float32) + aux_scores = ( + get_comparator_scores( + src_emb, neighbor_embs, comparator_model, symmetrize_comparator + ) + if comparator_model + else get_cosine_scores(src_emb, neighbor_embs) + ) + assert margin_scores[src_index].shape == aux_scores.shape + blended_scores = alpha * margin_scores[src_index] + (1 - alpha) * aux_scores + blended_neighbor_idx = blended_scores.argmax() + predictions.append(neighbors[blended_neighbor_idx]) + return predictions + + +def PxSIM( + x: np.ndarray, + y: np.ndarray, + x_aux: np.ndarray, + y_aux: np.ndarray, + alpha: float, + margin: str = Margin.RATIO.value, + k: int = 16, + comparator_path: tp.Optional[Path] = None, + symmetrize_comparator: bool = False, +) -> tp.Tuple[int, int, list[int]]: + """ + Parameters + ---------- + x : np.ndarray + source-side embedding array + y : np.ndarray + target-side embedding array + x_aux : np.ndarray + source-side embedding array using auxiliary model + y_aux : np.ndarray + target-side embedding array using auxiliary model + alpha : int + parameter to weight blended score + margin : str + margin scoring function (e.g. ratio, absolute, distance) + k : int + number of neighbors in k-nn search + comparator_path : Path + path to AutoPCP model config + symmetrize_comparator : bool + whether to symmetrize the comparator predictions + + Returns + ------- + err : int + Number of errors + nbex : int + Number of examples + preds : list[int] + List of (index-based) predictions + """ + assert Margin.has_value(margin), f"Margin type: {margin}, is not supported." + comparator_model = Comparator.load(comparator_path) if comparator_path else None + # get margin-based nearest neighbors + margin_scores, neighbor_indices, nbex = get_neighbors(x, y, k=k, margin=margin) + preds = get_blended_predictions( + alpha, + nbex, + margin_scores, + x_aux, + y_aux, + neighbor_indices, + comparator_model, + symmetrize_comparator, + ) + err = sum([idx != pred for idx, pred in enumerate(preds)]) + print(f"P-xSIM error: {100 * (err / nbex):.2f}") + return err, nbex, preds + + +def load_embeddings( + infile: Path, dim: int, fp16: bool = False, numpy_header: bool = False +) -> np.ndarray: + assert infile.exists(), f"file: {infile} does not exist." + if numpy_header: + return np.load(infile) + emb = np.fromfile(infile, dtype=np.float16 if fp16 else np.float32) + num_examples = emb.shape[0] // dim + emb.resize(num_examples, dim) + if fp16: + emb = emb.astype(np.float32) # faiss currently only supports fp32 + return emb + + +def run( + src_emb: Path, + tgt_emb: Path, + src_aux_emb: Path, + tgt_aux_emb: Path, + alpha: float, + margin: str = Margin.RATIO.value, + k: int = 16, + emb_fp16: bool = False, + aux_emb_fp16: bool = False, + emb_dim: int = 1024, + aux_emb_dim: int = 1024, + numpy_header: bool = False, + comparator_path: tp.Optional[Path] = None, + symmetrize_comparator: bool = False, + prediction_savepath: tp.Optional[Path] = None, +) -> None: + x = load_embeddings(src_emb, emb_dim, emb_fp16, numpy_header) + y = load_embeddings(tgt_emb, emb_dim, emb_fp16, numpy_header) + x_aux = load_embeddings(src_aux_emb, aux_emb_dim, aux_emb_fp16, numpy_header) + y_aux = load_embeddings(tgt_aux_emb, aux_emb_dim, aux_emb_fp16, numpy_header) + assert (x.shape == y.shape) and (x_aux.shape == y_aux.shape) + _, _, preds = PxSIM( + x, y, x_aux, y_aux, alpha, margin, k, comparator_path, symmetrize_comparator + ) + if prediction_savepath: + with open(prediction_savepath, "w") as outf: + for pred in preds: + print(pred, file=outf) + + +if __name__ == "__main__": + import func_argparse + + func_argparse.main() diff --git a/source/xsim.py b/source/xsim.py index 031f60cb..d87123ae 100644 --- a/source/xsim.py +++ b/source/xsim.py @@ -60,7 +60,7 @@ def _load_embeddings(infile: str, dim: int, fp16: bool = False) -> np.ndarray: return emb -def _score_margin( +def score_margin( Dxy: np.ndarray, Ixy: np.ndarray, Ax: np.ndarray, @@ -103,7 +103,7 @@ def _score_knn(x: np.ndarray, y: np.ndarray, k: int, margin: str) -> np.ndarray: Avg_xy = Cos_xy.mean(axis=1) Avg_yx = Cos_yx.mean(axis=1) - scores = _score_margin(Cos_xy, Idx_xy, Avg_xy, Avg_yx, margin, k) + scores = score_margin(Cos_xy, Idx_xy, Avg_xy, Avg_yx, margin, k) # find best best = scores.argmax(axis=1) diff --git a/tasks/pxsim/README.md b/tasks/pxsim/README.md new file mode 100644 index 00000000..e9b39b3b --- /dev/null +++ b/tasks/pxsim/README.md @@ -0,0 +1,27 @@ +# LASER: P-xSIM (dual approach multilingual similarity error rate) + +This README shows how to calculate the P-xSIM error rate (Seamless Communication et al., 2023) for a given language pair. + +P-xSIM returns the error rate for recreating gold alignments using a blended combination of two different approaches. +It works by performing a k-nearest-neighbor search and margin calculation (i.e. margin-based parallel alignment) using the +first approach, followed by the scoring of each candidate neighbor using an auxiliary model (the second approach). Finally, +the scores of both the margin-based alignment and the auxiliary model are combined together using a blended score defined as: + +$$ \text{blended-score}(x, y) = \alpha \cdot \text{margin} + (1 - \alpha) \cdot \text{auxiliary-score} $$ + +where the parameter $\alpha$ controls the combination of both the margin-based and auxiliary scores. By default, the auxiliary-score will be calculated as the cosine between the source and candidate neighbors using the auxiliary embeddings. However, there is also an option to perform inference using a comparator model (Seamless Communication et al., 2023). In this instance, the auxiliary-score will be the AutoPCP outputs. + +P-xSIM offers three margin-based scoring options (discussed in detail [here](https://arxiv.org/pdf/1811.01136.pdf)): +- distance +- ratio +- absolute + +## Example usage + +Simply run the example script `bash ./eval.sh` to download a sample dataset (flores200), sample encoders (laser2 and LaBSE), +and then perform P-xSIM. In this toy example, we use laser2 to provide the k-nearest-neighbors, followed by applying LaBSE as an +auxiliary model on each candidate neighbor, before then applying the blended scoring function defined above. Dependending on +your data sources, you may want to alter the approach used for either margin-based parallel alignment, or the scoring of each candidate neighbor +(i.e. the auxiliary model). + +In addition to LaBSE in the example above, you can also calculate P-xSIM using any model hosted on [HuggingFace sentence-transformers](https://huggingface.co/sentence-transformers). diff --git a/tasks/pxsim/eval.sh b/tasks/pxsim/eval.sh new file mode 100755 index 00000000..336cf949 --- /dev/null +++ b/tasks/pxsim/eval.sh @@ -0,0 +1,98 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# LASER Language-Agnostic SEntence Representations +# is a toolkit to calculate multilingual sentence embeddings +# and to use them for various tasks such as document classification, +# and bitext filtering +# +#------------------------------------------------------- +# +# This bash script downloads the flores200 dataset, laser2, and then +# performs pxsim evaluation + +if [ -z ${LASER} ] ; then + echo "Please set the environment variable 'LASER'" + exit +fi + +ddir="${LASER}/data" +cd $ddir # move to data directory + +if [ ! -d $ddir/flores200 ] ; then + echo " - Downloading flores200..." + wget --trust-server-names -q https://tinyurl.com/flores200dataset + tar -xf flores200_dataset.tar.gz + /bin/mv flores200_dataset flores200 + /bin/rm flores200_dataset.tar.gz +else + echo " - flores200 already downloaded" +fi + +cd - + +mdir="${LASER}/models" +if [ ! -d ${mdir} ] ; then + echo " - creating model directory: ${mdir}" + mkdir -p ${mdir} +fi + +function download { + file=$1 + save_dir=$2 + if [ -f ${save_dir}/${file} ] ; then + echo " - ${save_dir}/$file already downloaded"; + else + cd $save_dir + echo " - Downloading $s3/${file}"; + wget -q $s3/${file}; + cd - + fi +} + +# available encoders +s3="https://dl.fbaipublicfiles.com/nllb/laser" + +if [ ! -f ${mdir}/laser2.pt ] ; then + cd $mdir + echo " - Downloading $s3/laser2.pt" + wget --trust-server-names -q https://tinyurl.com/nllblaser2 + cd - +else + echo " - ${mdir}/laser2.pt already downloaded" +fi +download "laser2.spm" $mdir +download "laser2.cvocab" $mdir + +# encode FLORES200 texts using both LASER2 and LaBSE +for lang in eng_Latn wol_Latn; do + infile=$LASER/data/flores200/devtest/$lang.devtest + python3 ${LASER}/source/embed.py \ + --input $infile \ + --encoder $mdir/laser2.pt \ + --spm-model $mdir/laser2.spm \ + --output $lang.devtest.laser2 \ + --verbose + + python3 ${LASER}/source/embed.py \ + --input $infile \ + --encoder LaBSE \ + --use-hugging-face \ + --output $lang.devtest.labse \ + --verbose +done + +# run pxsim using LaBSE as an auxiliary scoring model +echo " - calculating p-xsim" +python3 $LASER/source/pxsim.py run \ + --src_emb wol_Latn.devtest.laser2 \ + --tgt_emb eng_Latn.devtest.laser2 \ + --src_aux_emb wol_Latn.devtest.labse \ + --tgt_aux_emb eng_Latn.devtest.labse \ + --alpha 0.1 \ + --k 32 \ + --aux_emb_dim 768