Skip to content

Commit

Permalink
[feat]加权融合
Browse files Browse the repository at this point in the history
  • Loading branch information
ZWJason committed Jan 14, 2025
1 parent ab5e045 commit c7381a4
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 4 deletions.
11 changes: 11 additions & 0 deletions src/main/java/com/search/docsearch/constant/Constants.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,15 @@ public class Constants {


public static final String HTTPS_PREFIX = "https://";


/**
* Maxsocre that used to normlize the result
*/
public static final int MAX_SCORE = 10000000;

/**
* Min socre that used to normlize the result
*/
public static final int MIN_SCORE = -1;
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.search.docsearch.constant.Constants;
import com.search.docsearch.utils.MergeUtil;

public class DataComposite implements Component {

Expand Down Expand Up @@ -110,8 +114,40 @@ public Map<String, Object> mergeResult(){
aresList.add(pos, bresList.get(pos));
}
}

ares.put("records", aresList);
return ares;
}
}

/**
* merge the other recall results into one way, based one the index 0 of children
*
* @return the merged result lists
*/
public List<Map<String, Object>> weightedMerge(int pageSize){
List<Map<String, Object>> mergeList = new ArrayList<>();

for (Component recall : this.children){
double minScore = Constants.MAX_SCORE;
double maxScore = Constants.MIN_SCORE;
List<Map<String, Object>> rcords = (List<Map<String, Object>>) recall.getResList().get("records");
// find min and max
for (Map<String, Object> entity : rcords) {
double score = (double) entity.get("score");
minScore = Math.min(score,minScore);
maxScore = Math.max(score, maxScore);
}
// do norm
for (Map<String, Object> entity : rcords) {
double score = (double) entity.get("score");
double normedScore = MergeUtil.normalize(score, minScore, maxScore);
entity.put("score", normedScore);
mergeList.add(entity);
}
}

mergeList = mergeList.stream().sorted((a, b) -> Double.compare((Double) b.get("score"), (Double) a.get("score"))).collect(Collectors.toList());

return mergeList.subList(0, Math.min(pageSize, mergeList.size()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ private Component searchByCondition(SearchCondition condition) throws ServiceImp
} else {
map.put("lang", "zh");
}
map.put("score", 5000 - (count + start) * 50);
map.put("score", (double) (5000 - (count + start) * 50));
count++;
data.add(map);
}
Expand Down
28 changes: 28 additions & 0 deletions src/main/java/com/search/docsearch/utils/MergeUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/* Copyright (c) 2024 openEuler Community
EasySoftware is licensed under the Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
*/
package com.search.docsearch.utils;

public class MergeUtil {

/**
* normalize the score according to their own score
*
* @return the normalied score of search results
*/
public static double normalize(double score, double minScore, double maxScore) {
// 检查范围是否有效
if (maxScore <= minScore) {
throw new IllegalArgumentException("maxScore 必须大于 minScore");
}
// 归一化公式 (score - minScore) / (maxScore - minScore)
return (score - minScore) / (maxScore - minScore);
}
}
52 changes: 51 additions & 1 deletion src/test/java/com/search/docsearch/CompositeTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
package com.search.docsearch;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.Test;
import org.springframework.boot.test.context.SpringBootTest;

Expand All @@ -20,7 +20,11 @@
import com.search.docsearch.multirecall.composite.cdata.EsRecallData;
import com.search.docsearch.multirecall.composite.cdata.GRecallData;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@SpringBootTest
public class CompositeTest {

Expand Down Expand Up @@ -103,4 +107,50 @@ void testFliteringRecallWithError() {
assertEquals("error when process the recall res",exception.getMessage());
}

/**
* 测试: normlize加权融合测试
*/
@Test
void testWeightedMerge() {
// 设置mockComponent1的返回数据
List<Map<String, Object>> records1 = new ArrayList<>();
Map<String, Object> record1_1 = new HashMap<>();
record1_1.put("score", 3.0);
records1.add(record1_1);

Map<String, Object> record1_2 = new HashMap<>();
record1_2.put("score", 1.0);
records1.add(record1_2);

// 设置mockComponent2的返回数据
List<Map<String, Object>> records2 = new ArrayList<>();
Map<String, Object> record2_1 = new HashMap<>();
record2_1.put("score", 2.5);
records2.add(record2_1);

Map<String, Object> record2_2 = new HashMap<>();
record2_2.put("score", 4.0);
records2.add(record2_2);

DataComposite dataComposite = new DataComposite();

Component mockComponent1 = new EsRecallData(Collections.singletonMap("records", records1));
Component mockComponent2 = new EsRecallData(Collections.singletonMap("records", records2));

dataComposite.add(mockComponent1);
dataComposite.add(mockComponent2);
// 校验是否按pagesize返回正确个数
int pageSize = 3;
List<Map<String, Object>> result = dataComposite.weightedMerge(pageSize);
assertEquals(pageSize, result.size());

// 验证结果是否按分数降序排列
for (int i = 0; i < result.size() - 1; i++) {
double score1 = (Double) result.get(i).get("score");
double score2 = (Double) result.get(i + 1).get("score");
assertTrue(score1 >= score2, "Results should be sorted in descending order by score");
}
}


}

0 comments on commit c7381a4

Please sign in to comment.