-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathprocess_audio_file.py
74 lines (57 loc) · 2 KB
/
process_audio_file.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# Simplified version of the model. Receives .wav audio as input and returns adequate .kern format
import gc
import os
import fire
import torch
from lightning.pytorch import Trainer
from lightning.pytorch.loggers.wandb import WandbLogger
from transformer.model import Transformer
from data.ar_dataset import ARDataModule
from utils.seed import seed_everything
seed_everything(42, benchmark=False)
def test(
input_audio_folder,
output_path_folder,
krn_encoding: str = "bekern",
checkpoint_path: str = "",
):
gc.collect()
torch.cuda.empty_cache()
# Check if checkpoint path is empty or does not exist
if checkpoint_path == "":
raise ValueError("Checkpoint path not provided")
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint path {checkpoint_path} does not exist")
# Get source dataset name
_, src_ds_name, model_name = checkpoint_path.split("/")
# Experiment info
print("INFERENCE ON IMAGE")
print(f"\tSource dataset: {src_ds_name}")
print(f"\tTest dataset: {input_audio_folder}")
print(f"\tKern encoding: {krn_encoding}")
print(f"\tCheckpoint path: {checkpoint_path}")
input_modality = "audio" # "audio" or "image" or "both"
use_distorted_images = False # Only used if input_modality == "image" or "both"
img_height = None
# Data module
datamodule = ARDataModule(
ds_name=input_audio_folder,
krn_encoding=krn_encoding,
input_modality=input_modality,
use_distorted_images=use_distorted_images,
img_height=img_height,
inference=True,
)
datamodule.setup(stage="predict")
ytest_i2w = datamodule.test_ds.i2w
# Model
model = Transformer.load_from_checkpoint(checkpoint_path, ytest_i2w=ytest_i2w)
# Test
trainer = Trainer(
precision="16-mixed", # Mixed precision training
)
model.eval()
breakpoint()
output = trainer.test(model, datamodule=datamodule)
if __name__ == "__main__":
fire.Fire(test)