Skip to content

Commit

Permalink
feat: add terms array length limit
Browse files Browse the repository at this point in the history
  • Loading branch information
AElfBourneShi committed Oct 18, 2024
1 parent 9d339a4 commit a3a8b6f
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 35 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using AElf.EntityMapping.Elasticsearch.Linq;
using AElf.EntityMapping.Elasticsearch.Options;
using Microsoft.Extensions.Options;
using Nest;
using Volo.Abp.Domain.Entities;

Expand All @@ -14,14 +16,18 @@ public class ElasticsearchQueryableFactory<TEntity> : IElasticsearchQueryableFac
where TEntity : class, IEntity
{
private readonly ICollectionNameProvider<TEntity> _collectionNameProvider;
private readonly ElasticsearchOptions _elasticsearchOptions;

public ElasticsearchQueryableFactory(ICollectionNameProvider<TEntity> collectionNameProvider)
public ElasticsearchQueryableFactory(ICollectionNameProvider<TEntity> collectionNameProvider,
IOptions<ElasticsearchOptions> elasticsearchOptions)
{
_collectionNameProvider = collectionNameProvider;
_elasticsearchOptions = elasticsearchOptions.Value;
}

public ElasticsearchQueryable<TEntity> Create(IElasticClient client, string index = null)
public ElasticsearchQueryable<TEntity> Create(IElasticClient client,
string index = null)
{
return new ElasticsearchQueryable<TEntity>(client, _collectionNameProvider, index);
return new ElasticsearchQueryable<TEntity>(client, _collectionNameProvider, index, _elasticsearchOptions);
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
using System.Collections.ObjectModel;
using System.Linq.Expressions;
using AElf.EntityMapping.Elasticsearch.Options;
using AElf.EntityMapping.Linq;
using AElf.EntityMapping.Options;
using Microsoft.Extensions.Options;
using Nest;
using Remotion.Linq;
using Remotion.Linq.Clauses;
Expand All @@ -13,12 +16,15 @@ public class ElasticsearchGeneratorQueryModelVisitor<TU> : QueryModelVisitorBase
{
private readonly PropertyNameInferrerParser _propertyNameInferrerParser;
private readonly INodeVisitor _nodeVisitor;
private readonly ElasticsearchOptions _elasticsearchOptions;
private QueryAggregator QueryAggregator { get; set; } = new QueryAggregator();

public ElasticsearchGeneratorQueryModelVisitor(PropertyNameInferrerParser propertyNameInferrerParser)
public ElasticsearchGeneratorQueryModelVisitor(PropertyNameInferrerParser propertyNameInferrerParser,
ElasticsearchOptions elasticsearchOptions)
{
_propertyNameInferrerParser = propertyNameInferrerParser;
_nodeVisitor = new NodeVisitor();
_elasticsearchOptions = elasticsearchOptions;
}

public QueryAggregator GenerateElasticQuery<T>(QueryModel queryModel)
Expand Down Expand Up @@ -48,7 +54,7 @@ public override void VisitMainFromClause(MainFromClause fromClause, QueryModel q

public override void VisitWhereClause(WhereClause whereClause, QueryModel queryModel, int index)
{
var tree = new GeneratorExpressionTreeVisitor<TU>(_propertyNameInferrerParser);
var tree = new GeneratorExpressionTreeVisitor<TU>(_propertyNameInferrerParser, _elasticsearchOptions);
tree.Visit(whereClause.Predicate);
if (QueryAggregator.Query == null)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.Dynamic;
using System.Linq.Expressions;
using AElf.EntityMapping.Elasticsearch.Exceptions;
using AElf.EntityMapping.Elasticsearch.Options;
using Elasticsearch.Net;
using Nest;
using Newtonsoft.Json;
Expand All @@ -23,16 +24,20 @@ public class ElasticsearchQueryExecutor<TEntity>: IQueryExecutor
private readonly JsonSerializerSettings _deserializerSettings;
private readonly ICollectionNameProvider<TEntity> _collectionNameProvider;
private const int ElasticQueryLimit = 10000;
private readonly ElasticsearchOptions _elasticsearchOptions;

public ElasticsearchQueryExecutor(IElasticClient elasticClient,
ICollectionNameProvider<TEntity> collectionNameProvider, string index)
ICollectionNameProvider<TEntity> collectionNameProvider, string index,
ElasticsearchOptions elasticsearchOptions)
{
_elasticClient = elasticClient;
_collectionNameProvider = collectionNameProvider;
_index = index;
_propertyNameInferrerParser = new PropertyNameInferrerParser(_elasticClient);
_elasticsearchOptions = elasticsearchOptions;
_elasticsearchGeneratorQueryModelVisitor =
new ElasticsearchGeneratorQueryModelVisitor<TEntity>(_propertyNameInferrerParser);
new ElasticsearchGeneratorQueryModelVisitor<TEntity>(_propertyNameInferrerParser,
_elasticsearchOptions);
_deserializerSettings = new JsonSerializerSettings
{
// Nest maps TimeSpan as a long (TimeSpan ticks)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Linq.Expressions;
using AElf.EntityMapping.Elasticsearch.Options;
using AElf.EntityMapping.Linq;
using Nest;
using Remotion.Linq;
Expand All @@ -10,10 +11,10 @@ public class ElasticsearchQueryable<T> : QueryableBase<T>, IElasticsearchQueryab
where T : class, IEntity
{
public ElasticsearchQueryable(IElasticClient elasticClient, ICollectionNameProvider<T> collectionNameProvider,
string index)
: base(new DefaultQueryProvider(typeof(ElasticsearchQueryable<>),
string index, ElasticsearchOptions elasticsearchOptions)
: base(new DefaultQueryProvider(typeof(ElasticsearchQueryable<>),
QueryParserFactory.Create(),
new ElasticsearchQueryExecutor<T>(elasticClient, collectionNameProvider, index)))
new ElasticsearchQueryExecutor<T>(elasticClient, collectionNameProvider, index, elasticsearchOptions)))
{
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using System.Linq.Expressions;
using System.Reflection;
using AElf.EntityMapping.Elasticsearch.Options;
using Elasticsearch.Net;
using Remotion.Linq.Clauses;
using Remotion.Linq.Clauses.Expressions;
Expand All @@ -11,6 +11,7 @@ namespace AElf.EntityMapping.Elasticsearch.Linq
public class GeneratorExpressionTreeVisitor<T> : ThrowingExpressionVisitor
{
private readonly PropertyNameInferrerParser _propertyNameInferrerParser;
private readonly ElasticsearchOptions _elasticsearchOptions;

private object Value { get; set; }
private string PropertyName { get; set; }
Expand All @@ -20,9 +21,11 @@ public class GeneratorExpressionTreeVisitor<T> : ThrowingExpressionVisitor
public IDictionary<Expression, Node> QueryMap { get; } =
new Dictionary<Expression, Node>();

public GeneratorExpressionTreeVisitor(PropertyNameInferrerParser propertyNameInferrerParser)
public GeneratorExpressionTreeVisitor(PropertyNameInferrerParser propertyNameInferrerParser,
ElasticsearchOptions elasticsearchOptions)
{
_propertyNameInferrerParser = propertyNameInferrerParser;
_elasticsearchOptions = elasticsearchOptions;
}

protected override Expression VisitUnary(UnaryExpression expression)
Expand Down Expand Up @@ -203,7 +206,14 @@ protected override Expression VisitSubQuery(SubQueryExpression expression)
case ContainsResultOperator containsResultOperator:
Visit(containsResultOperator.Item);
Visit(expression.QueryModel.MainFromClause.FromExpression);


//Check if the number of items in the Terms query array within the Contains clause is too large.
if (expression.QueryModel.MainFromClause
.FromExpression is ConstantExpression constantExpression)
{
CheckTermsArrayLength(constantExpression);
}

// Handling different types
query = GetDifferentTypesTermsQueryNode();

Expand Down Expand Up @@ -290,6 +300,13 @@ private void HandleNestedContains(SubQueryExpression subQueryExpression, Express
if (subQueryExpression == null || expression == null)
throw new ArgumentNullException("SubQueryExpression or expression cannot be null.");

//Check if the number of items in the Terms query array within the Contains clause is too large.
if (subQueryExpression.QueryModel.MainFromClause
.FromExpression is ConstantExpression constantExpression)
{
CheckTermsArrayLength(constantExpression);
}

foreach (var resultOperator in subQueryExpression.QueryModel.ResultOperators)
{
switch (resultOperator)
Expand Down Expand Up @@ -375,18 +392,16 @@ private Node GetDifferentTypesTermsQueryNode()
return query;
}

private string GetMemberName(Expression expression)
private void CheckTermsArrayLength(ConstantExpression constantExpression)
{
if (expression is MemberExpression memberExpression)
return memberExpression.Member.Name;
throw new InvalidOperationException("Expression does not represent a member access.");
}

private object GetValueFromExpression(Expression expression)
{
if (expression is ConstantExpression constantExpression)
return constantExpression.Value;
throw new InvalidOperationException("Expression is not a constant.");
if (constantExpression.Value is System.Collections.IEnumerable objectList)
{
var count = objectList.Cast<object>().Count();
if (count > _elasticsearchOptions.TermsArrayMaxLength)
{
throw new Exception($"The array input for Terms query is too large, exceeding {_elasticsearchOptions.TermsArrayMaxLength} items.");
}
}
}

protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ public class ElasticsearchOptions
public int NumberOfReplicas { get; set; } = 1;
public Refresh Refresh { get; set; } = Refresh.False;
public int MaxResultWindow { get; set; } = 10000;
public int TermsArrayMaxLength { get; set; } = 100;
}
Original file line number Diff line number Diff line change
Expand Up @@ -619,11 +619,11 @@ public async Task GetList_Terms_Test()

var queryable = await _elasticsearchRepository.GetQueryableAsync();

var predicates = inputs
.Select(s => (Expression<Func<BlockIndex, bool>>)(info => info.BlockHash == s))
.Aggregate((prev, next) => prev.Or(next));
var filterList_predicate = queryable.Where(predicates).ToList();
filterList_predicate.Count.ShouldBe(3);
// var predicates = inputs
// .Select(s => (Expression<Func<BlockIndex, bool>>)(info => info.BlockHash == s))
// .Aggregate((prev, next) => prev.Or(next));
// var filterList_predicate = queryable.Where(predicates).ToList();
// filterList_predicate.Count.ShouldBe(3);

var filterList = queryable.Where(item => inputs.Contains(item.BlockHash)).ToList();
filterList.Count.ShouldBe(3);
Expand Down Expand Up @@ -659,12 +659,12 @@ public async Task GetNestedList_Terms_Test()
101,
103
};
var queryable_predicate = await _transactionIndexRepository.GetQueryableAsync();
var predicates = inputs
.Select(s => (Expression<Func<TransactionIndex, bool>>)(info => info.LogEvents.Any(x => x.BlockHeight == s)))
.Aggregate((prev, next) => prev.Or(next));
var filterList_predicate = queryable_predicate.Where(predicates).ToList();
filterList_predicate.Count.ShouldBe(2);
// var queryable_predicate = await _transactionIndexRepository.GetQueryableAsync();
// var predicates = inputs
// .Select(s => (Expression<Func<TransactionIndex, bool>>)(info => info.LogEvents.Any(x => x.BlockHeight == s)))
// .Aggregate((prev, next) => prev.Or(next));
// var filterList_predicate = queryable_predicate.Where(predicates).ToList();
// filterList_predicate.Count.ShouldBe(2);

Expression<Func<TransactionIndex, bool>> mustQuery = item =>
item.LogEvents.Any(x => inputs.Contains(x.BlockHeight));
Expand Down

0 comments on commit a3a8b6f

Please sign in to comment.