Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure laser_encoders has parity with existing LASER inference code for release #268

Merged
merged 2 commits into from
Nov 20, 2023

Conversation

heffernankevin
Copy link
Contributor

@heffernankevin heffernankevin commented Nov 17, 2023

Why?

Before releasing laser_encoders we need to ensure that we have parity (within a reasonable tolerance) compared to the existing LASER inference code. This PR updates laser_encoders to ensure parity and then rigorously compares the inference of both laser_encoders and older LASER code side-by-side for both LASER(2) and LASER(3) on all of FLORES (approx. 0.5M lines of text). Results show embedding parity (atol=1e-3) with the exception of one line in shn_Mymr.devtest for LASER(2) i.e., we have parity on all of FLORES except for one sentence, for one language, on LASER(2). I'm not sure exactly the difference, but I think our current parity is acceptable.

How

  1. Resolved parity with script: remove-non-printing-char.perl

The existing code called the Python string function isprintable. However, after checking the MOSES perl script remove-non-printing-char.perl they are removing a range of specific unicode characters falling under the category: "C" (other). Updated code to account for this and added new dependency in pyproject.toml for library unicategories.

  1. Resolved parity with MOSES 4.0 release of punctuation-normalization.perl

After Siddharth's update to sacremoses which resolves parity with the current version of MOSES, unfortunately in LASER we are using a specific version: 4.0. Instead of requesting another update to sacremoses to support this deprecated version, we ourselves update the regexes, and then freeze the version of sacremoses in pyproject.toml.

  1. Updated test for laser_tokenizer

Test plans

  1. Check correctness against LASER2
    python parity.py run_comparison_parallel --laser_type laser2

  2. Check correctness against LASER3
    python parity.py run_comparison_parallel --laser_type laser3

Code below for parity.py (perhaps we should check this in somewhere?).

from pathlib import Path
from laser_encoders import LaserEncoderPipeline
from source.embed import embed_sentences
import numpy as np
import submitit

FLORES = Path("/path/to/flores200_dataset")
MODEL_DIR = Path("/path/to/model/dir")
LASER2_ENCODER = MODEL_DIR / "laser2.pt"
LASER2_SPM = MODEL_DIR / "laser2.spm"
LASER3_ENCODER = MODEL_DIR / "laser3-fuv.v1.pt"
LASER3_SPM = MODEL_DIR / "laser3-fuv.v1.spm"
OUTPUT_DIR = Path("/path/to/out_dir")


def run_old_laser(input_file: Path, laser_type: str) -> np.ndarray:
    encoder, spm_model = (
        (LASER2_ENCODER, LASER2_SPM)
        if laser_type == "laser2"
        else (LASER3_ENCODER, LASER3_SPM)
    )
    output_file = OUTPUT_DIR / "old_laser" / laser_type / f"{input_file.name}"
    embed_sentences(
        ifname=str(input_file),
        encoder_path=str(encoder),
        spm_model=str(spm_model),
        token_lang="--",
        bpe_codes=None,
        spm_lang="en",
        hugging_face=False,
        verbose=True,
        output=output_file,
        buffer_size=10000,
        max_tokens=12000,
        max_sentences=None,
        cpu=False,
        fp16=False,
        sort_kind="quicksort",
    )
    assert output_file.exists()
    dim = 1024
    X = np.fromfile(output_file, dtype=np.float32, count=-1)
    X.resize(X.shape[0] // dim, dim)
    return X


def run_new_laser(input_file: Path, laser_type: str) -> np.ndarray:
    lang = "eng" if laser_type == "laser2" else "fuv"
    encoder_pipeline = LaserEncoderPipeline(
        lang=lang, model_dir=str(MODEL_DIR), laser=laser_type
    )
    tokenized = [encoder_pipeline.tokenizer.tokenize(e) for e in open(input_file)]
    with open(
        OUTPUT_DIR / "new_laser" / laser_type / f"{input_file.name}.spm", "w"
    ) as output_file:
        for tokenized_sent in tokenized:
            output_file.write(tokenized_sent + "\n")
    embeddings = encoder_pipeline.encode_sentences([e for e in open(input_file)])
    return embeddings


def compare_laser_versions(input_file: Path, laser_type: str) -> bool:
    old_laser_embeddings = run_old_laser(input_file, laser_type=laser_type)
    new_laser_embeddings = run_new_laser(input_file, laser_type=laser_type)
    assert old_laser_embeddings.shape == new_laser_embeddings.shape
    return np.allclose(old_laser_embeddings, new_laser_embeddings, atol=1e-3)


def run_comparison_parallel(laser_type: str):
    assert laser_type in ["laser2", "laser3"]
    all_files = [
        file for ext in ["dev", "devtest"] for file in FLORES.glob(f"*/*.{ext}")
    ]
    njobs = len(all_files)
    submitit_folder = OUTPUT_DIR / "submitit_logs"
    executor = submitit.AutoExecutor(folder=submitit_folder)
    executor.update_parameters(
        slurm_partition="nllb,devaccel,learnaccel",
        slurm_array_parallelism=njobs,
        cpus_per_task=4,
        timeout_min=30,
        tasks_per_node=1,
        nodes=1,
        gpus_per_node=1,
    )
    print(f"sending job array for {njobs} items")
    jobs = executor.map_array(
        compare_laser_versions,
        all_files,
        [laser_type] * njobs,
    )
    print("gathering results")
    results = []
    for job in jobs:
        results.append(job.result())
    for index, file in enumerate(all_files):
        if not results[index]:
            print(f"not close on file: {file}")
    print(
        f"total files: {len(all_files)}, num true: {results.count(True)}, num false: {results.count(False)}"
    )


if __name__ == "__main__":
    import func_argparse

    func_argparse.main()

@facebook-github-bot facebook-github-bot added the CLA Signed Do not delete this pull request or issue due to inactivity. label Nov 17, 2023
@heffernankevin heffernankevin merged commit b4aed58 into MLH-dev Nov 20, 2023
3 checks passed
@heffernankevin heffernankevin deleted the fix-parity branch November 20, 2023 15:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed Do not delete this pull request or issue due to inactivity.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants