diff --git a/.github/ISSUE_TEMPLATE/sweep-template.yml b/.github/ISSUE_TEMPLATE/sweep-template.yml new file mode 100644 index 00000000..b77d1cf8 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/sweep-template.yml @@ -0,0 +1,22 @@ +name: Sweep Issue +title: 'Sweep: ' +description: For small bugs, features, refactors, and tests to be handled by Sweep, an AI-powered junior developer. +labels: sweep +body: + - type: textarea + id: description + attributes: + label: Details + description: Tell Sweep where and what to edit and provide enough context for a new developer to the codebase + placeholder: | + Unit Tests: Write unit tests for . Test each function in the file. Make sure to test edge cases. + Bugs: The bug might be in . Here are the logs: ... + Features: the new endpoint should use the ... class from because it contains ... logic. + Refactors: We are migrating this function to ... version because ... +# - type: input +# id: branch +# attributes: +# label: Branch +# description: The branch to work off of (optional) +# placeholder: | +# main diff --git a/.gitignore b/.gitignore index 02c7ca10..27a6dbea 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,418 @@ -data/* +# File created using '.gitignore Generator' for Visual Studio Code: https://bit.ly/vscode-gig +# Created by https://www.toptal.com/developers/gitignore/api/audio,python,visualstudio,visualstudiocode,windows +# Edit at https://www.toptal.com/developers/gitignore?templates=audio,python,visualstudio,visualstudiocode,windows + +### Audio ### +*.aif +*.iff +*.m3u +*.m4a +*.mp3 +*.mpa +*.ra +*.wav +*.wma +*.ogg +*.flac + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +Pipfile.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +### VisualStudioCode ### +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +!.vscode/*.code-snippets + +# Local History for Visual Studio Code +.history/ + +# Built Visual Studio Code Extensions +*.vsix + +### VisualStudioCode Patch ### +# Ignore all local history of files +.history +.ionide + +### Windows ### +# Windows thumbnail cache files +Thumbs.db +Thumbs.db:encryptable +ehthumbs.db +ehthumbs_vista.db + +# Dump file +*.stackdump + +# Folder config file +[Dd]esktop.ini + +# Recycle Bin used on file shares +$RECYCLE.BIN/ + +# Windows Installer files +*.cab +*.msi +*.msix +*.msm +*.msp + +# Windows shortcuts +*.lnk + +### VisualStudio ### +## Ignore Visual Studio temporary files, build results, and +## files generated by popular Visual Studio add-ons. +## +## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore + +# User-specific files +*.rsuser +*.suo +*.user +*.userosscache +*.sln.docstates + +# User-specific files (MonoDevelop/Xamarin Studio) +*.userprefs + +# Mono auto generated files +mono_crash.* + +# Build results +[Dd]ebug/ +[Dd]ebugPublic/ +[Rr]elease/ +[Rr]eleases/ +x64/ +x86/ +[Ww][Ii][Nn]32/ +[Aa][Rr][Mm]/ +[Aa][Rr][Mm]64/ +bld/ +[Bb]in/ +[Oo]bj/ +[Ll]og/ +[Ll]ogs/ + +# Visual Studio 2015/2017 cache/options directory +.vs/ + +# Visual Studio 2017 auto generated files +Generated\ Files/ + +# MSTest test Results +[Tt]est[Rr]esult*/ +[Bb]uild[Ll]og.* + +# NUnit +*.VisualState.xml +TestResult.xml +nunit-*.xml + +# Build Results of an ATL Project +[Dd]ebugPS/ +[Rr]eleasePS/ +dlldata.c + +# Files built by Visual Studio +*_i.c +*_p.c +*_h.h +*.ilk +*.meta +*.obj +*.iobj +*.pch +*.pdb +*.ipdb +*.pgc +*.pgd +*.rsp +*.sbr +*.tlb +*.tli +*.tlh +*.tmp +*.tmp_proj +*_wpftmp.csproj +*.tlog +*.vspscc +*.vssscc +.builds +*.pidb +*.svclog +*.scc + +# Visual Studio profiler +*.psess +*.vsp +*.vspx +*.sap + +# Visual Studio Trace Files +*.e2e + +# Visual Studio code coverage results +*.coverage +*.coveragexml + +# NuGet Packages +*.nupkg + +# NuGet Symbol Packages +*.snupkg + +# The packages folder can be ignored because of Package Restore +**/[Pp]ackages/* + +# except build/, which is used as an MSBuild target. +!**/[Pp]ackages/build/ + +# NuGet v3's project.json files produces more ignorable files +*.nuget.props +*.nuget.targets + +# Visual Studio cache files +# files ending in .cache can be ignored +*.[Cc]ache +# but keep track of directories ending in .cache +!?*.[Cc]ache/ + +# Others +ClientBin/ +~$* +*~ +*.dbmdl +*.dbproj.schemaview +*.jfm +*.pfx +*.publishsettings +orleans.codegen.cs + +# Backup & report files from converting an old project file +# to a newer Visual Studio version. Backup files are not needed, +# because we have git ;-) +_UpgradeReport_Files/ +Backup*/ +UpgradeLog*.XML +UpgradeLog*.htm +ServiceFabricBackup/ +*.rptproj.bak + +# SQL Server files +*.mdf +*.ldf +*.ndf + +# Business Intelligence projects +*.rdl.data +*.bim.layout +*.bim_*.settings +*.rptproj.rsuser +*- [Bb]ackup.rdl +*- [Bb]ackup ([0-9]).rdl +*- [Bb]ackup ([0-9][0-9]).rdl + +# Node.js Tools for Visual Studio +.ntvs_analysis.dat +node_modules/ + +# Visual Studio 6 build log +*.plg + +# Visual Studio 6 workspace options file +*.opt + +# Visual Studio 6 auto-generated workspace file (contains which files were open etc.) +*.vbw + +# Visual Studio 6 auto-generated project file (contains which files were open etc.) +*.vbp + +# Visual Studio 6 workspace and project file (working project files containing files to include in project) +*.dsw +*.dsp + +# Visual Studio 6 technical files + +# Visual Studio LightSwitch build output +**/*.HTMLClient/GeneratedArtifacts +**/*.DesktopClient/GeneratedArtifacts +**/*.DesktopClient/ModelManifest.xml +**/*.Server/GeneratedArtifacts +**/*.Server/ModelManifest.xml +_Pvt_Extensions + +# CodeRush personal settings +.cr/personal + +# Python Tools for Visual Studio (PTVS) +*.pyc + +# NVidia Nsight GPU debugger configuration file +*.nvuser + +# Local History for Visual Studio +.localhistory/ + +# Visual Studio History (VSHistory) files +.vshistory/ + +# VS Code files for those working on multiple tools +*.code-workspace + +# Custom rules (everything added below won't be overriden by 'Generate .gitignore File' if you use 'Update' option) + +*.mid +act.bat config.txt +data/* history results/* -cuda10.dockerfile -docker-compose.yml -Dockerfile \ No newline at end of file +run.bat +slnx.sqlite +tensorflow-1.14.0-cp37-cp37m-win_amd64.whl +__pycache__/* + +.vs/**/* +.vs/Composer/FileContentIndex/17b639d7-f02e-47fd-a903-7ef32f5632b2.vsidx +.vs/ProjectSettings.json +.vs/VSWorkspaceState.json +.vs/Composer/FileContentIndex/5bcaee41-053c-41c7-b31d-0a445c89d351.vsidx +.vs/Composer/FileContentIndex/read.lock +.vs/Composer/v17/.wsuo +.vs/slnx.sqlite +.vscode/settings.json +.vs/ProjectSettings.json +.vs/VSWorkspaceState.json diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index ebed377d..00000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "python.pythonPath": "D:\\ProgramData\\Anaconda3\\envs\\TFOld\\python.exe" -} \ No newline at end of file diff --git a/Composer Controls.docx b/Composer Controls.docx new file mode 100644 index 00000000..7bc44f87 Binary files /dev/null and b/Composer Controls.docx differ diff --git a/GradiantVis.ipynb b/GradiantVis.ipynb new file mode 100644 index 00000000..270c9358 --- /dev/null +++ b/GradiantVis.ipynb @@ -0,0 +1,476 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\ProgramData\\Anaconda2\\envs\\tensor3\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n", + "C:\\ProgramData\\Anaconda2\\envs\\tensor3\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n", + "C:\\ProgramData\\Anaconda2\\envs\\tensor3\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n", + "C:\\ProgramData\\Anaconda2\\envs\\tensor3\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n", + "C:\\ProgramData\\Anaconda2\\envs\\tensor3\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n", + "C:\\ProgramData\\Anaconda2\\envs\\tensor3\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n", + "C:\\ProgramData\\Anaconda2\\envs\\tensor3\\lib\\site-packages\\tensorboard\\compat\\tensorflow_stub\\dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n", + "C:\\ProgramData\\Anaconda2\\envs\\tensor3\\lib\\site-packages\\tensorboard\\compat\\tensorflow_stub\\dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n", + "C:\\ProgramData\\Anaconda2\\envs\\tensor3\\lib\\site-packages\\tensorboard\\compat\\tensorflow_stub\\dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n", + "C:\\ProgramData\\Anaconda2\\envs\\tensor3\\lib\\site-packages\\tensorboard\\compat\\tensorflow_stub\\dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n", + "C:\\ProgramData\\Anaconda2\\envs\\tensor3\\lib\\site-packages\\tensorboard\\compat\\tensorflow_stub\\dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n", + "C:\\ProgramData\\Anaconda2\\envs\\tensor3\\lib\\site-packages\\tensorboard\\compat\\tensorflow_stub\\dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n", + "Using TensorFlow backend.\n" + ] + } + ], + "source": [ + "import tensorflow as tf\n", + "import keras\n", + "from keras.models import Model, load_model\n", + "from keras import backend as K\n", + "import numpy as np\n", + "import params\n", + "import random\n", + "from keras.optimizers import Adam, RMSprop\n", + "\n", + "import midi_utils\n", + "import plot_utils\n", + "import models\n", + "import params\n", + "\n", + "import os\n", + "\n", + "import matplotlib\n", + "import matplotlib.pyplot as plt\n", + "\n", + "%matplotlib inline\n", + "\n", + "# User constants\n", + "dir_name = 'results/history/'\n", + "sub_dir_name = 'basic'\n", + "num_measures = 16\n", + "use_pca = True" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: Logging before flag parsing goes to stderr.\n", + "W0608 23:25:06.865839 31156 deprecation_wrapper.py:119] From C:\\ProgramData\\Anaconda2\\envs\\tensor3\\lib\\site-packages\\keras\\backend\\tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n", + "\n", + "W0608 23:25:06.904905 31156 deprecation_wrapper.py:119] From C:\\ProgramData\\Anaconda2\\envs\\tensor3\\lib\\site-packages\\keras\\backend\\tensorflow_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.\n", + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Keras version: 2.2.4\n", + "Loading encoder...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "W0608 23:25:07.065491 31156 deprecation_wrapper.py:119] From C:\\ProgramData\\Anaconda2\\envs\\tensor3\\lib\\site-packages\\keras\\backend\\tensorflow_backend.py:131: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.\n", + "\n", + "W0608 23:25:07.066457 31156 deprecation_wrapper.py:119] From C:\\ProgramData\\Anaconda2\\envs\\tensor3\\lib\\site-packages\\keras\\backend\\tensorflow_backend.py:133: The name tf.placeholder_with_default is deprecated. Please use tf.compat.v1.placeholder_with_default instead.\n", + "\n", + "W0608 23:25:07.201126 31156 deprecation.py:506] From C:\\ProgramData\\Anaconda2\\envs\\tensor3\\lib\\site-packages\\keras\\backend\\tensorflow_backend.py:3445: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n", + "Instructions for updating:\n", + "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n", + "W0608 23:25:07.821541 31156 deprecation_wrapper.py:119] From C:\\ProgramData\\Anaconda2\\envs\\tensor3\\lib\\site-packages\\keras\\backend\\tensorflow_backend.py:174: The name tf.get_default_session is deprecated. Please use tf.compat.v1.get_default_session instead.\n", + "\n", + "W0608 23:25:12.050011 31156 deprecation_wrapper.py:119] From C:\\ProgramData\\Anaconda2\\envs\\tensor3\\lib\\site-packages\\keras\\optimizers.py:790: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n", + "\n", + "W0608 23:25:12.060956 31156 deprecation.py:323] From C:\\ProgramData\\Anaconda2\\envs\\tensor3\\lib\\site-packages\\tensorflow\\python\\ops\\nn_impl.py:180: add_dispatch_support..wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n", + "Instructions for updating:\n", + "Use tf.where in 2.0, which has the same broadcast rule as np.where\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading gaussian/pca statistics...\n" + ] + } + ], + "source": [ + "print(\"Keras version: \" + keras.__version__)\n", + "\n", + "K.set_image_data_format('channels_first')\n", + "\n", + "print(\"Loading encoder...\")\n", + "# priority name.h5 in sub, model.h5 in sub, name.h5 in dir, model.h5 in dir\n", + "if os.path.isfile(dir_name + sub_dir_name + '/' + sub_dir_name + '.h5'):\n", + " model = load_model(dir_name + sub_dir_name + '/' + sub_dir_name + '.h5')\n", + "\n", + "elif os.path.isfile(dir_name + sub_dir_name + '/' + 'model.h5'):\n", + " model = load_model(dir_name + sub_dir_name + '/' + 'model.h5')\n", + "\n", + "elif os.path.isfile(dir_name + sub_dir_name + '.h5'):\n", + " model = load_model(dir_name + sub_dir_name + '.h5')\n", + "\n", + "else:\n", + " model = load_model(dir_name + 'model.h5')\n", + "\n", + "\n", + "encoder = Model(inputs=model.input,\n", + " outputs=model.get_layer('encoder').output)\n", + "encoderFunc = K.function([model.input, K.learning_phase()], \n", + " [model.get_layer('encoder').output])\n", + "decoder = K.function([model.get_layer('decoder').input, K.learning_phase()],\n", + " [model.layers[-1].output])\n", + "# decoder = Model(inputs=model.get_layer('decoder').input,\n", + "# outputs=model.layers[-1].output)\n", + "\n", + "print(\"Loading gaussian/pca statistics...\")\n", + "latent_means = np.load(dir_name + sub_dir_name + '/latent_means.npy')\n", + "latent_stds = np.load(dir_name + sub_dir_name + '/latent_stds.npy')\n", + "latent_pca_values = np.load(\n", + " dir_name + sub_dir_name + '/latent_pca_values.npy')\n", + "latent_pca_vectors = np.load(\n", + " dir_name + sub_dir_name + '/latent_pca_vectors.npy')" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "_________________________________________________________________\n", + "Layer (type) Output Shape Param # \n", + "=================================================================\n", + "input_1 (InputLayer) (None, 16, 96, 96) 0 \n", + "_________________________________________________________________\n", + "reshape_1 (Reshape) (None, 16, 9216) 0 \n", + "_________________________________________________________________\n", + "time_distributed_1 (TimeDist (None, 16, 2000) 18434000 \n", + "_________________________________________________________________\n", + "time_distributed_2 (TimeDist (None, 16, 200) 400200 \n", + "_________________________________________________________________\n", + "flatten_1 (Flatten) (None, 3200) 0 \n", + "_________________________________________________________________\n", + "dense_3 (Dense) (None, 1600) 5121600 \n", + "_________________________________________________________________\n", + "dense_4 (Dense) (None, 40) 64040 \n", + "_________________________________________________________________\n", + "encoder (BatchNormalization) (None, 40) 160 \n", + "=================================================================\n", + "Total params: 24,020,000\n", + "Trainable params: 24,019,920\n", + "Non-trainable params: 80\n", + "_________________________________________________________________\n" + ] + } + ], + "source": [ + "# model.summary()\n", + "encoder.summary()\n", + "\n", + "encoder.compile(optimizer=RMSprop(lr=0.01), loss='binary_crossentropy')\n", + "# decoder.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading Data...\n", + "Loaded 8980 samples from 172 songs.\n", + "8980\n", + "Preparing song samples, padding songs...\n", + "saving sample\n" + ] + } + ], + "source": [ + "np.random.seed(42)\n", + "random.seed(42)\n", + "\n", + "# network params\n", + "DROPOUT_RATE = 0.1\n", + "BATCHNORM_MOMENTUM = 0.9 # weighted normalization with the past\n", + "USE_EMBEDDING = False\n", + "USE_VAE = False\n", + "VAE_B1 = 0.02\n", + "VAE_B2 = 0.1\n", + "\n", + "BATCH_SIZE = 350\n", + "MAX_WINDOWS = 16 # the maximal number of measures a song can have\n", + "LATENT_SPACE_SIZE = params.num_params\n", + "NUM_OFFSETS = 16 if USE_EMBEDDING else 1\n", + "\n", + "K.set_image_data_format('channels_first')\n", + "\n", + "samples_path='data/interim/samples.npy'\n", + "lengths_path='data/interim/lengths.npy'\n", + "\n", + "print(\"Loading Data...\")\n", + "if not os.path.exists(samples_path) or not os.path.exists(lengths_path):\n", + " print('No input data found, run preprocess_songs.py first.')\n", + " assert(False)\n", + "\n", + "y_samples = np.load(samples_path)\n", + "y_lengths = np.load(lengths_path)\n", + "\n", + "samples_qty = y_samples.shape[0]\n", + "songs_qty = y_lengths.shape[0]\n", + "\n", + "print(\"Loaded \" + str(samples_qty) + \" samples from \" + str(songs_qty) + \" songs.\")\n", + "print(np.sum(y_lengths))\n", + "assert (np.sum(y_lengths) == samples_qty)\n", + "\n", + "print(\"Preparing song samples, padding songs...\")\n", + "x_shape = (songs_qty * NUM_OFFSETS, 1) # for embedding\n", + "x_orig = np.expand_dims(np.arange(x_shape[0]), axis=-1)\n", + "\n", + "y_shape = (songs_qty * NUM_OFFSETS, MAX_WINDOWS) + y_samples.shape[1:] # (songs_qty, max number of windows, window pitch qty, window beats per measure)\n", + "y_orig = np.zeros(y_shape, dtype=np.float32) # prepare dataset array\n", + "\n", + "# fill in measure of songs into input windows for network\n", + "song_start_ix = 0\n", + "song_end_ix = y_lengths[0]\n", + "for song_ix in range(songs_qty):\n", + " for offset in range(NUM_OFFSETS):\n", + " ix = song_ix * NUM_OFFSETS + offset # calculate the index of the song with its offset\n", + " song_end_ix = song_start_ix + y_lengths[song_ix] # get song end ix\n", + " for window_ix in range(MAX_WINDOWS): # get a maximum number of measures from a song\n", + " song_measure_ix = (window_ix + offset) % y_lengths[song_ix] # chosen measure of song to be placed in window (modulo song length)\n", + " y_orig[ix, window_ix] = y_samples[song_start_ix + song_measure_ix] # move measure into window\n", + " song_start_ix = song_end_ix # new song start index is previous song end index\n", + "assert (song_end_ix == samples_qty)\n", + "x_train = np.copy(x_orig)\n", + "y_train = np.copy(y_orig)\n", + "\n", + "test_ix = 0\n", + "y_test_song = np.copy(y_train[test_ix: test_ix + 1])\n", + "x_test_song = np.copy(x_train[test_ix: test_ix + 1])\n", + "print(\"saving sample\")\n", + "midi_utils.samples_to_midi(y_test_song[0], 'data/interim/gt.mid')" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(172, 16, 96, 96)\n" + ] + } + ], + "source": [ + "print(y_train.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(172, 16, 96, 96)\n" + ] + } + ], + "source": [ + "offset = 0\n", + "epochs_qty = 1\n", + "\n", + "for epoch in range(epochs_qty):\n", + " if USE_EMBEDDING:\n", + " history = 0\n", + "# history = model.fit(x_train, y_train, batch_size=BATCH_SIZE, epochs=1)\n", + " else:\n", + " # produce songs from its samples with a different starting point of the song each time\n", + " song_start_ix = 0\n", + " for song_ix in range(songs_qty):\n", + " song_end_ix = song_start_ix + y_lengths[song_ix]\n", + " for window_ix in range(MAX_WINDOWS):\n", + " song_measure_ix = (window_ix + offset) % y_lengths[song_ix]\n", + " y_train[song_ix, window_ix] = y_samples[song_start_ix + song_measure_ix]\n", + " #if params.encode_volume:\n", + " #y_train[song_ix, window_ix] /= 100.0\n", + " song_start_ix = song_end_ix\n", + " assert (song_end_ix == samples_qty)\n", + " offset += 1\n", + " print(y_train.shape)\n", + "\n", + "# history = model.fit(y_train, y_train, batch_size=BATCH_SIZE, epochs=1) # train model on reconstruction loss" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "def flattenWindow(song):\n", + " flat = np.zeros((96 * 2, 96 * 8))\n", + " for i in range(2):\n", + " for j in range(8):\n", + " wix = j + i * 8\n", + " flat[i*96:i*96+96, j*96:j*96+96] = song[0][wix]\n", + " return flat" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(1, 16, 96, 96)\n", + "(192, 768)\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAB2CAYAAAAz69PvAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAApBElEQVR4nO2da5Aj13Xf/+fexnuAGcx7dgbz2pmdfc8ud7nc4UOkKMt6paS4XFGRKdssxzLzQa6KnA8pslSVKB9USVxJKp/iChMpUSoxFUWJLZUSW5GYKIltORIlkRIpacWV+Fpyn7PPeQLoPvnQ3ZgG0AAaj35gdH9VU3jMbfQfF90XjXPPPX9iZigUCoVifyHCFqBQKBSK3qMGd4VCodiHqMFdoVAo9iFqcFcoFIp9iBrcFQqFYh+iBneFQqHYh/g2uBPRB4noAhFdJKJn/NqPQqFQKOohP/LciUgC+BmA9wO4BOC7AJ5k5h/3fGcKhUKhqMOvK/dzAC4y8y+YuQjgiwA+5tO+FAqFQlGD5tPrTgN42/H4EoAHnA2I6GkATwOAhDyTRq7+VdIp0M4uOJ0ANrYb7oyEACfjgKCm7XxjIAVsbgOZVHOdmgZoEqyJ0HTS5g44nTT1NoAScQAAa7LSrtLHWzvB6NwpArEYWDb+TCmVBAwDEAK8beqimAZICd7ZBWnSfI3t3msmInA6CSrpAAAuFhs3zqRAZavdrtmOEnGAGVws9VxblU6rDwAApTK4XG7ceCAFKhuAroNLZju3PvZFZ3zvmKOdXbBhWMdcAtjaBogA+7gdSAEGg3aKYMPwTZOrzmQCKOvgZKz1OTyQAhimfrexYaD5eOGFe7h1g5nH3P7n1+BOLs9VxX+Y+TkAzwFAjob5AXpf/Rb2+95s8IrOV7bbNmvnF5uO22b7162/3Rbt/MLWudVi/0XHrd3O7uMgdNs6deu20T6dY43dpmz9Efb62y/NWy7799quWPPYL+w+qN2/G5uO+3Y7tz72g1LNLaH+mLOPWzedQbFr3bY61+02zvu17b28Rgu+yV9+s9H//ArLXAJQcDyeAfCuT/vqb4SEOH44bBWKENEmJ6BNHwh8v+L4YUDIwPfbLuLU0bAl9BRtYQ4yn/d9P34N7t8FsExEC0QUB/AEgK92+mJyZLhnwvykow/M0EFXrkPmXMJSfkAEOTTY9mZBHIxOSNMgstmOtw9Kr0gmIZLJtrahWBwik6k81m/dhrF+s9fSqhCZDCgWr35uYwvgYMMarZBDg2YIxgG9eTkkNY3p5vgyrt2AsbHZumGX+DK4M3MZwO8B+DqAnwD4EjO/2vHrFSZ7Jc0/iMCFiY425Y1NYGq8x4LcIS3W/r6EBE939t46RaTTEGMjHW/Ps8EcM2JoECI/1N42A5mq98a7uzB2/J3LEOOjEJlU1XPGQBqgiC11mRqvxN9t9Fu3QhLTGC5M1H0JecXY3ASXmszR9AjfPllm/u/MfIiZDzLzZ7t5rdJQEnT6WHBXtx1SHkyBH1ytu0JqhbGzg9LYQCA/kUkK6NkkjEdPe9/I0IE33wGdPe6fsFriMRgDbeq0IUI5197VdMekkuBsBvzgqudN9Fu3wDENWmEGAKDNz0JbnPdJoImRSYGG86Azx/Z0DCZBIozJn8bo2SRE4QDkyhJEOg2cOxG2JFfKuSTkkWXP4TT7OO7oeO4QX/Lc26XhhKpCoVAoGvJN/vL3mPms2/8i9pusHpFMQh46GLaMlshcDtrCXFevoU36H/qQY2PQpqIf5tIKM93NtQhpvlef+1QuLXQ1NxAUcmUJlEhUPzcx3nFowReIII8su/+CJYIcc834CxzSNFOnB8Ls48gP7pRKoTjd/gRg0FB2AKXJoa5eg0e6294T+RyMTvYT8AGqjw6Ccp0PmiQIGMqCh/09dkqTg6B0qnXDkNmdGayLZWMoF62YOwlszw6CpNvgLoB8NMKypGnYKdRP/LoyFJ5mFZaJEHJ5EbS9i/Kld8KWUo2Q4PPHQX/5cthK9gVyYhwYHoT+k9dC1cFrqxDffbX5wiZFpOnrsIycGMfGx897astr5mSmc9IoKOTSAnY/cr+3xudPQg4NQh49VPW0fvF1lGc6zxDxAp09DuPhU63bWf2oTR+ANjtdGdh5zfvEYTeUHz8DcbJ1/r/M5SCPrUAuL9b9bJf5POTRQ75q3vq1B6DNFVq206YPQFuYg1g9Ar63Af0nr4ESCdBp81glTfN1wvruk+frwlyipNe1s/uR7j8ReA48aRpu/+aaa2pp7XktTh6GyGQCOx6diGwWd37jvKcrd7r/hLky/fzJAJTV7DvyV+5CQqSSMDZb54VSIgHe3QXF4oGkGlXtW9NA8TiMra3WbRMJcLEIkrLuqsl+D35BMbNMg5d9UCwO1s0BgKSEHB2GfvOWr/psRDIJ1o3WnyORmd7JhqnVeTwTmT/xpfRNs0inYezsmhlFTRtKM1REAlwqmnnSug5je6fyHv08bkU2C2Njo6p/KBaHPDCB8pt7lUJI08C6DorHA/mcXXXeu+f6P21hDuU3L5lrQ2JxcLkUjk4iiIGBhjqrmtpjkk/ndbMrd7/KD/QOQ/c0sAOodF7QAzsAcLlF3Q5nW1unS3u/D9R2+qaqrZSApgV2InnO+2Zu/J6YzT72MexQ+TIXsvkAb+jV64WkML+QHNrtL1I/cBuIuFQExzRzQLf6qHIbwsAOuOuswAwRj8HY0Sv9FopOZk8DO+A418P4ogx8j4q+hEtFlN++FLaMSEKaBpxtb4m8fmMd+u071a9xLvhwon7xdRgPBLh+oQvKb7wF/b6VaGX4RJi+GNy9xIhDh6ithSxhIZLJYBcjdYgcGYY8thK2jJZocwXImQPAd37U3naL89Bmpn1SVY88eghy1H0+h/7ipcB0tILOHGtazoH+8uXq0FtI8Npq5L9k+mJwj62bYRmRyUQ6R1uzdMqR4Y5ytOXyYq8l1cG6AXlzA4A5yeesceKFIDQCMEv23jX7U5ufNa9s20AuLUCk074X5OLNLWBn1ywGtbTQ1na8vVfulQ2GXN/wQyIAgO5uVsoN136GQX2mXpC3NsG6Yea1t9GfQaNZnxXF4tDmZ0NW4070Y+4AyK55rOvgsn9xya5gBtm1zkudxXhpJ4C5AjYAqy43F0tAm3HeQDQC5me9Y8Upd9vfJ+2WwIbhe710FEtg3QBJCRJtXCsVS9VzLmyAtv2Ly/LODlAy+6L2MwzsM/XC9o55jDJHS1cNtLVj/oJgo6PjMwiiny2jUCgUClf6Os8dqC62YzwSXOGdtiCqmhtoVyc/dCqQvGKRTFaKMVEsHkqesBfk6AjEycMwHjkNmctV8sFboU1OeF4a3gu0+VnIQwfBa6vQpg+0VSqDEolK/rPfBaXksZXKOgC5slQJVwVZyMoLdP8Js2AY0FGxQHlsxVwk5jP80KlKzJ3XVtsOGwZBV1fuRPQGgHswvV7KzHyWiIYB/CcA8wDeAPBxZm5as/OX+cpdHluB/uqFsGV4RpuaBBeL0H2uP95PkKZBLC+0teJUm5wAl8vQb6ybi4akQPnKVR9V1iNXlmD8/I19t0JVLi3AePvd0NI5g8TvK/f3MvMpxw6eAfACMy8DeMF63BVydKT5NyNR5du+3QnCXiLHxppefZOmVTIBKjqDjtcRtbyyoVjcWuwk6zIXjI1NIIirFKvwV7OiXCKZNBcHxeJ1RbHsY8Lvol4Ui0Pk803j5SKdNhdUJRKV49jY3AJv75hXpmzAuOffZCpgFeCrNWm5eady3gB7x6StNwxENtvyHLY/U7d2IpkEbW5D1B4PPUbm8/XHXK0Wuz9DKiznR1jmYwC+YN3/AoC/3u0Lbjx8EKJJ9olIJEAL5hJwPjzf7e46gwh3HluESDVO4xKDOdDMlDlwLpkz7MZQsF9GFI/jzqOLTb8s5eQ45OgwRCYNmpup+p+xsQFjxv/qfDI3gLuPLoJXGlfapNlpyIEM5EgecrL6C4u0GLA023T7nugcH8XGgwswhgYaN1qaBWkxyOkpCGuANe7dAxdLKJ9YBIaHPC/U6xRaKKC0Wp0Vo1+/DuPQXqYHH543v4QWChCpcIqhGUfmgeXmn5n9mfKh+bqLKZqZMi8+ZvzNqts5swg5PdW4AZGpD/D9GGwoocuwzOsAbsG0sv1XzPwcEd1m5iFHm1vMXOdJRURPA3gaAJJIn3mYPtyxjr5GSPDaiUjlGiuCg9dWIb7/01BDCMajpyH+9w9C27+ic/wMyzzEzPcB+BCATxLRe7xuyMzPMfNZZj4bQ+ufUPLQwVBDLl4Rxw+3N7li6KEM7N2YDlMsXlf0zC/6why5C5Nz7dXXIQoHAnmfIp12nfCNvXM7Ui5ncnSk4lLV0fZHlluGTHqBVphpuDAsCnQVPGXmd63ba0T0xwDOAbhKRFPMfJmIpgBc61akHBkGbW7DiPjEj8znQRtbKBvhp5c2hAhyMAe6u4VOrZG5XAIuXemprFpI00CpFEQXOoNAJJOAlBAbnek0trZAVxm07u88hshkACFAm9t1/+OrN6oWVIWJHBoEdGNvjUObiEwGuH4T3OFaE6/IfN7U6Pc6ii7o+MqdiDJElLXvA/hVAK8A+CqAp6xmTwH4SrciuTAJTifdi/hHBcsgmzOpyPlSOrENso1sFzFVZuh37/ZOlAu2QXZXOgNADA1CDOfBmc50crkM4949302gxfgoRC4LTtfPCRn37kUnY2ZqHJROgZKdXXmLsRFAN1pX6ewSLkyA0kkgHvN1P93QTVhmAsCfE9HLAL4D4L8x858B+McA3k9ErwF4v/W4K9ZXB7G11CJjJgJcO5/H9kw2Wu42NZAUuPqeUeyOpVs3dn0BCmatQTyGq49PoZSP9uCOVBI33ltAeTAgQ+4O0YcGcOvhWRi5aPfn7lQOdx8ouH4JeYEzKWBkyPf6SfeWB7FxciqQ8E+nqBWqCoVC0af0/QpVbXLCjLtH/MpdToxHz3S4li6No4Mw8QbM1Zsyn++JTj81i0zGdIPqYlVkL16jFTKfh8hmXQvayXzeTM+NAHJ0xNTZxQSvHB3xfbW3nBiHHBqsWicQNfpicOdh0yw50oM7ETA8CAxGPSwjgcGBjo2j/TactqF4HJQb6Ng0nB1myn5qpnQalEmbn3unr5FKmq/ho5kyDWTMWLaL6TgNZkFRiR0P5cz1FZkuBs2hHCjm81gxPAgaGFBhmVaosIyiEcYjpyH+r8rB9pP91Mfy6CFg/Tb0q10n6fUFfR+W4bXVpsbOIp02jYfDLIJFBJw7gfL7zjTMeZajI5ArS65muUEZElMiAXHqaFNjZ21+FtrCXJ3RuBwbC6z2t8znIY8sY6OQaFgbXx49ZJpl1Jp6EO31sc/GxFphBtrifENjZ8AsgKUtzEFbqF+paBul+33sypUlaHMFV2Pn2Pqm6el6/iTkoYN1RuNBIk4dhTY/i42Pn2/YhtdWKwbZTrSZadDWDnhj05O5elecOwG5soTdj9zv/n8hTXPsEIzGbfriyp0SCcDgpsbO7Rg/+wUlEiCixsbOTQyb/TbGrtpXLA6Kx2Bs77injNWYOe9tSK6m3v6INPdFqVSdsXOliaaBDTY11fS338bEFay+olSqoa8mxeKwDVTr+i4AE29gr69EJl2n02mKDV2vNxoPENuUXaSSDUsy2OMBl0vVOu1BlA2QFvPVS5kSCcDqs4qPrlsb+Ouf2t8G2fDWOWGYYtdp2N1F01OiiWEz7+5WGRX7CZeKzfur1sy5siEHlw9t72t7u+FAUzFzrv2CIgIXAzJQtvqKmxgmN+1r5zHRymS7C+y+cjXKDtkU24ndV81q7TTU6eg7P83GnRqanQ9h92dfhGV+WegXo+Ig6aRPtIU53+31ek1YBtn7EZHJQJyMvv+u36jBPSLw2iro2+2ZLAeJyGQ8G2b0knbr7siJcXBMQ/nSO/4I8gGtMANZmAb+6odhS9kXGJubMF76cdgyQqcvBne5vAitMNPUFT0KyKUFc9KszQkUSiQQu3Lb9yXTgHmFqM3Ptj0xyuUy5G1/a447EZmM6WzUpk796jXQnXvmBCH8N3+WI8OQE+Ouk6VuUCxeN5HN29vgDfe4ba/QpiYbFuSKkkG2NlcwEw86yPmX+Xxghbzk0oK5/iZCBddq6YuYO+0UwcWi6YoeYWi3BBgG3APWTTC4YlodCLtFULt9GbRGXQcXS+2ZTlvwbrESC/XdZLlUNvvG6xoMN0PlYgnwuR4Rl0pAuVyZi3ASKSPqoqmzowldXd8796yJbr/miGinCC6VfI/td0NfZMsoFApFO8iJcSCbgX7x9bCl+Epf57lrU5O4/VtrlcdRNciWK0vY/PUHKo+japDNa6soP34GxiOnvRtkCwl+MNg1BLsfur8S42/HIBtAxaxcjgz7nu9874nz0BbnAcCTQTadOVafnz1XqLyGX9z87bU6g2x+MFrGzqRpWP/dta4Msm3KSweg/+KtXsqrIHM5rH9ibX8bZPeKllfuRKHl3bZFhzopkYAoHAjuKqMf+rMLjdrUJLisQ79+vceiXOhApxwyl64HOunrolMeWYZ+4ReBzPV4pgfHpkgmQTNT/p5PETmHurpyJ6LPE9E1InrF8dwwEX2DiF6zbvOO/z1LRBeJ6AIRfaAb4bZRrxho4k9p7jQQM+RGOI2FK/ct82YAFWNkp0F2FboO3LpT9Vq+6Mxmzb5q4Y9ZMciuepICK5IkMhmARMMJdLc+tjHXCuhgnz1J7f4gLWZ+ObuZNdtGzjWG07yzCy6VginWZZmci0Si/pfhjdtV80MikzEHxhCuQu1jTqRSLQvviWy2qYk3xePAHX8m/+1jzsu5IDKZuuMiyD72Epb5dwA+WPPcMwBeYOZlAC9Yj0FERwE8AeCYtc2/JKLOYw2LsxCJREuDWdsgOywjWj40bw4qh+eB5TlzIC8cgMyZX0rywKRp6mAZZNdtXy5DX79pLr7x0eCbV+bMA39xtmk72yDbiUilQPOdW595xuoDmRsAFRrkqi+bGUm2QXaVzqFBUDbTcNVgz2TGTZNzOT4KOTXhaupcOR4tg2wbY2cHlExAjtRZC/ccux+pcABysDrEoV+/XnX1yYfnQTN7Jt5BIsdHISfGgMVZs2+bwCtzoPmZxhcphammi8q6gWamIAdzpvl1sy8h2yB7uTp7rtLHg/5n2XgKyxDRPICvMfNx6/EFAI85rPS+xcwrRPQsADDzP7LafR3AZ5j5281eX02oKhQKRfv4MaE6wcyXAcC6tZNSpwG87Wh3yXquDiJ6moheJKIXS/BnmW5UzZXF6hHHg87Nlf2ENK2+IBei16cim3XN067q4wggx8agzVSfCkEajXeKXF4MLdyp6I5eZ8u4/U5x/WnAzM8x81lmPhtD85rIjaoCthTz5mVztj2kqmyNoLcc5tKGDrp0OTwxDWBdB96pN8GmN6Ollbe3wVfqJ06r+thviExj5yYYd+/CuFntkxqE0biNyGQ6ivHzleugRDwwAxo5NBhtsxsLe5FclOl0cL9qhWNg3drFky8BcC6/mwHwbufyTLgw2dF2+q1bwNQ4RFSMCCxqzZD123dCUtIEZlddfhs5t4ttMF1LkDpJSmC6+THKu7v18wABGI3biPFRiA5MvI1794DxkeDM6afGW8bcowAXJiL/JdTp4P5VAE9Z958C8BXH808QUYKIFgAswzTP7orSULLjnFcjlwKkhP7Yfd3KaI5lHM0Prnq6QpL5fCghDpFMAudOwHjUWx6+VpipC3uIZNL3OulydATi5OGKTuOR001PJnn0UJ2dnhwZ9j08oy3MQSzOeTbIprPH6/Pc52d9z3M3MinQcL6uRr8nrt4AHVvuvSgX9GwSonDA9D1oE22uALm04IOqesq5JOSR5UgXqGs5oUpEzwN4DMAogKsA/gGAPwHwJQCzAN4C8DeY+abV/tMA/haAMoBPMfOfthKhJlQVCoWifbqaUGXmJ5l5ipljzDzDzJ9j5nVmfh8zL1u3Nx3tP8vMB5l5xcvA3gqRTLZc9RcFZC7nOQ88KJNpNzwbjRO5OvIEpV2OjTW8UneakFMs7jonE5ROr/txXYdhmZUHgVxZ6srvU2Qy/k+sEkEeWY7cHFktpGmmzogT+fIDlEqhON1l3m0AsTHKpD2fPEGZTLtBuaynmCZJCQzVn8yBac/nGhuND+39j2IaKFu/yM1pkO0nlf5ocYxRKglRY/pMglz72A92Zwa7imW76e85JLA9OxhcfL9DSNOwU4j+xG9/lB/oEn7olFkrPUrLrBX7BtI0GPcfA3375bCl+EpUjbTF6hHQW1ciN9kfBH1dOExOjDc1y3XCa+ZkpnPSSKwegfzBz3wf2OXSQmOz3Fp8noxsRjOj8ap2Nf1oE5QJefnxMw2LftHZ45XQkszlmhtk+8zWrz0AMV+AKDU/vrTpA3U13ymRCMwA5e6T5z2lFMvlRcixMVdj59jVu77WSydNa2o07sQ2yOa1VVC5gzLbXSCyWVejcTfo/hPmsRrCOR/9K3chm5rlOqmYIsfiFS9G530/IU1rapZb1TZAM+y6fbdhJO7Wd0FpF8lkQ6PxKl1ErmbIgelMp80a6bVm4nUN3euLB3V8imy2odF4lR6HWXadibv1Pz8LZolstqHReJWWWBxcLlVMveXUJMrvXA7m17lV78qTTp+N2vvbINvQPQ3sgMO01nGyVAZ5n82nuVz2/PqVD9lHQ+SG+25jIOFSce/qxDqhg9Ju7LgYg9glVp3vgTlcnfaXucskYNUxZ5uO1+hx0+6LTo+1VpqZZbNt4s0+9qdXnaUaA3QZYBCC2btOe0wK4WIu8mGZXqGfj5b5tBwahDga/Rn3huGE+4Jf3i+XF6E1sF+Thw5Cjldnnsh8HvJI+/nS7UKaBpytX7PgZu5Nq4fr89xnpqHNNy/mFgWibOJdfuMtNadWQ/TDMgCMh09B/PlLXe1DLi2AtndRfqfrBbMKRd8hjx4Crq1Dv7EetpSm0JljoFd/7v7LLULw2iror34Yek33vp5QBYDYuhmWEZkMtKnOShFASkCTvpps26vj5Miw53o4QZsT2wbZgHlV7ql+PJHrCko/tTsNsrX5WTO3eGK86SplN/Npv3U6DbJrV0c22q/IZuvz4xv0ca+gu5tgy7u10/7QpiZ9z3WXtzZNr2SitlebBmmQra2b9eIpFo/sr67ox9wB0Ma2eUfXweXOfnrR1o5psu3jNy3tlsw7Je+x/VDMia2TnIsl0yjEC8VS3VO+ancaZNtx63K5uSGxm/k0fNbpMMiuNfNuuF9dd5+fcenjXsE7O0Cp1FxXq9cotXG8dMr2jvk5Mrev02mQ7TO0tWNetTc45qJAX4RlFAqFQlFP34dlnEWuOjXIlsuLdfW0/aSVTpnPh15z3ItBtlaYqft5bBcfCwovBtnyyLJ74TCfDbKdOA2yG5lPh2WQLY+t1Blkt0MQGgEzL7wjg+zzJ7sqr9Au/NCpvewtZZDdmCCu3OWxFeivXvB1H4pgkCtL0F97PTrZEUJCLi9Av3ARAKAtzsO4cs13q7+gEJkMxNiImZGiiBR9f+UuR0eafzM6zJudV0UVg+JMBrR+20+JAKxiVx6LHvlphN0UIrPwlheE+wR0INqbFNWiuxsQyb2rNIrF66/aAjL0tguXkSPuynfuglw8BGyj9Coa9HGvEclkS0MRp9F7laF3sQTe3gnE0Ftksy2PL+d5Xfe/gMynZT7f8pdCpT9DcrJqObgT0eeJ6BoRveJ47jNE9A4RvWT9fdjxv2eJ6CIRXSCiD/RC5MbDByGaZJ/YBtkAqgymbYNiPjSP8tVrbpv2DiLceWwRItX6RKWYaa4cBhSP486ji55OAJFJg+ZqTLF9NvG2kbkB3H3UPauDd3ZBB/ZCMHIkDzlZ/YVFWiyQPpbjo9h4cAHG0F7xMn39Jnh+uu6LXk7Xm0/LgQxo1v9wIS0UUFptniXDh+fNFb8LhSrzaS4VQULU9bEfGEfmXc3GnTjP69o+Dsp8eufMIuR0vdn9nhDLIBsOo/SA8VLP/T0ANgD8e4dB9mcAbDDzP61pexTA8wDOATgA4JsADjE3X9IWRFhGf+w+yG9939d9KBQKRZB0W8/9/wC42aqdxccAfJGZd5n5dQAXYQ70XSMPHewqHJD4+bVAameL44c9/ywMc0K1HRcot7by2EogP3/d9i2OH64PfwnpOnkaSB+3YXIukklXl6EgXLlEOt2RN4LIZgNzOALMMKxWmGndMGS0wkxgefWd0E3M/feI6IdW2MZ2i50G8LajzSXruTqI6GkiepGIXiyhed0FOTIM2tzuqjaMsX4TRgB+lXTpcvNcbGfbIE2ca/fdhtG1a9t3rnh+n93gtm+65FIgytBBb7m0DaiPxYa3yVOjWAIu14cI/TYeF5kMICVoc7vtbSmZBK4Fs7JVDg0CugHeCaewnldkPm9q9HFtQrd0Orj/IYCDAE4BuAzgn1nPu9XAdI37MPNzzHyWmc/G0HxigguT4HSyqyL+xtZWIMV79Nt3PC9JDrP+dDv7dmvbzvvshob7dmsbkqE3CQJ7NZ82dFdTbL91ivFRiFwWnO5g8nYsH1zmz9Q4KJ0CJYNLa+wELkyA0knAZeI8KnQ0uDPzVWbWmdkA8K+xF3q5BMC5/nsGQNfFXNZXB7G11CJjJmyIcOPpNRQ/cLa9rAIhzZzZgBDJJG787TWUfuVMe9tls6Cze4Ww7Nr5fiFHR7D+ibU6Y3M60zz3WZucMOuoWHg1Au8UbWEOt/7m/Z4Nsm3k8mJV6MFvA3d9aAC3Hp41DePbQJw6Cly+7mtFVSe7UzncfaDQ2ZeQBd1/wveMrnvLg9g4ORVobn27eMpzJ6J5AF9zTKhOMfNl6/7vA3iAmZ8gomMA/gh7E6ovAFiOwoSqQqFQ7De6mlAloucBfBvAChFdIqLfAfAHRPQjIvohgPcC+H0AYOZXAXwJwI8B/BmAT7Ya2L2gTU54N3YOGad5cyvCNMpuZ9+1bYPU7dyXzOWa5oRX5e8HaD7dztoB0rSqSTiKxSHz+SZb9AaZz5sTox4L2oWFHB0xdXpdmVq7/dBgIFfTcmLcLNsdwFqKTvGSLfMkM08xc4yZZ5j5c8z8m8x8gplPMvNH7at4q/1nmfkgM68w85/2QiQPD5rGzn0wuGMw29jYuYagTJxd992G0XVt2yB1O/dN2QFQqkFYgcg0zrYfBmg+DRLm5+4FKUG5vbYUj4Fy9QbfvYYGMmYsOxfOghrPDOXM9RUdmnHTwABEEKGS4UHQwED/h2X8RoVlFAqFon36vvwAr602NXYW6TTE6pG6Ili8tgqxeiSYn05EwLkTZn3vIx4dlsIwzU0kIE4dhVaYqcslFum0a761W3ExOnPM3wnVfB7yyDLkylIllNC0yJmQpqlzLT73sVaYgbY4793Y2TpWa/HbeFyuLEGbKzQ3dj5/0nS0qglnyaHBqklqPxGnjkKbn8XGx883bMNrqxWD7FrkoYPB5J6fOwG5soTdj9zv/n/reHQzGg+Kvrhyp0TCrJndxNjZzfjZ3i4IA2J7f1wsgqT0lF0QllE2xeJ7Oeo1+eJeTbF9N3YmqqS+2qbMrfrLVafffWwZX1Mq5dlXMwzjcdI0sMEQmXRDnZRImLXma02wrc8iiIwZ+9gUqWRD7+TKeV0u1aXjBmHiXdFgGYk3ShO1QzZ+fq79bZANb53jNtBUDfQ+G2Q79xd1o+yGg7KQzfvRGsi4XPb/C5O5rh95d7dpX/HurnlVas95GLr/X56W8TVves8DrzLFJhGITrsvm30BNdTAbA6YRL4PmvZx1WhgB5qPB0GYeDs1NDvXw7hwc9IXYZleEDWD7ApCRsZ0WJxcaVjBTmQyECdXXE2fg6SZRgCQ42OQK4vgAD/vRgbZzdAW5qAtzIViNN4J/WLiDVgm5BHOYgmKvgjL9AJlkK1QdIbM58HTEzBe+WnYUloijywDN25Dv349bCmB0PcTqvZqvnbqXtcVOpIS5Xf9rd8hlxagLcx1NIGiLc57zo/vBtsg280kuZHBtBN7u07fp1ecBtm1eOmrilm5zwbktkG28ejptlN1nWblfuvUpiY7LshF+UHwhV/4oKoeba4AOTri3XPA3m5hzpwIvnknkIFdLi2Y6286zMcPgr6IudNO0TS31r2b31bMqu3HtqGtj9BuCTCMzkx6gyxAtFsEufUlGy11VEyL/dbrNMiuxcO+7c/fdwNyyyBbbrS/Hza4Yqbut04ulUyD8WIH+9ktBlIkDoD52ZbL7Z+rle0CMsjeKYJLpeD6pQN+acIyCoVCsd/o67CMNjWJ27+1VnncqUG238hDB6tMhz3rFBL8oL85zlW7O3UUMpeD8chpTwbZle0yGdCZY74X4rKxC5MZj5yuMsimM8eaTqjahcMarYnwA9I08NoqNn/9Ac9rHOTSghmKCWitw83fXuvYIDsoo3HSNKz/7prZN4cOejbIFqtHIPP5wMYGmcth/RNryiDbCy2v3Imgzc/CuHq949KjyiC7NZRIQBQOQL/4ethS+gshIVcWof/0YiBlkDuiyzRGOTYGkgLlK1d7KMqFANItu0UuL4LffhfGzk7YUvo/zx3M4Lv3OosXWgRhkN3vcKkM3Pbf0GTfwQZw43a0B6UutfHWFtht/qPXRLkPbW7dMU1XIk4krtyJ6DqATQA3wtbikVH0j1agv/T2k1ZA6fWTftIKhKN3jpldy59GYnAHACJ6sdHPi6jRT1qB/tLbT1oBpddP+kkrED29kZ9QVSgUCkX7qMFdoVAo9iFRGtyfC1tAG/STVqC/9PaTVkDp9ZN+0gpETG9kYu4KhUKh6B1RunJXKBQKRY9Qg7tCoVDsQ0If3Inog0R0gYguEtEzYesBACL6PBFdI6JXHM8NE9E3iOg16zbv+N+zlv4LRPSBgLUWiOh/EdFPiOhVIvo7UdVLREki+g4RvWxp/YdR1VqjWxLRD4joa1HXS0RvENGPiOglInoxynqJaIiIvkxEP7WO37UIa12x+tT+u0tEn4qqXgAAM4f2B0AC+DmARQBxAC8DOBqmJkvXewDcB+AVx3N/AOAZ6/4zAP6Jdf+opTsBYMF6PzJArVMA7rPuZwH8zNIUOb0ACMCAdT8G4P8BOB9FrTW6/y6APwLwtSgfC5aGNwCM1jwXSb0AvgDgE9b9OIChqGqt0S0BXAEwF2W9gXdMTSetAfi64/GzAJ4NU5NDyzyqB/cLAKas+1MALrhpBvB1AGsh6v4KgPdHXS+ANIDvA3ggyloBzAB4AcDjjsE9ynrdBvfI6QWQA/A6rKSOKGt10f6rAP4i6nrDDstMA3jb8fiS9VwUmWDmywBg3dpuApF5D0Q0D+A0zCviSOq1QhwvAbgG4BvMHFmtFv8CwN8D4CwUHmW9DOB/ENH3iOhp67ko6l0EcB3Av7VCXv+GiDIR1VrLEwCet+5HVm/Yg7ubnU6/5WZG4j0Q0QCA/wLgU8zcrPpXqHqZWWfmUzCviM8RUTOz01C1EtFfA3CNmb/ndROX54I+Fh5i5vsAfAjAJ4noPU3ahqlXgxn6/ENmPg2ztlSzObco9C2IKA7gowD+c6umLs8Fqjfswf0SAKev2wyAqJqcXiWiKQCwbq9Zz4f+HogoBnNg/4/M/F+tpyOrFwCY+TaAbwH4IKKr9SEAHyWiNwB8EcDjRPQfEF29YOZ3rdtrAP4YwDlEU+8lAJesX24A8GWYg30UtTr5EIDvM7Nd+ziyesMe3L8LYJmIFqxvxCcAfDVkTY34KoCnrPtPwYxt288/QUQJIloAsAzgO0GJIiIC8DkAP2Hmfx5lvUQ0RkRD1v0UgF8B8NMoagUAZn6WmWeYeR7msfk/mfk3oqqXiDJElLXvw4wNvxJFvcx8BcDbRLRiPfU+AD+OotYansReSMbWFU29YUxI1ExOfBhmhsfPAXw6bD2WpucBXAZQgvkN/DsARmBOrL1m3Q472n/a0n8BwIcC1vowzJ97PwTwkvX34SjqBXASwA8sra8A+PvW85HT6qL9MexNqEZSL8w49svW36v2+RRhvacAvGgdD38CIB9Vrdb+0wDWAQw6nousXlV+QKFQKPYhYYdlFAqFQuEDanBXKBSKfYga3BUKhWIfogZ3hUKh2IeowV2hUCj2IWpwVygUin2IGtwVCoViH/L/AZDlXJbMWnnzAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "song_ix = 0\n", + "song = y_train[song_ix:song_ix+1]\n", + "print(song.shape)\n", + "flat = flattenWindow(song)\n", + "print(flat.shape)\n", + "\n", + "# plt.figure()\n", + "# matplotlib.rcParams['figure.figsize'] = [20, 20]\n", + "# for i in range(16):\n", + "# plt.subplot(2, 8, i+1)\n", + "# plt.imshow(song[0][i])\n", + "# plt.show()\n", + "\n", + "plt.figure()\n", + "plt.imshow(flat)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tensor(\"Const_1:0\", shape=(1, 16, 96, 96), dtype=float32)\n", + "Tensor(\"model_1/encoder/cond/Merge:0\", shape=(?, 40), dtype=float32)\n", + "None\n" + ] + } + ], + "source": [ + "song = tf.convert_to_tensor(song)\n", + "print(song)\n", + "\n", + "with tf.GradientTape() as tape:\n", + " tape.watch(song)\n", + " \n", + " latent_x = encoder(song)\n", + " \n", + " print(latent_x)\n", + "\n", + " if use_pca:\n", + " current_params = np.dot(\n", + " latent_x - latent_means, latent_pca_vectors.T) / latent_pca_values\n", + " else:\n", + " current_params = (\n", + " latent_x - latent_means) / latent_stds\n", + "\n", + "# print(current_params)\n", + " \n", + "gradient = tape.gradient(latent_x, song)\n", + "print(gradient)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "composer", + "language": "python", + "name": "composer" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/README.md b/README.md index ed6355be..7e2448ae 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,9 @@ Generate tunable music using neural networks. Repository to ["Generating Songs With Neural Networks (Neural Composer)"](https://youtu.be/UWxfnNXlVy8). +## What is this fork +* I'd like to mess around with the models and create an accompanying web app so that people can play with this online. + ## How to install * Requires Python 3.5.6 * CUDA 10.0 AND CUDNN 7.6.5 https://developer.nvidia.com/rdp/cudnn-archive @@ -75,12 +78,7 @@ Exclusive to this fork Could be useful if you are listening while doing something else and want to look back on the history. 'B' - Blends smoothly through a series of preset songs by taking a linear combination of the latent vectors. Sounds awesome. asks in command line what song files (created by the S command above) to blend - Example result in blended song.mp3 which was recorded with audacity - https://drive.google.com/file/d/17MNTsHMXghApAa_GcUMB-pY0PTWWRnIF/view?usp=sharing -I would have liked the text entry to be in a box to the right of the notes rather than in command line, but I couldn't figure out how to do that -If you can help with that please make a pull request and email me at -michael_einhorn@yahoo.com ``` @@ -91,4 +89,8 @@ link below and extract it into the results/history folder of your project. The z named e and some number, indicating the number of epochs the model was trained and some model.h5. Pay attention when extracting the model not to override one of your own trained, most beloved models! -* Bach dataset: https://drive.google.com/open?id=1P_hOF0v55m1Snzvri5OBVP5yy099Lzgx +40 params, trained with 0.0002 noise +https://drive.google.com/open?id=1J_AHhXavLf_bQgmd1j8MpuMMSqhmKGpW + +Data set +https://www.ninsheetmusic.org/ diff --git a/TextInput LICENSE b/TextInput LICENSE new file mode 100644 index 00000000..1f8fd7b8 --- /dev/null +++ b/TextInput LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2017 Silas Gyger + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/TextInput README.md b/TextInput README.md new file mode 100644 index 00000000..c3fe232a --- /dev/null +++ b/TextInput README.md @@ -0,0 +1,61 @@ +# Pygame Text Input Module + +This small module can be used to write text in pygame. It includes a blinking cursor that can be moved using the left and right as well as the home and the end button. Any key can be pressed for an extended period of time to make that key re-enter itself many times a second. + +Here's an example of the module using the [Ubuntu font](http://font.ubuntu.com/): +![Example of module in use](http://i.imgur.com/enuCPEY.gif) + +# Usage + +The module is very easy to use. Simply create an instance of `InputText` in your code and then feed its `update`-method with pygame-events every frame. The surface with the text and the cursor can then be gotten using `get_surface()`. + +Here's a small example that displays a white window with the `InputText`-surface on it: + + +``` +#!/usr/bin/python3 +import pygame_textinput +import pygame +pygame.init() + +# Create TextInput-object +textinput = pygame_textinput.TextInput() + +screen = pygame.display.set_mode((1000, 200)) +clock = pygame.time.Clock() + +while True: + screen.fill((225, 225, 225)) + + events = pygame.event.get() + for event in events: + if event.type == pygame.QUIT: + exit() + + # Feed it with events every frame + textinput.update(events) + # Blit its surface onto the screen + screen.blit(textinput.get_surface(), (10, 10)) + + pygame.display.update() + clock.tick(30) +``` +If you want to catch the user input after the user hits `Return`, simply evaluate the return value of the `update()`-method - it is always `False` except for when the user hits `Return`, then it's `True`. To get the inputted text, use `get_text()`. Example: +``` +if textinput.update(events): + print(textinput.get_text()) +``` + +## Arguments: +Arguments for the initialisation of the `TextInput`-object (all of them are optional) + +| argument | description | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------------| +| initial_string | Initial text to be displayed | +| font_family | Name or path of the font that should be used. If none or one that doesn't exist is specified, the pygame default font is used. | +| font_size | Size of the font in pixels. Default is 35. | +| antialias | (bool) Declare if antialiasing should be used on text or not. True uses more CPU cycles. | +| text_color | The color of the text. | +| cursor_color | The color of the cursor. | +| repeat_keys_initial_ms | Time in ms until the key presses get repeated when a key is not released | +| repeat_keys_interval_ms | Time in ms between key presses if key is not released | diff --git a/composer.py b/composer.py index 0db1e06d..bddf2bfe 100644 --- a/composer.py +++ b/composer.py @@ -6,24 +6,31 @@ """ import argparse +import itertools import math import wave +import typing +from wave import Wave_write import numpy as np import pyaudio import pygame +from numpy import ndarray + import params +import os + import midi_utils import keras -from keras.models import Model, load_model +from keras.models import Model,load_model from keras import backend as K # User constants dir_name = 'results/history/' sub_dir_name = 'e2000/' -sample_rate = 48000 +sample_rate: int = 48000 note_dt = 2000 # num samples note_duration = 20000 # num samples note_decay = 5.0 / sample_rate @@ -38,23 +45,23 @@ autosavenow = False blend = False blendfactor = np.float32(1.0) -#0 first sond 1 first to second 2 second song 3 second to first +# 0 first sond 1 first to second 2 second song 3 second to first blendstate = 0 # colors -background_color = (210, 210, 210) -edge_color = (60, 60, 60) -slider_colors = [(90, 20, 20), (90, 90, 20), (20, 90, 20), - (20, 90, 90), (20, 20, 90), (90, 20, 90)] +background_color = (210,210,210) +edge_color = (60,60,60) +slider_colors = [(90,20,20),(90,90,20),(20,90,20), + (20,90,90),(20,20,90),(90,20,90)] note_w = 96 note_h = 96 note_pad = 2 -notes_rows = int(num_measures / 8) +notes_rows = num_measures // 8 notes_cols = 8 -slider_num = min(40, num_params) +slider_num = min(40,num_params) slider_h = 200 slider_pad = 5 tick_pad = 4 @@ -63,8 +70,8 @@ control_h = 30 control_pad = 5 control_num = 5 -control_colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (0, 255, 255), (255, 255, 0)] -control_inits = [0.75, 0.5, 0.5, 0.5, 0.5] +control_colors = [(255,0,0),(0,255,0),(0,0,255),(0,255,255),(255,255,0)] +control_inits = [0.75,0.5,0.5,0.5,0.5] # derived constants notes_w = notes_cols * (note_w + note_pad * 2) @@ -73,7 +80,7 @@ sliders_h = slider_h + slider_pad * 2 controls_w = control_w * control_num controls_h = control_h -window_w = max(notes_w, controls_w) +window_w = max(notes_w,controls_w) window_h = notes_h + sliders_h + controls_h slider_w = int((window_w - slider_pad * 2) / slider_num) notes_x = 0 @@ -88,7 +95,7 @@ controls_y = notes_h + sliders_h # global variables -keyframe_paths = np.array(("song 1.txt", "song 2.txt", )) +keyframe_paths = np.array(("song 1.txt","song 2.txt",)) prev_mouse_pos = None mouse_pressed = 0 cur_slider_ix = 0 @@ -97,10 +104,10 @@ balance = 0.5 instrument = 0 needs_update = True -current_params = np.zeros((num_params,), dtype=np.float32) +current_params = np.zeros((num_params,),dtype=np.float32) keyframe_params = np.zeros((len(keyframe_paths),num_params),dtype=np.float32) -current_notes = np.zeros((num_measures, note_h, note_w), dtype=np.uint8) -cur_controls = np.array(control_inits, dtype=np.float32) +current_notes = np.zeros((num_measures,note_h,note_w),dtype=np.uint8) +cur_controls = np.array(control_inits,dtype=np.float32) keyframe_controls = np.zeros((len(keyframe_paths),len(cur_controls)),dtype=np.float32) blend_slerp = False keyframe_magnitudes = np.zeros((len(keyframe_paths),),dtype=np.float32) @@ -117,7 +124,7 @@ audio_pause = False -def audio_callback(in_data, frame_count, time_info, status): +def audio_callback(in_data,frame_count,time_info,status): """ Audio call-back to influence playback of music with input. :param in_data: @@ -151,8 +158,8 @@ def audio_callback(in_data, frame_count, time_info, status): # check if paused if audio_pause and status is not None: - data = np.zeros((frame_count,), dtype=np.float32) - return data.tobytes(), pyaudio.paContinue + data = np.zeros((frame_count,),dtype=np.float32) + return data.tobytes(),pyaudio.paContinue # find and add any notes in this time window cur_dt = note_dt @@ -162,54 +169,59 @@ def audio_callback(in_data, frame_count, time_info, status): break note_ix = note_time % note_h notes = np.where( - current_notes[measure_ix, note_ix] >= note_threshold)[0] + current_notes[measure_ix,note_ix] >= note_threshold)[0] for note in notes: - freq = 2 * 38.89 * pow(2.0, note / 12.0) / sample_rate + freq = 2 * 38.89 * pow(2.0,note / 12.0) / sample_rate if params.encode_length: - if not note in audio_notes_lengths or audio_notes_lengths[note][1] < audio_time: - audio_notes_lengths[note] = (note_time_dt, note_time_dt + note_dt, freq, current_notes[measure_ix, note_ix, note]) + if note not in audio_notes_lengths or audio_notes_lengths[note][1] < audio_time: + audio_notes_lengths[note] = ( + note_time_dt,note_time_dt + note_dt,freq,current_notes[measure_ix,note_ix,note]) else: - audio_notes_lengths[note] = (audio_notes_lengths[note][0], note_time_dt + note_dt, freq, current_notes[measure_ix, note_ix, note]) + audio_notes_lengths[note] = ( + audio_notes_lengths[note][0],note_time_dt + note_dt,freq,current_notes[measure_ix,note_ix,note]) else: - audio_notes.append((note_time_dt, note_time_dt + note_duration, freq, current_notes[measure_ix, note_ix, note])) + audio_notes.append( + (note_time_dt,note_time_dt + note_duration,freq,current_notes[measure_ix,note_ix,note])) note_time += 1 note_time_dt += cur_dt # generate the tones - data = np.zeros((frame_count,), dtype=np.float32) - for t, e, f, v in audio_notes_lengths if params.encode_volume else audio_notes: + data = np.zeros((frame_count,),dtype=np.float32) + for t,e,f,v in audio_notes_lengths if params.encode_volume else audio_notes: if e < audio_time: continue - startTime = 0 if params.encode_volume else t; - x = np.arange(audio_time - startTime, audio_time + frame_count - startTime) - x = np.maximum(x, 0) + startTime = 0 if params.encode_volume else t + x = np.arange(audio_time - startTime,audio_time + frame_count - startTime) + x = np.maximum(x,0) if instrument == 0: - w = np.sign(1 - np.mod(x * f, 2)) # Square + w = np.sign(1 - np.mod(x * f,2)) # Square elif instrument == 1: - w = np.mod(x * f - 1, 2) - 1 # Sawtooth + w = np.mod(x * f - 1,2) - 1 # Sawtooth elif instrument == 2: - w = 2 * np.abs(np.mod(x * f - 0.5, 2) - 1) - 1 # Triangle + w = 2 * np.abs(np.mod(x * f - 0.5,2) - 1) - 1 # Triangle elif instrument == 3: w = np.sin(x * f * math.pi) # Sine elif instrument == 4: - w = -1 * np.sign(np.mod(2*x*f,4)-2) * np.sqrt(1-((np.mod(2*x*f,2)-1) * ((np.mod(2*x*f,2)-1)))) # Circle + w = -1 * np.sign(np.mod(2 * x * f,4) - 2) * np.sqrt( + 1 - ((np.mod(2 * x * f,2) - 1) * (np.mod(2 * x * f,2) - 1))) # Circle + # w = np.floor(w*8)/8 w[x == 0] = 0 - n = 12 * np.log (f * sample_rate / 38.89) / np.log(2); - w *= volume * np.exp(-x * note_decay) * pow(balance, (n - 60) / 12.0) / np.log(2) + n = 12 * np.log(f * sample_rate / 38.89) / np.log(2) + w *= volume * np.exp(-x * note_decay) * pow(balance,(n - 60) / 12.0) / np.log(2) if params.encode_volume: w *= v / 255 data += w - data = np.clip(data, -32000, 32000).astype(np.int16) + data = np.clip(data,-32000,32000).astype(np.int16) # remove notes that are too old audio_time += frame_count - audio_notes = [(t, e, f, v) - for t, e, f, v in audio_notes if audio_time < t + note_duration] - blendfactor = (np.cos( ((note_time / note_h)/num_measures) * math.pi )+1)/2 - #print(blendfactor) + audio_notes = [(t,e,f,v) + for t,e,f,v in audio_notes if audio_time < t + note_duration] + blendfactor = (np.cos(((note_time / note_h) / num_measures) * math.pi) + 1) / 2 + print(blendfactor) # reset if loop occurs if note_time / note_h >= num_measures: audio_time = 0 @@ -217,15 +229,15 @@ def audio_callback(in_data, frame_count, time_info, status): note_time_dt = 0 audio_notes = [] audio_notes_lengths = {} - blendstate = (blendstate+1)%(2*len(keyframe_paths)) - #if blendstate == 0: - #audio_pause = True + blendstate = (blendstate + 1) % (2 * len(keyframe_paths)) + # if blendstate == 0: + # audio_pause = True blendfactor = 1 if autosave and not autosavenow: autosavenow = True # return the sound clip - return data.tobytes(), pyaudio.paContinue + return data.tobytes(),pyaudio.paContinue def update_mouse_click(mouse_pos): @@ -251,7 +263,7 @@ def update_mouse_click(mouse_pos): mouse_pressed = 2 -def apply_controls(): +def apply_controls() -> object: """ Change parameters based on controls. :return: @@ -265,11 +277,11 @@ def apply_controls(): global balance note_threshold = (1.0 - cur_controls[0]) * 200 + 10 - note_dt = (1.0 - cur_controls[1]) * 1800 + 200 + note_dt = ((1.0 - cur_controls[1]) * 1800 + 200) * params.timeScaleF volume = cur_controls[2] * 6000 - balance = pow(2, cur_controls[3] * 4 - 2); + balance = pow(2,cur_controls[3] * 4 - 2) - note_duration = 10000 / ((1-cur_controls[4]) + 0.001) + note_duration = 10000 / ((1 - cur_controls[4]) + 0.001) note_decay = 10 * (1 - cur_controls[4]) / sample_rate @@ -280,9 +292,7 @@ def update_mouse_move(mouse_pos): :return: """ global needs_update - t = 1 - if int(cur_control_ix) == 0: - t = 210.0 / 200 + t = 210.0 / 200 if int(cur_control_ix) == 0 else 1 if mouse_pressed == 1: # change sliders y = (mouse_pos[1] - sliders_y) @@ -314,8 +324,8 @@ def draw_controls(screen): h = control_h - control_pad * 2 col = control_colors[i] - pygame.draw.rect(screen, col, (x, y, int(w * t * cur_controls[i]), h)) - pygame.draw.rect(screen, (0, 0, 0), (x, y, w, h), 1) + pygame.draw.rect(screen,col,(x,y,int(w * t * cur_controls[i]),h)) + pygame.draw.rect(screen,(0,0,0),(x,y,w,h),1) t = 1 @@ -334,20 +344,20 @@ def draw_sliders(screen): cx = x + slider_w / 2 cy_1 = y cy_2 = y + slider_h - pygame.draw.line(screen, slider_color, (cx, cy_1), (cx, cy_2)) + pygame.draw.line(screen,slider_color,(cx,cy_1),(cx,cy_2)) cx_1 = x + tick_pad cx_2 = x + slider_w - tick_pad for j in range(int(num_sigmas * 2 + 1)): ly = y + slider_h / 2.0 + \ - (j - num_sigmas) * slider_h / (num_sigmas * 2.0) + (j - num_sigmas) * slider_h / (num_sigmas * 2.0) ly = int(ly) - col = (0, 0, 0) if j - num_sigmas == 0 else slider_color - pygame.draw.line(screen, col, (cx_1, ly), (cx_2, ly)) + col = (0,0,0) if j - num_sigmas == 0 else slider_color + pygame.draw.line(screen,col,(cx_1,ly),(cx_2,ly)) py = y + int((current_params[i] / (num_sigmas * 2) + 0.5) * slider_h) - pygame.draw.circle(screen, slider_color, (int( - cx), int(py)), int((slider_w - tick_pad) / 2)) + pygame.draw.circle(screen,slider_color,(int( + cx),int(py)),int((slider_w - tick_pad) / 2)) def get_pianoroll_from_notes(notes): @@ -356,27 +366,26 @@ def get_pianoroll_from_notes(notes): :param notes: :return: """ - - output = np.full((3, int(notes_h), int(notes_w)), 64, dtype=np.uint8) - for i in range(notes_rows): - for j in range(notes_cols): - x = note_pad + j * (note_w + note_pad * 2) - y = note_pad + i * (note_h + note_pad * 2) - ix = i * notes_cols + j + output = np.full((3,int(notes_h),int(notes_w)),64,dtype=np.uint8) + + for i,j in itertools.product(range(notes_rows),range(notes_cols)): + x = note_pad + j * (note_w + note_pad * 2) + y = note_pad + i * (note_h + note_pad * 2) + ix = i * notes_cols + j - measure = np.rot90(notes[ix]) + measure = np.rot90(notes[ix]) - played_only = np.where(measure >= note_threshold, 255, 0) - output[0, y:y + note_h, x:x + - note_w] = np.minimum(measure * (255.0 / note_threshold), 255.0) - output[1, y:y + note_h, x:x + note_w] = played_only - output[2, y:y + note_h, x:x + note_w] = played_only + played_only = np.where(measure >= note_threshold,255,0) + output[0,y:y + note_h,x:x + + note_w] = np.minimum(measure * (255.0 / note_threshold),255.0) + output[1,y:y + note_h,x:x + note_w] = played_only + output[2,y:y + note_h,x:x + note_w] = played_only - return np.transpose(output, (2, 1, 0)) + return np.transpose(output,(2,1,0)) -def draw_notes(screen, notes_surface): +def draw_notes(screen,notes_surface): """ Draw pianoroll notes to screen. :param screen: @@ -385,7 +394,7 @@ def draw_notes(screen, notes_surface): """ pygame.surfarray.blit_array( - notes_surface, get_pianoroll_from_notes(current_notes)) + notes_surface,get_pianoroll_from_notes(current_notes)) measure_ix = int(note_time / note_h) note_ix = note_time % note_h @@ -394,7 +403,7 @@ def draw_notes(screen, notes_surface): y = notes_y + note_pad + \ int(measure_ix / notes_cols) * (note_h + note_pad * 2) - pygame.draw.rect(screen, (255, 255, 0), (x, y, 4, note_h), 0) + pygame.draw.rect(screen,(255,255,0),(x,y,4,note_h),0) def play(): @@ -425,11 +434,25 @@ def play(): K.set_image_data_format('channels_first') print("Loading encoder...") - model = load_model(dir_name + 'model.h5') - encoder = Model(inputs=model.input, - outputs=model.get_layer('encoder').output) - decoder = K.function([model.get_layer('decoder').input, K.learning_phase()], - [model.layers[-1].output]) + # priority name.h5 in sub, model.h5 in sub, name.h5 in dir, model.h5 in dir + import torch + from models import AutoencoderModel + + model_path = '' + if os.path.isfile(dir_name + sub_dir_name + '/' + sub_dir_name + '.pth'): + model_path = dir_name + sub_dir_name + '/' + sub_dir_name + '.pth' + elif os.path.isfile(dir_name + sub_dir_name + '/' + 'model.pth'): + model_path = dir_name + sub_dir_name + '/' + 'model.pth' + elif os.path.isfile(dir_name + sub_dir_name + '.pth'): + model_path = dir_name + sub_dir_name + '.pth' + else: + model_path = dir_name + 'model.pth' + + model = torch.load(model_path) + model.eval() + # Using PyTorch model directly for both encoding and decoding + encoder = model.encoder + decoder = model.decoder print("Loading gaussian/pca statistics...") latent_means = np.load(dir_name + sub_dir_name + '/latent_means.npy') @@ -442,8 +465,8 @@ def play(): # open a window pygame.init() pygame.font.init() - screen = pygame.display.set_mode((int(window_w), int(window_h))) - notes_surface = screen.subsurface((notes_x, notes_y, notes_w, notes_h)) + screen = pygame.display.set_mode((int(window_w),int(window_h))) + notes_surface = screen.subsurface((notes_x,notes_y,notes_w,notes_h)) pygame.display.set_caption('Neural Composer') # start the audio stream @@ -466,11 +489,11 @@ def play(): if autosavenow: # generate random song current_params = np.clip(np.random.normal( - 0.0, 1.0, (num_params,)), -num_sigmas, num_sigmas) + 0.0,1.0,(num_params,)),-num_sigmas,num_sigmas) needs_update = True audio_reset = True # save slider values - with open("results/history/autosave" + str(autosavenum)+".txt", "w") as text_file: + with open("results/history/autosave" + str(autosavenum) + ".txt","w") as text_file: text_file.write(sub_dir_name + "\n") text_file.write(str(instrument) + "\n") for iter in cur_controls: @@ -482,12 +505,13 @@ def play(): audio_reset = True save_audio = b'' while True: - save_audio += audio_callback(None, 1024, None, None)[0] + save_audio += audio_callback(None,1024,None,None)[0] if audio_time == 0: break - wave_output = wave.open('results/history/autosave' + str(autosavenum)+'.wav', 'w') + wave_output: typing.Union[Wave_write,typing.Any] = wave.open( + 'results/history/autosave' + str(autosavenum) + '.wav','w') wave_output.setparams( - (1, 2, sample_rate, 0, 'NONE', 'not compressed')) + 1,2,sample_rate,0,'NONE','not compressed') wave_output.writeframes(save_audio) wave_output.close() audio_pause = False @@ -498,19 +522,25 @@ def play(): blendcycle += 1 if blend and blendcycle > 10: blendcycle = 0 - if blendstate%2 == 0: + if blendstate % 2 == 0: needs_update = True - current_params = np.copy(keyframe_params[int(blendstate/2)]) - cur_controls = np.copy(keyframe_controls[int(blendstate/2)]) + current_params = np.copy(keyframe_params[int(blendstate / 2)]) + cur_controls: ndarray = np.copy(keyframe_controls[int(blendstate / 2)]) apply_controls() - elif blendstate%2 == 1: + elif blendstate % 2 == 1: for x in range(0,len(current_params)): - current_params[x] = (blendfactor * keyframe_params[int(blendstate/2),x]) + ((1-blendfactor)*keyframe_params[((int(blendstate/2))+1)%len(keyframe_paths),x]) + current_params[x] = (blendfactor * keyframe_params[int(blendstate / 2),x]) + ( + (1 - blendfactor) * keyframe_params[ + ((int(blendstate / 2)) + 1) % len(keyframe_paths),x]) if blend_slerp: - magnitude = (blendfactor * keyframe_magnitudes[int(blendstate/2)]) + ((1-blendfactor)*keyframe_magnitudes[((int(blendstate/2))+1)%len(keyframe_paths)]) - current_params = current_params * ((sum(current_params*current_params)**-0.5) * magnitude) + magnitude = (blendfactor * keyframe_magnitudes[int(blendstate / 2)]) + ( + (1 - blendfactor) * keyframe_magnitudes[ + ((int(blendstate / 2)) + 1) % len(keyframe_paths)]) + current_params = current_params * ((sum(current_params * current_params) ** -0.5) * magnitude) for x in range(0,len(cur_controls)): - cur_controls[x] = (blendfactor * keyframe_controls[int(blendstate/2),x]) + ((1-blendfactor)*keyframe_controls[((int(blendstate/2))+1)%len(keyframe_paths),x]) + cur_controls[x] = (blendfactor * keyframe_controls[int(blendstate / 2),x]) + ( + (1 - blendfactor) * keyframe_controls[ + ((int(blendstate / 2)) + 1) % len(keyframe_paths),x]) apply_controls() needs_update = True for event in pygame.event.get(): @@ -524,10 +554,10 @@ def play(): update_mouse_click(prev_mouse_pos) update_mouse_move(prev_mouse_pos) elif pygame.mouse.get_pressed()[2]: - current_params = np.zeros((num_params,), dtype=np.float32) + current_params = np.zeros((num_params,),dtype=np.float32) needs_update = True - elif event.type == pygame.MOUSEBUTTONUP: # MOUSE BUTTON UP + elif event.type == pygame.MOUSEBUTTONUP: # MOUSE BUTTON UP mouse_pressed = 0 prev_mouse_pos = None @@ -538,17 +568,17 @@ def play(): if event.key == pygame.K_r: # KEYDOWN R # generate random song current_params = np.clip(np.random.normal( - 0.0, 1.0, (num_params,)), -num_sigmas, num_sigmas) + 0.0,1.0,(num_params,)),-num_sigmas,num_sigmas) needs_update = True audio_reset = True if event.key == pygame.K_t: # KEYDOWN T - for x in range(int(num_params/3)+1, num_params): - current_params[x] = np.clip(np.random.normal(0.0,1.0), -num_sigmas, num_sigmas) + for x in range(int(num_params / 3) + 1,num_params): + current_params[x] = np.clip(np.random.normal(0.0,1.0),-num_sigmas,num_sigmas) needs_update = True if event.key == pygame.K_x: # KEYDOWN X # generate random song current_params += np.clip(np.random.normal( - 0.0, 0.3, (num_params,)), -num_sigmas, num_sigmas) + 0.0,0.3,(num_params,)),-num_sigmas,num_sigmas) needs_update = True if event.key == pygame.K_a: # KEYDOWN A autosave = not autosave @@ -565,11 +595,11 @@ def play(): keyframe_controls = np.zeros((blendnum,len(cur_controls)),dtype=np.float32) keyframe_params = np.zeros((blendnum,num_params),dtype=np.float32) for y in range(blendnum): - fileName = input("The file name of the next song to be blended ") - if "." not in fileName: - fileName = fileName + ".txt" - keyframe_paths.append((fileName)) - fo = open("results/history/" + fileName, "r") + file_name = input("The file name of the next song to be blended ") + if "." not in file_name: + file_name = file_name + ".txt" + keyframe_paths.append(file_name) + fo: typing.TextIO = open("results/history/" + file_name,"r") if not sub_dir_name == fo.readline()[:-1]: running = False print("incompatable with current model") @@ -579,10 +609,10 @@ def play(): keyframe_controls[y,x] = float(fo.readline()) for x in range(len(current_params)): keyframe_params[y,x] = float(fo.readline()) - #keyframe_magnitudes[y] = sum(keyframe_params[y]*keyframe_params[y])**0.5 + # keyframe_magnitudes[y] = sum(keyframe_params[y]*keyframe_params[y])**0.5 if event.key == pygame.K_e: # KEYDOWN E # generate random song with larger variance - current_params = np.clip(np.random.normal(0.0, 2.0, (num_params,)), -num_sigmas, num_sigmas) + current_params = np.clip(np.random.normal(0.0,2.0,(num_params,)),-num_sigmas,num_sigmas) needs_update = True audio_reset = True if event.key == pygame.K_PERIOD: @@ -603,10 +633,10 @@ def play(): if event.key == pygame.K_s: # KEYDOWN S # save slider values audio_pause = True - fileName = input("File Name to save into ") - if "." not in fileName: - fileName = fileName + ".txt" - with open("results/history/" + fileName, "w") as text_file: + file_name = input("File Name to save into ") + if "." not in file_name: + file_name = file_name + ".txt" + with open("results/history/" + file_name,"w") as text_file: if blend: text_file.write(sub_dir_name + "\n") text_file.write("blended song" + "\n") @@ -625,6 +655,7 @@ def play(): needs_update = True audio_reset = True fileName = input("File Name to read ") + if "." not in fileName: fileName = fileName + ".txt" fo = open("results/history/" + fileName, "r") @@ -634,6 +665,7 @@ def play(): print("incompatable with current model") break tempDir = fo.readline() + if tempDir.startswith("blended song"): blend = True blendnum = int(fo.readline()) @@ -659,21 +691,26 @@ def play(): cur_controls[x] = float(fo.readline()) for x in range(len(current_params)): current_params[x] = float(fo.readline()) + apply_controls() + if event.key == pygame.K_o: # KEYDOWN O if not songs_loaded: print("Loading songs...") try: - y_samples = np.load('data/interim/samples.npy') + y_samples: typing.Union[ + typing.Union[np.ndarray,typing.Iterable,int,float,tuple,dict],typing.Any] = np.load( + 'data/interim/samples.npy') y_lengths = np.load('data/interim/lengths.npy') songs_loaded = True except Exception as e: - print("This functionality is to check if the model training went well by reproducing an original song. " - "The composer could not load samples and lengths from model training. " - "If you have the midi files, the model was trained with, process them by using" - " the preprocess_songs.py to find the requested files in data/interim " - "(Load exception: {0}".format(e)) + print( + "This functionality is to check if the model training went well by reproducing an original song. " + "The composer could not load samples and lengths from model training. " + "If you have the midi files, the model was trained with, process them by using" + " the preprocess_songs.py to find the requested files in data/interim " + "(Load exception: {0}".format(e)) if songs_loaded: # check how well the autoencoder can reconstruct a random song @@ -681,21 +718,23 @@ def play(): if is_ae: example_song = y_samples[cur_len:cur_len + num_measures] current_notes = example_song * 255 - latent_x = encoder.predict(np.expand_dims( - example_song, 0), batch_size=1)[0] + example_song_tensor = torch.tensor(example_song, dtype=torch.float).unsqueeze(0) + with torch.no_grad(): + latent_x = encoder(example_song_tensor).numpy() cur_len += y_lengths[random_song_ix] random_song_ix += 1 else: - random_song_ix = np.array( - [random_song_ix], dtype=np.int64) - latent_x = encoder.predict( - random_song_ix, batch_size=1)[0] + random_song_ix: ndarray = np.array( + [random_song_ix],dtype=np.int64) + random_song_ix_tensor = torch.tensor([random_song_ix], dtype=torch.int64) + with torch.no_grad(): + latent_x = encoder(random_song_ix_tensor).numpy() random_song_ix = ( random_song_ix + 1) % model.layers[0].input_dim if use_pca: current_params = np.dot( - latent_x - latent_means, latent_pca_vectors.T) / latent_pca_values + latent_x - latent_means,latent_pca_vectors.T) / latent_pca_values else: current_params = ( latent_x - latent_means) / latent_stds @@ -707,28 +746,28 @@ def play(): # save song as midi audio_pause = True audio_reset = True - fileName = input("File Name to save into ") - if "." not in fileName: - fileName = fileName + ".mid" + file_name = input("File Name to save into ") + if "." not in file_name: + file_name = file_name + ".mid" midi_utils.samples_to_midi( - current_notes, 'results/history/' + fileName, note_threshold) + current_notes,'results/history/' + file_name,note_threshold) audio_pause = False if event.key == pygame.K_w: # KEYDOWN W # save song as wave audio_pause = True audio_reset = True - fileName = input("File Name to save into ") - if "." not in fileName: - fileName = fileName + ".wav" + file_name = input("File Name to save into ") + if "." not in file_name: + file_name = file_name + ".wav" save_audio = b'' while True: - save_audio += audio_callback(None, 1024, None, None)[0] + save_audio += audio_callback(None,1024,None,None)[0] if audio_time == 0: break - wave_output = wave.open('results/history/' + fileName + '.wav', 'w') + wave_output = wave.open('results/history/' + file_name + '.wav','w') wave_output.setparams( - (1, 2, sample_rate, 0, 'NONE', 'not compressed')) + 1,2,sample_rate,0,'NONE','not compressed') wave_output.writeframes(save_audio) wave_output.close() audio_pause = False @@ -771,32 +810,36 @@ def play(): if event.key == pygame.K_c: # KEYDOWN C # y = np.expand_dims( - np.where(current_notes > note_threshold, 1, 0), 0) - latent_x = encoder.predict(y)[0] + np.where(current_notes > note_threshold,1,0),0) + y_tensor = torch.tensor(y, dtype=torch.float) + with torch.no_grad(): + latent_x = encoder(y_tensor).numpy() if use_pca: current_params = np.dot( - latent_x - latent_means, latent_pca_vectors.T) / latent_pca_values + latent_x - latent_means,latent_pca_vectors.T) / latent_pca_values else: current_params = ( - latent_x - latent_means) / latent_stds + latent_x - latent_means) / latent_stds needs_update = True # check if params were changed so that a new song should be generated if needs_update: if use_pca: latent_x = latent_means + \ - np.dot(current_params * latent_pca_values, - latent_pca_vectors) + np.dot(current_params * latent_pca_values, + latent_pca_vectors) else: latent_x = latent_means + latent_stds * current_params - latent_x = np.expand_dims(latent_x, axis=0) - y = decoder([latent_x, 0])[0][0] + latent_x = np.expand_dims(latent_x,axis=0) + latent_x_tensor = torch.tensor(latent_x, dtype=torch.float) + with torch.no_grad(): + y = decoder(latent_x_tensor).detach().cpu().numpy() current_notes = (y * (255)).astype(np.uint8) needs_update = False # draw GUI to the screen screen.fill(background_color) - draw_notes(screen, notes_surface) + draw_notes(screen,notes_surface) draw_sliders(screen) draw_controls(screen) @@ -810,46 +853,82 @@ def play(): audio.terminate() +def loadsongfile(file_name): + global mouse_pressed + global current_notes + global audio_pause + global needs_update + global current_params + global prev_mouse_pos + global audio_reset + global instrument + global songs_loaded + global autosavenow + global autosavenum + global autosave + global blend + global blendstate + global blendfactor + global keyframe_params + global keyframe_controls + global keyframe_paths + global cur_controls + global keyframe_magnitudes + global blend_slerp + + if "." not in file_name: + file_name = file_name + ".txt" + fo = open("results/history/" + file_name,"r") + print(fo.name) + if not sub_dir_name == fo.readline()[:-1]: + running = False + print("incompatable with current model") + return + tempDir = fo.readline() + if tempDir.startswith("blended song"): + blend = True + blendnum = int(fo.readline()) + keyframe_paths = [] + keyframe_controls = np.zeros((blendnum,len(cur_controls)),dtype=np.float32) + keyframe_params = np.zeros((blendnum,num_params),dtype=np.float32) + for y in range(blendnum): + fileName2 = fo.readline()[:-1] + keyframe_paths.append(file_name) + fo2 = open("results/history/" + fileName2,"r") + if not sub_dir_name == fo2.readline()[:-1]: + running = False + print("incompatable with current model") + return + instrument = int(fo2.readline()) + for x in range(len(cur_controls)): + keyframe_controls[y,x] = float(fo2.readline()) + for x in range(len(current_params)): + keyframe_params[y,x] = float(fo2.readline()) + else: + print(tempDir) + instrument = int(tempDir) + for x in range(len(cur_controls)): + cur_controls[x] = float(fo.readline()) + for x in range(len(current_params)): + current_params[x] = float(fo.readline()) + + if __name__ == "__main__": # configure parser and parse arguments parser = argparse.ArgumentParser( description='Neural Composer: Play and edit music of a trained model.') - parser.add_argument('--model_path', type=str, - help='The folder the model is stored in (e.g. a folder named e and a number located in results/history/).', required=True) + parser.add_argument('--model_path',type=str, + help='The folder the model is stored in (e.g. a folder named e and a number located in results/history/).', + required=True) args = parser.parse_args() if args.model_path.endswith(".txt"): - fo = open("results/history/" + args.model_path, "r") - print (fo.name) + fo = open("results/history/" + args.model_path,"r") + print(fo.name) sub_dir_name = fo.readline()[:-1] - tempDir = fo.readline() - if tempDir.startswith("blended song"): - blend = True - blendnum = int(fo.readline()) - keyframe_paths = [] - keyframe_controls = np.zeros((blendnum,len(cur_controls)),dtype=np.float32) - keyframe_params = np.zeros((blendnum,num_params),dtype=np.float32) - for y in range(blendnum): - fileName2 = fo.readline()[:-1] - keyframe_paths.append(fileName2) - fo2 = open("results/history/" + fileName2, "r") - if not sub_dir_name == fo2.readline()[:-1]: - running = false - print("incompatable with current model") - break - instrument = int(fo2.readline()) - for x in range(len(cur_controls)): - keyframe_controls[y,x] = float(fo2.readline()) - for x in range(len(current_params)): - keyframe_params[y,x] = float(fo2.readline()) - else: - print(sub_dir_name) - instrument = int(tempDir) - for x in range(len(cur_controls)): - cur_controls[x] = float(fo.readline()) - for x in range(len(current_params)): - current_params[x] = float(fo.readline()) - + fo.close() + loadsongfile(args.model_path) + else: sub_dir_name = args.model_path play() diff --git a/midi_utils.py b/midi_utils.py index 841b75c8..b9c51b2f 100644 --- a/midi_utils.py +++ b/midi_utils.py @@ -4,6 +4,7 @@ """ Utils to read and write midi. """ +from typing import Dict,Union,Any,List from mido import MidiFile, MidiTrack, Message import numpy as np @@ -18,11 +19,12 @@ def midi_to_samples(file_name, num_notes=96, samples_per_measure=96): :param samples_per_measure: :return: """ + global note has_time_sig = False mid = MidiFile(file_name) ticks_per_beat = mid.ticks_per_beat # get ticks per beat - ticks_per_measure = 4 * ticks_per_beat # get ticks per measure + ticks_per_measure = 4 * ticks_per_beat # get ticks per measure # detect the time signature of the midi for track in mid.tracks: @@ -37,9 +39,11 @@ def midi_to_samples(file_name, num_notes=96, samples_per_measure=96): ticks_per_measure = new_tpm has_time_sig = True + ticks_per_measure = ticks_per_measure * params.timeScaleF + # turn tracks into pianoroll representation maxVol = 1 - all_notes = {} + all_notes: Dict[Union[float, Any], List[Any]] = {} for track in mid.tracks: abs_time = 0 @@ -56,12 +60,12 @@ def midi_to_samples(file_name, num_notes=96, samples_per_measure=96): # we skip notes without a velocity (basically how strong a note is played to make it sound human) if msg.velocity == 0: continue - + if msg.velocity > maxVol: maxVol = msg.velocity # transform the notes into the 96 heights - note = msg.note - (128 - num_notes) / 2 + note: Union[float, Any] = msg.note - (128 - num_notes) / 2 if note < 0 or note >= num_notes: # ignore a note that is outside of that range print('Ignoring', file_name, 'note is outside 0-%d range' % (num_notes - 1)) return [] @@ -109,7 +113,7 @@ def midi_to_samples(file_name, num_notes=96, samples_per_measure=96): # get sample and find its start to encode the start of the note sample = samples[sample_ix] start_ix = int(start - sample_ix * samples_per_measure) - sample[start_ix, int(note)] = vel / maxVol if params.encode_volume else 1 + sample[start_ix, int(note)] = vel if params.encode_volume else 1 #print(vel) #print(maxVol) @@ -117,10 +121,10 @@ def midi_to_samples(file_name, num_notes=96, samples_per_measure=96): if params.encode_length: end_ix = min(end - sample_ix * samples_per_measure, samples_per_measure) while start_ix < end_ix: - sample[start_ix, int(note)] = vel / maxVol if params.encode_volume else 1 + sample[start_ix, int(note)] = vel if params.encode_volume else 1 start_ix += 1 - - + + return samples @@ -134,6 +138,7 @@ def samples_to_midi(samples, file_name, threshold=0.5, num_notes=96, samples_per :param num_notes: :param samples_per_measure: :return: + @rtype: object """ # TODO: Encode the certainties of the notes into the volume of the midi for the notes that are above threshold @@ -143,6 +148,7 @@ def samples_to_midi(samples, file_name, threshold=0.5, num_notes=96, samples_per ticks_per_beat = mid.ticks_per_beat ticks_per_measure = 4 * ticks_per_beat + ticks_per_measure = ticks_per_measure * params.timeScaleF ticks_per_sample = ticks_per_measure / samples_per_measure # add instrument for track diff --git a/models.py b/models.py index 92005c93..9b849482 100644 --- a/models.py +++ b/models.py @@ -1,113 +1,114 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -""" -The models used for music generation. -""" - -import torch -import torch.nn as nn - - - -import params - -def vae_sampling(args): - z_mean, z_log_sigma_sq, vae_b1 = args - epsilon = K.random_normal(shape=K.shape(z_mean), mean=0.0, stddev=vae_b1) - return z_mean + K.exp(z_log_sigma_sq * 0.5) * epsilon - - +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +The models used for music generation. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import randn +import params + + +def vae_sampling(args): + z_mean, z_log_sigma_sq, vae_b1 = args + epsilon = randn(z_mean.size(), dtype=z_mean.dtype, device=z_mean.device) * vae_b1 + return z_mean + torch.exp(z_log_sigma_sq * 0.5) * epsilon + + class AutoencoderModel(nn.Module): def __init__(self, input_shape, latent_space_size, dropout_rate, max_windows, batchnorm_momentum, use_vae=False, vae_b1=0.02, use_embedding=False, embedding_input_shape=None, embedding_shape=None): - super(AutoencoderModel, self).__init__() - """ - Create larger autoencoder with the options of making it variational and embedding. - :param input_shape: - :param latent_space_size: - :param dropout_rate: - :param max_windows: - :param batchnorm_momentum: - :param use_vae: - :param vae_b1: - :param use_embedding: - :param embedding_input_shape: - :param embedding_shape: - :return: - """ - if use_embedding: - x_in = Input(shape=embedding_input_shape) - print((None,) + embedding_input_shape) - - x = Embedding(embedding_shape, latent_space_size, input_length=1)(x_in) - x = Flatten(name='encoder')(x) - else: - x_in = Input(shape=input_shape) - print((None,) + input_shape) - - x = Reshape((input_shape[0], -1))(x_in) - print(K.int_shape(x)) - - if params.noise_rate > 0: - x = Lambda(lambda x: 1 - x)(x) - x = Dropout(params.noise_rate)(x) - x = Lambda(lambda x: 1 - x)(x) - - print(K.int_shape(x)) - x = TimeDistributed(Dense(2000, activation='relu'))(x) - print(K.int_shape(x)) - - x = TimeDistributed(Dense(200, activation='relu'))(x) - print(K.int_shape(x)) - - x = Flatten()(x) - print(K.int_shape(x)) - - x = Dense(1600, activation='relu')(x) - print(K.int_shape(x)) - - if use_vae: - z_mean = Dense(latent_space_size)(x) - z_log_sigma_sq = Dense(latent_space_size)(x) - x = Lambda(vae_sampling, output_shape=(latent_space_size,), name='encoder')([z_mean, z_log_sigma_sq, vae_b1]) - else: - x = Dense(latent_space_size)(x) - x = BatchNormalization(momentum=batchnorm_momentum, name='encoder')(x) - print(K.int_shape(x)) - - # LATENT SPACE - - x = Dense(1600, name='decoder')(x) - x = BatchNormalization(momentum=batchnorm_momentum)(x) - x = Activation('relu')(x) - if dropout_rate > 0: - x = Dropout(dropout_rate)(x) - print(K.int_shape(x)) - - x = Dense(max_windows * 200)(x) - print(K.int_shape(x)) - x = Reshape((max_windows, 200))(x) - x = TimeDistributed(BatchNormalization(momentum=batchnorm_momentum))(x) - x = Activation('relu')(x) - if dropout_rate > 0: - x = Dropout(dropout_rate)(x) - print(K.int_shape(x)) - - x = TimeDistributed(Dense(2000))(x) - x = TimeDistributed(BatchNormalization(momentum=batchnorm_momentum))(x) - x = Activation('relu')(x) - if dropout_rate > 0: - x = Dropout(dropout_rate)(x) - print(K.int_shape(x)) - - #if params.encode_volume: - #x = TimeDistributed(Dense(input_shape[1] * input_shape[2]))(x) - #else: - x = TimeDistributed(Dense(input_shape[1] * input_shape[2], activation='sigmoid'))(x) - print(K.int_shape(x)) - x = Reshape((input_shape[0], input_shape[1], input_shape[2]))(x) - print(K.int_shape(x)) - - model = Model(x_in, x) - - return model + super(AutoencoderModel, self).__init__() + """ + Create larger autoencoder with the options of making it variational and embedding. + :param input_shape: + :param latent_space_size: + :param dropout_rate: + :param max_windows: + :param batchnorm_momentum: + :param use_vae: + :param vae_b1: + :param use_embedding: + :param embedding_input_shape: + :param embedding_shape: + :return: + """ + if use_embedding: + x_in = Input(shape=embedding_input_shape) + print((None,) + embedding_input_shape) + + x = Embedding(embedding_shape, latent_space_size, input_length=1)(x_in) + x = Flatten(name='encoder')(x) + else: + x_in = Input(shape=input_shape) + print((None,) + input_shape) + + x = Reshape((input_shape[0], -1))(x_in) + print(K.int_shape(x)) + + if params.noise_rate > 0: + x = Lambda(lambda x: 1 - x)(x) + x = Dropout(params.noise_rate)(x) + x = Lambda(lambda x: 1 - x)(x) + + print(K.int_shape(x)) + + x = TimeDistributed(Dense(2000, activation='relu'))(x) + print(K.int_shape(x)) + + x = TimeDistributed(Dense(200, activation='relu'))(x) + print(K.int_shape(x)) + + x = Flatten()(x) + print(K.int_shape(x)) + + x = Dense(1600, activation='relu')(x) + print(K.int_shape(x)) + + if use_vae: + z_mean = Dense(latent_space_size)(x) + z_log_sigma_sq = Dense(latent_space_size)(x) + x = Lambda(vae_sampling, output_shape=(latent_space_size,), name='encoder')([z_mean, z_log_sigma_sq, vae_b1]) + else: + x = Dense(latent_space_size)(x) + x = BatchNormalization(momentum=batchnorm_momentum, name='encoder')(x) + print(K.int_shape(x)) + + # LATENT SPACE + + x = Dense(1600, name='decoder')(x) + x = BatchNormalization(momentum=batchnorm_momentum)(x) + x = Activation('relu')(x) + if dropout_rate > 0: + x = Dropout(dropout_rate)(x) + print(K.int_shape(x)) + + x = Dense(max_windows * 200)(x) + print(K.int_shape(x)) + x = Reshape((max_windows, 200))(x) + x = TimeDistributed(BatchNormalization(momentum=batchnorm_momentum))(x) + x = Activation('relu')(x) + if dropout_rate > 0: + x = Dropout(dropout_rate)(x) + print(K.int_shape(x)) + + x = TimeDistributed(Dense(2000))(x) + x = TimeDistributed(BatchNormalization(momentum=batchnorm_momentum))(x) + x = Activation('relu')(x) + if dropout_rate > 0: + x = Dropout(dropout_rate)(x) + print(K.int_shape(x)) + + + #if params.encode_volume: + #x = TimeDistributed(Dense(input_shape[1] * input_shape[2]))(x) + #else: + + x = TimeDistributed(Dense(input_shape[1] * input_shape[2], activation='sigmoid'))(x) + print(K.int_shape(x)) + x = Reshape((input_shape[0], input_shape[1], input_shape[2]))(x) + print(K.int_shape(x)) + + return Model(x_in, x) diff --git a/params.py b/params.py index 9bda1f40..ff4ae2da 100644 --- a/params.py +++ b/params.py @@ -1,3 +1,5 @@ num_params = 40 encode_volume = False -encode_length = False \ No newline at end of file +encode_length = False +noise_rate = 0.0002 +timeScaleF = 1 diff --git a/plot_utils.py b/plot_utils.py index 7f9520f8..b70cc2cc 100644 --- a/plot_utils.py +++ b/plot_utils.py @@ -23,4 +23,4 @@ def plot_samples(folder, samples, threshold=None): os.makedirs(folder) for i in range(samples.shape[0]): - plot_sample(folder + '/s' + str(i) + '.png', samples[i], threshold) + plot_sample(f'{folder}/s{str(i)}.png', samples[i], threshold) diff --git a/preprocess_songs.py b/preprocess_songs.py index e32ed511..6976a60d 100644 --- a/preprocess_songs.py +++ b/preprocess_songs.py @@ -12,6 +12,7 @@ import argparse import params + def preprocess_songs(data_folders): """ Load and preprocess the songs from the data folders and turn them into a dataset of samples/pitches and lengths of the tones. @@ -31,9 +32,9 @@ def preprocess_songs(data_folders): print("Loading songs...") # walk folders and look for midi files for folder in data_folders: - for root, _, files in os.walk(folder): + for root,_,files in os.walk(folder): for file in files: - path = os.path.join(root, file) + path = os.path.join(root,file) if not (path.endswith('.mid') or path.endswith('.midi')): continue @@ -41,39 +42,39 @@ def preprocess_songs(data_folders): try: samples = midi_utils.midi_to_samples(path) except Exception as e: - print("ERROR ", path) + print("ERROR ",path) print(e) failed += 1 continue # if the midi does not produce the minimal number of sample/measures, we skip it if len(samples) < 16: - print('WARN', path, 'Sample too short, unused') + print('WARN',path,'Sample too short, unused') ignored += 1 continue # transpose samples (center them in full range to get more training samples for the same tones) - samples, lengths = music_utils.generate_centered_transpose(samples) + samples,lengths = music_utils.generate_centered_transpose(samples) all_samples += samples all_lengths += lengths - print('SUCCESS', path, len(samples), 'samples') + print('SUCCESS',path,len(samples),'samples') succeeded += 1 assert (sum(all_lengths) == len(all_samples)) # assert equal number of samples and lengths # save all to disk - print("Saving " + str(len(all_samples)) + " samples...") - all_samples = np.array(all_samples, dtype=np.uint8) # reduce size when saving - all_lengths = np.array(all_lengths, dtype=np.uint32) - np.save('data/interim/samples.npy', all_samples) - np.save('data/interim/lengths.npy', all_lengths) - print('Done: ', succeeded, 'succeded,', ignored, 'ignored,', failed, 'failed of', succeeded + ignored + failed, 'in total') + print(f"Saving {len(all_samples)} samples...") + all_samples = np.array(all_samples,dtype=np.uint8) # reduce size when saving + all_lengths = np.array(all_lengths,dtype=np.uint32) + np.save('data/interim/samples.npy',all_samples) + np.save('data/interim/lengths.npy',all_lengths) + print('Done: ',succeeded,'succeded,',ignored,'ignored,',failed,'failed of',succeeded + ignored + failed,'in total') if __name__ == "__main__": # configure parser and parse arguments parser = argparse.ArgumentParser(description='Load songs, preprocess them and put them into a dataset.') - parser.add_argument('--data_folder', default=["data/raw"], type=str, help='The path to the midi data', action='append') + parser.add_argument('--data_folder',default=["data/raw"],type=str,help='The path to the midi data',action='append') args = parser.parse_args() preprocess_songs(args.data_folder) diff --git a/pygame_textinput.py b/pygame_textinput.py new file mode 100644 index 00000000..e024aee9 --- /dev/null +++ b/pygame_textinput.py @@ -0,0 +1,206 @@ +""" +Copyright 2017, Silas Gyger, silasgyger@gmail.com, All rights reserved. + +Borrowed from https://github.com/Nearoo/pygame-text-input under the MIT license. +""" + +import os.path + +import pygame +import pygame.locals as pl + +pygame.font.init() + + +class TextInput: + """ + This class lets the user input a piece of text, e.g. a name or a message. + This class lets the user input a short, one-lines piece of text at a blinking cursor + that can be moved using the arrow-keys. Delete, home and end work as well. + """ + def __init__( + self, + initial_string="", + font_family="", + font_size=35, + antialias=True, + text_color=(0, 0, 0), + cursor_color=(0, 0, 1), + repeat_keys_initial_ms=400, + repeat_keys_interval_ms=35, + max_string_length=-1): + """ + :param initial_string: Initial text to be displayed + :param font_family: name or list of names for font (see pygame.font.match_font for precise format) + :param font_size: Size of font in pixels + :param antialias: Determines if antialias is applied to font (uses more processing power) + :param text_color: Color of text (duh) + :param cursor_color: Color of cursor + :param repeat_keys_initial_ms: Time in ms before keys are repeated when held + :param repeat_keys_interval_ms: Interval between key press repetition when held + :param max_string_length: Allowed length of text + """ + + # Text related vars: + self.antialias = antialias + self.text_color = text_color + self.font_size = font_size + self.max_string_length = max_string_length + self.input_string = initial_string # Inputted text + + if not os.path.isfile(font_family): + font_family = pygame.font.match_font(font_family) + + self.font_object = pygame.font.Font(font_family, font_size) + + # Text-surface will be created during the first update call: + self.surface = pygame.Surface((1, 1)) + self.surface.set_alpha(0) + + # Vars to make keydowns repeat after user pressed a key for some time: + self.keyrepeat_counters = {} # {event.key: (counter_int, event.unicode)} (look for "***") + self.keyrepeat_intial_interval_ms = repeat_keys_initial_ms + self.keyrepeat_interval_ms = repeat_keys_interval_ms + + # Things cursor: + self.cursor_surface = pygame.Surface((int(self.font_size / 20 + 1), self.font_size)) + self.cursor_surface.fill(cursor_color) + self.cursor_position = len(initial_string) # Inside text + self.cursor_visible = True # Switches every self.cursor_switch_ms ms + self.cursor_switch_ms = 500 # /|\ + self.cursor_ms_counter = 0 + + self.clock = pygame.time.Clock() + + def update(self, events): + for event in events: + if event.type == pygame.KEYDOWN: + self.cursor_visible = True # So the user sees where he writes + + # If none exist, create counter for that key: + if event.key not in self.keyrepeat_counters: + self.keyrepeat_counters[event.key] = [0, event.unicode] + + if event.key == pl.K_BACKSPACE: + self.input_string = ( + self.input_string[:max(self.cursor_position - 1, 0)] + + self.input_string[self.cursor_position:] + ) + + # Subtract one from cursor_pos, but do not go below zero: + self.cursor_position = max(self.cursor_position - 1, 0) + elif event.key == pl.K_DELETE: + self.input_string = ( + self.input_string[:self.cursor_position] + + self.input_string[self.cursor_position + 1:] + ) + + elif event.key == pl.K_RETURN: + return True + + elif event.key == pl.K_RIGHT: + # Add one to cursor_pos, but do not exceed len(input_string) + self.cursor_position = min(self.cursor_position + 1, len(self.input_string)) + + elif event.key == pl.K_LEFT: + # Subtract one from cursor_pos, but do not go below zero: + self.cursor_position = max(self.cursor_position - 1, 0) + + elif event.key == pl.K_END: + self.cursor_position = len(self.input_string) + + elif event.key == pl.K_HOME: + self.cursor_position = 0 + + elif len(self.input_string) < self.max_string_length or self.max_string_length == -1: + # If no special key is pressed, add unicode of key to input_string + self.input_string = ( + self.input_string[:self.cursor_position] + + event.unicode + + self.input_string[self.cursor_position:] + ) + self.cursor_position += len(event.unicode) # Some are empty, e.g. K_UP + + elif event.type == pl.KEYUP: + # *** Because KEYUP doesn't include event.unicode, this dict is stored in such a weird way + if event.key in self.keyrepeat_counters: + del self.keyrepeat_counters[event.key] + + # Update key counters: + for key in self.keyrepeat_counters: + self.keyrepeat_counters[key][0] += self.clock.get_time() # Update clock + + # Generate new key events if enough time has passed: + if self.keyrepeat_counters[key][0] >= self.keyrepeat_intial_interval_ms: + self.keyrepeat_counters[key][0] = ( + self.keyrepeat_intial_interval_ms + - self.keyrepeat_interval_ms + ) + + event_key, event_unicode = key, self.keyrepeat_counters[key][1] + pygame.event.post(pygame.event.Event(pl.KEYDOWN, key=event_key, unicode=event_unicode)) + + # Re-render text surface: + self.surface = self.font_object.render(self.input_string, self.antialias, self.text_color) + + # Update self.cursor_visible + self.cursor_ms_counter += self.clock.get_time() + if self.cursor_ms_counter >= self.cursor_switch_ms: + self.cursor_ms_counter %= self.cursor_switch_ms + self.cursor_visible = not self.cursor_visible + + if self.cursor_visible: + cursor_y_pos = self.font_object.size(self.input_string[:self.cursor_position])[0] + # Without this, the cursor is invisible when self.cursor_position > 0: + if self.cursor_position > 0: + cursor_y_pos -= self.cursor_surface.get_width() + self.surface.blit(self.cursor_surface, (cursor_y_pos, 0)) + + self.clock.tick() + return False + + def get_surface(self): + return self.surface + + def get_text(self): + return self.input_string + + def get_cursor_position(self): + return self.cursor_position + + def set_text_color(self, color): + self.text_color = color + + def set_cursor_color(self, color): + self.cursor_surface.fill(color) + + def clear_text(self): + self.input_string = "" + self.cursor_position = 0 + + + +if __name__ == "__main__": + pygame.init() + + # Create TextInput-object + textinput = TextInput() + + screen = pygame.display.set_mode((1000, 200)) + clock = pygame.time.Clock() + + while True: + screen.fill((225, 225, 225)) + + events = pygame.event.get() + for event in events: + if event.type == pygame.QUIT: + exit() + + # Feed it with events every frame + textinput.update(events) + # Blit its surface onto the screen + screen.blit(textinput.get_surface(), (10, 10)) + + pygame.display.update() + clock.tick(30) diff --git a/sweep.yaml b/sweep.yaml new file mode 100644 index 00000000..89e1d027 --- /dev/null +++ b/sweep.yaml @@ -0,0 +1,27 @@ +# Sweep AI turns bugs & feature requests into code changes (https://sweep.dev) +# For details on our config file, check out our docs at https://docs.sweep.dev/usage/config + +# This setting contains a list of rules that Sweep will check for. If any of these rules are broken in a new commit, Sweep will create an pull request to fix the broken rule. +rules: + - "All new business logic should have corresponding unit tests." + - "Refactor large functions to be more modular." + - "Add docstrings to all functions and file headers." + +# This is the branch that Sweep will develop from and make pull requests to. Most people use 'main' or 'master' but some users also use 'dev' or 'staging'. +branch: 'main' + +# By default Sweep will read the logs and outputs from your existing Github Actions. To disable this, set this to false. +gha_enabled: True + +# This is the description of your project. It will be used by sweep when creating PRs. You can tell Sweep what's unique about your project, what frameworks you use, or anything else you want. +# +# Example: +# +# description: sweepai/sweep is a python project. The main api endpoints are in sweepai/api.py. Write code that adheres to PEP8. +description: '' + +# This sets whether to create pull requests as drafts. If this is set to True, then all pull requests will be created as drafts and GitHub Actions will not be triggered. +draft: False + +# This is a list of directories that Sweep will not be able to edit. +blocked_dirs: [] diff --git a/tests.py b/tests.py new file mode 100644 index 00000000..5c91971b --- /dev/null +++ b/tests.py @@ -0,0 +1,53 @@ +import pytest +import numpy as np +from midi_utils import midi_to_samples, samples_to_midi +from mido import MidiFile +import os + +# Mock params module as it's not provided +class MockParams: + encode_volume = True + encode_length = True + +params = MockParams() + +@pytest.mark.parametrize("file_name, num_notes, samples_per_measure, expected_exception, test_id", [ + ("test_midi_1.mid", 96, 96, None, "happy_path_basic"), + ("test_midi_2.mid", 128, 48, None, "happy_path_extended_notes"), + ("test_midi_invalid.mid", 96, 96, NotImplementedError, "error_multiple_time_signatures"), + ("test_midi_nonexistent.mid", 96, 96, FileNotFoundError, "error_file_not_found"), +]) +def test_midi_to_samples(file_name, num_notes, samples_per_measure, expected_exception, test_id, tmpdir): + if expected_exception: + with pytest.raises(expected_exception): + midi_to_samples(file_name, num_notes, samples_per_measure) + else: + # Arrange + midi_path = os.path.join(tmpdir, file_name) + MidiFile().save(midi_path) # Create a simple, empty MIDI file for testing + + # Act + samples = midi_to_samples(midi_path, num_notes, samples_per_measure) + + # Assert + assert isinstance(samples, list), f"Test ID {test_id}: The result should be a list." + if samples: # If there are samples, check their structure + assert isinstance(samples[0], np.ndarray), f"Test ID {test_id}: Each sample should be a numpy array." + assert samples[0].shape == (samples_per_measure, num_notes), f"Test ID {test_id}: Incorrect shape of sample." + +@pytest.mark.parametrize("samples, file_name, threshold, num_notes, samples_per_measure, expected_notes, test_id", [ + ([np.zeros((96, 96))], "output_1.mid", 0.5, 96, 96, 0, "empty_samples"), + ([np.ones((96, 96)) * 0.6], "output_2.mid", 0.5, 96, 96, 96, "full_samples_above_threshold"), + ([np.ones((96, 96)) * 0.4], "output_3.mid", 0.5, 96, 96, 0, "full_samples_below_threshold"), +]) +def test_samples_to_midi(samples, file_name, threshold, num_notes, samples_per_measure, expected_notes, test_id, tmpdir): + # Arrange + output_path = os.path.join(tmpdir, file_name) + + # Act + samples_to_midi(samples, output_path, threshold, num_notes, samples_per_measure) + + # Assert + mid = MidiFile(output_path) + note_on_messages = sum(1 for track in mid.tracks for msg in track if msg.type == 'note_on') + assert note_on_messages == expected_notes, f"Test ID {test_id}: Expected {expected_notes} 'note_on' messages, found {note_on_messages}." diff --git a/train.py b/train.py index 4500f8cf..9c74d390 100644 --- a/train.py +++ b/train.py @@ -2,26 +2,26 @@ # -*- coding: utf-8 -*- """ -Train an autoencoder model to learn to encode songs. +Train an auto-encoder model to learn to encode songs. """ + +import argparse import random import numpy as np from matplotlib import pyplot as plt import midi_utils -import plot_utils import models import params +import plot_utils -import argparse - -# Load Keras -print("Loading keras...") +# Load Torch import os import torch import torch.nn as nn +import torch.nn.functional as F import torch.optim as optim import torch.utils.data @@ -31,23 +31,21 @@ # import tensorflow as tf # from tensorflow.python.client import device_lib # print(device_lib.list_local_devices()) - # config = tf.ConfigProto( device_count = {'GPU': 1 , 'CPU': 56} ) # sess = tf.Session(config=config) # K.set_session(sess) EPOCHS_QTY = 3000 EPOCHS_TO_SAVE = [1, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 120, 140, 160, 180, 200, 250, 300, 350, 400, 450, 800, 1000, 1500, 2000, 2500, 3000] -LEARNING_RATE = 0.001 # learning rate -CONTINUE_TRAIN = False GENERATE_ONLY = False - +LEARNING_RATE = 0.001 / 5 # learning rate +CONTINUE_TRAIN = True WRITE_HISTORY = True NUM_RAND_SONGS = 10 # network params DROPOUT_RATE = 0.1 -BATCHNORM_MOMENTUM = 0.9 # weighted normalization with the past +BATCHNORM_MOMENTUM: float = 0.9 # weighted normalization with the past USE_EMBEDDING = False USE_VAE = False VAE_B1 = 0.02 @@ -58,7 +56,7 @@ LATENT_SPACE_SIZE = params.num_params NUM_OFFSETS = 16 if USE_EMBEDDING else 1 -K.set_image_data_format('channels_first') +# PyTorch uses 'channels_first' by default for convolutional operations, so this line is not needed. # Fix the random seed so that training comparisons are easier to make np.random.seed(42) @@ -101,6 +99,7 @@ def plot_losses(scores, f_name, on_top=True): def save_training_config(num_songs, model, learning_rate): + # sourcery skip: use-fstring-for-concatenation """ Save configuration of training. :param num_songs: @@ -128,7 +127,9 @@ def generate_random_songs(decoder, write_dir, random_vectors): for i in range(random_vectors.shape[0]): random_latent_x = random_vectors[i:i + 1] y_song = decoder([random_latent_x, 0])[0] - midi_utils.samples_to_midi(y_song[0], write_dir + 'random_vectors' + str(i) + '.mid', 32) + midi_utils.samples_to_midi( + y_song[0], f'{write_dir}random_vectors{str(i)}.mid', 32 + ) def calculate_and_store_pca_statistics(encoder, x_orig, y_orig, write_dir): @@ -155,10 +156,10 @@ def calculate_and_store_pca_statistics(encoder, x_orig, y_orig, write_dir): print("Latent Mean values: ", latent_mean[:6]) print("Latent PCA values: ", latent_pca_values[:6]) - np.save(write_dir + 'latent_means.npy', latent_mean) - np.save(write_dir + 'latent_stds.npy', latent_stds) - np.save(write_dir + 'latent_pca_values.npy', latent_pca_values) - np.save(write_dir + 'latent_pca_vectors.npy', latent_pca_vectors) + np.save(f'{write_dir}latent_means.npy', latent_mean) + np.save(f'{write_dir}latent_stds.npy', latent_stds) + np.save(f'{write_dir}latent_pca_values.npy', latent_pca_values) + np.save(f'{write_dir}latent_pca_vectors.npy', latent_pca_vectors) return latent_mean, latent_stds, latent_pca_values, latent_pca_vectors @@ -178,28 +179,25 @@ def generate_normalized_random_songs(x_orig, y_orig, encoder, decoder, random_ve latent_vectors = latent_mean + np.dot(random_vectors * pca_values, pca_vectors) generate_random_songs(decoder, write_dir, latent_vectors) - title = '' - if '/' in write_dir: - title = 'Epoch: ' + write_dir.split('/')[-2][1:] - + title = 'Epoch: ' + write_dir.split('/')[-2][1:] if '/' in write_dir else '' plt.clf() pca_values[::-1].sort() plt.title(title) plt.bar(np.arange(pca_values.shape[0]), pca_values, align='center') plt.draw() - plt.savefig(write_dir + 'latent_pca_values.png') + plt.savefig(f'{write_dir}latent_pca_values.png') plt.clf() plt.title(title) plt.bar(np.arange(pca_values.shape[0]), latent_mean, align='center') plt.draw() - plt.savefig(write_dir + 'latent_means.png') + plt.savefig(f'{write_dir}latent_means.png') plt.clf() plt.title(title) plt.bar(np.arange(pca_values.shape[0]), latent_stds, align='center') plt.draw() - plt.savefig(write_dir + 'latent_stds.png') + plt.savefig(f'{write_dir}latent_stds.png') def train(samples_path='data/interim/samples.npy', lengths_path='data/interim/lengths.npy', epochs_qty=EPOCHS_QTY, learning_rate=LEARNING_RATE): @@ -239,7 +237,7 @@ def __getitem__(self, idx): samples_qty = y_samples.shape[0] songs_qty = y_lengths.shape[0] - print("Loaded " + str(samples_qty) + " samples from " + str(songs_qty) + " songs.") + print(f"Loaded {str(samples_qty)} samples from {str(songs_qty)} songs.") print(np.sum(y_lengths)) assert (np.sum(y_lengths) == samples_qty) @@ -248,7 +246,7 @@ def __getitem__(self, idx): x_orig = np.expand_dims(np.arange(x_shape[0]), axis=-1) y_shape = (songs_qty * NUM_OFFSETS, MAX_WINDOWS) + y_samples.shape[1:] # (songs_qty, max number of windows, window pitch qty, window beats per measure) - y_orig = np.zeros(y_shape, dtype=y_samples.dtype) # prepare dataset array + y_orig = np.zeros(y_shape, dtype=np.float32) # prepare dataset array # fill in measure of songs into input windows for network song_start_ix = 0 @@ -274,7 +272,7 @@ def __getitem__(self, idx): # create model if CONTINUE_TRAIN or GENERATE_ONLY: print("Loading model...") - model = load_model('results/history/model.h5') + model = torch.load('results/history/model.pth') else: print("Building model...") @@ -289,17 +287,19 @@ def __getitem__(self, idx): embedding_input_shape=x_shape[1:], embedding_shape=x_train.shape[0]) + # Define the optimizer and loss function for PyTorch if USE_VAE: - model.compile(optimizer=Adam(lr=learning_rate), loss=vae_loss) + optimizer = optim.Adam(model.parameters(), lr=learning_rate) + loss_function = vae_loss else: - model.compile(optimizer=RMSprop(lr=learning_rate), loss='binary_crossentropy') + optimizer = optim.RMSprop(model.parameters(), lr=learning_rate) + loss_function = nn.BCELoss() # plot model with graphvis if installed #try: - # plot_model(model, to_file='results/model.png', show_shapes=True) + # plot_model(model, to_file='results/model.png', show_shapes=True) #except OSError as e: # print(e) - # train print("Referencing sub-models...") decoder = K.function([model.get_layer('decoder').input, K.learning_phase()], [model.layers[-1].output]) @@ -313,8 +313,10 @@ def __getitem__(self, idx): generate_normalized_random_songs(x_orig, y_orig, encoder, decoder, random_vectors, 'results/') for save_epoch in range(20): x_test_song = x_train[save_epoch:save_epoch + 1] - y_song = model.predict(x_test_song, batch_size=BATCH_SIZE)[0] - midi_utils.samples_to_midi(y_song, 'results/gt' + str(save_epoch) + '.mid') + model.eval() + with torch.no_grad(): + y_song = model(x_test_song).cpu().numpy()[0] + midi_utils.samples_to_midi(y_song, f'results/gt{str(save_epoch)}.mid') exit(0) save_training_config(songs_qty, model, learning_rate) @@ -325,7 +327,14 @@ def __getitem__(self, idx): for epoch in range(epochs_qty): print("Training epoch: ", epoch, "of", epochs_qty) if USE_EMBEDDING: - history = model.fit(x_train, y_train, batch_size=BATCH_SIZE, epochs=1) + # Manual training loop in PyTorch + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + optimizer.zero_grad() + output = model(data) + loss = loss_function(output, target) + loss.backward() + optimizer.step() else: # produce songs from its samples with a different starting point of the song each time song_start_ix = 0 @@ -334,16 +343,25 @@ def __getitem__(self, idx): for window_ix in range(MAX_WINDOWS): song_measure_ix = (window_ix + offset) % y_lengths[song_ix] y_train[song_ix, window_ix] = y_samples[song_start_ix + song_measure_ix] + #if params.encode_volume: + #y_train[song_ix, window_ix] /= 100.0 song_start_ix = song_end_ix assert (song_end_ix == samples_qty) offset += 1 - history = model.fit(y_train, y_train, batch_size=BATCH_SIZE, epochs=1) # train model on reconstruction loss + # Manual training loop in PyTorch for reconstruction loss + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + optimizer.zero_grad() + output = model(data) + loss = loss_function(output, target) + loss.backward() + optimizer.step() # store last loss loss = history.history["loss"][-1] train_loss.append(loss) - print("Train loss: " + str(train_loss[-1])) + print(f"Train loss: {str(train_loss[-1])}") if WRITE_HISTORY: plot_losses(train_loss, 'results/history/losses.png', True) @@ -356,23 +374,27 @@ def __getitem__(self, idx): write_dir = '' if WRITE_HISTORY: # Create folder to save models into - write_dir += 'results/history/e' + str(save_epoch) + write_dir += f'results/history/e{str(save_epoch)}' if not os.path.exists(write_dir): os.makedirs(write_dir) write_dir += '/' - model.save('results/history/model.h5') + torch.save(model.state_dict(), 'results/history/model.pth') else: model.save('results/model.h5') print("...Saved.") if USE_EMBEDDING: - y_song = model.predict(x_test_song, batch_size=BATCH_SIZE)[0] + model.eval() + with torch.no_grad(): + y_song = model(x_test_song).cpu().numpy()[0] else: - y_song = model.predict(y_test_song, batch_size=BATCH_SIZE)[0] + model.eval() + with torch.no_grad(): + y_song = model(y_test_song).cpu().numpy()[0] - plot_utils.plot_samples(write_dir + 'test', y_song) - midi_utils.samples_to_midi(y_song, write_dir + 'test.mid') + plot_utils.plot_samples(f'{write_dir}test', y_song) + midi_utils.samples_to_midi(y_song, f'{write_dir}test.mid') generate_normalized_random_songs(x_orig, y_orig, encoder, decoder, random_vectors, write_dir)