Skip to content

Commit

Permalink
feat: adapt nlu prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
marrouchi committed Sep 23, 2024
1 parent fa055fa commit 7764465
Show file tree
Hide file tree
Showing 11 changed files with 133 additions and 110 deletions.
6 changes: 6 additions & 0 deletions api/src/chat/services/bot.service.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ import { MenuService } from '@/cms/services/menu.service';
import { offlineEventText } from '@/extensions/channels/offline/__test__/events.mock';
import OfflineHandler from '@/extensions/channels/offline/index.channel';
import OfflineEventWrapper from '@/extensions/channels/offline/wrapper';
import { LanguageRepository } from '@/i18n/repositories/language.repository';
import { LanguageModel } from '@/i18n/schemas/language.schema';
import { I18nService } from '@/i18n/services/i18n.service';
import { LanguageService } from '@/i18n/services/language.service';
import { LoggerService } from '@/logger/logger.service';
import { NlpEntityRepository } from '@/nlp/repositories/nlp-entity.repository';
import { NlpSampleEntityRepository } from '@/nlp/repositories/nlp-sample-entity.repository';
Expand Down Expand Up @@ -107,6 +110,7 @@ describe('BlockService', () => {
NlpEntityModel,
NlpSampleEntityModel,
NlpSampleModel,
LanguageModel,
]),
],
providers: [
Expand All @@ -126,6 +130,7 @@ describe('BlockService', () => {
NlpEntityRepository,
NlpSampleEntityRepository,
NlpSampleRepository,
LanguageRepository,
BlockService,
CategoryService,
ContentTypeService,
Expand All @@ -143,6 +148,7 @@ describe('BlockService', () => {
NlpSampleEntityService,
NlpSampleService,
NlpService,
LanguageService,
{
provide: PluginService,
useValue: {},
Expand Down
24 changes: 22 additions & 2 deletions api/src/extensions/helpers/nlp/default/__test__/index.mock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ export const nlpEmptyFormated: DatasetType = {
name: 'product',
elements: ['pizza', 'sandwich'],
},
{
elements: ['en', 'fr'],
name: 'language',
},
],
entity_synonyms: [
{
Expand All @@ -34,17 +38,33 @@ export const nlpEmptyFormated: DatasetType = {

export const nlpFormatted: DatasetType = {
common_examples: [
{ text: 'Hello', intent: 'greeting', entities: [] },
{
text: 'Hello',
intent: 'greeting',
entities: [
{
entity: 'language',
value: 'en',
},
],
},
{
text: 'i want to order a pizza',
intent: 'order',
entities: [{ entity: 'product', value: 'pizza', start: 19, end: 23 }],
entities: [
{ entity: 'product', value: 'pizza', start: 19, end: 23 },
{
entity: 'language',
value: 'en',
},
],
},
],
regex_features: [],
lookup_tables: [
{ name: 'intent', elements: ['greeting', 'order'] },
{ name: 'product', elements: ['pizza', 'sandwich'] },
{ name: 'language', elements: ['en', 'fr'] },
],
entity_synonyms: [
{
Expand Down
45 changes: 30 additions & 15 deletions api/src/extensions/helpers/nlp/default/__test__/index.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@
*/

import { HttpModule } from '@nestjs/axios';
import { CACHE_MANAGER } from '@nestjs/cache-manager';
import { EventEmitter2 } from '@nestjs/event-emitter';
import { MongooseModule } from '@nestjs/mongoose';
import { Test, TestingModule } from '@nestjs/testing';

import { LanguageRepository } from '@/i18n/repositories/language.repository';
import { LanguageModel } from '@/i18n/schemas/language.schema';
import { LanguageService } from '@/i18n/services/language.service';
import { LoggerService } from '@/logger/logger.service';
import { NlpEntityRepository } from '@/nlp/repositories/nlp-entity.repository';
import { NlpSampleEntityRepository } from '@/nlp/repositories/nlp-sample-entity.repository';
Expand Down Expand Up @@ -56,10 +60,24 @@ describe('NLP Default Helper', () => {
NlpValueModel,
NlpSampleModel,
NlpSampleEntityModel,
LanguageModel,
]),
HttpModule,
],
providers: [
NlpService,
NlpSampleService,
NlpSampleRepository,
NlpEntityService,
NlpEntityRepository,
NlpValueService,
NlpValueRepository,
NlpSampleEntityService,
NlpSampleEntityRepository,
LanguageService,
LanguageRepository,
EventEmitter2,
DefaultNlpHelper,
LoggerService,
{
provide: SettingService,
Expand All @@ -76,17 +94,14 @@ describe('NLP Default Helper', () => {
})),
},
},
NlpService,
NlpSampleService,
NlpSampleRepository,
NlpEntityService,
NlpEntityRepository,
NlpValueService,
NlpValueRepository,
NlpSampleEntityService,
NlpSampleEntityRepository,
EventEmitter2,
DefaultNlpHelper,
{
provide: CACHE_MANAGER,
useValue: {
del: jest.fn(),
get: jest.fn(),
set: jest.fn(),
},
},
],
}).compile();
settingService = module.get<SettingService>(SettingService);
Expand All @@ -103,15 +118,15 @@ describe('NLP Default Helper', () => {
expect(nlp).toBeDefined();
});

it('should format empty training set properly', () => {
it('should format empty training set properly', async () => {
const nlp = nlpService.getNLP();
const results = nlp.format([], entitiesMock);
const results = await nlp.format([], entitiesMock);
expect(results).toEqual(nlpEmptyFormated);
});

it('should format training set properly', () => {
it('should format training set properly', async () => {
const nlp = nlpService.getNLP();
const results = nlp.format(samplesMock, entitiesMock);
const results = await nlp.format(samplesMock, entitiesMock);
expect(results).toEqual(nlpFormatted);
});

Expand Down
85 changes: 12 additions & 73 deletions api/src/extensions/helpers/nlp/default/index.nlp.helper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,13 @@ import { Injectable } from '@nestjs/common';
import { LoggerService } from '@/logger/logger.service';
import BaseNlpHelper from '@/nlp/lib/BaseNlpHelper';
import { Nlp } from '@/nlp/lib/types';
import { NlpEntity, NlpEntityFull } from '@/nlp/schemas/nlp-entity.schema';
import { NlpEntityFull } from '@/nlp/schemas/nlp-entity.schema';
import { NlpSampleFull } from '@/nlp/schemas/nlp-sample.schema';
import { NlpValue } from '@/nlp/schemas/nlp-value.schema';
import { NlpEntityService } from '@/nlp/services/nlp-entity.service';
import { NlpSampleService } from '@/nlp/services/nlp-sample.service';
import { NlpService } from '@/nlp/services/nlp.service';

import {
CommonExample,
DatasetType,
EntitySynonym,
ExampleEntity,
LookupTable,
NlpParseResultType,
} from './types';
import { DatasetType, NlpParseResultType } from './types';

@Injectable()
export default class DefaultNlpHelper extends BaseNlpHelper {
Expand Down Expand Up @@ -61,69 +53,16 @@ export default class DefaultNlpHelper extends BaseNlpHelper {
* @param entities - All available entities
* @returns {DatasetType} - The formatted RASA training set
*/
format(samples: NlpSampleFull[], entities: NlpEntityFull[]): DatasetType {
const entityMap = NlpEntity.getEntityMap(entities);
const valueMap = NlpValue.getValueMap(
NlpValue.getValuesFromEntities(entities),
async format(
samples: NlpSampleFull[],
entities: NlpEntityFull[],
): Promise<DatasetType> {
const nluData = await this.nlpSampleService.formatRasaNlu(
samples,
entities,
);

const common_examples: CommonExample[] = samples
.filter((s) => s.entities.length > 0)
.map((s) => {
const intent = s.entities.find(
(e) => entityMap[e.entity].name === 'intent',
);
if (!intent) {
throw new Error('Unable to find the `intent` nlp entity.');
}
const sampleEntities: ExampleEntity[] = s.entities
.filter((e) => entityMap[<string>e.entity].name !== 'intent')
.map((e) => {
const res: ExampleEntity = {
entity: entityMap[<string>e.entity].name,
value: valueMap[<string>e.value].value,
};
if ('start' in e && 'end' in e) {
Object.assign(res, {
start: e.start,
end: e.end,
});
}
return res;
});
return {
text: s.text,
intent: valueMap[intent.value].value,
entities: sampleEntities,
};
});
const lookup_tables: LookupTable[] = entities.map((e) => {
return {
name: e.name,
elements: e.values.map((v) => {
return v.value;
}),
};
});
const entity_synonyms = entities
.reduce((acc, e) => {
const synonyms = e.values.map((v) => {
return {
value: v.value,
synonyms: v.expressions,
};
});
return acc.concat(synonyms);
}, [] as EntitySynonym[])
.filter((s) => {
return s.synonyms.length > 0;
});
return {
common_examples,
regex_features: [],
lookup_tables,
entity_synonyms,
};
return nluData;
}

/**
Expand All @@ -138,7 +77,7 @@ export default class DefaultNlpHelper extends BaseNlpHelper {
entities: NlpEntityFull[],
): Promise<any> {
const self = this;
const nluData: DatasetType = self.format(samples, entities);
const nluData: DatasetType = await self.format(samples, entities);
// Train samples
const result = await this.httpService.axiosRef.post(
`${this.settings.endpoint}/train`,
Expand Down Expand Up @@ -169,7 +108,7 @@ export default class DefaultNlpHelper extends BaseNlpHelper {
entities: NlpEntityFull[],
): Promise<any> {
const self = this;
const nluTestData: DatasetType = self.format(samples, entities);
const nluTestData: DatasetType = await self.format(samples, entities);
// Evaluate model with test samples
return await this.httpService.axiosRef.post(
`${this.settings.endpoint}/evaluate`,
Expand Down
14 changes: 12 additions & 2 deletions api/src/i18n/controllers/language.controller.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,23 @@ describe('LanguageController', () => {
});

describe('findPage', () => {
const pageQuery = getPageQuery<Language>();
const pageQuery = getPageQuery<Language>({ sort: ['code', 'asc'] });
it('should find languages', async () => {
jest.spyOn(languageService, 'findPage');
const result = await languageController.findPage(pageQuery, {});

expect(languageService.findPage).toHaveBeenCalledWith({}, pageQuery);
expect(result).toEqualPayload(languageFixtures);
expect(result).toEqualPayload(
languageFixtures.sort(({ code: codeA }, { code: codeB }) => {
if (codeA < codeB) {
return -1;
}
if (codeA > codeB) {
return 1;
}
return 0;
}),
);
});
});

Expand Down
2 changes: 1 addition & 1 deletion api/src/nlp/controllers/nlp-sample.controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ export class NlpSampleController extends BaseController<
type ? { type } : {},
);
const entities = await this.nlpEntityService.findAllAndPopulate();
const result = this.nlpSampleService.formatRasaNlu(samples, entities);
const result = await this.nlpSampleService.formatRasaNlu(samples, entities);

// Sending the JSON data as a file
const buffer = Buffer.from(JSON.stringify(result));
Expand Down
11 changes: 11 additions & 0 deletions api/src/nlp/services/nlp-sample.service.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
* 3. SaaS Restriction: This software, or any derivative of it, may not be used to offer a competing product or service (SaaS) without prior written consent from Hexastack. Offering the software as a service or using it in a commercial cloud environment without express permission is strictly prohibited.
*/

import { CACHE_MANAGER } from '@nestjs/cache-manager';
import { EventEmitter2 } from '@nestjs/event-emitter';
import { MongooseModule } from '@nestjs/mongoose';
import { Test, TestingModule } from '@nestjs/testing';

import { LanguageRepository } from '@/i18n/repositories/language.repository';
import { Language, LanguageModel } from '@/i18n/schemas/language.schema';
import { LanguageService } from '@/i18n/services/language.service';
import { nlpSampleFixtures } from '@/utils/test/fixtures/nlpsample';
import { installNlpSampleEntityFixtures } from '@/utils/test/fixtures/nlpsampleentity';
import { getPageQuery } from '@/utils/test/pagination';
Expand Down Expand Up @@ -68,7 +70,16 @@ describe('NlpSampleService', () => {
NlpSampleEntityService,
NlpEntityService,
NlpValueService,
LanguageService,
EventEmitter2,
{
provide: CACHE_MANAGER,
useValue: {
del: jest.fn(),
get: jest.fn(),
set: jest.fn(),
},
},
],
}).compile();
nlpSampleService = module.get<NlpSampleService>(NlpSampleService);
Expand Down
Loading

0 comments on commit 7764465

Please sign in to comment.