Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: 修复wss链接的证书问题 #1116

Merged
merged 1 commit into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from io import BytesIO
from typing import Dict
from urllib.parse import urlparse

import ssl
import websockets

from setting.models_provider.base_model_provider import MaxKBBaseModel
Expand Down Expand Up @@ -61,6 +61,10 @@
GZIP = 0b0001
CUSTOM_COMPRESSION = 0b1111

ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE


def generate_header(
version=PROTOCOL_VERSION,
Expand Down Expand Up @@ -292,7 +296,8 @@ async def segment_data_processor(self, wav_data: bytes, segment_size: int):
header = self.token_auth()
elif self.auth_method == "signature":
header = self.signature_auth(full_client_request)
async with websockets.connect(self.volcanic_api_url, extra_headers=header, max_size=1000000000) as ws:
async with websockets.connect(self.volcanic_api_url, extra_headers=header, max_size=1000000000,
ssl=ssl_context) as ws:
# 发送 full client request
await ws.send(full_client_request)
res = await ws.recv()
Expand All @@ -319,7 +324,8 @@ def check_auth(self):
header = self.token_auth()

async def check():
async with websockets.connect(self.volcanic_api_url, extra_headers=header, max_size=1000000000) as ws:
async with websockets.connect(self.volcanic_api_url, extra_headers=header, max_size=1000000000,
ssl=ssl_context) as ws:
pass

asyncio.run(check())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import json
import uuid
from typing import Dict

import ssl
import websockets

from setting.models_provider.base_model_provider import MaxKBBaseModel
Expand All @@ -35,6 +35,10 @@
# reserved data: 0x00 (1 byte)
default_header = bytearray(b'\x11\x10\x11\x00')

ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE


class VolcanicEngineTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
volcanic_app_id: str
Expand Down Expand Up @@ -68,7 +72,8 @@ def check_auth(self):
header = self.token_auth()

async def check():
async with websockets.connect(self.volcanic_api_url, extra_headers=header, ping_interval=None) as ws:
async with websockets.connect(self.volcanic_api_url, extra_headers=header, ping_interval=None,
ssl=ssl_context) as ws:
pass

asyncio.run(check())
Expand Down Expand Up @@ -113,7 +118,8 @@ async def submit(self, request_json):
full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
full_client_request.extend(payload_bytes) # payload
header = {"Authorization": f"Bearer; {self.volcanic_token}"}
async with websockets.connect(self.volcanic_api_url, extra_headers=header, ping_interval=None) as ws:
async with websockets.connect(self.volcanic_api_url, extra_headers=header, ping_interval=None,
ssl=ssl_context) as ws:
await ws.send(full_client_request)
return await self.parse_response(ws)

Expand Down Expand Up @@ -161,4 +167,4 @@ async def parse_response(ws):
payload = gzip.decompress(payload)
else:
break
return result
return result
10 changes: 7 additions & 3 deletions apps/setting/models_provider/impl/xf_model_provider/model/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from datetime import datetime
from typing import Dict
from urllib.parse import urlencode, urlparse

import ssl
import websockets

from setting.models_provider.base_model_provider import MaxKBBaseModel
Expand All @@ -21,6 +21,10 @@
STATUS_CONTINUE_FRAME = 1 # 中间帧标识
STATUS_LAST_FRAME = 2 # 最后一帧的标识

ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE


class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
spark_app_id: str
Expand Down Expand Up @@ -86,14 +90,14 @@ def create_url(self):

def check_auth(self):
async def check():
async with websockets.connect(self.create_url()) as ws:
async with websockets.connect(self.create_url(), ssl=ssl_context) as ws:
pass

asyncio.run(check())

def speech_to_text(self, file):
async def handle():
async with websockets.connect(self.create_url(), max_size=1000000000) as ws:
async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws:
# 发送 full client request
await self.send(ws, file)
return await self.handle_message(ws)
Expand Down
10 changes: 7 additions & 3 deletions apps/setting/models_provider/impl/xf_model_provider/model/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from datetime import datetime
from typing import Dict
from urllib.parse import urlencode, urlparse

import ssl
import websockets

from setting.models_provider.base_model_provider import MaxKBBaseModel
Expand All @@ -24,6 +24,10 @@
STATUS_CONTINUE_FRAME = 1 # 中间帧标识
STATUS_LAST_FRAME = 2 # 最后一帧的标识

ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE


class XFSparkTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
spark_app_id: str
Expand Down Expand Up @@ -89,7 +93,7 @@ def create_url(self):

def check_auth(self):
async def check():
async with websockets.connect(self.create_url(), max_size=1000000000) as ws:
async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws:
pass

asyncio.run(check())
Expand All @@ -99,7 +103,7 @@ def text_to_speech(self, text):
# 使用小语种须使用以下方式,此处的unicode指的是 utf16小端的编码方式,即"UTF-16LE"”
# self.Data = {"status": 2, "text": str(base64.b64encode(self.Text.encode('utf-16')), "UTF8")}
async def handle():
async with websockets.connect(self.create_url(), max_size=1000000000) as ws:
async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws:
# 发送 full client request
await self.send(ws, text)
return await self.handle_message(ws)
Expand Down
Loading