diff --git a/.gitignore b/.gitignore index de5fa9f..dceea95 100644 --- a/.gitignore +++ b/.gitignore @@ -3,7 +3,8 @@ __pycache__ checkpoints/ .idea/ venv/ -data/notes -data/dataset_hash +training_data/notes +training_data/dataset_hash midi_songs/ logs/ +results/ diff --git a/Pipfile b/Pipfile index 7b16e4f..bacffc6 100644 --- a/Pipfile +++ b/Pipfile @@ -16,6 +16,10 @@ tensorflow = "~=2.4.0" wget = "~=3.2" pyngrok = "*" random-word = "*" +pandas = "*" +matplotlib = "*" +statsmodels = "*" +scipy = "*" [requires] python_version = "3.8" diff --git a/Pipfile.lock b/Pipfile.lock index f530b4b..97acb5e 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "af39b89560373015ca0c1548f387e34ab503480d0240e56393799a8964a5ccee" + "sha256": "cb50950298afc7a0fb0898e35a0da93352b938d825b7906240f2c63a27351e48" }, "pipfile-spec": 6, "requires": { @@ -39,10 +39,10 @@ }, "certifi": { "hashes": [ - "sha256:1a4995114262bffbc2413b159f2a1a480c969de6e6eb13ee966d470af86af59c", - "sha256:719a74fb9e33b9bd44cc7f3a8d94bc35e4049deebe19ba7d8e108280cfd59830" + "sha256:2bbf76fd432960138b3ef6dda3dde0544f27cbf8546c458e60baf371917ba9ee", + "sha256:50b1e4f8446b06f41be7dd6338db18e0990601dce795c2b1686458aa7e8fa7d8" ], - "version": "==2020.12.5" + "version": "==2021.5.30" }, "chardet": { "hashes": [ @@ -59,6 +59,13 @@ "index": "pypi", "version": "==1.2.0" }, + "cycler": { + "hashes": [ + "sha256:1d8a5ae1ff6c5cf9b93e8811e581232ad8920aeec647c37316ceac982b08cb2d", + "sha256:cd7b2d1018258d7247a71425e9f26463dfb444d411c39569972f4ce586b0c9d8" + ], + "version": "==0.10.0" + }, "flatbuffers": { "hashes": [ "sha256:63bb9a722d5e373701913e226135b28a6f6ac200d5cc7b4d919fa38d73b44610", @@ -203,6 +210,43 @@ ], "version": "==1.1.2" }, + "kiwisolver": { + "hashes": [ + "sha256:0cd53f403202159b44528498de18f9285b04482bab2a6fc3f5dd8dbb9352e30d", + "sha256:1e1bc12fb773a7b2ffdeb8380609f4f8064777877b2225dec3da711b421fda31", + "sha256:225e2e18f271e0ed8157d7f4518ffbf99b9450fca398d561eb5c4a87d0986dd9", + "sha256:232c9e11fd7ac3a470d65cd67e4359eee155ec57e822e5220322d7b2ac84fbf0", + "sha256:31dfd2ac56edc0ff9ac295193eeaea1c0c923c0355bf948fbd99ed6018010b72", + "sha256:33449715e0101e4d34f64990352bce4095c8bf13bed1b390773fc0a7295967b3", + "sha256:401a2e9afa8588589775fe34fc22d918ae839aaaf0c0e96441c0fdbce6d8ebe6", + "sha256:44a62e24d9b01ba94ae7a4a6c3fb215dc4af1dde817e7498d901e229aaf50e4e", + "sha256:50af681a36b2a1dee1d3c169ade9fdc59207d3c31e522519181e12f1b3ba7000", + "sha256:563c649cfdef27d081c84e72a03b48ea9408c16657500c312575ae9d9f7bc1c3", + "sha256:5989db3b3b34b76c09253deeaf7fbc2707616f130e166996606c284395da3f18", + "sha256:5a7a7dbff17e66fac9142ae2ecafb719393aaee6a3768c9de2fd425c63b53e21", + "sha256:5c3e6455341008a054cccee8c5d24481bcfe1acdbc9add30aa95798e95c65621", + "sha256:5f6ccd3dd0b9739edcf407514016108e2280769c73a85b9e59aa390046dbf08b", + "sha256:72c99e39d005b793fb7d3d4e660aed6b6281b502e8c1eaf8ee8346023c8e03bc", + "sha256:78751b33595f7f9511952e7e60ce858c6d64db2e062afb325985ddbd34b5c131", + "sha256:834ee27348c4aefc20b479335fd422a2c69db55f7d9ab61721ac8cd83eb78882", + "sha256:8be8d84b7d4f2ba4ffff3665bcd0211318aa632395a1a41553250484a871d454", + "sha256:950a199911a8d94683a6b10321f9345d5a3a8433ec58b217ace979e18f16e248", + "sha256:a357fd4f15ee49b4a98b44ec23a34a95f1e00292a139d6015c11f55774ef10de", + "sha256:a53d27d0c2a0ebd07e395e56a1fbdf75ffedc4a05943daf472af163413ce9598", + "sha256:acef3d59d47dd85ecf909c359d0fd2c81ed33bdff70216d3956b463e12c38a54", + "sha256:b38694dcdac990a743aa654037ff1188c7a9801ac3ccc548d3341014bc5ca278", + "sha256:b9edd0110a77fc321ab090aaa1cfcaba1d8499850a12848b81be2222eab648f6", + "sha256:c08e95114951dc2090c4a630c2385bef681cacf12636fb0241accdc6b303fd81", + "sha256:c5518d51a0735b1e6cee1fdce66359f8d2b59c3ca85dc2b0813a8aa86818a030", + "sha256:c8fd0f1ae9d92b42854b2979024d7597685ce4ada367172ed7c09edf2cef9cb8", + "sha256:ca3820eb7f7faf7f0aa88de0e54681bddcb46e485beb844fcecbcd1c8bd01689", + "sha256:cf8b574c7b9aa060c62116d4181f3a1a4e821b2ec5cbfe3775809474113748d4", + "sha256:d3155d828dec1d43283bd24d3d3e0d9c7c350cdfcc0bd06c0ad1209c1bbc36d0", + "sha256:f8d6f8db88049a699817fd9178782867bf22283e3813064302ac59f61d95be05", + "sha256:fd34fbbfbc40628200730bc1febe30631347103fc8d3d4fa012c21ab9c11eca9" + ], + "version": "==1.3.1" + }, "markdown": { "hashes": [ "sha256:31b5b491868dcc87d6c24b7e3d19a0d730d59d3e46f4eea6430a321bed387a49", @@ -210,6 +254,31 @@ ], "version": "==3.3.4" }, + "matplotlib": { + "hashes": [ + "sha256:0bea5ec5c28d49020e5d7923c2725b837e60bc8be99d3164af410eb4b4c827da", + "sha256:1c1779f7ab7d8bdb7d4c605e6ffaa0614b3e80f1e3c8ccf7b9269a22dbc5986b", + "sha256:21b31057bbc5e75b08e70a43cefc4c0b2c2f1b1a850f4a0f7af044eb4163086c", + "sha256:32fa638cc10886885d1ca3d409d4473d6a22f7ceecd11322150961a70fab66dd", + "sha256:3a5c18dbd2c7c366da26a4ad1462fe3e03a577b39e3b503bbcf482b9cdac093c", + "sha256:5826f56055b9b1c80fef82e326097e34dc4af8c7249226b7dd63095a686177d1", + "sha256:6382bc6e2d7e481bcd977eb131c31dee96e0fb4f9177d15ec6fb976d3b9ace1a", + "sha256:6475d0209024a77f869163ec3657c47fed35d9b6ed8bccba8aa0f0099fbbdaa8", + "sha256:6a6a44f27aabe720ec4fd485061e8a35784c2b9ffa6363ad546316dfc9cea04e", + "sha256:7a58f3d8fe8fac3be522c79d921c9b86e090a59637cb88e3bc51298d7a2c862a", + "sha256:7ad19f3fb6145b9eb41c08e7cbb9f8e10b91291396bee21e9ce761bb78df63ec", + "sha256:85f191bb03cb1a7b04b5c2cca4792bef94df06ef473bc49e2818105671766fee", + "sha256:956c8849b134b4a343598305a3ca1bdd3094f01f5efc8afccdebeffe6b315247", + "sha256:a9d8cb5329df13e0cdaa14b3b43f47b5e593ec637f13f14db75bb16e46178b05", + "sha256:b1d5a2cedf5de05567c441b3a8c2651fbde56df08b82640e7f06c8cd91e201f6", + "sha256:b26535b9de85326e6958cdef720ecd10bcf74a3f4371bf9a7e5b2e659c17e153", + "sha256:c541ee5a3287efe066bbe358320853cf4916bc14c00c38f8f3d8d75275a405a9", + "sha256:d8d994cefdff9aaba45166eb3de4f5211adb4accac85cbf97137e98f26ea0219", + "sha256:df815378a754a7edd4559f8c51fc7064f779a74013644a7f5ac7a0c31f875866" + ], + "index": "pypi", + "version": "==3.4.2" + }, "more-itertools": { "hashes": [ "sha256:2cf89ec599962f2ddc4d568a05defc40e0a587fbc10d5989713638864c36be4d", @@ -273,10 +342,10 @@ }, "oauthlib": { "hashes": [ - "sha256:bee41cc35fcca6e988463cacc3bcb8a96224f470ca547e697b604cc697b2f889", - "sha256:df884cd6cbe20e32633f1db1072e9356f53638e4361bef4e8b03c9127c9328ea" + "sha256:42bf6354c2ed8c6acb54d971fce6f88193d97297e18602a3a886603f9d7730cc", + "sha256:8f0215fcc533dd8dd1bee6f4c412d4f0cd7297307d43ac61666389e3bc3198a3" ], - "version": "==3.1.0" + "version": "==3.1.1" }, "opt-einsum": { "hashes": [ @@ -285,6 +354,74 @@ ], "version": "==3.3.0" }, + "pandas": { + "hashes": [ + "sha256:167693a80abc8eb28051fbd184c1b7afd13ce2c727a5af47b048f1ea3afefff4", + "sha256:2111c25e69fa9365ba80bbf4f959400054b2771ac5d041ed19415a8b488dc70a", + "sha256:298f0553fd3ba8e002c4070a723a59cdb28eda579f3e243bc2ee397773f5398b", + "sha256:2b063d41803b6a19703b845609c0b700913593de067b552a8b24dd8eeb8c9895", + "sha256:2cb7e8f4f152f27dc93f30b5c7a98f6c748601ea65da359af734dd0cf3fa733f", + "sha256:52d2472acbb8a56819a87aafdb8b5b6d2b3386e15c95bde56b281882529a7ded", + "sha256:612add929bf3ba9d27b436cc8853f5acc337242d6b584203f207e364bb46cb12", + "sha256:649ecab692fade3cbfcf967ff936496b0cfba0af00a55dfaacd82bdda5cb2279", + "sha256:68d7baa80c74aaacbed597265ca2308f017859123231542ff8a5266d489e1858", + "sha256:8d4c74177c26aadcfb4fd1de6c1c43c2bf822b3e0fc7a9b409eeaf84b3e92aaa", + "sha256:971e2a414fce20cc5331fe791153513d076814d30a60cd7348466943e6e909e4", + "sha256:9db70ffa8b280bb4de83f9739d514cd0735825e79eef3a61d312420b9f16b758", + "sha256:b730add5267f873b3383c18cac4df2527ac4f0f0eed1c6cf37fcb437e25cf558", + "sha256:bd659c11a4578af740782288cac141a322057a2e36920016e0fc7b25c5a4b686", + "sha256:c601c6fdebc729df4438ec1f62275d6136a0dd14d332fc0e8ce3f7d2aadb4dd6", + "sha256:d0877407359811f7b853b548a614aacd7dea83b0c0c84620a9a643f180060950" + ], + "index": "pypi", + "version": "==1.2.4" + }, + "patsy": { + "hashes": [ + "sha256:5465be1c0e670c3a965355ec09e9a502bf2c4cbe4875e8528b0221190a8a5d40", + "sha256:f115cec4201e1465cd58b9866b0b0e7b941caafec129869057405bfe5b5e3991" + ], + "version": "==0.5.1" + }, + "pillow": { + "hashes": [ + "sha256:01425106e4e8cee195a411f729cff2a7d61813b0b11737c12bd5991f5f14bcd5", + "sha256:031a6c88c77d08aab84fecc05c3cde8414cd6f8406f4d2b16fed1e97634cc8a4", + "sha256:083781abd261bdabf090ad07bb69f8f5599943ddb539d64497ed021b2a67e5a9", + "sha256:0d19d70ee7c2ba97631bae1e7d4725cdb2ecf238178096e8c82ee481e189168a", + "sha256:0e04d61f0064b545b989126197930807c86bcbd4534d39168f4aa5fda39bb8f9", + "sha256:12e5e7471f9b637762453da74e390e56cc43e486a88289995c1f4c1dc0bfe727", + "sha256:22fd0f42ad15dfdde6c581347eaa4adb9a6fc4b865f90b23378aa7914895e120", + "sha256:238c197fc275b475e87c1453b05b467d2d02c2915fdfdd4af126145ff2e4610c", + "sha256:3b570f84a6161cf8865c4e08adf629441f56e32f180f7aa4ccbd2e0a5a02cba2", + "sha256:463822e2f0d81459e113372a168f2ff59723e78528f91f0bd25680ac185cf797", + "sha256:4d98abdd6b1e3bf1a1cbb14c3895226816e666749ac040c4e2554231068c639b", + "sha256:5afe6b237a0b81bd54b53f835a153770802f164c5570bab5e005aad693dab87f", + "sha256:5b70110acb39f3aff6b74cf09bb4169b167e2660dabc304c1e25b6555fa781ef", + "sha256:5cbf3e3b1014dddc45496e8cf38b9f099c95a326275885199f427825c6522232", + "sha256:624b977355cde8b065f6d51b98497d6cd5fbdd4f36405f7a8790e3376125e2bb", + "sha256:63728564c1410d99e6d1ae8e3b810fe012bc440952168af0a2877e8ff5ab96b9", + "sha256:66cc56579fd91f517290ab02c51e3a80f581aba45fd924fcdee01fa06e635812", + "sha256:6c32cc3145928c4305d142ebec682419a6c0a8ce9e33db900027ddca1ec39178", + "sha256:8b56553c0345ad6dcb2e9b433ae47d67f95fc23fe28a0bde15a120f25257e291", + "sha256:8bb1e155a74e1bfbacd84555ea62fa21c58e0b4e7e6b20e4447b8d07990ac78b", + "sha256:95d5ef984eff897850f3a83883363da64aae1000e79cb3c321915468e8c6add5", + "sha256:a013cbe25d20c2e0c4e85a9daf438f85121a4d0344ddc76e33fd7e3965d9af4b", + "sha256:a787ab10d7bb5494e5f76536ac460741788f1fbce851068d73a87ca7c35fc3e1", + "sha256:a7d5e9fad90eff8f6f6106d3b98b553a88b6f976e51fce287192a5d2d5363713", + "sha256:aac00e4bc94d1b7813fe882c28990c1bc2f9d0e1aa765a5f2b516e8a6a16a9e4", + "sha256:b91c36492a4bbb1ee855b7d16fe51379e5f96b85692dc8210831fbb24c43e484", + "sha256:c03c07ed32c5324939b19e36ae5f75c660c81461e312a41aea30acdd46f93a7c", + "sha256:c5236606e8570542ed424849f7852a0ff0bce2c4c8d0ba05cc202a5a9c97dee9", + "sha256:c6b39294464b03457f9064e98c124e09008b35a62e3189d3513e5148611c9388", + "sha256:cb7a09e173903541fa888ba010c345893cd9fc1b5891aaf060f6ca77b6a3722d", + "sha256:d68cb92c408261f806b15923834203f024110a2e2872ecb0bd2a110f89d3c602", + "sha256:dc38f57d8f20f06dd7c3161c59ca2c86893632623f33a42d592f097b00f720a9", + "sha256:e98eca29a05913e82177b3ba3d198b1728e164869c613d76d0de4bde6768a50e", + "sha256:f217c3954ce5fd88303fc0c317af55d5e0204106d86dea17eb8205700d47dec2" + ], + "version": "==8.2.0" + }, "protobuf": { "hashes": [ "sha256:03346c1fdaaaece437f2b611f3f784114b109985aea65c7d98050f8d38fe671c", @@ -329,6 +466,27 @@ "index": "pypi", "version": "==5.0.5" }, + "pyparsing": { + "hashes": [ + "sha256:1c6409312ce2ce2997896af5756753778d5f1603666dba5587804f09ad82ed27", + "sha256:f4896b4cc085a1f8f8ae53a1a90db5a86b3825ff73eb974dffee3d9e701007f4" + ], + "version": "==3.0.0b2" + }, + "python-dateutil": { + "hashes": [ + "sha256:73ebfe9dbf22e832286dafa60473e4cd239f8592f699aa5adaf10050e6e1823c", + "sha256:75bb3f31ea686f1197762692a9ee6a7550b59fc6ca3a1f4b5d7e32fb98e2da2a" + ], + "version": "==2.8.1" + }, + "pytz": { + "hashes": [ + "sha256:83a4a90894bf38e243cf052c8b58f381bfe9a7a483f6a9cab140bc7f702ac4da", + "sha256:eb10ce3e7736052ed3623d49975ce333bcd712c7bb19a58b9e2089d4057d0798" + ], + "version": "==2021.1" + }, "pyyaml": { "hashes": [ "sha256:08682f6b72c722394747bddaf0aa62277e02557c0fd1c42cb853016a38f8dedf", @@ -414,6 +572,7 @@ "sha256:f68eb46b86b2c246af99fcaa6f6e37c7a7a413e1084a794990b877f2ff71f7b6", "sha256:fdf606341cd798530b05705c87779606fcdfaf768a8129c348ea94441da15b04" ], + "index": "pypi", "version": "==1.6.3" }, "six": { @@ -423,6 +582,33 @@ ], "version": "==1.15.0" }, + "statsmodels": { + "hashes": [ + "sha256:0197855aa1d40c42532d6a75b4ca72e30826a50d90ec3047a404f9702d8b814f", + "sha256:1fa720e895112a1b04b27002218b0ea7f10dd1d9cffd1c018c88bbfb82520f57", + "sha256:37e107fa11299090ed90f93c7172162b850c28fd09999937b971926813e887c5", + "sha256:3aab85174444f1bcad1e9218a3d3db08f0f86eeb97985236ca8605a0a39ce305", + "sha256:3e94306d4c07e332532ea4911d1f1d1f661c79aa73f22c5bb22e6dd47b40d562", + "sha256:4184487e9c281acad3d0bda19445c69db292f0dbb18f25ebf56a7966a0a28eef", + "sha256:43de84bc08c8b9f778502aed7a476d6e68674e6878718e533b07d569cf0927a9", + "sha256:587deb788e7f8f3f866d28e812cf5c082b4d4a2d3f5beea94d0e9699ea71ef22", + "sha256:5d3e7333e1c5b234797ed57c3d1533371374c1e1e7e7ed54d27805611f96e2d5", + "sha256:8ad7a7ae7cdd929095684118e3b05836c0ccb08b6a01fe984159475d174a1b10", + "sha256:8f93cb3f7d87c1fc7e51b3b239371c25a17a0a8e782467fdf4788cfef600724a", + "sha256:93273aa1c31caf59bcce9790ca4c3f54fdc45a37c61084d06f1ba4fbe56e7752", + "sha256:94d3632d56c13eebebaefb52bd4b43144ad5a131337b57842f46db826fa7d2d3", + "sha256:a3bd3922463dda8ad33e5e5075d2080e9e012aeb2032b5cdaeea9b79c2472000", + "sha256:aaf3c75fd22cb9dcf9c1b28f8ae87521310870f4dd8a6a4f1010f1e46d992377", + "sha256:c1d98ce2072f5e772cbf91d05475490368da5d3ee4a3150062330c7b83221ceb", + "sha256:c3782ce846a52862ac72f89d22b6b1ca13d877bc593872309228a6f05d934321", + "sha256:c48b7cbb37a651bb1cd23614abc10f447845ad3c3a713bf74e2aad20cfc94ae7", + "sha256:cbbdf6f708c9a1f1fad5cdea5e4342d6fdb37e42e92288c2cf906b99976ffe15", + "sha256:f3a7622f3d0ce2fc204f43b74de4e03e42775609705bf94d656b730482ca935a", + "sha256:f61f33f64760a22100b6b146217823f73cfedd251c9bdbd58453ca94e63326c7" + ], + "index": "pypi", + "version": "==0.12.2" + }, "tensorboard": { "hashes": [ "sha256:e167460085b6528956b33bab1c970c989cdce47a6616273880733f5e7bde452e" @@ -480,10 +666,10 @@ }, "urllib3": { "hashes": [ - "sha256:2f4da4594db7e1e110a944bb1b551fdf4e6c136ad42e4234131391e21eb5b0df", - "sha256:e7b021f7241115872f92f43c6508082facffbd1c048e3c6e2bb9c2a157e28937" + "sha256:753a0374df26658f99d826cfe40394a686d05985786d946fbe4165b5148f5a7c", + "sha256:a7acd0977125325f516bda9735fa7142b909a8d01e8b2e4c8108d0984e6e0098" ], - "version": "==1.26.4" + "version": "==1.26.5" }, "webcolors": { "hashes": [ @@ -531,18 +717,18 @@ }, "astroid": { "hashes": [ - "sha256:4db03ab5fc3340cf619dbc25e42c2cc3755154ce6009469766d7143d1fc2ee4e", - "sha256:8a398dfce302c13f14bab13e2b14fe385d32b73f4e4853b9bdfb64598baa1975" + "sha256:3c9a2d84354185d13213ff2640ec03d39168dbcd13648abc84fb13ca3b2e2761", + "sha256:d66a600e1602736a0f24f725a511b0e50d12eb18f54b31ec276d2c26a0a62c6a" ], - "version": "==2.5.6" + "version": "==2.5.7" }, "black": { "hashes": [ - "sha256:23695358dbcb3deafe7f0a3ad89feee5999a46be5fec21f4f1d108be0bcdb3b1", - "sha256:8a60071a0043876a4ae96e6c69bd3a127dad2c1ca7c8083573eb82f92705d008" + "sha256:1fc0e0a2c8ae7d269dfcf0c60a89afa299664f3e811395d40b1922dff8f854b5", + "sha256:e5cf21ebdffc7a9b29d73912b6a6a9a4df4ce70220d523c21647da2eae0751ef" ], "index": "pypi", - "version": "==21.5b1" + "version": "==21.5b2" }, "click": { "hashes": [ diff --git a/data/vocabulary b/data/vocabulary deleted file mode 100644 index 17e4015..0000000 Binary files a/data/vocabulary and /dev/null differ diff --git a/data_preparation.py b/data_preparation.py index fa06609..2ce3d07 100644 --- a/data_preparation.py +++ b/data_preparation.py @@ -5,7 +5,9 @@ import math import datetime import shutil +import random from multiprocessing import Pool, cpu_count +from collections import Counter import checksumdir from music21 import converter, instrument, stream, note, chord from random_word import RandomWords @@ -14,12 +16,12 @@ CHECKPOINTS_DIR = "checkpoints" MIDI_SONGS_DIR = "midi_songs" -DATA_DIR = "data" +TRAINING_DATA_DIR = "training_data" NOTES_FILENAME = "notes" VOCABULARY_FILENAME = "vocabulary" HASH_FILENAME = "dataset_hash" RESULTS_DIR = "results" -SEQUENCE_LENGTH = 100 +SEQUENCE_LENGTH = 60 VALIDATION_SPLIT = 0.2 """ @@ -31,16 +33,25 @@ NUM_NOTES_TO_PREDICT = 1 -def clean_data_and_checkpoints(): - shutil.rmtree(DATA_DIR) - shutil.rmtree(CHECKPOINTS_DIR) +def clear_checkpoints(): + try: + shutil.rmtree(CHECKPOINTS_DIR) + except FileNotFoundError: + print("Checkpoints directory doesn't exist") + + +def clear_training_data(): + try: + shutil.rmtree(TRAINING_DATA_DIR) + except FileNotFoundError: + print("Training data directory doesn't exist") def save_data_hash(hash_value): - if not os.path.isdir(DATA_DIR): - os.mkdir(DATA_DIR) + if not os.path.isdir(TRAINING_DATA_DIR): + os.mkdir(TRAINING_DATA_DIR) - hash_file_path = os.path.join(DATA_DIR, HASH_FILENAME) + hash_file_path = os.path.join(TRAINING_DATA_DIR, HASH_FILENAME) with open(hash_file_path, "wb") as hash_file: pickle.dump(hash_value, hash_file) @@ -48,7 +59,7 @@ def save_data_hash(hash_value): def is_data_changed(): current_hash = checksumdir.dirhash(MIDI_SONGS_DIR) - hash_file_path = os.path.join(DATA_DIR, HASH_FILENAME) + hash_file_path = os.path.join(TRAINING_DATA_DIR, HASH_FILENAME) if not os.path.exists(hash_file_path): save_data_hash(current_hash) return True @@ -63,37 +74,90 @@ def is_data_changed(): return False -def get_notes_from_file(file): - print(f"Parsing {file}") +def get_midi_in_default_octave(pattern): + if isinstance(pattern, note.Note): + note_in_default_octave = note.Note(pattern.name) + elif isinstance(pattern, int): + note_in_default_octave = note.Note(pattern) + + return note_in_default_octave.pitch.midi + + +def map_midi_to_reduced_octaves(midi_value, min_midi=4 * 12, max_midi=5 * 12 - 1): + if midi_value > max_midi: + return midi_value - (math.ceil((midi_value - max_midi) / 12) * 12) + + if midi_value < min_midi: + return midi_value + (math.ceil((min_midi - midi_value) / 12) * 12) + + return midi_value - midi = converter.parse(file) + +def get_notes_from_midi_stream(midi_stream, octave_transposition=0): + transposition = octave_transposition * 12 notes = [] + s2 = instrument.partitionByInstrument(midi_stream) + + # Looping over all the instruments + for part in s2.parts: + + # select elements of only piano + if "Piano" in str(part): + + notes_to_parse = part.recurse() + + # finding whether a particular element is note or a chord + for element in notes_to_parse: + + # note + if isinstance(element, note.Note): + midi_value = ( + map_midi_to_reduced_octaves(element.pitch.midi) + transposition + ) + notes.append(str(midi_value)) + + # chord + elif isinstance(element, chord.Chord): + midi_values = [ + map_midi_to_reduced_octaves(pitch.midi) + transposition + for pitch in element.pitches + ] + midi_values = list(set(midi_values)) + notes.append(".".join(str(midi) for midi in sorted(midi_values))) + return notes + + +def get_notes_from_file(file, augment_data=False, octave_augmentation=1): + print(f"Parsing {file}") + try: - # file has instrument parts - instrument_stream = instrument.partitionByInstrument(midi) - notes_to_parse = instrument_stream.parts[0].recurse() + midi_stream = converter.parse(file) except: - # file has notes in a flat structure - notes_to_parse = midi.flat.notes + return [] - for element in notes_to_parse: - if isinstance(element, note.Note): - notes.append(element.name) - elif isinstance(element, chord.Chord): - notes.append(".".join(str(n) for n in element.normalOrder)) + if augment_data: + all_notes = [] + for octave_transposition in range( + -octave_augmentation, octave_augmentation + 1 + ): + notes = get_notes_from_midi_stream(midi_stream, octave_transposition) + for note in notes: + all_notes.append(note) - return notes + else: + all_notes = get_notes_from_midi_stream(midi_stream) + + return all_notes def get_notes_from_dataset(): - notes_path = os.path.join(DATA_DIR, NOTES_FILENAME) + notes_path = os.path.join(TRAINING_DATA_DIR, NOTES_FILENAME) notes = [] if is_data_changed(): try: with Pool(cpu_count() - 1) as pool: - notes_from_files = pool.map( - get_notes_from_file, glob.glob(f"{MIDI_SONGS_DIR}/*.mid") - ) + files = glob.glob(f"{MIDI_SONGS_DIR}/*.mid") + notes_from_files = pool.map(get_notes_from_file, files) for notes_from_file in notes_from_files: for note in notes_from_file: @@ -103,7 +167,7 @@ def get_notes_from_dataset(): pickle.dump(notes, notes_data_file) except: - hash_file_path = os.path.join(DATA_DIR, HASH_FILENAME) + hash_file_path = os.path.join(TRAINING_DATA_DIR, HASH_FILENAME) os.remove(hash_file_path) print("Removed the hash file") sys.exit(1) @@ -122,17 +186,19 @@ def create_vocabulary_for_training(notes): sound_names = sorted(set(item for item in notes)) vocab = dict((note, number) for number, note in enumerate(sound_names)) - vocab_path = os.path.join(DATA_DIR, VOCABULARY_FILENAME) + vocab_path = os.path.join(TRAINING_DATA_DIR, VOCABULARY_FILENAME) with open(vocab_path, "wb") as vocab_data_file: pickle.dump(vocab, vocab_data_file) + print(f"*** vocabulary size: {len(vocab)} ***") + return vocab def load_vocabulary_from_training(): print("*** Restoring vocabulary used for training ***") - vocab_path = os.path.join(DATA_DIR, VOCABULARY_FILENAME) + vocab_path = os.path.join(TRAINING_DATA_DIR, VOCABULARY_FILENAME) with open(vocab_path, "rb") as vocab_data_file: return pickle.load(vocab_data_file) @@ -173,16 +239,56 @@ def prepare_sequence_for_prediction(notes, vocab): return network_input +def get_class_weights(notes, vocab): + mapped_notes = [vocab[note] for note in notes] + notes_counter = Counter(mapped_notes) + + for key in notes_counter: + notes_counter[key] = 1 / notes_counter[key] + + return notes_counter + + def get_best_representation(vocab, pattern): - """assumption: all 12 single notes are present in vocabulary""" + # assumption: all single notes (not necessarily from the same octave) + # are present in vocabulary + if pattern in vocab.keys(): return vocab[pattern] - chord_sounds = [int(sound) for sound in pattern.split(".")] - unknown_chord = chord.Chord(chord_sounds) + # either an unknown chord or an unknown single note + chord_midis = [int(midi) for midi in pattern.split(".")] + unknown_chord = chord.Chord(chord_midis) root_note = unknown_chord.root() - print(f"*** Mapping {unknown_chord} to {root_note} ***") - return vocab[root_note.name] + + nearest_note_midi = find_nearest_single_note_midi(vocab, root_note.midi) + print(f"*** Mapping {pattern} to {nearest_note_midi} ***") + return vocab[str(nearest_note_midi)] + + +def find_nearest_single_note_midi(vocab, midi_note): + if str(midi_note) in vocab.keys(): + return midi_note + + midi_note_down = midi_note + midi_note_up = midi_note + + while midi_note_down >= 0 or midi_note_up <= 87: + midi_note_down -= 12 + midi_note_up += 12 + + print(f"{midi_note} {midi_note_up} {midi_note_down}") + + if midi_note_down >= 0 and str(midi_note_down) in vocab.keys(): + return midi_note_down + + if midi_note_up <= 87 and str(midi_note_up) in vocab.keys(): + return midi_note_up + + print( + f"ALERT: couldn't find any appropriate representation of {midi_note} in vocabulary. Returned a random representation." + ) + return random.choice([key for key in vocab.keys() if not "." in key]) def save_midi_file(prediction_output): @@ -193,18 +299,20 @@ def save_midi_file(prediction_output): for pattern in prediction_output: # pattern is a chord if ("." in pattern) or pattern.isdigit(): - notes_in_chord = pattern.split(".") + midis_in_chord = [int(midi) for midi in pattern.split(".")] notes = [] - for current_note in notes_in_chord: - new_note = note.Note(int(current_note)) + for current_midi in midis_in_chord: + new_note = note.Note(current_midi) new_note.storedInstrument = instrument.Piano() notes.append(new_note) + new_chord = chord.Chord(notes) new_chord.offset = offset output_notes.append(new_chord) # pattern is a note else: - new_note = note.Note(pattern) + midi = int(pattern) + new_note = note.Note(midi) new_note.offset = offset new_note.storedInstrument = instrument.Piano() output_notes.append(new_note) @@ -224,3 +332,5 @@ def save_midi_file(prediction_output): midi_stream = stream.Stream(output_notes) midi_stream.write("midi", fp=f"{RESULTS_DIR}/{output_name}.mid") + + print(f"Result saved as {output_name}") diff --git a/network.py b/network.py index fb1fc9c..48cb5f0 100644 --- a/network.py +++ b/network.py @@ -1,36 +1,36 @@ from keras.models import Sequential -from keras.layers import Dense -from keras.layers import Dropout -from keras.layers import LSTM -from keras.layers import BatchNormalization as BatchNorm -from keras.layers import Activation +from keras.layers import Dense, Dropout, LSTM, Activation from data_preparation import SEQUENCE_LENGTH, NUM_NOTES_TO_PREDICT def create_network(vocab_size, weights_filename=None): + lstm_units = 128 + dense_units = vocab_size * 2 + dropout_rate = 0.3 model = Sequential() model.add( LSTM( - 512, + lstm_units, input_shape=(SEQUENCE_LENGTH, NUM_NOTES_TO_PREDICT), return_sequences=True, ) ) - model.add(LSTM(512, return_sequences=True)) - model.add(LSTM(512)) - model.add(BatchNorm()) - model.add(Activation("relu")) - model.add(Dropout(0.3)) - model.add(Dense(256)) - model.add(BatchNorm()) - model.add(Activation("relu")) - model.add(Dropout(0.3)) + model.add(Dropout(dropout_rate)) + model.add(LSTM(lstm_units, return_sequences=True)) + model.add(Dropout(dropout_rate)) + model.add(LSTM(lstm_units, return_sequences=True)) + model.add(Dropout(dropout_rate)) + model.add(LSTM(lstm_units)) + model.add(Dense(dense_units)) + model.add(Dropout(dropout_rate)) model.add(Dense(vocab_size)) model.add(Activation("softmax")) - model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=['acc']) + model.compile(loss="categorical_crossentropy", optimizer="rmsprop", metrics=["acc"]) if weights_filename: print(f"*** Loading weights from {weights_filename} ***") model.load_weights(weights_filename) + model.summary() + return model diff --git a/notes_sequence.py b/notes_sequence.py index ee31eb9..065923c 100644 --- a/notes_sequence.py +++ b/notes_sequence.py @@ -1,5 +1,5 @@ -from keras.utils import Sequence, np_utils import math +from keras.utils import Sequence, np_utils import numpy as np diff --git a/predict.py b/predict.py index b158469..1a570ad 100644 --- a/predict.py +++ b/predict.py @@ -1,7 +1,6 @@ import os import sys import getopt -import pickle import tensorflow as tf import numpy as np from network import create_network @@ -64,7 +63,7 @@ def generate_notes(model, network_input, vocab, vocab_size): def generate_music(file): - notes = get_notes_from_file(file) + notes = get_notes_from_file(file, augment_data=False) vocab = load_vocabulary_from_training() vocab_size = len(vocab) diff --git a/prediction_data/Fiend_Battle_(Piano).mid b/prediction_data/Fiend_Battle_(Piano).mid new file mode 100644 index 0000000..94f5181 Binary files /dev/null and b/prediction_data/Fiend_Battle_(Piano).mid differ diff --git a/prediction_data/MIDI-Unprocessed_043_PIANO043_MID--AUDIO-split_07-06-17_Piano-e_1-03_wav--3.mid b/prediction_data/MIDI-Unprocessed_043_PIANO043_MID--AUDIO-split_07-06-17_Piano-e_1-03_wav--3.mid new file mode 100644 index 0000000..c4236e3 Binary files /dev/null and b/prediction_data/MIDI-Unprocessed_043_PIANO043_MID--AUDIO-split_07-06-17_Piano-e_1-03_wav--3.mid differ diff --git a/prediction_data/MIDI-Unprocessed_Recital5-7_MID--AUDIO_05_R1_2018_wav--3.mid b/prediction_data/MIDI-Unprocessed_Recital5-7_MID--AUDIO_05_R1_2018_wav--3.mid new file mode 100644 index 0000000..a089844 Binary files /dev/null and b/prediction_data/MIDI-Unprocessed_Recital5-7_MID--AUDIO_05_R1_2018_wav--3.mid differ diff --git a/prediction_data/MIDI-Unprocessed_Recital5-7_MID--AUDIO_07_R1_2018_wav--2.mid b/prediction_data/MIDI-Unprocessed_Recital5-7_MID--AUDIO_07_R1_2018_wav--2.mid new file mode 100644 index 0000000..89efc60 Binary files /dev/null and b/prediction_data/MIDI-Unprocessed_Recital5-7_MID--AUDIO_07_R1_2018_wav--2.mid differ diff --git a/prediction_data/cosmo.mid b/prediction_data/cosmo.mid new file mode 100644 index 0000000..3c05f61 Binary files /dev/null and b/prediction_data/cosmo.mid differ diff --git a/prediction_data/ff4-town.mid b/prediction_data/ff4-town.mid new file mode 100644 index 0000000..a748f7d Binary files /dev/null and b/prediction_data/ff4-town.mid differ diff --git a/results/cryptus_riffle-bars.mid b/results/cryptus_riffle-bars.mid deleted file mode 100644 index de72ac6..0000000 Binary files a/results/cryptus_riffle-bars.mid and /dev/null differ diff --git a/results/nonstriking_waylayers.mid b/results/nonstriking_waylayers.mid deleted file mode 100644 index 2413f35..0000000 Binary files a/results/nonstriking_waylayers.mid and /dev/null differ diff --git a/results/output_2021-05-27 15:53:40.754080.mid b/results/output_2021-05-27 15:53:40.754080.mid deleted file mode 100644 index 17aa15d..0000000 Binary files a/results/output_2021-05-27 15:53:40.754080.mid and /dev/null differ diff --git a/results/programs_hill-fort.mid b/results/programs_hill-fort.mid deleted file mode 100644 index 2413f35..0000000 Binary files a/results/programs_hill-fort.mid and /dev/null differ diff --git a/results/puszta_decolorize.mid b/results/puszta_decolorize.mid deleted file mode 100644 index 2413f35..0000000 Binary files a/results/puszta_decolorize.mid and /dev/null differ diff --git a/results/tottori_lipide.mid b/results/tottori_lipide.mid deleted file mode 100644 index 2413f35..0000000 Binary files a/results/tottori_lipide.mid and /dev/null differ diff --git a/results/waterbomber_phatter.mid b/results/waterbomber_phatter.mid deleted file mode 100644 index 37c0259..0000000 Binary files a/results/waterbomber_phatter.mid and /dev/null differ diff --git a/tensorboard.py b/tensorboard.py index 20e4920..ad76184 100644 --- a/tensorboard.py +++ b/tensorboard.py @@ -33,9 +33,7 @@ def parse_cli_args(): return port -if __name__ == "__main__": - port = parse_cli_args() - +def start_tensorboard(port): tensorboard_process = subprocess.Popen( ["tensorboard", "--logdir", LOG_DIR, "--port", str(port)] ) @@ -57,3 +55,8 @@ def parse_cli_args(): ngrok.kill() tensorboard_process.terminate() + + +if __name__ == "__main__": + port = parse_cli_args() + start_tensorboard(port) diff --git a/train.py b/train.py index 9298156..3b595b3 100644 --- a/train.py +++ b/train.py @@ -4,18 +4,21 @@ import sys import tensorflow as tf from keras.models import load_model -from keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard +from keras.callbacks import ModelCheckpoint, TensorBoard from network import create_network from data_preparation import ( get_notes_from_dataset, prepare_sequences_for_training, create_vocabulary_for_training, - clean_data_and_checkpoints, + clear_training_data, + clear_checkpoints, + get_class_weights, ) LOG_DIR = "logs/" -BATCH_SIZE = 128 +BATCH_SIZE = 64 +DATASET_PERCENT = 1 def get_latest_checkpoint(): @@ -26,12 +29,13 @@ def get_latest_checkpoint(): checkpoints = ["checkpoints/" + name for name in os.listdir("checkpoints/")] if checkpoints: return max(checkpoints, key=os.path.getctime) - else: - return None + + return None def train_network(): notes = get_notes_from_dataset() + notes = notes[: int(len(notes) * DATASET_PERCENT)] vocab = create_vocabulary_for_training(notes) vocab_size = len(vocab) @@ -47,21 +51,21 @@ def train_network(): else: model = create_network(vocab_size) - train(model, training_sequence, validation_sequence) + class_weights = get_class_weights(notes, vocab) + + train(model, training_sequence, validation_sequence, class_weights) -def train(model, training_sequence, validation_sequence): +def train(model, training_sequence, validation_sequence, class_weights): filepath = "checkpoints/weights-improvement-{epoch:02d}-{loss:.4f}-bigger.hdf5" model_checkpoint = ModelCheckpoint( - filepath, monitor="loss", verbose=0, save_best_only=True, mode="min" + filepath, monitor="val_acc", verbose=0, save_best_only=True, mode="max" ) - early_stopping = EarlyStopping(monitor="val_loss", patience=3) - logdir = LOG_DIR + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") tensorboard = TensorBoard(log_dir=logdir) - callbacks_list = [model_checkpoint, early_stopping, tensorboard] + callbacks_list = [model_checkpoint, tensorboard] model.fit( x=training_sequence, @@ -69,16 +73,15 @@ def train(model, training_sequence, validation_sequence): epochs=200, callbacks=callbacks_list, shuffle=True, + # class_weight=class_weights, ) def parse_cli_args(): - usage_str = ( - f"Usage: {sys.argv[0]} [-h] [-c | --clean (clean data/ and checkpoints/)]" - ) + usage_str = f"Usage: {sys.argv[0]} [-h] [--clear-data (clear training_data/)] [--clear-checkpoints (clear checkpoints/)]" try: - opts, _ = getopt.getopt(sys.argv[1:], "hc", ["clean"]) + opts, _ = getopt.getopt(sys.argv[1:], "h", ["clear-data", "clear-checkpoints"]) except getopt.GetoptError: print(usage_str) sys.exit(2) @@ -87,11 +90,15 @@ def parse_cli_args(): if opt == "-h": print(usage_str) sys.exit(0) - elif opt in ["-c", "--clean"]: - clean_data_and_checkpoints() + elif opt == "--clear-data": + clear_training_data() + elif opt == "--clear-checkpoints": + clear_checkpoints() if __name__ == "__main__": + tf.compat.v1.reset_default_graph() + gpus = tf.config.experimental.list_physical_devices("GPU") for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) diff --git a/training_data/vocabulary b/training_data/vocabulary new file mode 100644 index 0000000..f345f7f Binary files /dev/null and b/training_data/vocabulary differ diff --git a/visualize_notes_data.py b/visualize_notes_data.py new file mode 100644 index 0000000..e694f90 --- /dev/null +++ b/visualize_notes_data.py @@ -0,0 +1,56 @@ +from collections import Counter +from math import ceil, sqrt +import matplotlib.pyplot as plt +import pandas as pd +from statsmodels.graphics.gofplots import qqplot +from data_preparation import get_notes_from_dataset, create_vocabulary_for_training +from train import DATASET_PERCENT + +print("Loading data...") +notes = get_notes_from_dataset() +notes = notes[: int(len(notes) * DATASET_PERCENT)] + +print("Creating vocabulary...") +vocab = create_vocabulary_for_training(notes) +vocab_size = len(vocab) + +print("Mapping notes using vocabulary...") +mapped_notes = [vocab[note] for note in notes] + +print("Counting occurances...") +notes_counter = Counter(mapped_notes) + +print("Rearranging occurances...") +counter_size = len(notes_counter) +least_common = notes_counter.most_common(counter_size) +least_common.reverse() +occurances = [0] * counter_size + +half_len = ceil(counter_size / 2) +for i in range(half_len): + if counter_size % 2 == 1 and i == half_len - 1: + occurances[i] = least_common[2 * i][1] + else: + occurances[i] = least_common[2 * i][1] + occurances[-i - 1] = least_common[2 * i + 1][1] + +sqrt_occurances = [sqrt(occurance) for occurance in occurances] + +print("Plotting...") +fig, axes = plt.subplots(2, 1) +fig.canvas.set_window_title("Maestro") +fig.suptitle("MIDI chords/notes from the Maestro dataset") + +axes[0].set_title( + "Distribution of single notes and chords (note.note.note...) in MIDI mapped to vocabulary which is sorted by the lowest MIDI value in a chord" +) +axes[0].bar([i for i in range(len(occurances))], sqrt_occurances) +axes[0].set_ylabel("sqrt(occurances)") +axes[0].set_xlabel("chord") +axes[0].set_xlim([0, len(occurances)]) + +df_pitches = pd.DataFrame(occurances, columns=["occurances"]) +axes[1].set_title("qqplot (normal distribution)") +qqplot(df_pitches["occurances"], line="s", ax=axes[1]) + +plt.show()