diff --git a/src/Momento.Sdk/IPreviewVectorIndexClient.cs b/src/Momento.Sdk/IPreviewVectorIndexClient.cs index fb4edbb2..67aaf0ea 100644 --- a/src/Momento.Sdk/IPreviewVectorIndexClient.cs +++ b/src/Momento.Sdk/IPreviewVectorIndexClient.cs @@ -105,6 +105,17 @@ public interface IPreviewVectorIndexClient : IDisposable /// public Task DeleteIndexAsync(string indexName); + /// + /// Gets the number of items in a vector index. + /// + /// + /// In the event the index does not exist, the response will be an error. + /// A count of zero is reserved for an index that exists but has no items. + /// + /// The name of the vector index to get the item count from. + /// Task representing the result of the count items operation. + public Task CountItemsAsync(string indexName); + /// /// Upserts a batch of items into a vector index. /// If an item with the same ID already exists in the index, it will be replaced. diff --git a/src/Momento.Sdk/Internal/VectorIndexDataClient.cs b/src/Momento.Sdk/Internal/VectorIndexDataClient.cs index 66d219a2..6a78c782 100644 --- a/src/Momento.Sdk/Internal/VectorIndexDataClient.cs +++ b/src/Momento.Sdk/Internal/VectorIndexDataClient.cs @@ -28,6 +28,30 @@ public VectorIndexDataClient(IVectorIndexConfiguration config, string authToken, _exceptionMapper = new CacheExceptionMapper(config.LoggerFactory); } + const string REQUEST_COUNT_ITEMS = "COUNT_ITEMS"; + public async Task CountItemsAsync(string indexName) + { + try + { + _logger.LogTraceVectorIndexRequest(REQUEST_COUNT_ITEMS, indexName); + CheckValidIndexName(indexName); + var request = new _CountItemsRequest() { IndexName = indexName, All = new _CountItemsRequest.Types.All() }; + + var response = + await grpcManager.Client.CountItemsAsync(request, new CallOptions(deadline: CalculateDeadline())); + // To maintain CLS compliance we use a long here instead of a ulong. + // The max value of a long is still over 9 quintillion so we should be good for a while. + var itemCount = checked((long)response.ItemCount); + return _logger.LogTraceVectorIndexRequestSuccess(REQUEST_COUNT_ITEMS, indexName, + new CountItemsResponse.Success(itemCount)); + } + catch (Exception e) + { + return _logger.LogTraceVectorIndexRequestError(REQUEST_COUNT_ITEMS, indexName, + new CountItemsResponse.Error(_exceptionMapper.Convert(e))); + } + } + const string REQUEST_UPSERT_ITEM_BATCH = "UPSERT_ITEM_BATCH"; public async Task UpsertItemBatchAsync(string indexName, IEnumerable items) diff --git a/src/Momento.Sdk/Internal/VectorIndexDataGrpcManager.cs b/src/Momento.Sdk/Internal/VectorIndexDataGrpcManager.cs index ea77bd37..ff21ceee 100644 --- a/src/Momento.Sdk/Internal/VectorIndexDataGrpcManager.cs +++ b/src/Momento.Sdk/Internal/VectorIndexDataGrpcManager.cs @@ -19,6 +19,7 @@ namespace Momento.Sdk.Internal; public interface IVectorIndexDataClient { + public Task<_CountItemsResponse> CountItemsAsync(_CountItemsRequest request, CallOptions callOptions); public Task<_UpsertItemBatchResponse> UpsertItemBatchAsync(_UpsertItemBatchRequest request, CallOptions callOptions); public Task<_SearchResponse> SearchAsync(_SearchRequest request, CallOptions callOptions); public Task<_SearchAndFetchVectorsResponse> SearchAndFetchVectorsAsync(_SearchAndFetchVectorsRequest request, CallOptions callOptions); @@ -47,6 +48,12 @@ public VectorIndexDataClientWithMiddleware(VectorIndex.VectorIndexClient generat _middlewares = middlewares; } + public async Task<_CountItemsResponse> CountItemsAsync(_CountItemsRequest request, CallOptions callOptions) + { + var wrapped = await _middlewares.WrapRequest(request, callOptions, (r, o) => _generatedClient.CountItemsAsync(r, o)); + return await wrapped.ResponseAsync; + } + public async Task<_UpsertItemBatchResponse> UpsertItemBatchAsync(_UpsertItemBatchRequest request, CallOptions callOptions) { var wrapped = await _middlewares.WrapRequest(request, callOptions, (r, o) => _generatedClient.UpsertItemBatchAsync(r, o)); diff --git a/src/Momento.Sdk/Momento.Sdk.csproj b/src/Momento.Sdk/Momento.Sdk.csproj index 6093fce0..82fadc32 100644 --- a/src/Momento.Sdk/Momento.Sdk.csproj +++ b/src/Momento.Sdk/Momento.Sdk.csproj @@ -57,7 +57,7 @@ - + diff --git a/src/Momento.Sdk/PreviewVectorIndexClient.cs b/src/Momento.Sdk/PreviewVectorIndexClient.cs index f6357ddc..8adaea83 100644 --- a/src/Momento.Sdk/PreviewVectorIndexClient.cs +++ b/src/Momento.Sdk/PreviewVectorIndexClient.cs @@ -1,3 +1,4 @@ +using System; using System.Collections.Generic; using System.Threading.Tasks; using Momento.Sdk.Auth; @@ -54,6 +55,12 @@ public async Task DeleteIndexAsync(string indexName) return await controlClient.DeleteIndexAsync(indexName); } + /// + public async Task CountItemsAsync(string indexName) + { + return await dataClient.CountItemsAsync(indexName); + } + /// public async Task UpsertItemBatchAsync(string indexName, IEnumerable items) diff --git a/src/Momento.Sdk/Responses/Vector/CountItemsResponse.cs b/src/Momento.Sdk/Responses/Vector/CountItemsResponse.cs new file mode 100644 index 00000000..9ce645c2 --- /dev/null +++ b/src/Momento.Sdk/Responses/Vector/CountItemsResponse.cs @@ -0,0 +1,80 @@ +using Momento.Sdk.Exceptions; + +namespace Momento.Sdk.Responses.Vector; + +/// +/// Parent response type for a count items request. The +/// response object is resolved to a type-safe object of one of +/// the following subtypes: +/// +/// CountItemsResponse.Success +/// CountItemsResponse.Error +/// +/// Pattern matching can be used to operate on the appropriate subtype. +/// For example: +/// +/// if (response is CountItemsResponse.Success successResponse) +/// { +/// return successResponse.ItemCount; +/// } +/// else if (response is CountItemsResponse.Error errorResponse) +/// { +/// // handle error as appropriate +/// } +/// else +/// { +/// // handle unexpected response +/// } +/// +/// +public abstract class CountItemsResponse +{ + /// + public class Success : CountItemsResponse + { + /// + /// The number of items in the vector index. + /// + public long ItemCount { get; } + + /// + /// The number of items in the vector index. + public Success(long itemCount) + { + ItemCount = itemCount; + } + + /// + public override string ToString() + { + return $"{base.ToString()}: {ItemCount}"; + } + + } + + /// + public class Error : CountItemsResponse, IError + { + /// + public Error(SdkException error) + { + InnerException = error; + } + + /// + public SdkException InnerException { get; } + + /// + public MomentoErrorCode ErrorCode => InnerException.ErrorCode; + + /// + public string Message => $"{InnerException.MessageWrapper}: {InnerException.Message}"; + + /// + public override string ToString() + { + return $"{base.ToString()}: {Message}"; + } + + } +} diff --git a/tests/Integration/Momento.Sdk.Tests/VectorIndexDataTest.cs b/tests/Integration/Momento.Sdk.Tests/VectorIndexDataTest.cs index 2d77356b..dc66b6cd 100644 --- a/tests/Integration/Momento.Sdk.Tests/VectorIndexDataTest.cs +++ b/tests/Integration/Momento.Sdk.Tests/VectorIndexDataTest.cs @@ -520,4 +520,67 @@ public async Task GetItemAndGetItemMetadata_HappyPath(GetItemDelegate getI assertOnGetItemResponse.Invoke(getResponse, expected); } } + + [Fact] + public async Task CountItemsAsync_OnMissingIndex_ReturnsError() + { + var indexName = Utils.NewGuidString(); + var response = await vectorIndexClient.CountItemsAsync(indexName); + Assert.True(response is CountItemsResponse.Error, $"Unexpected response: {response}"); + var error = (CountItemsResponse.Error)response; + Assert.Equal(MomentoErrorCode.NOT_FOUND_ERROR, error.InnerException.ErrorCode); + } + + [Fact] + public async Task CountItemsAsync_OnEmptyIndex_ReturnsZero() + { + var indexName = Utils.TestVectorIndexName("data-count-items-on-empty-index"); + using (Utils.WithVectorIndex(vectorIndexClient, indexName, 2, SimilarityMetric.InnerProduct)) + { + var response = await vectorIndexClient.CountItemsAsync(indexName); + Assert.True(response is CountItemsResponse.Success, $"Unexpected response: {response}"); + var successResponse = (CountItemsResponse.Success)response; + Assert.Equal(0, successResponse.ItemCount); + } + } + + [Fact] + public async Task CountItemsAsync_HasItems_CountsCorrectly() + { + var indexName = Utils.TestVectorIndexName("data-count-items-has-items-counts-correctly"); + using (Utils.WithVectorIndex(vectorIndexClient, indexName, 2, SimilarityMetric.InnerProduct)) + { + var items = new List + { + new("test_item_1", new List { 1.0f, 2.0f }), + new("test_item_2", new List { 3.0f, 4.0f }), + new("test_item_3", new List { 5.0f, 6.0f }), + new("test_item_4", new List { 7.0f, 8.0f }), + new("test_item_5", new List { 9.0f, 10.0f }), + }; + + var upsertResponse = await vectorIndexClient.UpsertItemBatchAsync(indexName, items); + Assert.True(upsertResponse is UpsertItemBatchResponse.Success, + $"Unexpected response: {upsertResponse}"); + + await Task.Delay(2_000); + + var response = await vectorIndexClient.CountItemsAsync(indexName); + Assert.True(response is CountItemsResponse.Success, $"Unexpected response: {response}"); + var successResponse = (CountItemsResponse.Success)response; + Assert.Equal(5, successResponse.ItemCount); + + // Delete two items + var deleteResponse = await vectorIndexClient.DeleteItemBatchAsync(indexName, + new List { "test_item_1", "test_item_2" }); + Assert.True(deleteResponse is DeleteItemBatchResponse.Success, $"Unexpected response: {deleteResponse}"); + + await Task.Delay(2_000); + + response = await vectorIndexClient.CountItemsAsync(indexName); + Assert.True(response is CountItemsResponse.Success, $"Unexpected response: {response}"); + successResponse = (CountItemsResponse.Success)response; + Assert.Equal(3, successResponse.ItemCount); + } + } }