-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathpredict.py
executable file
·214 lines (169 loc) · 7.39 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
# never push DEBUG_MODE = True to Replicate!
# DEBUG_MODE = False
import os
import shutil
import time
import random
import sys
import json
import tempfile
import requests
import subprocess
import signal
from typing import Iterator, Optional
from dotenv import load_dotenv
from dataclasses import dataclass, asdict
from cog import BasePredictor, BaseModel, File, Input, Path as cogPath
import torch
import torchaudio
from audiocraft.data.audio import audio_write
from audiocraft.models import AudioGen, MusicGen
load_dotenv()
os.environ["TORCH_HOME"] = "/src/.torch"
os.environ["TRANSFORMERS_CACHE"] = "/src/.huggingface/"
os.environ["DIFFUSERS_CACHE"] = "/src/.huggingface/"
os.environ["HF_HOME"] = "/src/.huggingface/"
if DEBUG_MODE:
debug_output_dir = "/src/tests/server/debug_output"
if os.path.exists(debug_output_dir):
shutil.rmtree(debug_output_dir)
os.makedirs(debug_output_dir, exist_ok=True)
class CogOutput(BaseModel):
files: Optional[list[cogPath]] = []
name: Optional[str] = None
thumbnails: Optional[list[cogPath]] = []
attributes: Optional[dict] = None
progress: Optional[float] = None
isFinal: bool = False
import subprocess
def convert_wav_to_mp3(wav_file_path, mp3_file_path, bitrate='192k'):
"""
Converts a WAV file to an MP3 file using FFmpeg.
Args:
wav_file_path (str): The path to the input WAV file.
mp3_file_path (str): The path to the output MP3 file.
bitrate (str): The bitrate of the output MP3 file. Default is 192k.
"""
try:
subprocess.run(['ffmpeg', '-i', wav_file_path, '-ab', bitrate, mp3_file_path], check=True)
print(f"Conversion complete: {mp3_file_path}")
except subprocess.CalledProcessError as e:
print(f"An error occurred: {e}")
#https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md
MODEL_INFO = {
'facebook/audiogen-medium': '1.5B model, text to sound - 🤗 Hub',
'facebook/musicgen-medium': '1.5B model, text to music - 🤗 Hub',
'facebook/musicgen-large': '3.3B model, text to music - 🤗 Hub',
#'facebook/musicgen-melody': '1.5B model, text to music and text+melody to music - 🤗 Hub',
#'facebook/musicgen-melody-large': '3.3B model, text to music and text+melody to music - 🤗 Hub',
}
class Predictor(BasePredictor):
GENERATOR_OUTPUT_TYPE = cogPath if DEBUG_MODE else CogOutput
def setup(self):
print("cog:setup")
# https://www.reddit.com/r/audiocraft/comments/146914g/longer_than_30_sec/
def generate_long_audio(self, model, text, duration, topk=250, topp=0, temperature=1.0, cfg_coef=3.0, overlap=5):
topk = int(topk)
output = None
total_samples = duration * 50 + 3
segment_duration = duration
while duration > 0:
if output is None: # first pass of long or short song
if segment_duration > model.lm.cfg.dataset.segment_duration:
segment_duration = model.lm.cfg.dataset.segment_duration
else:
segment_duration = duration
else: # next pass of long song
if duration + overlap < model.lm.cfg.dataset.segment_duration:
segment_duration = duration + overlap
else:
segment_duration = model.lm.cfg.dataset.segment_duration
print(f'Segment duration: {segment_duration}, duration: {duration}, overlap: {overlap}')
model.set_generation_params(
use_sampling=True,
top_k=topk,
top_p=topp,
temperature=temperature,
cfg_coef=cfg_coef,
duration=min(segment_duration, 30), # ensure duration does not exceed 30
)
if output is None:
next_segment = model.generate(descriptions=[text])
duration -= segment_duration
else:
last_chunk = output[:, :, -overlap*model.sample_rate:]
next_segment = model.generate_continuation(last_chunk, model.sample_rate, descriptions=[text])
duration -= segment_duration - overlap
if output is None:
output = next_segment
else:
output = torch.cat([output[:, :, :-overlap*model.sample_rate], next_segment], 2)
return output
def generate(self, model_name, text_input, desired_duration):
if "musicgen" in model_name:
model = MusicGen.get_pretrained(model_name)
elif "audiogen" in model_name:
model = AudioGen.get_pretrained(model_name)
else:
raise ValueError("model_name must contain 'musicgen' or 'audiogen'")
#desired_duration = min(desired_duration, model.max_duration)
model.set_generation_params(duration=int(desired_duration))
# Generate the audio:
if desired_duration < 30:
wav = model.generate([text_input])
else:
wav = self.generate_long_audio(model, text_input, desired_duration)
wav = wav[0].cpu()
return wav, model.sample_rate
def predict(
self,
# Universal args
model_name: str = Input(
description="Model name", default="facebook/audiogen-medium",
choices=MODEL_INFO.keys()
),
# Create mode
text_input: str = Input(
description="Text description of the sound / music", default=None
),
duration_seconds: float = Input(
description="Duration of the audio in seconds",
ge=1.0, le=120.0, default=10.0
),
) -> Iterator[GENERATOR_OUTPUT_TYPE]:
t_start = time.time()
for i in range(3):
print("-------------------------------------------------------")
if not text_input:
raise ValueError("text_input is required")
if model_name not in MODEL_INFO.keys():
print(f"Invalid model_name: {model_name}")
print(f"Valid options are:")
print(MODEL_INFO)
raise ValueError(f"Invalid audio model_name: {model_name}")
print(f"cog:predict: {model_name}")
wav, sample_rate = self.generate(model_name, text_input, duration_seconds)
audio_write('/src/tmp_wav', wav, sample_rate, strategy="loudness", loudness_compressor=True)
out_path = f"{str(int(time.time()))}_{model_name}_output.mp3"
out_path = out_path.replace('/', "_")
convert_wav_to_mp3('/src/tmp_wav.wav', out_path)
print(f"Final audio saved to {out_path}")
attributes = {
"model_name": model_name,
"model_info": MODEL_INFO[model_name],
"duration_seconds": duration_seconds,
"text_input": text_input,
"job_time_seconds": time.time() - t_start,
}
if DEBUG_MODE:
print(attributes)
#shutil.copyfile(out_path, os.path.join(debug_output_dir, prediction_name + ".mp4"))
yield cogPath(out_path)
else:
yield CogOutput(files=[cogPath(out_path)], name=text_input, thumbnails=[cogPath('/src/sound.png')], attributes=attributes, isFinal=True, progress=1.0)
if DEBUG_MODE:
print("--------------------------------")
print("--- cog was in DEBUG mode!!! ---")
print("--------------------------------")
t_end = time.time()
print(f"predict.py: done in {t_end - t_start:.2f} seconds")