Skip to content

Commit

Permalink
Update multilingual_universal_sentence_encoder.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yangheng95 authored Aug 30, 2023
1 parent 9acd23f commit d5bb399
Showing 1 changed file with 3 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,13 @@ def __init__(self, threshold=0.8, large=False, metric="angular", **kwargs):
tensorflow_text._load()
if large:
tfhub_url = "https://tfhub.dev/google/universal-sentence-encoder-multilingual-large/3"
mirror_tfhub_url = "https://hub.tensorflow.google.cn/google/universal-sentence-encoder-multilingual-large/3"
else:
tfhub_url = "https://https://tfhub.dev/google/universal-sentence-encoder-multilingual/3"
mirror_tfhub_url = "https://hub.tensorflow.google.cn/google/universal-sentence-encoder-multilingual/3"

# TODO add QA SET. Details at: https://hub.tensorflow.google.cn/google/universal-sentence-encoder-multilingual-qa/3
self._tfhub_url = tfhub_url
self.mirror_tfhub_url = mirror_tfhub_url
try:
self.model = hub.load(self._tfhub_url)
except Exception as e:
print('Error loading model from tfhub, trying mirror url')
self.model = hub.load(self.mirror_tfhub_url)
self.model = hub.load(self._tfhub_url)


def encode(self, sentences):
return self.model(sentences).numpy()
Expand All @@ -44,8 +38,4 @@ def __getstate__(self):

def __setstate__(self, state):
self.__dict__ = state
try:
self.model = hub.load(self._tfhub_url)
except Exception as e:
print('Error loading model from tfhub, trying mirror url')
self.model = hub.load(self.mirror_tfhub_url)
self.model = hub.load(self._tfhub_url)

0 comments on commit d5bb399

Please sign in to comment.