-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutil.py
48 lines (31 loc) · 1.18 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
import numpy as np
def crop_roll_128_to_88(full_piano_roll):
assert full_piano_roll.shape[0] == 128
# piano range
full_piano_roll = full_piano_roll[21:109,:]
return full_piano_roll
def pad_roll_88_to_128(piano_roll):
assert piano_roll.shape[0] == 88
# pad to piano range
piano_roll = torch.nn.functional.pad(piano_roll,(0,0,21,21))
return piano_roll
def crop_augment_piano_roll(piano_roll,crop_size):
n_pitches , n_timesteps = piano_roll.shape
pad_size = crop_size
piano_roll = np.pad(piano_roll,((pad_size,pad_size),(0,0)),'constant',constant_values=0)
time_sum=np.sum(piano_roll>0,axis=-1)
min_pitch = np.min(np.where(time_sum>0)[0])
max_pitch = np.max(np.where(time_sum>0)[0])
low = min(max_pitch-crop_size,min_pitch)
high = max(max_pitch-crop_size,min_pitch)
if low==high:
start_pitch = low
else:
start_pitch = np.random.randint(low=low,high=high)
end_pitch = start_pitch+crop_size
piano_roll = piano_roll[start_pitch:end_pitch,:]
return piano_roll
def get_onsets(rolls):
# batch, pitch, time
return (np.diff(rolls,prepend=0,axis=-1)>0).astype(np.int32)