Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change(web): track the base correction for generated predictions 📚 #11875

Merged
merged 3 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,18 @@ describe('LMLayer using dummy model', function () {
// Since Firefox can't do JSON imports quite yet.
const hazelFixture = await fetch(new URL(`${domain}/resources/json/models/future_suggestions/i_got_distracted_by_hazel.json`));
hazelModel = await hazelFixture.json();
hazelModel = hazelModel.map((set) => set.map((entry) => {
return {
...entry,
// Dummy-model predictions all claim probability 1; there's no actual probability stuff
// used here.
'lexical-p': 1,
// We're predicting from a single transform, not a distribution, so probability 1.
'correction-p': 1,
// Multiply 'em together.
p: 1,
}
}));
});

describe('Prediction', function () {
Expand Down
13 changes: 12 additions & 1 deletion common/test/resources/model-helpers.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,18 @@ export function randomToken() {
}

export function iGotDistractedByHazel() {
return jsonFixture('models/future_suggestions/i_got_distracted_by_hazel');
return jsonFixture('models/future_suggestions/i_got_distracted_by_hazel').map((set) => set.map((entry) => {
return {
...entry,
// Dummy-model predictions all claim probability 1; there's no actual probability stuff
// used here.
'lexical-p': 1,
// We're predicting from a single transform, not a distribution, so probability 1.
'correction-p': 1,
// Multiply 'em together.
p: 1,
}
}));
}

export function jsonFixture(name, root, import_root) {
Expand Down
119 changes: 72 additions & 47 deletions common/web/lm-worker/src/main/model-compositor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ import * as correction from './correction/index.js'

import TransformUtils from './transformUtils.js';

type CorrectionPredictionTuple = {
prediction: ProbabilityMass<Suggestion>,
correction: ProbabilityMass<string>,
totalProb: number;
};

export default class ModelCompositor {
private lexicalModel: LexicalModel;
private contextTracker?: correction.ContextTracker;
Expand Down Expand Up @@ -39,9 +45,13 @@ export default class ModelCompositor {

private SUGGESTION_ID_SEED = 0;

private testMode: boolean = false
private testMode: boolean = false;
private verbose: boolean = true;

constructor(lexicalModel: LexicalModel, testMode?: boolean) {
constructor(
lexicalModel: LexicalModel,
testMode?: boolean
) {
this.lexicalModel = lexicalModel;
if(lexicalModel.traverseFromRoot) {
this.contextTracker = new correction.ContextTracker();
Expand All @@ -50,23 +60,32 @@ export default class ModelCompositor {
this.testMode = !!testMode;
}

private predictFromCorrections(corrections: ProbabilityMass<Transform>[], context: Context): Distribution<Suggestion> {
let returnedPredictions: Distribution<Suggestion> = [];
private predictFromCorrections(corrections: ProbabilityMass<Transform>[], context: Context): CorrectionPredictionTuple[] {
let returnedPredictions: CorrectionPredictionTuple[] = [];

for(let correction of corrections) {
let predictions = this.lexicalModel.predict(correction.sample, context);

const { sample: correctionTransform, p: correctionProb } = correction;
const correctionRoot = this.wordbreak(models.applyTransform(correction.sample, context));

let predictionSet = predictions.map(function(pair: ProbabilityMass<Suggestion>) {
let transform = correction.sample;
let inputProb = correction.p;

// Let's not rely on the model to copy transform IDs.
// Only bother is there IS an ID to copy.
if(transform.id !== undefined) {
pair.sample.transformId = transform.id;
if(correctionTransform.id !== undefined) {
pair.sample.transformId = correctionTransform.id;
}

let prediction = {sample: pair.sample, p: pair.p * inputProb};
return prediction;
let tuple: CorrectionPredictionTuple = {
prediction: pair,
correction: {
sample: correctionRoot,
p: correctionProb
},
totalProb: pair.p * correctionProb
};
return tuple;
}, this);

returnedPredictions = returnedPredictions.concat(predictionSet);
Expand All @@ -76,7 +95,7 @@ export default class ModelCompositor {
}

async predict(transformDistribution: Transform | Distribution<Transform>, context: Context): Promise<Suggestion[]> {
let suggestionDistribution: Distribution<Suggestion> = [];
let suggestionDistribution: CorrectionPredictionTuple[] = [];
let lexicalModel = this.lexicalModel;
let punctuation = this.punctuation;

Expand Down Expand Up @@ -119,7 +138,7 @@ export default class ModelCompositor {
let keepOptionText = this.wordbreak(postContext);
let keepOption: Outcome<Keep> = null;

let rawPredictions: Distribution<Suggestion> = [];
let rawPredictions: CorrectionPredictionTuple[] = [];

// Used to restore whitespaces if operations would remove them.
let prefixTransform: Transform;
Expand Down Expand Up @@ -318,10 +337,10 @@ export default class ModelCompositor {
// If we're getting the same prediction again, it's lower-cost. Update!
let oldPredictionSet = correctionPredictionMap[match.matchString];
if(oldPredictionSet) {
rawPredictions = rawPredictions.filter((entry) => !oldPredictionSet.find((match) => entry == match))
rawPredictions = rawPredictions.filter((entry) => !oldPredictionSet.find((match) => entry.prediction.sample == match.sample));
}

correctionPredictionMap[match.matchString] = predictions;
correctionPredictionMap[match.matchString] = predictions.map((entry) => entry.prediction);

rawPredictions = rawPredictions.concat(predictions);

Expand All @@ -337,13 +356,13 @@ export default class ModelCompositor {
} else {
// Sort the prediction list; we need them in descending order for the next check.
rawPredictions.sort(function(a, b) {
return b.p - a.p;
return b.totalProb - a.totalProb;
});

// If the best suggestion from the search's current tier fails to beat the worst
// pending suggestion from previous tiers, assume all further corrections will
// similarly fail to win; terminate the search-loop.
if(rawPredictions[ModelCompositor.MAX_SUGGESTIONS-1].p > Math.exp(-correctionCost)) {
if(rawPredictions[ModelCompositor.MAX_SUGGESTIONS-1].totalProb > Math.exp(-correctionCost)) {
break;
}
}
Expand All @@ -361,7 +380,7 @@ export default class ModelCompositor {
// Section 2 - post-analysis for our generated predictions, managing 'keep'.
// Assumption: Duplicated 'displayAs' properties indicate duplicated Suggestions.
// When true, we can use an 'associative array' to de-duplicate everything.
let suggestionDistribMap: {[key: string]: ProbabilityMass<Suggestion>} = {};
let suggestionDistribMap: {[key: string]: CorrectionPredictionTuple} = {};
let currentCasing: CasingForm = null;
if(lexicalModel.languageUsesCasing) {
currentCasing = this.detectCurrentCasing(postContext);
Expand All @@ -370,9 +389,12 @@ export default class ModelCompositor {
let baseWord = this.wordbreak(context);

// Deduplicator + annotator of 'keep' suggestions.
for(let prediction of rawPredictions) {
for(let tuple of rawPredictions) {
const prediction = tuple.prediction.sample;
const prob = tuple.totalProb;

// Combine duplicate samples.
let displayText = prediction.sample.displayAs;
let displayText = prediction.displayAs;
let preserveAsKeep = displayText == keepOptionText;

// De-duplication should be case-insensitive, but NOT
Expand All @@ -384,7 +406,7 @@ export default class ModelCompositor {
if(preserveAsKeep) {
// Preserve the original, pre-keyed version of the text.
if(!keepOption) {
let baseTransform = prediction.sample.transform;
let baseTransform = prediction.transform;

let keepTransform = {
insert: keepOptionText,
Expand All @@ -393,32 +415,32 @@ export default class ModelCompositor {
id: baseTransform.id
}

let intermediateKeep = models.transformToSuggestion(keepTransform, prediction.p);
let intermediateKeep = models.transformToSuggestion(keepTransform, prob);
keepOption = this.toAnnotatedSuggestion(intermediateKeep, 'keep', models.QuoteBehavior.noQuotes);
keepOption.matchesModel = true;

// Since we replaced the original Suggestion with a keep-annotated one,
// we must manually preserve the transform ID.
keepOption.transformId = prediction.sample.transformId;
} else if(keepOption.p && prediction.p) {
keepOption.p += prediction.p;
keepOption.transformId = prediction.transformId;
} else if(keepOption.p && prob) {
keepOption.p += prob;
}
} else {
// Apply capitalization rules now; facilitates de-duplication of suggestions
// that may be caused as a result.
//
// Example: "apple" and "Apple" are separate when 'lower', but identical for 'initial' and 'upper'.
if(currentCasing && currentCasing != 'lower') {
this.applySuggestionCasing(prediction.sample, baseWord, currentCasing);
this.applySuggestionCasing(prediction, baseWord, currentCasing);
// update the mapping string, too.
displayText = prediction.sample.displayAs;
displayText = prediction.displayAs;
}

let existingSuggestion = suggestionDistribMap[displayText];
if(existingSuggestion) {
existingSuggestion.p += prediction.p;
existingSuggestion.totalProb += prob;
} else {
suggestionDistribMap[displayText] = prediction;
suggestionDistribMap[displayText] = tuple;
}
}
}
Expand Down Expand Up @@ -448,28 +470,31 @@ export default class ModelCompositor {
}

suggestionDistribution = suggestionDistribution.sort(function(a, b) {
return b.p - a.p; // Use descending order - we want the largest probabilty suggestions first!
return b.totalProb - a.totalProb; // Use descending order - we want the largest probabilty suggestions first!
});

let suggestions = suggestionDistribution.splice(0, ModelCompositor.MAX_SUGGESTIONS).map(function(value) {
let sample: Suggestion & {
p?: number,
"lexical-p"?: number,
"correction-p"?: number
} = value.sample;

if(sample['p']) {
// For analysis / debugging
sample['lexical-p'] = sample['p'];
sample['correction-p'] = value.p / sample['p'];
// Use of the Trie model always exposed the lexical model's probability for a word to KMW.
// It's useful for debugging right now, so may as well repurpose it as the posterior.
//
// We still condition on 'p' existing so that test cases aren't broken.
sample['p'] = value.p;
let suggestions = suggestionDistribution.splice(0, ModelCompositor.MAX_SUGGESTIONS).map((tuple) => {
const prediction = tuple.prediction;

if(!this.verbose) {
return {
...prediction.sample,
p: tuple.totalProb
};
} else {
const sample: Suggestion & {
p?: number,
"lexical-p"?: number,
"correction-p"?: number
} = {
...prediction.sample,
p: tuple.totalProb,
"lexical-p": tuple.prediction.p,
"correction-p": tuple.correction.p
}

return sample;
}
//
return sample;
});

if(keepOption) {
Expand Down
Loading