Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/modify model #15

Open
wants to merge 49 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
7b97769
Remove dropout
madpeh May 27, 2021
c480ada
Merge branch 'master' into feature/modify-model
madpeh May 27, 2021
dd01cc7
Refactor
madpeh May 27, 2021
b866c28
Add prediction data
madpeh May 27, 2021
70ae90b
Fix bug
madpeh May 27, 2021
7caf865
Add cli with clearing data and/or checkpoints
madpeh May 27, 2021
ce3ddf7
Increase patience
madpeh May 27, 2021
42e9438
Use full midi scale for data
madpeh May 28, 2021
6e57857
Clear gpu graph cache
madpeh May 28, 2021
bf29fdb
Add multiple model implementations
madpeh May 28, 2021
efa69f1
Results
madpeh May 28, 2021
a4be4e9
Save midi for single notes
madpeh May 28, 2021
13986b8
Add visualization of vocabulary
madpeh May 28, 2021
db57d9b
Vocabulary
madpeh May 28, 2021
623e01a
Add more models
madpeh May 28, 2021
c1e2f1c
Add dataset percent
madpeh May 28, 2021
482a4df
Add dataset percent for training
madpeh May 28, 2021
182aec0
Add cudnnlstm
madpeh May 29, 2021
464f592
Update functions for prediction with new representation
madpeh May 29, 2021
b09a673
Refactor
madpeh May 29, 2021
0fa39aa
Change constants
madpeh May 29, 2021
2393330
Results and vocabulary
madpeh May 29, 2021
ff6294f
Add reduce lr on plateau callback
madpeh May 30, 2021
3e18cf6
Moar modelz
madpeh May 30, 2021
36f71fe
Change dataset percent
madpeh May 30, 2021
f425f06
Results
madpeh May 30, 2021
365ceb9
Add original model
madpeh May 30, 2021
d5951dd
Update visualization
madpeh Jun 1, 2021
a2a2d0f
Add class weights
madpeh Jun 1, 2021
2819186
Remove test code
madpeh Jun 1, 2021
c0e81d7
Results
madpeh Jun 6, 2021
9bad2df
Use simpler model
madpeh Jun 6, 2021
2cfb19e
Vocabulary
madpeh Jun 6, 2021
965cb69
Change batch size
madpeh Jun 6, 2021
7b25afa
Catch exception while parsing
madpeh Jun 6, 2021
b0c9775
Moar results
madpeh Jun 6, 2021
9a86f79
Another vocabulary
madpeh Jun 6, 2021
1e78168
Fix parsing results to midi and add data augmentation
madpeh Jun 6, 2021
2e36148
Clean network
madpeh Jun 6, 2021
479fd5e
Turn off data augmentation for prediction
madpeh Jun 6, 2021
8acd254
Merge branch 'master' into feature/modify-model
madpeh Jun 6, 2021
01c40ed
Remove unused pylintrc file
madpeh Jun 6, 2021
ebc9d29
Reformat
madpeh Jun 6, 2021
e1d814a
Add results to gitignore
madpeh Jun 6, 2021
24d5c36
Drop files from .gitignore
madpeh Jun 6, 2021
4f1a898
Merge branch 'feature/modify-model' of https://github.com/aif-dev/gen…
madpeh Jun 6, 2021
83b1189
Add data augmentation to one octave
madpeh Jun 6, 2021
d2c4b02
Add midi reduction to selected octaves
madpeh Jun 6, 2021
34ef3a4
Remove redundant code
madpeh Jun 6, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ __pycache__
checkpoints/
.idea/
venv/
data/notes
data/dataset_hash
training_data/notes
training_data/dataset_hash
midi_songs/
logs/
results/
4 changes: 4 additions & 0 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ tensorflow = "~=2.4.0"
wget = "~=3.2"
pyngrok = "*"
random-word = "*"
pandas = "*"
matplotlib = "*"
statsmodels = "*"
scipy = "*"

[requires]
python_version = "3.8"
Expand Down
218 changes: 202 additions & 16 deletions Pipfile.lock

Large diffs are not rendered by default.

Binary file removed data/vocabulary
Binary file not shown.
188 changes: 149 additions & 39 deletions data_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import math
import datetime
import shutil
import random
from multiprocessing import Pool, cpu_count
from collections import Counter
import checksumdir
from music21 import converter, instrument, stream, note, chord
from random_word import RandomWords
Expand All @@ -14,12 +16,12 @@

CHECKPOINTS_DIR = "checkpoints"
MIDI_SONGS_DIR = "midi_songs"
DATA_DIR = "data"
TRAINING_DATA_DIR = "training_data"
NOTES_FILENAME = "notes"
VOCABULARY_FILENAME = "vocabulary"
HASH_FILENAME = "dataset_hash"
RESULTS_DIR = "results"
SEQUENCE_LENGTH = 100
SEQUENCE_LENGTH = 60
VALIDATION_SPLIT = 0.2

"""
Expand All @@ -31,24 +33,33 @@
NUM_NOTES_TO_PREDICT = 1


def clean_data_and_checkpoints():
shutil.rmtree(DATA_DIR)
shutil.rmtree(CHECKPOINTS_DIR)
def clear_checkpoints():
try:
shutil.rmtree(CHECKPOINTS_DIR)
except FileNotFoundError:
print("Checkpoints directory doesn't exist")


def clear_training_data():
try:
shutil.rmtree(TRAINING_DATA_DIR)
except FileNotFoundError:
print("Training data directory doesn't exist")


def save_data_hash(hash_value):
if not os.path.isdir(DATA_DIR):
os.mkdir(DATA_DIR)
if not os.path.isdir(TRAINING_DATA_DIR):
os.mkdir(TRAINING_DATA_DIR)

hash_file_path = os.path.join(DATA_DIR, HASH_FILENAME)
hash_file_path = os.path.join(TRAINING_DATA_DIR, HASH_FILENAME)
with open(hash_file_path, "wb") as hash_file:
pickle.dump(hash_value, hash_file)


def is_data_changed():
current_hash = checksumdir.dirhash(MIDI_SONGS_DIR)

hash_file_path = os.path.join(DATA_DIR, HASH_FILENAME)
hash_file_path = os.path.join(TRAINING_DATA_DIR, HASH_FILENAME)
if not os.path.exists(hash_file_path):
save_data_hash(current_hash)
return True
Expand All @@ -63,37 +74,90 @@ def is_data_changed():
return False


def get_notes_from_file(file):
print(f"Parsing {file}")
def get_midi_in_default_octave(pattern):
if isinstance(pattern, note.Note):
note_in_default_octave = note.Note(pattern.name)
elif isinstance(pattern, int):
note_in_default_octave = note.Note(pattern)

return note_in_default_octave.pitch.midi


def map_midi_to_reduced_octaves(midi_value, min_midi=4 * 12, max_midi=5 * 12 - 1):
if midi_value > max_midi:
return midi_value - (math.ceil((midi_value - max_midi) / 12) * 12)

if midi_value < min_midi:
return midi_value + (math.ceil((min_midi - midi_value) / 12) * 12)

return midi_value

midi = converter.parse(file)

def get_notes_from_midi_stream(midi_stream, octave_transposition=0):
transposition = octave_transposition * 12
notes = []
s2 = instrument.partitionByInstrument(midi_stream)

# Looping over all the instruments
for part in s2.parts:

# select elements of only piano
if "Piano" in str(part):

notes_to_parse = part.recurse()

# finding whether a particular element is note or a chord
for element in notes_to_parse:

# note
if isinstance(element, note.Note):
midi_value = (
map_midi_to_reduced_octaves(element.pitch.midi) + transposition
)
notes.append(str(midi_value))

# chord
elif isinstance(element, chord.Chord):
midi_values = [
map_midi_to_reduced_octaves(pitch.midi) + transposition
for pitch in element.pitches
]
midi_values = list(set(midi_values))
notes.append(".".join(str(midi) for midi in sorted(midi_values)))
return notes


def get_notes_from_file(file, augment_data=False, octave_augmentation=1):
print(f"Parsing {file}")

try:
# file has instrument parts
instrument_stream = instrument.partitionByInstrument(midi)
notes_to_parse = instrument_stream.parts[0].recurse()
midi_stream = converter.parse(file)
except:
# file has notes in a flat structure
notes_to_parse = midi.flat.notes
return []

for element in notes_to_parse:
if isinstance(element, note.Note):
notes.append(element.name)
elif isinstance(element, chord.Chord):
notes.append(".".join(str(n) for n in element.normalOrder))
if augment_data:
all_notes = []
for octave_transposition in range(
-octave_augmentation, octave_augmentation + 1
):
notes = get_notes_from_midi_stream(midi_stream, octave_transposition)
for note in notes:
all_notes.append(note)

return notes
else:
all_notes = get_notes_from_midi_stream(midi_stream)

return all_notes


def get_notes_from_dataset():
notes_path = os.path.join(DATA_DIR, NOTES_FILENAME)
notes_path = os.path.join(TRAINING_DATA_DIR, NOTES_FILENAME)
notes = []
if is_data_changed():
try:
with Pool(cpu_count() - 1) as pool:
notes_from_files = pool.map(
get_notes_from_file, glob.glob(f"{MIDI_SONGS_DIR}/*.mid")
)
files = glob.glob(f"{MIDI_SONGS_DIR}/*.mid")
notes_from_files = pool.map(get_notes_from_file, files)

for notes_from_file in notes_from_files:
for note in notes_from_file:
Expand All @@ -103,7 +167,7 @@ def get_notes_from_dataset():
pickle.dump(notes, notes_data_file)

except:
hash_file_path = os.path.join(DATA_DIR, HASH_FILENAME)
hash_file_path = os.path.join(TRAINING_DATA_DIR, HASH_FILENAME)
os.remove(hash_file_path)
print("Removed the hash file")
sys.exit(1)
Expand All @@ -122,17 +186,19 @@ def create_vocabulary_for_training(notes):
sound_names = sorted(set(item for item in notes))
vocab = dict((note, number) for number, note in enumerate(sound_names))

vocab_path = os.path.join(DATA_DIR, VOCABULARY_FILENAME)
vocab_path = os.path.join(TRAINING_DATA_DIR, VOCABULARY_FILENAME)
with open(vocab_path, "wb") as vocab_data_file:
pickle.dump(vocab, vocab_data_file)

print(f"*** vocabulary size: {len(vocab)} ***")

return vocab


def load_vocabulary_from_training():
print("*** Restoring vocabulary used for training ***")

vocab_path = os.path.join(DATA_DIR, VOCABULARY_FILENAME)
vocab_path = os.path.join(TRAINING_DATA_DIR, VOCABULARY_FILENAME)
with open(vocab_path, "rb") as vocab_data_file:
return pickle.load(vocab_data_file)

Expand Down Expand Up @@ -173,16 +239,56 @@ def prepare_sequence_for_prediction(notes, vocab):
return network_input


def get_class_weights(notes, vocab):
mapped_notes = [vocab[note] for note in notes]
notes_counter = Counter(mapped_notes)

for key in notes_counter:
notes_counter[key] = 1 / notes_counter[key]

return notes_counter


def get_best_representation(vocab, pattern):
"""assumption: all 12 single notes are present in vocabulary"""
# assumption: all single notes (not necessarily from the same octave)
# are present in vocabulary

if pattern in vocab.keys():
return vocab[pattern]

chord_sounds = [int(sound) for sound in pattern.split(".")]
unknown_chord = chord.Chord(chord_sounds)
# either an unknown chord or an unknown single note
chord_midis = [int(midi) for midi in pattern.split(".")]
unknown_chord = chord.Chord(chord_midis)
root_note = unknown_chord.root()
print(f"*** Mapping {unknown_chord} to {root_note} ***")
return vocab[root_note.name]

nearest_note_midi = find_nearest_single_note_midi(vocab, root_note.midi)
print(f"*** Mapping {pattern} to {nearest_note_midi} ***")
return vocab[str(nearest_note_midi)]


def find_nearest_single_note_midi(vocab, midi_note):
if str(midi_note) in vocab.keys():
return midi_note

midi_note_down = midi_note
midi_note_up = midi_note

while midi_note_down >= 0 or midi_note_up <= 87:
midi_note_down -= 12
midi_note_up += 12

print(f"{midi_note} {midi_note_up} {midi_note_down}")

if midi_note_down >= 0 and str(midi_note_down) in vocab.keys():
return midi_note_down

if midi_note_up <= 87 and str(midi_note_up) in vocab.keys():
return midi_note_up

print(
f"ALERT: couldn't find any appropriate representation of {midi_note} in vocabulary. Returned a random representation."
)
return random.choice([key for key in vocab.keys() if not "." in key])


def save_midi_file(prediction_output):
Expand All @@ -193,18 +299,20 @@ def save_midi_file(prediction_output):
for pattern in prediction_output:
# pattern is a chord
if ("." in pattern) or pattern.isdigit():
notes_in_chord = pattern.split(".")
midis_in_chord = [int(midi) for midi in pattern.split(".")]
notes = []
for current_note in notes_in_chord:
new_note = note.Note(int(current_note))
for current_midi in midis_in_chord:
new_note = note.Note(current_midi)
new_note.storedInstrument = instrument.Piano()
notes.append(new_note)

new_chord = chord.Chord(notes)
new_chord.offset = offset
output_notes.append(new_chord)
# pattern is a note
else:
new_note = note.Note(pattern)
midi = int(pattern)
new_note = note.Note(midi)
new_note.offset = offset
new_note.storedInstrument = instrument.Piano()
output_notes.append(new_note)
Expand All @@ -224,3 +332,5 @@ def save_midi_file(prediction_output):

midi_stream = stream.Stream(output_notes)
midi_stream.write("midi", fp=f"{RESULTS_DIR}/{output_name}.mid")

print(f"Result saved as {output_name}")
32 changes: 16 additions & 16 deletions network.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,36 @@
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
from keras.layers import LSTM
from keras.layers import BatchNormalization as BatchNorm
from keras.layers import Activation
from keras.layers import Dense, Dropout, LSTM, Activation
from data_preparation import SEQUENCE_LENGTH, NUM_NOTES_TO_PREDICT


def create_network(vocab_size, weights_filename=None):
lstm_units = 128
dense_units = vocab_size * 2
dropout_rate = 0.3
model = Sequential()
model.add(
LSTM(
512,
lstm_units,
input_shape=(SEQUENCE_LENGTH, NUM_NOTES_TO_PREDICT),
return_sequences=True,
)
)
model.add(LSTM(512, return_sequences=True))
model.add(LSTM(512))
model.add(BatchNorm())
model.add(Activation("relu"))
model.add(Dropout(0.3))
model.add(Dense(256))
model.add(BatchNorm())
model.add(Activation("relu"))
model.add(Dropout(0.3))
model.add(Dropout(dropout_rate))
model.add(LSTM(lstm_units, return_sequences=True))
model.add(Dropout(dropout_rate))
model.add(LSTM(lstm_units, return_sequences=True))
model.add(Dropout(dropout_rate))
model.add(LSTM(lstm_units))
model.add(Dense(dense_units))
model.add(Dropout(dropout_rate))
model.add(Dense(vocab_size))
model.add(Activation("softmax"))
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=['acc'])
model.compile(loss="categorical_crossentropy", optimizer="rmsprop", metrics=["acc"])

if weights_filename:
print(f"*** Loading weights from {weights_filename} ***")
model.load_weights(weights_filename)

model.summary()

return model
2 changes: 1 addition & 1 deletion notes_sequence.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from keras.utils import Sequence, np_utils
import math
from keras.utils import Sequence, np_utils
import numpy as np


Expand Down
3 changes: 1 addition & 2 deletions predict.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import sys
import getopt
import pickle
import tensorflow as tf
import numpy as np
from network import create_network
Expand Down Expand Up @@ -64,7 +63,7 @@ def generate_notes(model, network_input, vocab, vocab_size):


def generate_music(file):
notes = get_notes_from_file(file)
notes = get_notes_from_file(file, augment_data=False)
vocab = load_vocabulary_from_training()
vocab_size = len(vocab)

Expand Down
Binary file added prediction_data/Fiend_Battle_(Piano).mid
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added prediction_data/cosmo.mid
Binary file not shown.
Binary file added prediction_data/ff4-town.mid
Binary file not shown.
Binary file removed results/cryptus_riffle-bars.mid
Binary file not shown.
Binary file removed results/nonstriking_waylayers.mid
Binary file not shown.
Binary file removed results/output_2021-05-27 15:53:40.754080.mid
Binary file not shown.
Binary file removed results/programs_hill-fort.mid
Binary file not shown.
Binary file removed results/puszta_decolorize.mid
Binary file not shown.
Binary file removed results/tottori_lipide.mid
Binary file not shown.
Binary file removed results/waterbomber_phatter.mid
Binary file not shown.
Loading