|
""" |
|
This file manages the loading of the data |
|
""" |
|
import csv |
|
import os |
|
import pickle |
|
import string |
|
|
|
import numpy as np |
|
import pretty_midi |
|
|
|
|
|
def get_midi_files(midi_pickle, midi_folder, artists, names): |
|
""" |
|
This function loads the midi files |
|
:param midi_pickle: path for the pickle file |
|
:param midi_folder: path for the midi folder |
|
:param artists: list of artist |
|
:param names: list of song names |
|
:return: list of pretty midi objects |
|
""" |
|
|
|
pretty_midi_songs = _read_pickle_if_exists(pickle_path=midi_pickle) |
|
if pretty_midi_songs is None: |
|
pretty_midi_songs = [] |
|
lower_upper_files = get_lower_upper_dict(midi_folder) |
|
if len(artists) != len(names): |
|
raise Exception('Artists and Names lengths are different.') |
|
for artist, song_name in zip(artists, names): |
|
if song_name[0] == " ": |
|
song_name = song_name[1:] |
|
song_file_name = f'{artist}_-_{song_name}.mid'.replace(" ", "_") |
|
if song_file_name not in lower_upper_files: |
|
print(f'Song {song_file_name} does not exist, even though' |
|
f' the song is provided in the training or testing sets') |
|
continue |
|
original_file_name = lower_upper_files[song_file_name] |
|
midi_file_path = os.path.join(midi_folder, original_file_name) |
|
try: |
|
pretty_midi_format = pretty_midi.PrettyMIDI(midi_file_path) |
|
pretty_midi_songs.append(pretty_midi_format) |
|
except Exception: |
|
print(f'Exception raised from Mido using this file: {midi_file_path}') |
|
|
|
_save_pickle(pickle_path=midi_pickle, content=pretty_midi_songs) |
|
return pretty_midi_songs |
|
|
|
|
|
def get_lower_upper_dict(midi_folder): |
|
""" |
|
This function maps between lower case name to upper case name |
|
:param midi_folder: midi folder path |
|
:return: A dictionary between lower case name to upper case name |
|
""" |
|
lower_upper_files = {} |
|
for file_name in os.listdir(midi_folder): |
|
if file_name.endswith(".mid"): |
|
lower_upper_files[file_name.lower()] = file_name |
|
return lower_upper_files |
|
|
|
|
|
def get_input_sets(input_file, pickle_path, word2vec, midi_folder) -> (list, list, list): |
|
""" |
|
This function loads the training and testing set that provided by the course staff. |
|
In addition some pre-processing methods are work here. |
|
:param input_file: training or testing set path |
|
:param pickle_path: training or testing pickle path |
|
:param word2vec: dictionary maps between a word and a vector |
|
:param midi_folder: the midi folder that we use to validate if song is exists |
|
:return: Nothing |
|
""" |
|
|
|
pickle_value = _read_pickle_if_exists(pickle_path=pickle_path) |
|
|
|
lower_upper_files = get_lower_upper_dict(midi_folder) |
|
if pickle_value is not None: |
|
artists, names, lyrics = pickle_value[0], pickle_value[1], pickle_value[2] |
|
else: |
|
artists, names, lyrics = [], [], [] |
|
with open(input_file, newline='') as f: |
|
lines = csv.reader(f, delimiter=',', quotechar='|') |
|
for row in lines: |
|
artist_name = row[0] |
|
song_name = row[1] |
|
if song_name[0] == " ": |
|
song_name = song_name[1:] |
|
song_file_name = f'{artist_name}_-_{song_name}.mid'.replace(" ", "_") |
|
if song_file_name not in lower_upper_files: |
|
print(f'Song {song_file_name} does not exist, even though' |
|
f' the song is provided in the training or testing sets') |
|
continue |
|
original_file_name = lower_upper_files[song_file_name] |
|
midi_file_path = os.path.join(midi_folder, original_file_name) |
|
try: |
|
pretty_midi.PrettyMIDI(midi_file_path) |
|
except Exception: |
|
print(f'Exception raised from Mido using this file: {midi_file_path}') |
|
continue |
|
song_lyrics = row[2] |
|
song_lyrics = song_lyrics.replace('&', '') |
|
song_lyrics = song_lyrics.replace(' ', ' ') |
|
song_lyrics = song_lyrics.replace('\'', '') |
|
song_lyrics = song_lyrics.replace('--', ' ') |
|
|
|
tokens = song_lyrics.split() |
|
table = str.maketrans('', '', string.punctuation) |
|
tokens = [w.translate(table) for w in tokens] |
|
tokens = [word for word in tokens if |
|
word.isalpha()] |
|
tokens = [word.lower() for word in tokens if word.lower() in word2vec] |
|
song_lyrics = ' '.join(tokens) |
|
artists.append(artist_name) |
|
names.append(song_name) |
|
lyrics.append(song_lyrics) |
|
_save_pickle(pickle_path=pickle_path, content=[artists, names, lyrics]) |
|
|
|
return {'artists': artists, 'names': names, 'lyrics': lyrics} |
|
|
|
|
|
def get_word2vec(word2vec_path, pre_trained, vector_size, encoding='utf-8') -> dict: |
|
""" |
|
This function returns a dictionary that maps between word and a vector |
|
:param word2vec_path: path for the pickle file |
|
:param pre_trained: path for the pre-trained embedding file |
|
:param vector_size: the vector size for each word |
|
:param encoding: the encoding the the pre_trained file |
|
:return: dictionary maps between a word and a vector |
|
""" |
|
|
|
word2vec = _read_pickle_if_exists(word2vec_path) |
|
if word2vec is None: |
|
with open(pre_trained, 'r', encoding=encoding) as f: |
|
list_of_lines = list(f) |
|
word2vec = _iterate_over_glove_list(list_of_lines=list_of_lines, vector_size=vector_size) |
|
_save_pickle(pickle_path=word2vec_path, content=word2vec) |
|
return word2vec |
|
|
|
|
|
def _iterate_over_glove_list(list_of_lines, vector_size): |
|
""" |
|
This function iterates over the glove list line by line and returns a word2vec dictionary |
|
:param list_of_lines: List of glove lines |
|
:param vector_size: the size of the embedding vector size |
|
:return: dictionary maps between a word and a vector |
|
""" |
|
word2vec = {} |
|
punctuation = string.punctuation |
|
for line in list_of_lines: |
|
values = line.split(' ') |
|
word = values[0] |
|
if word in punctuation: |
|
continue |
|
vec = np.asarray(values[1:], "float32") |
|
if len(vec) != vector_size: |
|
raise Warning(f"Vector size is different than {vector_size}") |
|
else: |
|
word2vec[word] = vec |
|
return word2vec |
|
|
|
|
|
def _save_pickle(pickle_path, content): |
|
""" |
|
This function saves a value to pickle file |
|
:param pickle_path: path for the pickle file |
|
:param content: the value you want to save |
|
:return: Nothing |
|
""" |
|
with open(pickle_path, 'wb') as f: |
|
pickle.dump(content, f) |
|
|
|
|
|
def _read_pickle_if_exists(pickle_path): |
|
""" |
|
This function reads a pickle file |
|
:param pickle_path:path for the pickle file |
|
:return: the saved value in the pickle file |
|
""" |
|
pickle_file = None |
|
if os.path.exists(pickle_path): |
|
with open(pickle_path, 'rb') as f: |
|
pickle_file = pickle.load(f) |
|
return pickle_file |
|
|
|
|
|
print('Loaded Successfully') |
|
|