-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathAEDataset.py
72 lines (57 loc) · 2.42 KB
/
AEDataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import Dataset
from Util import get_spectrogram
class AssignmentDataset(Dataset):
"""
Dataset for accessing data points of auto generated audio tracks from MIDI.
"""
def __init__(self, root_dir, transform=None):
"""
Args:
root_dir (string): Directory containing the stored spectrogram data points (*.npy files).
transform (callable, optional): Transformation to apply to samples. This is used to convert the tensors
to GPU-compatible tensors
"""
self.root_dir = root_dir
self.transform = transform
# Get all piano and synth *.npy files
mel_piano_npys = Path(root_dir).rglob("*_piano_mel.npy")
mel_synth_npys = Path(root_dir).rglob("*_synth_mel.npy")
self.piano_mel_filenames = [str(npy) for npy in mel_piano_npys]
self.synth_mel_filenames = [str(npy) for npy in mel_synth_npys]
self.length = len(self.piano_mel_filenames)
def __len__(self):
return self.length
def __getitem__(self, idx):
# This is included for completeness. Future versions could handle the retrieval of multiple data points
# simultaneously
if torch.is_tensor(idx):
idx = idx.tolist()
# Get Mel spectrogram from *.npy file
# Convert idx to filename
mel_piano_path = self.piano_mel_filenames[idx]
mel_synth_path = self.synth_mel_filenames[idx]
mel_piano = get_spectrogram(mel_piano_path)
mel_synth = get_spectrogram(mel_synth_path)
# Because only the magnitude is used, an additional channel dimension is needed for PyTorch
mel_piano = np.expand_dims(mel_piano, axis=0)
mel_synth = np.expand_dims(mel_synth, axis=0)
# Pack sample
sample = {'piano_mel': mel_piano, 'synth_mel': mel_synth}
if self.transform:
sample = self.transform(sample)
return sample
class ToTensor(object):
"""
Transformation used to convert ndarrays in sample to PyTorch tensors.
"""
def __call__(self, sample):
# Get ndarrays
piano_mel = sample['piano_mel']
synth_mel = sample['synth_mel']
# Convert to float tensors
piano_mel = torch.from_numpy(piano_mel).float()
synth_mel = torch.from_numpy(synth_mel).float()
return {'piano_mel': piano_mel, 'synth_mel': synth_mel}