diff --git a/common/aws/dynamodb/client.go b/common/aws/dynamodb/client.go index 2e309a86bd..7fba0f0e26 100644 --- a/common/aws/dynamodb/client.go +++ b/common/aws/dynamodb/client.go @@ -17,8 +17,10 @@ import ( ) const ( - // dynamoBatchLimit is the maximum number of items that can be written in a single batch - dynamoBatchLimit = 25 + // dynamoBatchWriteLimit is the maximum number of items that can be written in a single batch + dynamoBatchWriteLimit = 25 + // dynamoBatchReadLimit is the maximum number of items that can be read in a single batch + dynamoBatchReadLimit = 100 ) type batchOperation uint @@ -163,6 +165,15 @@ func (c *Client) GetItem(ctx context.Context, tableName string, key Key) (Item, return resp.Item, nil } +func (c *Client) GetItems(ctx context.Context, tableName string, keys []Key) ([]Item, error) { + items, err := c.readItems(ctx, tableName, keys) + if err != nil { + return nil, err + } + + return items, nil +} + // QueryIndex returns all items in the index that match the given key func (c *Client) QueryIndex(ctx context.Context, tableName string, indexName string, keyCondition string, expAttributeValues ExpresseionValues) ([]Item, error) { response, err := c.dynamoClient.Query(ctx, &dynamodb.QueryInput{ @@ -277,7 +288,7 @@ func (c *Client) writeItems(ctx context.Context, tableName string, requestItems failedItems := make([]map[string]types.AttributeValue, 0) for startIndex < len(requestItems) { remainingNumKeys := float64(len(requestItems) - startIndex) - batchSize := int(math.Min(float64(dynamoBatchLimit), remainingNumKeys)) + batchSize := int(math.Min(float64(dynamoBatchWriteLimit), remainingNumKeys)) writeRequests := make([]types.WriteRequest, batchSize) for i := 0; i < batchSize; i += 1 { item := requestItems[startIndex+i] @@ -307,8 +318,42 @@ func (c *Client) writeItems(ctx context.Context, tableName string, requestItems } } - startIndex += dynamoBatchLimit + startIndex += dynamoBatchWriteLimit } return failedItems, nil } + +func (c *Client) readItems(ctx context.Context, tableName string, keys []Key) ([]Item, error) { + startIndex := 0 + items := make([]Item, 0) + for startIndex < len(keys) { + remainingNumKeys := float64(len(keys) - startIndex) + batchSize := int(math.Min(float64(dynamoBatchReadLimit), remainingNumKeys)) + keysBatch := keys[startIndex : startIndex+batchSize] + output, err := c.dynamoClient.BatchGetItem(ctx, &dynamodb.BatchGetItemInput{ + RequestItems: map[string]types.KeysAndAttributes{ + tableName: { + Keys: keysBatch, + }, + }, + }) + if err != nil { + return nil, err + } + + if len(output.Responses) > 0 { + for _, resp := range output.Responses { + items = append(items, resp...) + } + } + + if output.UnprocessedKeys != nil { + keys = append(keys, output.UnprocessedKeys[tableName].Keys...) + } + + startIndex += batchSize + } + + return items, nil +} diff --git a/common/aws/dynamodb/client_test.go b/common/aws/dynamodb/client_test.go index 9a905d8a28..7b885f7744 100644 --- a/common/aws/dynamodb/client_test.go +++ b/common/aws/dynamodb/client_test.go @@ -223,13 +223,17 @@ func TestBatchOperations(t *testing.T) { createTable(t, tableName) ctx := context.Background() - numItems := 30 + numItems := 33 items := make([]commondynamodb.Item, numItems) + expectedBlobKeys := make([]string, numItems) + expectedMetadataKeys := make([]string, numItems) for i := 0; i < numItems; i += 1 { items[i] = commondynamodb.Item{ "MetadataKey": &types.AttributeValueMemberS{Value: fmt.Sprintf("key%d", i)}, "BlobKey": &types.AttributeValueMemberS{Value: fmt.Sprintf("blob%d", i)}, } + expectedBlobKeys[i] = fmt.Sprintf("blob%d", i) + expectedMetadataKeys[i] = fmt.Sprintf("key%d", i) } unprocessed, err := dynamoClient.PutItems(ctx, tableName, items) assert.NoError(t, err) @@ -256,6 +260,18 @@ func TestBatchOperations(t *testing.T) { } } + fetchedItems, err := dynamoClient.GetItems(ctx, tableName, keys) + assert.NoError(t, err) + assert.Len(t, fetchedItems, numItems) + blobKeys := make([]string, numItems) + metadataKeys := make([]string, numItems) + for i := 0; i < numItems; i += 1 { + blobKeys[i] = fetchedItems[i]["BlobKey"].(*types.AttributeValueMemberS).Value + metadataKeys[i] = fetchedItems[i]["MetadataKey"].(*types.AttributeValueMemberS).Value + } + assert.ElementsMatch(t, expectedBlobKeys, blobKeys) + assert.ElementsMatch(t, expectedMetadataKeys, metadataKeys) + unprocessedKeys, err := dynamoClient.DeleteItems(ctx, tableName, keys) assert.NoError(t, err) assert.Len(t, unprocessedKeys, 0)