From 4e655d3679d970d02f0e3afe37fd10d745cac301 Mon Sep 17 00:00:00 2001 From: Audrey Sage Lorberfeld Date: Mon, 21 Oct 2024 10:27:01 -0700 Subject: [PATCH] Allow users to pass customField to /rerank endpoint (#303) ## Problem The current implementation of `/rerank` in the TS client does not (correctly) allow users to pass a custom field upon which to rerank. ## Solution Allow custom fields! Please reference this PR to account for all expected functionality: https://github.com/pinecone-io/python-plugin-inference/pull/21/files ## Type of Change - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [x] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] This change requires a documentation update - [ ] Infrastructure change (CI configs, etc) - [ ] Non-code change (docs, etc) - [ ] None of the above: (explain here) ## Test Plan CI passes + reviewer xreferences PR above w/functionality introduced in this PR. --- - To see the specific tasks where the Asana app for GitHub is being used, see below: - https://app.asana.com/0/0/1208523729730914 --------- Co-authored-by: Jesse Seldess --- README.md | 50 +++-- src/inference/__tests__/embed.test.ts | 42 +++++ src/inference/__tests__/inference.test.ts | 216 ---------------------- src/inference/__tests__/rerank.test.ts | 100 ++++++++++ src/inference/inference.ts | 145 ++++++++------- src/integration/inference/rerank.test.ts | 98 +++++++++- 6 files changed, 359 insertions(+), 292 deletions(-) create mode 100644 src/inference/__tests__/embed.test.ts delete mode 100644 src/inference/__tests__/inference.test.ts create mode 100644 src/inference/__tests__/rerank.test.ts diff --git a/README.md b/README.md index 548c6184..e4628ce0 100644 --- a/README.md +++ b/README.md @@ -1104,20 +1104,20 @@ import { Pinecone } from '@pinecone-database/pinecone'; const pc = new Pinecone(); const rerankingModel = 'bge-reranker-v2-m3'; const myQuery = 'What are some good Turkey dishes for Thanksgiving?'; -const myDocuments = [ - { text: 'I love turkey sandwiches with pastrami' }, - { - text: 'A lemon brined Turkey with apple sausage stuffing is a classic Thanksgiving main', - }, - { text: 'My favorite Thanksgiving dish is pumpkin pie' }, - { text: 'Turkey is a great source of protein' }, + +// Option 1: Documents as an array of strings +const myDocsStrings = [ + 'I love turkey sandwiches with pastrami', + 'A lemon brined Turkey with apple sausage stuffing is a classic Thanksgiving main', + 'My favorite Thanksgiving dish is pumpkin pie', + 'Turkey is a great source of protein', ]; -// >>> Sample without passing an `options` object: +// Option 1 response const response = await pc.inference.rerank( rerankingModel, myQuery, - myDocuments + myDocsStrings ); console.log(response); // { @@ -1131,15 +1131,43 @@ console.log(response); // usage: { rerankUnits: 1 } // } -// >>> Sample with an `options` object: +// Option 2: Documents as an array of objects +const myDocsObjs = [ + { + title: 'Turkey Sandwiches', + body: 'I love turkey sandwiches with pastrami', + }, + { + title: 'Lemon Turkey', + body: 'A lemon brined Turkey with apple sausage stuffing is a classic Thanksgiving main', + }, + { + title: 'Thanksgiving', + body: 'My favorite Thanksgiving dish is pumpkin pie', + }, + { + title: 'Protein Sources', + body: 'Turkey is a great source of protein', + }, +]; + +// Option 2: Options object declaring which custom key to rerank on +// Note: If no custom key is passed via `rankFields`, each doc must contain a `text` key, and that will act as the default) const rerankOptions = { topN: 3, returnDocuments: false, + rankFields: ['body'], + parameters: { + inputType: 'passage', + truncate: 'END', + }, }; + +// Option 2 response const response = await pc.inference.rerank( rerankingModel, myQuery, - myDocuments, + myDocsObjs, rerankOptions ); console.log(response); diff --git a/src/inference/__tests__/embed.test.ts b/src/inference/__tests__/embed.test.ts new file mode 100644 index 00000000..d3cb5f0e --- /dev/null +++ b/src/inference/__tests__/embed.test.ts @@ -0,0 +1,42 @@ +import { Inference } from '../inference'; +import type { PineconeConfiguration } from '../../data'; +import { inferenceOperationsBuilder } from '../inferenceOperationsBuilder'; + +let inference: Inference; + +beforeAll(() => { + const config: PineconeConfiguration = { apiKey: 'test-api-key' }; + const infApi = inferenceOperationsBuilder(config); + inference = new Inference(infApi); +}); + +describe('Inference Class: _formatInputs', () => { + test('Should format inputs correctly', () => { + const inputs = ['input1', 'input2']; + const expected = [{ text: 'input1' }, { text: 'input2' }]; + const result = inference._formatInputs(inputs); + expect(result).toEqual(expected); + }); +}); + +describe('Inference Class: embed', () => { + test('Should throw error if response is missing required fields', async () => { + const model = 'test-model'; + const inputs = ['input1', 'input2']; + const params = { inputType: 'text', truncate: 'END' }; + + const mockedIncorrectResponse = { model: 'test-model' }; + const expectedError = Error( + 'Response from Inference API is missing required fields' + ); + const embed = jest.spyOn(inference._inferenceApi, 'embed'); + // @ts-ignore + embed.mockResolvedValue(mockedIncorrectResponse); + + try { + await inference.embed(model, inputs, params); + } catch (error) { + expect(error).toEqual(expectedError); + } + }); +}); diff --git a/src/inference/__tests__/inference.test.ts b/src/inference/__tests__/inference.test.ts deleted file mode 100644 index 787f6c2a..00000000 --- a/src/inference/__tests__/inference.test.ts +++ /dev/null @@ -1,216 +0,0 @@ -import { Inference } from '../inference'; -import type { PineconeConfiguration } from '../../data'; -import { inferenceOperationsBuilder } from '../inferenceOperationsBuilder'; -import { PineconeArgumentError } from '../../errors'; -import { RerankResult } from '../../pinecone-generated-ts-fetch/inference'; - -describe('Inference Class: _formatInputs', () => { - let inference: Inference; - - beforeEach(() => { - const config: PineconeConfiguration = { apiKey: 'test-api-key' }; - const infApi = inferenceOperationsBuilder(config); - inference = new Inference(infApi); - }); - - it('Should format inputs correctly', () => { - const inputs = ['input1', 'input2']; - const expected = [{ text: 'input1' }, { text: 'input2' }]; - const result = inference._formatInputs(inputs); - expect(result).toEqual(expected); - }); -}); - -describe('Inference Class: _formatParams', () => { - let inference: Inference; - - beforeEach(() => { - const config: PineconeConfiguration = { apiKey: 'test-api-key' }; - const infApi = inferenceOperationsBuilder(config); - inference = new Inference(infApi); - }); - - it('Should format params correctly', () => { - const params = { inputType: 'text', truncate: 'END' }; - const expected = { inputType: 'text', truncate: 'END' }; - const result = inference._formatParams(params); - expect(result).toEqual(expected); - }); -}); - -describe('Inference Class: embed', () => { - let inference: Inference; - - beforeEach(() => { - const config: PineconeConfiguration = { apiKey: 'test-api-key' }; - const infApi = inferenceOperationsBuilder(config); - inference = new Inference(infApi); - }); - - it('Should throw error if response is missing required fields', async () => { - const model = 'test-model'; - const inputs = ['input1', 'input2']; - const params = { inputType: 'text', truncate: 'END' }; - - const mockedIncorrectResponse = { model: 'test-model' }; - const expectedError = Error( - 'Response from Inference API is missing required fields' - ); - const embed = jest.spyOn(inference._inferenceApi, 'embed'); - // @ts-ignore - embed.mockResolvedValue(mockedIncorrectResponse); - - try { - await inference.embed(model, inputs, params); - } catch (error) { - expect(error).toEqual(expectedError); - } - }); -}); - -describe('Inference Class: rerank', () => { - let inference: Inference; - - beforeEach(() => { - const config: PineconeConfiguration = { apiKey: 'test-api-key' }; - const infApi = inferenceOperationsBuilder(config); - inference = new Inference(infApi); - }); - - test('Throws error if no documents are passed', async () => { - const rerankingModel = 'test-model'; - const myQuery = 'test-query'; - const expectedError = new PineconeArgumentError( - 'You must pass at least one document to rerank' - ); - try { - await inference.rerank(rerankingModel, myQuery, []); - } catch (error) { - expect(error).toEqual(expectedError); - } - }); - - test('Throws error if list of objs is passed for docs, but does not contain `text` key', async () => { - const rerankingModel = 'test-model'; - const myQuery = 'test-query'; - const myDocuments = [{ id: '1' }, { id: '2' }]; - const expectedError = new PineconeArgumentError( - '`documents` can only be a list of strings or a list of objects with at least a `text` key, followed by a' + - ' string value' - ); - try { - await inference.rerank(rerankingModel, myQuery, myDocuments); - } catch (error) { - expect(error).toEqual(expectedError); - } - }); - - test('Confirm list of strings as docs is converted to list of objects with `text` key', async () => { - const rerankingModel = 'test-model'; - const myQuery = 'test-query'; - const myDocuments = ['doc1', 'doc2']; - const expectedDocuments = [{ text: 'doc1' }, { text: 'doc2' }]; - const rerank = jest.spyOn(inference._inferenceApi, 'rerank'); - rerank.mockResolvedValue({ - model: 'some-model', - data: [{}], - usage: { rerankUnits: 1 }, - } as RerankResult); - await inference.rerank(rerankingModel, myQuery, myDocuments); - - const expectedReq = { - model: rerankingModel, - query: myQuery, - documents: expectedDocuments, - // defaults: - parameters: {}, - rankFields: ['text'], - returnDocuments: true, - topN: 2, - }; - - expect(rerank).toHaveBeenCalledWith({ rerankRequest: expectedReq }); - }); - - test('Confirm error thrown if rankFields does match fields in passed documents', async () => { - const rerankingModel = 'test-model'; - const myQuery = 'test-query'; - const myDocuments = [ - { text: 'doc1', title: 'title1' }, - { text: 'doc2', title: 'title2' }, - ]; - const rankFields = ['OopsIMisspelledTheTextField', 'title']; - const rerank = jest.spyOn(inference._inferenceApi, 'rerank'); - // @ts-ignore - rerank.mockResolvedValue({ rerankResponse: {} }); - try { - await inference.rerank(rerankingModel, myQuery, myDocuments, { - rankFields, - }); - } catch (error) { - expect(error).toEqual( - new PineconeArgumentError( - 'The `rankField` value you passed ("OopsIMisspelledTheTextField") is missing in the document at index 0' - ) - ); - } - }); - - test('Confirm provided rankFields override default `text` field for ranking', async () => { - const rerankingModel = 'test-model'; - const myQuery = 'test-query'; - const myDocuments = [ - { text: 'doc1', title: 'title1' }, - { text: 'doc2', title: 'title2' }, - ]; - const rankFields = ['title']; - const rerank = jest.spyOn(inference._inferenceApi, 'rerank'); - // @ts-ignore - rerank.mockResolvedValue({ rerankResponse: {} }); - await inference.rerank(rerankingModel, myQuery, myDocuments, { - rankFields, - }); - - const expectedReq = { - model: rerankingModel, - query: myQuery, - documents: myDocuments, - rankFields, - // defaults: - parameters: {}, - returnDocuments: true, - topN: 2, - }; - - expect(rerank).toHaveBeenCalledWith({ rerankRequest: expectedReq }); - }); - - test('Confirm error thrown if query is missing', async () => { - const rerankingModel = 'test-model'; - const myQuery = ''; - const myDocuments = [{ text: 'doc1' }, { text: 'doc2' }]; - const expectedError = new PineconeArgumentError( - 'You must pass a query to rerank' - ); - try { - await inference.rerank(rerankingModel, myQuery, myDocuments); - } catch (error) { - expect(error).toEqual(expectedError); - } - }); - - test('Confirm error thrown if model is missing', async () => { - const rerankingModel = ''; - const myQuery = 'test-query'; - const myDocuments = [{ text: 'doc1' }, { text: 'doc2' }]; - const expectedError = new PineconeArgumentError( - 'You must pass the name of a supported reranking model in order to rerank' + - ' documents. See https://docs.pinecone.io/models for supported models.' - ); - try { - await inference.rerank(rerankingModel, myQuery, myDocuments); - } catch (error) { - expect(error).toEqual(expectedError); - } - }); -}); diff --git a/src/inference/__tests__/rerank.test.ts b/src/inference/__tests__/rerank.test.ts new file mode 100644 index 00000000..3f1f8f6b --- /dev/null +++ b/src/inference/__tests__/rerank.test.ts @@ -0,0 +1,100 @@ +import { Inference } from '../inference'; +import type { PineconeConfiguration } from '../../data'; +import { inferenceOperationsBuilder } from '../inferenceOperationsBuilder'; +import { PineconeArgumentError } from '../../errors'; +import { RerankResult } from '../../pinecone-generated-ts-fetch/inference'; + +let inference: Inference; +const rerankModel = 'test-model'; +const myQuery = 'test-query'; + +beforeAll(() => { + const config: PineconeConfiguration = { apiKey: 'test-api-key' }; + const infApi = inferenceOperationsBuilder(config); + inference = new Inference(infApi); +}); + +test('Confirm throws error if no documents are passed', async () => { + try { + await inference.rerank(rerankModel, myQuery, []); + } catch (error) { + expect(error).toEqual( + new PineconeArgumentError('You must pass at least one document to rerank') + ); + } +}); + +test('Confirm docs as list of strings is converted to list of objects with `text` key', async () => { + const myDocuments = ['doc1', 'doc2']; + const expectedDocuments = [{ text: 'doc1' }, { text: 'doc2' }]; + const rerank = jest.spyOn(inference._inferenceApi, 'rerank'); + rerank.mockResolvedValue({ + model: 'some-model', + data: [{}], + usage: { rerankUnits: 1 }, + } as RerankResult); + await inference.rerank(rerankModel, myQuery, myDocuments); + + const expectedReq = { + model: rerankModel, + query: myQuery, + documents: expectedDocuments, + parameters: {}, + rankFields: ['text'], + returnDocuments: true, + topN: 2, + }; + expect(rerank).toHaveBeenCalledWith({ rerankRequest: expectedReq }); +}); + +test('Confirm provided rankFields override default `text` field for reranking', async () => { + const myDocuments = [ + { text: 'doc1', title: 'title1' }, + { text: 'doc2', title: 'title2' }, + ]; + const rankFields = ['title']; + const rerank = jest.spyOn(inference._inferenceApi, 'rerank'); + // @ts-ignore + rerank.mockResolvedValue({ rerankResponse: {} }); + await inference.rerank(rerankModel, myQuery, myDocuments, { + rankFields, + }); + + const expectedReq = { + model: rerankModel, + query: myQuery, + documents: myDocuments, + rankFields, + parameters: {}, + returnDocuments: true, + topN: 2, + }; + expect(rerank).toHaveBeenCalledWith({ rerankRequest: expectedReq }); +}); + +test('Confirm error thrown if query is missing', async () => { + const myQuery = ''; + const myDocuments = [{ text: 'doc1' }, { text: 'doc2' }]; + try { + await inference.rerank(rerankModel, myQuery, myDocuments); + } catch (error) { + expect(error).toEqual( + new PineconeArgumentError('You must pass a query to rerank') + ); + } +}); + +test('Confirm error thrown if model is missing', async () => { + const rerankModel = ''; + const myDocuments = [{ text: 'doc1' }, { text: 'doc2' }]; + try { + await inference.rerank(rerankModel, myQuery, myDocuments); + } catch (error) { + expect(error).toEqual( + new PineconeArgumentError( + 'You must pass the name of a supported reranking model in order to rerank' + + ' documents. See https://docs.pinecone.io/models for supported models.' + ) + ); + } +}); diff --git a/src/inference/inference.ts b/src/inference/inference.ts index dc16d0b8..5e1fed0f 100644 --- a/src/inference/inference.ts +++ b/src/inference/inference.ts @@ -1,14 +1,21 @@ import { EmbedOperationRequest, EmbedRequestInputsInner, - EmbedRequestParameters, InferenceApi, RerankResult, } from '../pinecone-generated-ts-fetch/inference'; import { EmbeddingsList } from '../models'; import { PineconeArgumentError } from '../errors'; -import { prerelease } from '../utils/prerelease'; +/** Options one can send with a request to {@link rerank} * + * + * @param topN - The number of documents to return in the response. Default is the number of documents passed in the + * request. + * @param returnDocuments - Whether to return the documents in the response. Default is `true`. + * @param rankFields - The fields by which to rank the documents. If no field is passed, default is `['text']`. + * Note: some models only support 1 reranking field. See the [model documentation](https://docs.pinecone.io/guides/inference/understanding-inference#rerank) for more information. + * @param parameters - Additional model-specific parameters to send with the request, e.g. {truncate: "END"}. + * */ export interface RerankOptions { topN?: number; returnDocuments?: boolean; @@ -31,13 +38,6 @@ export class Inference { }); } - /* Format the parameters object into the correct format for the Inference API request. */ - public _formatParams( - parameters: Record - ): EmbedRequestParameters { - return parameters; - } - /* Generate embeddings for a list of input strings using a specified embedding model. */ async embed( model: string, @@ -46,12 +46,11 @@ export class Inference { ): Promise { const typedAndFormattedInputs: Array = this._formatInputs(inputs); - const typedParams: EmbedRequestParameters = this._formatParams(params); const typedRequest: EmbedOperationRequest = { embedRequest: { model: model, inputs: typedAndFormattedInputs, - parameters: typedParams, + parameters: params, }, }; const response = await this._inferenceApi.embed(typedRequest); @@ -61,25 +60,27 @@ export class Inference { /** Rerank documents against a query with a reranking model. Each document is ranked in descending relevance order * against the query provided. * - * Note: by default, the ['text'] field of each document is used for ranking; you can overwrite this default - * behavior by passing an {@link RerankOptions} `options` object specifying 1+ other fields. - * * @example - * ```typescript + * ````typescript * import { Pinecone } from '@pinecone-database/pinecone'; - * * const pc = new Pinecone(); - * const rerankingModel = "bge-reranker-v2-m3"; - * const myQuery = "What are some good Turkey dishes for Thanksgiving?"; - * const myDocuments = [ - * { text: "I love turkey sandwiches with pastrami" }, - * { text: "A lemon brined Turkey with apple sausage stuffing is a classic Thanksgiving main" }, - * { text: "My favorite Thanksgiving dish is pumpkin pie" }, - * { text: "Turkey is a great source of protein" }, + * const rerankingModel = 'bge-reranker-v2-m3'; + * const myQuery = 'What are some good Turkey dishes for Thanksgiving?'; + * + * // Option 1: Documents as an array of strings + * const myDocsStrings = [ + * 'I love turkey sandwiches with pastrami', + * 'A lemon brined Turkey with apple sausage stuffing is a classic Thanksgiving main', + * 'My favorite Thanksgiving dish is pumpkin pie', + * 'Turkey is a great source of protein', * ]; * - * // >>> Sample without passing an `options` object: - * const response = await pc.inference.rerank(rerankingModel, myQuery, myDocuments); + * // Option 1 response + * const response = await pc.inference.rerank( + * rerankingModel, + * myQuery, + * myDocsStrings + * ); * console.log(response); * // { * // model: 'bge-reranker-v2-m3', @@ -92,13 +93,43 @@ export class Inference { * // usage: { rerankUnits: 1 } * // } * + * // Option 2: Documents as an array of objects + * const myDocsObjs = [ + * { + * title: 'Turkey Sandwiches', + * body: 'I love turkey sandwiches with pastrami', + * }, + * { + * title: 'Lemon Turkey', + * body: 'A lemon brined Turkey with apple sausage stuffing is a classic Thanksgiving main', + * }, + * { + * title: 'Thanksgiving', + * body: 'My favorite Thanksgiving dish is pumpkin pie', + * }, + * { title: 'Protein Sources', body: 'Turkey is a great source of protein' }, + * ]; * - * // >>> Sample with an `options` object: + * // Option 2: Options object declaring which custom key to rerank on + * // Note: If no custom key is passed via `rankFields`, each doc must contain a `text` key, and that will act as + * the default) * const rerankOptions = { - * topN: 3, - * returnDocuments: false - * } - * const response = await pc.inference.rerank(rerankingModel, myQuery, myDocuments, rerankOptions); + * topN: 3, + * returnDocuments: false, + * rankFields: ['body'], + * parameters: { + * inputType: 'passage', + * truncate: 'END', + * }, + * }; + * + * // Option 2 response + * const response = await pc.inference.rerank( + * rerankingModel, + * myQuery, + * myDocsObjs, + * rerankOptions + * ); * console.log(response); * // { * // model: 'bge-reranker-v2-m3', @@ -109,17 +140,14 @@ export class Inference { * // ], * // usage: { rerankUnits: 1 } * //} - * ``` + * ``` * * @param model - (Required) The model to use for reranking. Currently, the only available model is "[bge-reranker-v2-m3](https://docs.pinecone.io/models/bge-reranker-v2-m3)"}. * @param query - (Required) The query to rerank documents against. - * @param documents - (Required) The documents to rerank. Each document must be either a string or an object - * with (at minimum) a `text` key. - * @param options - (Optional) Additional options to send with the core reranking request. Options include: how many - * results to return, whether to return the documents in the response, alternative fields by which the model - * should to rank the documents, and additional model-specific parameters. See {@link RerankOptions} for more details. + * @param documents - (Required) An array of documents to rerank. The array can either be an array of strings or + * an array of objects. + * @param options - (Optional) Additional options to send with the reranking request. See {@link RerankOptions} for more details. * */ - @prerelease('2024-10') async rerank( model: string, query: string, @@ -140,41 +168,30 @@ export class Inference { ' documents. See https://docs.pinecone.io/models for supported models.' ); } - // Destructure `options` with defaults - // Note: If the user passes in key:value pairs in `options` that are not the following, they are ignored + const { topN = documents.length, returnDocuments = true, - rankFields = ['text'], parameters = {}, } = options; - // Allow documents to be passed a list of strings, or a list of objs w/at least a `text` key: - let newDocuments: Array<{ [key: string]: string }> = []; - if (typeof documents[0] === 'object' && !('text' in documents[0])) { - throw new PineconeArgumentError( - '`documents` can only be a list of strings or a list of objects with at least a `text` key, followed by a' + - ' string value' - ); - } else if (typeof documents[0] === 'string') { - newDocuments = documents.map((doc) => { - return { text: doc as string }; - }); - } else { - newDocuments = documents as Array<{ [key: string]: string }>; + let { rankFields = ['text'] } = options; + + // Validate and standardize documents to ensure they are in object format + const newDocuments = documents.map((doc) => + typeof doc === 'string' ? { text: doc } : doc + ); + + if (!options.rankFields) { + if (!newDocuments.every((doc) => typeof doc === 'object' && doc.text)) { + throw new PineconeArgumentError( + 'Documents must be a list of strings or objects containing the "text" field' + ); + } } - // Ensure all rankFields, if passed, are present in each document - if (rankFields.length > 0) { - newDocuments.forEach((doc, index) => { - rankFields.forEach((field) => { - if (!(field in doc)) { - throw new PineconeArgumentError( - `The \`rankField\` value you passed ("${field}") is missing in the document at index ${index}` - ); - } - }); - }); + if (options.rankFields) { + rankFields = options.rankFields; } const req = { diff --git a/src/integration/inference/rerank.test.ts b/src/integration/inference/rerank.test.ts index bd6592b9..b88f5a2c 100644 --- a/src/integration/inference/rerank.test.ts +++ b/src/integration/inference/rerank.test.ts @@ -7,12 +7,12 @@ describe('Integration Test: Pinecone Inference API rerank endpoint', () => { let pinecone: Pinecone; beforeAll(() => { + model = 'bge-reranker-v2-m3'; query = 'What are some good Turkey dishes for Thanksgiving?'; documents = [ 'document content 1 yay I am about turkey', 'document content 2', ]; - model = 'bge-reranker-v2-m3'; const apiKey = process.env.PINECONE_API_KEY || ''; pinecone = new Pinecone({ apiKey }); }); @@ -34,4 +34,100 @@ describe('Integration Test: Pinecone Inference API rerank endpoint', () => { // (Just ignoring the fact that technically doc.document['text'] could be undefined) expect(response.data.map((doc) => doc.document['text'])).toBeDefined(); }); + + test('Confirm list of strings as docs + rankFields set to customField fails', async () => { + const myDocuments = ['doc1', 'doc2']; + const rankFields = ['customField']; + + await expect( + pinecone.inference.rerank(model, query, myDocuments, { rankFields }) + ).rejects.toThrow( + expect.objectContaining({ + message: expect.stringContaining( + "field 'customField' not found in document at index 0" + ), + }) + ); + }); + + test('Confirm docs as list of objects + no rankFields succeeds, if docs contain `text` key, succeeds', async () => { + const myDocuments = [{ text: 'doc1' }, { text: 'doc2' }]; + const resp = await pinecone.inference.rerank(model, query, myDocuments); + expect(resp.usage.rerankUnits).toBeGreaterThanOrEqual(1); + }); + + test('Confirm docs as list of objects with additional customField + no rankfields, succeeds', async () => { + const myDocuments = [ + { text: 'hi', customField: 'doc1' }, + { text: 'bye', customField: 'doc2' }, + ]; + const resp = await pinecone.inference.rerank(model, query, myDocuments); + expect(resp.usage.rerankUnits).toBeGreaterThanOrEqual(1); + }); + + test('Confirm docs as list of objects with only custom fields + custom rankFields, succeeds', async () => { + const myDocuments = [ + { customField2: 'hi', customField: 'doc1' }, + { customField2: 'bye', customField: 'doc2' }, + ]; + const rankFields = ['customField2']; + const resp = await pinecone.inference.rerank(model, query, myDocuments, { + rankFields: rankFields, + }); + expect(resp.usage.rerankUnits).toBeGreaterThanOrEqual(1); + }); + + test('Confirm error thrown if docs as list of objects only has custom fields + no custom rankFields obj is passed', async () => { + const myDocuments = [ + { customField2: 'hi', customField: 'doc1' }, + { customField2: 'bye', customField: 'doc2' }, + ]; + await expect( + pinecone.inference.rerank(model, query, myDocuments) + ).rejects.toThrow( + expect.objectContaining({ + message: expect.stringContaining( + 'Documents must be a list of strings or objects containing the "text" field' + ), + }) + ); + }); + + test('Confirm error thrown if rankFields does not match fields in passed documents', async () => { + const myDocuments = [ + { text: 'doc1', title: 'title1' }, + { text: 'doc2', title: 'title2' }, + ]; + const rankFields = ['NonExistentRankField']; + await expect( + pinecone.inference.rerank(model, query, myDocuments, { + rankFields: rankFields, + }) + ).rejects.toThrow( + expect.objectContaining({ + message: expect.stringContaining( + "field 'NonExistentRankField' not found in document at index" + ), + }) + ); + }); + + test('Confirm error thrown if rankFields > 1 for model that only allows 1', async () => { + const myDocuments = [ + { text: 'doc1', title: 'title1' }, + { text: 'doc2', title: 'title2' }, + ]; + const rankFields = ['title', 'text']; + await expect( + pinecone.inference.rerank(model, query, myDocuments, { + rankFields: rankFields, + }) + ).rejects.toThrow( + expect.objectContaining({ + message: expect.stringContaining( + '"Only one rank field is supported for model' + ), + }) + ); + }); });