Skip to content

Commit

Permalink
Add support for $median aggregation operator.
Browse files Browse the repository at this point in the history
Closes #4472
Original Pull Request: #4515
  • Loading branch information
sxhinzvc authored and christophstrobl committed Sep 28, 2023
1 parent cd80c84 commit d506cd0
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,16 @@ public Percentile percentile(Double... percentages) {
return percentile.percentages(percentages);
}

/**
* Creates new {@link AggregationExpression} that calculates the median of the associated numeric value expression.
*
* @return new instance of {@link Median}.
* @since 4.2
*/
public Median median() {
return usesFieldRef() ? Median.medianOf(fieldReference) : Median.medianOf(expression);
}

private boolean usesFieldRef() {
return fieldReference != null;
}
Expand Down Expand Up @@ -1082,4 +1092,78 @@ protected String getMongoMethod() {
return "$percentile";
}
}

/**
* {@link AggregationExpression} for {@code $median}.
*
* @author Julia Lee
* @since 4.2
*/
public static class Median extends AbstractAggregationExpression {

private Median(Object value) {
super(value);
}

/**
* Creates new {@link Median}.
*
* @param fieldReference must not be {@literal null}.
* @return new instance of {@link Median}.
*/
public static Median medianOf(String fieldReference) {

Assert.notNull(fieldReference, "FieldReference must not be null");
Map<String, Object> fields = new HashMap<>();
fields.put("input", Fields.field(fieldReference));
fields.put("method", "approximate");
return new Median(fields);
}

/**
* Creates new {@link Median}.
*
* @param expression must not be {@literal null}.
* @return new instance of {@link Median}.
*/
public static Median medianOf(AggregationExpression expression) {

Assert.notNull(expression, "Expression must not be null");
Map<String, Object> fields = new HashMap<>();
fields.put("input", expression);
fields.put("method", "approximate");
return new Median(fields);
}

/**
* Creates new {@link Median} with all previously added inputs appending the given one. <br />
* <strong>NOTE:</strong> Only possible in {@code $project} stage.
*
* @param fieldReference must not be {@literal null}.
* @return new instance of {@link Median}.
*/
public Median and(String fieldReference) {

Assert.notNull(fieldReference, "FieldReference must not be null");
return new Median(appendTo("input", Fields.field(fieldReference)));
}

/**
* Creates new {@link Median} with all previously added inputs appending the given one. <br />
* <strong>NOTE:</strong> Only possible in {@code $project} stage.
*
* @param expression must not be {@literal null}.
* @return new instance of {@link Median}.
*/
public Median and(AggregationExpression expression) {

Assert.notNull(expression, "Expression must not be null");
return new Median(appendTo("input", expression));
}

@Override
protected String getMongoMethod() {
return "$median";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.CovariancePop;
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.CovarianceSamp;
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.Max;
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.Median;
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.Min;
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.Percentile;
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.StdDevPop;
Expand Down Expand Up @@ -948,6 +949,18 @@ public Percentile percentile(Double... percentages) {
return percentile.percentages(percentages);
}

/**
* Creates new {@link AggregationExpression} that calculates the requested percentile(s) of the
* numeric value.
*
* @return new instance of {@link Median}.
* @since 4.2
*/
public Median median() {
return usesFieldRef() ? AccumulatorOperators.Median.medianOf(fieldReference)
: AccumulatorOperators.Median.medianOf(expression);
}

private boolean usesFieldRef() {
return fieldReference != null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,34 @@ void rendersPercentileWithExpression() {
.isEqualTo(Document.parse("{ $percentile: { input: [\"$scoreOne\", {\"$sum\": \"$scoreTwo\"}], method: \"approximate\", p: [0.1, 0.2] } }"));
}

@Test // GH-4472
void rendersMedianWithFieldReference() {

assertThat(valueOf("score").median().toDocument(Aggregation.DEFAULT_CONTEXT))
.isEqualTo(Document.parse("{ $median: { input: \"$score\", method: \"approximate\" } }"));

assertThat(valueOf("score").median().and("scoreTwo").toDocument(Aggregation.DEFAULT_CONTEXT))
.isEqualTo(Document.parse("{ $median: { input: [\"$score\", \"$scoreTwo\"], method: \"approximate\" } }"));
}

@Test // GH-4472
void rendersMedianWithExpression() {

assertThat(valueOf(Sum.sumOf("score")).median().toDocument(Aggregation.DEFAULT_CONTEXT))
.isEqualTo(Document.parse("{ $median: { input: {\"$sum\": \"$score\"}, method: \"approximate\" } }"));

assertThat(valueOf("scoreOne").median().and(Sum.sumOf("scoreTwo")).toDocument(Aggregation.DEFAULT_CONTEXT))
.isEqualTo(Document.parse("{ $median: { input: [\"$scoreOne\", {\"$sum\": \"$scoreTwo\"}], method: \"approximate\" } }"));
}

@Test // GH-4472
void rendersMedianCorrectlyWithTypedAggregationContext() {

assertThat(valueOf("midichlorianCount").median()
.toDocument(TestAggregationContext.contextFor(Jedi.class)))
.isEqualTo(Document.parse("{ $median: { input: \"$force\", method: \"approximate\" } }"));
}

static class Jedi {

String name;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1897,19 +1897,44 @@ void facetShouldCreateFacets() {
@EnableIfMongoServerVersion(isGreaterThanEqual = "7.0")
void percentileShouldBeAppliedCorrectly() {

mongoTemplate.insert(new DATAMONGO788(15, 16));
mongoTemplate.insert(new DATAMONGO788(17, 18));
DATAMONGO788 objectToSave = new DATAMONGO788(62, 81, 80);
DATAMONGO788 objectToSave2 = new DATAMONGO788(60, 83, 79);

mongoTemplate.insert(objectToSave);
mongoTemplate.insert(objectToSave2);

Aggregation agg = Aggregation.newAggregation(
project().and(ArithmeticOperators.valueOf("x").percentile(0.9).and("y"))
.as("ninetiethPercentile"));
project().and(ArithmeticOperators.valueOf("x").percentile(0.9, 0.4).and("y").and("xField"))
.as("percentileValues"));

AggregationResults<Document> result = mongoTemplate.aggregate(agg, DATAMONGO788.class, Document.class);

// MongoDB server returns $percentile as an array of doubles
List<Document> rawResults = (List<Document>) result.getRawResults().get("results");
assertThat((List<Object>) rawResults.get(0).get("ninetiethPercentile")).containsExactly(16.0);
assertThat((List<Object>) rawResults.get(1).get("ninetiethPercentile")).containsExactly(18.0);
assertThat((List<Object>) rawResults.get(0).get("percentileValues")).containsExactly(81.0, 80.0);
assertThat((List<Object>) rawResults.get(1).get("percentileValues")).containsExactly(83.0, 79.0);
}

@Test // GH-4472
@EnableIfMongoServerVersion(isGreaterThanEqual = "7.0")
void medianShouldBeAppliedCorrectly() {

DATAMONGO788 objectToSave = new DATAMONGO788(62, 81, 80);
DATAMONGO788 objectToSave2 = new DATAMONGO788(60, 83, 79);

mongoTemplate.insert(objectToSave);
mongoTemplate.insert(objectToSave2);

Aggregation agg = Aggregation.newAggregation(
project().and(ArithmeticOperators.valueOf("x").median().and("y").and("xField"))
.as("medianValue"));

AggregationResults<Document> result = mongoTemplate.aggregate(agg, DATAMONGO788.class, Document.class);

// MongoDB server returns $median a Double
List<Document> rawResults = (List<Document>) result.getRawResults().get("results");
assertThat(rawResults.get(0).get("medianValue")).isEqualTo(80.0);
assertThat(rawResults.get(1).get("medianValue")).isEqualTo(79.0);
}

@Test // DATAMONGO-1986
Expand Down Expand Up @@ -2152,6 +2177,12 @@ public DATAMONGO788() {}
this.y = y;
this.yField = y;
}

public DATAMONGO788(int x, int y, int xField) {
this.x = x;
this.y = y;
this.xField = xField;
}
}

// DATAMONGO-806
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import org.junit.jupiter.api.Test;

/**
* Unit tests for {@link Round}.
* Unit tests for {@link ArithmeticOperators}.
*
* @author Christoph Strobl
* @author Mark Paluch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2261,6 +2261,22 @@ void shouldRenderPercentileWithMultipleArgsAggregationExpression() {
assertThat(agg).isEqualTo(Document.parse("{ $project: { scorePercentiles: { $percentile: { input: [\"$scoreOne\", \"$scoreTwo\"], method: \"approximate\", p: [0.4] } }} } }"));
}

@Test // GH-4472
void shouldRenderMedianAggregationExpressions() {

Document singleArgAgg = project()
.and(ArithmeticOperators.valueOf("score").median()).as("medianValue")
.toDocument(Aggregation.DEFAULT_CONTEXT);

assertThat(singleArgAgg).isEqualTo(Document.parse("{ $project: { medianValue: { $median: { input: \"$score\", method: \"approximate\" } }} } }"));

Document multipleArgsAgg = project()
.and(ArithmeticOperators.valueOf("score").median().and("scoreTwo")).as("medianValue")
.toDocument(Aggregation.DEFAULT_CONTEXT);

assertThat(multipleArgsAgg).isEqualTo(Document.parse("{ $project: { medianValue: { $median: { input: [\"$score\", \"$scoreTwo\"], method: \"approximate\" } }} } }"));
}

private static Document extractOperation(String field, Document fromProjectClause) {
return (Document) fromProjectClause.get(field);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ At the time of this writing, we provide support for the following Aggregation Op
| `setEquals`, `setIntersection`, `setUnion`, `setDifference`, `setIsSubset`, `anyElementTrue`, `allElementsTrue`

| Group/Accumulator Aggregation Operators
| `addToSet`, `bottom`, `bottomN`, `covariancePop`, `covarianceSamp`, `expMovingAvg`, `first`, `firstN`, `last`, `lastN` `max`, `maxN`, `min`, `minN`, `avg`, `push`, `sum`, `top`, `topN`, `count` (+++*+++), `percentile`, `stdDevPop`, `stdDevSamp`
| `addToSet`, `bottom`, `bottomN`, `covariancePop`, `covarianceSamp`, `expMovingAvg`, `first`, `firstN`, `last`, `lastN` `max`, `maxN`, `min`, `minN`, `avg`, `push`, `sum`, `top`, `topN`, `count` (+++*+++), `median`, `percentile`, `stdDevPop`, `stdDevSamp`

| Arithmetic Aggregation Operators
| `abs`, `acos`, `acosh`, `add` (+++*+++ via `plus`), `asin`, `asin`, `atan`, `atan2`, `atanh`, `ceil`, `cos`, `cosh`, `derivative`, `divide`, `exp`, `floor`, `integral`, `ln`, `log`, `log10`, `mod`, `multiply`, `pow`, `round`, `sqrt`, `subtract` (+++*+++ via `minus`), `sin`, `sinh`, `tan`, `tanh`, `trunc`
Expand Down

0 comments on commit d506cd0

Please sign in to comment.