-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
12d3484
commit 716f435
Showing
3 changed files
with
36 additions
and
56 deletions.
There are no files selected for viewing
56 changes: 13 additions & 43 deletions
56
apps/setting/models_provider/impl/tencent_model_provider/credential/embedding.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
33 changes: 21 additions & 12 deletions
33
apps/setting/models_provider/impl/tencent_model_provider/model/embedding.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,25 +1,34 @@ | ||
from setting.models_provider.base_model_provider import MaxKBBaseModel | ||
from typing import Dict | ||
import requests | ||
from typing import Dict, List | ||
|
||
from langchain_core.embeddings import Embeddings | ||
from tencentcloud.common import credential | ||
from tencentcloud.hunyuan.v20230901.hunyuan_client import HunyuanClient | ||
from tencentcloud.hunyuan.v20230901.models import GetEmbeddingRequest | ||
|
||
class TencentEmbeddingModel(MaxKBBaseModel): | ||
def __init__(self, secret_id: str, secret_key: str, api_base: str, model_name: str): | ||
|
||
class TencentEmbeddingModel(Embeddings): | ||
def embed_documents(self, texts: List[str]) -> List[List[float]]: | ||
return [self.embed_query(text) for text in texts] | ||
|
||
def embed_query(self, text: str) -> List[float]: | ||
request = GetEmbeddingRequest() | ||
request.Input = text | ||
res = self.client.GetEmbedding(request) | ||
return res.Data | ||
|
||
def __init__(self, secret_id: str, secret_key: str, model_name: str): | ||
self.secret_id = secret_id | ||
self.secret_key = secret_key | ||
self.api_base = api_base | ||
self.model_name = model_name | ||
cred = credential.Credential( | ||
secret_id, secret_key | ||
) | ||
self.client = HunyuanClient(cred, "") | ||
|
||
@staticmethod | ||
def new_instance(model_type: str, model_name: str, model_credential: Dict[str, str], **model_kwargs): | ||
return TencentEmbeddingModel( | ||
secret_id=model_credential.get('SecretId'), | ||
secret_key=model_credential.get('SecretKey'), | ||
api_base=model_credential.get('api_base'), | ||
model_name=model_name, | ||
) | ||
|
||
|
||
def _generate_auth_token(self): | ||
# Example method to generate an authentication token for the model API | ||
return f"{self.secret_id}:{self.secret_key}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters