Skip to content

Commit

Permalink
增加保存音频模型接口
Browse files Browse the repository at this point in the history
  • Loading branch information
v3ucn committed Nov 2, 2024
1 parent 3f139d7 commit fce3408
Showing 1 changed file with 68 additions and 0 deletions.
68 changes: 68 additions & 0 deletions api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR))
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))

import requests
from pydub import AudioSegment

import numpy as np
from flask import Flask, request, Response,send_from_directory
import torch
Expand All @@ -19,6 +22,8 @@
from flask_cors import CORS
from flask import make_response

import shutil

import json

cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-25Hz')
Expand All @@ -41,6 +46,38 @@
CORS(app, supports_credentials=True)


def download_and_convert(mp3_url, wav_filename):
"""Downloads an MP3 file and converts it to WAV.
Args:
mp3_url: The URL of the MP3 file.
wav_filename: The desired filename for the WAV file (including .wav extension).
"""
try:
response = requests.get(mp3_url, stream=True)
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)

with open("temp.mp3", "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)

# Convert MP3 to WAV using pydub
sound = AudioSegment.from_mp3("temp.mp3")
sound.export(wav_filename, format="wav")

print(f"音频已成功下载并转换为 {wav_filename}")

except requests.exceptions.RequestException as e:
print(f"下载音频时出错: {e}")
except Exception as e:
print(f"转换音频时出错: {e}")
finally:
# Clean up the temporary MP3 file
import os
try:
os.remove("temp.mp3")
except OSError as e:
print(f"删除临时文件时出错: {e}")
def speed_change(input_audio: np.ndarray, speed: float, sr: int):
# 检查输入数据类型和声道数
if input_audio.dtype != np.int16:
Expand Down Expand Up @@ -122,6 +159,37 @@ def generate():
response.headers['Content-Disposition'] = 'attachment; filename=sound.ogg'
return response

@app.route("/save_voice", methods=['GET'])
def save_voice():

text = request.args.get('text')
audio = request.args.get('audio')
voice_name = request.args.get('voice_name')


download_and_convert(audio,"zero_test.wav")

prompt_speech_16k = load_wav('zero_test.wav', 16000)
tts_speeches = []
for i, j in enumerate(cosyvoice.inference_zero_shot(text,text, prompt_speech_16k, stream=False)):
# torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], 22050)
tts_speeches.append(j['tts_speech'])

audio_data = torch.concat(tts_speeches, dim=1)
torchaudio.save('zero_shot.wav',audio_data, 22050, format="wav")

shutil.copyfile(f"{ROOT_DIR}/output.pt",f"{ROOT_DIR}/voices/{voice_name}.pt")

response = app.response_class(
response=json.dumps({"voice_name":voice_name}),
status=200,
mimetype='application/json'
)
return response





@app.route("/", methods=['GET'])
def sft_get():
Expand Down

0 comments on commit fce3408

Please sign in to comment.