Skip to content

Commit

Permalink
Make use of the new sqlite-based samplers for augmentation
Browse files Browse the repository at this point in the history
Should help with #778
  • Loading branch information
gcampax committed Dec 2, 2021
1 parent ea103e8 commit 6cc9dba
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 66 deletions.
10 changes: 1 addition & 9 deletions lib/dataset-tools/augmentation/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import Stream from 'stream';
import { SchemaRetriever } from 'thingtalk';
import * as Tp from 'thingpedia';

import ParameterReplacer from './replace_parameters';
import ParameterReplacer, { ParameterProvider } from './replace_parameters';
import SingleDeviceAugmenter from './single_device_augmenter';

import { SentenceExample } from '../parsers';
Expand Down Expand Up @@ -52,14 +52,6 @@ interface DatasetAugmenterOptions {
numAttempts : number;
}

interface ParameterRecord {
preprocessed : string;
weight : number;
}
interface ParameterProvider {
get(type : 'entity'|'string', key : string) : Promise<ParameterRecord[]>;
}

export default class DatasetAugmenter extends Stream.Transform {
private _options : DatasetAugmenterOptions;
private _rng : () => number;
Expand Down
119 changes: 62 additions & 57 deletions lib/dataset-tools/augmentation/replace_parameters.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,10 @@ function constantToNN(constant : string) : string {
return entitytype + constant.substring(underscoreindex);
}

function adjustForLength(sentence : string, weight : number) : number {
const length = sentence.split(' ').length;
return weight / Math.exp((length-1)/3);
}

interface ValueList {
readonly size : number;
sample(rng : () => number) : ParameterRecord;
}
type ParameterRecord = Tp.FileParameterProvider.ParameterRecord;
type Sampler = Tp.FileParameterProvider.Sampler;

class NumberValueList implements ValueList {
class NumberSampler implements Sampler {
private _min : number;
private _max : number;
private _isMeasure : boolean;
Expand All @@ -132,16 +125,16 @@ class NumberValueList implements ValueList {
throw new Error(`Unexpected ${value} with bounds ${this._min} / ${this._max} (isMeasure = ${this._isMeasure})`);
}

sample(rng : () => number) : ParameterRecord {
async sample(rng : () => number) : Promise<ParameterRecord> {
if (this._isMeasure) {
// for measurements, sample uniformly between the (adjusted) bounds,

const value = (this._min + (this._max - this._min) * rng());
this._checkFinite(value);
if (Math.abs(value) > 2)
return { preprocessed: value.toFixed(coin(0.9, rng) ? 0 : 1) };
return { preprocessed: value.toFixed(coin(0.9, rng) ? 0 : 1), value: '', weight: 1.0 };
else
return { preprocessed: value.toPrecision(2) };
return { preprocessed: value.toPrecision(2), value: '', weight: 1.0 };
}

// sample an "easy" number
Expand Down Expand Up @@ -181,24 +174,22 @@ class NumberValueList implements ValueList {
} while (val < this._min || val > this._max || val === 0 || val === 1);

this._checkFinite(val);
return { preprocessed: String(val) };
return { preprocessed: String(val), value: '', weight: 1.0 };
}
}

class WeightedValueList implements ValueList {
class WeightedSampler implements Sampler {
private _values : ParameterRecord[];
private _cumsum : number[];

constructor(values : ParameterRecord[], weights : number[]) {
assert.strictEqual(values.length, weights.length);

constructor(values : ParameterRecord[]) {
this._values = values;

if (weights.length > 0) {
const cumsum = new Array(weights.length);
cumsum[0] = adjustForLength(values[0].preprocessed, weights[0]);
for (let i = 1; i < weights.length; i++)
cumsum[i] = cumsum[i-1] + adjustForLength(values[i].preprocessed, weights[i]);
if (values.length > 0) {
const cumsum = new Array(values.length);
cumsum[0] = values[0].weight ?? 1;
for (let i = 1; i < values.length; i++)
cumsum[i] = cumsum[i-1] + (values[i].weight ?? 1);
this._cumsum = cumsum;
} else {
this._cumsum = [];
Expand All @@ -209,13 +200,13 @@ class WeightedValueList implements ValueList {
return this._values.length;
}

sample(rng : () => number) : ParameterRecord {
async sample(rng : () => number) : Promise<ParameterRecord> {
const sample = rng() * this._cumsum[this._cumsum.length-1];
return this._values[binarySearch(this._cumsum, sample)];
}
}

class UniformValueList implements ValueList {
class UniformSampler implements Sampler {
private _values : ParameterRecord[];

constructor(values : ParameterRecord[]) {
Expand All @@ -226,12 +217,12 @@ class UniformValueList implements ValueList {
return this._values.length;
}

sample(rng : () => number) {
async sample(rng : () => number) {
return uniform(this._values, rng);
}
}

class SequentialValueList implements ValueList {
class SequentialSampler implements Sampler {
private _values : ParameterRecord[];
private _index : number;
constructor(values : ParameterRecord[]) {
Expand All @@ -243,7 +234,7 @@ class SequentialValueList implements ValueList {
return this._values.length;
}

sample(rng : () => number) {
async sample(rng : () => number) {
if (this._index === this._values.length)
this._index = 0;
const value = this._values[this._index];
Expand All @@ -252,20 +243,16 @@ class SequentialValueList implements ValueList {
}
}

interface ParameterRecord {
preprocessed : string;
value ?: string;
weight ?: number;
}
interface ParameterProvider {
export interface ParameterProvider {
getSampler(type : 'entity'|'string', key : string, mode : Tp.FileParameterProvider.SampleMode) : Promise<Sampler>;
get(type : 'entity'|'string', key : string) : Promise<ParameterRecord[]>;
}

type SamplingType = 'random' | 'uniform' | 'default' | 'sequential';

class ValueListLoader {
private _provider : ParameterProvider;
private _cache : Map<string, Promise<ValueList>>;
private _cache : Map<string, Promise<Sampler>>;
private _samplingType : SamplingType;
private _subsetParamSet : [number, number];
private _rng : () => number;
Expand All @@ -282,7 +269,7 @@ class ValueListLoader {
this._rng = rng;
}

get([valueListType, valueListName] : ['string'|'entity', string|string[]]) : Promise<ValueList> {
get([valueListType, valueListName] : ['string'|'entity', string|string[]]) : Promise<Sampler> {
const name = Array.isArray(valueListName) ? valueListName[0] : valueListName;
const key = valueListType + ':' + name;
if (this._cache.has(key))
Expand All @@ -293,9 +280,27 @@ class ValueListLoader {
return promise;
}

private async _load(valueListType : 'string'|'entity', valueListName : string|string[]) : Promise<ValueList> {
private async _load(valueListType : 'string'|'entity', valueListName : string|string[]) : Promise<Sampler> {
if (!Array.isArray(valueListName))
valueListName = [valueListName];

// try handling the common cases without loading the entire list of rows
// in memory
// the sampler returned in this case use sqlite to sample
if (valueListName.length === 1 &&
this._samplingType !== 'random' &&
this._subsetParamSet[0] === 0 &&
this._subsetParamSet[1] === 1) {
let sampleMode;
if (this._samplingType === 'sequential')
sampleMode = Tp.FileParameterProvider.SampleMode.SEQUENTIAL;
else if (this._samplingType === 'uniform')
sampleMode = Tp.FileParameterProvider.SampleMode.UNIFORM;
else
sampleMode = Tp.FileParameterProvider.SampleMode.WEIGHTED;
return this._provider.getSampler(valueListType, valueListName[0], sampleMode);
}

let rows : ParameterRecord[] = [];
for (const name of valueListName)
rows = rows.concat(await this._provider.get(valueListType, name));
Expand Down Expand Up @@ -332,11 +337,11 @@ class ValueListLoader {
// (ie, the range is significantly smaller than the average)
// we use a uniform sampler, which is faster
if (this._samplingType === 'sequential')
return new SequentialValueList(rows);
return new SequentialSampler(rows);
else if ((maxWeight - minWeight) / (sumWeight / rows.length) < 0.0001)
return new UniformValueList(rows);
return new UniformSampler(rows);
else
return new WeightedValueList(rows, rows.map((r) => r.weight||1));
return new WeightedSampler(rows);
}
}

Expand Down Expand Up @@ -529,7 +534,7 @@ export default class ParameterReplacer {
this._entityDescendants = {};
const entities = await this._tpClient.getAllEntityTypes();
for (const entity of entities) {
if (!entity.subtype_of)
if (!entity.subtype_of)
continue;
for (const parent of entity.subtype_of) {
if (!(parent in this._entityDescendants))
Expand Down Expand Up @@ -640,7 +645,7 @@ export default class ParameterReplacer {
return { sentenceValue, programValue };
}

private async _getValueListForSlot(slot : Ast.AbstractSlot) : Promise<[ValueList, Ast.ArgumentDef|null, Type, string]> {
private async _getValueListForSlot(slot : Ast.AbstractSlot) : Promise<[Sampler, Ast.ArgumentDef|null, Type, string]> {
const arg = this._getSlotArg(slot);

let operator;
Expand Down Expand Up @@ -679,7 +684,7 @@ export default class ParameterReplacer {
min = 1;
max = 5;
}
return [new NumberValueList(min, max, !!unit), arg, slot.type, operator];
return [new NumberSampler(min, max, !!unit), arg, slot.type, operator];
}

let valueListKey = await this._getParamListKey(slot, arg);
Expand All @@ -701,12 +706,12 @@ export default class ParameterReplacer {
return [valueList, arg, slot.type, operator];
}

private async _getValueListForToken(token : string) : Promise<[ValueList, Type, string]> {
private async _getValueListForToken(token : string) : Promise<[Sampler, Type, string]> {
let type, valueListKey : ['string' | 'entity', string], fallbackKey;
if (token.startsWith('NUMBER_')) {
// choose reasonable bounds if we don't know the bound
const match = /NUMBER_[0-9]+__([a-zA-Z0-9]+)/.exec(token);
return [new NumberValueList(0, 1000, !!match),
return [new NumberSampler(0, 1000, !!match),
match ? new Type.Measure(match[1]) : Type.Number, '='];
} else if (token.startsWith('LOCATION_')) {
valueListKey = ['string', 'tt:location'];
Expand Down Expand Up @@ -755,18 +760,18 @@ export default class ParameterReplacer {
return [valueList, type, '='];
}

private _sampleParam(key : string,
arg : Ast.ArgumentDef|null,
valueList : ValueList,
type : Type,
operator : string,
replacedValuesSet : Set<ParameterRecord>) : ReplacementRecord|null {
private async _sampleParam(key : string,
arg : Ast.ArgumentDef|null,
valueList : Sampler,
type : Type,
operator : string,
replacedValuesSet : Set<ParameterRecord>) : Promise<ReplacementRecord|null> {
let typeValue = undefined;
if (type instanceof Type.Entity)
typeValue = type.type;
let attempts = this._numAttempts;
while (attempts > 0) {
const sampled = valueList.sample(this._rng);
const sampled = await valueList.sample(this._rng);
let words = sampled.preprocessed.split(' ');
words = Array.from(resampleIgnorableAndAbbreviations(this._paramLangPack, type, words, this._rng));

Expand All @@ -784,7 +789,7 @@ export default class ParameterReplacer {
continue;
}

return this._transformValue(candidate, { preprocessed: candidate }, arg);
return this._transformValue(candidate, { preprocessed: candidate, value: candidate, weight: 1.0 }, arg);
}

if (!type.isNumeric()) {
Expand Down Expand Up @@ -838,7 +843,7 @@ export default class ParameterReplacer {
[valueList, type, operator] = await this._getValueListForToken(token);
}
}
const replace = this._sampleParam(key, arg, valueList, type, operator, replacedValueSet);
const replace = await this._sampleParam(key, arg, valueList, type, operator, replacedValueSet);
if (!replace) {
output.push(token);
} else {
Expand All @@ -863,17 +868,17 @@ export default class ParameterReplacer {
if (token.startsWith('LOCATION_')) {
output.push('new', 'Location', '(', '"', string, '"', ')');
} else if (token.startsWith('GENERIC_ENTITY_')) {
if (this._includeEntityValue && value)
if (this._includeEntityValue && value)
output.push('"', value, '"');
else
else
output.push('null');
output.push('^^' + token.substring('GENERIC_ENTITY_'.length, token.length-2), '(', '"', string, '"', ')');
} else if (token.startsWith('NUMBER_')) {
output.push(string);
} else {
output.push('"', string, '"');
}

if (token.startsWith('HASHTAG_'))
output.push('^^tt:hashtag');
else if (token.startsWith('USERNAME_'))
Expand Down

0 comments on commit 6cc9dba

Please sign in to comment.