Skip to content

Commit

Permalink
fix: 修复wss链接的证书问题
Browse files Browse the repository at this point in the history
  • Loading branch information
liuruibin committed Sep 4, 2024
1 parent b500404 commit 4791ff1
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from io import BytesIO
from typing import Dict
from urllib.parse import urlparse

import ssl
import websockets

from setting.models_provider.base_model_provider import MaxKBBaseModel
Expand Down Expand Up @@ -61,6 +61,10 @@
GZIP = 0b0001
CUSTOM_COMPRESSION = 0b1111

ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE


def generate_header(
version=PROTOCOL_VERSION,
Expand Down Expand Up @@ -292,7 +296,8 @@ async def segment_data_processor(self, wav_data: bytes, segment_size: int):
header = self.token_auth()
elif self.auth_method == "signature":
header = self.signature_auth(full_client_request)
async with websockets.connect(self.volcanic_api_url, extra_headers=header, max_size=1000000000) as ws:
async with websockets.connect(self.volcanic_api_url, extra_headers=header, max_size=1000000000,
ssl=ssl_context) as ws:
# 发送 full client request
await ws.send(full_client_request)
res = await ws.recv()
Expand All @@ -319,7 +324,8 @@ def check_auth(self):
header = self.token_auth()

async def check():
async with websockets.connect(self.volcanic_api_url, extra_headers=header, max_size=1000000000) as ws:
async with websockets.connect(self.volcanic_api_url, extra_headers=header, max_size=1000000000,
ssl=ssl_context) as ws:
pass

asyncio.run(check())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import json
import uuid
from typing import Dict

import ssl
import websockets

from setting.models_provider.base_model_provider import MaxKBBaseModel
Expand All @@ -35,6 +35,10 @@
# reserved data: 0x00 (1 byte)
default_header = bytearray(b'\x11\x10\x11\x00')

ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE


class VolcanicEngineTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
volcanic_app_id: str
Expand Down Expand Up @@ -68,7 +72,8 @@ def check_auth(self):
header = self.token_auth()

async def check():
async with websockets.connect(self.volcanic_api_url, extra_headers=header, ping_interval=None) as ws:
async with websockets.connect(self.volcanic_api_url, extra_headers=header, ping_interval=None,
ssl=ssl_context) as ws:
pass

asyncio.run(check())
Expand Down Expand Up @@ -113,7 +118,8 @@ async def submit(self, request_json):
full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
full_client_request.extend(payload_bytes) # payload
header = {"Authorization": f"Bearer; {self.volcanic_token}"}
async with websockets.connect(self.volcanic_api_url, extra_headers=header, ping_interval=None) as ws:
async with websockets.connect(self.volcanic_api_url, extra_headers=header, ping_interval=None,
ssl=ssl_context) as ws:
await ws.send(full_client_request)
return await self.parse_response(ws)

Expand Down Expand Up @@ -161,4 +167,4 @@ async def parse_response(ws):
payload = gzip.decompress(payload)
else:
break
return result
return result
10 changes: 7 additions & 3 deletions apps/setting/models_provider/impl/xf_model_provider/model/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from datetime import datetime
from typing import Dict
from urllib.parse import urlencode, urlparse

import ssl
import websockets

from setting.models_provider.base_model_provider import MaxKBBaseModel
Expand All @@ -21,6 +21,10 @@
STATUS_CONTINUE_FRAME = 1 # 中间帧标识
STATUS_LAST_FRAME = 2 # 最后一帧的标识

ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE


class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
spark_app_id: str
Expand Down Expand Up @@ -86,14 +90,14 @@ def create_url(self):

def check_auth(self):
async def check():
async with websockets.connect(self.create_url()) as ws:
async with websockets.connect(self.create_url(), ssl=ssl_context) as ws:
pass

asyncio.run(check())

def speech_to_text(self, file):
async def handle():
async with websockets.connect(self.create_url(), max_size=1000000000) as ws:
async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws:
# 发送 full client request
await self.send(ws, file)
return await self.handle_message(ws)
Expand Down
10 changes: 7 additions & 3 deletions apps/setting/models_provider/impl/xf_model_provider/model/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from datetime import datetime
from typing import Dict
from urllib.parse import urlencode, urlparse

import ssl
import websockets

from setting.models_provider.base_model_provider import MaxKBBaseModel
Expand All @@ -24,6 +24,10 @@
STATUS_CONTINUE_FRAME = 1 # 中间帧标识
STATUS_LAST_FRAME = 2 # 最后一帧的标识

ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE


class XFSparkTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
spark_app_id: str
Expand Down Expand Up @@ -89,7 +93,7 @@ def create_url(self):

def check_auth(self):
async def check():
async with websockets.connect(self.create_url(), max_size=1000000000) as ws:
async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws:
pass

asyncio.run(check())
Expand All @@ -99,7 +103,7 @@ def text_to_speech(self, text):
# 使用小语种须使用以下方式,此处的unicode指的是 utf16小端的编码方式,即"UTF-16LE"”
# self.Data = {"status": 2, "text": str(base64.b64encode(self.Text.encode('utf-16')), "UTF8")}
async def handle():
async with websockets.connect(self.create_url(), max_size=1000000000) as ws:
async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws:
# 发送 full client request
await self.send(ws, text)
return await self.handle_message(ws)
Expand Down

0 comments on commit 4791ff1

Please sign in to comment.