Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
johnml1135 committed Nov 21, 2024
1 parent 7d3b138 commit dd337ab
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 163 deletions.
30 changes: 17 additions & 13 deletions src/SIL.Machine/Translation/IWordAlignmentEngine.cs
Original file line number Diff line number Diff line change
@@ -1,26 +1,30 @@
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using SIL.Machine.Corpora;
using SIL.ObjectModel;

namespace SIL.Machine.Translation
{
public interface IWordAlignmentEngine : IWordAligner, IDisposable
{
Task<WordAlignmentResult> GetBestAlignmentAsync(
string sourceSegment,
string targetSegment,
CancellationToken cancellationToken = default
);
IWordVocabulary SourceWords { get; }
IWordVocabulary TargetWords { get; }
IReadOnlySet<int> SpecialSymbolIndices { get; }

IEnumerable<(string TargetWord, double Score)> GetTranslations(string sourceWord, double threshold = 0);
IEnumerable<(int TargetWordIndex, double Score)> GetTranslations(int sourceWordIndex, double threshold = 0);

Task<WordAlignmentResult> GetBestAlignmentAsync(
double GetTranslationScore(string sourceWord, string targetWord);
double GetTranslationScore(int sourceWordIndex, int targetWordIndex);

IReadOnlyCollection<AlignedWordPair> GetBestAlignedWordPairs(
IReadOnlyList<string> sourceSegment,
IReadOnlyList<string> targetSegment
);
void ComputeAlignedWordPairScores(
IReadOnlyList<string> sourceSegment,
IReadOnlyList<string> targetSegment,
CancellationToken cancellationToken = default
IReadOnlyCollection<AlignedWordPair> wordPairs
);

WordAlignmentResult GetBestAlignment(string sourceSegment, string targetSegment);

WordAlignmentResult GetBestAlignment(IReadOnlyList<string> sourceSegment, IReadOnlyList<string> targetSegment);
}
}
173 changes: 173 additions & 0 deletions src/SIL.Machine/Translation/SymmetrizedWordAlignmentEngine.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
using System;
using System.Collections.Generic;
using System.Linq;
using SIL.Machine.Corpora;
using SIL.ObjectModel;

namespace SIL.Machine.Translation
{
public class SymmetrizedWordAlignmentEngine : DisposableBase, IWordAlignmentEngine
{
private readonly IWordAlignmentEngine _directWordAlignmentEngine;
private readonly IWordAlignmentEngine _inverseWordAlignmentEngine;
private readonly SymmetrizedWordAligner _aligner;

public SymmetrizedWordAlignmentEngine(
IWordAlignmentEngine directWordAlignmentEngine,
IWordAlignmentEngine inverseWordAlignmentEngine
)
{
_directWordAlignmentEngine = directWordAlignmentEngine;
_inverseWordAlignmentEngine = inverseWordAlignmentEngine;
_aligner = new SymmetrizedWordAligner(DirectWordAlignmentEngine, InverseWordAlignmentEngine);
}

public SymmetrizationHeuristic Heuristic
{
get => _aligner.Heuristic;
set => _aligner.Heuristic = value;
}

public IWordAlignmentEngine DirectWordAlignmentEngine
{
get
{
CheckDisposed();

return _directWordAlignmentEngine;
}
}

public IWordAlignmentEngine InverseWordAlignmentEngine
{
get
{
CheckDisposed();

return _inverseWordAlignmentEngine;
}
}

public IWordVocabulary SourceWords
{
get
{
CheckDisposed();

return _directWordAlignmentEngine.SourceWords;
}
}

public IWordVocabulary TargetWords
{
get
{
CheckDisposed();

return _directWordAlignmentEngine.TargetWords;
}
}

public IReadOnlySet<int> SpecialSymbolIndices => _directWordAlignmentEngine.SpecialSymbolIndices;

public WordAlignmentMatrix Align(IReadOnlyList<string> sourceSegment, IReadOnlyList<string> targetSegment)
{
CheckDisposed();

return _aligner.Align(sourceSegment, targetSegment);
}

public IReadOnlyList<WordAlignmentMatrix> AlignBatch(
IReadOnlyList<(IReadOnlyList<string> SourceSegment, IReadOnlyList<string> TargetSegment)> segments
)
{
CheckDisposed();

return _aligner.AlignBatch(segments);
}

public IEnumerable<(string TargetWord, double Score)> GetTranslations(string sourceWord, double threshold = 0)
{
CheckDisposed();

foreach ((string targetWord, double dirScore) in _directWordAlignmentEngine.GetTranslations(sourceWord))
{
double invScore = _inverseWordAlignmentEngine.GetTranslationScore(targetWord, sourceWord);
double score = Math.Max(dirScore, invScore);
if (score > threshold)
yield return (targetWord, score);
}
}

public IEnumerable<(int TargetWordIndex, double Score)> GetTranslations(
int sourceWordIndex,
double threshold = 0
)
{
CheckDisposed();

foreach (
(int targetWordIndex, double dirScore) in _directWordAlignmentEngine.GetTranslations(sourceWordIndex)
)
{
double invScore = _inverseWordAlignmentEngine.GetTranslationScore(targetWordIndex, sourceWordIndex);
double score = Math.Max(dirScore, invScore);
if (score > threshold)
yield return (targetWordIndex, score);
}
}

public double GetTranslationScore(string sourceWord, string targetWord)
{
CheckDisposed();

double dirScore = _directWordAlignmentEngine.GetTranslationScore(sourceWord, targetWord);
double invScore = _inverseWordAlignmentEngine.GetTranslationScore(targetWord, sourceWord);
return Math.Max(dirScore, invScore);
}

public double GetTranslationScore(int sourceWordIndex, int targetWordIndex)
{
CheckDisposed();

double dirScore = _directWordAlignmentEngine.GetTranslationScore(sourceWordIndex, targetWordIndex);
double invScore = _inverseWordAlignmentEngine.GetTranslationScore(targetWordIndex, sourceWordIndex);
return Math.Max(dirScore, invScore);
}

public IReadOnlyCollection<AlignedWordPair> GetBestAlignedWordPairs(
IReadOnlyList<string> sourceSegment,
IReadOnlyList<string> targetSegment
)
{
CheckDisposed();

WordAlignmentMatrix matrix = Align(sourceSegment, targetSegment);
IReadOnlyCollection<AlignedWordPair> wordPairs = matrix.ToAlignedWordPairs();
ComputeAlignedWordPairScores(sourceSegment, targetSegment, wordPairs);
return wordPairs;
}

public void ComputeAlignedWordPairScores(
IReadOnlyList<string> sourceSegment,
IReadOnlyList<string> targetSegment,
IReadOnlyCollection<AlignedWordPair> wordPairs
)
{
AlignedWordPair[] inverseWordPairs = wordPairs.Select(wp => wp.Invert()).ToArray();
_directWordAlignmentEngine.ComputeAlignedWordPairScores(sourceSegment, targetSegment, wordPairs);
_inverseWordAlignmentEngine.ComputeAlignedWordPairScores(targetSegment, sourceSegment, inverseWordPairs);
foreach (var (wordPair, inverseWordPair) in wordPairs.Zip(inverseWordPairs, (wp, invWp) => (wp, invWp)))
{
wordPair.TranslationScore = Math.Max(wordPair.TranslationScore, inverseWordPair.TranslationScore);
wordPair.AlignmentScore = Math.Max(wordPair.AlignmentScore, inverseWordPair.AlignmentScore);
}
}

protected override void DisposeManagedResources()
{
_directWordAlignmentEngine.Dispose();
_inverseWordAlignmentEngine.Dispose();
}
}
}
153 changes: 3 additions & 150 deletions src/SIL.Machine/Translation/SymmetrizedWordAlignmentModel.cs
Original file line number Diff line number Diff line change
@@ -1,167 +1,20 @@
using System;
using System.Collections.Generic;
using System.Linq;
using SIL.Machine.Corpora;
using SIL.ObjectModel;
using SIL.Machine.Corpora;

namespace SIL.Machine.Translation
{
public class SymmetrizedWordAlignmentModel : DisposableBase, IWordAlignmentModel
public class SymmetrizedWordAlignmentModel : SymmetrizedWordAlignmentEngine, IWordAlignmentModel
{
private readonly IWordAlignmentModel _directWordAlignmentModel;
private readonly IWordAlignmentModel _inverseWordAlignmentModel;
private readonly SymmetrizedWordAligner _aligner;

public SymmetrizedWordAlignmentModel(
IWordAlignmentModel directWordAlignmentModel,
IWordAlignmentModel inverseWordAlignmentModel
)
: base(directWordAlignmentModel, inverseWordAlignmentModel)
{
_directWordAlignmentModel = directWordAlignmentModel;
_inverseWordAlignmentModel = inverseWordAlignmentModel;
_aligner = new SymmetrizedWordAligner(DirectWordAlignmentModel, InverseWordAlignmentModel);
}

public SymmetrizationHeuristic Heuristic
{
get => _aligner.Heuristic;
set => _aligner.Heuristic = value;
}

public IWordAlignmentModel DirectWordAlignmentModel
{
get
{
CheckDisposed();

return _directWordAlignmentModel;
}
}

public IWordAlignmentModel InverseWordAlignmentModel
{
get
{
CheckDisposed();

return _inverseWordAlignmentModel;
}
}

public IWordVocabulary SourceWords
{
get
{
CheckDisposed();

return _directWordAlignmentModel.SourceWords;
}
}

public IWordVocabulary TargetWords
{
get
{
CheckDisposed();

return _directWordAlignmentModel.TargetWords;
}
}

public IReadOnlySet<int> SpecialSymbolIndices => _directWordAlignmentModel.SpecialSymbolIndices;

public WordAlignmentMatrix Align(IReadOnlyList<string> sourceSegment, IReadOnlyList<string> targetSegment)
{
CheckDisposed();

return _aligner.Align(sourceSegment, targetSegment);
}

public IReadOnlyList<WordAlignmentMatrix> AlignBatch(
IReadOnlyList<(IReadOnlyList<string> SourceSegment, IReadOnlyList<string> TargetSegment)> segments
)
{
CheckDisposed();

return _aligner.AlignBatch(segments);
}

public IEnumerable<(string TargetWord, double Score)> GetTranslations(string sourceWord, double threshold = 0)
{
CheckDisposed();

foreach ((string targetWord, double dirScore) in _directWordAlignmentModel.GetTranslations(sourceWord))
{
double invScore = _inverseWordAlignmentModel.GetTranslationScore(targetWord, sourceWord);
double score = Math.Max(dirScore, invScore);
if (score > threshold)
yield return (targetWord, score);
}
}

public IEnumerable<(int TargetWordIndex, double Score)> GetTranslations(
int sourceWordIndex,
double threshold = 0
)
{
CheckDisposed();

foreach (
(int targetWordIndex, double dirScore) in _directWordAlignmentModel.GetTranslations(sourceWordIndex)
)
{
double invScore = _inverseWordAlignmentModel.GetTranslationScore(targetWordIndex, sourceWordIndex);
double score = Math.Max(dirScore, invScore);
if (score > threshold)
yield return (targetWordIndex, score);
}
}

public double GetTranslationScore(string sourceWord, string targetWord)
{
CheckDisposed();

double dirScore = _directWordAlignmentModel.GetTranslationScore(sourceWord, targetWord);
double invScore = _inverseWordAlignmentModel.GetTranslationScore(targetWord, sourceWord);
return Math.Max(dirScore, invScore);
}

public double GetTranslationScore(int sourceWordIndex, int targetWordIndex)
{
CheckDisposed();

double dirScore = _directWordAlignmentModel.GetTranslationScore(sourceWordIndex, targetWordIndex);
double invScore = _inverseWordAlignmentModel.GetTranslationScore(targetWordIndex, sourceWordIndex);
return Math.Max(dirScore, invScore);
}

public IReadOnlyCollection<AlignedWordPair> GetBestAlignedWordPairs(
IReadOnlyList<string> sourceSegment,
IReadOnlyList<string> targetSegment
)
{
CheckDisposed();

WordAlignmentMatrix matrix = Align(sourceSegment, targetSegment);
IReadOnlyCollection<AlignedWordPair> wordPairs = matrix.ToAlignedWordPairs();
ComputeAlignedWordPairScores(sourceSegment, targetSegment, wordPairs);
return wordPairs;
}

public void ComputeAlignedWordPairScores(
IReadOnlyList<string> sourceSegment,
IReadOnlyList<string> targetSegment,
IReadOnlyCollection<AlignedWordPair> wordPairs
)
{
AlignedWordPair[] inverseWordPairs = wordPairs.Select(wp => wp.Invert()).ToArray();
_directWordAlignmentModel.ComputeAlignedWordPairScores(sourceSegment, targetSegment, wordPairs);
_inverseWordAlignmentModel.ComputeAlignedWordPairScores(targetSegment, sourceSegment, inverseWordPairs);
foreach (var (wordPair, inverseWordPair) in wordPairs.Zip(inverseWordPairs, (wp, invWp) => (wp, invWp)))
{
wordPair.TranslationScore = Math.Max(wordPair.TranslationScore, inverseWordPair.TranslationScore);
wordPair.AlignmentScore = Math.Max(wordPair.AlignmentScore, inverseWordPair.AlignmentScore);
}
}

public ITrainer CreateTrainer(IParallelTextCorpus corpus)
Expand Down

0 comments on commit dd337ab

Please sign in to comment.