Skip to content

Commit

Permalink
add get similar words func
Browse files Browse the repository at this point in the history
  • Loading branch information
jsksxs360 committed Nov 11, 2016
1 parent f60b47f commit 32ecfe1
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 5 deletions.
10 changes: 9 additions & 1 deletion src/me/xiaosheng/word2vec/Test.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import java.util.List;
import java.util.Set;

import com.ansj.vec.domain.WordEntry;

public class Test {

public static void main(String[] args) {
Expand All @@ -17,16 +19,22 @@ public static void main(String[] args) {
// TODO Auto-generated catch block
e.printStackTrace();
}
//计算词语相似度
System.out.println(vec.wordSimilarity("狗", "猫"));
System.out.println(vec.wordSimilarity("计算机", "电脑"));
System.out.println(vec.wordSimilarity("计算机", "人"));

//计算句子相似度
String s1 = "苏州 有 多条 公路 正在 施工 造成 局部 地区 汽车 行驶 非常 缓慢";
String s2 = "苏州 最近 有 多条 公路 在 施工 导致 部分 地区 交通 拥堵 汽车 难以 通行";
String s3 = "苏州 是 一座 美丽 的 城市 四季 分明 雨量 充沛";
System.out.println(vec.sentenceSimilairy(s1, s1));
System.out.println(vec.sentenceSimilairy(s1, s2));
System.out.println(vec.sentenceSimilairy(s1, s3));
//获取相似的词语
Set<WordEntry> similarWords = vec.getSimilarWords("漂亮", 10);
for(WordEntry word : similarWords) {
System.out.println(word.name + " : " + word.score);
}
// try {
// Word2Vec.trainJavaModel("data/train.txt", "data/test.model");
// } catch (IOException e) {
Expand Down
55 changes: 51 additions & 4 deletions src/me/xiaosheng/word2vec/Word2Vec.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;

import com.ansj.vec.Learn;
import com.ansj.vec.Word2VEC;
import com.ansj.vec.domain.WordEntry;

public class Word2Vec {

Expand Down Expand Up @@ -49,6 +53,19 @@ public float[] getWordVector(String word) {
}
return vec.getWordVector(word);
}
/**
* 计算向量内积
* @param vec1
* @param vec2
* @return
*/
private float calDist(float[] vec1, float[] vec2) {
float dist = 0;
for (int i = 0; i < vec1.length; i++) {
dist += vec1[i] * vec2[i];
}
return dist;
}
/**
* 计算词相似度
* @param word1
Expand All @@ -64,11 +81,41 @@ public float wordSimilarity(String word1, String word2) {
if(word1Vec == null || word2Vec == null) {
return -1;
}
float dist = 0;
for (int i = 0; i < word1Vec.length; i++) {
dist += word1Vec[i] * word2Vec[i];
return calDist(word1Vec, word2Vec);
}
/**
* 获取相似词语
* @param word
* @param maxReturnNum
* @return
*/
public Set<WordEntry> getSimilarWords(String word, int maxReturnNum) {
if (loadModel == false)
return null;
float[] center = getWordVector(word);
if (center == null) {
return Collections.emptySet();
}
return dist;
int resultSize = vec.getWords() < maxReturnNum ? vec.getWords() : maxReturnNum;
TreeSet<WordEntry> result = new TreeSet<WordEntry>();
double min = Double.MIN_VALUE;
for (Map.Entry<String, float[]> entry : vec.getWordMap().entrySet()) {
float[] vector = entry.getValue();
float dist = calDist(center, vector);
if (result.size() <= resultSize) {
result.add(new WordEntry(entry.getKey(), dist));
min = result.last().score;
} else {
if (dist > min) {
result.add(new WordEntry(entry.getKey(), dist));
result.pollLast();
min = result.last().score;
}
}
}
result.pollFirst();

return result;
}
/**
* 计算句子相似度
Expand Down

0 comments on commit 32ecfe1

Please sign in to comment.