Skip to content

Commit

Permalink
feat: implement mvi count items command (#525)
Browse files Browse the repository at this point in the history
Adds MVI count items, which returns the number of items in a vector
index. In the event an index does not exist, the response is a
NOT_FOUND error. That way we distinguish between an empty index that
does exist vs an index that does not exist.

This PR adds the response type, method, documentation, and integration
tests.
  • Loading branch information
malandis authored Jan 23, 2024
1 parent 5718a72 commit d3f4189
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 1 deletion.
11 changes: 11 additions & 0 deletions src/Momento.Sdk/IPreviewVectorIndexClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,17 @@ public interface IPreviewVectorIndexClient : IDisposable
///</returns>
public Task<DeleteIndexResponse> DeleteIndexAsync(string indexName);

/// <summary>
/// Gets the number of items in a vector index.
/// </summary>
/// <remarks>
/// In the event the index does not exist, the response will be an error.
/// A count of zero is reserved for an index that exists but has no items.
/// </remarks>
/// <param name="indexName">The name of the vector index to get the item count from.</param>
/// <returns>Task representing the result of the count items operation.</returns>
public Task<CountItemsResponse> CountItemsAsync(string indexName);

/// <summary>
/// Upserts a batch of items into a vector index.
/// If an item with the same ID already exists in the index, it will be replaced.
Expand Down
24 changes: 24 additions & 0 deletions src/Momento.Sdk/Internal/VectorIndexDataClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,30 @@ public VectorIndexDataClient(IVectorIndexConfiguration config, string authToken,
_exceptionMapper = new CacheExceptionMapper(config.LoggerFactory);
}

const string REQUEST_COUNT_ITEMS = "COUNT_ITEMS";
public async Task<CountItemsResponse> CountItemsAsync(string indexName)
{
try
{
_logger.LogTraceVectorIndexRequest(REQUEST_COUNT_ITEMS, indexName);
CheckValidIndexName(indexName);
var request = new _CountItemsRequest() { IndexName = indexName, All = new _CountItemsRequest.Types.All() };

var response =
await grpcManager.Client.CountItemsAsync(request, new CallOptions(deadline: CalculateDeadline()));
// To maintain CLS compliance we use a long here instead of a ulong.
// The max value of a long is still over 9 quintillion so we should be good for a while.
var itemCount = checked((long)response.ItemCount);
return _logger.LogTraceVectorIndexRequestSuccess(REQUEST_COUNT_ITEMS, indexName,
new CountItemsResponse.Success(itemCount));
}
catch (Exception e)
{
return _logger.LogTraceVectorIndexRequestError(REQUEST_COUNT_ITEMS, indexName,
new CountItemsResponse.Error(_exceptionMapper.Convert(e)));
}
}

const string REQUEST_UPSERT_ITEM_BATCH = "UPSERT_ITEM_BATCH";
public async Task<UpsertItemBatchResponse> UpsertItemBatchAsync(string indexName,
IEnumerable<Item> items)
Expand Down
7 changes: 7 additions & 0 deletions src/Momento.Sdk/Internal/VectorIndexDataGrpcManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace Momento.Sdk.Internal;

public interface IVectorIndexDataClient
{
public Task<_CountItemsResponse> CountItemsAsync(_CountItemsRequest request, CallOptions callOptions);
public Task<_UpsertItemBatchResponse> UpsertItemBatchAsync(_UpsertItemBatchRequest request, CallOptions callOptions);
public Task<_SearchResponse> SearchAsync(_SearchRequest request, CallOptions callOptions);
public Task<_SearchAndFetchVectorsResponse> SearchAndFetchVectorsAsync(_SearchAndFetchVectorsRequest request, CallOptions callOptions);
Expand Down Expand Up @@ -47,6 +48,12 @@ public VectorIndexDataClientWithMiddleware(VectorIndex.VectorIndexClient generat
_middlewares = middlewares;
}

public async Task<_CountItemsResponse> CountItemsAsync(_CountItemsRequest request, CallOptions callOptions)
{
var wrapped = await _middlewares.WrapRequest(request, callOptions, (r, o) => _generatedClient.CountItemsAsync(r, o));
return await wrapped.ResponseAsync;
}

public async Task<_UpsertItemBatchResponse> UpsertItemBatchAsync(_UpsertItemBatchRequest request, CallOptions callOptions)
{
var wrapped = await _middlewares.WrapRequest(request, callOptions, (r, o) => _generatedClient.UpsertItemBatchAsync(r, o));
Expand Down
2 changes: 1 addition & 1 deletion src/Momento.Sdk/Momento.Sdk.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
<ItemGroup>
<PackageReference Include="Grpc.Net.Client" Version="2.49.0" />
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces" Version="7.0.0" />
<PackageReference Include="Momento.Protos" Version="0.97.1" />
<PackageReference Include="Momento.Protos" Version="0.102.1" />
<PackageReference Include="JWT" Version="9.0.3" />
<PackageReference Include="System.Threading.Channels" Version="6.0.0" />
<PackageReference Include="Microsoft.Extensions.Logging" Version="6.0.0" />
Expand Down
7 changes: 7 additions & 0 deletions src/Momento.Sdk/PreviewVectorIndexClient.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using Momento.Sdk.Auth;
Expand Down Expand Up @@ -54,6 +55,12 @@ public async Task<DeleteIndexResponse> DeleteIndexAsync(string indexName)
return await controlClient.DeleteIndexAsync(indexName);
}

/// <inheritdoc />
public async Task<CountItemsResponse> CountItemsAsync(string indexName)
{
return await dataClient.CountItemsAsync(indexName);
}

/// <inheritdoc />
public async Task<UpsertItemBatchResponse> UpsertItemBatchAsync(string indexName,
IEnumerable<Item> items)
Expand Down
80 changes: 80 additions & 0 deletions src/Momento.Sdk/Responses/Vector/CountItemsResponse.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
using Momento.Sdk.Exceptions;

namespace Momento.Sdk.Responses.Vector;

/// <summary>
/// Parent response type for a count items request. The
/// response object is resolved to a type-safe object of one of
/// the following subtypes:
/// <list type="bullet">
/// <item><description>CountItemsResponse.Success</description></item>
/// <item><description>CountItemsResponse.Error</description></item>
/// </list>
/// Pattern matching can be used to operate on the appropriate subtype.
/// For example:
/// <code>
/// if (response is CountItemsResponse.Success successResponse)
/// {
/// return successResponse.ItemCount;
/// }
/// else if (response is CountItemsResponse.Error errorResponse)
/// {
/// // handle error as appropriate
/// }
/// else
/// {
/// // handle unexpected response
/// }
/// </code>
/// </summary>
public abstract class CountItemsResponse
{
/// <include file="../../docs.xml" path='docs/class[@name="Success"]/description/*' />
public class Success : CountItemsResponse
{
/// <summary>
/// The number of items in the vector index.
/// </summary>
public long ItemCount { get; }

/// <include file="../../docs.xml" path='docs/class[@name="Success"]/description/*' />
/// <param name="itemCount">The number of items in the vector index.</param>
public Success(long itemCount)
{
ItemCount = itemCount;
}

/// <inheritdoc />
public override string ToString()
{
return $"{base.ToString()}: {ItemCount}";
}

}

/// <include file="../../docs.xml" path='docs/class[@name="Error"]/description/*' />
public class Error : CountItemsResponse, IError
{
/// <include file="../../docs.xml" path='docs/class[@name="Error"]/constructor/*' />
public Error(SdkException error)
{
InnerException = error;
}

/// <inheritdoc />
public SdkException InnerException { get; }

/// <inheritdoc />
public MomentoErrorCode ErrorCode => InnerException.ErrorCode;

/// <inheritdoc />
public string Message => $"{InnerException.MessageWrapper}: {InnerException.Message}";

/// <inheritdoc />
public override string ToString()
{
return $"{base.ToString()}: {Message}";
}

}
}
63 changes: 63 additions & 0 deletions tests/Integration/Momento.Sdk.Tests/VectorIndexDataTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -520,4 +520,67 @@ public async Task GetItemAndGetItemMetadata_HappyPath<T>(GetItemDelegate<T> getI
assertOnGetItemResponse.Invoke(getResponse, expected);
}
}

[Fact]
public async Task CountItemsAsync_OnMissingIndex_ReturnsError()
{
var indexName = Utils.NewGuidString();
var response = await vectorIndexClient.CountItemsAsync(indexName);
Assert.True(response is CountItemsResponse.Error, $"Unexpected response: {response}");
var error = (CountItemsResponse.Error)response;
Assert.Equal(MomentoErrorCode.NOT_FOUND_ERROR, error.InnerException.ErrorCode);
}

[Fact]
public async Task CountItemsAsync_OnEmptyIndex_ReturnsZero()
{
var indexName = Utils.TestVectorIndexName("data-count-items-on-empty-index");
using (Utils.WithVectorIndex(vectorIndexClient, indexName, 2, SimilarityMetric.InnerProduct))
{
var response = await vectorIndexClient.CountItemsAsync(indexName);
Assert.True(response is CountItemsResponse.Success, $"Unexpected response: {response}");
var successResponse = (CountItemsResponse.Success)response;
Assert.Equal(0, successResponse.ItemCount);
}
}

[Fact]
public async Task CountItemsAsync_HasItems_CountsCorrectly()
{
var indexName = Utils.TestVectorIndexName("data-count-items-has-items-counts-correctly");
using (Utils.WithVectorIndex(vectorIndexClient, indexName, 2, SimilarityMetric.InnerProduct))
{
var items = new List<Item>
{
new("test_item_1", new List<float> { 1.0f, 2.0f }),
new("test_item_2", new List<float> { 3.0f, 4.0f }),
new("test_item_3", new List<float> { 5.0f, 6.0f }),
new("test_item_4", new List<float> { 7.0f, 8.0f }),
new("test_item_5", new List<float> { 9.0f, 10.0f }),
};

var upsertResponse = await vectorIndexClient.UpsertItemBatchAsync(indexName, items);
Assert.True(upsertResponse is UpsertItemBatchResponse.Success,
$"Unexpected response: {upsertResponse}");

await Task.Delay(2_000);

var response = await vectorIndexClient.CountItemsAsync(indexName);
Assert.True(response is CountItemsResponse.Success, $"Unexpected response: {response}");
var successResponse = (CountItemsResponse.Success)response;
Assert.Equal(5, successResponse.ItemCount);

// Delete two items
var deleteResponse = await vectorIndexClient.DeleteItemBatchAsync(indexName,
new List<string> { "test_item_1", "test_item_2" });
Assert.True(deleteResponse is DeleteItemBatchResponse.Success, $"Unexpected response: {deleteResponse}");

await Task.Delay(2_000);

response = await vectorIndexClient.CountItemsAsync(indexName);
Assert.True(response is CountItemsResponse.Success, $"Unexpected response: {response}");
successResponse = (CountItemsResponse.Success)response;
Assert.Equal(3, successResponse.ItemCount);
}
}
}

0 comments on commit d3f4189

Please sign in to comment.