forked from btyu/MidiProcessor
-
Notifications
You must be signed in to change notification settings - Fork 0
/
similarity.py
122 lines (95 loc) · 3.98 KB
/
similarity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import midiprocessor as mp
def generate_bar_index(pos_info):
bar_index = {}
begin = 0
cur_bar = pos_info[0][0]
for idx, item in enumerate(pos_info):
now_bar = item[0]
if now_bar != cur_bar:
bar_index[cur_bar] = (begin, idx)
cur_bar = now_bar
begin = idx
bar_index[cur_bar] = (begin, len(pos_info))
return bar_index
def generate_bar_insts_pos_index(bar):
r = {}
for idx, item in enumerate(bar):
notes = item[-1]
if notes is not None:
for inst in notes:
if inst not in r:
r[inst] = {}
r[inst][idx] = notes[inst]
return r
def cal_bar_similarity_basic(bar1, bar2, bar1_insts, bar2_insts, inter_insts, inter_pos, only_duration=False):
o_sim = {}
for inst in ((bar1_insts - bar2_insts) | (bar2_insts - bar1_insts)):
o_sim[inst] = 0.0
for inst in inter_insts:
inst_bar1 = bar1[inst]
inst_bar2 = bar2[inst]
# inst_bar1_pos = set(inst_bar1.keys())
# inst_bar2_pos = set(inst_bar2.keys())
# num_union_pos = len(inst_bar1_pos | inst_bar2_pos)
# inter_pos = inst_bar1_pos & inst_bar2_pos
inst_bar1_note = set()
for pos in inst_bar1:
temp_pos_notes = inst_bar1[pos]
for note in temp_pos_notes:
inst_bar1_note.add((pos, note[1]) if only_duration else (pos, note[0], note[1]))
inst_bar2_note = set()
for pos in inst_bar2:
temp_pos_notes = inst_bar2[pos]
for note in temp_pos_notes:
inst_bar2_note.add((pos, note[1]) if only_duration else (pos, note[0], note[1]))
s = len(inst_bar1_note & inst_bar2_note) / len(inst_bar1_note | inst_bar2_note)
o_sim[inst] = s
return 'basic_dur' if only_duration else 'basic', o_sim
def cal_bar_similarity(bar1, bar2):
bar1 = generate_bar_insts_pos_index(bar1)
bar2 = generate_bar_insts_pos_index(bar2)
inter_pos = set(bar1.keys()) & set(bar2.keys())
bar1_insts = set(bar1.keys())
bar2_insts = set(bar2.keys())
inter_insts = bar1_insts & bar2_insts
sim = {}
name, o_sim = cal_bar_similarity_basic(bar1, bar2, bar1_insts, bar2_insts, inter_insts, inter_pos)
sim[name] = o_sim
name, o_sim = cal_bar_similarity_basic(bar1, bar2, bar1_insts, bar2_insts, inter_insts, inter_pos, only_duration=True)
sim[name] = o_sim
return sim
def cal_song_similarity(midi_path=None, midi_obj=None, encoder=None, device='cpu'):
if encoder is None:
encoder = mp.MidiEncoder('REMIGEN')
if midi_obj is None:
assert midi_path is not None
midi_obj = mp.midi_utils.load_midi(midi_path, file=None, midi_checker=None)
pos_info = encoder.collect_pos_info(midi_obj)
bar_index = generate_bar_index(pos_info)
num_bars = pos_info[-1][0] + 1
sim_matrices = {}
for bar_id_i in range(num_bars):
for bar_id_j in range(bar_id_i):
bar_i_index = bar_index[bar_id_i]
bar_j_index = bar_index[bar_id_j]
bar_i = pos_info[slice(*bar_i_index)]
bar_j = pos_info[slice(*bar_j_index)]
sim_ij = cal_bar_similarity(bar_i, bar_j)
for sim_name in sim_ij:
if sim_name not in sim_matrices:
sim_matrices[sim_name] = {}
name_sim_matrices = sim_matrices[sim_name]
name_sim_ij = sim_ij[sim_name]
for inst in name_sim_ij:
if inst not in name_sim_matrices:
name_sim_matrices[inst] = torch.eye(num_bars, num_bars, device=device)
name_sim_matrices[inst][bar_id_i, bar_id_j] = name_sim_ij[inst]
return sim_matrices
def skew_matrix(m):
assert m.ndim == 2
m_size = m.size(1)
mm = torch.cat((torch.zeros(1, m_size, device=m.device), m), dim=0)
mm = mm.view(m_size, m_size + 1)
mm = mm[:, 1:]
return mm