diff --git a/src/me/xiaosheng/word2vec/Test.java b/src/me/xiaosheng/word2vec/Test.java index eb21912..07cc5ca 100644 --- a/src/me/xiaosheng/word2vec/Test.java +++ b/src/me/xiaosheng/word2vec/Test.java @@ -6,6 +6,8 @@ import java.util.List; import java.util.Set; +import com.ansj.vec.domain.WordEntry; + public class Test { public static void main(String[] args) { @@ -17,16 +19,22 @@ public static void main(String[] args) { // TODO Auto-generated catch block e.printStackTrace(); } + //计算词语相似度 System.out.println(vec.wordSimilarity("狗", "猫")); System.out.println(vec.wordSimilarity("计算机", "电脑")); System.out.println(vec.wordSimilarity("计算机", "人")); - + //计算句子相似度 String s1 = "苏州 有 多条 公路 正在 施工 造成 局部 地区 汽车 行驶 非常 缓慢"; String s2 = "苏州 最近 有 多条 公路 在 施工 导致 部分 地区 交通 拥堵 汽车 难以 通行"; String s3 = "苏州 是 一座 美丽 的 城市 四季 分明 雨量 充沛"; System.out.println(vec.sentenceSimilairy(s1, s1)); System.out.println(vec.sentenceSimilairy(s1, s2)); System.out.println(vec.sentenceSimilairy(s1, s3)); + //获取相似的词语 + Set similarWords = vec.getSimilarWords("漂亮", 10); + for(WordEntry word : similarWords) { + System.out.println(word.name + " : " + word.score); + } // try { // Word2Vec.trainJavaModel("data/train.txt", "data/test.model"); // } catch (IOException e) { diff --git a/src/me/xiaosheng/word2vec/Word2Vec.java b/src/me/xiaosheng/word2vec/Word2Vec.java index fa55b9e..e785bf5 100644 --- a/src/me/xiaosheng/word2vec/Word2Vec.java +++ b/src/me/xiaosheng/word2vec/Word2Vec.java @@ -3,13 +3,17 @@ import java.io.File; import java.io.IOException; import java.util.Arrays; +import java.util.Collections; import java.util.HashSet; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Set; +import java.util.TreeSet; import com.ansj.vec.Learn; import com.ansj.vec.Word2VEC; +import com.ansj.vec.domain.WordEntry; public class Word2Vec { @@ -49,6 +53,19 @@ public float[] getWordVector(String word) { } return vec.getWordVector(word); } + /** + * 计算向量内积 + * @param vec1 + * @param vec2 + * @return + */ + private float calDist(float[] vec1, float[] vec2) { + float dist = 0; + for (int i = 0; i < vec1.length; i++) { + dist += vec1[i] * vec2[i]; + } + return dist; + } /** * 计算词相似度 * @param word1 @@ -64,11 +81,41 @@ public float wordSimilarity(String word1, String word2) { if(word1Vec == null || word2Vec == null) { return -1; } - float dist = 0; - for (int i = 0; i < word1Vec.length; i++) { - dist += word1Vec[i] * word2Vec[i]; + return calDist(word1Vec, word2Vec); + } + /** + * 获取相似词语 + * @param word + * @param maxReturnNum + * @return + */ + public Set getSimilarWords(String word, int maxReturnNum) { + if (loadModel == false) + return null; + float[] center = getWordVector(word); + if (center == null) { + return Collections.emptySet(); } - return dist; + int resultSize = vec.getWords() < maxReturnNum ? vec.getWords() : maxReturnNum; + TreeSet result = new TreeSet(); + double min = Double.MIN_VALUE; + for (Map.Entry entry : vec.getWordMap().entrySet()) { + float[] vector = entry.getValue(); + float dist = calDist(center, vector); + if (result.size() <= resultSize) { + result.add(new WordEntry(entry.getKey(), dist)); + min = result.last().score; + } else { + if (dist > min) { + result.add(new WordEntry(entry.getKey(), dist)); + result.pollLast(); + min = result.last().score; + } + } + } + result.pollFirst(); + + return result; } /** * 计算句子相似度