Skip to content

Commit

Permalink
Word alignment Refinement (#151)
Browse files Browse the repository at this point in the history
* create separate files for aligning words, don't just use the training data.

* I think it's working now.

* Working test, minor updates to naming

---------

Co-authored-by: John Lambert <[email protected]>
  • Loading branch information
Enkidu93 and johnml1135 authored Jan 17, 2025
1 parent 6de035f commit 101d227
Show file tree
Hide file tree
Showing 10 changed files with 193 additions and 58 deletions.
3 changes: 2 additions & 1 deletion .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@
"tamasfe.even-better-toml",
"github.vscode-github-actions",
"mhutchie.git-graph",
"GitHub.copilot"
"GitHub.copilot",
"ms-toolsai.jupyter"
]
}
},
Expand Down
18 changes: 14 additions & 4 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,21 @@
},
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"python.analysis.extraPaths": ["tests"],
"python.analysis.importFormat": "relative",
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": true
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit"
},
},
"black-formatter.path": ["poetry", "run", "black"]
}
"black-formatter.path": [
"poetry",
"run",
"black"
],
"isort.args": [
"--profile",
"black"
],
}
27 changes: 27 additions & 0 deletions machine/corpora/parallel_text_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ def from_hf_dataset(
ref_factory,
)

@classmethod
def from_parallel_rows(
cls,
rows: Iterable[ParallelTextRow],
) -> ParallelTextCorpus:
return _FromParallelRowsTextCorpus(rows)

@property
@abstractmethod
def is_source_tokenized(self) -> bool: ...
Expand Down Expand Up @@ -754,3 +761,23 @@ def _get_translation(self, lang: str, example: dict) -> str:
except ValueError:
return ""
return ""


class _FromParallelRowsTextCorpus(ParallelTextCorpus):
def __init__(self, rows: Iterable[ParallelTextRow]) -> None:
self._rows = rows

def _get_rows(self, text_ids: Optional[Iterable[str]] = None) -> Generator[ParallelTextRow, None, None]:
if text_ids is None:
yield from self._rows
else:
text_ids = set(text_ids)
yield from [row for row in self._rows if row.text_id in text_ids]

@property
def is_source_tokenized(self) -> bool:
return True

@property
def is_target_tokenized(self) -> bool:
return True
2 changes: 1 addition & 1 deletion machine/jobs/build_word_alignment_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def run(args: dict):


def main() -> None:
parser = argparse.ArgumentParser(description="Trains an SMT model.")
parser = argparse.ArgumentParser(description="Trains a word alignment model.")
parser.add_argument("--model-type", required=True, type=str, help="Model type")
parser.add_argument("--build-id", required=True, type=str, help="Build id")
parser.add_argument("--clearml", default=False, action="store_true", help="Initializes a ClearML task")
Expand Down
17 changes: 16 additions & 1 deletion machine/jobs/shared_file_service_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,21 @@
from typing import Any, Iterator, TextIO


class PrettyFloat(float):
def __repr__(self):
return "%.8g" % self


def pretty_floats(obj):
if isinstance(obj, float):
return PrettyFloat(obj)
elif isinstance(obj, dict):
return dict((k, pretty_floats(v)) for k, v in obj.items())
elif isinstance(obj, (list, tuple)):
return list(map(pretty_floats, obj))
return obj


class DictToJsonWriter:
def __init__(self, file: TextIO) -> None:
self._file = file
Expand All @@ -12,7 +27,7 @@ def __init__(self, file: TextIO) -> None:
def write(self, pi: object) -> None:
if not self._first:
self._file.write(",\n")
self._file.write(" " + json.dumps(pi))
self._file.write(" " + json.dumps(pretty_floats(pi)))
self._first = False


Expand Down
52 changes: 37 additions & 15 deletions machine/jobs/word_alignment_build_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from ..corpora.aligned_word_pair import AlignedWordPair
from ..corpora.parallel_text_corpus import ParallelTextCorpus
from ..corpora.parallel_text_row import ParallelTextRow
from ..tokenization.tokenizer_factory import create_tokenizer
from ..utils.phased_progress_reporter import Phase, PhasedProgressReporter
from ..utils.progress_status import ProgressStatus
Expand Down Expand Up @@ -50,7 +51,8 @@ def run(
check_canceled()

logger.info("Generating alignments")
self._batch_inference(parallel_corpus, progress_reporter, check_canceled)

self._batch_inference(progress_reporter, check_canceled)

self._save_model()
return train_corpus_size
Expand Down Expand Up @@ -83,39 +85,59 @@ def _train_model(

def _batch_inference(
self,
parallel_corpus: ParallelTextCorpus,
progress_reporter: PhasedProgressReporter,
check_canceled: Optional[Callable[[], None]],
) -> None:
inference_step_count = parallel_corpus.count(include_empty=False)

inference_inputs = self._word_alignment_file_service.get_word_alignment_inputs()

inference_step_count = len(inference_inputs)

with ExitStack() as stack:
phase_progress = stack.enter_context(progress_reporter.start_next_phase())
alignment_model = stack.enter_context(self._word_alignment_model_factory.create_alignment_model())
writer = stack.enter_context(self._word_alignment_file_service.open_target_alignment_writer())
writer = stack.enter_context(self._word_alignment_file_service.open_alignment_output_writer())
current_inference_step = 0
phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count))
batch_size = self._config["inference_batch_size"]
segment_batch = list(parallel_corpus.lowercase().tokenize(self._tokenizer).take(batch_size))

parallel_corpus = ParallelTextCorpus.from_parallel_rows(
[
ParallelTextRow(
ii["textId"],
ii["refs"],
ii["refs"],
list(self._tokenizer.tokenize(ii["source"])),
list(self._tokenizer.tokenize(ii["target"])),
)
for ii in inference_inputs
]
).lowercase()

segment_batch = list(parallel_corpus.take(batch_size))
if check_canceled is not None:
check_canceled()
alignments = alignment_model.align_batch(segment_batch)
if check_canceled is not None:
check_canceled()

def format_score(score: float) -> str:
return f"{score:.8f}".rstrip("0").rstrip(".")

for row, alignment in zip(parallel_corpus.get_rows(), alignments):
source_segment = list(self._tokenizer.tokenize(row.source_text))
target_segment = list(self._tokenizer.tokenize(row.target_text))
for parallel_text_row, inference_input, alignment in zip(
parallel_corpus.get_rows(), inference_inputs, alignments
):
word_pairs = alignment.to_aligned_word_pairs(include_null=True)
alignment_model.compute_aligned_word_pair_scores(source_segment, target_segment, word_pairs)
alignment_model.compute_aligned_word_pair_scores(
parallel_text_row.source_segment, parallel_text_row.target_segment, word_pairs
)

word_alignment_info = {
"refs": [str(ref) for ref in row.source_refs],
"column_count": alignment.column_count,
"row_count": alignment.row_count,
"corpus_id": inference_input["corpusId"],
"text_id": inference_input["textId"],
"refs": [str(ref) for ref in inference_input["refs"]],
"source_tokens": parallel_text_row.source_segment,
"target_tokens": parallel_text_row.target_segment,
"confidences": [
word_pair.alignment_score * word_pair.translation_score for word_pair in word_pairs
],
"alignment": AlignedWordPair.to_string(word_pairs),
}
writer.write(word_alignment_info)
Expand Down
44 changes: 39 additions & 5 deletions machine/jobs/word_alignment_file_service.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,38 @@
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Iterator
from typing import Any, Iterator, List, TypedDict

import json_stream

from ..corpora.text_corpus import TextCorpus
from ..corpora.text_file_text_corpus import TextFileTextCorpus
from .shared_file_service_base import DictToJsonWriter, SharedFileServiceBase
from .shared_file_service_factory import SharedFileServiceType, get_shared_file_service


class WordAlignmentInput(TypedDict):
corpusId: str # noqa: N815
textId: str # noqa: N815
refs: List[str]
source: str
target: str


class WordAlignmentFileService:
def __init__(
self,
type: SharedFileServiceType,
config: Any,
source_filename: str = "train.src.txt",
target_filename: str = "train.trg.txt",
word_alignment_filename: str = "word_alignments.json",
word_alignment_input_filename: str = "word_alignments.inputs.json",
word_alignment_output_filename: str = "word_alignments.outputs.json",
) -> None:

self._source_filename = source_filename
self._target_filename = target_filename
self._word_alignment_filename = word_alignment_filename
self._word_alignment_input_filename = word_alignment_input_filename
self._word_alignment_output_filename = word_alignment_output_filename

self.shared_file_service: SharedFileServiceBase = get_shared_file_service(type, config)

Expand All @@ -34,15 +46,37 @@ def create_target_corpus(self) -> TextCorpus:
self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{self._target_filename}")
)

def get_word_alignment_inputs(self) -> List[WordAlignmentInput]:
src_pretranslate_path = self.shared_file_service.download_file(
f"{self.shared_file_service.build_path}/{self._word_alignment_input_filename}"
)
with src_pretranslate_path.open("r", encoding="utf-8-sig") as file:
wa_inputs = [
WordAlignmentInput(
corpusId=pi["corpusId"],
textId=pi["textId"],
refs=list(pi["refs"]),
source=pi["source"],
target=pi["target"],
)
for pi in json_stream.load(file)
]
return wa_inputs

def exists_source_corpus(self) -> bool:
return self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{self._source_filename}")

def exists_target_corpus(self) -> bool:
return self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{self._target_filename}")

def exists_word_alignment_inputs(self) -> bool:
return self.shared_file_service._exists_file(
f"{self.shared_file_service.build_path}/{self._word_alignment_input_filename}"
)

def save_model(self, model_path: Path, destination: str) -> None:
self.shared_file_service.upload_path(model_path, destination)

@contextmanager
def open_target_alignment_writer(self) -> Iterator[DictToJsonWriter]:
return self.shared_file_service.open_target_writer(self._word_alignment_filename)
def open_alignment_output_writer(self) -> Iterator[DictToJsonWriter]:
return self.shared_file_service.open_target_writer(self._word_alignment_output_filename)
56 changes: 28 additions & 28 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ charset-normalizer = "^2.1.1"
urllib3 = "<2"

sentencepiece = "^0.2.0"
sil-thot = "^3.4.4"
sil-thot = "^3.4.6"

transformers = ">=4.38.0,<4.46"
datasets = "^2.4.0"
Expand Down
Loading

0 comments on commit 101d227

Please sign in to comment.