From e60ac35e0846a8fe40eb583e3b1778b908c948b7 Mon Sep 17 00:00:00 2001 From: Jaybit0 Date: Mon, 3 Feb 2025 09:32:03 +0100 Subject: [PATCH 1/9] Squashed commit of the following: commit 21758a82b2d64835e68ea3f8a99fd96ecefa5fe0 Author: Jaybit0 Date: Fri Jan 31 09:51:15 2025 +0100 Documentation fix commit 39af800bc3bed99e3a5bb1bab8ae1b3aaed00df0 Author: Jaybit0 Date: Fri Jan 31 09:41:51 2025 +0100 Minor fixes for documentation and legacy test removal commit 8bd149977fc5929c62c4185b6e5a305dbc5e90cd Author: Jaybit0 Date: Thu Jan 30 16:34:41 2025 +0100 Remove Legacy Test commit 98ef774f8560b2c21606637a5c7f5edfcdb223b4 Author: Jaybit0 Date: Thu Jan 30 16:01:18 2025 +0100 Revert changes in ApplyTransformTest commit 7db72d6aa68fc0abd36486f3264f2c6ccc531761 Author: Jaybit0 Date: Thu Jan 30 15:59:33 2025 +0100 Cleanup commit 6f9f59a41bb0be04df01e0565a44c33cedc9a61f Author: Jaybit0 Date: Thu Jan 30 15:53:12 2025 +0100 Some more cleanup commit 976eb8fe266a1c1551c06e5aadd7989334afe9be Author: Jaybit0 Date: Thu Jan 30 15:22:07 2025 +0100 Revert change from OptUtils commit 27cebd5a4025a4207da5306458a301876092b764 Author: Jaybit0 Date: Thu Jan 30 15:18:36 2025 +0100 Code Cleanup commit 4c447db2416289aa76feb022e45dde75d965329e Author: Jaybit0 Date: Thu Jan 30 15:02:33 2025 +0100 Update ProgramRewriter.java commit 8beb9b3858b6914570a86e118e77c749dcff5e9f Author: Jaybit0 Date: Thu Jan 30 14:58:38 2025 +0100 Update MetaPropagator.java commit fa6d98b7a65e637a65df1027bdf1553aee7981ac Author: Jaybit0 Date: Thu Jan 30 14:50:48 2025 +0100 Update RewriterFramework.java commit 98090c07f1ccb7abc23effd13c60d982afd18f77 Author: Jaybit0 Date: Thu Jan 30 13:59:47 2025 +0100 Bugfix commit 94d46a8d05bd02bd86cebbccba7b7b21f908b9e7 Author: Jaybit0 Date: Thu Jan 30 13:41:03 2025 +0100 Various small bugfixes commit 66350bac5e64b3660099ea0ebc3ce0bdb0d9d799 Author: Jaybit0 Date: Thu Jan 30 13:16:26 2025 +0100 Remove debug log commit d8292c99485cc0866917a0ee8d4dd7eac9473680 Author: Jaybit0 Date: Thu Jan 30 13:11:42 2025 +0100 Disable Logs commit 2d0fb82c0d5c787ccc7f31f9ef1ffd11d531d215 Author: Jaybit0 Date: Thu Jan 30 13:11:03 2025 +0100 Bugfix commit ce97d7a7e7f2e0a8e5918230836300ab746933cf Author: Jaybit0 Date: Thu Jan 30 10:40:34 2025 +0100 Some more bugfixes commit 13db978676353b1d0ab20b9c4d171eb62d0a714d Author: Jaybit0 Date: Thu Jan 30 10:17:40 2025 +0100 Logging via log4j commit f85493a0cfe6bb0be6232fa2693e7f2ec0dbfb2d Author: Jaybit0 Date: Thu Jan 30 09:55:05 2025 +0100 Disable all (knowingly) failing tests commit cd4315076cef78e0c22a343883e33573b4f37597 Merge: 3647c74241 8eed966ea2 Author: Jaybit0 Date: Wed Jan 29 16:49:06 2025 +0100 Merge branch 'main' of https://github.com/Jaybit0/systemds commit 3647c74241e5f721219ba8b92dc4f68e8e314043 Merge: c3ddfbff61 d44ec2bc85 Author: Jaybit0 Date: Wed Jan 29 16:48:58 2025 +0100 Merge branch 'RewriterRebase' commit c3ddfbff61335af71647f7df51cb38b4814252af Author: Jaybit0 Date: Thu Aug 15 13:32:51 2024 +0200 Update ProgramBlock.java commit 8181a0192f44f7fdb978b00b620239b7f2277adc Author: Jaybit0 Date: Sun Aug 4 10:12:44 2024 +0200 Workaround to avoid increasing object size of LineageItem commit e6bdcc4e4b67032066b88285fc61159e7775fe01 Author: Jaybit0 Date: Fri Aug 2 13:32:27 2024 +0200 NGrams now use lineage per default commit f85071e60d7075f197f30468d7e30246d98206c8 Author: Jaybit0 Date: Thu Jul 4 12:45:33 2024 +0200 Some more bugfixes commit e97e1408ea2b0e535fa53b0c6328882f3d259a99 Author: Jaybit0 Date: Mon Jul 1 16:12:52 2024 +0200 More extensive stats for n-grams commit d44ec2bc85d67c0e0aee2cf811311efbf87fa065 Author: Jaybit0 Date: Wed Jan 29 16:42:51 2025 +0100 Update RewriterFramework.java commit 54bc0b503f2f6c916018f5b3ecfcfcff6a7ef1ef Author: Jaybit0 Date: Wed Jan 29 16:12:24 2025 +0100 More Code Cleanup commit 1d316f272f379e8aeede3f0cccc45bece7f0d13b Author: Jaybit0 Date: Wed Jan 29 15:53:32 2025 +0100 Remove additional files commit 811d82a224c024d3e7710c4bbd8078f7ed9acea2 Author: Jaybit0 Date: Wed Jan 29 15:52:13 2025 +0100 Cleanup TestBase commit ab03423d05b632f298fa977623a627ef99231057 Author: Jaybit0 Date: Wed Jan 29 15:41:58 2025 +0100 Reverting some changes commit 313d8debec11a57852a83adbb15049620fa42ee8 Author: Jaybit0 Date: Wed Jan 29 14:57:25 2025 +0100 Minor bugfixes commit d74190293125da5530214a520bceaa91bda9fd68 Author: Jaybit0 Date: Wed Jan 29 14:39:59 2025 +0100 Some improvements commit 259696eaad5f142e02952f60015f0c8d4491c2ae Author: Jaybit0 Date: Wed Jan 29 12:27:31 2025 +0100 Better error handling commit d077f937dceba82cee0f9b6ce74349a8561c59b4 Author: Jaybit0 Date: Wed Jan 29 12:10:27 2025 +0100 Fix rowVec/colVec commit 994bfdbbc3aaafe03cc87992b57d5ace5e98c78f Author: Jaybit0 Date: Wed Jan 29 12:04:14 2025 +0100 Update log4j.properties commit efb805d098e2cb0df5f935fe8bc35b42f3299728 Author: Jaybit0 Date: Wed Jan 29 12:04:09 2025 +0100 Create RewriterFramework.java commit 95e440b64e1e7494f8081de77da9bfd85a0c0cc0 Author: Jaybit0 Date: Wed Jan 29 12:03:59 2025 +0100 Bugfix CostEstimator commit 4117aa43d137d634aaa408f02aa1d6e84e0a1d9f Author: Jaybit0 Date: Wed Jan 29 11:07:22 2025 +0100 Update TopologicalSort.java commit ab82249677d78ab949c697c171865887e46f24d0 Author: Jaybit0 Date: Tue Jan 28 16:16:53 2025 +0100 More Code Cleanup commit c13a1a815c88a1c23649b6aadc0acf80a8111ebc Author: Jaybit0 Date: Tue Jan 28 16:10:26 2025 +0100 Further Code Cleanup commit 3b07f756997bfdfe30d42ca4747f8a847e8deed6 Author: Jaybit0 Date: Tue Jan 28 15:42:26 2025 +0100 Code Cleanup and add Licenses commit 966db9344acbe09d253a050954fc0341ffd82e33 Author: Jaybit0 Date: Tue Jan 28 13:24:14 2025 +0100 Add License to Tests commit 4a58174ebeef79b7fdac855c01d4904390809c5d Author: Jaybit0 Date: Tue Jan 28 12:45:53 2025 +0100 Cleanup commit 0abd3784d609f3b53cd458f8a1a70583f01c7f93 Author: Jaybit0 Date: Tue Jan 28 12:45:43 2025 +0100 Increased stability of our sort algorithm commit 2f6de8fc74c5f53d349ba61fea326fc4b6049f6d Author: Jaybit0 Date: Tue Jan 28 12:03:35 2025 +0100 Bugfix: Diag commit ef49168dd7dd1df68f85bfc1df65c4801e1113e1 Author: Jaybit0 Date: Tue Jan 28 11:50:36 2025 +0100 Bugfix fix commit 8fca8e6d25338f256a94699692498589b5e39443 Author: Jaybit0 Date: Tue Jan 28 11:44:05 2025 +0100 Bugfix commit edc93d96763661809884e7ba5e64dd9e2437188c Author: Jaybit0 Date: Thu Jan 23 10:25:25 2025 +0100 Update commit e30fac577a91c1d92ff92679fdb460944bec5e50 Author: Jaybit0 Date: Mon Jan 20 18:57:47 2025 +0100 Update commit a1650c5369e0fab18347b8105ac251723952e4b3 Author: Jaybit0 Date: Mon Jan 20 10:32:37 2025 +0100 Various bugfixes commit 74d5235b8e5006edd2a74968fc5ea6add3651fb0 Author: Jaybit0 Date: Sun Jan 19 10:43:33 2025 +0100 Some updates commit adc21e7ade4958527d04df9ab4fa01f4bdeda8fe Author: Jaybit0 Date: Sat Jan 18 15:01:43 2025 +0100 Bugfix for BinaryOp construction commit 32dd096c25bfff0a3e72a2032fc8bbb683e15e18 Author: Jaybit0 Date: Fri Jan 17 12:57:53 2025 +0100 Update commit bda5cb9ced5583174f42e299fe83fd053db7b22e Author: Jaybit0 Date: Thu Jan 16 16:54:14 2025 +0100 Update commit 61f26c9813604d66f7b025f06ef0aba830607ea8 Author: Jaybit0 Date: Thu Jan 16 15:57:14 2025 +0100 Some updates commit 81f1b0fe78f9e9fb06e2c8c9a75e5709545c9bd9 Author: Jaybit0 Date: Thu Jan 16 15:07:43 2025 +0100 Some fixes commit ff36f126113def38b5b1fb71ab422d38f3a584cf Author: Jaybit0 Date: Thu Jan 16 14:26:01 2025 +0100 Some updates commit a22f3dd8930cd28c6f924d1a192d8e2feda76e7f Author: Jaybit0 Date: Thu Jan 16 13:16:24 2025 +0100 Update commit a4ca40a51d34b0fc5440f3a874e7c51d622df220 Author: Jaybit0 Date: Thu Jan 16 11:50:29 2025 +0100 Better literal support commit ad3c0798169db8087ce53a1003b69ad350c8cc4b Author: Jaybit0 Date: Thu Jan 16 11:18:31 2025 +0100 Add licenses commit a9c61d7ff384d21c2d50d701e206e3db3f332ac1 Author: Jaybit0 Date: Thu Jan 16 11:15:40 2025 +0100 Update TopologicalSort.java commit d46e9832b0d6a51a9baefd7af9bdb66a7ce96ff1 Author: Jaybit0 Date: Thu Jan 16 11:13:06 2025 +0100 Some bugfixes commit b6b8cb5af43690e0972d3cc65764e32681a5d3c3 Author: Jaybit0 Date: Tue Jan 14 09:38:49 2025 +0100 Update commit 4b6d103a6287eeea972b1853dacf6045ef634e96 Author: Jaybit0 Date: Mon Jan 13 14:56:40 2025 +0100 Some fixes commit 3bdc01c9635c36485c9c33d36f7c866dee678dff Author: Jaybit0 Date: Mon Jan 13 13:11:36 2025 +0100 Some more fixes commit 4f3a9b730385da22d580d7a516556b2882eb9e08 Author: Jaybit0 Date: Mon Jan 13 12:10:18 2025 +0100 Backup commit 4078d3476488c6ab3ccff7ae897caa4124e05d77 Author: Jaybit0 Date: Mon Jan 13 12:01:55 2025 +0100 Some more fixes commit 7d6299d48b1a2df4b5285b2ad90d5e6ecc959a89 Author: Jaybit0 Date: Sun Jan 12 18:32:04 2025 +0100 Update RewriterNormalFormTests.java commit e3c42ad5d46b71fa043706e7287139db72abc852 Author: Jaybit0 Date: Sun Jan 12 18:29:26 2025 +0100 Update RewriterNormalFormTests.java commit 62647fb36a8d5141bedf71bfbcdc18aa2454746b Author: Jaybit0 Date: Sun Jan 12 18:29:04 2025 +0100 Update RewriterNormalFormTests.java commit 2dc44c4ba38867b994bd7ebab489ac83bba04389 Author: Jaybit0 Date: Sun Jan 12 18:25:55 2025 +0100 Update RewriterNormalFormTests.java commit 13cb5d13c722f39ae4b156facee3b71fc1a64624 Author: Jaybit0 Date: Sun Jan 12 18:24:18 2025 +0100 Normal Form Tests commit 802360625e71044a9d0a91414da809ae8eeda58e Author: Jaybit0 Date: Sun Jan 12 14:59:06 2025 +0100 Some tests for experiments commit 2851b1f8e6583b336ad49512896e8caff419de74 Author: Jaybit0 Date: Sat Jan 11 17:42:52 2025 +0100 Some more updates commit 5aa69ea8494c1a398be854cc479efdcbe3741843 Author: Jaybit0 Date: Thu Jan 9 16:20:33 2025 +0100 Fix build commit 0c25af4aba416944b80bf2aaaca08fb6ad9f9055 Author: Jaybit0 Date: Thu Jan 9 13:46:52 2025 +0100 Some improvements and fixes commit a8244ef5ad545bac0d4e3a48c706517792b34b9b Author: Jaybit0 Date: Wed Jan 8 16:47:43 2025 +0100 Various fixes commit 549d4566eca5fc8714430504060ad07a5c10cc5b Author: Jaybit0 Date: Tue Jan 7 16:44:49 2025 +0100 Multirule parsing commit 53e58534d4a38b64baa704a2e7fddc9b750951a4 Author: Jaybit0 Date: Tue Jan 7 16:19:42 2025 +0100 Some more fancy stuff commit 927dedcceacd89b38e4e9d960021eda094d73742 Author: Jaybit0 Date: Tue Jan 7 15:45:27 2025 +0100 Further bugfixes commit 95de6b838ca414d82bf61beff9263591398ac7f5 Author: Jaybit0 Date: Mon Jan 6 16:51:50 2025 +0100 Some more fixes commit a26ea8c6dd96a84d362c439dee8e102e85158296 Author: Jaybit0 Date: Mon Jan 6 15:25:37 2025 +0100 Some more improvements commit a8bb0d4697026d801871dbb69a16a53e989fbb03 Author: Jaybit0 Date: Mon Jan 6 15:12:34 2025 +0100 Some more cleanup commit 7058b266a46a612fdd93698868dca1253ad1b6ff Author: Jaybit0 Date: Mon Jan 6 15:07:37 2025 +0100 Some bugfixes commit e5d16ba80341d41b354af7f7f8bf20b42148d868 Author: Jaybit0 Date: Sun Jan 5 17:09:03 2025 +0100 Update commit 78a5561c48025f3861b001e0a5e2b30f6f1b25bf Author: Jaybit0 Date: Sat Jan 4 16:15:41 2025 +0100 Some more bugfixes commit 2439b0fab73d0b565f8118cb777fa53a8f4187c2 Author: Jaybit0 Date: Sat Jan 4 15:31:30 2025 +0100 Some more bugfixes commit 08898c7a3223a4bf0fe3d37293713b1637a30d3b Author: Jaybit0 Date: Sat Dec 21 10:39:30 2024 +0100 Backup commit d98a93b1297b6c7ff059a3e42c16be0dab6f1167 Author: Jaybit0 Date: Wed Dec 18 13:05:59 2024 +0100 Update RewriterAlphabetEncoder.java commit 2520d16218ea13ef8f009ff30c608b0ba8be03fb Author: Jaybit0 Date: Wed Dec 18 12:31:47 2024 +0100 Some bugfixes commit b5ec81301249d1cfbe5694c573ffce53bc7f214b Author: Jaybit0 Date: Mon Dec 16 17:15:00 2024 +0100 Some more fixes commit f74b7f8e46e0e83c39390cf937abe17e18841058 Author: Jaybit0 Date: Mon Dec 16 16:40:28 2024 +0100 Some more updates commit 26eaeadd45dd5d9d9ef4144b150f64ef5b94841b Author: Jaybit0 Date: Mon Dec 16 16:00:07 2024 +0100 Some more bugfixes commit bfcbe440e9c9fb46e202fa8caaf06dd9505da6e3 Author: Jaybit0 Date: Mon Dec 16 15:35:28 2024 +0100 Some improvements commit 0d5651f3e116010abf13fd65d2b9c70d762521d2 Author: Jaybit0 Date: Sat Dec 14 12:50:47 2024 +0100 Update TopologicalSort.java commit d885d53ebcfce80cf125cd3c4d0a46d758ac6d73 Author: Jaybit0 Date: Sat Dec 14 12:48:17 2024 +0100 Several bugfixes commit 511efac805a7e6b0a2004e45b774dba06b41a52c Author: Jaybit0 Date: Sat Dec 14 11:55:10 2024 +0100 Some more changes commit 0346da1e2260943d2fde404a62ec4b9d95094817 Author: Jaybit0 Date: Sat Dec 14 11:31:49 2024 +0100 Some more bug-fixes commit b34d73e7bb733e56c7fbc4c56c645f208ec6988e Author: Jaybit0 Date: Fri Dec 13 19:13:45 2024 +0100 Bugfix type casting commit ecd9e80f80162664b11ad02425473ca0a0856e89 Author: Jaybit0 Date: Fri Dec 13 09:22:41 2024 +0100 Backup commit abfc8b4b321c56e9e18370eef48d4f35d267aa95 Author: Jaybit0 Date: Thu Dec 12 16:15:38 2024 +0100 Benchmarking infrastructure commit 0709da319a6b5fd15284034f12365641a020daaf Author: Jaybit0 Date: Wed Dec 11 16:41:50 2024 +0100 More improvements commit 927eb2b7679ef860d29881ad46af42725ed3067a Author: Jaybit0 Date: Wed Dec 11 15:33:26 2024 +0100 Some more improvements commit b90a4a4255a0a7c4370626f091ebfdac448478f3 Author: Jaybit0 Date: Wed Dec 11 15:22:49 2024 +0100 Fix commit 605d1bdba7600c01d39f9850f3449d62256a99a6 Author: Jaybit0 Date: Wed Dec 11 14:38:04 2024 +0100 Update RewriterRuleCollection.java commit 26414e8962aac4394eb5f01985af85e7df401478 Author: Jaybit0 Date: Wed Dec 11 14:13:58 2024 +0100 Bugfix commit 47515816044a65a5fc6dbb0b60743f69341af0f3 Author: Jaybit0 Date: Tue Dec 10 16:37:20 2024 +0100 Some progress commit f8250c106b1e52a3588db5899f8b8355df0095cd Author: Jaybit0 Date: Tue Dec 10 16:15:54 2024 +0100 Some fixes commit c4075cd5ff51879b0d8c7cd24eaa56fb0fdf9ed8 Author: Jaybit0 Date: Tue Dec 10 15:05:01 2024 +0100 Bugfix commit 55702af29462f3ed81c81093359eeda87804430e Author: Jaybit0 Date: Tue Dec 10 09:22:37 2024 +0100 Update GeneratedRewriteClass.java commit a012b3279d72130e22ff82e8d3ece4730eb3b36f Author: Jaybit0 Date: Mon Dec 9 17:34:21 2024 +0100 Bugfix commit a68d6b6b67bd32cd0c8dcbb690807ba4fcaa63bd Author: Jaybit0 Date: Sat Dec 7 14:54:09 2024 +0100 Backup commit 76fa4edcfdc817106a6a305e33d563792332963a Author: Jaybit0 Date: Sat Dec 7 12:38:25 2024 +0100 Update CodeGenTests.java commit 398881765f2d0165ae0588aca8a69ec2849258e0 Author: Jaybit0 Date: Fri Dec 6 13:37:14 2024 +0100 Bugfix commit 5f9be9e2d051466d5f688441494e8d1bee98e799 Author: Jaybit0 Date: Fri Dec 6 12:59:32 2024 +0100 Some improvements commit 616c199333332c4c6983f72cfd527f5786a21456 Author: Jaybit0 Date: Thu Dec 5 15:17:03 2024 +0100 Some changes commit 123523619475718ec96978622ef7b35679c3f57d Author: Jaybit0 Date: Thu Dec 5 11:58:37 2024 +0100 Some cleanup (TODO: Check if something broke) commit 0322c6061fde5c2543f367c7e80776db2429a37b Author: Jaybit0 Date: Thu Dec 5 11:50:41 2024 +0100 Some more improvements commit cb48d354f7766a7728767d6988af9537766bab2f Author: Jaybit0 Date: Thu Dec 5 10:46:57 2024 +0100 Some more improvements commit 8e7257845fb1816ba39f522d5d73170f39523f70 Author: Jaybit0 Date: Wed Dec 4 13:29:59 2024 +0100 Some improvements commit d51f5ed785d41ab0bb1ccce28e0cc0c1f597bb24 Author: Jaybit0 Date: Wed Dec 4 12:13:09 2024 +0100 Some more improvements commit 45113496093ad689079405098e34b78662aeb468 Author: Jaybit0 Date: Tue Dec 3 14:44:33 2024 +0100 Some more improvements commit 838fe934f81729ed1ecb36d3dadac622c59944a0 Author: Jaybit0 Date: Tue Dec 3 14:26:31 2024 +0100 Update RewriterCostEstimator.java commit e1848cef2b0c19f39bcee45d469f33bd2bbb9c23 Author: Jaybit0 Date: Tue Dec 3 13:07:56 2024 +0100 Bugfix commit ed84ce754e22e81252224178c354dbafb62fa235 Author: Jaybit0 Date: Tue Dec 3 12:57:23 2024 +0100 Some improvements commit 6e93733042ff4067adcc394077c90ca0ff95ff11 Author: Jaybit0 Date: Tue Dec 3 11:29:02 2024 +0100 Bugfix commit feecf16b354cadcc0c0762fd24753ddf6826a609 Author: Jaybit0 Date: Tue Dec 3 11:25:29 2024 +0100 Some more improvements commit 5bbda9a55f0b9f109ce4a9477d6797fefaa207c3 Author: Jaybit0 Date: Mon Dec 2 15:09:31 2024 +0100 Sparsity esimtation commit edc3f177acbb0cb7fe53f7ea4daed8384129057c Author: Jaybit0 Date: Mon Dec 2 13:59:41 2024 +0100 NNZ estimatior commit f7d4cb42db6ba207cea905ce74645fda2fe29dec Author: Jaybit0 Date: Mon Dec 2 11:59:25 2024 +0100 Some more improvements commit 1e17b4f326156c680a1852798aca2acfecf4168e Author: Jaybit0 Date: Mon Dec 2 11:34:13 2024 +0100 Some more improvements commit 81c457ccf0e2efd2734c5dd539e8d65adf041a39 Author: Jaybit0 Date: Sat Nov 30 12:02:35 2024 +0100 Some more improvements commit bf49b5227f8ec141789ab62dadba0e3537051572 Author: Jaybit0 Date: Sat Nov 30 11:45:01 2024 +0100 Some bugfixes commit ab0ccda6fc0f52c747ff9016e7384cfdeef7b3de Author: Jaybit0 Date: Fri Nov 29 12:41:42 2024 +0100 Some improvements commit 61248fc8b194a27e16f786926483360eecd8fb01 Author: Jaybit0 Date: Fri Nov 29 10:49:48 2024 +0100 Some progress commit a94575f11f0bc08bf1e769b0e94e5e21cebdbbd6 Author: Jaybit0 Date: Fri Nov 29 10:42:10 2024 +0100 Some improvements commit 4067c7e56d54462893466a54620e8171536b9664 Author: Jaybit0 Date: Thu Nov 28 18:06:15 2024 +0100 Some more improvements commit 77418caaca9c06ccdd275f95eacfc0174342d79d Author: Jaybit0 Date: Thu Nov 28 15:48:48 2024 +0100 Some breakthroughs commit a0441d886aca24d37d0b9174d31ea190e91af6ea Author: Jaybit0 Date: Thu Nov 28 12:03:10 2024 +0100 Some more stuff commit 32a46d616833b44d3fa00ca75a33137e9fcfb2cc Author: Jaybit0 Date: Wed Nov 27 17:27:32 2024 +0100 Backup commit d69232cd4c824ab51d8e13beb06eab3a835610d4 Author: Jaybit0 Date: Wed Nov 27 16:39:43 2024 +0100 Some more improvements commit 0a8ed4dd340d0fd3b8d89624e0332a6bbf2f8c83 Author: Jaybit0 Date: Wed Nov 27 16:25:29 2024 +0100 Some more fixes commit 7b4de8d4a0de0f5fd37c8f0edca0acc95c44518c Author: Jaybit0 Date: Wed Nov 27 16:13:37 2024 +0100 Some improvements commit 0a641e911e4332e228db69cd3b9591270322adfe Author: Jaybit0 Date: Wed Nov 27 16:01:04 2024 +0100 Some more bugfixes commit 77f5d57d6fd35a5c1612b1cdc15bfae085c55041 Author: Jaybit0 Date: Wed Nov 27 15:56:29 2024 +0100 Some bugfixes commit 31a410ff8f8e94acebebe2ee1d199e306d441c90 Author: Jaybit0 Date: Wed Nov 27 15:24:44 2024 +0100 Some more changes commit 81a272e1018dae514d395588b62518a831e28d8c Author: Jaybit0 Date: Wed Nov 27 15:15:08 2024 +0100 Some better constant handling commit 613657b6f820b383a40f3290830d1b7e67b4935a Author: Jaybit0 Date: Wed Nov 27 14:16:07 2024 +0100 Some more fixes commit b8d3424b83f9fb030fb101fcd840be1869ce9fe3 Author: Jaybit0 Date: Tue Nov 26 19:00:30 2024 +0100 Some more updates commit 05b39d813868722952771ea1250e64744b464161 Author: Jaybit0 Date: Tue Nov 26 18:38:19 2024 +0100 Some more fixes commit ab76040baeab4119317a6594703a1141b8b79f6f Author: Jaybit0 Date: Tue Nov 26 17:31:15 2024 +0100 Some more improvements commit 9852e362f8de47e563467a4e61d64ba06ebfc1ca Author: Jaybit0 Date: Tue Nov 26 17:08:21 2024 +0100 Some more improvements commit e354cc1c6edb07833192c66e4a6bdd557e2159bc Author: Jaybit0 Date: Tue Nov 26 15:52:43 2024 +0100 Some more changes commit 04867f37afa2ca356284ed2a558d0c20d32ed9cf Author: Jaybit0 Date: Tue Nov 26 11:38:06 2024 +0100 Some bugfixes (origin still unknown) commit 3bdc6c6354c51787c4504ad287468207810a06b2 Author: Jaybit0 Date: Tue Nov 26 10:40:25 2024 +0100 Some bugfixes commit c4b5a1fb8dc6d70a73258691c062be69403b01c9 Author: Jaybit0 Date: Mon Nov 25 19:51:01 2024 +0100 Major improvements to randomized search commit ca7d84dadb96f318570914f00956409e0d884349 Author: Jaybit0 Date: Mon Nov 25 17:54:59 2024 +0100 Some more improvements commit 470d1be36c18ca01c31814e49ca1656ae82e8cd0 Author: Jaybit0 Date: Mon Nov 25 17:13:41 2024 +0100 Some more bugfixes commit a3500efa9fbec51ea66a164d4d3ed153a4f01d15 Author: Jaybit0 Date: Mon Nov 25 15:11:30 2024 +0100 Some bugfixes commit 7e50769b3fc2f9ade9673aa98dda145cc80454e7 Author: Jaybit0 Date: Mon Nov 25 14:17:26 2024 +0100 Some more fixes commit 99e9d9aaeb67d41a46625f31785044c8318bec16 Author: Jaybit0 Date: Mon Nov 25 12:25:11 2024 +0100 Some more improvements commit 0f9f5ad04e8fead095ea27bd917f36c6282129de Author: Jaybit0 Date: Mon Nov 25 11:27:23 2024 +0100 Some improvements commit 58193f561058399515e6de03e69d04a4e1904e15 Author: Jaybit0 Date: Sun Nov 24 16:34:20 2024 +0100 Some improvements commit 91d0086badb80a2099a731192312341cb72fbc28 Author: Jaybit0 Date: Fri Nov 22 18:59:50 2024 +0100 Error handling commit 023d8d579bdaf6a742098d1b90a48c8c8f4d9a68 Author: Jaybit0 Date: Fri Nov 22 18:53:18 2024 +0100 Some more improvements commit ddf613606534b242bae9c758a1ff749cd5e38171 Author: Jaybit0 Date: Fri Nov 22 18:46:36 2024 +0100 Some more improvements commit 52075ac1e41a5a29d9a71babed1cc1a18e270aca Author: Jaybit0 Date: Fri Nov 22 18:10:36 2024 +0100 Some more improvements commit 970149f61af50c312d41f67a34c0cda5dc679d3c Author: Jaybit0 Date: Fri Nov 22 17:56:31 2024 +0100 Some improvements commit 47ee20ff6c43a48f860ee89556d213dea7766166 Author: Jaybit0 Date: Fri Nov 22 16:43:03 2024 +0100 Some more fixes commit 969e59325af0cc2df3509be86d00d780d7bbb23a Author: Jaybit0 Date: Fri Nov 22 15:33:22 2024 +0100 Some improvements commit b076ad5b10d1571f82258fceaef76e469eff2ac4 Author: Jaybit0 Date: Fri Nov 22 13:44:11 2024 +0100 Some cost fixes commit 18f5db9416d7c3eea2a4269d8e18922b672c356d Author: Jaybit0 Date: Fri Nov 22 13:04:54 2024 +0100 Some more fixes commit 6303ca411efd9cba75988ec0547cff9b928b6baa Author: Jaybit0 Date: Fri Nov 22 13:00:24 2024 +0100 LogNZ support commit e9ff9535e6012d5ef9080c078c70f640557193d8 Author: Jaybit0 Date: Fri Nov 22 12:07:18 2024 +0100 First fused op commit f1378e9bd9e1cb388bfe5d1f08da687c299e962c Author: Jaybit0 Date: Fri Nov 22 11:25:42 2024 +0100 Test fix commit 2522042e048760137354aabd318d727a846e1b69 Author: Jaybit0 Date: Fri Nov 22 11:22:02 2024 +0100 Some more minor fixes commit 1a0effe0ee1765a51bfdf6ec751857f819485610 Author: Jaybit0 Date: Fri Nov 22 11:16:38 2024 +0100 Some fixes commit 6283cb4757c273ec6b747aa92156807fd19e3b44 Author: Jaybit0 Date: Fri Nov 22 11:12:36 2024 +0100 Some more bugfixes commit 4f6587a706efafe27fa53fdcd385389a6eff0606 Author: Jaybit0 Date: Thu Nov 21 17:02:10 2024 +0100 Some fixes commit 89f119e3f5659596d3d9f57df52102ff1cd2eb41 Author: Jaybit0 Date: Thu Nov 21 16:21:02 2024 +0100 Some bugfixes commit 6f0115e8a79dcbeaae791591c5f68880e54d2094 Author: Jaybit0 Date: Thu Nov 21 15:03:38 2024 +0100 Some more improvements commit c9c2eda07f517ef514551d483f07d19b745a23fd Author: Jaybit0 Date: Thu Nov 21 14:43:18 2024 +0100 Bugfix commit 49716370aa65f383d6ce61c0b6357d3db93df1a4 Author: Jaybit0 Date: Thu Nov 21 14:14:29 2024 +0100 Bugfix commit 99b720e246004b68783618d8a4bd7d471b40b2f9 Author: Jaybit0 Date: Thu Nov 21 14:07:50 2024 +0100 Some improvements commit f11e557bc8c1f449d7015b17923e257a48bbcaf6 Author: Jaybit0 Date: Thu Nov 21 12:09:46 2024 +0100 Some more fixes commit dd5cad5f1dcd7f808993653bbe91b5504ce027ce Author: Jaybit0 Date: Thu Nov 21 11:44:49 2024 +0100 Some bugfixes commit 520ff961bd3b91c8423b5a14907b6848cbe6eab8 Author: Jaybit0 Date: Thu Nov 21 10:43:43 2024 +0100 Some improvements to the rules commit 7594ae9911ace64e8d719f72669b4b6c0b6cfd6d Author: Jaybit0 Date: Wed Nov 20 16:23:02 2024 +0100 Some more improvements but clustering not working yet commit 29cfa20e1389edf866c9889699ae10cb7b383ffe Author: Jaybit0 Date: Wed Nov 20 16:11:54 2024 +0100 Some improvements commit a4e0c829a0ef3d221af12f5de7087865b055e0cc Author: Jaybit0 Date: Wed Nov 20 15:40:05 2024 +0100 Some more improvements to the system commit 9c1a445afad143f80be3ea0f7868660cef56275d Author: Jaybit0 Date: Wed Nov 20 13:05:20 2024 +0100 Minor bugfix commit fa50fef9469974f67d55d42d887826d6d6bed045 Author: Jaybit0 Date: Wed Nov 20 12:56:51 2024 +0100 Some more improvements commit be2d33b6de241c862292f546ef4004af0b949d9c Author: Jaybit0 Date: Wed Nov 20 11:48:43 2024 +0100 Some more cost estimation functionality commit fa79b4a983e86d0faf2959d1307be83cc6d05e5c Author: Jaybit0 Date: Wed Nov 20 11:10:35 2024 +0100 Improved cost estimation commit 6411b5a19839fb97f4d7139444761aa11708d892 Author: Jaybit0 Date: Wed Nov 20 10:32:57 2024 +0100 Some bugfixes commit a32da49a8065b3a8f1bb38d721a21eb725f7218b Author: Jaybit0 Date: Tue Nov 19 14:29:05 2024 +0100 Some more improvements commit f0600da622194c1e18328d6e98e99abda8560019 Author: Jaybit0 Date: Tue Nov 19 12:24:05 2024 +0100 Some further improvements commit e309e909e607008eca626286e58e00687336d828 Author: Jaybit0 Date: Tue Nov 19 11:04:43 2024 +0100 Backup commit 642b7ac2d82d43381433af18688726df6698159f Author: Jaybit0 Date: Mon Nov 18 17:44:27 2024 +0100 Update DMLCodeGenTest.java commit 47c43eedae589eb0f9972081a5c227824796dc33 Author: Jaybit0 Date: Mon Nov 18 17:05:32 2024 +0100 Some improvements commit 65ef982c36d872066eb7dc3dcf2e382da61ec5d6 Author: Jaybit0 Date: Mon Nov 18 11:59:23 2024 +0100 Some more improvements commit 184a0df15bccff8a56a1e876a22d8bbd51634539 Author: Jaybit0 Date: Sun Nov 17 17:46:00 2024 +0100 Some more improvements commit 9c523e12ae80a21c3f9a3f4deef9787577fe648d Author: Jaybit0 Date: Sun Nov 17 16:53:08 2024 +0100 Bugfixes commit 6e037e201602846a1c1a516f0d3d38125d179ede Author: Jaybit0 Date: Fri Nov 15 13:59:07 2024 +0100 Backup commit 9a5a5d6d1bf730d17e70547f164f826263685fba Author: Jaybit0 Date: Fri Nov 15 13:41:46 2024 +0100 Some changes commit b858f933c0596cf4d174eed0b3d854bdd68b5601 Author: Jaybit0 Date: Fri Nov 15 12:21:28 2024 +0100 Some changes commit 3cddac5c56cb9f9d9a05b0a9762fc74acad72ac0 Author: Jaybit0 Date: Fri Nov 15 11:27:00 2024 +0100 Some debugging commit 61bb3bf97a1bfe33519e650dd9cd803ff0dce90f Author: Jaybit0 Date: Fri Nov 15 11:21:41 2024 +0100 Some improvements commit 4da4c8a05f0eb736a7e1f8fd8198d12d3569f09c Author: Jaybit0 Date: Thu Nov 14 17:55:03 2024 +0100 Some more improvements commit d5fe2cf849e39e62a9c891fa930f04b1b25f171c Author: Jaybit0 Date: Thu Nov 14 14:28:43 2024 +0100 AlphabetEncoder commit 8b866c6c41baf22075b0ff4fc2e9f72bbd1209c1 Author: Jaybit0 Date: Thu Nov 14 13:16:20 2024 +0100 Validation script implementation commit e153ae1431b160ab6a13810a8295aaabdef0a354 Author: Jaybit0 Date: Thu Nov 14 11:29:18 2024 +0100 Some changes commit 8877e0311c4d72080a983a65bd2797be1d4442d0 Author: Jaybit0 Date: Wed Nov 13 17:56:01 2024 +0100 Some improvements commit b7c5943b22cb90aba70c3a13b24817369a95306d Author: Jaybit0 Date: Wed Nov 13 16:47:27 2024 +0100 Some more stuff commit 09692e7c0ca840a3be7a60778a474e61c4e856a4 Author: Jaybit0 Date: Wed Nov 13 15:52:52 2024 +0100 Some more changes commit eeec2c11d401ef1898ecf2cd6c8520c3a6c368b0 Author: Jaybit0 Date: Wed Nov 13 14:23:18 2024 +0100 Some bugfixes commit 85141922cf71145e583025f827282901510fd553 Author: Jaybit0 Date: Wed Nov 13 12:18:36 2024 +0100 Update RewriterCodeGen.java commit 09619989c82f8ef7e6e7be9a160afff7de23ca71 Author: Jaybit0 Date: Wed Nov 13 11:29:09 2024 +0100 Some more codegen commit a8da7cfeba4d3f63608db24985d517854a0416ee Author: Jaybit0 Date: Tue Nov 12 16:25:04 2024 +0100 First CodeGen implementation commit d12aab19c8bcc29ec04e42e7ca4daf7c20cdb50c Author: Jaybit0 Date: Tue Nov 12 11:36:32 2024 +0100 Major improvements commit 5219001d4ec6c080611e767793436b8e684b18d4 Author: Jaybit0 Date: Mon Nov 11 16:46:34 2024 +0100 Some more improvements commit cb05568a80d7c26e6c99d93c0ab826688bb44a2b Author: Jaybit0 Date: Mon Nov 11 15:17:36 2024 +0100 Some improvements commit 8cbddbf473d4e5f8c96637f8d40ea5d2a9056c85 Author: Jaybit0 Date: Mon Nov 11 14:32:21 2024 +0100 Assertion breaking changes commit 497465df1929d253bba2ac1421b3c3f38894f42a Author: Jaybit0 Date: Fri Nov 8 11:43:23 2024 +0100 Some more improvements commit 60436f33b95d0dbbdc8dc7cd191d56f4615a7f46 Author: Jaybit0 Date: Thu Nov 7 18:05:21 2024 +0100 Some more improvements commit 44c85c340efd6df949faac6d91d63e693e75c6f9 Author: Jaybit0 Date: Thu Nov 7 17:23:12 2024 +0100 First RuleCreator implementation commit 4990f9bd2910f7ad61a4cd3d7438002709432b65 Author: Jaybit0 Date: Thu Nov 7 12:43:33 2024 +0100 Rule generator commit 0286e480a5ef163f748063346e54110bccaa77c6 Author: Jaybit0 Date: Thu Nov 7 11:18:42 2024 +0100 Checkpoint commit ba8d1b718bfe2fe154ba5fe26c43691fd7cb5a04 Author: Jaybit0 Date: Thu Nov 7 10:29:01 2024 +0100 Some fixes commit e63f5e95490a8d2b123b53ba8f45bfc700aedf84 Author: Jaybit0 Date: Wed Nov 6 12:46:02 2024 +0100 Cost estimator commit 7f0e47eb2085d7fb53906bdb85c3de092c9ae465 Author: Jaybit0 Date: Tue Nov 5 17:01:26 2024 +0100 Some more cost estimates commit ac83c6612b5ce65566c63cc8cf30d4624c9f52a1 Author: Jaybit0 Date: Tue Nov 5 16:28:33 2024 +0100 Some more fixes commit c57aa33082cc2579c976fe358d81722eb3d80a34 Author: Jaybit0 Date: Tue Nov 5 15:51:52 2024 +0100 Constant folding commit 6e7c77ecfe1806b32778cd565b053468976ed05c Author: Jaybit0 Date: Tue Nov 5 11:24:12 2024 +0100 Fix commit 3a90482f0c4ffa1d2c43833a4c90116a909f4a4a Author: Jaybit0 Date: Tue Nov 5 10:43:26 2024 +0100 Bugfix commit 143b68a734b25f2efe671c45cda490ecb001b1cb Author: Jaybit0 Date: Mon Nov 4 14:02:28 2024 +0100 Bugfix commit 8ee743152280ceb62592c9984c3d64838f07c33d Author: Jaybit0 Date: Sun Nov 3 12:53:08 2024 +0100 Minor bugfix updating types commit e84eb00e93c6510c37bbef0ab51a9308bd947486 Author: Jaybit0 Date: Sun Nov 3 12:26:35 2024 +0100 Bugfix commit 522c46bf50270107141e5b4c570bf24aa9b2d296 Author: Jaybit0 Date: Sun Nov 3 12:06:05 2024 +0100 Bugfix commit 5bda1004266234c3830e5e7ecaba9c8555d71c16 Author: Jaybit0 Date: Wed Oct 30 10:38:22 2024 +0100 Better match filtering commit ccf1e01e2b064bc64cda428ce5c9456685db25ed Author: Jaybit0 Date: Tue Oct 29 17:20:15 2024 +0100 Some first minimal difference implementation commit f4a6ec6d18267eac6d7bcb0c0cd3aa06904f0992 Author: Jaybit0 Date: Tue Oct 29 16:57:01 2024 +0100 Begin of finding minimal difference commit 31f802939703a933f8e45fc84807e308de84796c Author: Jaybit0 Date: Tue Oct 29 15:56:48 2024 +0100 Some major performance improvements commit 0e15664edeb721a8b92f6cd23c5a3286a11e6e28 Author: Jaybit0 Date: Tue Oct 29 15:21:26 2024 +0100 Some more improvements (probably expensive) commit ecd2e597926b4df760d38dd5180612e506710e74 Author: Jaybit0 Date: Mon Oct 28 15:59:46 2024 +0100 Some improvements commit 3337477d15f64eea6d0b6cc7cade4e784945daf6 Author: Jaybit0 Date: Fri Oct 25 16:38:10 2024 +0200 FIXXXX commit 9c527b7a9e4ea7eaa5a26c9b29946afd31f12df0 Author: Jaybit0 Date: Fri Oct 25 14:48:21 2024 +0200 FIX commit f223e4e6d9eaa635229755d4d8c11f595ab69549 Author: Jaybit0 Date: Thu Oct 24 15:22:55 2024 +0200 Update RewriterAssertions.java commit 2dcb143d5f2aca3f709ebfbd8acde58008e511c0 Author: Jaybit0 Date: Wed Oct 23 15:37:59 2024 +0200 Better equality matching commit 21dd17f9ee128086ed969875a7ab2dfec59c7824 Author: Jaybit0 Date: Wed Oct 23 15:20:10 2024 +0200 Topological sort bugfix commit 636ea9d4a99e2f8c50ea65162107c2d6bbb3fd0f Author: Jaybit0 Date: Wed Oct 23 13:11:49 2024 +0200 Some fixes commit 0aae424e30b67e448a4c599a1c8209851978114b Author: Jaybit0 Date: Wed Oct 23 12:02:35 2024 +0200 Some E-Class behavior commit 20eebf41ac930d927e9da07a444e5caedd07926d Author: Jaybit0 Date: Wed Oct 23 11:42:32 2024 +0200 Some progress with equivalence classes commit 5090b7a6deb4dafb4338cf7d76a96c88d4130c47 Author: Jaybit0 Date: Tue Oct 22 15:51:31 2024 +0200 Assertions initial commit commit 01622b99d9fb64588a457c4f6518586d1838feaa Author: Jaybit0 Date: Tue Oct 22 12:35:57 2024 +0200 Some fixes commit fabe1bcd4ed52e7e4b67d31051b61cecafe3356a Author: Jaybit0 Date: Tue Oct 22 10:59:02 2024 +0200 Some improvements to the system commit ea973d34aa9b24569b330b5091b4769c1a41287f Author: Jaybit0 Date: Mon Oct 21 17:11:42 2024 +0200 Partial fix (not working yet) commit ed4e67dd41e9f3e330a66799d8a24f3ecf6eb1bf Author: Jaybit0 Date: Mon Oct 21 16:27:48 2024 +0200 Some bugfixes (still buggy though) commit 708d72d2883f7140e816f67323f3b4b922b9a33b Author: Jaybit0 Date: Mon Oct 21 13:12:24 2024 +0200 Better topological sort commit 30db050d464794e31f3732c3f5a38064c40af6a1 Author: Jaybit0 Date: Fri Oct 18 19:38:25 2024 +0200 Some breaking changes for the topological sort commit 4e549baf8316ce8dc09bef8d13ab4ee35b8589bf Author: Jaybit0 Date: Thu Oct 17 16:07:51 2024 +0200 Some progress with the topological sort commit 20c12a023509f9e446d6546b98015897138c5b20 Author: Jaybit0 Date: Thu Oct 17 12:47:20 2024 +0200 Some more changes commit a6748eab392183863a310c522289bfa4045de2a6 Author: Jaybit0 Date: Thu Oct 17 11:08:34 2024 +0200 Some breaking changes Readability of printed statements Bugfixes regarding heuristics More heuristic for diag commit 1b6d9ce4098812235b976caf8b078ff4464a3a2c Author: Jaybit0 Date: Wed Oct 16 17:43:42 2024 +0200 Some more tests commit 4dab0e07f721463629abadf86c0370c327994cf7 Author: Jaybit0 Date: Wed Oct 16 16:03:01 2024 +0200 Some improvements commit e146e64de9ed69ec8a7dc108e32def25238d4c98 Author: Jaybit0 Date: Wed Oct 16 15:39:00 2024 +0200 Some more updates commit efa3e3366ef81f10dcf11979885d8a7aaa80039f Author: Jaybit0 Date: Wed Oct 16 11:26:16 2024 +0200 Some bugfixes commit 9bb9dbe2de8a8b80c1137d49a5489a570342d243 Author: Jaybit0 Date: Tue Oct 15 10:20:31 2024 +0200 Crucial bugfix commit 8a8ada472685fc9d2c3413bd2f345036f4c5f0ee Author: Jaybit0 Date: Mon Oct 14 15:14:26 2024 +0200 Get all substatements commit 2c8584a78d306551fdb17ae18584c34c3bc594c0 Author: Jaybit0 Date: Mon Oct 14 12:46:37 2024 +0200 Some more changes commit 9d8ff520fe58ead9d99ae71e1e63c78a2293a336 Author: Jaybit0 Date: Mon Oct 14 10:41:49 2024 +0200 Bugfix commit 26276d893546f186b1a5ff93492df18f7fa0fe45 Author: Jaybit0 Date: Sun Oct 13 12:54:18 2024 +0200 Update RewriterRuntimeUtils.java commit da990677943de37cd70d2688c1014870e4da6e18 Author: Jaybit0 Date: Sat Oct 12 13:20:41 2024 +0200 Setup commit 8eed966ea221daddca4b8f19af8778bfbbbadf9a Merge: 6af0c436d7 d0f0837d61 Author: Jaybit0 Date: Thu Oct 3 16:32:21 2024 +0200 Merge remote-tracking branch 'upstream/main' commit 6af0c436d7938de4dc9d78928a9f089dbb5f6ec2 Author: Jaybit0 Date: Fri Aug 16 11:54:39 2024 +0200 Update Statistics.java commit 88b46114114d656256329b4512159f25e0acb37d Author: Jaybit0 Date: Thu Aug 15 13:32:51 2024 +0200 Update ProgramBlock.java commit aad1728601fbda25725be14384b91480ff211dd3 Author: Jaybit0 Date: Thu Aug 15 13:42:45 2024 +0200 Update Statistics.java commit 62ee0ee44186114e3ff26e20acaef5a0782acd58 Author: Jaybit0 Date: Thu Aug 15 13:32:51 2024 +0200 Update ProgramBlock.java commit 746abdc0fb43c45e4594e680df3878966e45af90 Author: Jaybit0 Date: Thu Aug 15 13:17:42 2024 +0200 CSV Stream improvement commit 136880847729fe4c9b686b849c6f9122ff80c9fa Author: Jaybit0 Date: Thu Aug 15 12:53:02 2024 +0200 Add some MetaData commit 1271660db1f5e500c3d717b459d21804a1dea85e Author: Jaybit0 Date: Wed Aug 14 10:47:41 2024 +0200 More detailed statistics for matrices commit cb44f7a8334458eb29c2460e13ee8a164b831a21 Author: Jaybit0 Date: Fri Aug 9 10:28:10 2024 +0200 Update L2SVMTest.java commit d85363eb062f62aec4bb7b4b58e06e68c3ab0ee1 Author: Jaybit0 Date: Tue Aug 6 11:05:18 2024 +0200 Minor bugfix to prevent issues with multithreading commit 0cf96b4086dd85c7e80ba9c1999664c54d10d4cd Author: Jaybit0 Date: Sun Aug 4 10:47:22 2024 +0200 Bugfix commit dfa9c6ca1567aeed9d99f09e9958926f284ceb6a Author: Jaybit0 Date: Sun Aug 4 10:12:44 2024 +0200 Workaround to avoid increasing object size of LineageItem commit f2cf05c9bcead79309a6f5e41030a6256c583417 Author: Jaybit0 Date: Sat Aug 3 12:44:54 2024 +0200 Update Statistics.java commit 90f30186f74c2bd9b8cf249a26787142cce3e03b Author: Jaybit0 Date: Fri Aug 2 14:24:11 2024 +0200 Index-aware lineage ngrams and time averaging over number of recorded inputs commit 76b161b4b39d6bcc96cd289c3de50cfe9695d1cd Author: Jaybit0 Date: Fri Aug 2 13:32:27 2024 +0200 NGrams now use lineage per default commit c713a776f9934ab52c97c48aae420810d73f0018 Author: Jaybit0 Date: Thu Jul 4 12:45:33 2024 +0200 Some more bugfixes commit 05f50333343a881af04ada0625950a10dd574a4d Author: Jaybit0 Date: Mon Jul 1 16:12:52 2024 +0200 More extensive stats for n-grams --- .../java/org/apache/sysds/api/DMLOptions.java | 5 + .../java/org/apache/sysds/api/DMLScript.java | 15 +- .../sysds/hops/rewrite/HopRewriteUtils.java | 102 + .../sysds/hops/rewrite/ProgramRewriter.java | 9 + .../sysds/hops/rewriter/MetaPropagator.java | 369 + .../rewriter/RewriterContextSettings.java | 299 + .../sysds/hops/rewriter/RewriterDataType.java | 487 + .../sysds/hops/rewriter/RewriterDatabase.java | 107 + .../rewriter/RewriterEquivalenceDatabase.java | 127 + .../hops/rewriter/RewriterFramework.java | 493 + .../hops/rewriter/RewriterInstruction.java | 627 + .../hops/rewriter/RewriterRuntimeUtils.java | 938 + .../hops/rewriter/RewriterStatement.java | 1092 + .../hops/rewriter/RewriterStatementEntry.java | 58 + .../sysds/hops/rewriter/RuleContext.java | 208 + .../sysds/hops/rewriter/TopologicalSort.java | 543 + .../assertions/RewriterAssertionUtils.java | 89 + .../assertions/RewriterAssertions.java | 751 + .../rewriter/codegen/CodeGenCondition.java | 622 + .../rewriter/codegen/RewriterCodeGen.java | 807 + .../hops/rewriter/dml/DMLCodeGenerator.java | 383 + .../sysds/hops/rewriter/dml/DMLExecutor.java | 135 + .../estimators/RewriterCostEstimator.java | 947 + .../estimators/RewriterSparsityEstimator.java | 136 + .../generated/GeneratedRewriteClass.java | 4044 ++++ .../RewriteAutomaticallyGenerated.java | 137 + .../hops/rewriter/rule/RewriterHeuristic.java | 115 + .../rule/RewriterHeuristicTransformation.java | 47 + .../rewriter/rule/RewriterHeuristics.java | 140 + .../hops/rewriter/rule/RewriterRule.java | 489 + .../rewriter/rule/RewriterRuleBuilder.java | 543 + .../rewriter/rule/RewriterRuleCollection.java | 1445 ++ .../rewriter/rule/RewriterRuleCreator.java | 537 + .../hops/rewriter/rule/RewriterRuleSet.java | 345 + .../hops/rewriter/utils/CodeGenUtils.java | 498 + .../rewriter/utils/ConstantFoldingUtils.java | 184 + .../rewriter/utils/RewriterSearchUtils.java | 618 + .../hops/rewriter/utils/RewriterUtils.java | 1375 ++ .../hops/rewriter/utils/StatementUtils.java | 60 + .../org/apache/sysds/utils/Statistics.java | 30 + .../apache/sysds/test/AutomatedTestBase.java | 72 + .../sysds/test/applications/L2SVMTest.java | 2 +- .../rewrite/RewriterNormalFormTests.java | 561 + .../rewrite/RewriterRuleValidationTest.java | 89 + .../codegen/rewrite/RewriterStreamTests.java | 1751 ++ .../rewrite/RewriterTopologySortTests.java | 265 + .../rewrite/functions/AssertionTests.java | 63 + .../rewrite/functions/CodeExecutionTest.java | 41 + .../functions/CodeGenConditionTests.java | 149 + .../rewrite/functions/CodeGenTests.java | 314 + .../rewrite/functions/CostEstimates.java | 395 + .../rewrite/functions/DMLCodeGenTest.java | 248 + .../rewrite/functions/MinimalDifference.java | 66 + .../functions/RewriterSearchUtilsTest.java | 131 + .../rewrite/functions/RuleCreationTests.java | 295 + .../functions/RuleSerializationTest.java | 168 + .../functions/SparsityEstimationTest.java | 168 + .../functions/SubtreeGeneratorTest.java | 83 + .../rewrite/functions/TestRuleSet.java | 80 + .../rewriterframework/expressions.db | 18610 ++++++++++++++++ 60 files changed, 43505 insertions(+), 2 deletions(-) create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/MetaPropagator.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/RewriterContextSettings.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/RewriterDataType.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/RewriterDatabase.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/RewriterEquivalenceDatabase.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/RewriterFramework.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/RewriterInstruction.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/RewriterRuntimeUtils.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/RewriterStatementEntry.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/RuleContext.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/TopologicalSort.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/assertions/RewriterAssertionUtils.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/assertions/RewriterAssertions.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/codegen/CodeGenCondition.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/codegen/RewriterCodeGen.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/dml/DMLCodeGenerator.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/dml/DMLExecutor.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/estimators/RewriterCostEstimator.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/estimators/RewriterSparsityEstimator.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/generated/GeneratedRewriteClass.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/generated/RewriteAutomaticallyGenerated.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterHeuristic.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterHeuristicTransformation.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterHeuristics.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRule.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleBuilder.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleCollection.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleCreator.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleSet.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/utils/CodeGenUtils.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/utils/ConstantFoldingUtils.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterSearchUtils.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterUtils.java create mode 100644 src/main/java/org/apache/sysds/hops/rewriter/utils/StatementUtils.java create mode 100644 src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterNormalFormTests.java create mode 100644 src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterRuleValidationTest.java create mode 100644 src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterStreamTests.java create mode 100644 src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterTopologySortTests.java create mode 100644 src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/AssertionTests.java create mode 100644 src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeExecutionTest.java create mode 100644 src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeGenConditionTests.java create mode 100644 src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeGenTests.java create mode 100644 src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CostEstimates.java create mode 100644 src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/DMLCodeGenTest.java create mode 100644 src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/MinimalDifference.java create mode 100644 src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RewriterSearchUtilsTest.java create mode 100644 src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RuleCreationTests.java create mode 100644 src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RuleSerializationTest.java create mode 100644 src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/SparsityEstimationTest.java create mode 100644 src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/SubtreeGeneratorTest.java create mode 100644 src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/TestRuleSet.java create mode 100644 src/test/resources/rewriterframework/expressions.db diff --git a/src/main/java/org/apache/sysds/api/DMLOptions.java b/src/main/java/org/apache/sysds/api/DMLOptions.java index a289b29bcde..9bbcbddfaa7 100644 --- a/src/main/java/org/apache/sysds/api/DMLOptions.java +++ b/src/main/java/org/apache/sysds/api/DMLOptions.java @@ -58,6 +58,7 @@ public class DMLOptions { public int[] statsNGramSizes = { 3 }; // Default n-gram tuple sizes public int statsTopKNGrams = 10; // How many of the most heavy hitting n-grams are displayed public boolean statsNGramsUseLineage = true; // If N-Grams use lineage for data-dependent tracking + public boolean applyGeneratedRewrites = false; // If generated rewrites should be applied public boolean fedStats = false; // Whether to record and print the federated statistics public int fedStatsCount = 10; // Default federated statistics count public boolean memStats = false; // max memory statistics @@ -246,6 +247,8 @@ else if (lineageType.equalsIgnoreCase("debugger")) } } + dmlOptions.applyGeneratedRewrites = line.hasOption("applyGeneratedRewrites"); + dmlOptions.fedStats = line.hasOption("fedStats"); if (dmlOptions.fedStats) { String fedStatsCount = line.getOptionValue("fedStats"); @@ -372,6 +375,7 @@ private static Options createCLIOptions() { Option ngramsOpt = OptionBuilder//.withArgName("ngrams") .withDescription("monitors and reports the most occurring n-grams; -ngrams ") .hasOptionalArgs(2).create("ngrams"); + Option applyGeneratedRewritesOpt = OptionBuilder.withArgName("applyGeneratedRewrites").withDescription("if automatically generated rewrites should be applied").create("applyGeneratedRewrites"); Option fedStatsOpt = OptionBuilder.withArgName("count") .withDescription("monitors and reports summary execution statistics of federated workers; heavy hitter is 10 unless overridden; default off") .hasOptionalArg().create("fedStats"); @@ -434,6 +438,7 @@ private static Options createCLIOptions() { options.addOption(cleanOpt); options.addOption(statsOpt); options.addOption(ngramsOpt); + options.addOption(applyGeneratedRewritesOpt); options.addOption(fedStatsOpt); options.addOption(memOpt); options.addOption(explainOpt); diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java b/src/main/java/org/apache/sysds/api/DMLScript.java index d6853891e24..5777128396a 100644 --- a/src/main/java/org/apache/sysds/api/DMLScript.java +++ b/src/main/java/org/apache/sysds/api/DMLScript.java @@ -35,6 +35,8 @@ import java.util.Date; import java.util.Map; import java.util.Scanner; +import java.util.function.BiConsumer; +import java.util.function.Function; import org.apache.commons.cli.AlreadySelectedException; import org.apache.commons.cli.HelpFormatter; @@ -106,6 +108,7 @@ public class DMLScript public static int STATISTICS_TOP_K_NGRAMS = DMLOptions.defaultOptions.statsTopKNGrams; // Set if N-Grams use lineage for data-dependent tracking public static boolean STATISTICS_NGRAMS_USE_LINEAGE = DMLOptions.defaultOptions.statsNGramsUseLineage; + public static boolean APPLY_GENERATED_REWRITES = DMLOptions.defaultOptions.applyGeneratedRewrites; // Set statistics maximum wrap length public static int STATISTICS_MAX_WRAP_LEN = 30; // Enable/disable to print federated statistics @@ -168,6 +171,9 @@ public class DMLScript public static String _uuid = IDHandler.createDistributedUniqueID(); private static final Log LOG = LogFactory.getLog(DMLScript.class.getName()); + public static Function preHopInterceptor = null; // Intercepts HOPs before they are rewritten + public static Function hopInterceptor = null; // Intercepts HOPs after they are rewritten + /////////////////////////////// // public external interface //////// @@ -261,6 +267,7 @@ public static boolean executeScript( String[] args ) STATISTICS_NGRAMS = dmlOptions.statsNGrams; STATISTICS_NGRAM_SIZES = dmlOptions.statsNGramSizes; STATISTICS_TOP_K_NGRAMS = dmlOptions.statsTopKNGrams; + APPLY_GENERATED_REWRITES = dmlOptions.applyGeneratedRewrites; FED_STATISTICS = dmlOptions.fedStats; FED_STATISTICS_COUNT = dmlOptions.fedStatsCount; JMLC_MEM_STATISTICS = dmlOptions.memStats; @@ -456,9 +463,15 @@ private static void execute(String dmlScriptStr, String fnameOptConfig, Map params = new HashMap<>(); + params.put(DataExpression.RAND_ROWS, rows); + params.put(DataExpression.RAND_COLS, cols); + params.put(DataExpression.RAND_MIN, val); + params.put(DataExpression.RAND_MAX, val); + params.put(DataExpression.RAND_PDF, new LiteralOp(DataExpression.RAND_PDF_UNIFORM)); + params.put(DataExpression.RAND_LAMBDA, new LiteralOp(-1.0)); + params.put(DataExpression.RAND_SPARSITY, new LiteralOp(1.0)); + params.put(DataExpression.RAND_SEED, new LiteralOp(DataGenOp.UNSPECIFIED_SEED) ); + + //note internal refresh size information + Hop datagen = new DataGenOp(OpOpDG.RAND, new DataIdentifier("tmp"), params); + datagen.setBlocksize(1000); + //copyLineNumbers(rowInput, datagen); + + if( value==0 ) + datagen.setNnz(0); + + return datagen; + } public static Hop createDataGenOp( Hop rowInput, Hop colInput, double value ) { @@ -661,6 +685,84 @@ public static BinaryOp createBinary(Hop input1, Hop input2, OpOp2 op, boolean ou bop.refreshSizeInformation(); return bop; } + + // To fix issues with createBinary, which does not always correctly set value types (e.g. INT-MATRIX+FLOAT-SCALAR -> bop(+)::INT) + public static BinaryOp createAutoGeneratedBinary(Hop input1, Hop input2, OpOp2 op) { + Hop mainInput = input1.getDataType().isMatrix() ? input1 : + input2.getDataType().isMatrix() ? input2 : input1; + BinaryOp bop = new BinaryOp(mainInput.getName(), getImplicitDataType(input1, input2), + getImplicitValueType(input1, input2), op, input1, input2); + //cleanup value type for relational operations + if( bop.isPPredOperation() && bop.getDataType().isScalar() ) + bop.setValueType(ValueType.BOOLEAN); + bop.setOuterVectorOperation(false); + bop.setBlocksize(mainInput.getBlocksize()); + copyLineNumbers(mainInput, bop); + bop.refreshSizeInformation(); + return bop; + } + + public static DataType getImplicitDataType(Hop... inputs) { + for (int i = 0; i < inputs.length; i++) + if (inputs[i].getDataType().isMatrix()) + return inputs[i].getDataType(); + + return inputs[0].getDataType(); + } + + public static ValueType getImplicitValueType(Hop... inputs) { + ValueType out = null; + for (int i = 0; i < inputs.length; i++) { + switch (inputs[i].getValueType()) { + case FP64: + return inputs[i].getValueType(); + case FP32: + out = inputs[i].getValueType(); + break; + case INT64: + out = implicitValueType(out, ValueType.INT64); + break; + case INT32: + out = implicitValueType(out, ValueType.INT32); + break; + case BOOLEAN: + out = implicitValueType(out, ValueType.BOOLEAN); + break; + } + } + + return out == null ? inputs[0].getValueType() : out; + } + + private static ValueType implicitValueType(ValueType type1, ValueType type2) { + int rank1 = getTypeRank(type1); + int rank2 = getTypeRank(type2); + + if (rank1 == Integer.MIN_VALUE && rank2 == Integer.MIN_VALUE) + return null; + + return rank1 > rank2 ? type1 : type2; + } + + private static int getTypeRank(ValueType vt) { + if (vt == null) + return Integer.MIN_VALUE; + + switch (vt) { + case FP64: + return 5; + case FP32: + return 4; + case INT64: + return 3; + case INT32: + return 2; + case BOOLEAN: + return 1; + } + + return Integer.MIN_VALUE; + } public static AggUnaryOp createSum( Hop input ) { return createAggUnaryOp(input, AggOp.SUM, Direction.RowCol); diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java index b08d836efe5..357f2860bc4 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java @@ -27,6 +27,9 @@ import org.apache.sysds.conf.CompilerConfig.ConfigType; import org.apache.sysds.hops.Hop; import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.hops.rewriter.generated.GeneratedRewriteClass; +import org.apache.sysds.hops.rewriter.generated.RewriteAutomaticallyGenerated; +import org.apache.sysds.hops.rewriter.dml.DMLExecutor; import org.apache.sysds.parser.DMLProgram; import org.apache.sysds.parser.ForStatement; import org.apache.sysds.parser.ForStatementBlock; @@ -83,6 +86,9 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites) _dagRuleSet.add( new RewriteCommonSubexpressionElimination() ); if( OptimizerUtils.ALLOW_CONSTANT_FOLDING ) _dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse + if ( DMLScript.APPLY_GENERATED_REWRITES ) { + _dagRuleSet.add(new RewriteAutomaticallyGenerated(new GeneratedRewriteClass())); + } if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION ) _dagRuleSet.add( new RewriteAlgebraicSimplificationStatic() ); //dependencies: cse if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION ) //dependency: simplifications (no need to merge leafs again) @@ -124,6 +130,9 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites) if ( DMLScript.USE_ACCELERATOR ){ _dagRuleSet.add( new RewriteGPUSpecificOps() ); // gpu-specific rewrites } + if ( DMLScript.APPLY_GENERATED_REWRITES ) { + _dagRuleSet.add(new RewriteAutomaticallyGenerated(new GeneratedRewriteClass())); + } if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES) { _dagRuleSet.add( new RewriteMatrixMultChainOptimization() ); //dependency: cse if( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES2 ) diff --git a/src/main/java/org/apache/sysds/hops/rewriter/MetaPropagator.java b/src/main/java/org/apache/sysds/hops/rewriter/MetaPropagator.java new file mode 100644 index 00000000000..bfe9a1a880f --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/MetaPropagator.java @@ -0,0 +1,369 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter; + +import org.apache.commons.lang3.NotImplementedException; +import org.apache.commons.lang3.mutable.MutableObject; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; + +import java.util.HashMap; +import java.util.Optional; +import java.util.UUID; +import java.util.function.Function; + +/** + * This class is used to propagate dimension information. + * Each instruction that produces a matrix must be implemented here. + */ +public class MetaPropagator implements Function { + private final RuleContext ctx; + + public MetaPropagator(RuleContext ctx) { + this.ctx = ctx; + } + + public RewriterStatement apply(RewriterStatement root) { + RewriterAssertions assertions = root.getAssertions(ctx); + MutableObject out = new MutableObject<>(root); + HashMap literalMap = new HashMap<>(); + + root.forEachPostOrderWithDuplicates((el, parent, pIdx) -> { + RewriterStatement toSet = propagateDims(el, parent, pIdx, assertions); + + if (toSet != null && toSet != el) { + el = toSet; + if (parent == null) + out.setValue(toSet); + else + parent.getOperands().set(pIdx, toSet); + } + + // Assert + if (el.getResultingDataType(ctx).startsWith("MATRIX") + && (el.getNCol() == null || el.getNRow() == null)) + throw new IllegalArgumentException("Some properties have not been set by the meta propagator: " + el.toString(ctx) + " :: " + el.getResultingDataType(ctx)); + + + // Eliminate common literals + if (el.isLiteral()) { + RewriterStatement existingLiteral = literalMap.get(el.getLiteral()); + + if (existingLiteral != null) { + if (parent == null) + out.setValue(existingLiteral); + else + parent.getOperands().set(pIdx, existingLiteral); + } else { + literalMap.put(el.getLiteral(), el); + } + } + + validate(el); + }); + + return out.getValue(); + } + + private RewriterStatement propagateDims(RewriterStatement root, RewriterStatement parent, int pIdx, RewriterAssertions assertions) { + if (root.getResultingDataType(ctx) == null) + throw new IllegalArgumentException("Null type: " + root.toParsableString(ctx)); + if (!root.getResultingDataType(ctx).startsWith("MATRIX")) { + if (root.isInstruction()) { + String ti = root.trueTypedInstruction(ctx); + RewriterStatement ret = null; + + switch (ti) { + case "ncol(MATRIX)": + ret = (RewriterStatement)root.getOperands().get(0).getMeta("ncol"); + break; + case "nrow(MATRIX)": + ret = (RewriterStatement)root.getOperands().get(0).getMeta("nrow"); + break; + } + + if (ret == null) + return null; + + RewriterStatement asserted = assertions != null ? assertions.getAssertionStatement(ret, parent) : null; + + if (asserted == null) + return ret; + + return asserted; + } + return null; + } + + Object colAccess; + Object rowAccess; + + if (root.getOperands() == null || root.getOperands().isEmpty()) { + RewriterStatement ncol = root.getNCol(); + + if (ncol == null) { + root.unsafePutMeta("ncol", new RewriterInstruction().withInstruction("ncol").withOps(root).as(UUID.randomUUID().toString()).consolidate(ctx)); + } + + RewriterStatement nrow = root.getNRow(); + + if (nrow == null) { + root.unsafePutMeta("nrow", new RewriterInstruction().withInstruction("nrow").withOps(root).as(UUID.randomUUID().toString()).consolidate(ctx)); + } + + return null; + } + + if (root.isInstruction()) { + Optional firstMatrixStatement = root.getOperands().stream().filter(el -> el.getResultingDataType(ctx).startsWith("MATRIX")).findFirst(); + switch(root.trueInstruction()) { + // Handle generators + case "rand": + root.unsafePutMeta("nrow", root.getOperands().get(0)); + root.unsafePutMeta("ncol", root.getOperands().get(1)); + return null; + case "as.matrix": + root.unsafePutMeta("ncol", new RewriterDataType().ofType("INT").as("1").asLiteral(1L).consolidate(ctx)); + root.unsafePutMeta("nrow", new RewriterDataType().ofType("INT").as("1").asLiteral(1L).consolidate(ctx)); + return null; + case "argList": + // We assume argLists always occur if the matrix properties don't change + root.unsafePutMeta("nrow", firstMatrixStatement.get().getMeta("nrow")); + root.unsafePutMeta("ncol", firstMatrixStatement.get().getMeta("ncol")); + return null; + case "_map": + root.unsafePutMeta("nrow", root.getOperands().get(1).getMeta("nrow")); + root.unsafePutMeta("ncol", root.getOperands().get(1).getMeta("ncol")); + return null; + case "+": + case "-": + case "*": + case "inv": + case "==": + case "!=": + case "&": + case "|": + case "<": + case ">": + case "abs": + case "round": + case "exp": + case "^": + if (firstMatrixStatement.isEmpty()) + throw new IllegalArgumentException(root.toString(ctx) + " has empty args!"); + root.unsafePutMeta("nrow", firstMatrixStatement.get().getMeta("nrow")); + root.unsafePutMeta("ncol", firstMatrixStatement.get().getMeta("ncol")); + return null; + case "cast.MATRIX": + String mDT = root.getChild(0).getResultingDataType(ctx); + if (mDT.equals("BOOL") || mDT.equals("INT") || mDT.equals("FLOAT")) { + root.unsafePutMeta("ncol", new RewriterDataType().ofType("INT").as("1").asLiteral(1L).consolidate(ctx)); + root.unsafePutMeta("nrow", new RewriterDataType().ofType("INT").as("1").asLiteral(1L).consolidate(ctx)); + return null; + } + case "log_nz": + case "log": + root.unsafePutMeta("nrow", root.getOperands().get(0).getMeta("nrow")); + root.unsafePutMeta("ncol", root.getOperands().get(0).getMeta("ncol")); + return null; + } + + switch(root.trueTypedInstruction(ctx)) { + case "t(MATRIX)": + colAccess = root.getOperands().get(0).getMeta("ncol"); + rowAccess = root.getOperands().get(0).getMeta("nrow"); + root.unsafePutMeta("ncol", rowAccess); + root.unsafePutMeta("nrow", colAccess); + return null; + case "_m(INT,INT,FLOAT)": + case "_m(INT,INT,BOOL)": + case "_m(INT,INT,INT)": + if (root.getOperands().get(0).isInstruction() + && root.getOperands().get(0).trueTypedInstruction(ctx).equals("_idx(INT,INT)")) { + root.unsafePutMeta("nrow", root.getOperands().get(0).getOperands().get(1)); + } else { + root.unsafePutMeta("nrow", new RewriterDataType().ofType("INT").as("1").asLiteral(1L).consolidate(ctx)); + } + + if (root.getOperands().get(1).isInstruction() + && root.getOperands().get(1).trueTypedInstruction(ctx).equals("_idx(INT,INT)")) { + root.unsafePutMeta("ncol", root.getOperands().get(1).getOperands().get(1)); + } else { + root.unsafePutMeta("ncol", new RewriterDataType().ofType("INT").as("1").asLiteral(1L).consolidate(ctx)); + } + return null; + case "%*%(MATRIX,MATRIX)": + rowAccess = root.getOperands().get(0).getMeta("nrow"); + colAccess = root.getOperands().get(1).getMeta("ncol"); + root.unsafePutMeta("nrow", rowAccess); + root.unsafePutMeta("ncol", colAccess); + return null; + case "diag(MATRIX)": + root.unsafePutMeta("nrow", root.getOperands().get(0).getMeta("nrow")); + root.unsafePutMeta("ncol", root.getOperands().get(0).getMeta("ncol")); + return null; + case "[](MATRIX,INT,INT,INT,INT)": + Long[] ints = new Long[4]; + + for (int i = 1; i < 5; i++) + if (root.getChild(i).isLiteral()) + if (root.getChild(i).getLiteral() instanceof Integer) + ints[i-1] = (Long)root.getChild(i).getLiteral(); + + if (ints[0] != null && ints[1] != null) { + String literalString = Long.toString(ints[1] - ints[0] + 1); + root.unsafePutMeta("nrow", RewriterUtils.foldConstants(RewriterUtils.parse(literalString, ctx, "LITERAL_INT:" + literalString), ctx)); + } else { + HashMap subStmts = new HashMap<>(); + subStmts.put("i1", root.getOperands().get(2)); + subStmts.put("i0", root.getOperands().get(1)); + + if (ints[0] != null) { + root.unsafePutMeta("nrow", RewriterUtils.foldConstants(RewriterUtils.parse("+(argList(i1, " + (1 - ints[0]) + "))", ctx, subStmts, "LITERAL_INT:" + (1 - ints[0])), ctx)); + } else if (ints[1] != null) { + root.unsafePutMeta("nrow", RewriterUtils.foldConstants(RewriterUtils.parse("+(argList(" + (ints[1] + 1) + ", -(i0)))", ctx, subStmts, "LITERAL_INT:" + (ints[1] + 1)), ctx)); + } else { + root.unsafePutMeta("nrow", RewriterUtils.foldConstants(RewriterUtils.parse("+(argList(i1, -(i0), 1))", ctx, subStmts, "LITERAL_INT:1"), ctx)); + } + } + + if (ints[2] != null && ints[3] != null) { + root.unsafePutMeta("ncol", ints[3] - ints[2] + 1); + } else { + HashMap subStmts = new HashMap<>(); + subStmts.put("i3", root.getOperands().get(4)); + subStmts.put("i2", root.getOperands().get(3)); + if (ints[2] != null) { + root.unsafePutMeta("ncol", RewriterUtils.foldConstants(RewriterUtils.parse("+(argList(i3, " + (1 - ints[2]) + "))", ctx, subStmts, "LITERAL_INT:" + (1 - ints[2])), ctx)); + } else if (ints[3] != null) { + root.unsafePutMeta("ncol", RewriterUtils.foldConstants(RewriterUtils.parse("+(argList(" + (ints[3] + 1) + ", -(i2)))", ctx, subStmts, "LITERAL_INT:" + (ints[3] + 1)), ctx)); + } else { + root.unsafePutMeta("ncol", RewriterUtils.foldConstants(RewriterUtils.parse("+(argList(i3, -(i2), 1))", ctx, subStmts, "LITERAL_INT:1"), ctx)); + } + } + + return null; + case "rowSums(MATRIX)": + root.unsafePutMeta("nrow", root.getOperands().get(0).getMeta("nrow")); + root.unsafePutMeta("ncol", new RewriterDataType().ofType("INT").as("1").asLiteral(1L).consolidate(ctx)); + return null; + case "colSums(MATRIX)": + root.unsafePutMeta("ncol", root.getOperands().get(0).getMeta("ncol")); + root.unsafePutMeta("nrow", new RewriterDataType().ofType("INT").as("1").asLiteral(1L).consolidate(ctx)); + return null; + case "cast.MATRIX(MATRIX)": + root.unsafePutMeta("nrow", root.getOperands().get(0).getMeta("nrow")); + root.unsafePutMeta("ncol", root.getOperands().get(0).getMeta("ncol")); + return null; + case "RBind(MATRIX,MATRIX)": + HashMap mstmts = new HashMap<>(); + mstmts.put("row1", (RewriterStatement)root.getOperands().get(0).getMeta("nrow")); + mstmts.put("row2", (RewriterStatement)root.getOperands().get(1).getMeta("nrow")); + root.unsafePutMeta("nrow", RewriterUtils.parse("+(argList(row1, row2))", ctx, mstmts)); + root.unsafePutMeta("ncol", root.getOperands().get(0).getMeta("ncol")); + return null; + case "CBind(MATRIX,MATRIX)": + mstmts = new HashMap<>(); + mstmts.put("col1", (RewriterStatement)root.getOperands().get(0).getMeta("ncol")); + mstmts.put("col2", (RewriterStatement)root.getOperands().get(1).getMeta("ncol")); + root.unsafePutMeta("nrow", root.getOperands().get(0).getMeta("nrow")); + root.unsafePutMeta("ncol", RewriterUtils.parse("+(argList(col1, col2))", ctx, mstmts)); + return null; + + // Fused ops + case "1-*(MATRIX,MATRIX)": + case "log_nz(MATRIX)": + root.unsafePutMeta("nrow", root.getOperands().get(0).getMeta("nrow")); + root.unsafePutMeta("ncol", root.getOperands().get(0).getMeta("ncol")); + return null; + case "const(MATRIX,FLOAT)": + root.unsafePutMeta("nrow", root.getOperands().get(0).getMeta("nrow")); + root.unsafePutMeta("ncol", root.getOperands().get(0).getMeta("ncol")); + return null; + case "rowVec(MATRIX)": + root.unsafePutMeta("nrow", root.getOperands().get(0).getMeta("nrow")); + root.unsafePutMeta("ncol", RewriterStatement.literal(ctx, 1L)); + return null; + case "colVec(MATRIX)": + root.unsafePutMeta("ncol", root.getOperands().get(0).getMeta("ncol")); + root.unsafePutMeta("nrow", RewriterStatement.literal(ctx, 1L)); + return null; + case "cellMat(MATRIX)": + root.unsafePutMeta("ncol", RewriterStatement.literal(ctx, 1L)); + root.unsafePutMeta("nrow", RewriterStatement.literal(ctx, 1L)); + return null; + case "rev(MATRIX)": + case "replace(MATRIX,FLOAT,FLOAT)": + case "sumSq(MATRIX)": + case "+*(MATRIX,FLOAT,MATRIX)": + case "-*(MATRIX,FLOAT,MATRIX)": + case "*2(MATRIX)": + case "sq(MATRIX)": + case "!(MATRIX)": + root.unsafePutMeta("nrow", root.getChild(0).getMeta("nrow")); + root.unsafePutMeta("ncol", root.getChild(0).getMeta("ncol")); + return null; + } + + RewriterInstruction instr = (RewriterInstruction) root; + + if (instr.getProperties(ctx).contains("ElementWiseInstruction")) { + if (root.getOperands().get(0).getResultingDataType(ctx).startsWith("MATRIX")) { + root.unsafePutMeta("nrow", root.getOperands().get(0).getMeta("nrow")); + root.unsafePutMeta("ncol", root.getOperands().get(0).getMeta("ncol")); + } else { + root.unsafePutMeta("nrow", root.getOperands().get(1).getMeta("nrow")); + root.unsafePutMeta("ncol", root.getOperands().get(1).getMeta("ncol")); + } + + return null; + } + + if (instr.getProperties(ctx).contains("ElementWiseUnary.FLOAT")) { + if (root.getOperands().get(0).getResultingDataType(ctx).startsWith("MATRIX")) { + root.unsafePutMeta("nrow", root.getOperands().get(0).getMeta("nrow")); + root.unsafePutMeta("ncol", root.getOperands().get(0).getMeta("ncol")); + } else { + root.unsafePutMeta("nrow", root.getOperands().get(1).getMeta("nrow")); + root.unsafePutMeta("ncol", root.getOperands().get(1).getMeta("ncol")); + } + + return null; + } + + throw new NotImplementedException("Unknown instruction: " + instr.trueTypedInstruction(ctx) + "\n" + instr.toParsableString(ctx)); + } + + return null; + } + + private void validate(RewriterStatement stmt) { + if (stmt.isInstruction()) { + if (stmt.trueInstruction().equals("_idx") && (stmt.getMeta("ownerId") == null || stmt.getMeta("idxId") == null)) + throw new IllegalArgumentException(stmt.toString(ctx)); + + if (stmt.trueInstruction().equals("_m") && stmt.getMeta("ownerId") == null) + throw new IllegalArgumentException(stmt.toString(ctx)); + + if (stmt.getResultingDataType(ctx) == null) + throw new IllegalArgumentException(stmt.toString(ctx)); + } + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterContextSettings.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterContextSettings.java new file mode 100644 index 00000000000..f1cc25fa095 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterContextSettings.java @@ -0,0 +1,299 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter; + +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; + +import java.util.List; +import java.util.Random; + +public class RewriterContextSettings { + + public static final List ALL_TYPES = List.of("FLOAT", "INT", "BOOL", "MATRIX"); + public static final List SCALARS = List.of("FLOAT", "INT", "BOOL"); + + public static String getDefaultContextString() { + StringBuilder builder = new StringBuilder(); + ALL_TYPES.forEach(t -> { + builder.append("argList(" + t + ")::" + t + "...\n"); + builder.append("argList(" + t + "...)::" + t + "...\n"); + }); // This is a meta function that can take any number of arguments + + builder.append("CBind(MATRIX,MATRIX)::MATRIX\n"); // This instruction is not really supported + builder.append("RBind(MATRIX,MATRIX)::MATRIX\n"); // This instruction is not really supported + + builder.append("sum(MATRIX)::FLOAT\n"); + builder.append("rowSums(MATRIX)::MATRIX\n"); + builder.append("colSums(MATRIX)::MATRIX\n"); + + builder.append("max(MATRIX)::FLOAT\n"); // Support for min/max is limited + builder.append("min(MATRIX)::FLOAT\n"); // Support for min/max is limited + + builder.append("%*%(MATRIX,MATRIX)::MATRIX\n"); + + builder.append("rev(MATRIX)::MATRIX\n"); + builder.append("t(MATRIX)::MATRIX\n"); + + RewriterUtils.buildBinaryPermutations(List.of("INT", "FLOAT", "BOOL"), (t1, t2) -> { + builder.append("BinaryScalarInstruction(" + t1 + "," + t2 + ")::" + RewriterUtils.defaultTypeHierarchy(t1, t2) + "\n"); + builder.append("impl ElementWiseInstruction\n"); + }); + + RewriterUtils.buildBinaryPermutations(List.of("MATRIX...", "MATRIX", "INT", "FLOAT", "BOOL"), (t1, t2) -> { + builder.append("ElementWiseInstruction(" + t1 + "," + t2 + ")::" + RewriterUtils.defaultTypeHierarchy(t1, t2) + "\n"); + builder.append("impl ElementWiseSumExpandableInstruction\n"); + builder.append("impl /\n"); + builder.append("impl max\n"); + builder.append("impl min\n"); + builder.append("impl ^\n"); + builder.append("impl >\n"); + builder.append("impl <\n"); + builder.append("impl >=\n"); + builder.append("impl <=\n"); + builder.append("impl ==\n"); + builder.append("impl |\n"); + builder.append("impl &\n"); + builder.append("impl /\n"); + builder.append("impl !=\n"); + }); + + builder.append("ElementWiseInstruction(MATRIX...)::MATRIX\n"); + builder.append("impl ElementWiseSumExpandableInstruction\n"); + builder.append("impl /\n"); + builder.append("impl max\n"); + builder.append("impl min\n"); + builder.append("impl ^\n"); + builder.append("impl >\n"); + builder.append("impl <\n"); + builder.append("impl >=\n"); + builder.append("impl <=\n"); + builder.append("impl ==\n"); + builder.append("impl |\n"); + builder.append("impl &\n"); + + RewriterUtils.buildBinaryPermutations(List.of("MATRIX...", "MATRIX", "INT", "FLOAT", "BOOL"), (t1, t2) -> { + builder.append("ElementWiseSumExpandableInstruction(" + t1 + "," + t2 + ")::" + RewriterUtils.defaultTypeHierarchy(t1, t2) + "\n"); // Any instruction that allows op(sum(A*), sum(B*)) = sum(op(A, B)) + builder.append("impl ElementWiseAdditiveInstruction\n"); + builder.append("impl *\n"); + + builder.append("ElementWiseAdditiveInstruction(" + t1 + "," + t2 + ")::" + RewriterUtils.defaultTypeHierarchy(t1, t2) + "\n"); + builder.append("impl +\n"); + builder.append("impl -\n"); + }); + + builder.append("ElementWiseAdditiveInstruction(MATRIX...)::MATRIX\n"); + builder.append("impl +\n"); + //builder.append("impl -\n"); + + + ALL_TYPES.forEach(t -> { + builder.append("UnaryElementWiseOperator(" + t + ")::" + t + "\n"); + builder.append("impl -\n"); + builder.append("impl abs\n"); + builder.append("impl !\n"); + builder.append("impl round\n"); + }); + + builder.append("rowSelect(MATRIX,INT,INT)::MATRIX\n"); + builder.append("colSelect(MATRIX,INT,INT)::MATRIX\n"); + builder.append("min(INT,INT)::INT\n"); + builder.append("max(INT,INT)::INT\n"); + + builder.append("index(MATRIX,INT,INT,INT,INT)::MATRIX\n"); + + RewriterUtils.buildBinaryPermutations(List.of("MATRIX...", "MATRIX", "INT", "FLOAT", "BOOL"), (t1, t2) -> { + builder.append("FusableBinaryOperator(" + t1 + "," + t2 + ")::" + RewriterUtils.defaultTypeHierarchy(t1, t2) + "\n"); + builder.append("impl +\n"); + builder.append("impl *\n"); + }); + + List.of("MATRIX", "INT", "FLOAT", "BOOL").forEach(t -> { + builder.append("FusedOperator(" + t + "...)::" + t + "\n"); + builder.append("impl +\n"); + builder.append("impl *\n"); + }); + + builder.append("ncol(MATRIX)::INT\n"); + builder.append("nrow(MATRIX)::INT\n"); + builder.append("length(MATRIX)::INT\n"); + + RewriterUtils.buildBinaryAlgebraInstructions(builder, "+", List.of("INT", "FLOAT", "BOOL", "MATRIX")); + RewriterUtils.buildBinaryAlgebraInstructions(builder, "*", List.of("INT", "FLOAT", "BOOL", "MATRIX")); + RewriterUtils.buildBinaryAlgebraInstructions(builder, "^", ALL_TYPES); + ALL_TYPES.forEach(t -> builder.append("-(" + t + ")::" + t + "\n")); + ALL_TYPES.forEach(t -> builder.append("inv(" + t + ")::" + t + "\n")); + + + builder.append("as.matrix(INT)::MATRIX\n"); + builder.append("as.matrix(FLOAT)::MATRIX\n"); + builder.append("as.matrix(BOOL)::MATRIX\n"); + builder.append("as.scalar(MATRIX)::FLOAT\n"); + builder.append("as.scalar(FLOAT)::FLOAT\n"); + builder.append("as.float(INT)::FLOAT\n"); + builder.append("as.float(BOOL)::FLOAT\n"); + builder.append("as.int(BOOL)::INT\n"); + + RewriterUtils.buildBinaryPermutations(ALL_TYPES, (tFrom, tTo) -> { + builder.append("cast." + tTo + "(" + tFrom + ")::" + tTo + "\n"); + }); + + builder.append("rand(INT,INT,FLOAT,FLOAT)::MATRIX\n"); // Args: rows, cols, min, max + builder.append("rand(INT,INT)::FLOAT\n"); // Just to make it possible to say that random is dependent on both matrix indices + builder.append("rand(INT...)::FLOAT\n"); + builder.append("matrix(INT,INT,INT)::MATRIX\n"); + + builder.append("trace(MATRIX)::FLOAT\n"); + + // Boole algebra + + RewriterUtils.buildBinaryPermutations(List.of("MATRIX", "FLOAT", "INT", "BOOL"), (t1, t2) -> { + String ret = t1.equals("MATRIX") || t2.equals("MATRIX") ? "MATRIX" : "BOOL"; + builder.append("==(" + t1 + "," + t2 + ")::" + ret + "\n"); + builder.append("!=(" + t1 + "," + t2 + ")::" + ret + "\n"); + builder.append("<(" + t1 + "," + t2 + ")::" + ret + "\n"); + builder.append("<=(" + t1 + "," + t2 + ")::" + ret + "\n"); + builder.append(">(" + t1 + "," + t2 + ")::" + ret + "\n"); + builder.append(">=(" + t1 + "," + t2 + ")::" + ret + "\n"); + builder.append("&(" + t1 + "," + t2 + ")::" + ret + "\n"); + builder.append("|(" + t1 + "," + t2 + ")::" + ret + "\n"); + }); + + List.of("MATRIX", "FLOAT", "INT", "BOOL").forEach(t -> { + builder.append("!(" + t + ")::" + (t.equals("MATRIX") ? "MATRIX" : "BOOL") + "\n"); + }); + + // Expressions that will be rewritten to an equivalent expression + RewriterUtils.buildBinaryPermutations(ALL_TYPES, (t1, t2) -> { + builder.append("-(" + t1+ "," + t2 + ")::" + RewriterUtils.defaultTypeHierarchy(t1, t2) + "\n"); + builder.append("/(" + t1+ "," + t2 + ")::" + RewriterUtils.defaultTypeHierarchy(t1, t2) + "\n"); + }); + + // Unary ops + ALL_TYPES.forEach(t -> { + builder.append("ElementWiseUnary.FLOAT(" + t + ")::" + (t.equals("MATRIX") ? "MATRIX" : "FLOAT") + "\n"); + builder.append("impl sqrt\n"); + builder.append("impl exp\n"); + builder.append("impl log\n"); + builder.append("impl inv\n"); + }); + + builder.append("[](MATRIX,INT,INT)::FLOAT\n"); + builder.append("[](MATRIX,INT,INT,INT,INT)::MATRIX\n"); + builder.append("diag(MATRIX)::MATRIX\n"); + builder.append("replace(MATRIX,FLOAT,FLOAT)::MATRIX\n"); // This is not supported + builder.append("_nnz(MATRIX)::INT\n"); + builder.append("sumSq(MATRIX)::FLOAT\n"); + builder.append("sq(MATRIX)::MATRIX\n"); + builder.append("+*(MATRIX,FLOAT,MATRIX)::MATRIX\n"); + builder.append("-*(MATRIX,FLOAT,MATRIX)::MATRIX\n"); + builder.append("*2(MATRIX)::MATRIX\n"); + + for (String t : SCALARS) { + for (String t2 : SCALARS) + builder.append("ifelse(BOOL," + t + "," + t2 + ")::" + RewriterUtils.convertibleType(t, t2) + "\n"); + } + + + List.of("INT", "FLOAT", "BOOL").forEach(t -> { + String newType = t.equals("BOOL") ? "INT" : t; + builder.append("sum(" + t + "...)::" + newType + "\n"); + builder.append("sum(" + t + "*)::" + newType + "\n"); + builder.append("sum(" + t + ")::" + newType + "\n"); + + builder.append("min(" + t + "...)::" + t + "\n"); + builder.append("min(" + t + "*)::" + t + "\n"); + builder.append("min(" + t + ")::" + t + "\n"); + + builder.append("max(" + t + "...)::" + t + "\n"); + builder.append("max(" + t + "*)::" + t + "\n"); + builder.append("max(" + t + ")::" + t + "\n"); + }); + + // Some fused operators + builder.append("1-*(MATRIX,MATRIX)::MATRIX\n"); // OpOp2.MINUS1_MULT + builder.append("log_nz(MATRIX)::MATRIX\n"); // OpOp1.LOG_NZ + SCALARS.forEach(t -> { + builder.append("log(MATRIX," + t + ")::MATRIX\n"); + builder.append("log_nz(MATRIX," + t + ")::MATRIX\n"); + }); + + builder.append("const(MATRIX,FLOAT)::MATRIX\n"); + + builder.append("rowVec(MATRIX)::MATRIX\n"); + builder.append("colVec(MATRIX)::MATRIX\n"); + builder.append("cellMat(MATRIX)::MATRIX\n"); + + builder.append("_m(INT,INT,FLOAT)::MATRIX\n"); + builder.append("_m(INT,INT,BOOL)::MATRIX\n"); + builder.append("_m(INT,INT,INT)::MATRIX\n"); + List.of("FLOAT", "INT", "BOOL").forEach(t -> { + builder.append("_idxExpr(INT," + t + ")::" + t + "*\n"); + builder.append("_idxExpr(INT," + t + "*)::" + t + "*\n"); + builder.append("_idxExpr(INT...," + t + ")::" + t + "*\n"); + builder.append("_idxExpr(INT...," + t + "*)::" + t + "*\n"); + }); + builder.append("_idx(INT,INT)::INT\n"); + + ALL_TYPES.forEach(t -> builder.append("_EClass(" + t + "...)::" + t + "\n")); + ALL_TYPES.forEach(t -> builder.append("_backRef." + t + "()::" + t + "\n")); + + for (String s : SCALARS) + builder.append("literal." + s + "()::" + s + "\n"); + + return builder.toString(); + } + public static RuleContext getDefaultContext() { + String ctxString = getDefaultContextString(); + + RuleContext ctx = RuleContext.createContext(ctxString); + + ctx.customStringRepr.put("rand(INT,INT,FLOAT,FLOAT)", (stmt, mctx) -> { + List ops = stmt.getOperands(); + return "rand(rows=(" + ops.get(0) + "), cols=(" + ops.get(1) + "), min=(" + ops.get(2) + "), max=(" + ops.get(3) + "))"; + }); + ctx.customStringRepr.put("rand(INT,INT,INT,INT)", ctx.customStringRepr.get("rand(INT,INT,FLOAT,FLOAT)")); + ctx.customStringRepr.put("rand(INT,INT,FLOAT,INT)", ctx.customStringRepr.get("rand(INT,INT,FLOAT,FLOAT)")); + ctx.customStringRepr.put("rand(INT,INT,INT,FLOAT)", ctx.customStringRepr.get("rand(INT,INT,FLOAT,FLOAT)")); + + RewriterUtils.putAsDefaultBinaryPrintable(List.of("<", "<=", ">", ">=", "==", "!=", "&", "|"), List.of("INT", "FLOAT", "BOOL", "MATRIX"), ctx.customStringRepr); + + RewriterUtils.putAsBinaryPrintable("*", List.of("INT", "FLOAT", "MATRIX", "BOOL"), ctx.customStringRepr, RewriterUtils.binaryStringRepr(" * ")); + RewriterUtils.putAsBinaryPrintable("+", List.of("INT", "FLOAT", "MATRIX", "BOOL"), ctx.customStringRepr, RewriterUtils.binaryStringRepr(" + ")); + + ctx.customStringRepr.put("%*%(MATRIX,MATRIX)", RewriterUtils.binaryStringRepr(" %*% ")); + ctx.customStringRepr.put("&&(INT,INT)", RewriterUtils.binaryStringRepr(" && ")); + ctx.customStringRepr.put("index(MATRIX,INT,INT,INT,INT)", (stmt, ctx2) -> { + String out; + RewriterInstruction mInstr = (RewriterInstruction) stmt; + List ops = mInstr.getOperands(); + RewriterStatement op1 = ops.get(0); + + if (op1 instanceof RewriterDataType) + out = op1.toString(ctx2); + else + out = "(" + op1.toString(ctx2) + ")"; + + out += "[" + ops.get(1).toString(ctx2) + " : " + ops.get(2).toString(ctx2) + ", " + ops.get(3).toString(ctx2) + " : " + ops.get(4).toString(ctx2) + "]"; + return out; + }); + + return ctx; + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterDataType.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterDataType.java new file mode 100644 index 00000000000..5e42f7dbd63 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterDataType.java @@ -0,0 +1,487 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter; + +import org.apache.commons.lang3.function.TriFunction; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +public class RewriterDataType extends RewriterStatement { + private String id; + private String type; + private Object literal = null; + private boolean consolidated = false; + private int hashCode; + private RewriterStatement ncol; + private RewriterStatement nrow; + + @Override + protected void compress(RewriterAssertions assertions) { + if (literal != null) + id = null; + + if (meta != null) { + if (type.equals("MATRIX")) { + nrow = getNRow(); + ncol = getNCol(); + + if (assertions != null) { + RewriterStatement mAss1 = assertions.getAssertionStatement(nrow, null); + RewriterStatement mAss2 = assertions.getAssertionStatement(ncol, null); + + if (mAss1 != null) + nrow = mAss1; + + if (mAss2 != null) + ncol = mAss2; + } + } + } + } + + @Override + public RewriterStatement getNRow() { + if (nrow != null) + return nrow; + + return super.getNRow(); + } + + public void setNRow(RewriterStatement stmt) { + nrow = stmt; + } + + @Override + public RewriterStatement getNCol() { + if (ncol != null) + return ncol; + + return super.getNCol(); + } + + public void setNCol(RewriterStatement stmt) { + ncol = stmt; + } + + @Override + public String getId() { + return id; + } + + @Override + public String getResultingDataType(final RuleContext ctx) { + return type; + } + + @Override + public void refreshReturnType(final RuleContext ctx) {} + + @Override + public boolean isLiteral() { + return literal != null && !(literal instanceof List); + } + + @Override + public Object getLiteral() { + return literal; + } + + @Override + public long intLiteral(boolean cast) { + if (getLiteral() instanceof Boolean) + return (boolean)getLiteral() ? 1 : 0; + + if (cast && getLiteral() instanceof Double) { + double val = floatLiteral(); + return (long)val; + } + + return (long)getLiteral(); + } + + @Override + public double floatLiteral() { + if (getLiteral() instanceof Boolean) + return (boolean)getLiteral() ? 1 : 0; + if (getLiteral() instanceof Long) + return Double.valueOf((Long)getLiteral()); + return (double)getLiteral(); + } + + @Override + public boolean boolLiteral() { + if (getLiteral() instanceof Boolean) + return (boolean)getLiteral(); + if (getLiteral() instanceof Long) + return (long)getLiteral() == 0L; + return (double)getLiteral() == 0.0D; + } + + @Override + public void setLiteral(Object literal) { + if (consolidated) + throw new IllegalArgumentException(); + + this.literal = literal; + } + + @Override + public RewriterStatement getLiteralStatement() { + return this; + } + + @Override + public boolean isArgumentList() { + return false; + } + + @Override + public List getArgumentList() { + return null; + } + + @Override + public boolean isInstruction() { + return false; + } + + @Override + public boolean isEClass() { + return false; + } + + @Override + public String trueInstruction() { + return null; + } + + @Override + public String trueTypedInstruction(RuleContext ctx) { + return null; + } + + @Override + public String trueTypedInstruction(boolean allowImplicitConversions, RuleContext ctx) { + return null; + } + + @Override + public RewriterStatement consolidate(final RuleContext ctx) { + if (consolidated) + return this; + + if (!isLiteral() && (id == null || id.isEmpty())) + throw new IllegalArgumentException("The id of a data type cannot be empty"); + if (type == null ||type.isEmpty()) + throw new IllegalArgumentException("The type of a data type cannot be empty"); + + if (isLiteral()) + hashCode = Objects.hash(-1, -1, type, literal); + else + hashCode = Objects.hash(rid, refCtr, type); + return this; + } + + @Override + public int recomputeHashCodes(boolean recursively, final RuleContext ctx) { + if (isLiteral()) + hashCode = Objects.hash(-1, -1, type, literal); + else + hashCode = Objects.hash(rid, refCtr, type); + return hashCode; + } + + @Override + public int structuralHashCode() { + return hashCode; + } + + @Override + public RewriterStatement rename(String id) { + this.id = id; + return this; + } + + @Override + public int hashCode() { + if (isLiteral()) + return hashCode; + + return super.hashCode(); + } + + @Override + public boolean equals(Object o) { + if (isLiteral()) + return o instanceof RewriterDataType && getLiteral().equals(((RewriterDataType)o).getLiteral()); + return super.equals(o); + } + + @Override + public int computeIds(int id) { + if (!isLiteral()) + return super.computeIds(id); + + rid = -1; + return id; + } + + @Override + public void computeRefCtrs() { + refCtr = -1; + } + + @Override + public boolean isConsolidated() { + return consolidated; + } + + @Override + public boolean match(final MatcherContext mCtx) { + RewriterStatement stmt = mCtx.currentStatement; + RuleContext ctx = mCtx.ctx; + String dType = stmt.getResultingDataType(ctx); + + if (!(stmt instanceof RewriterDataType) && !mCtx.statementsCanBeVariables) { + mCtx.setFirstMismatch(this, stmt); + return false; + } + + if (!dType.equals(type)) { + if (!mCtx.allowImplicitTypeConversions || !RewriterUtils.isImplicitlyConvertible(dType, type)) { + if (!mCtx.allowTypeHierarchy) { + mCtx.setFirstMismatch(this, stmt); + return false; + } + + Set types = ctx.typeHierarchy.get(dType); + if (types == null || !types.contains(type)) { + mCtx.setFirstMismatch(this, stmt); + return false; + } + } + } + + if (mCtx.literalsCanBeVariables) { + if (isLiteral()) { + if (!mCtx.ignoreLiteralValues && (!stmt.isLiteral() || !RewriterUtils.compareLiterals(this, (RewriterDataType)stmt, mCtx.allowImplicitTypeConversions))) { + mCtx.setFirstMismatch(this, stmt); + return false; + } + } + } else { + if (isLiteral() != stmt.isLiteral()) { + mCtx.setFirstMismatch(this, stmt); + return false; + } + if (!mCtx.ignoreLiteralValues && isLiteral() && !RewriterUtils.compareLiterals(this, (RewriterDataType)stmt, mCtx.allowImplicitTypeConversions)) { + mCtx.setFirstMismatch(this, stmt); + return false; + } + } + + // If matrix, check if the dimensions + if (!mCtx.statementsCanBeVariables && dType.equals("MATRIX")) { + RewriterStatement ncolEquiv = getNCol(); + RewriterStatement nrowEquiv = getNRow(); + + if (ncolEquiv != null && nrowEquiv != null) { + if (!mCtx.wasVisited(this)) { + mCtx.dontVisitAgain(this); + RewriterStatement ncolEquivThat = stmt.getNCol(); + RewriterStatement nrowEquivThat = stmt.getNRow(); + + RewriterAssertions assertionsThis = mCtx.getOldAssertionsThis(); + RewriterAssertions assertionsThat = mCtx.getOldAssertionsThat(); + + if (assertionsThis != null) { + RewriterStatement ncolAssertion = assertionsThis.getAssertionStatement(ncolEquiv, null); + + RewriterStatement nrowAssertion = assertionsThis.getAssertionStatement(nrowEquiv, null); + ncolEquiv = ncolAssertion == null ? ncolEquiv : ncolAssertion; + nrowEquiv = nrowAssertion == null ? nrowEquiv : nrowAssertion; + } + + if (assertionsThat != null) { + RewriterStatement ncolAssertionThat = assertionsThat.getAssertionStatement(ncolEquivThat, null); + + RewriterStatement nrowAssertionThat = assertionsThat.getAssertionStatement(nrowEquivThat, null); + ncolEquivThat = ncolAssertionThat == null ? ncolEquiv : ncolAssertionThat; + nrowEquivThat = nrowAssertionThat == null ? nrowEquiv : nrowAssertionThat; + } + + // Now, match those statements + mCtx.currentStatement = ncolEquivThat; + if (!ncolEquiv.match(mCtx)) { + if (mCtx.isDebug()) + System.out.println("MismatchNcolEquiv: " + ncolEquiv + " <=> " + ncolEquivThat); + return false; + } + mCtx.currentStatement = nrowEquivThat; + if (!nrowEquiv.match(mCtx)) { + if (mCtx.isDebug()) + System.out.println("MismatchNrowEquiv: " + nrowEquiv + " <=> " + nrowEquivThat); + return false; + } + } + } + } + + RewriterStatement assoc = mCtx.getDependencyMap().get(this); + if (assoc == null) { + if (!mCtx.allowDuplicatePointers && mCtx.getDependencyMap().containsValue(stmt)) { + mCtx.setFirstMismatch(this, stmt); + if (mCtx.isDebug()) + System.out.println("MismatchAssocNull: " + stmt); + return false; // Then the statement variable is already associated with another variable + } + mCtx.getDependencyMap().put(this, stmt); + return true; + } else if (assoc.equals(stmt)) { + return true; + } + + if (mCtx.isDebug()) + System.out.println("MismatchAssoc: " + stmt + " <=> " + assoc); + + mCtx.setFirstMismatch(this, stmt); + return false; + } + + @Override + public RewriterStatement clone() { + return new RewriterDataType().as(id).ofType(type); + } + + @Override + public RewriterStatement copyNode() { + return new RewriterDataType().as(id).ofType(type).asLiteral(literal); + } + + @Override + public RewriterStatement nestedCopyOrInject(Map copiedObjects, TriFunction injector, RewriterStatement parent, int pIdx) { + RewriterStatement mCpy = copiedObjects.get(this); + if (mCpy != null) + return mCpy; + mCpy = injector.apply(this, parent, pIdx); + if (mCpy != null) { + // Then change the reference to the injected object + copiedObjects.put(this, mCpy); + return mCpy; + } + + RewriterDataType mCopy = new RewriterDataType(); + mCopy.id = id; + mCopy.type = type; + if (literal != null && literal instanceof List) { + final ArrayList mList = new ArrayList<>(((List)literal).size()); + mCopy.literal = mList; + ((List) literal).forEach(el -> { + if (el instanceof RewriterStatement) + mList.add(((RewriterStatement)el).nestedCopyOrInject(copiedObjects, injector)); + }); + } else + mCopy.literal = literal; + mCopy.consolidated = consolidated; + mCopy.hashCode = hashCode; + if (meta != null) + mCopy.meta = new HashMap<>(meta); + copiedObjects.put(this, mCopy); + mCopy.nestedCopyOrInjectMetaStatements(copiedObjects, injector); + + return mCopy; + } + + @Override + public RewriterStatement simplify(final RuleContext ctx) { + return this; + } + + public String getType() { + return type; + } + + @Override + public RewriterDataType as(String id) { + if (consolidated) + throw new IllegalArgumentException("A data type cannot be modified after consolidation"); + this.id = id; + return this; + } + + public RewriterDataType ofType(String type) { + if (consolidated) + throw new IllegalArgumentException("A data type cannot be modified after consolidation"); + this.type = type; + return this; + } + + public RewriterDataType asLiteral(Object literal) { + if (consolidated) + throw new IllegalArgumentException("A data type cannot be modified after consolidation"); + this.literal = literal; + return this; + } + + @Override + public int toParsableString(StringBuilder sb, Map refs, int maxRefId, Map> vars, Set forceCreateRefs, final RuleContext ctx) { + String mType = type; + String varStr = id; + + if (isLiteral()) { + mType = "LITERAL_" + type; + varStr = getLiteral().toString(); + + if (getLiteral() instanceof Boolean) + varStr = varStr.toUpperCase(); + } + + Set varSet = vars.get(mType); + + if (varSet == null) { + varSet = new HashSet<>(); + vars.put(mType, varSet); + } + + varSet.add(varStr); + sb.append(varStr); + + return maxRefId; + } + + @Override + public String toString(final RuleContext ctx) { + if (!isLiteral()) + return getId() + "::" + getResultingDataType(ctx) + "[" + hashCode() + "]"; + + if (getLiteral() instanceof Boolean) + return getLiteral().toString().toUpperCase(); + + return getLiteral().toString() + "::" + getResultingDataType(ctx) + "[" + hashCode() + "]"; + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterDatabase.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterDatabase.java new file mode 100644 index 00000000000..7c38e4dc80d --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterDatabase.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter; + +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; + +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; + +public class RewriterDatabase { + + private ConcurrentHashMap db = new ConcurrentHashMap<>(); + + public void clear() { + db.clear(); + } + + public boolean containsEntry(RewriterStatement instr) { + return db.containsKey(instr); + } + + public boolean insertEntry(final RuleContext ctx, RewriterStatement stmt) { + return db.putIfAbsent(new RewriterStatementEntry(ctx, stmt), stmt) == null; + } + + public RewriterStatement find(final RuleContext ctx, RewriterStatement stmt) { + return db.get(new RewriterStatementEntry(ctx, stmt)); + } + + public RewriterStatement insertOrReturn(final RuleContext ctx, RewriterStatement stmt) { + return db.putIfAbsent(new RewriterStatementEntry(ctx, stmt), stmt); + } + + public void forEach(Consumer consumer) { + db.values().forEach(consumer); + } + + public void parForEach(Consumer consumer) { + db.values().parallelStream().forEach(consumer); + } + + public int size() {return db.size(); } + + public void serialize(BufferedWriter writer, final RuleContext ctx) throws IOException { + for (RewriterStatement entry : db.values()) { + writer.write("\n::STMT\n"); + writer.write(entry.toParsableString(ctx, true)); + } + } + + public void deserialize(BufferedReader reader, final RuleContext ctx) throws IOException { + List strBuffer = new ArrayList<>(); + + String line; + while ((line = reader.readLine()) != null) { + if (line.isBlank()) + continue; + + if (line.startsWith("::STMT")) { + if (strBuffer.isEmpty()) + continue; + try { + RewriterStatement stmt = RewriterUtils.parse(String.join("\n", strBuffer), ctx); + insertEntry(ctx, stmt); + strBuffer.clear(); + } catch (Exception e) { + System.err.println("An error occurred while parsing the string:\n" + String.join("\n", strBuffer)); + strBuffer.clear(); + e.printStackTrace(); + } + } else { + strBuffer.add(line); + } + } + + if (!strBuffer.isEmpty()) { + try { + RewriterStatement stmt = RewriterUtils.parse(String.join("\n", strBuffer), ctx); + insertEntry(ctx, stmt); + } catch (Exception e) { + e.printStackTrace(); + } + } + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterEquivalenceDatabase.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterEquivalenceDatabase.java new file mode 100644 index 00000000000..a134ecb893f --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterEquivalenceDatabase.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter; + +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; + +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; + +public class RewriterEquivalenceDatabase { + private ConcurrentHashMap db = new ConcurrentHashMap<>(); + + public void clear() { + db.clear(); + } + + public boolean containsEntry(RewriterStatement instr) { + return db.containsKey(instr); + } + + public DBEntry insert(final RuleContext ctx, RewriterStatement canonicalForm, RewriterStatement equivalence) { + return db.compute(new RewriterStatementEntry(ctx, canonicalForm), (k, v) -> { + if (v == null) + return new DBEntry(canonicalForm, equivalence); + + v.insertEquivalence(equivalence); + return v; + }); + } + + public DBEntry find(final RuleContext ctx, RewriterStatement canonicalForm) { + return db.get(new RewriterStatementEntry(ctx, canonicalForm)); + } + + public void forEach(Consumer consumer) { + db.values().forEach(consumer); + } + + public void parForEach(Consumer consumer) { + db.values().parallelStream().forEach(consumer); + } + + public int size() {return db.size(); } + + @Deprecated + public void serialize(BufferedWriter writer, final RuleContext ctx) throws IOException { + for (DBEntry entry : db.values()) { + writer.write("\n::STMT\n"); + writer.write(entry.canonicalForm.toParsableString(ctx, true)); + } + } + + @Deprecated + public void deserialize(BufferedReader reader, final RuleContext ctx) throws IOException { + List strBuffer = new ArrayList<>(); + + String line; + while ((line = reader.readLine()) != null) { + if (line.isBlank()) + continue; + + if (line.startsWith("::STMT")) { + if (strBuffer.isEmpty()) + continue; + try { + RewriterStatement stmt = RewriterUtils.parse(String.join("\n", strBuffer), ctx); + insert(ctx, stmt, null); + strBuffer.clear(); + } catch (Exception e) { + System.err.println("An error occurred while parsing the string:\n" + String.join("\n", strBuffer)); + strBuffer.clear(); + e.printStackTrace(); + } + } else { + strBuffer.add(line); + } + } + + if (!strBuffer.isEmpty()) { + try { + RewriterStatement stmt = RewriterUtils.parse(String.join("\n", strBuffer), ctx); + insert(ctx, stmt, null); + } catch (Exception e) { + e.printStackTrace(); + } + } + } + + public static class DBEntry { + public final RewriterStatement canonicalForm; + public final List equivalences; + + public DBEntry(RewriterStatement canonicalForm, RewriterStatement firstEquivalence) { + this.canonicalForm = canonicalForm; + this.equivalences = new ArrayList<>(3); + + if (firstEquivalence != null) + this.equivalences.add(firstEquivalence); + } + + public void insertEquivalence(RewriterStatement equivalence) { + equivalences.add(equivalence); + } + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterFramework.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterFramework.java new file mode 100644 index 00000000000..ce57eafa0f5 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterFramework.java @@ -0,0 +1,493 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter; + +import org.apache.commons.collections.list.SynchronizedList; +import org.apache.commons.lang3.mutable.MutableInt; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertionUtils; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; +import org.apache.sysds.hops.rewriter.estimators.RewriterCostEstimator; +import org.apache.sysds.hops.rewriter.generated.RewriteAutomaticallyGenerated; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleCreator; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleSet; +import org.apache.sysds.hops.rewriter.utils.RewriterSearchUtils; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import scala.Tuple2; +import scala.Tuple4; + +import java.io.*; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class RewriterFramework { + + // To test the framework + public static void main(String[] args) { + String dbPath = "./src/test/resources/rewriterframework/expressions.db"; + RewriterFramework rwf = new RewriterFramework(dbPath); + rwf.init(true,true); + rwf.dataDrivenSearch(1000); + rwf.systematicSearch(3); + //rwf.randomSearch(4, 4, 5000); + rwf.createRules(true); + rwf.removeInvalidRules(); + // Note that unconditional rules are not 'static' rules. + // It is a set of equivalences that have a single optimal expression + System.out.println(rwf.getUnconditionalRuleSet()); + //rwf.removeInapplicableRules(); + //System.out.println(rwf.getUnconditionalRuleSet().toJavaCode("GeneratedRewriteClass", true)); + + /*RewriterRuleSet rs = loadRuleSet(rPath); + saveJavaCode(sPath, rs, "GeneratedRewriteClass", true);*/ + } + + + private RuleContext ctx; + private Function converter; + private RewriterDatabase db; + private String dbFile; + + private int BATCH_SIZE = 1000; + private int MAX_COST_SAMPLES = 50; + + private RewriterEquivalenceDatabase equivalenceDB; + private List foundEquivalences; + private boolean pruneNovelExpressions = false; + + private RewriterRuleCreator unconditionalRuleCreator; + private RewriterRuleSet conditionalRuleSet; + + public RewriterFramework(String dbFile) { + this.dbFile = dbFile; + } + + private void setupDataDrivenSearch() { + if (db != null && db.size() > 0) + return; // Then a database has already been loaded + + try(BufferedReader reader = new BufferedReader(new FileReader(dbFile))) { + db.deserialize(reader, ctx); + } catch (IOException e) { + e.printStackTrace(); + } + } + + /** + * Initializes the rewriter framework + * @param allowInversionCanonicalization if the conversion from a/c => a*(c^-1) should be applied (during canonicalization) + * @param pruneNovelExpressions if only equivalence groups should be stored, where at least one expression was in the data-set + */ + public void init(boolean allowInversionCanonicalization, boolean pruneNovelExpressions) { + ctx = RewriterUtils.buildDefaultContext(); + converter = RewriterUtils.buildCanonicalFormConverter(ctx, allowInversionCanonicalization, false); + db = new RewriterDatabase(); + equivalenceDB = new RewriterEquivalenceDatabase(); + foundEquivalences = new ArrayList<>(); + this.pruneNovelExpressions = pruneNovelExpressions; + } + + public RuleContext getContext() { + return ctx; + } + + /** + * Performs a data-driven search where existing expressions and their subexpressions are considered + * @param exprPruningThreshold the maximum number of generated subexpressions (to avoid exploding numbers of subgraphs for big graphs) + */ + public void dataDrivenSearch(int exprPruningThreshold) { + setupDataDrivenSearch(); // Load the expression DB + + int size = db.size(); + RewriterDatabase exactExprDB = new RewriterDatabase(); + + MutableInt ctr = new MutableInt(0); + MutableInt failures = new MutableInt(0); + MutableInt generatedExpressions = new MutableInt(0); + MutableInt evaluatedExpressions = new MutableInt(0); + MutableInt totalCanonicalizationMillis = new MutableInt(0); + db.parForEach(expr -> { + if (ctr.incrementAndGet() % 10 == 0) + System.out.println("Done: " + ctr.intValue() + " / " + size); + + List subExprs = RewriterSearchUtils.generateSubtrees(expr, ctx, exprPruningThreshold); + if (subExprs.size() > exprPruningThreshold) + System.out.println("Critical number of subtrees: " + subExprs.size()); + if (subExprs.size() > 2 * exprPruningThreshold) { + System.out.println("Skipping subtrees..."); + subExprs = List.of(expr); + } + long evaluationCtr = 0; + long mCanonicalizationMillis = 0; + + for (RewriterStatement subExpr : subExprs) { + try { + if (!exactExprDB.insertEntry(ctx, subExpr)) + continue; + + evaluationCtr++; + + // Duplicate the statement as we do not want to canonicalize the original statement + long startMillis = System.currentTimeMillis(); + RewriterStatement canonicalForm = converter.apply(subExpr); + mCanonicalizationMillis += System.currentTimeMillis() - startMillis; + + synchronized (this) { + RewriterEquivalenceDatabase.DBEntry entry = equivalenceDB.insert(ctx, canonicalForm, subExpr); + + // Now, we use common variables + if (entry.equivalences.size() > 1) { + RewriterStatement commonForm = RewriterRuleCreator.createCommonForm(subExpr, entry.equivalences.get(0), canonicalForm, entry.canonicalForm, ctx)._1; + entry.equivalences.set(entry.equivalences.size()-1, commonForm); + } + + if (entry.equivalences.size() == 2) + foundEquivalences.add(entry); + } + } catch (Exception e) { + try { + System.err.println("Error from expression: " + subExpr.toParsableString(ctx)); + } catch (Exception e2) { + } + e.printStackTrace(); + failures.incrementAndGet(); + } + } + + generatedExpressions.addAndGet(subExprs.size()); + evaluatedExpressions.addAndGet(evaluationCtr); + totalCanonicalizationMillis.addAndGet(mCanonicalizationMillis); + }); + } + + /** + * Performs a systematic search + * @param maxDepth the maximum number of (virtual) operands + */ + public void systematicSearch(int maxDepth) { + systematicSearch(0, RewriterSearchUtils.getMaxSearchNumberForNumOps(maxDepth), true, false); + } + + /** + * Performs a systematic search + * @param maxDepth the maximum number of (virtual) operands + * @param includeDuplicateReferences if the search space should be extended to contain a shared variable (e.g. +(A,B) => [+(A,B), +(A,A)]) + */ + public void systematicSearch(int maxDepth, boolean includeDuplicateReferences) { + systematicSearch(0, RewriterSearchUtils.getMaxSearchNumberForNumOps(maxDepth), includeDuplicateReferences, false); + } + + /** + * Performs a systematic search + * @param fromIdx the start index + * @param toIdx the end index + * @param includeDuplicateReferences if the search space should be extended to contain a shared variable (e.g. +(A,B) => [+(A,B), +(A,A)]) + * @param includeRowColVectors if row-vectors and col-vectors should be included in the search (note that the data-driven approach does not support this) + */ + public void systematicSearch(int fromIdx, int toIdx, boolean includeDuplicateReferences, boolean includeRowColVectors) { + int diff = toIdx - fromIdx; + int maxN = toIdx; + + for (int batch = 0; batch < 10000 && batch * BATCH_SIZE < diff; batch++) { + List indices = IntStream.range(fromIdx + batch * BATCH_SIZE, fromIdx + Math.min((batch + 1) * BATCH_SIZE - 1, maxN)).boxed().collect(Collectors.toList()); + Collections.shuffle(indices); + MutableInt ctr2 = new MutableInt(0); + int maxSize = indices.size(); + final int mBATCH = batch; + indices.parallelStream().forEach(idx -> { + if (ctr2.incrementAndGet() % 10 == 0) + System.out.println("Done: " + (mBATCH * BATCH_SIZE + ctr2.intValue()) + " / " + (mBATCH * BATCH_SIZE + maxSize)); + + + List ops = RewriterSearchUtils.decodeOrderedStatements(idx); + List stmts = RewriterSearchUtils.buildAllPossibleDAGs(ops, ctx, true); + + for (RewriterStatement dag : stmts) { + List expanded = new ArrayList<>(); + expanded.add(dag); + if (includeDuplicateReferences) + expanded.addAll(RewriterSearchUtils.buildVariations(dag, ctx)); + if (includeRowColVectors) + expanded.addAll(RewriterSearchUtils.buildAssertionVariations(dag, ctx)); + + insertEquivalences(expanded); + } + }); + } + } + + public void randomSearch(int minExprSize, int maxExprSize, int numSamples) { + randomSearchFromIndex(RewriterSearchUtils.getMaxSearchNumberForNumOps(minExprSize-1)+1, RewriterSearchUtils.getMaxSearchNumberForNumOps(maxExprSize), numSamples, true, false); + } + + /** + * Performs a random search. Samples numSamples expression groups (groups of expressions encoded by a single integer) + * @param fromIdx the start index + * @param toIdx the end index + * @param numSamples the number of sampmles + * @param includeDuplicateReferences if expressions such as +(A,A) should be included in the search + * @param includeRowColVectors if row-col vectors should be included in the search + */ + public void randomSearchFromIndex(int fromIdx, int toIdx, int numSamples, boolean includeDuplicateReferences, boolean includeRowColVectors) { + // Now we will just do random sampling for a few rounds + Random rd = new Random(42); + for (int batch = 0; batch < 200 && batch * BATCH_SIZE < numSamples; batch++) { + List indices = IntStream.range(batch * BATCH_SIZE, (batch + 1) * BATCH_SIZE - 1).boxed().map(v -> fromIdx + rd.nextInt(toIdx-fromIdx)).collect(Collectors.toList()); + MutableInt ctr2 = new MutableInt(0); + int maxSize = indices.size(); + final int mBATCH = batch; + indices.parallelStream().forEach(idx -> { + if (ctr2.incrementAndGet() % 10 == 0) + System.out.println("Done: " + (mBATCH * BATCH_SIZE + ctr2.intValue()) + " / " + (mBATCH * BATCH_SIZE + maxSize)); + + List ops = RewriterSearchUtils.decodeOrderedStatements(idx); + List stmts = RewriterSearchUtils.buildAllPossibleDAGs(ops, ctx, true); + + for (RewriterStatement dag : stmts) { + List expanded = new ArrayList<>(); + expanded.add(dag); + if (includeDuplicateReferences) + expanded.addAll(RewriterSearchUtils.buildVariations(dag, ctx)); + if (includeRowColVectors) + expanded.addAll(RewriterSearchUtils.buildAssertionVariations(dag, ctx)); + + insertEquivalences(expanded); + } + }); + } + } + + private void insertEquivalences(List stmts) { + for (RewriterStatement stmt : stmts) { + try { + RewriterStatement canonicalForm = converter.apply(stmt); + + synchronized (this) { + if (pruneNovelExpressions && !equivalenceDB.containsEntry(canonicalForm)) + return; + + RewriterEquivalenceDatabase.DBEntry entry = equivalenceDB.insert(ctx, canonicalForm, stmt); + + // Now, we use common variables + if (entry.equivalences.size() > 1) { + RewriterStatement commonForm = RewriterRuleCreator.createCommonForm(stmt, entry.equivalences.get(0), canonicalForm, entry.canonicalForm, ctx)._1; + entry.equivalences.set(entry.equivalences.size()-1, commonForm); + } + + if (entry.equivalences.size() == 2) + foundEquivalences.add(entry); + } + } catch (Exception e) { + System.err.println("Faulty expression: " + stmt.toParsableString(ctx)); + e.printStackTrace(); + } + } + } + + /** + * Create rules from all observed equivalences + * @param freeDBMemory if all the stored equivalences that are not needed for rule creation should be dropped immediately + */ + public void createRules(boolean freeDBMemory) { + System.out.println("===== SUGGESTED REWRITES ====="); + List, Long, Boolean>> rewrites = findSuggestedRewrites(foundEquivalences, MAX_COST_SAMPLES); + + if (freeDBMemory) { + db.clear(); + foundEquivalences.clear(); + equivalenceDB.clear(); + } + + // Here, we create any rule + List> allRules = new ArrayList<>(); + int mCtr = 0; + for (Tuple4, Long, Boolean> rewrite : rewrites) { + if (++mCtr % 100 == 0) + System.out.println("Creating rule: " + mCtr + " / " + rewrites.size()); + + try { + RewriterRule rule; + if (rewrite._4()) + rule = RewriterRuleCreator.createRuleFromCommonStatements(rewrite._1(), rewrite._2().get(0), ctx); + else + rule = RewriterRuleCreator.createConditionalRuleFromCommonStatements(rewrite._1(), rewrite._2(), ctx); + + allRules.add(new Tuple4<>(rule, rewrite._3(), rule.getStmt1().countInstructions(), rewrite._4())); + } catch (Exception e) { + System.err.println("An error occurred while trying to create a rule:"); + System.err.println(rewrite._1().toParsableString(ctx, true)); + for (RewriterStatement stmt : rewrite._2()) + System.err.println(stmt.toParsableString(ctx, true)); + e.printStackTrace(); + } + } + + System.out.println("Rule creation complete!"); + + allRules.sort(Comparator.comparing(Tuple4::_3)); + + System.out.println("Rules sorted!"); + + unconditionalRuleCreator = new RewriterRuleCreator(ctx); + List conditionalRules = new ArrayList<>(); + + mCtr = 0; + + for (Tuple4 t : allRules) { + if (++mCtr % 100 == 0) + System.out.println("Registering rule: " + mCtr + " / " + allRules.size()); + + try { + // First, without validating correctness + // This might throw out some fallback options if a rule turns out to be incorrect but we there is a huge performance benefit + if (!t._1().isConditionalMultiRule()) { + unconditionalRuleCreator.registerRule(t._1(), converter, ctx); + } else { + conditionalRules.add(t._1()); + } + } catch (Exception e) { + e.printStackTrace(); + } + } + + conditionalRuleSet = new RewriterRuleSet(ctx, conditionalRules); + } + + /** + * This function removes rules where the output of the origin expression does not match + * the output of the target expression. + */ + public void removeInvalidRules() { + unconditionalRuleCreator.throwOutInvalidRules(true, false); + } + + /** + * This function removes rules where the origin expression is modified by the HOP-DAG rewriter. + * We aim to remove rules that are already implemented by intercepting the HOP-DAG after rewriting. + * We disable operator fusion and sum-product rewrites during execution. + * However, we throw away any rule that does not match our expected DAG structure, which may affect + * valid rules that are not correctly extracted during runtime. + */ + public void removeInapplicableRules() { + unconditionalRuleCreator.throwOutInvalidRules(false, true); + } + + /** + * + * @return the unconditional rule set (includes rules where there is exactly one possible optimum per equality set) + */ + public RewriterRuleSet getUnconditionalRuleSet() { + return unconditionalRuleCreator.getRuleSet(); + } + + /** + * + * @return the conditional rule set (rules where the optimal expression may change, e.g., (A*B)+(A*C) <=> A*(B+C)) + */ + public RewriterRuleSet getConditionalRuleSet() { + return conditionalRuleSet; + } + + public static boolean saveRuleSet(String filePath, RewriterRuleSet ruleSet) { + try (FileWriter writer = new FileWriter(filePath)) { + writer.write(ruleSet.serialize()); + } catch (IOException ex) { + ex.printStackTrace(); + return false; + } + + return true; + } + + public static RewriterRuleSet loadRuleSet(String filePath) { + try { + List lines = Files.readAllLines(Paths.get(filePath)); + return RewriterRuleSet.deserialize(lines, RewriterUtils.buildDefaultContext()); + } catch (IOException ex) { + ex.printStackTrace(); + return null; + } + } + + public static boolean saveJavaCode(String filePath, RewriterRuleSet ruleSet, String className, boolean optimize) { + try (FileWriter writer = new FileWriter(filePath)) { + writer.write(ruleSet.toJavaCode(className, optimize)); + } catch (IOException ex) { + ex.printStackTrace(); + return false; + } + + return true; + } + + /** + * This function computes rewrite suggestions based on cost-estimates. To enable random sampling, sample_size should be bigger than 1. + * Note that random sampling might generate incorrect suggestions due to inaccurate cost-estimates (especially for fused ops) + * @param equivalences + * @param sample_size how many sparsity and dimension values should be sampled; a sample size of 1 uses a fixed cost esimtate with ncols=nrows=2000 and fully dense matrices + * @return + */ + private List, Long, Boolean>> findSuggestedRewrites(List equivalences, int sample_size) { + List, Long, Boolean>> suggestions = SynchronizedList.decorate(new ArrayList<>()); + + AtomicLong idCtr = new AtomicLong(); + equivalences.parallelStream().forEach(entry -> { + try { + List mEq = entry.equivalences; + RewriterAssertions assertions = RewriterAssertionUtils.buildImplicitAssertions(mEq.get(0), ctx); + + for (int i = 1; i < mEq.size(); i++) + RewriterAssertionUtils.buildImplicitAssertions(mEq.get(i), assertions, ctx); + + List, List>> costs = RewriterCostEstimator.compareCosts(mEq, assertions, ctx, true, 0); + + Set> rewriteProposals = RewriterCostEstimator.findOptima(costs); + long mId = idCtr.incrementAndGet(); + + if (!rewriteProposals.isEmpty()) { + int targetIdx = rewriteProposals.stream().findFirst().get()._2; + boolean hasOneTarget = rewriteProposals.stream().allMatch(t -> t._2 == targetIdx); + + // Group by origin expression + Map>> grouped = rewriteProposals.stream().collect(Collectors.groupingBy(Tuple2::_1)); + + for (List> proposalsFromSameOrigin : grouped.values()) { + suggestions.add(new Tuple4<>(mEq.get(proposalsFromSameOrigin.get(0)._1), proposalsFromSameOrigin.stream().map(t -> mEq.get(t._2)).collect(Collectors.toList()), mId, hasOneTarget)); + } + } + } catch (Exception e) { + //e.printStackTrace(); + } + }); + + return suggestions; + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterInstruction.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterInstruction.java new file mode 100644 index 00000000000..bfd9ca615db --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterInstruction.java @@ -0,0 +1,627 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter; + +import org.apache.commons.lang3.function.TriFunction; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Collectors; + +public class RewriterInstruction extends RewriterStatement { + + private String id; + private String returnType; + private String instr; + private ArrayList operands = new ArrayList<>(); + private Function, Long> costFunction = null; + private boolean consolidated = false; + private int hashCode; + + public RewriterInstruction() { + } + + public RewriterInstruction(String instr, final RuleContext ctx, RewriterStatement... ops) { + id = UUID.randomUUID().toString(); + this.instr = instr; + withOps(ops); + consolidate(ctx); + } + + @Override + protected void compress(RewriterAssertions assertions) { + id = null; + operands.trimToSize(); + meta = null; + } + + @Override + public String getId() { + if (isDataOrigin()) { + if (trueInstruction().equals("const")) { + boolean regen = id == null; + if (!regen) { + try { + UUID.fromString(id); + regen = true; + } catch (Exception e) { + } + } + if (regen) { + id = "mConst" + new Random().nextInt(10000); + } + } else { + return getChild(0).getId(); + } + } + + return id; + } + + @Override + public String getResultingDataType(final RuleContext ctx) { + if (returnType != null) + return returnType; + + if (isArgumentList()) + returnType = getOperands().stream().map(op -> op.getResultingDataType(ctx)).reduce(RewriterUtils::defaultTypeHierarchy).get() + "..."; + else + returnType = ctx.instrTypes.get(trueTypedInstruction(ctx));//getResult(ctx).getResultingDataType(ctx); + + if (returnType == null) + throw new IllegalArgumentException("Return type not found for: " + trueTypedInstruction(ctx)); + + return returnType; + } + + @Override + public void refreshReturnType(final RuleContext ctx) { + returnType = null; + } + + @Override + public boolean isLiteral() { + return false; + } + + @Override + public Object getLiteral() { + return null; + } + + @Override + public RewriterStatement getLiteralStatement() { + for (RewriterStatement op : getChild(0).getOperands()) + if (op.isLiteral()) + return op; + + return null; + } + + @Override + public long intLiteral(boolean cast) { + throw new UnsupportedOperationException(); + } + + @Override + public double floatLiteral() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean boolLiteral() { + throw new UnsupportedOperationException(); + } + + @Override + public RewriterStatement consolidate(final RuleContext ctx) { + if (consolidated) + return this; + + if (instr == null || instr.isEmpty()) + throw new IllegalArgumentException("Instruction type cannot be empty"); + + if (getCostFunction(ctx) == null) + throw new IllegalArgumentException("Could not find a matching cost function for " + typedInstruction(ctx)); + + for (RewriterStatement operand : operands) + operand.consolidate(ctx); + + hashCode = Objects.hash(rid, refCtr, instr, getResultingDataType(ctx), operands); + consolidated = true; + + return this; + } + @Override + public int recomputeHashCodes(boolean recursively, final RuleContext ctx) { + if (recursively) { + operands.forEach(op -> op.recomputeHashCodes(true, ctx)); + } + + hashCode = Objects.hash(rid, refCtr, instr, getResultingDataType(ctx), operands.stream().map(RewriterStatement::structuralHashCode).collect(Collectors.toList())); + return hashCode; + } + + @Override + public boolean isConsolidated() { + return consolidated; + } + + @Override + public boolean match(final MatcherContext mCtx) { + RewriterStatement stmt = mCtx.currentStatement; + RuleContext ctx = mCtx.ctx; + + if (mCtx.isDebug()) + System.out.println("Matching: " + this.toString(ctx) + " <=> " + stmt.toString(ctx)); + + // Check for some meta information + if (mCtx.statementsCanBeVariables && getResultingDataType(ctx).equals("MATRIX")) { + if ((trueInstruction().equals("rowVec") && stmt.isRowVector()) + || (trueInstruction().equals("colVec") && stmt.isColVector())) { + RewriterStatement existingRef = mCtx.findInternalReference(this); + + if (existingRef != null) { + if (existingRef == stmt) + return true; + else { + mCtx.setFirstMismatch(this, stmt); + return false; + } + } + + if (!mCtx.allowDuplicatePointers && mCtx.getInternalReferences().containsValue(stmt)) { + mCtx.setFirstMismatch(this, stmt); + return false; + } + + mCtx.getInternalReferences().put(this, stmt); + + if (stmt.isInstruction() && (stmt.trueInstruction().equals("rowVec") || stmt.trueInstruction().equals("colVec"))) + mCtx.getDependencyMap().put(getChild(0), stmt.getChild(0)); + else + mCtx.getDependencyMap().put(getChild(0), stmt); + + + return true; + } + } + + if (stmt instanceof RewriterInstruction && (getResultingDataType(ctx).equals(stmt.getResultingDataType(ctx)) || (mCtx.allowImplicitTypeConversions && RewriterUtils.isImplicitlyConvertible(stmt.getResultingDataType(ctx), getResultingDataType(ctx))))) { + RewriterInstruction inst = (RewriterInstruction)stmt; + + if(!inst.instr.equals(this.instr)) { + if (!mCtx.allowPropertyScan) { + mCtx.setFirstMismatch(this, stmt); + return false; + } + Set props = inst.getProperties(ctx); + + if (props == null || !props.contains(typedInstruction(ctx))) { + mCtx.setFirstMismatch(this, stmt); + return false; + } + } + if (this.operands.size() != inst.operands.size()) { + mCtx.setFirstMismatch(this, stmt); + return false; + } + + RewriterStatement existingRef = mCtx.findInternalReference(this); + + if (existingRef != null) { + if (existingRef == stmt) + return true; + else { + mCtx.setFirstMismatch(this, stmt); + return false; + } + } + + if (!mCtx.allowDuplicatePointers && mCtx.getInternalReferences().containsValue(stmt)) { + mCtx.setFirstMismatch(this, stmt); + return false; + } + + RewriterRule.LinkObject ruleLink = mCtx.ruleLinks.get(this); + + if (ruleLink != null) + mCtx.getLinks().add(new RewriterRule.ExplicitLink(inst, ruleLink.stmt, ruleLink.transferFunction)); + + int s = inst.operands.size(); + + if (mCtx.findMinimalMismatchRoot) { + int mismatchCtr = 0; + + for (int i = 0; i < s; i++) { + mCtx.currentStatement = inst.operands.get(i); + + if (!operands.get(i).match(mCtx)) + mismatchCtr++; + } + + if (mismatchCtr == 0) + mCtx.getInternalReferences().put(this, stmt); + else if (mismatchCtr > 1) + mCtx.setFirstMismatch(this, stmt); + + return mismatchCtr == 0; + } else { + for (int i = 0; i < s; i++) { + mCtx.currentStatement = inst.operands.get(i); + + if (!operands.get(i).match(mCtx)) { + if (mCtx.isDebug()) + System.out.println("Mismatch: " + operands.get(i) + " <=> " + inst.operands.get(i)); + return false; + } + } + + mCtx.getInternalReferences().put(this, stmt); + return true; + } + } + + mCtx.setFirstMismatch(this, stmt); + return false; + } + + @Override + public RewriterStatement copyNode() { + RewriterInstruction mCopy = new RewriterInstruction(); + mCopy.instr = instr; + mCopy.id = id; + mCopy.costFunction = costFunction; + mCopy.consolidated = consolidated; + mCopy.operands = new ArrayList<>(operands); + mCopy.returnType = returnType; + if (meta != null) + mCopy.meta = new HashMap<>(meta); + else + mCopy.meta = null; + return mCopy; + } + + @Override + public RewriterStatement nestedCopyOrInject(Map copiedObjects, TriFunction injector, RewriterStatement parent, int pIdx) { + RewriterStatement mCpy = copiedObjects.get(this); + if (mCpy != null) + return mCpy; + mCpy = injector.apply(this, parent, pIdx); + if (mCpy != null) { + // Then change the reference to the injected object + copiedObjects.put(this, mCpy); + return mCpy; + } + + RewriterInstruction mCopy = new RewriterInstruction(); + mCopy.instr = instr; + mCopy.id = id; + mCopy.costFunction = costFunction; + mCopy.consolidated = consolidated; + mCopy.operands = new ArrayList<>(operands.size()); + mCopy.returnType = returnType; + mCopy.hashCode = hashCode; + if (meta != null) + mCopy.meta = new HashMap<>(meta); + else + mCopy.meta = null; + mCopy.nestedCopyOrInjectMetaStatements(copiedObjects, injector); + copiedObjects.put(this, mCopy); + + for (int i = 0; i < operands.size(); i++) + mCopy.operands.add(operands.get(i).nestedCopyOrInject(copiedObjects, injector, mCopy, i)); + + return mCopy; + } + + @Override + public boolean isArgumentList() { + return trueInstruction().equals("argList"); + } + + @Override + public List getArgumentList() { + return isArgumentList() ? getOperands() : null; + } + + @Override + public boolean isInstruction() { + return true; + } + + @Override + public boolean isEClass() { + return trueInstruction().equals("_EClass"); + } + + @Deprecated + @Override + public RewriterStatement clone() { + RewriterInstruction mClone = new RewriterInstruction(); + mClone.instr = instr; + mClone.id = id; + ArrayList clonedOperands = new ArrayList<>(operands.size()); + + for (RewriterStatement stmt : operands) + clonedOperands.add(stmt.clone()); + + mClone.operands = clonedOperands; + mClone.costFunction = costFunction; + mClone.consolidated = consolidated; + mClone.returnType = returnType; + mClone.meta = meta; + return mClone; + } + + @Override + public List getOperands() { + return operands == null ? Collections.emptyList() : operands; + } + + + @Override + public RewriterStatement simplify(final RuleContext ctx) { + for (int i = 0; i < operands.size(); i++) { + RewriterStatement stmt = operands.get(i).simplify(ctx); + if (stmt != null) + operands.set(i, stmt); + } + + Function rule = ctx.simplificationRules.get(typedInstruction(ctx)); + if (rule != null) { + RewriterStatement stmt = rule.apply(this); + + if (stmt != null) + return stmt; + } + return this; + } + + public RewriterInstruction withInstruction(String instr) { + if (consolidated) + throw new IllegalArgumentException("An instruction cannot be modified after consolidation"); + this.instr = instr; + return this; + } + + public RewriterInstruction withOps(RewriterStatement... operands) { + if (consolidated) + throw new IllegalArgumentException("An instruction cannot be modified after consolidation"); + this.operands = new ArrayList<>(Arrays.asList(operands)); + return this; + } + + public RewriterInstruction addOp(String id) { + if (consolidated) + throw new IllegalArgumentException("An instruction cannot be modified after consolidation"); + this.operands.add(new RewriterDataType().as(id)); + return this; + } + + public RewriterInstruction addOp(RewriterStatement operand) { + if (consolidated) + throw new IllegalArgumentException("An instruction cannot be modified after consolidation"); + this.operands.add(operand); + return this; + } + + public RewriterInstruction ofType(String type) { + if (consolidated) + throw new IllegalArgumentException("An instruction cannot be modified after consolidation"); + RewriterStatement stmt = this.operands.get(this.operands.size()-1); + + if (stmt instanceof RewriterDataType) + ((RewriterDataType)stmt).ofType(type); + else + throw new IllegalArgumentException("Can only set the data type of RewriterDataType class"); + + return this; + } + + public Function, Long> getCostFunction(final RuleContext ctx) { + if (this.costFunction == null) + this.costFunction = ctx.instrCosts.get(typedInstruction(ctx)); + + return this.costFunction; + } + + public RewriterInstruction withCostFunction(Function, Long> costFunction) { + if (consolidated) + throw new IllegalArgumentException("An instruction cannot be modified after consolidation"); + this.costFunction = costFunction; + return this; + } + + public Optional findOperand(String id) { + return this.operands.stream().filter(op -> op.getId().equals(id)).findFirst(); + } + + @Override + public RewriterInstruction as(String id) { + if (consolidated) + throw new IllegalArgumentException("An instruction cannot be modified after consolidation"); + this.id = id; + return this; + } + + public String typedInstruction(final RuleContext ctx) { + return typedInstruction(this.instr, false, ctx); + } + + public String getInstr() { + return instr; + } + + private String typedInstruction(String instrName, boolean allowImplicitConversions, final RuleContext ctx) { + StringBuilder builder = new StringBuilder(); + builder.append(instrName); + builder.append("("); + + if (!operands.isEmpty()) { + String resultingDataType = operands.get(0).getResultingDataType(ctx); + if (allowImplicitConversions) + resultingDataType = RewriterUtils.convertImplicitly(resultingDataType); + builder.append(resultingDataType); + } + + if (!isArgumentList()) { + for (int i = 1; i < operands.size(); i++) { + builder.append(","); + String resultingDataType = operands.get(i).getResultingDataType(ctx); + if (allowImplicitConversions) + resultingDataType = RewriterUtils.convertImplicitly(resultingDataType); + builder.append(resultingDataType); + } + } + + builder.append(")"); + return builder.toString(); + } + + @Override + public int toParsableString(StringBuilder sb, Map refs, int maxRefId, Map> vars, Set forceCreateRefs, final RuleContext ctx) { + Integer ref = refs.get(this); + + if (ref != null) { + sb.append('$'); + sb.append(ref); + return maxRefId; + } + + if (refCtr > 1 || forceCreateRefs.contains(this)) { + maxRefId++; + sb.append('$'); + sb.append(maxRefId); + sb.append(':'); + refs.put(this, maxRefId); + } + + sb.append(instr); + sb.append('('); + + for (int i = 0; i < getOperands().size(); i++) { + if (i > 0) + sb.append(','); + + RewriterStatement op = getOperands().get(i); + maxRefId = op.toParsableString(sb, refs, maxRefId, vars, forceCreateRefs, ctx); + } + + sb.append(')'); + + return maxRefId; + } + + @Override + public String toString(final RuleContext ctx) { + Object varName = getMeta(META_VARNAME); + if (varName != null) + return varName.toString(); + + Object trueInstrObj = getMeta("trueInstr"); + String typedInstr = trueInstrObj != null ? typedInstruction((String)trueInstrObj, false, ctx) : typedInstruction(ctx); + BiFunction customStringFunc = ctx.customStringRepr.get(typedInstr); + if (customStringFunc != null) + return customStringFunc.apply(this, ctx); + + String instrName = meta == null ? instr : meta.getOrDefault("trueName", instr).toString(); + + StringBuilder builder = new StringBuilder(); + builder.append(instrName); + builder.append("("); + for (int i = 0; i < operands.size(); i++) { + if (i > 0) + builder.append(", "); + builder.append(operands.get(i).toString(ctx)); + } + builder.append(")"); + return builder + "[" + System.identityHashCode(this) + "]"; + } + + @Override + public int structuralHashCode() { + return hashCode; + } + + @Override + public RewriterStatement rename(String id) { + this.id = id; + return this; + } + + public String changeConsolidatedInstruction(String newName, final RuleContext ctx) { + String typedInstruction = newName; + String newInstrReturnType = ctx.instrTypes.get(typedInstruction); + if (newInstrReturnType == null || !newInstrReturnType.equals(getResultingDataType(ctx))) + throw new IllegalArgumentException("An instruction name can only be changed if it has the same signature (return type) [" + typedInstruction + "::" + newInstrReturnType + " <-> " + typedInstruction(ctx) + "::" + getResultingDataType(ctx) + "]"); + String oldName = instr; + instr = newName.substring(0, newName.indexOf('(')); + recomputeHashCodes(false, ctx); + return oldName; + } + + public boolean hasProperty(String property, final RuleContext ctx) { + Set properties = getProperties(ctx); + + if (properties == null) + return false; + + return properties.contains(property); + } + + public String trueInstruction() { + return instr; + } + + public String trueTypedInstruction(final RuleContext ctx) { + return typedInstruction(trueInstruction(), false, ctx); + } + + public String trueTypedInstruction(boolean allowImplicitConversions, final RuleContext ctx) { + return typedInstruction(trueInstruction(), allowImplicitConversions, ctx); + } + + public Set getProperties(final RuleContext ctx) { + Set ret = ctx.instrProperties.get(trueTypedInstruction(ctx)); + if (ret == null) + return Collections.emptySet(); + return ret; + } + + public void unsafeSetInstructionName(String str) { + this.instr = str; + } + +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuntimeUtils.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuntimeUtils.java new file mode 100644 index 00000000000..9ed52b1506e --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterRuntimeUtils.java @@ -0,0 +1,938 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter; + +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types; +import org.apache.sysds.hops.AggBinaryOp; +import org.apache.sysds.hops.AggUnaryOp; +import org.apache.sysds.hops.BinaryOp; +import org.apache.sysds.hops.DataGenOp; +import org.apache.sysds.hops.DataOp; +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.IndexingOp; +import org.apache.sysds.hops.LiteralOp; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.hops.ReorgOp; +import org.apache.sysds.hops.UnaryOp; +import org.apache.sysds.hops.rewrite.HopRewriteUtils; +import org.apache.sysds.hops.rewriter.dml.DMLExecutor; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.parser.DMLProgram; +import org.apache.sysds.parser.ForStatement; +import org.apache.sysds.parser.ForStatementBlock; +import org.apache.sysds.parser.FunctionStatement; +import org.apache.sysds.parser.FunctionStatementBlock; +import org.apache.sysds.parser.IfStatement; +import org.apache.sysds.parser.IfStatementBlock; +import org.apache.sysds.parser.StatementBlock; +import org.apache.sysds.parser.WhileStatement; +import org.apache.sysds.parser.WhileStatementBlock; + +import javax.annotation.Nullable; +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.FileReader; +import java.io.FileWriter; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.stream.Collectors; + +public class RewriterRuntimeUtils { + public static final boolean interceptAll = false; + public static boolean printUnknowns = false; + public static final String dbFile = "./src/test/resources/rewriterframework/expressions.db"; + public static final boolean readDB = true; + public static final boolean writeDB = true; + + private static boolean setupComplete = false; + + private static HashMap unknownOps = new HashMap<>(); + private static boolean ENFORCE_FLOAT_OBSERVATIONS = true; // To force every data type to float + private static boolean OBSERVE_SELECTIONS = false; + private static boolean OBSERVE_RAND = false; + + public static void setupIfNecessary() { + if (setupComplete) + return; + + setupComplete = true; + + if (interceptAll) { + System.out.println("INTERCEPTOR"); + OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = false; + OptimizerUtils.ALLOW_OPERATOR_FUSION = false; + System.out.println("OptLevel:" + OptimizerUtils.getOptLevel().toString()); + System.out.println("AllowOpFusion: " + OptimizerUtils.ALLOW_OPERATOR_FUSION); + System.out.println("AllowSumProductRewrites: " + OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES); + System.out.println("AllowConstantFolding: " + OptimizerUtils.ALLOW_CONSTANT_FOLDING); + + // Setup default context + RuleContext ctx = RewriterUtils.buildDefaultContext(); + + RewriterDatabase exactExprDB = new RewriterDatabase(); + + if (readDB) { + try(BufferedReader reader = new BufferedReader(new FileReader(dbFile))) { + exactExprDB.deserialize(reader, ctx); + } catch (IOException ex) { + ex.printStackTrace(); + } + } + + RewriterRuntimeUtils.attachPreHopInterceptor(prog -> { + RewriterRuntimeUtils.forAllUniqueTranslatableStatements(prog, 4, mstmt -> {}, exactExprDB, ctx); + return true; // We will continue to extract the rewritten hop + }); + + RewriterRuntimeUtils.attachHopInterceptor(prog -> { + RewriterRuntimeUtils.forAllUniqueTranslatableStatements(prog, 4, mstmt -> {}, exactExprDB, ctx); + return false; // Then we cancel the excecution to save time + }); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + if (writeDB) { + try (BufferedWriter writer = new BufferedWriter(new FileWriter(dbFile))) { + exactExprDB.serialize(writer, ctx); + } catch (IOException e) { + e.printStackTrace(); + } + } + })); + } + } + + public static void attachHopInterceptor(Function interceptor) { + DMLScript.hopInterceptor = interceptor; + } + + public static void detachHopInterceptor() { + DMLScript.hopInterceptor = null; + } + + public static void attachPreHopInterceptor(Function interceptor) { + DMLScript.preHopInterceptor = interceptor; + } + + public static void detachPreHopInterceptor() { + DMLScript.preHopInterceptor = null; + } + + public static RewriterStatement buildDAGFromHop(Hop hop, int maxDepth, boolean mindDataCharacteristics, final RuleContext ctx) { + RewriterStatement out = buildDAGRecursively(hop, null, new HashMap<>(), 0, maxDepth, ctx); + + if (mindDataCharacteristics) + return populateDataCharacteristics(out, ctx); + + return out; + } + + public static RewriterStatement populateDataCharacteristics(RewriterStatement stmt, final RuleContext ctx) { + if (stmt == null) + return null; + + if (stmt instanceof RewriterDataType && stmt.getResultingDataType(ctx).equals("MATRIX")) { + Long nrow = (Long) stmt.getMeta("_actualNRow"); + Long ncol = (Long) stmt.getMeta("_actualNCol"); + int matType = 0; + + if (nrow != null && nrow == 1L) { + matType = 1; + } else if (ncol != null && ncol == 1L) { + matType = 2; + } + + if (matType > 0) { + return new RewriterInstruction() + .as(stmt.getId()) + .withInstruction(matType == 1L ? "rowVec" : "colVec") + .withOps(stmt) + .consolidate(ctx); + } + } + + Map createdObjects = new HashMap<>(); + + stmt.forEachPostOrder((cur, pred) -> { + for (int i = 0; i < cur.getOperands().size(); i++) { + RewriterStatement child = cur.getChild(i); + + if (child instanceof RewriterDataType && child.getResultingDataType(ctx).equals("MATRIX")) { + Long nrow = (Long) child.getMeta("_actualNRow"); + Long ncol = (Long) child.getMeta("_actualNCol"); + int matType = 0; + + if (nrow != null && nrow == 1L) { + matType = 1; + } else if (ncol != null && ncol == 1L) { + matType = 2; + } + + if (matType > 0) { + RewriterStatement created = createdObjects.get(child); + + if (created == null) { + created = new RewriterInstruction() + .as(stmt.getId()) + .withInstruction(matType == 1 ? "rowVec" : "colVec") + .withOps(child) + .consolidate(ctx); + createdObjects.put(child, created); + } + + cur.getOperands().set(i, created); + } + } + } + }, false); + + return stmt; + } + + public static void forAllUniqueTranslatableStatements(DMLProgram program, int maxDepth, Consumer stmt, RewriterDatabase db, final RuleContext ctx) { + try { + Set visited = new HashSet<>(); + + for (String namespaceKey : program.getNamespaces().keySet()) { + for (String fname : program.getFunctionStatementBlocks(namespaceKey).keySet()) { + FunctionStatementBlock fsblock = program.getFunctionStatementBlock(namespaceKey, fname); + handleStatementBlock(fsblock, maxDepth, stmt, visited, db, ctx); + } + } + + for (StatementBlock sb : program.getStatementBlocks()) { + handleStatementBlock(sb, maxDepth, stmt, visited, db, ctx); + } + } catch (Exception e) { + e.printStackTrace(); + } + } + + private static void handleStatementBlock(StatementBlock sb, int maxDepth, Consumer consumer, Set visited, RewriterDatabase db, final RuleContext ctx) { + if (sb instanceof FunctionStatementBlock) + { + FunctionStatementBlock fsb = (FunctionStatementBlock) sb; + FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0); + fstmt.getBody().forEach(s -> handleStatementBlock(s, maxDepth, consumer, visited, db, ctx)); + } + else if (sb instanceof WhileStatementBlock) + { + WhileStatementBlock wsb = (WhileStatementBlock) sb; + WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); + forAllUniqueTranslatableStatements(wsb.getPredicateHops(), maxDepth, consumer, visited, db, ctx); + wstmt.getBody().forEach(s -> handleStatementBlock(s, maxDepth, consumer, visited, db, ctx)); + } + else if (sb instanceof IfStatementBlock) + { + IfStatementBlock isb = (IfStatementBlock) sb; + IfStatement istmt = (IfStatement)isb.getStatement(0); + forAllUniqueTranslatableStatements(isb.getPredicateHops(), maxDepth, consumer, visited, db, ctx); + istmt.getIfBody().forEach(s -> handleStatementBlock(s, maxDepth, consumer, visited, db, ctx)); + istmt.getElseBody().forEach(s -> handleStatementBlock(s, maxDepth, consumer, visited, db, ctx)); + } + else if (sb instanceof ForStatementBlock) + { + ForStatementBlock fsb = (ForStatementBlock) sb; + ForStatement fstmt = (ForStatement)fsb.getStatement(0); + forAllUniqueTranslatableStatements(fsb.getFromHops(), maxDepth, consumer, visited, db, ctx); + forAllUniqueTranslatableStatements(fsb.getToHops(), maxDepth, consumer, visited, db, ctx); + forAllUniqueTranslatableStatements(fsb.getIncrementHops(), maxDepth, consumer, visited, db, ctx); + fstmt.getBody().forEach(s -> handleStatementBlock(s, maxDepth, consumer, visited, db, ctx)); + } + else + { + if (sb.getHops() != null) + sb.getHops().forEach(hop -> forAllUniqueTranslatableStatements(hop, maxDepth, consumer, visited, db, ctx)); + } + } + + private static void forAllUniqueTranslatableStatements(Hop currentHop, int maxDepth, Consumer consumer, Set visited, RewriterDatabase db, final RuleContext ctx) { + if (currentHop == null || visited.contains(currentHop)) + return; + + visited.add(currentHop); + RewriterStatement stmt = buildDAGRecursively(currentHop, null, new HashMap<>(), 0, maxDepth, ctx); + + if (stmt instanceof RewriterInstruction) + stmt = ctx.metaPropagator.apply(stmt); + + if (stmt == null) { + // TODO: What to do about TWrite and PWrite? + // Just ignore these ops? + if (!currentHop.getOpString().startsWith("TWrite") && !currentHop.getOpString().startsWith("PWrite") && !currentHop.getValueType().toString().equals("STRING") && !currentHop.getOpString().startsWith("LiteralOp") && !currentHop.getOpString().startsWith("fcall") && !currentHop.getOpString().startsWith("TRead") && !currentHop.getOpString().startsWith("PRead")) + unknownOps.compute(currentHop.getOpString() + "::" + currentHop.getDataType() + "::" + currentHop.getValueType(), (k, v) -> v == null ? 1 : v + 1); + } + + if (stmt != null) { + stmt.prepareForHashing(); + stmt.recomputeHashCodes(ctx); + } + + if (stmt != null && db.insertEntry(ctx, stmt)) { + RewriterStatement cpy = stmt.nestedCopyOrInject(new HashMap<>(), el -> null); + consumer.accept(cpy); + } + + if (currentHop.getInput() != null) + currentHop.getInput().forEach(child -> forAllUniqueTranslatableStatements(child, maxDepth, consumer, visited, db, ctx)); + } + + private static RewriterStatement buildDAGRecursively(Hop next, @Nullable String expectedType, Map cache, int depth, int maxDepth, final RuleContext ctx) { + if (depth == maxDepth) + return buildLeaf(next, expectedType, ctx); + + if (cache.containsKey(next)) + return checkForCorrectTypes(cache.get(next), expectedType, next, ctx); + + if (next instanceof LiteralOp) { + RewriterStatement literal = buildLiteral((LiteralOp)next, expectedType, ctx); + literal = checkForCorrectTypes(literal, expectedType, next, ctx); + cache.put(next, literal); + return literal; + } + + if (next instanceof AggBinaryOp) { + RewriterStatement stmt = buildAggBinaryOp((AggBinaryOp) next, expectedType, ctx); + stmt = checkForCorrectTypes(stmt, expectedType, next, ctx); + + if (stmt == null) + return buildLeaf(next, expectedType, ctx); + + insertDataCharacteristics(next, stmt, ctx); + + if (buildInputs(stmt, next.getInput(), cache, true, depth, maxDepth, ctx)) + return stmt; + + return null; + } + + if (next instanceof AggUnaryOp) { + RewriterStatement stmt = buildAggUnaryOp((AggUnaryOp) next, expectedType, ctx); + stmt = checkForCorrectTypes(stmt, expectedType, next, ctx); + + if (stmt == null) + return buildLeaf(next, expectedType, ctx); + + insertDataCharacteristics(next, stmt, ctx); + + if (buildInputs(stmt, next.getInput(), cache, true, depth, maxDepth, ctx)) + return stmt; + + return null; + } + + if (next instanceof BinaryOp) { + RewriterStatement stmt = buildBinaryOp((BinaryOp) next, expectedType, ctx); + stmt = checkForCorrectTypes(stmt, expectedType, next, ctx); + + if (stmt == null) + return buildLeaf(next, expectedType, ctx); + + insertDataCharacteristics(next, stmt, ctx); + + if (buildInputs(stmt, next.getInput(), cache, true, depth, maxDepth, ctx)) + return stmt; + + return null; + } + + if (next instanceof ReorgOp) { + RewriterStatement stmt = buildReorgOp((ReorgOp) next, expectedType, ctx); + stmt = checkForCorrectTypes(stmt, expectedType, next, ctx); + + if (stmt == null) + return buildLeaf(next, expectedType, ctx); + + insertDataCharacteristics(next, stmt, ctx); + + if (buildInputs(stmt, next.getInput(), cache, true, depth, maxDepth, ctx)) + return stmt; + + return null; + } + + if (next instanceof UnaryOp) { + RewriterStatement stmt = buildUnaryOp((UnaryOp)next, expectedType, ctx); + stmt = checkForCorrectTypes(stmt, expectedType, next, ctx); + + if (stmt == null) + return buildLeaf(next, expectedType, ctx); + + insertDataCharacteristics(next, stmt, ctx); + + if (buildInputs(stmt, next.getInput(), cache, true, depth, maxDepth, ctx)) + return stmt; + + return null; + } + + if (next instanceof IndexingOp) { + RewriterStatement stmt = buildIndexingOp((IndexingOp) next, expectedType, ctx); + stmt = checkForCorrectTypes(stmt, expectedType, next, ctx); + + if (stmt == null) + return buildLeaf(next, expectedType, ctx); + + insertDataCharacteristics(next, stmt, ctx); + + if (buildInputs(stmt, next.getInput(), cache, true, depth, maxDepth, ctx)) + return stmt; + + return null; + } + + if (next instanceof DataGenOp) { + List interestingHops = new ArrayList<>(); + RewriterStatement stmt = buildDataGenOp((DataGenOp)next, expectedType, ctx, interestingHops); + stmt = checkForCorrectTypes(stmt, expectedType, next, ctx); + + if (stmt == null) + return buildLeaf(next, expectedType, ctx); + + insertDataCharacteristics(next, stmt, ctx); + + if (buildInputs(stmt, interestingHops, cache, true, depth, maxDepth, ctx)) + return stmt; + + return null; + } + + if (next instanceof DataOp) { + DataOp dop = (DataOp) next; + + if (dop.isRead()) + return buildLeaf(next, expectedType, ctx); + } + + if (printUnknowns) { + System.out.println("Unknown Op: " + next); + System.out.println("Class: " + next.getClass().getSimpleName()); + System.out.println("OPString: " + next.getOpString()); + } + + return null; + } + + private static void insertDataCharacteristics(Hop hop, RewriterStatement stmt, final RuleContext ctx) { + if (stmt.getResultingDataType(ctx).equals("MATRIX")) { + if (hop.getDataCharacteristics() != null) { + long nrows = hop.getDataCharacteristics().getRows(); + long ncols = hop.getDataCharacteristics().getCols(); + if (nrows > 0) + stmt.unsafePutMeta("_actualNRow", nrows); + if (ncols > 0) + stmt.unsafePutMeta("_actualNCol", ncols); + } + } + } + + private static RewriterStatement checkForCorrectTypes(RewriterStatement stmt, @Nullable String expectedType, Hop hop, final RuleContext ctx) { + if (stmt == null) + return null; + + if (expectedType == null) + expectedType = stmt.getResultingDataType(ctx); + + String actualType = resolveExactDataType(hop); + + if (actualType == null) + return null; + + if (actualType.equals(expectedType)) + return stmt; + + if (actualType.equals("MATRIX")) { + HashMap oldTypes = new HashMap<>(); + oldTypes.put("A", stmt); + RewriterStatement newStmt = RewriterUtils.parseExpression("as.matrix(A)", new HashMap<>(), oldTypes, ctx); + return newStmt; + } + + return null; + } + + private static RewriterStatement buildLeaf(Hop hop, @Nullable String expectedType, final RuleContext ctx) { + String hopName = hop.getName(); + + // Check if hopName collides with literal values + if (RewriterUtils.LONG_PATTERN.matcher(hopName).matches()) + hopName = "int" + new Random().nextInt(1000); + if (RewriterUtils.DOUBLE_PATTERN.matcher(hopName).matches() || RewriterUtils.SPECIAL_FLOAT_PATTERN.matcher(hopName).matches()) + hopName = "float" + new Random().nextInt(1000); + + if (expectedType != null) { + RewriterStatement stmt = RewriterUtils.parse(hopName, ctx, expectedType + ":" + hopName); + insertDataCharacteristics(hop, stmt, ctx); + return stmt; + } + + switch (hop.getDataType()) { + case SCALAR: + return buildScalarLeaf(hop, hopName, ctx); + case MATRIX: + RewriterStatement stmt = RewriterUtils.parse(hopName, ctx, "MATRIX:" + hopName); + insertDataCharacteristics(hop, stmt, ctx); + return stmt; + } + + return null; // Not supported then + } + + private static RewriterStatement buildScalarLeaf(Hop hop, final RuleContext ctx) { + return buildScalarLeaf(hop, null, ctx); + } + + private static RewriterStatement buildScalarLeaf(Hop hop, @Nullable String newName, final RuleContext ctx) { + if (newName == null) + newName = hop.getName(); + + switch (hop.getValueType()) { + case FP64: + case FP32: + return RewriterUtils.parse(newName, ctx, "FLOAT:" + newName); + case INT64: + case INT32: + if (ENFORCE_FLOAT_OBSERVATIONS) + return RewriterUtils.parse(newName, ctx, "FLOAT:" + newName); + return RewriterUtils.parse(newName, ctx, "INT:" + newName); + case BOOLEAN: + if (ENFORCE_FLOAT_OBSERVATIONS) + return RewriterUtils.parse(newName, ctx, "FLOAT:" + newName); + return RewriterUtils.parse(newName, ctx, "BOOL:" + newName); + } + + return null; // Not supported then + } + + private static boolean buildInputs(RewriterStatement stmt, List inputs, Map cache, boolean fixedSize, int depth, int maxDepth, final RuleContext ctx) { + if (fixedSize && stmt.getOperands().size() != inputs.size()) + return false; + + List children = new ArrayList<>(); + int ctr = 0; + for (Hop in : inputs) { + RewriterStatement childStmt = buildDAGRecursively(in, fixedSize ? stmt.getOperands().get(ctr).getResultingDataType(ctx) : null, cache, depth + 1, maxDepth, ctx); + + if (childStmt == null) { + //System.out.println("Could not build child: " + in); + // TODO: Then just build leaf + //return false; + childStmt = buildLeaf(in, stmt.getOperands().get(ctr).getResultingDataType(ctx), ctx); + + if (childStmt == null) + return false; + } + + if (fixedSize && !RewriterUtils.convertImplicitly(childStmt.getResultingDataType(ctx), ENFORCE_FLOAT_OBSERVATIONS).equals(stmt.getOperands().get(ctr).getResultingDataType(ctx))) + throw new IllegalArgumentException("Different data type than expected: " + stmt.toString(ctx) + "; [" + ctr + "] " + childStmt.toString(ctx) + " ::" + childStmt.getResultingDataType(ctx)); + + children.add(childStmt); + ctr++; + } + + stmt.getOperands().clear(); + stmt.getOperands().addAll(children); + stmt.consolidate(ctx); + return true; + } + + private static RewriterStatement buildIndexingOp(IndexingOp op, @Nullable String expectedType, final RuleContext ctx) { + if (!OBSERVE_SELECTIONS) + return null; + + if (expectedType == null) { + expectedType = resolveExactDataType(op); + + if (expectedType == null) + return null; + } + + switch (op.getOpString()) { + case "rix": + return RewriterUtils.parse("[](A, i, j, k, l)", ctx, "MATRIX:A", "INT:i,j,k,l"); + } + + return null; + } + + private static RewriterStatement buildUnaryOp(UnaryOp op, @Nullable String expectedType, final RuleContext ctx) { + if (expectedType == null) { + expectedType = resolveExactDataType(op); + + if (expectedType == null) + return null; + } + + String fromType = resolveExactDataType(op.getInput(0)); + Types.DataType toDT = op.getDataType(); + + if (!toDT.isMatrix() && !toDT.isScalar()) + return null; + + switch(op.getOpString()) { + case "u(castdts)": + if (toDT.isMatrix()) + return RewriterUtils.parse("cast.MATRIX(A)", ctx, "MATRIX:A"); + if (fromType != null) + return RewriterUtils.parse("cast." + expectedType + "(A)", ctx, fromType + ":A"); + + return null; + case "u(castdtm)": + if (fromType != null) + return RewriterUtils.parse("cast.MATRIX(a)", ctx, fromType + ":a"); + + return null; + case "u(sqrt)": + return RewriterUtils.parse("sqrt(A)", ctx, fromType + ":A"); + case "u(!)": + return RewriterUtils.parse("!(A)", ctx, fromType + ":A"); + case "u(ncol)": + return RewriterUtils.parse("ncol(A)", ctx, "MATRIX:A"); + case "u(nrow)": + return RewriterUtils.parse("nrow(A)", ctx, "MATRIX:A"); + case "u(length)": + return RewriterUtils.parse("length(A)", ctx, "MATRIX:A"); + case "u(exp)": + return RewriterUtils.parse("exp(A)", ctx, fromType + ":A"); + case "u(round)": + return RewriterUtils.parse("round(A)", ctx, fromType + ":A"); + case "u(abs)": + return RewriterUtils.parse("abs(A)", ctx, fromType + ":A"); + } + + if (printUnknowns) + DMLExecutor.println("Unknown UnaryOp: " + op.getOpString()); + return null; + } + + private static RewriterStatement buildAggBinaryOp(AggBinaryOp op, @Nullable String expectedType, final RuleContext ctx) { + if (expectedType != null && !expectedType.equals("MATRIX")) + throw new IllegalArgumentException(); + + // Some placeholder definitions + switch(op.getOpString()) { + case "ba(+*)": // Matrix multiplication + return RewriterUtils.parse("%*%(A, B)", ctx, "MATRIX:A,B"); + } + + if (printUnknowns) + DMLExecutor.println("Unknown AggBinaryOp: " + op.getOpString()); + return null; + } + + private static RewriterStatement buildAggUnaryOp(AggUnaryOp op, @Nullable String expectedType, final RuleContext ctx) { + // Some placeholder definitions + switch(op.getOpString()) { + case "ua(+C)": // Matrix multiplication + if (expectedType != null && !expectedType.equals("MATRIX")) + throw new IllegalArgumentException("Unexpected type: " + expectedType); + return RewriterUtils.parse("colSums(A)", ctx, "MATRIX:A"); + case "ua(+R)": + if (expectedType != null && !expectedType.equals("MATRIX")) + throw new IllegalArgumentException("Unexpected type:" + expectedType); + return RewriterUtils.parse("rowSums(A)", ctx, "MATRIX:A"); + case "ua(+RC)": + if (expectedType != null && !expectedType.equals("FLOAT")) + throw new IllegalArgumentException("Unexpected type: " + expectedType); + return RewriterUtils.parse("sum(A)", ctx, "MATRIX:A"); + case "ua(nrow)": + if (expectedType != null && !expectedType.equals("INT")) + throw new IllegalArgumentException("Unexpected type: " + expectedType); + return RewriterUtils.parse("nrow(A)", ctx, "MATRIX:A"); + case "ua(ncol)": + if (expectedType != null && !expectedType.equals("INT")) + throw new IllegalArgumentException("Unexpected type: " + expectedType); + return RewriterUtils.parse("ncol(A)", ctx, "MATRIX:A"); + case "ua(maxRC)": + if (expectedType != null && !expectedType.equals("FLOAT")) + throw new IllegalArgumentException("Unexpected type: " + expectedType); + return RewriterUtils.parse("max(A)", ctx, "MATRIX:A"); + case "ua(minRC)": + if (expectedType != null && !expectedType.equals("FLOAT")) + throw new IllegalArgumentException("Unexpected type: " + expectedType); + return RewriterUtils.parse("min(A)", ctx, "MATRIX:A"); + case "ua(traceRC)": + return RewriterUtils.parse("trace(A)", ctx, "MATRIX:A"); + } + + if (printUnknowns) + DMLExecutor.println("Unknown AggUnaryOp: " + op.getOpString()); + return null; + } + + private static RewriterStatement buildBinaryOp(BinaryOp op, @Nullable String expectedType, final RuleContext ctx) { + String t1 = resolveExactDataType(op.getInput().get(0)); + String t2 = resolveExactDataType(op.getInput().get(1)); + + if (t1 == null || t2 == null) + return null; + + t1 += ":a"; + t2 += ":b"; + + RewriterStatement parsed = null; + + switch(op.getOpString()) { + case "b(+)": // Addition + parsed = RewriterUtils.parse("+(a, b)", ctx, t1, t2); + break; + case "b(*)": // Matrix multiplication + parsed = RewriterUtils.parse("*(a, b)", ctx, t1, t2); + break; + case "b(-)": + parsed = RewriterUtils.parse("-(a, b)", ctx, t1, t2); + break; + case "b(/)": + parsed = RewriterUtils.parse("/(a, b)", ctx, t1, t2); + break; + case "b(||)": + parsed = RewriterUtils.parse("|(a, b)", ctx, t1, t2); + break; + case "b(!=)": + parsed = RewriterUtils.parse("!=(a, b)", ctx, t1, t2); + break; + case "b(==)": + parsed = RewriterUtils.parse("==(a, b)", ctx, t1, t2); + break; + case "b(&&)": + parsed = RewriterUtils.parse("&(a, b)", ctx, t1, t2); + break; + case "b(<)": + parsed = RewriterUtils.parse("<(a, b)", ctx, t1, t2); + break; + case "b(>)": + parsed = RewriterUtils.parse(">(a, b)", ctx, t1, t2); + break; + case "b(>=)": + parsed = RewriterUtils.parse(">=(a, b)", ctx, t1, t2); + break; + case "b(<=)": + parsed = RewriterUtils.parse("<=(a, b)", ctx, t1, t2); + break; + case "b(^)": + parsed = RewriterUtils.parse("^(a, b)", ctx, t1, t2); + break; + case "b(rbind)": + if (!t1.equals("MATRIX") || !t2.equals("MATRIX")) + return null; + return RewriterUtils.parse("RBind(a, b)", ctx, t1, t2); + case "b(cbind)": + if (!t1.equals("MATRIX") || !t2.equals("MATRIX")) + return null; + return RewriterUtils.parse("CBind(a, b)", ctx, t1, t2); + case "b(1-*)": + return RewriterUtils.parse("1-*(A, B)", ctx, "MATRIX:A,B"); + } + + if (parsed != null) + return parsed.rename(op.getName()); + + if (printUnknowns) + DMLExecutor.println("Unknown BinaryOp: " + op.getOpString()); + return null; + } + + private static String resolveExactDataType(Hop hop) { + if (hop.getDataType() == Types.DataType.MATRIX) + return "MATRIX"; + + switch (hop.getValueType()) { + case FP64: + case FP32: + return "FLOAT"; + case INT64: + case INT32: + if (ENFORCE_FLOAT_OBSERVATIONS) + return "FLOAT"; + return "INT"; + case BOOLEAN: + if (ENFORCE_FLOAT_OBSERVATIONS) + return "FLOAT"; + return "BOOL"; + } + + if (printUnknowns) + DMLExecutor.println("Unknown type: " + hop + " -> " + hop.getDataType() + " : " + hop.getValueType()); + + return null; + } + + private static RewriterStatement buildReorgOp(ReorgOp op, @Nullable String expectedType, final RuleContext ctx) { + if (expectedType != null && !expectedType.equals("MATRIX")) + throw new IllegalArgumentException(); + + switch(op.getOpString()) { + case "r(r')": // Matrix multiplication + return RewriterUtils.parse("t(A)", ctx, "MATRIX:A"); + case "r(rev)": + return RewriterUtils.parse("rev(A)", ctx, "MATRIX:A"); + case "r(rdiag)": + return RewriterUtils.parse("diag(A)", ctx, "MATRIX:A"); + } + + //System.out.println("Unknown BinaryOp: " + op.getOpString()); + if (printUnknowns) + DMLExecutor.println("Unknown ReorgOp: " + op.getOpString()); + return null; + } + + private static RewriterStatement buildDataGenOp(DataGenOp op, @Nullable String expectedType, final RuleContext ctx, List interestingHops) { + if (expectedType != null && !expectedType.equals("MATRIX")) + throw new IllegalArgumentException(); + + switch(op.getOpString()) { + case "dg(rand)": + if (OBSERVE_RAND) { + interestingHops.add(op.getParam("rows")); + interestingHops.add(op.getParam("cols")); + interestingHops.add(op.getParam("min")); + interestingHops.add(op.getParam("max")); + return RewriterUtils.parse("rand(i1, i2, f1, f2)", ctx, "INT:i1,i2", "FLOAT:f1,f2").rename(op.getName()); + } + return null; + } + + return null; + } + + private static RewriterStatement buildLiteral(LiteralOp literal, @Nullable String expectedType, final RuleContext ctx) { + if (literal.getDataType() != Types.DataType.SCALAR) + return null; // Then it is not supported yet + + String mType; + Object mValue; + + switch (literal.getValueType()) { + case FP64: + case FP32: + if (expectedType != null && !expectedType.equals("FLOAT")) + throw new IllegalArgumentException("Unexpected type: " + expectedType); + return new RewriterDataType().as(UUID.randomUUID().toString()).ofType("FLOAT").asLiteral(literal.getDoubleValue()).consolidate(ctx); + case INT32: + case INT64: + if (expectedType != null) { + if (expectedType.equals("INT")) { + mType = expectedType; + mValue = literal.getLongValue(); + } else if (expectedType.equals("FLOAT")) { + mType = "FLOAT"; + mValue = (double)literal.getLongValue(); + } else { + throw new IllegalArgumentException(); + } + } else { + mType = "INT"; + mValue = literal.getLongValue(); + } + return new RewriterDataType().as(UUID.randomUUID().toString()).ofType(mType).asLiteral(mValue).consolidate(ctx); + case BOOLEAN: + if (expectedType != null) { + if (expectedType.equals("FLOAT")) { + mType = expectedType; + mValue = literal.getBooleanValue() ? 1.0D : 0.0D; + } else if (expectedType.equals("INT")) { + mType = expectedType; + mValue = literal.getBooleanValue() ? 1L : 0L; + } else if (expectedType.equals("BOOL")) { + mType = expectedType; + mValue = literal.getBooleanValue(); + } else { + throw new IllegalArgumentException(); + } + } else { + mType = "BOOL"; + mValue = literal.getBooleanValue(); + } + return new RewriterDataType().as(UUID.randomUUID().toString()).ofType(mType).asLiteral(mValue).consolidate(ctx); + default: + return null; // Not supported yet + } + } + + public static boolean executeScript(String script) { + try { + return DMLScript.executeScript(new String[]{"-s", script}); + } catch (Exception ex) { + ex.printStackTrace(); + return false; + } + } + + + /** + * Validates matrix dimensions to ensure that broadcasting still works afer the transformation + * @param hop1 the first HOP + * @param hop2 the second HOP + * @return if the new binary op would work in terms of broadcasting + */ + public static boolean validateBinaryBroadcasting(Hop hop1, Hop hop2) { + if (hop1.isMatrix() && hop2.isMatrix()) { + if (!hop1.dimsKnown() || !hop2.dimsKnown()) + return false; + + if (hop1.getDim1() == hop2.getDim1()) { + if (hop1.getDim2() == hop2.getDim2()) + return true; // Then both dimensions match + + return hop2.getDim2() == 1; // Otherwise we require a column vector + } else if (hop1.getDim2() == hop2.getDim2()) { + return hop2.getDim1() == 1; // We require a row vector + } + + // At least one dimension must match + return false; + } + + return true; + } + + public static boolean hasMatchingDims(Hop hop1, Hop hop2) { + return hop1.dimsKnown() && hop2.dimsKnown() && hop1.getDim1() == hop2.getDim1() && hop1.getDim2() == hop2.getDim2(); + } + + public static boolean hasMatchingDims(Hop... hops) { + if (hops.length < 2) + return true; + + for (Hop hop : hops) + if (!hop.dimsKnown()) + return false; + + long dim1 = hops[0].getDim1(); + long dim2 = hops[0].getDim2(); + + for (int i = 1; i < hops.length; i++) + if (hops[i].getDim1() != dim1 && hops[i].getDim2() != dim2) + return false; + + return true; + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java new file mode 100644 index 00000000000..faf2dbaea21 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java @@ -0,0 +1,1092 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter; + +import org.apache.commons.collections4.bidimap.DualHashBidiMap; +import org.apache.commons.lang3.NotImplementedException; +import org.apache.commons.lang3.function.TriFunction; +import org.apache.commons.lang3.mutable.MutableInt; +import org.apache.logging.log4j.util.TriConsumer; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; +import org.apache.sysds.hops.rewriter.estimators.RewriterCostEstimator; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.utils.StatementUtils; +import scala.Tuple2; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +public abstract class RewriterStatement { + public static final String META_VARNAME = "_varName"; + + + protected int rid = 0; + public int refCtr = 0; + protected long cost = -2; + + protected HashMap meta = null; + + + public static class MatchingSubexpression { + private final RewriterStatement expressionRoot; + private final RewriterStatement matchRoot; + private final RewriterPredecessor pred; + private final Map assocs; + private final List links; + public RewriterStatement newExprRoot; + + public MatchingSubexpression(RewriterStatement expressionRoot, RewriterStatement matchRoot, RewriterPredecessor pred, Map assocs, List links) { + this.expressionRoot = expressionRoot; + this.matchRoot = matchRoot; + this.pred = pred; + this.assocs = assocs; + this.links = links; + } + + public RewriterStatement getExpressionRoot() { + return expressionRoot; + } + + public RewriterStatement getMatchRoot() { + return matchRoot; + } + + public RewriterPredecessor getPredecessor() { + return pred; + } + + public Map getAssocs() { + return assocs; + } + + public List getLinks() { + return links; + } + + public RewriterStatement getNewExprRoot() { + return newExprRoot; + } + + public void setNewExprRoot(RewriterStatement exprRoot) { + newExprRoot = exprRoot; + } + } + + public static class MatcherContext { + final RuleContext ctx; + final boolean statementsCanBeVariables; + final boolean literalsCanBeVariables; + final boolean ignoreLiteralValues; + final boolean allowDuplicatePointers; + final boolean allowPropertyScan; + final boolean allowTypeHierarchy; + final boolean terminateOnFirstMatch; + final boolean findMinimalMismatchRoot; + final boolean traceVariableEliminations; + final boolean allowImplicitTypeConversions; + final Map ruleLinks; + final RewriterStatement expressionRoot; + final RewriterStatement thisExpressionRoot; + RewriterStatement matchRoot; + RewriterPredecessor pred; + + public RewriterStatement currentStatement; + + private Map dependencyMap; + private List links; + private DualHashBidiMap internalReferences; + + private List subMatches; + private Tuple2 firstMismatch; + private boolean debug; + private boolean assertionsFetched = false; + private RewriterAssertions assertionsThat; + private RewriterAssertions assertionsThis; + private Set dontVisitAgain; + + public MatcherContext(final RuleContext ctx, RewriterStatement matchRoot, RewriterStatement expressionRoot, RewriterStatement thisExpressionRoot) { + this(ctx, matchRoot, expressionRoot, thisExpressionRoot, false, false, false, false, false, false, false, false, false, Collections.emptyMap()); + } + + public MatcherContext(final RuleContext ctx, RewriterStatement matchRoot, RewriterStatement expressionRoot, RewriterStatement thisExpressionRoot, final boolean statementsCanBeVariables, final boolean literalsCanBeVariables, final boolean ignoreLiteralValues, final boolean allowDuplicatePointers, final boolean allowPropertyScan, final boolean allowTypeHierarchy, final boolean terminateOnFirstMatch, final boolean findMinimalMismatchRoot, boolean traceVariableEliminations, final Map ruleLinks) { + this.ctx = ctx; + this.matchRoot = matchRoot; + this.pred = new RewriterPredecessor(); + this.expressionRoot = expressionRoot; + this.thisExpressionRoot = thisExpressionRoot; + this.statementsCanBeVariables = statementsCanBeVariables; + this.currentStatement = matchRoot; + this.literalsCanBeVariables = literalsCanBeVariables; + this.ignoreLiteralValues = ignoreLiteralValues; + this.allowDuplicatePointers = allowDuplicatePointers; + this.allowPropertyScan = allowPropertyScan; + this.allowTypeHierarchy = allowTypeHierarchy; + this.terminateOnFirstMatch = terminateOnFirstMatch; + this.ruleLinks = ruleLinks; + this.findMinimalMismatchRoot = findMinimalMismatchRoot; + this.traceVariableEliminations = traceVariableEliminations; + this.allowImplicitTypeConversions = false; + this.debug = false; + } + + public MatcherContext(final RuleContext ctx, RewriterStatement matchRoot, RewriterPredecessor pred, RewriterStatement expressionRoot, RewriterStatement thisExprRoot, final boolean statementsCanBeVariables, final boolean literalsCanBeVariables, final boolean ignoreLiteralValues, final boolean allowDuplicatePointers, final boolean allowPropertyScan, final boolean allowTypeHierarchy, final boolean terminateOnFirstMatch, final boolean findMinimalMismatchRoot, boolean traceVariableEliminations, boolean allowImplicitTypeConversions, final Map ruleLinks) { + this.ctx = ctx; + this.matchRoot = matchRoot; + this.pred = pred; + this.expressionRoot = expressionRoot; + this.thisExpressionRoot = thisExprRoot; + this.currentStatement = matchRoot; + this.statementsCanBeVariables = statementsCanBeVariables; + this.literalsCanBeVariables = literalsCanBeVariables; + this.ignoreLiteralValues = ignoreLiteralValues; + this.allowDuplicatePointers = allowDuplicatePointers; + this.allowPropertyScan = allowPropertyScan; + this.allowTypeHierarchy = allowTypeHierarchy; + this.terminateOnFirstMatch = terminateOnFirstMatch; + this.ruleLinks = ruleLinks; + this.findMinimalMismatchRoot = findMinimalMismatchRoot; + this.traceVariableEliminations = traceVariableEliminations; + this.allowImplicitTypeConversions = allowImplicitTypeConversions; + this.debug = false; + } + + private void fetchAssertions() { + if (!assertionsFetched) { + assertionsThat = (RewriterAssertions) expressionRoot.getMeta("_assertions"); + assertionsThis = (RewriterAssertions) thisExpressionRoot.getMeta("_assertions"); + assertionsFetched = true; + } + } + + public boolean allowsImplicitTypeConversions() { + return allowImplicitTypeConversions; + } + + public void dontVisitAgain(RewriterStatement stmt) { + if (dontVisitAgain == null) { + dontVisitAgain = new HashSet<>(); + } + + dontVisitAgain.add(stmt); + } + + public boolean wasVisited(RewriterStatement stmt) { + if (dontVisitAgain == null) + return false; + + return dontVisitAgain.contains(stmt); + } + + public RewriterAssertions getOldAssertionsThat() { + fetchAssertions(); + + return assertionsThat; + } + + public RewriterAssertions getOldAssertionsThis() { + fetchAssertions(); + + return assertionsThis; + } + + public Map getDependencyMap() { + if (dependencyMap == null) + if (allowDuplicatePointers) + dependencyMap = new HashMap<>(); + else + dependencyMap = new DualHashBidiMap(); + return dependencyMap; + } + + public List getLinks() { + if (links == null) + links = new ArrayList<>(); + return links; + } + + public RewriterStatement findInternalReference(RewriterStatement stmt) { + if (internalReferences == null) + return null; + return internalReferences.get(stmt); + } + + public RewriterStatement findReverseInternalReference(RewriterStatement stmt) { + if (internalReferences == null) + return null; + return internalReferences.getKey(stmt); + } + + public Map getInternalReferences() { + if (internalReferences == null) + internalReferences = new DualHashBidiMap<>(); + return internalReferences; + } + + public List getSubMatches() { + if (subMatches == null) + return Collections.emptyList(); + return subMatches; + } + + public boolean hasSubMatches() { + return subMatches != null && !subMatches.isEmpty(); + } + + public void addSubMatch(MatcherContext matcherContext) { + if (subMatches == null) + subMatches = new ArrayList<>(); + subMatches.addAll(matcherContext.getFlattenedSubMatches()); + } + + public List getFlattenedSubMatches() { + if (hasSubMatches()) + return subMatches.stream().flatMap(mCtx -> mCtx.getFlattenedSubMatches().stream()).collect(Collectors.toList()); + return Collections.emptyList(); + } + + public MatchingSubexpression toMatch() { + return new MatchingSubexpression(expressionRoot, matchRoot, pred, getDependencyMap(), getLinks()); + } + + public void reset() { + if (dependencyMap != null) + dependencyMap.clear(); + if (links != null) + links.clear(); + if (internalReferences != null) + internalReferences.clear(); + } + + public void setFirstMismatch(RewriterStatement stmt1, RewriterStatement stmt2) { + firstMismatch = new Tuple2<>(stmt1, stmt2); + } + + public Tuple2 getFirstMismatch() { + return firstMismatch; + } + + public MatcherContext debug(boolean debug) { + this.debug = debug; + return this; + } + + public boolean match() { + return thisExpressionRoot.match(this); + } + + public boolean isDebug() { + return debug; + } + + public static MatcherContext exactMatch(final RuleContext ctx, RewriterStatement stmt, RewriterStatement thisExprRoot) { + return new MatcherContext(ctx, stmt, stmt, thisExprRoot); + } + + public static MatcherContext exactMatchWithDifferentLiteralValues(final RuleContext ctx, RewriterStatement stmt, RewriterStatement thisExprRoot) { + return new MatcherContext(ctx, stmt, stmt, thisExprRoot, false, false, true, false, false, false, false, false, false, Collections.emptyMap()); + } + + public static MatcherContext findMinimalDifference(final RuleContext ctx, RewriterStatement stmt, RewriterStatement thisExpressionRoot) { + return new MatcherContext(ctx, stmt, stmt, thisExpressionRoot, false, false, true, false, false, false, false, true, false, Collections.emptyMap()); + } + } + + public static final class RewriterPredecessor { + private final Object obj; + private final Object meta; + + // Use iff the element is already the root + public RewriterPredecessor() { + obj = null; + meta = null; + } + + public RewriterPredecessor(RewriterStatement parent, Integer idx) { + obj = parent; + meta = idx; + } + + // Use iff the element is a meta object + public RewriterPredecessor(RewriterStatement parent, String meta) { + obj = parent; + this.meta = meta; + } + + public RewriterPredecessor(RewriterAssertions assertions, RewriterAssertions.RewriterAssertion assertion) { + obj = assertions; + meta = assertion; + } + + public boolean isOperand() { + return obj instanceof RewriterStatement && meta instanceof Integer; + } + + public boolean isRoot() { + return obj == null && meta == null; + } + + public boolean isMetaObject() { + return obj instanceof RewriterStatement && meta instanceof String; + } + + public boolean isAssertionObject() { + return obj instanceof RewriterAssertions && meta instanceof RewriterAssertions.RewriterAssertion; + } + + public RewriterStatement getParent() { + return (RewriterStatement) obj; + } + + public RewriterAssertions getAssertions() { + return (RewriterAssertions) obj; + } + + public RewriterAssertions.RewriterAssertion getAssertion() { + return (RewriterAssertions.RewriterAssertion) meta; + } + + public String getMetaKey() { + return (String) meta; + } + + public int getIndex() { + return (Integer) meta; + } + } + + public static enum ReferenceType { + ROOT, OPERAND, NCOL, NROW, BACKREF, ASSERTION + } + + public static class RewriterStatementReference { + public final ReferenceType referenceType; + public final RewriterStatement stmt; + public final Object parentRef; + public final Object ref; + + // TODO: What about root? + public RewriterStatementReference(ReferenceType type, RewriterStatement stmt, RewriterStatement parentRef) { + this.referenceType = type; + this.stmt = stmt; + this.parentRef = parentRef; + this.ref = null; + } + + public RewriterStatementReference(RewriterStatement stmt, RewriterStatement parentRef, int idx) { + this.referenceType = parentRef == null ? ReferenceType.ROOT : ReferenceType.OPERAND; + this.stmt = stmt; + this.parentRef = parentRef; + this.ref = idx; + } + + public RewriterStatementReference(RewriterStatement stmt, RewriterAssertions assertions, RewriterAssertions.RewriterAssertion assertion) { + this.referenceType = ReferenceType.ASSERTION; + this.stmt = stmt; + this.parentRef = assertions; + this.ref = assertion; + } + + public void replace(RewriterStatement newStmt) { + switch (referenceType) { + case ROOT: + throw new NotImplementedException(); + case OPERAND: + ((RewriterStatement) parentRef).getOperands().set((Integer)ref, newStmt); + break; + case NCOL: + ((RewriterStatement) parentRef).unsafePutMeta("ncol", newStmt); + break; + case NROW: + ((RewriterStatement) parentRef).unsafePutMeta("nrow", newStmt); + break; + case BACKREF: + ((RewriterStatement) parentRef).unsafePutMeta("backRef", newStmt); + break; + case ASSERTION: + ((RewriterAssertions) parentRef).replaceAssertionContent(stmt, newStmt, (RewriterAssertions.RewriterAssertion) ref); + break; + } + } + } + + public abstract String getId(); + public abstract String getResultingDataType(final RuleContext ctx); + public abstract boolean isLiteral(); + public abstract Object getLiteral(); + public abstract RewriterStatement getLiteralStatement(); + public long intLiteral() { + return intLiteral(false); + } + public abstract long intLiteral(boolean cast); + public abstract double floatLiteral(); + public abstract boolean boolLiteral(); + + public void setLiteral(Object literal) { + throw new IllegalArgumentException("This class does not support setting literals"); + } + public abstract RewriterStatement consolidate(final RuleContext ctx); + public abstract boolean isConsolidated(); + @Deprecated + public abstract RewriterStatement clone(); + public abstract RewriterStatement copyNode(); + // Performs a nested copy until a condition is met + public abstract RewriterStatement nestedCopyOrInject(Map copiedObjects, TriFunction injector, RewriterStatement parent, int pIdx); + // Returns the new maxRefId + public abstract int toParsableString(StringBuilder builder, Map refs, int maxRefId, Map> vars, Set forceCreateRefs, final RuleContext ctx); + public abstract void refreshReturnType(final RuleContext ctx); + protected abstract void compress(RewriterAssertions assertions); + + public static String parsableDefinitions(Map> defs) { + StringBuilder sb = new StringBuilder(); + defs.forEach((k, v) -> { + sb.append(k); + sb.append(':'); + + int i = 0; + for (String varName : v) { + if (i > 0) + sb.append(','); + + sb.append(varName); + i++; + } + + sb.append('\n'); + }); + + return sb.toString(); + } + + public String toParsableString(final RuleContext ctx, Map> defs) { + return toParsableString(ctx, defs, Collections.emptySet()); + } + + public String toParsableString(final RuleContext ctx, Map> defs, Set forceCreateRefs) { + StringBuilder sb = new StringBuilder(); + toParsableString(sb, new HashMap<>(), 0, defs, forceCreateRefs, ctx); + return sb.toString(); + } + + public String toParsableString(final RuleContext ctx, boolean includeDefinitions) { + return toParsableString(ctx, includeDefinitions, Collections.emptySet()); + } + + public String toParsableString(final RuleContext ctx, boolean includeDefinitions, Set forceCreateRefs) { + StringBuilder sb = new StringBuilder(); + HashMap> defs = new HashMap<>(); + toParsableString(sb, new HashMap<>(), 0, defs, forceCreateRefs, ctx); + + if (includeDefinitions) + return parsableDefinitions(defs) + sb; + + return sb.toString(); + } + + public String toParsableString(final RuleContext ctx) { + return toParsableString(ctx, false); + } + + public RewriterStatement nestedCopyOrInject(Map copiedObjects, TriFunction injector) { + return nestedCopyOrInject(copiedObjects, injector, null, -1); + } + + public RewriterStatement nestedCopyOrInject(Map copiedObjects, Function injector) { + return nestedCopyOrInject(copiedObjects, (el, parent, pIdx) -> injector.apply(el), null, -1); + } + + public RewriterStatement nestedCopy(boolean copyAssertions) { + return nestedCopy(copyAssertions, new HashMap<>()); + } + + public RewriterStatement nestedCopy(boolean copyAssertions, Map createdObjects) { + RewriterStatement cpy = nestedCopyOrInject(createdObjects, el -> null); + + if (copyAssertions) { + RewriterAssertions assertions = (RewriterAssertions) getMeta("_assertions"); + + if (assertions != null) { + cpy.unsafePutMeta("_assertions", RewriterAssertions.copy(assertions, createdObjects, true)); + } + } else { + cpy.unsafeRemoveMeta("_assertions"); + } + + return cpy; + } + + // Returns the root of the matching sub-statement, null if there is no match + public abstract boolean match(MatcherContext matcherContext); + + public abstract int recomputeHashCodes(boolean recursively, final RuleContext ctx); + public abstract RewriterStatement simplify(final RuleContext ctx); + public abstract RewriterStatement as(String id); + public abstract String toString(final RuleContext ctx); + public abstract boolean isArgumentList(); + public abstract List getArgumentList(); + public abstract boolean isInstruction(); + public abstract boolean isEClass(); + public abstract String trueInstruction(); + public abstract String trueTypedInstruction(final RuleContext ctx); + public abstract String trueTypedInstruction(boolean allowImplicitConversions, final RuleContext ctx); + public abstract int structuralHashCode(); + public abstract RewriterStatement rename(String id); + public void prepareDefinitions(final RuleContext ctx, final List strDefs, final Set varDefs) { + if (getMeta(META_VARNAME) != null) + return; + + if (getOperands() != null) + getOperands().forEach(op -> op.prepareDefinitions(ctx, strDefs, varDefs)); + + if (this instanceof RewriterInstruction) { + RewriterInstruction self = ((RewriterInstruction) this); + // Check if it is necessary to define variables + if (refCtr > 1 || self.trueInstruction().equals("_asVar")) { + Pattern pattern = Pattern.compile("[a-zA-Z0-9_]+"); + String instr = pattern.matcher(self.getInstr()).matches() ? self.getInstr() : "tmp"; + instr = instr.replace("_", ""); + String varName = "var_" + instr + "_"; + + int ctr = 1; + while (varDefs.contains(varName + ctr)) + ctr++; + + strDefs.add(varName + ctr + " = " + toString(ctx)); + varDefs.add(varName + ctr); + unsafePutMeta(META_VARNAME, varName + ctr); + } + } + } + + public void eraseDefinitions() { + unsafeRemoveMeta(META_VARNAME); + + if (getOperands() != null) + getOperands().forEach(RewriterStatement::eraseDefinitions); + } + + public List getOperands() { + return Collections.emptyList(); + } + + public int recomputeHashCodes(final RuleContext ctx) { + return recomputeHashCodes(true, ctx); + } + + public void prepareForHashing() { + resetRefCtrs(); + computeRefCtrs(); + resetIds(); + computeIds(1); + } + + protected void resetRefCtrs() { + refCtr = 0; + if (getOperands() != null) + getOperands().forEach(RewriterStatement::resetRefCtrs); + } + + protected void computeRefCtrs() { + refCtr++; + if (refCtr < 2 && getOperands() != null) + getOperands().forEach(RewriterStatement::computeRefCtrs); + } + + protected void resetIds() { + rid = 0; + if (getOperands() != null) + getOperands().forEach(RewriterStatement::resetIds); + } + + protected int computeIds(int id) { + rid = id++; + + if (getOperands() != null) { + for (RewriterStatement stmt : getOperands()) + id = stmt.computeIds(id); + } + + return id; + } + + /** + * Traverses the DAG in-order. If nodes with multiple parents exist, those are visited multiple times. + * If the function returns false, the sub-DAG of the current node will not be traversed. + * @param function test + */ + @Deprecated + public void forEachPreOrderWithDuplicates(Function function) { + if (function.apply(this) && getOperands() != null) + for (int i = 0; i < getOperands().size(); i++) + getOperands().get(i).forEachPreOrderWithDuplicates(function); + } + + public void forEachPreOrder(Function function, boolean includeMeta) { + forEachPreOrder((el, pred) -> function.apply(el), includeMeta); + } + + public void forEachPreOrder(BiFunction function, boolean includeMeta) { + forEachPreOrder(function, new HashSet<>(), new RewriterPredecessor(), includeMeta); + } + + // We will also include metadata + private void forEachPreOrder(BiFunction function, Set visited, RewriterPredecessor pred, boolean includeMeta) { + if (!visited.add(this)) + return; + + if (function.apply(this, pred)) { + for (int i = 0; i < getOperands().size(); i++) + getOperands().get(i).forEachPreOrder(function, visited, new RewriterPredecessor(this, i), includeMeta); + + if (includeMeta) + forEachMetaObject((stmt, mPred) -> stmt.forEachPreOrder(function, visited, mPred, includeMeta)); + } + } + + public void forEachPostOrder(BiConsumer consumer, boolean includeMeta) { + forEachPostOrder(consumer, new HashSet<>(), new RewriterPredecessor(), includeMeta); + } + + private void forEachPostOrder(BiConsumer consumer, Set visited, RewriterPredecessor pred, boolean includeMeta) { + if (!visited.add(this)) + return; + + if (getOperands() != null) + for (int i = 0; i < getOperands().size(); i++) + getOperands().get(i).forEachPostOrder(consumer, visited, new RewriterPredecessor(this, i), includeMeta); + + if (includeMeta) + forEachMetaObject((stmt, mPred) -> stmt.forEachPostOrder(consumer, visited, mPred, includeMeta)); + + consumer.accept(this, pred); + } + + @Deprecated + public void forEachPostOrderWithDuplicates(TriConsumer consumer) { + forEachPostOrderWithDuplicates(consumer, null, -1); + } + + @Deprecated + private void forEachPostOrderWithDuplicates(TriConsumer consumer, RewriterStatement parent, int pIdx) { + for (int i = 0; i < getOperands().size(); i++) + getOperands().get(i).forEachPostOrderWithDuplicates(consumer, this, i); + + consumer.accept(this, parent, pIdx); + } + + public void putMeta(String key, Object value) { + if (isConsolidated()) + throw new IllegalArgumentException("An instruction cannot be modified after consolidation"); + + if (meta == null) + meta = new HashMap<>(); + + meta.put(key, value); + } + + public void unsafePutMeta(String key, Object value) { + if (isLiteral()) + throw new UnsupportedOperationException("Cannot put meta for literals"); + + if (meta == null) + meta = new HashMap<>(); + + meta.put(key, value); + } + + public void unsafeRemoveMeta(String key) { + if (meta == null) + return; + + meta.remove(key); + + if (meta.isEmpty()) + meta = null; + } + + public Object getMeta(String key) { + if (meta == null) + return null; + + return meta.get(key); + } + + public long getCost() { + if (!isInstruction()) + return 0; + + return cost; + } + + public RewriterAssertions getAssertions(final RuleContext ctx) { + RewriterAssertions assertions = (RewriterAssertions) getMeta("_assertions"); + if (assertions == null) { + assertions = new RewriterAssertions(ctx); + if (!isLiteral()) // Otherwise the assertion object will just be temporary + unsafePutMeta("_assertions", assertions); + } + + return assertions; + } + + public RewriterStatement getNCol() { + return (RewriterStatement) getMeta("ncol"); + } + + public RewriterStatement getNRow() { + return (RewriterStatement) getMeta("nrow"); + } + + public RewriterStatement getBackRef() { + return (RewriterStatement) getMeta("_backRef"); + } + + public RewriterStatement getChild(int index) { + return getOperands().get(index); + } + + public RewriterStatement getChild(int... indices) { + RewriterStatement current = this; + + for (int i = 0; i < indices.length; i++) + current = current.getOperands().get(indices[i]); + + return current; + } + + // This can only be called from the root expression to add a new assertion manually + public RewriterStatement givenThatEqualDimensions(RewriterStatement stmt1, RewriterStatement stmt2, final RuleContext ctx) { + getAssertions(ctx).addEqualityAssertion(stmt1.getNRow(), stmt2.getNRow(), this); + getAssertions(ctx).addEqualityAssertion(stmt1.getNCol(), stmt2.getNCol(), this); + return this; + } + + // This can only be called from the root expression to add a new assertion manually + public RewriterStatement givenThatEqual(RewriterStatement stmt1, RewriterStatement stmt2, final RuleContext ctx) { + return givenThatEqual(stmt1, stmt2, this, ctx); + } + + public RewriterStatement givenThatEqual(RewriterStatement stmt1, RewriterStatement stmt2, RewriterStatement exprRoot, final RuleContext ctx) { + getAssertions(ctx).addEqualityAssertion(stmt1, stmt2, exprRoot); + return this; + } + + public RewriterStatement recomputeAssertions() { + RewriterAssertions assertions = (RewriterAssertions) getMeta("_assertions"); + + if (assertions != null) + return assertions.update(this); + + return this; + } + + public static void transferMeta(RewriterRule.ExplicitLink link) { + if (link.oldStmt instanceof RewriterInstruction) { + for (RewriterStatement mNew : link.newStmt) { + if (mNew instanceof RewriterInstruction && + !((RewriterInstruction)mNew).trueInstruction().equals(((RewriterInstruction)link.oldStmt).trueInstruction())) { + ((RewriterInstruction) mNew).unsafeSetInstructionName(((RewriterInstruction)link.oldStmt).trueInstruction()); + } + } + } + + if (link.oldStmt.meta != null) { + link.newStmt.forEach(stmt -> { + HashMap newMap = new HashMap<>(link.oldStmt.meta); + stmt.overwriteImplicitMetaObjects(newMap); + stmt.meta = newMap; + }); + } + else + link.newStmt.forEach(RewriterStatement::cleanupMeta/*stmt.meta = null*/); + } + + public void moveRootTo(RewriterStatement newRoot) { + RewriterAssertions assertions = (RewriterAssertions) getMeta("_assertions"); + + if (assertions != null && !newRoot.isLiteral()) + newRoot.unsafePutMeta("_assertions", assertions); + } + + private void overwriteImplicitMetaObjects(Map map) { + RewriterAssertions assertions = (RewriterAssertions) getMeta("_assertions"); + RewriterStatement ncol = getNCol(); + RewriterStatement nrow = getNRow(); + RewriterStatement backref = getBackRef(); + + if (assertions != null) + map.put("_assertions", assertions); + + if (ncol != null) + map.put("ncol", ncol); + + if (nrow != null) + map.put("nrow", nrow); + + if (backref != null) + map.put("_backRef", backref); + } + + private void cleanupMeta() { + if (meta == null) + return; + + RewriterAssertions assertions = (RewriterAssertions) getMeta("_assertions"); + RewriterStatement ncol = getNCol(); + RewriterStatement nrow = getNRow(); + RewriterStatement backref = getBackRef(); + + if (assertions == null && ncol == null && nrow == null && backref == null) + return; + + meta = new HashMap<>(); + + if (assertions != null) + meta.put("_assertions", assertions); + + if (ncol != null) + meta.put("ncol", ncol); + + if (nrow != null) + meta.put("nrow", nrow); + + if (backref != null) + meta.put("_backRef", ncol); + } + + @Override + public String toString() { + return toString(RuleContext.currentContext); + } + + public boolean isColVector() { + RewriterStatement nrow = getNRow(); + + if (nrow == null) + return false; + + if (nrow.isLiteral() && nrow.getLiteral().equals(1L)) + return true; + + if (nrow.isEClass() && nrow.getChild(0).getOperands().stream().anyMatch(el -> el.isLiteral() && el.getLiteral().equals(1L))) + return true; + + return false; + } + + public boolean isRowVector() { + RewriterStatement ncol = getNCol(); + + if (ncol == null) + return false; + + if (ncol.isLiteral() && ncol.getLiteral().equals(1L)) + return true; + + if (ncol.isEClass() && ncol.getChild(0).getOperands().stream().anyMatch(el -> el.isLiteral() && el.getLiteral().equals(1L))) + return true; + + return false; + } + + public List toExecutableString(final RuleContext ctx) { + ArrayList defList = new ArrayList<>(); + prepareDefinitions(ctx, defList, new HashSet<>()); + defList.add(toString(ctx)); + eraseDefinitions(); + + return defList; + } + + public void compress() { + RewriterAssertions assertions = (RewriterAssertions) getMeta("_assertions"); + this.forEachPostOrder((cur, pred) -> { + cur.compress(assertions); + }, true); + } + + public long getCost(final RuleContext ctx) { + if (!this.isInstruction()) + return 0; + + if (cost != -2) + return cost; + + try { + cost = RewriterCostEstimator.estimateCost(this, ctx); + } catch (Exception e) { + cost = -1L; + } + + return cost; + } + + // This may create cycles if visited objects are not tracked + public void forEachMetaObject(BiConsumer consumer) { + RewriterStatement backref = getBackRef(); + RewriterAssertions assertions = (RewriterAssertions) getMeta("_assertions"); + + if (backref != null) + consumer.accept(backref, new RewriterPredecessor(this, "_backRef")); + if (assertions != null) + assertions.forEachAssertionContents(consumer); + } + + public void updateMetaObjects(Function f) { + RewriterStatement backref = getBackRef(); + + RewriterStatement mNew; + + if (backref != null) { + mNew = f.apply(backref); + + if (backref != mNew) + unsafePutMeta("_backRef", backref); + } + + RewriterAssertions assertions = (RewriterAssertions) getMeta("_assertions"); + + if (assertions != null) + assertions.updateAssertionContents(f); + } + + protected void nestedCopyOrInjectMetaStatements(Map copiedObjects, TriFunction injector) { + if (getNCol() != null) { + unsafePutMeta("ncol", getNCol().nestedCopyOrInject(copiedObjects, injector, this, -1)); + } + + if (getNRow() != null) + unsafePutMeta("nrow", getNRow().nestedCopyOrInject(copiedObjects, injector, this, -1)); + + RewriterStatement backRef = (RewriterStatement) getMeta("_backRef"); + + if (backRef != null) + unsafePutMeta("_backRef", backRef.nestedCopyOrInject(copiedObjects, injector, this, -1)); + + RewriterAssertions assertions = (RewriterAssertions) getMeta("_assertions"); + + if (assertions != null) { + assertions = assertions.nestedCopyOrInject(copiedObjects, injector, this); + unsafePutMeta("_assertions", assertions); + } + } + + // This returns a stream of all children including metadata and assertions if available + // This may contain loops in case of back references + public Stream> allChildren() { + Stream> stream = IntStream.range(0, getOperands().size()).mapToObj(i -> new Tuple2<>(getOperands().get(i), new RewriterPredecessor(this, i))); + RewriterStatement ncol = getNCol(); + RewriterStatement nrow = getNRow(); + RewriterStatement backRef = getBackRef(); + + if (ncol != null) + stream = Stream.concat(stream, Stream.of(new Tuple2<>(ncol, new RewriterPredecessor(this, "ncol")))); + if (nrow != null) + stream = Stream.concat(stream, Stream.of(new Tuple2<>(nrow, new RewriterPredecessor(this, "nrow")))); + if (backRef != null) + stream = Stream.concat(stream, Stream.of(new Tuple2<>(backRef, new RewriterPredecessor(this, "_backRef")))); + + RewriterAssertions assertions = (RewriterAssertions) getMeta("_assertions"); + + if (assertions != null) + stream = Stream.concat(stream, assertions.streamOfContents()); + + return stream; + } + + public boolean isDataOrigin() { + if (!isInstruction()) + return true; + + switch (trueInstruction()) { + case "rowVec": + case "colVec": + case "const": + return true; + } + + return false; + } + + public int countInstructions() { + MutableInt i = new MutableInt(); + forEachPreOrder(cur -> { + if (!cur.isDataOrigin() || cur.isLiteral()) { + i.increment(); + } + return true; + }, false); + return i.getAndIncrement(); + } + + public static RewriterStatement argList(final RuleContext ctx, RewriterStatement... args) { + return new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("argList").withOps(args).consolidate(ctx); + } + + public static RewriterStatement argList(final RuleContext ctx, List args) { + return argList(ctx, args.toArray(RewriterStatement[]::new)); + } + + public static RewriterStatement castFloat(final RuleContext ctx, RewriterStatement stmt) { + return new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("cast.FLOAT").withOps(stmt).consolidate(ctx); + } + + public static RewriterStatement nnz(RewriterStatement of, final RuleContext ctx) { + return nnz(of, ctx, false); + } + + public static RewriterStatement nnz(RewriterStatement of, final RuleContext ctx, boolean treatAsDense) { + if (treatAsDense) + return StatementUtils.length(ctx, of); + return new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("_nnz").withOps(of).consolidate(ctx); + } + + public static RewriterStatement literal(final RuleContext ctx, Object literal) { + if (literal == null) + throw new IllegalArgumentException(); + + if (literal instanceof Double) { // We need to differentiate between -0.0 and 0.0 because otherwise this may leed to bugs + return new RewriterDataType().as(literal.toString()).ofType("FLOAT").asLiteral(((Double) literal).doubleValue() == -0.0 ? 0.0 : literal).consolidate(ctx); + } else if (literal instanceof Long) { + return new RewriterDataType().as(literal.toString()).ofType("INT").asLiteral(literal).consolidate(ctx); + } else if (literal instanceof Boolean) { + return new RewriterDataType().as(literal.toString()).ofType("BOOL").asLiteral(literal).consolidate(ctx); + } + + throw new IllegalArgumentException(); + } + + public static RewriterStatement multiArgInstr(final RuleContext ctx, String instrName, RewriterStatement... ops) { + RewriterStatement argList = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("argList").withOps(ops).consolidate(ctx); + return new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction(instrName).withOps(argList).consolidate(ctx); + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatementEntry.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatementEntry.java new file mode 100644 index 00000000000..80daebebc32 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatementEntry.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter; + +import java.util.HashMap; + +public class RewriterStatementEntry { + private final RuleContext ctx; + final RewriterStatement instr; + + public RewriterStatementEntry(final RuleContext ctx, RewriterStatement instr) { + this.ctx = ctx; + this.instr = instr; + } + + @Override + public int hashCode() { + return instr.structuralHashCode(); + } + + @Override + public boolean equals(Object o) { + if (o instanceof RewriterStatement) { + if (instr == o) + return true; + if (instr.structuralHashCode() != ((RewriterStatement)o).structuralHashCode()) + return false; + return instr.match(new RewriterStatement.MatcherContext(ctx, (RewriterStatement) o, new RewriterStatement.RewriterPredecessor(), (RewriterStatement) o, instr, false, false, false, false, false, false, true, false, false, false, new HashMap<>())); + } + + if (o.hashCode() != hashCode()) + return false; + + if (o instanceof RewriterStatementEntry) { + if (instr == ((RewriterStatementEntry) o).instr) + return true; + return instr.match(new RewriterStatement.MatcherContext(ctx, ((RewriterStatementEntry) o).instr, new RewriterStatement.RewriterPredecessor(), ((RewriterStatementEntry) o).instr, instr, false, false, false, false, false, false, true, false, false, false, new HashMap<>())); + } + return false; + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RuleContext.java b/src/main/java/org/apache/sysds/hops/rewriter/RuleContext.java new file mode 100644 index 00000000000..978cb62501f --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/RuleContext.java @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.BiFunction; +import java.util.function.Function; + +public class RuleContext { + public static RuleContext currentContext; + + public HashMap, Long>> instrCosts = new HashMap<>(); + + public HashMap instrTypes = new HashMap<>(); + + public HashMap> simplificationRules = new HashMap<>(); + + public HashMap> instrProperties = new HashMap<>(); + + public HashMap> typeHierarchy = new HashMap<>(); + + public HashMap> customStringRepr = new HashMap<>(); + + public Function metaPropagator = null; + + public static RuleContext floatArithmetic = new RuleContext(); + public static RuleContext selectionPushdownContext = new RuleContext(); + + static { + floatArithmetic.instrCosts.put("+(float,float)", d -> 1l); + floatArithmetic.instrCosts.put("*(float,float)", d -> 1l); + + floatArithmetic.instrTypes.put("+(float,float)", "float"); + floatArithmetic.instrTypes.put("*(float,float)", "float"); + + floatArithmetic.simplificationRules.put("+(float,float)", i -> { + RewriterStatement op1 = i.getOperands().get(0); + RewriterStatement op2 = i.getOperands().get(1); + + if (op1.isLiteral() && op2.isLiteral()) { + op1.setLiteral(((Float)op1.getLiteral()) + ((Float)op2.getLiteral())); + return op1; + } + + return null; + }); + floatArithmetic.simplificationRules.put("*(float, float)", i -> { + RewriterStatement op1 = i.getOperands().get(0); + RewriterStatement op2 = i.getOperands().get(1); + + if (op1.isLiteral() && op2.isLiteral()) { + op1.setLiteral(((Float)op1.getLiteral()) * ((Float)op2.getLiteral())); + return op1; + } + + return null; + }); + + selectionPushdownContext.instrCosts.put("RowSelectPushableBinaryInstruction(MATRIX,MATRIX)", d -> 1l); // Just temporary costs + selectionPushdownContext.instrTypes.put("RowSelectPushableBinaryInstruction(MATRIX,MATRIX)", "MATRIX"); + selectionPushdownContext.instrCosts.put("rowSelect(MATRIX,INT,INT)", d -> 1l); + selectionPushdownContext.instrTypes.put("rowSelect(MATRIX,INT,INT)", "MATRIX"); + selectionPushdownContext.instrCosts.put("min(INT,INT)", d -> 1l); + selectionPushdownContext.instrTypes.put("min(INT,INT)", "INT"); + selectionPushdownContext.instrCosts.put("max(INT,INT)", d -> 1l); + selectionPushdownContext.instrTypes.put("max(INT,INT)", "INT"); + + selectionPushdownContext.instrCosts.put("+(MATRIX,MATRIX)", d -> 1l); + selectionPushdownContext.instrTypes.put("+(MATRIX,MATRIX)", "MATRIX"); + } + + public static RuleContext createContext(String contextString) { + RuleContext ctx = new RuleContext(); + HashMap instrTypes = ctx.instrTypes; + HashMap> instrProps = ctx.instrProperties; + String[] lines = contextString.split("\n"); + String fName = null; + String fArgTypes = null; + String fReturnType = null; + for (String line : lines) { + line = line.replaceFirst("^\\s+", ""); + if (line.isEmpty()) + continue; + + if (line.startsWith("impl")) { + if (fArgTypes == null || fReturnType == null) + throw new IllegalArgumentException(); + String newFName = line.substring(4).replace(" ", ""); + if (newFName.isEmpty()) + throw new IllegalArgumentException(); + + instrTypes.put(newFName + fArgTypes, fReturnType); + + final String propertyFunction = fName + fArgTypes; + + if (instrProps.containsKey(newFName + fArgTypes)) { + HashSet props = instrProps.get(newFName + fArgTypes); + props.add(propertyFunction); + props.add(fName); + } else { + HashSet mset = new HashSet<>(); + mset.add(propertyFunction); + mset.add(fName); + instrProps.put(newFName + fArgTypes, mset); + } + + ctx.instrCosts.put(newFName + fArgTypes, d -> 1L); + } else if (line.startsWith("dtype ")) { + String[] dTypeStr = line.substring(6).split("::"); + if (dTypeStr.length > 1) { + Set mSet = ctx.typeHierarchy.compute(dTypeStr[0], (k, v) -> v == null ? new HashSet<>() : v); + for (int i = 1; i < dTypeStr.length; i++) + mSet.add(dTypeStr[i]); + } + + } else { + String[] keyVal = readFunctionDefinition(line); + fName = keyVal[0]; + fArgTypes = keyVal[1]; + fReturnType = keyVal[2]; + instrTypes.put(fName + fArgTypes, fReturnType); + ctx.instrCosts.put(fName + fArgTypes, d -> 1L); + } + } + + // Resolve transitive function properties + boolean changed = true; + while (changed) { + changed = false; + for (Map.Entry> pair : instrProps.entrySet()) { + HashSet toAdd = new HashSet<>(); + for (String propertyFunction : pair.getValue()) { + if (instrProps.containsKey(propertyFunction)) + toAdd.addAll(instrProps.get(propertyFunction)); + } + + changed |= pair.getValue().addAll(toAdd); + } + } + + changed = true; + while (changed) { + changed = false; + for (Map.Entry> pair : ctx.typeHierarchy.entrySet()) { + HashSet toAdd = new HashSet<>(); + for (String superTypes : pair.getValue()) { + if (instrProps.containsKey(superTypes)) + toAdd.addAll(instrProps.get(superTypes)); + } + + changed |= pair.getValue().addAll(toAdd); + } + } + + return ctx; + } + + public static String[] readFunctionDefinition(String line) { + int leftParanthesisIdx = line.indexOf('('); + + if (leftParanthesisIdx == -1) + throw new IllegalArgumentException(); + + String fName = line.substring(0, leftParanthesisIdx).replace(" ", ""); + String rest = line.substring(leftParanthesisIdx+1); + + int parenthesisCloseIdx = rest.indexOf(')'); + + if (parenthesisCloseIdx == -1) + throw new IllegalArgumentException(); + + String argsStr = rest.substring(0, parenthesisCloseIdx); + String[] args = argsStr.split(","); + + args = Arrays.stream(args).map(arg -> arg.replace(" ", "")).toArray(String[]::new); + + if (args.length != 1 && Arrays.stream(args).anyMatch(String::isEmpty)) + throw new IllegalArgumentException(); + + if (!rest.substring(parenthesisCloseIdx+1, parenthesisCloseIdx+3).equals("::")) + throw new IllegalArgumentException(); + + String returnDataType = rest.substring(parenthesisCloseIdx+3); + return new String[] { fName, "(" + String.join(",", args) + ")", returnDataType }; + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/TopologicalSort.java b/src/main/java/org/apache/sysds/hops/rewriter/TopologicalSort.java new file mode 100644 index 00000000000..94bf6f029fb --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/TopologicalSort.java @@ -0,0 +1,543 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter; + +import org.apache.commons.lang3.mutable.MutableInt; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.function.BiFunction; + +// We assume that _argList() will have one unique parent +public class TopologicalSort { + protected static final Log LOG = LogFactory.getLog(TopologicalSort.class.getName()); + + public static boolean DEBUG = false; + + // All of these operators are sortable with argument lists (e.g. +(argList(1, 2, 3)) + private static final Set SORTABLE_ARGLIST_OPS = Set.of("+", "*", "_idxExpr", "_EClass", "rand", "_dummy"); + // All of these operators are sortable but have their operands directly as children (e.g. ==(a,b)) + private static final Set SORTABLE_OPS = Set.of("==", "!="); + + public static void sort(RewriterStatement root, final RuleContext ctx) { + sort(root, (el, parent) -> { + if (!el.isInstruction()) + return false; + + if (el.isArgumentList()) + return parent != null && SORTABLE_ARGLIST_OPS.contains(parent.trueInstruction()); + + return SORTABLE_OPS.contains(el.trueInstruction()); + }, ctx); + } + + public static void sort(RewriterStatement root, BiFunction isArrangable, final RuleContext ctx) { + // First, we setup an artificial root node to be able to sort E-Classes that are only included as meta-info not directly in the operand structure + Set hiddenEClasses = new HashSet<>(); + root.forEachPostOrder((stmt, pred) -> { + if (stmt instanceof RewriterDataType && !stmt.isLiteral() && stmt.getResultingDataType(ctx).equals("MATRIX")) { + if (stmt.getNRow().isInstruction() && stmt.getNRow().trueInstruction().equals("_EClass")) + hiddenEClasses.add(stmt.getNRow()); + + if (stmt.getNCol().isInstruction() && stmt.getNCol().trueInstruction().equals("_EClass")) + hiddenEClasses.add(stmt.getNCol()); + } + }, true); + + RewriterStatement oldRoot = root; + + if (!hiddenEClasses.isEmpty()) { + RewriterStatement argList = new RewriterInstruction().withInstruction("argList").withOps(hiddenEClasses.toArray(RewriterStatement[]::new)); + RewriterStatement dummy = new RewriterInstruction().withInstruction("_dummy").withOps(argList); + root = new RewriterInstruction().withInstruction("_root").withOps(root, dummy); + } + + List uncertainParents = setupOrderFacts(root, isArrangable, ctx); + + buildAddresses(root, ctx); + resolveAmbiguities(root, ctx, uncertainParents); + resetAddresses(uncertainParents); + + int factCtr = 0; + + // Now, we start introducing facts for the lowest level unordered sets + Set lowestUncertainties = findLowestUncertainties(root); + int ctr = 0; + + while (!lowestUncertainties.isEmpty()) { + if (DEBUG) { + LOG.trace("Uncertainties after iteration " + ctr + ": " + lowestUncertainties.size()); + LOG.trace("Lowest uncertainties: " + lowestUncertainties); + } + + factCtr = introduceFacts(lowestUncertainties, factCtr); + buildAddresses(root, ctx); + + if (DEBUG) { + LOG.trace("Built addresses:"); + for (UnorderedSet u : lowestUncertainties) { + for (RewriterStatement s : u.contents) { + LOG.trace("- " + s + " :: " + getAddress(s)); + } + } + } + + resolveAmbiguities(root, ctx, uncertainParents); + resetAddresses(uncertainParents); + + lowestUncertainties = findLowestUncertainties(root); + ctr++; + + if (ctr > 100) + throw new RuntimeException("Could not finish sorting process for expression:\n" + root.toParsableString(ctx)); // Should never get here but just to make sure + } + + // At the end + if (DEBUG) + LOG.trace("Before construction: " + oldRoot.toParsableString(ctx)); + constructNewDAG(oldRoot, ctx); + if (DEBUG) + LOG.trace("After construction: " + oldRoot.toParsableString(ctx)); + } + + // Returns all uncertain parents ordered in post order (elements without uncertain sub-DAGs come first in the list) + private static List setupOrderFacts(RewriterStatement root, BiFunction isArrangable, final RuleContext ctx) { + List uncertainParents = new ArrayList<>(); + + // Create a random global order which will be used for indistinguishable sub-DAGs + MutableInt nameCtr = new MutableInt(0); + root.forEachPostOrder((el, pred) -> { + if (el.isLiteral()) + return; + + el.unsafePutMeta("_tempName", nameCtr.intValue()); + nameCtr.increment(); + boolean arrangable = isArrangable.apply(el, pred.getParent()); + + el.unsafePutMeta("_arrangable", arrangable); + }, false); + + // Try to establish a first order + root.forEachPostOrder((el, pred) -> { + if (el.isLiteral()) + return; + + boolean arrangable = (boolean) el.getMeta("_arrangable"); + + List knownOrder = new ArrayList<>(); + el.unsafePutMeta("_knownOrder", knownOrder); + + if (arrangable) { + el.getOperands().sort((cmp1, cmp2) -> compare(cmp1, cmp2, ctx)); + + boolean containsUnorderedSet = false; + + List currSet = new ArrayList<>(); + currSet.add(el.getOperands().get(0)); + + for (int i = 1; i < el.getOperands().size(); i++) { + if (compare(el.getOperands().get(i-1), el.getOperands().get(i), ctx) != 0) { + if (currSet.size() == 1) { + knownOrder.add(currSet.get(0)); + currSet.clear(); + } else { + final RewriterStatement first = currSet.get(0); + if (currSet.stream().allMatch(mEl -> first == mEl)) { + // Then this is not an unordered set as it only contains one instance and the order doesn't matter + knownOrder.addAll(currSet); + currSet.clear(); + } else { + containsUnorderedSet = true; + currSet.forEach(cur -> { + if (!cur.isLiteral()) + cur.unsafePutMeta("_addresses", new ArrayList()); + }); + knownOrder.add(new UnorderedSet(currSet)); + currSet = new ArrayList<>(); + } + } + } + + currSet.add(el.getOperands().get(i)); + } + + if (currSet.size() == 1) + knownOrder.add(currSet.get(0)); + else { + final RewriterStatement first = currSet.get(0); + if (currSet.stream().allMatch(first::equals)) { + knownOrder.addAll(currSet); + } else { + containsUnorderedSet = true; + currSet.forEach(cur -> { + if (!cur.isLiteral()) + cur.unsafePutMeta("_addresses", new ArrayList()); + }); + knownOrder.add(new UnorderedSet(currSet)); + } + } + + if (containsUnorderedSet) + uncertainParents.add(el); + } else { + knownOrder.addAll(el.getOperands()); + } + + if (DEBUG) + LOG.trace("Initial known order of " + el.toParsableString(ctx) + ": " + knownOrder); + }, false); + + return uncertainParents; + } + + private static int introduceFacts(Collection sets, int factCtr) { + for (RewriterStatement stmt : allChildren(sets)) { + if (stmt.isLiteral()) + continue; + + if (stmt.getMeta("_addresses") == null) + stmt.unsafePutMeta("_addresses", new ArrayList<>()); + + if (stmt.getMeta("_fact") == null) + stmt.unsafePutMeta("_fact", factCtr++); + } + + return factCtr; + } + + // Returns a list of all unordered set that do not contain other unordered set + private static Set findLowestUncertainties(RewriterStatement root) { + Set set = new HashSet<>(); + recursivelyFindLowestUncertainties(root, set); + + List tmpList = new ArrayList<>(set); + Set minSet = new HashSet<>(); + // We have the issue that uncertainties might still depend on each other (e.g. {a,b}, {inv(a),inv(b)}), even if they are the lowest entries + // Theoretically, this comparison might still lead to amgibuities, but never occurred in our examples + int minCumSize = Integer.MAX_VALUE; + for (int i = 0; i < tmpList.size(); i++) { + int cumSize = tmpList.get(i).contents.stream().map(RewriterStatement::countInstructions).reduce(0, Integer::sum); + + if (cumSize < minCumSize) { + minSet.clear(); + minCumSize = cumSize; + } + + if (cumSize <= minCumSize) + minSet.add(tmpList.get(i)); + } + + return minSet; + } + + // All children in post order and unique + private static List allChildren(Collection unorderedSets) { + Set is = new HashSet<>(); + List children = new ArrayList<>(); + for (UnorderedSet set : unorderedSets) + for (RewriterStatement s : set.contents) + traverse(s, is, children); + + return children; + } + + private static void traverse(RewriterStatement stmt, Set visited, List l) { + if (visited.contains(stmt)) + return; + + visited.add(stmt); + stmt.getOperands().forEach(el -> traverse(el, visited, l)); + + l.add(stmt); + } + + private static boolean recursivelyFindLowestUncertainties(RewriterStatement current, Set lowestUncertainties) { + if (current.isLiteral()) + return false; + + List knownOrder = (List) current.getMeta("_knownOrder"); + boolean containsUncertainty = false; + + for (Object o : knownOrder) { + if (o instanceof RewriterStatement) { + containsUncertainty |= recursivelyFindLowestUncertainties((RewriterStatement) o, lowestUncertainties); + } else { + UnorderedSet set = (UnorderedSet) o; + containsUncertainty = true; + boolean foundEmbeddedUncertainty = set.contents.stream().anyMatch(stmt -> recursivelyFindLowestUncertainties(stmt, lowestUncertainties)); + + if (foundEmbeddedUncertainty) + lowestUncertainties.remove(set); + else + lowestUncertainties.add(set); + } + } + + return containsUncertainty; + } + + public static void constructNewDAG(RewriterStatement root, final RuleContext ctx) { + root.forEachPostOrder((cur, pred) -> { + List knownOrder = (List) cur.getMeta("_knownOrder"); + if (DEBUG) + LOG.trace("KnownOrder of " + cur.toParsableString(ctx) + ": " + knownOrder); + + for (int i = 0; i < cur.getOperands().size(); i++) + cur.getOperands().set(i, (RewriterStatement) knownOrder.get(i)); + + cur.unsafeRemoveMeta("_knownOrder"); + cur.unsafeRemoveMeta("_addresses"); + cur.unsafeRemoveMeta("_address"); + cur.unsafeRemoveMeta("_arrangable"); + cur.unsafeRemoveMeta("_tempName"); + }, false); + + root.prepareForHashing(); + root.recomputeHashCodes(ctx); + } + + // Here, we try to infer new information given the address information + // This step also resets all addresses as they will change after one sorting step + private static boolean resolveAmbiguities(RewriterStatement root, final RuleContext ctx, List uncertainParents) { + boolean couldResolve = false; + boolean couldResolveAnyUncertainty = true; + + while (couldResolveAnyUncertainty) { + couldResolveAnyUncertainty = false; + + for (int i = 0; i < uncertainParents.size(); i++) { + List knownOrder = (List) uncertainParents.get(i).getMeta("_knownOrder"); + boolean uncertaintyRemaining = false; + + for (int j = 0; j < knownOrder.size(); j++) { + if (knownOrder.get(j) instanceof UnorderedSet) { + UnorderedSet set = (UnorderedSet) knownOrder.get(j); + + if (tryResolveUncertainties(set, ctx)) { + couldResolveAnyUncertainty = true; + couldResolve = true; + knownOrder.set(j, set.contents.get(0)); + knownOrder.addAll(j+1, set.contents.subList(1, set.contents.size())); + set.contents.forEach(el -> { + el.unsafeRemoveMeta("_addresses"); + el.unsafeRemoveMeta("_address"); + }); + } else { + uncertaintyRemaining = true; + } + } + } + + if (!uncertaintyRemaining) { + uncertainParents.remove(i); + i--; + } + } + } + + return couldResolve; + } + + private static void resetAddresses(List uncertainParents) { + for (RewriterStatement uParent : uncertainParents) { + List knownOrder = (List) uParent.getMeta("_knownOrder"); + + for (Object o : knownOrder) { + if (o instanceof UnorderedSet) { + ((UnorderedSet) o).contents.forEach(el -> { + List addresses = (List) el.getMeta("_addresses"); + + if (addresses == null) { + addresses = new ArrayList<>(); + el.unsafePutMeta("_addresses", addresses); + el.unsafeRemoveMeta("_address"); + } + + addresses.clear(); + }); + } + } + } + } + + private static boolean tryResolveUncertainties(UnorderedSet set, final RuleContext ctx) { + set.contents.sort((el1, el2) -> compare(el1, el2, ctx)); // We assume that every statement has an address, as it is uncertain + + RewriterStatement compareTo = set.contents.get(0); + // Check if ambiguity could be resolved + for (int i = 1; i < set.contents.size(); i++) { + if (compareTo.equals(set.contents.get(i))) + continue; // Ignore same instances + + if (compare(set.contents.get(i), compareTo, ctx) == 0) + return false; // Then there are still some ambiguities + + compareTo = set.contents.get(i); + } + + return true; + } + + private static List buildAddresses(RewriterStatement root, final RuleContext ctx) { + // First, catch all addresses + List elementsWithAddress = new ArrayList<>(); + recursivelyBuildAddresses(root, null, ctx, elementsWithAddress); + + // Now, we sort all addresses + for (RewriterStatement el : elementsWithAddress) { + List addresses = (List) el.getMeta("_addresses"); + Collections.sort(addresses); + String address = String.join(";", addresses); + el.unsafePutMeta("_address", address); + + if (DEBUG) + LOG.trace("Address of " + el + " :: " + address); + } + + return elementsWithAddress; + } + + private static void recursivelyBuildAddresses(RewriterStatement current, String currentAddress, final RuleContext ctx, List elementsWithAddress) { + List knownOrder = (List)current.getMeta("_knownOrder"); + List addresses = (List)current.getMeta("_addresses"); + + if (knownOrder == null) + knownOrder = Collections.emptyList(); + + + + if (DEBUG) { + LOG.trace("CUR: " + current); + LOG.trace("KnownOrder: " + knownOrder); + } + + if (addresses != null) { + if (addresses.isEmpty()) + elementsWithAddress.add(current); + + addresses.add(currentAddress); + } + + for (int i = 0; i < knownOrder.size(); i++) { + Object next = knownOrder.get(i); + String addr = currentAddress == null ? Integer.toString(i) : currentAddress + "." + i; + + if (next instanceof RewriterStatement) { + recursivelyBuildAddresses((RewriterStatement) next, addr, ctx, elementsWithAddress); + } else { + UnorderedSet set = (UnorderedSet) next; + set.contents.forEach(el -> recursivelyBuildAddresses(el, addr, ctx, elementsWithAddress)); + } + } + } + + private static String getAddress(RewriterStatement stmt) { + String addr = (String) stmt.getMeta("_address"); + + if (addr == null) + return null; + + return addr + (stmt.getMeta("_fact") == null ? "_" : "_" + stmt.getMeta("_fact")); + } + + // Expects that the children have already been sorted to the best of the current knowledge + public static int compare(RewriterStatement stmt1, RewriterStatement stmt2, final RuleContext ctx) { + int comp = toOrderString(ctx, stmt1, false).compareTo(toOrderString(ctx, stmt2, false)); + + if (comp != 0 || stmt1.equals(stmt2)) + return comp; + + List knownOrder1 = (List)stmt1.getMeta("_knownOrder"); + List knownOrder2 = (List)stmt2.getMeta("_knownOrder"); + + // Then the two statements are distinguishable by their number of unknowns + if (knownOrder1.size() != knownOrder2.size()) + return Integer.compare(knownOrder1.size(), knownOrder2.size()); + + for (int i = 0; i < knownOrder1.size() && comp == 0; i++) + comp = compare(knownOrder1.get(i), knownOrder2.get(i), ctx); + + if (comp == 0) { + String addr1 = getAddress(stmt1); + String addr2 = getAddress(stmt2); + + if (addr1 != null && addr2 != null) + return addr1.compareTo(addr2); + } + + return comp; + } + + public static int compare(Object obj1, Object obj2, final RuleContext ctx) { + boolean isStmt1 = obj1 instanceof RewriterStatement; + boolean isStmt2 = obj2 instanceof RewriterStatement; + + if (isStmt1 && !isStmt2) + return 1; + if (!isStmt1 && isStmt2) + return -1; + + if (isStmt1 && isStmt2) + return compare((RewriterStatement) obj1, (RewriterStatement) obj2, ctx); + + UnorderedSet set1 = (UnorderedSet) obj1; + UnorderedSet set2 = (UnorderedSet) obj2; + + if (set1.contents.size() < 2 || set2.contents.size() < 2) + throw new IllegalArgumentException(); // This should never happen as this would not be an unknown ordering + + if (set1.contents.size() != set2.contents.size()) + return Integer.compare(set1.contents.size(), set2.contents.size()); + + // Now, we can just choose any representant of the set + return compare(set1.contents.get(0), set2.contents.get(0), ctx); + } + + public static String toOrderString(final RuleContext ctx, RewriterStatement stmt, boolean useGlobalOrder) { + String globalOrderAddition = useGlobalOrder ? ((Integer)stmt.getMeta("_tempName")).toString() : ""; + + if (stmt.isInstruction()) { + return stmt.getResultingDataType(ctx) + ":" + stmt.trueTypedInstruction(ctx) + "[" + stmt.refCtr + "](" + stmt.getOperands().size() + ")" + globalOrderAddition + ";"; + } else { + return stmt.getResultingDataType(ctx) + ":" + (stmt.isLiteral() ? "L:" + stmt.getLiteral() : "V") + "[" + stmt.refCtr + "](0)" + globalOrderAddition + ";"; + } + } + + + + static class UnorderedSet { + List contents; + + public UnorderedSet(List contents) { + this.contents = contents; + } + + public String toString() { + return contents.toString(); + } + } +} \ No newline at end of file diff --git a/src/main/java/org/apache/sysds/hops/rewriter/assertions/RewriterAssertionUtils.java b/src/main/java/org/apache/sysds/hops/rewriter/assertions/RewriterAssertionUtils.java new file mode 100644 index 00000000000..d6b15d25b5b --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/assertions/RewriterAssertionUtils.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.assertions; + +import org.apache.sysds.hops.rewriter.RewriterInstruction; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; + +public class RewriterAssertionUtils { + public static RewriterAssertions buildImplicitAssertions(RewriterStatement root, final RuleContext ctx) { + RewriterAssertions assertions = new RewriterAssertions(ctx); + buildImplicitAssertions(root, assertions, ctx); + return assertions; + } + + public static void buildImplicitAssertions(RewriterStatement root, RewriterAssertions assertions, final RuleContext ctx) { + root.forEachPreOrder(cur -> { + buildImplicitAssertion(cur, assertions, root, ctx); + return true; + }, false); + } + + public static boolean buildImplicitAssertion(RewriterStatement stmt, RewriterAssertions assertions, RewriterStatement exprRoot, final RuleContext ctx) { + if (!stmt.isInstruction()) + return false; + + switch (stmt.trueInstruction()) { + case "%*%": + assertions.addEqualityAssertion(stmt.getChild(0).getNCol(), stmt.getChild(1).getNRow(), exprRoot); + return true; + case "diag": + assertions.addEqualityAssertion(stmt.getChild(0).getNCol(), stmt.getChild(0).getNRow(), exprRoot); + return true; + case "RBind": + assertions.addEqualityAssertion(stmt.getChild(0).getNCol(), stmt.getChild(1).getNCol(), exprRoot); + return true; + case "CBind": + assertions.addEqualityAssertion(stmt.getChild(0).getNRow(), stmt.getChild(1).getNRow(), exprRoot); + return true; + case "1-*": + assertions.addEqualityAssertion(stmt.getChild(0).getNCol(), stmt.getChild(1).getNCol(), exprRoot); + assertions.addEqualityAssertion(stmt.getChild(0).getNRow(), stmt.getChild(1).getNRow(), exprRoot); + return true; + case "+*": + case "-*": + assertions.addEqualityAssertion(stmt.getChild(0).getNCol(), stmt.getChild(2).getNCol(), exprRoot); + assertions.addEqualityAssertion(stmt.getChild(0).getNRow(), stmt.getChild(2).getNRow(), exprRoot); + return true; + } + + switch (stmt.trueTypedInstruction(ctx)) { + case "trace(MATRIX)": + assertions.addEqualityAssertion(stmt.getChild(0).getNRow(), stmt.getChild(0).getNCol(), exprRoot); + return true; + case "cast.FLOAT(MATRIX)": + assertions.addEqualityAssertion(stmt.getChild(0).getNRow(), stmt.getChild(0).getNCol(), exprRoot); + assertions.addEqualityAssertion(stmt.getChild(0).getNRow(), RewriterStatement.literal(ctx, 1L), exprRoot); + return true; + } + + if (((RewriterInstruction) stmt).hasProperty("ElementWiseInstruction", ctx)) { + if (stmt.getChild(0).getResultingDataType(ctx).equals("MATRIX") + && stmt.getChild(1).getResultingDataType(ctx).equals("MATRIX")) { + assertions.addEqualityAssertion(stmt.getChild(0).getNCol(), stmt.getChild(1).getNCol(), exprRoot); + assertions.addEqualityAssertion(stmt.getChild(0).getNRow(), stmt.getChild(1).getNRow(), exprRoot); + return true; + } + } + + return false; + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/assertions/RewriterAssertions.java b/src/main/java/org/apache/sysds/hops/rewriter/assertions/RewriterAssertions.java new file mode 100644 index 00000000000..7da9da401dd --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/assertions/RewriterAssertions.java @@ -0,0 +1,751 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.assertions; + +import org.apache.commons.lang3.NotImplementedException; +import org.apache.commons.lang3.function.TriFunction; +import org.apache.sysds.hops.rewriter.RewriterInstruction; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; +import scala.Tuple2; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.UUID; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class RewriterAssertions { + private final RuleContext ctx; + private Map assertionMatcher = new HashMap<>(); + // Tracks which statements are part of which assertions + private Map> partOfAssertion = new HashMap<>(); + private Set allAssertions = new HashSet<>(); + + public RewriterAssertions(final RuleContext ctx) { + this.ctx = ctx; + } + + public RewriterAssertions nestedCopyOrInject(Map createdObjects, TriFunction injector, RewriterStatement parent) { + RewriterAssertions out = new RewriterAssertions(ctx); + out.allAssertions = allAssertions.stream().map(assertion -> { + Set newSet = new HashSet<>(assertion.set.size()); + RewriterAssertion mapped = RewriterAssertion.from(newSet); + + if (assertion.stmt != null) { + mapped.stmt = assertion.stmt.nestedCopyOrInject(createdObjects, injector, parent, -1); + out.assertionMatcher.put(mapped.stmt, mapped); + } + + for (RewriterStatement entry : assertion.set) { + RewriterStatement newStmt = entry.nestedCopyOrInject(createdObjects, injector, parent, -1); + newSet.add(newStmt); + out.assertionMatcher.put(newStmt, mapped); + } + + if (assertion.backRef != null) { + mapped.backRef = assertion.backRef.nestedCopyOrInject(createdObjects, injector, parent, -1); + out.assertionMatcher.put(mapped.backRef, mapped); + } + + return mapped; + }).collect(Collectors.toSet()); + + for (RewriterAssertion assertion : out.allAssertions) { + forEachUniqueElementInAssertion(assertion, el -> { + Set partOfAssertions = out.partOfAssertion.get(el); + + if (partOfAssertions == null) { + partOfAssertions = new HashSet<>(); + out.partOfAssertion.put(el, partOfAssertions); + } + + partOfAssertions.add(assertion); + }); + } + + return out; + } + + public static RewriterAssertions copy(RewriterAssertions old, Map createdObjects, boolean removeOthers) { + RewriterAssertions newAssertions = new RewriterAssertions(old.ctx); + + Map mappedAssertions = new HashMap<>(); + + newAssertions.allAssertions = old.allAssertions.stream().map(assertion -> { + Set newSet = new HashSet<>(); + List backRefsToCheck = new ArrayList<>(); + + for (RewriterStatement oldEl : assertion.set) { + RewriterStatement cpy = createdObjects.get(oldEl); + + if (cpy == null) + cpy = oldEl.nestedCopyOrInject(createdObjects, stmt -> null); + + if (cpy.isInstruction() && cpy.trueInstruction().startsWith("_backRef.")) + backRefsToCheck.add(cpy); + + newSet.add(cpy); + } + + List backRefsToRemove = Collections.emptyList(); + + if (!backRefsToCheck.isEmpty()) { + backRefsToRemove = new ArrayList<>(); + + for (RewriterStatement backRef : backRefsToCheck) { + System.out.println("Candidate: " + backRef); + if (newSet.contains(backRef.getMeta("_backRef"))) { + newSet.remove(backRef); + backRefsToRemove.add(backRef); + } + } + } + + if (newSet.size() < 2) { + System.out.println("Removing E-Class: " + assertion); + return null; + } + + RewriterAssertion mapped = RewriterAssertion.from(newSet); + if (assertion.stmt != null) { + mapped.stmt = createdObjects.get(assertion.stmt); + + if (!backRefsToRemove.isEmpty()) { + mapped.stmt.getChild(0).getOperands().removeAll(backRefsToRemove); + } + } + if (assertion.backRef != null) + mapped.backRef = createdObjects.get(assertion.backRef); + mappedAssertions.put(assertion, mapped); + return mapped; + }).filter(Objects::nonNull).collect(Collectors.toSet()); + + for (Map.Entry> e : old.partOfAssertion.entrySet()) { + RewriterStatement k = createdObjects.get(e.getKey()); + + if (k == null) + continue; + + Set v = e.getValue(); + Set newV = v.stream().map(mappedAssertions::get).filter(Objects::nonNull).collect(Collectors.toSet()); + + newAssertions.partOfAssertion.put(k, newV); + } + + if (removeOthers) { + old.assertionMatcher.forEach((k, v) -> { + RewriterStatement newK = createdObjects.get(k); + + if (newK == null) + return; + + RewriterAssertion newV = mappedAssertions.get(v); + + if (newV == null) + return; + + newAssertions.assertionMatcher.put(newK, newV); + }); + } else { + old.assertionMatcher.forEach((k, v) -> { + RewriterStatement newK = createdObjects.getOrDefault(k, k); + RewriterAssertion newV = mappedAssertions.get(v); + + if (newV == null) + return; + + newAssertions.assertionMatcher.put(newK, newV); + }); + } + + return newAssertions; + } + + public void forEachAssertionContents(BiConsumer consumer) { + allAssertions.forEach(assertion -> assertion.set.forEach(set -> consumer.accept(set, new RewriterStatement.RewriterPredecessor(this, assertion)))); + } + + public void updateAssertionContents(Function f) { + for (RewriterAssertion assertion : allAssertions) { + Set toRemove = new HashSet<>(); + Map toReplace = new HashMap<>(); + + for (RewriterStatement stmt : assertion.set) { + RewriterStatement mNew = f.apply(stmt); + if (mNew != stmt) { + toRemove.add(stmt); + toReplace.put(stmt, mNew); + } + } + + if (toReplace.isEmpty()) + continue; + + toRemove.forEach(assertion.set::remove); + assertion.set.addAll(toReplace.values()); + + if (assertion.stmt != null) { + List argList = assertion.stmt.getChild(0).getOperands(); + for (int i = 0; i < argList.size(); i++) { + RewriterStatement replaced = toReplace.get(argList.get(i)); + + if (replaced != null) + argList.set(i, replaced); + } + } + + // Now, we have to recompute partOfAssertion for removed and newly added elements + for (RewriterStatement removed : toRemove) { + removed.forEachPreOrder((cur, pred) -> { + Set set = partOfAssertion.get(cur); + + if (set != null) + set.remove(assertion); + + return true; + }, false); + } + + forEachUniqueElementInAssertion(assertion, cur -> { + partOfAssertion.compute(cur, (k, v) -> { + if (v == null) + v = new HashSet<>(); + + v.add(assertion); + return v; + }); + }); + } + } + + public Stream> streamOfContents() { + return allAssertions.stream().flatMap(assertion -> { + if (assertion.stmt != null) { + if (assertion.backRef != null) + return Stream.of(new Tuple2<>(assertion.stmt, new RewriterStatement.RewriterPredecessor(this, assertion)), new Tuple2<>(assertion.backRef, new RewriterStatement.RewriterPredecessor(this, assertion))); + return Stream.of(new Tuple2<>(assertion.stmt, new RewriterStatement.RewriterPredecessor(this, assertion))); + } else { + return assertion.set.stream().map(stmt -> new Tuple2<>(stmt, new RewriterStatement.RewriterPredecessor(this, assertion))); + } + }); + } + + public void replaceAssertionContent(RewriterStatement oldStmt, RewriterStatement newStmt, RewriterAssertion assertion) { + if (oldStmt == assertion.stmt) { + // Then we will remove this assertion + allAssertions.remove(assertion); + assertion.set.forEach(s -> this.assertionMatcher.remove(s)); + } + + assertion.set.remove(oldStmt); + assertion.set.add(newStmt); + + if (assertion.stmt != null) { + assertion.stmt.getChild(); + } + + throw new NotImplementedException(); + } + + public void resolveExistingAssertions(RewriterStatement root) { + List backRefs = new ArrayList<>(); + root.forEachPreOrder(stmt -> { + if (stmt.isEClass()) { + if (!assertionMatcher.containsKey(stmt)) { + RewriterAssertion assertion = new RewriterAssertion(); + assertion.stmt = stmt; + assertion.set = new HashSet<>(stmt.getChild(0).getOperands()); + allAssertions.add(assertion); + + for (RewriterStatement eStmt : assertion.set) + assertionMatcher.put(eStmt, assertion); + + forEachUniqueElementInAssertion(assertion, cur -> { + partOfAssertion.compute(cur, (k, v) -> { + if (v == null) + v = new HashSet<>(); + + v.add(assertion); + return v; + }); + }); + } + } else if (stmt.isInstruction() && stmt.trueInstruction().equals("_backRef")) { + backRefs.add(stmt); + } + + return true; + }, false); + + for (RewriterStatement backRef : backRefs) { + RewriterAssertion assertion = getAssertionObj(backRef); + if (assertion != null) { + assertion.backRef = backRef; + } else { + // TODO + } + } + } + + public boolean addEqualityAssertion(RewriterStatement stmt1, RewriterStatement stmt2, RewriterStatement exprRoot) { + if (stmt1 == null || stmt2 == null) + throw new IllegalArgumentException("Cannot add an equality assertion to a null reference!"); + + if (stmt1 == stmt2 || (stmt1.isLiteral() && stmt2.isLiteral() && stmt1.getLiteral().equals(stmt2.getLiteral()))) + return false; + + if (stmt1.isLiteral() && stmt2.isLiteral() && !stmt1.getLiteral().equals(stmt2.getLiteral())) + throw new IllegalArgumentException("Cannot assert equality of two different literals!"); + + if (stmt1.hashCode() == 0) + throw new IllegalArgumentException(); + + RewriterStatement e1 = stmt1; + RewriterStatement e2 = stmt2; + RewriterAssertion stmt1Assertions = assertionMatcher.get(e1); + RewriterAssertion stmt2Assertions = assertionMatcher.get(e2); + + if (stmt1.isLiteral() || stmt2.isLiteral()) { + RewriterStatement literal = stmt1.isLiteral() ? stmt1 : stmt2; + + if (stmt1Assertions != null) { + Optional existingLiteral = stmt1Assertions.getLiteral(); + + if (existingLiteral.isPresent()) { + if (literal.getLiteral().equals(existingLiteral.get().getLiteral())) + return false; + else + throw new IllegalArgumentException("Cannot assert equality of two different literals!"); + } + } + + if (stmt2Assertions != null) { + Optional existingLiteral = stmt2Assertions.getLiteral(); + + if (existingLiteral.isPresent()) { + if (literal.getLiteral().equals(existingLiteral.get().getLiteral())) + return false; + else + throw new IllegalArgumentException("Cannot assert equality of two different literals!"); + } + } + + if (stmt1Assertions != null && stmt2Assertions != null) { + // Here, we need to check if both assertions already contain a literal + // If the literals are identical, we need to deduplicate, otherwise throw an error + Optional existingLiteral1 = stmt1Assertions.getLiteral(); + Optional existingLiteral2 = stmt2Assertions.getLiteral(); + + if (existingLiteral1.isPresent() && existingLiteral2.isPresent()) { + if (!existingLiteral1.get().getLiteral().equals(existingLiteral2.get().getLiteral())) + throw new IllegalArgumentException("Cannot assert equality of two different literal!"); + } + } + } + + if (stmt1Assertions == stmt2Assertions) { + if (stmt1Assertions == null) { + // Then we need to introduce a new equality set + Set newSet = new HashSet<>(); + newSet.add(e1); + newSet.add(e2); + + RewriterAssertion newAssertion = RewriterAssertion.from(newSet); + + assertionMatcher.put(e1, newAssertion); + assertionMatcher.put(e2, newAssertion); + + allAssertions.add(newAssertion); + + resolveCyclicAssertions(newAssertion); + + forEachUniqueElementInAssertion(newAssertion, cur -> { + partOfAssertion.compute(cur, (k, v) -> { + if (v == null) + v = new HashSet<>(); + + v.add(newAssertion); + return v; + }); + }); + + return true; + } + + return false; // The assertion already exists + } + + if (stmt1Assertions == null || stmt2Assertions == null) { + boolean assert1 = stmt1Assertions == null; + RewriterStatement toAssert = assert1 ? stmt1 : stmt2; + RewriterAssertion existingAssertion = assert1 ? stmt2Assertions : stmt1Assertions; + existingAssertion.set.add(toAssert); + assertionMatcher.put(assert1 ? e1 : e2, existingAssertion); + if (existingAssertion.stmt != null) + updateInstance(existingAssertion.stmt.getChild(0), existingAssertion.set); + + resolveCyclicAssertions(existingAssertion); + + toAssert.forEachPreOrder(cur -> { + partOfAssertion.compute(cur, (k, v) -> { + if (v == null) + v = new HashSet<>(); + + v.add(existingAssertion); + return v; + }); + return true; + }, false); + + return true; + } + + // Otherwise we need to merge the assertions + + // For that, we choose the smaller set as we will need fewer operations + if (stmt1Assertions.set.size() > stmt2Assertions.set.size()) { + RewriterAssertion tmp = stmt1Assertions; + stmt1Assertions = stmt2Assertions; + stmt2Assertions = tmp; + } + + stmt2Assertions.set.addAll(stmt1Assertions.set); + allAssertions.remove(stmt1Assertions); + if (stmt2Assertions.stmt != null) + updateInstance(stmt2Assertions.stmt.getChild(0), stmt2Assertions.set); + + for (RewriterStatement stmt : stmt1Assertions.set) + assertionMatcher.put(stmt, stmt2Assertions); + + if (stmt1Assertions.stmt != null) + assertionMatcher.put(stmt1Assertions.stmt, stmt2Assertions); // Only temporary + + resolveCyclicAssertions(stmt2Assertions); + stmt2Assertions.deduplicate(); + + final RewriterAssertion assertionToRemove = stmt1Assertions; + final RewriterAssertion assertionToExtend = stmt2Assertions; + forEachUniqueElementInAssertion(stmt1Assertions, cur -> { + Set v = partOfAssertion.get(cur); + + if (v == null) + throw new IllegalArgumentException(cur.toString()); + + v.remove(assertionToRemove); + v.add(assertionToExtend); + }); + + if (assertionToRemove.stmt != null) { + exprRoot.forEachPreOrder(cur -> { + for (int i = 0; i < cur.getOperands().size(); i++) { + RewriterStatement child = cur.getChild(i); + if (child == assertionToRemove.stmt) + cur.getOperands().set(i, assertionToExtend.getEClassStmt(ctx, this)); + } + return true; + }, false); + } + + return true; + } + + public static RewriterStatement updateMergedEClasses(RewriterStatement exprRoot, Map legacyEClasses) { + exprRoot.forEachPreOrder(cur -> { + for (int i = 0; i < cur.getOperands().size(); i++) { + RewriterStatement child = cur.getChild(i); + if (child.isEClass()) { + RewriterStatement mapped = legacyEClasses.get(child); + if (mapped != null) + cur.getOperands().set(i, mapped); + } + } + return true; + }, false); + + if (exprRoot.isEClass()) { + RewriterStatement mapped = legacyEClasses.get(exprRoot); + if (mapped != null) + return mapped; + } + + return exprRoot; + } + + private void forEachUniqueElementInAssertion(RewriterAssertion assertion, Consumer consumer) { + Set visited = new HashSet<>(); + for (RewriterStatement eq : assertion.set) { + eq.forEachPreOrderWithDuplicates(cur -> { + if (!visited.add(cur)) + return false; + + consumer.accept(cur); + return true; + }); + } + } + + // Replace cycles with _backRef() + private void resolveCyclicAssertions(RewriterAssertion assertion) { + if (assertion.stmt == null) + return; + + RewriterStatement backref = assertion.getBackRef(ctx, this); + + for (RewriterStatement eq : assertion.set) { + eq.forEachPreOrder((cur, pred) -> { + for (int i = 0; i < cur.getOperands().size(); i++) + if (!cur.getChild(i).isLiteral() && getAssertionObj(cur.getChild(i)) == assertion) + cur.getOperands().set(i, backref); + + return true; + }, false); + } + } + + public RewriterAssertion getAssertionObj(RewriterStatement stmt) { + return assertionMatcher.get(stmt); + } + + public Set getAssertions(RewriterStatement stmt) { + RewriterAssertion set = assertionMatcher.get(stmt); + return set == null ? Collections.emptySet() : set.set; + } + + public RewriterStatement getAssertionStatement(RewriterStatement stmt, RewriterStatement parent) { + RewriterAssertion set = assertionMatcher.get(stmt); + + if (set == null || set.getEClassStmt(ctx, this).getChild(0) == parent) { + return stmt; + } + + if (parent != null && parent != set.getEClassStmt(ctx, this).getChild(0) && partOfAssertion.getOrDefault(parent, Collections.emptySet()).contains(set)) + return set.getBackRef(ctx, this); + + return set.getEClassStmt(ctx, this); + } + + public RewriterStatement update(RewriterStatement root) { + RewriterStatement eClass = getAssertionStatement(root, null); + + if (eClass == null) + eClass = root; + else if (root.getMeta("_assertions") != null) + eClass.unsafePutMeta("_assertions", root.getMeta("_assertions")); + + updateRecursively(eClass); + + return eClass; + } + + // This removes E-Classes that are not actually E-Classes like _EClass(argList(nrow(A), nrow(A))), or _EClass(argList(nrow(A), _backRef.INT())) + public RewriterStatement cleanupEClasses(RewriterStatement expressionRoot) { + Set toRemoveList = new HashSet<>(); + Map toRemove = new HashMap<>(); + + for (RewriterAssertion assertion : allAssertions) { + int previousSize = assertion.set.size(); + if (assertion.stmt != null) { + // Eliminate top-level back-refs + assertion.set.removeIf(el -> el.isInstruction() && el.trueInstruction().startsWith("_backRef") && el.getMeta("_backRef").equals(assertion.stmt)); + } + + if (assertion.set.size() < 2) { + toRemoveList.add(assertion); + + if (assertion.stmt != null) + toRemove.put(assertion.stmt, assertion.set.stream().findFirst().get()); + } + + if (previousSize != assertion.set.size() && assertion.stmt != null) { + // Then we need to update the EClass + assertion.stmt.getChild(0).getOperands().removeIf(el -> !assertion.set.contains(el)); + + if (assertion.stmt.getChild(0).getOperands().size() != assertion.set.size()) { + // Then there are still duplicates which we need to rule out + Set visited = new HashSet<>(); + List eItems = assertion.stmt.getChild(0).getOperands(); + for (int i = 0; i < eItems.size(); i++) { + if (!visited.add(eItems.get(i))) + eItems.remove(i--); + } + } + } + } + + if (!toRemoveList.isEmpty()) { + allAssertions.removeAll(toRemoveList); + + if (!toRemove.isEmpty()) { + if (expressionRoot.isEClass()) { + RewriterStatement mNew = toRemove.get(expressionRoot); + + if (mNew != null) + expressionRoot = mNew; + } + + expressionRoot.forEachPostOrder((cur, pred) -> { + cur.allChildren().forEach(t -> { + if (t._1.isEClass()) { + RewriterStatement mNew = toRemove.get(t._1); + if (mNew != null) { + if (t._2.isOperand()) { + cur.getOperands().set(t._2.getIndex(), mNew); + } else if (t._2.isMetaObject()) { + cur.unsafePutMeta(t._2.getMetaKey(), mNew); + } + } + } + }); + }, true); + } + } + + return expressionRoot; + } + + private void updateRecursively(RewriterStatement cur) { + for (int i = 0; i < cur.getOperands().size(); i++) { + RewriterStatement child = cur.getChild(i); + RewriterStatement eClass = getAssertionStatement(child, cur); + + if (eClass != child) + cur.getOperands().set(i, eClass); + + updateRecursively(cur.getChild(i)); + } + } + + @Override + public String toString() { + return allAssertions.toString(); + } + + private void updateInstance(RewriterStatement stmt, Set set) { + if (stmt != null) { + stmt.getOperands().clear(); + stmt.getOperands().addAll(set); + } + } + + public static class RewriterAssertion { + Set set; + RewriterStatement stmt; + RewriterStatement backRef; // The back-reference to this assertion + + public Collection getEClass() { + return set; + } + + public RewriterStatement getEClassStmt(final RuleContext ctx, RewriterAssertions assertions) { + if (stmt != null) + return stmt; + + RewriterStatement argList = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("argList").withOps(set.toArray(RewriterStatement[]::new)); + stmt = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("_EClass").withOps(argList); + stmt.consolidate(ctx); + assertions.assertionMatcher.put(stmt, this); + assertions.partOfAssertion.compute(stmt, (k, v) -> { + if (v == null) + v = new HashSet<>(); + + v.add(this); + return v; + }); + assertions.partOfAssertion.compute(argList, (k, v) -> { + if (v == null) + v = new HashSet<>(); + + v.add(this); + return v; + }); + assertions.resolveCyclicAssertions(this); + return stmt; + } + + public RewriterStatement getBackRef(final RuleContext ctx, RewriterAssertions assertions) { + if (backRef != null) + return backRef; + + backRef = new RewriterInstruction() + .as(UUID.randomUUID().toString()) + .withInstruction("_backRef." + getEClassStmt(ctx, assertions).getResultingDataType(ctx)) + .consolidate(ctx); + backRef.unsafePutMeta("_backRef", getEClassStmt(ctx, assertions)); + assertions.partOfAssertion.compute(backRef, (k, v) -> { + if (v == null) + v = new HashSet<>(); + + v.add(this); + return v; + }); + return backRef; + } + + // Returns a literal if available, otherwise null + public Optional getLiteral() { + return set.stream().filter(RewriterStatement::isLiteral).findFirst(); + } + + // Removes duplicate entries (e.g. duplicate literals etc.) + public void deduplicate() { + if (stmt != null && stmt.getChild(0).getOperands().size() != set.size()) { + List operands = stmt.getChild(0).getOperands(); + Set elementTracker = new HashSet<>(); + + for (int i = 0; i < operands.size(); i++) { + RewriterStatement el = operands.get(i); + + if (elementTracker.contains(el)) { + operands.remove(i); + i--; + } else { + elementTracker.add(el); + } + } + } + } + + @Override + public String toString() { + if (stmt != null) + return stmt + " -- " + System.identityHashCode(this); + + return set.toString() + " -- " + System.identityHashCode(this); + } + + static RewriterAssertion from(Set set) { + RewriterAssertion a = new RewriterAssertion(); + a.set = set; + return a; + } + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/codegen/CodeGenCondition.java b/src/main/java/org/apache/sysds/hops/rewriter/codegen/CodeGenCondition.java new file mode 100644 index 00000000000..e8e30f4105e --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/codegen/CodeGenCondition.java @@ -0,0 +1,622 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.codegen; + +import org.apache.sysds.hops.rewriter.RewriterDataType; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.hops.rewriter.utils.CodeGenUtils; +import scala.Tuple2; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.stream.Collectors; + +public class CodeGenCondition { + public enum ConditionType { + DATA_TYPE, VALUE_TYPE, UNIQUE_PARENTS, LITERAL, OP_CLASS, OP_CODE, NUM_INPUTS, ELSE + } + + public enum ConditionDataType { + SCALAR, MATRIX + } + + private ConditionType conditionType; + private Object conditionValue; + private List rulesIf; + private List applyAnyway; + private List relativeChildPath; + private RewriterStatement representant; + + private CodeGenCondition(ConditionType cType, Object cValue, List relativeChildPath, RewriterStatement representant, final RuleContext ctx) { + conditionType = cType; + conditionValue = cValue; + rulesIf = new ArrayList<>(); + applyAnyway = new ArrayList<>(); + this.relativeChildPath = relativeChildPath; + this.representant = representant; + + if (conditionType != ConditionType.ELSE) + buildConditionCheck(new StringBuilder(), ctx); + } + + public static List buildCondition(List rules, int maxNumRules, final RuleContext ctx) { + return buildCondition(rules, 3, maxNumRules, ctx); + } + + public static List buildCondition(List rules, int maxDepth, int maxNumRules, final RuleContext ctx) { + if (rules.isEmpty()) + return Collections.emptyList(); + List transformed = rules.stream().map(rule -> new Tuple2(rule, rule.getStmt1())).collect(Collectors.toList()); + List out = populateLayerRecursively(transformed, Collections.emptyList(), new LinkedList<>(), maxDepth, maxNumRules, ctx); + List cond = out.stream().filter(o -> o instanceof CodeGenCondition).map(o -> ((CodeGenCondition)o)).collect(Collectors.toList()); + return cond.isEmpty() ? List.of(conditionalElse(transformed, Collections.emptyList(), ((Tuple2) transformed.get(0))._2, ctx)) : cond; + } + + private static List populateLayerRecursively(List rules, List relativeChildPath, Queue, List>> queue, int maxDepth, int maxNumRules, final RuleContext ctx) { + if (rules.size() <= maxNumRules || maxDepth == 0) + return rules; + + List out = populateDataTypeLayer(rules, relativeChildPath, ctx); + + for (int i = 0; i < out.size(); i++) { + CodeGenCondition c = (CodeGenCondition) out.get(i); + + if (c.rulesIf.size() <= maxNumRules) + continue; + + c.rulesIf = populateOpClassLayer(c.rulesIf, relativeChildPath, ctx); + + for (int j = 0; j < c.rulesIf.size(); j++) { + CodeGenCondition c2 = (CodeGenCondition) c.rulesIf.get(j); + + if (c2.rulesIf.size() <= maxNumRules) + continue; + + c2.rulesIf = populateOpCodeLayer(c2.rulesIf, relativeChildPath, ctx); + + for (int k = 0; k < c2.rulesIf.size(); k++) { + CodeGenCondition c3 = (CodeGenCondition) c2.rulesIf.get(k); + + if (c3.rulesIf.size() <= maxNumRules) + continue; + + c3.rulesIf = populateInputSizeLayer(c3.rulesIf, relativeChildPath, ctx); + + for (int l = 0; l < c3.rulesIf.size(); l++) { + CodeGenCondition c4 = (CodeGenCondition) c3.rulesIf.get(l); + + if (((Tuple2) c4.rulesIf.get(0))._2 == null) + continue; + + final int maxIndex = ((Tuple2) c4.rulesIf.get(0))._2.getOperands().size(); + Set activeRules = c4.rulesIf.stream().map(o -> ((Tuple2) o)._1).collect(Collectors.toSet()); + Queue, List>> mQueue = new LinkedList<>(); + + for (Tuple2, List> t : queue) { + List mObj = new ArrayList<>(); + for (Object o : t._1) { + if (activeRules.contains(((Tuple2) o)._1)) + mObj.add(o); + } + + if (!mObj.isEmpty()) + mQueue.add(new Tuple2<>(mObj, t._2)); + } + + for (int idx = 0; idx < maxIndex; idx++) { + final int mIdx = idx; + final List newRelativeChildPath = new ArrayList<>(relativeChildPath); + newRelativeChildPath.add(mIdx); + List mList = new ArrayList<>(); + mQueue.add(new Tuple2<>(mList, newRelativeChildPath)); + + c4.rulesIf.forEach(o -> { + Tuple2 t = (Tuple2) o; + mList.add(new Tuple2(t._1, (t._2 == null ? null : (t._2.getOperands().isEmpty() ? null : t._2.getChild(mIdx))))); + }); + } + + if (!mQueue.isEmpty()) { + Tuple2, List> next = mQueue.poll(); + c4.rulesIf = populateLayerRecursively(next._1, next._2(), mQueue, maxDepth-1, maxNumRules, ctx); + } + } + } + } + } + + return out; + } + + private static boolean validateSizeMaintenance(List rules, List generatedConditions) { + int origSize = rules.size(); + int newSize = generatedConditions.stream().mapToInt(o -> ((CodeGenCondition)o).rulesIf.size()).sum(); + return origSize <= newSize; + } + + private static List populateDataTypeLayer(List rules, List relativeChildPath, final RuleContext ctx) { + List conds = new ArrayList<>(); + List> defer = new ArrayList<>(); + + //System.out.println("====="); + + for (Object o : rules) { + Tuple2 t = (Tuple2) o; + + if (t._2 == null) { + defer.add(t); + continue; + } + + if (!conds.stream().anyMatch(cond -> ((CodeGenCondition) cond).insertIfMatches(t, ctx))) { + CodeGenCondition cond = CodeGenCondition.conditionalDataType(t._2, relativeChildPath, t._2, ctx); + cond.insertIfMatches(t, ctx); + conds.add(cond); + StringBuilder sb = new StringBuilder(); + cond.buildConditionCheck(sb, ctx); + } else { + CodeGenCondition condse = (CodeGenCondition) conds.stream().filter(cond -> ((CodeGenCondition) cond).matchesCondition(t._1, t._2, ctx)).findFirst().get(); + StringBuilder sb = new StringBuilder(); + condse.buildConditionCheck(sb, ctx); + } + } + + if (!defer.isEmpty()) { + conds.add(CodeGenCondition.conditionalElse(new ArrayList<>(defer), relativeChildPath, null, ctx)); + } + + for (Tuple2 deferred : defer) { + for (Object obj : conds) + ((CodeGenCondition) obj).rulesIf.add(deferred); + } + + if (!validateSizeMaintenance(rules, conds)) + throw new IllegalArgumentException(); + + return conds; + } + + private static List populateOpClassLayer(List l, List relativeChildPath, final RuleContext ctx) { + List conds = new ArrayList<>(); + List remaining = new ArrayList<>(); + List> defer = new ArrayList<>(); + + for (Object o : l) { + try { + Tuple2 t = (Tuple2) o; + + if (t._2 == null || (t._2 instanceof RewriterDataType && !t._2.isLiteral())) { + defer.add(t); + continue; + } + + if (canGenerateOpClassCheck(t._2, ctx)) { + if (!conds.stream().anyMatch(cond -> ((CodeGenCondition) cond).insertIfMatches(t, ctx))) { + CodeGenCondition cond = CodeGenCondition.conditionalOpClass(t._2, relativeChildPath, t._2, ctx); + cond.insertIfMatches(t, ctx); + conds.add(cond); + } + } else { + remaining.add(t); + } + } catch (Exception e) { + e.printStackTrace(); + } + } + + remaining.addAll(defer); + + for (Tuple2 deferred : defer) { + for (Object obj : conds) + ((CodeGenCondition) obj).rulesIf.add(deferred); + } + + if (!remaining.isEmpty()) { + conds.add(CodeGenCondition.conditionalElse(remaining, relativeChildPath, ((Tuple2) remaining.get(0))._2, ctx)); + } + + return conds; + } + + private static List populateOpCodeLayer(List l, List relativeChildPath, final RuleContext ctx) { + List conds = new ArrayList<>(); + List remaining = new ArrayList<>(); + List> defer = new ArrayList<>(); + + for (Object o : l) { + Tuple2 t = (Tuple2) o; + + if (t._2 == null || (t._2 instanceof RewriterDataType && !t._2.isLiteral())) { + defer.add(t); + continue; + } + + if (canGenerateOpCodeCheck(t._2, ctx)) { + if (!conds.stream().anyMatch(cond -> ((CodeGenCondition) cond).insertIfMatches(t, ctx))) { + CodeGenCondition cond = CodeGenCondition.conditionalOpCode(t._2, relativeChildPath, t._2, ctx); + cond.insertIfMatches(t, ctx); + conds.add(cond); + } + } else if (t._2 instanceof RewriterDataType && !t._2.isLiteral()) { + // Then we must add it to all conditions + for (Object obj : conds) + ((CodeGenCondition) obj).rulesIf.add(t); + } else { + remaining.add(t); + } + } + + remaining.addAll(defer); + + for (Tuple2 deferred : defer) { + for (Object obj : conds) + ((CodeGenCondition) obj).rulesIf.add(deferred); + } + + if (!remaining.isEmpty()) { + conds.add(CodeGenCondition.conditionalElse(remaining, relativeChildPath, ((Tuple2) remaining.get(0))._2, ctx)); + } + + if (!validateSizeMaintenance(l, conds)) + throw new IllegalArgumentException(); + + return conds; + } + + private static List populateInputSizeLayer(List l, List relativeChildPath, final RuleContext ctx) { + List conds = new ArrayList<>(); + List remaining = new ArrayList<>(); + List> defer = new ArrayList<>(); + + for (Object o : l) { + Tuple2 t = (Tuple2) o; + + if (t._2 == null || (t._2 instanceof RewriterDataType && !t._2.isLiteral())) { + defer.add(t); + continue; + } + + if (canGenerateInputSizeCheck(t._2, ctx)) { + if (!conds.stream().anyMatch(cond -> ((CodeGenCondition) cond).insertIfMatches(t, ctx))) { + CodeGenCondition cond = CodeGenCondition.conditionalInputSize(t._2.getOperands().size(), relativeChildPath, t._2, ctx); + cond.insertIfMatches(t, ctx); + conds.add(cond); + } + } else if (t._2 instanceof RewriterDataType && !t._2.isLiteral()) { + // Then we must add it to all conditions + for (Object obj : conds) + ((CodeGenCondition) obj).rulesIf.add(t); + } else { + remaining.add(t); + } + } + + remaining.addAll(defer); + + for (Tuple2 deferred : defer) { + for (Object obj : conds) + ((CodeGenCondition) obj).rulesIf.add(deferred); + } + + if (!remaining.isEmpty()) { + conds.add(CodeGenCondition.conditionalElse(remaining, relativeChildPath, ((Tuple2) remaining.get(0))._2, ctx)); + } + + if (!validateSizeMaintenance(l, conds)) + throw new IllegalArgumentException(); + + return conds; + } + + public String getVarName() { + if (relativeChildPath.isEmpty()) + return "hi"; + return "hi_" + relativeChildPath.stream().map(Object::toString).collect(Collectors.joining("_")); + } + + public void buildConditionCheck(StringBuilder sb, final RuleContext ctx) { + switch (conditionType) { + case DATA_TYPE: + sb.append("hi"); + if (!relativeChildPath.isEmpty()) { + sb.append("_"); + sb.append(relativeChildPath.stream().map(Object::toString).collect(Collectors.joining("_"))); + } + sb.append(".getDataType() == "); + sb.append(CodeGenUtils.getReturnType(getDataType() == ConditionDataType.MATRIX ? "MATRIX" : "FLOAT")[0]); + break; + case OP_CLASS: + sb.append("hi"); + if (!relativeChildPath.isEmpty()) { + sb.append("_"); + sb.append(relativeChildPath.stream().map(Object::toString).collect(Collectors.joining("_"))); + } + sb.append(" instanceof " + CodeGenUtils.getOpClass(representant, ctx)); + break; + case OP_CODE: + String hopVar = "hi"; + if (!relativeChildPath.isEmpty()) { + hopVar += "_"; + hopVar += relativeChildPath.stream().map(Object::toString).collect(Collectors.joining("_")); + } + + String specialInstr = CodeGenUtils.getSpecialOpCheck(representant, ctx, hopVar); + if (specialInstr != null) { + sb.append(specialInstr); + } else { + // Some type casting + sb.append("(( "); + sb.append(CodeGenUtils.getOpClass(representant, ctx)); + sb.append(" ) "); + sb.append(hopVar); + sb.append(" )"); + sb.append(".getOp() == "); + sb.append(CodeGenUtils.getOpCode(representant, ctx)); + } + break; + case NUM_INPUTS: + sb.append("hi"); + if (!relativeChildPath.isEmpty()) { + sb.append("_"); + sb.append(relativeChildPath.stream().map(Object::toString).collect(Collectors.joining("_"))); + } + sb.append(".getInput().size() == "); + sb.append(conditionValue.toString()); + break; + default: + throw new IllegalArgumentException(conditionType.name()); + } + } + + public boolean insertIfMatches(Tuple2 t, final RuleContext ctx) { + if (matchesCondition(t._1, t._2, ctx)) { + rulesIf.add(t); + return true; + } + + return false; + } + + public boolean matchesCondition(RewriterRule rule, RewriterStatement stmt, final RuleContext ctx) { + switch (conditionType) { + case DATA_TYPE: + return matchesDataTypeCondition(rule, stmt, ctx); + case OP_CLASS: + return matchesOpClassCondition(rule, stmt, ctx); + case OP_CODE: + return matchesOpCodeCondition(rule, stmt, ctx); + case NUM_INPUTS: + return matchesNumInputs(rule, stmt, ctx); + } + return false; + } + + public ConditionDataType getDataType() { + return (ConditionDataType) conditionValue; + } + + private boolean matchesNumInputs(RewriterRule rule, RewriterStatement stmt, final RuleContext ctx) { + return ((int)conditionValue) == stmt.getOperands().size(); + } + + private boolean matchesDataTypeCondition(RewriterRule rule, RewriterStatement stmt, final RuleContext ctx) { + ConditionDataType cdt = getDataType(); + String dType = stmt.getResultingDataType(ctx); + + if (dType.equals("MATRIX")) + return cdt.equals(ConditionDataType.MATRIX); + else + return cdt.equals(ConditionDataType.SCALAR); + } + + private boolean matchesOpClassCondition(RewriterRule rule, RewriterStatement stmt, final RuleContext ctx) { + String opClass = (String) conditionValue; + String actualClass = CodeGenUtils.getOpClass(stmt, ctx); + + return opClass.equals(actualClass); + } + + private boolean matchesOpCodeCondition(RewriterRule rule, RewriterStatement stmt, final RuleContext ctx) { + String opType = (String) conditionValue; + String actualOpType = CodeGenUtils.getOpCode(stmt, ctx); + + return actualOpType.equals(opType); + } + + + public static CodeGenCondition conditionalDataType(RewriterStatement stmt, List i, RewriterStatement representant, final RuleContext ctx) { + ConditionDataType cdt = stmt.getResultingDataType(ctx).equals("MATRIX") ? ConditionDataType.MATRIX : ConditionDataType.SCALAR; + return new CodeGenCondition(ConditionType.DATA_TYPE, cdt, i, representant, ctx); + } + + public static CodeGenCondition conditionalOpClass(RewriterStatement op, List i, RewriterStatement representant, final RuleContext ctx) { + String opClass = CodeGenUtils.getOpClass(op, ctx); + return new CodeGenCondition(ConditionType.OP_CLASS, opClass, i, representant, ctx); + } + + public static boolean canGenerateOpClassCheck(RewriterStatement op, final RuleContext ctx) { + return !op.isDataOrigin(); + } + + public static CodeGenCondition conditionalOpCode(RewriterStatement op, List i, RewriterStatement representant, final RuleContext ctx) { + String opCode = CodeGenUtils.getOpCode(op, ctx); + return new CodeGenCondition(ConditionType.OP_CODE, opCode, i, representant, ctx); + } + + public static boolean canGenerateOpCodeCheck(RewriterStatement op, final RuleContext ctx) { + return !op.isDataOrigin(); + } + + public static CodeGenCondition conditionalInputSize(int inputSize, List i, RewriterStatement representant, final RuleContext ctx) { + return new CodeGenCondition(ConditionType.NUM_INPUTS, inputSize, i, representant, ctx); + } + + public static boolean canGenerateInputSizeCheck(RewriterStatement op, final RuleContext ctx) { + return !op.isDataOrigin(); + } + + public static CodeGenCondition conditionalElse(List l, List relativeChildPath, RewriterStatement representant, final RuleContext ctx) { + CodeGenCondition cond = new CodeGenCondition(ConditionType.ELSE, null, relativeChildPath, representant, ctx); + cond.rulesIf = l; + return cond; + } + + public static String getSelectionString(List conds, int indentation, Map ruleFunctionMappings, final RuleContext ctx) { + StringBuilder sb = new StringBuilder(); + buildSelection(sb, conds, indentation, ruleFunctionMappings, ctx); + return sb.toString(); + } + + public static void buildSelection(StringBuilder sb, List conds, int indentation, Map ruleFunctionMappings, final RuleContext ctx) { + if (conds.isEmpty()) + return; + + CodeGenCondition firstCond = conds.get(0); + + if (firstCond.conditionType == ConditionType.ELSE) { + List nestedCondition = firstCond.rulesIf.stream().filter(o -> o instanceof CodeGenCondition).map(o -> (CodeGenCondition)o).collect(Collectors.toList()); + buildSelection(sb, nestedCondition, indentation, ruleFunctionMappings, ctx); + if (nestedCondition.isEmpty()) { + List> cur = firstCond.rulesIf.stream().map(o -> (Tuple2)o).collect(Collectors.toList()); + + for (Tuple2 t : cur) { + String fMapping = ruleFunctionMappings.get(t._1); + if (fMapping != null) { + RewriterCodeGen.indent(indentation, sb); + sb.append("hi = "); + sb.append(fMapping); + sb.append("(hi); // "); + sb.append(t._1.toString()); + sb.append("\n"); + } + } + } + return; + } + + RewriterCodeGen.indent(indentation, sb); + sb.append("if ( "); + firstCond.buildConditionCheck(sb, ctx); + sb.append(" ) {\n"); + + if (firstCond.conditionType == ConditionType.NUM_INPUTS) { + int numInputs = (int)firstCond.conditionValue; + + for (int i = 0; i < numInputs; i++) { + RewriterCodeGen.indent(indentation + 1, sb); + sb.append("Hop "); + sb.append(firstCond.getVarName()); + sb.append("_"); + sb.append(i); + sb.append(" = "); + sb.append(firstCond.getVarName()); + sb.append(".getInput("); + sb.append(i); + sb.append(");\n"); + } + } + + List nestedCondition = firstCond.rulesIf.stream().filter(o -> o instanceof CodeGenCondition).map(o -> (CodeGenCondition)o).collect(Collectors.toList()); + buildSelection(sb, nestedCondition, indentation + 1, ruleFunctionMappings, ctx); + + if (nestedCondition.isEmpty()) { + List> cur = firstCond.rulesIf.stream().map(o -> (Tuple2)o).collect(Collectors.toList()); + + if (cur.isEmpty()) + throw new IllegalArgumentException(firstCond.rulesIf.toString()); + + for (Tuple2 t : cur) { + String fMapping = ruleFunctionMappings.get(t._1); + if (fMapping != null) { + RewriterCodeGen.indent(indentation + 1, sb); + sb.append("hi = "); + sb.append(fMapping); + sb.append("(hi); // "); + sb.append(t._1.toString()); + sb.append("\n"); + } + } + } + + RewriterCodeGen.indent(indentation, sb); + sb.append("}"); + + for (CodeGenCondition cond : conds.subList(1, conds.size())) { + if (cond.conditionType == ConditionType.ELSE) { + sb.append(" else {\n"); + } else { + sb.append(" else if ( "); + cond.buildConditionCheck(sb, ctx); + sb.append(" ) {\n"); + } + + if (cond.conditionType == ConditionType.NUM_INPUTS) { + int numInputs = (int)cond.conditionValue; + + for (int i = 0; i < numInputs; i++) { + RewriterCodeGen.indent(indentation + 1, sb); + sb.append("Hop "); + sb.append(cond.getVarName()); + sb.append("_"); + sb.append(i); + sb.append(" = "); + sb.append(cond.getVarName()); + sb.append(".getInput("); + sb.append(i); + sb.append(");"); + } + } + + List mNestedCondition = cond.rulesIf.stream().filter(o -> o instanceof CodeGenCondition).map(o -> (CodeGenCondition)o).collect(Collectors.toList()); + buildSelection(sb, mNestedCondition, indentation + 1, ruleFunctionMappings, ctx); + + if (mNestedCondition.isEmpty()) { + List> cur = cond.rulesIf.stream().map(o -> (Tuple2)o).collect(Collectors.toList()); + + if (cur.isEmpty()) + throw new IllegalArgumentException(cond.rulesIf.toString()); + + for (Tuple2 t : cur) { + String fMapping = ruleFunctionMappings.get(t._1); + if (fMapping != null) { + RewriterCodeGen.indent(indentation + 1, sb); + sb.append("hi = "); + sb.append(fMapping); + sb.append("(hi); // "); + sb.append(t._1.toString()); + sb.append("\n"); + } + } + } + + RewriterCodeGen.indent(indentation, sb); + sb.append("}"); + } + + sb.append("\n"); + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/codegen/RewriterCodeGen.java b/src/main/java/org/apache/sysds/hops/rewriter/codegen/RewriterCodeGen.java new file mode 100644 index 00000000000..7af6984660c --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/codegen/RewriterCodeGen.java @@ -0,0 +1,807 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.codegen; + +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleSet; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; +import org.apache.sysds.hops.rewriter.estimators.RewriterCostEstimator; +import org.apache.sysds.hops.rewriter.RewriterDataType; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.hops.rewriter.utils.CodeGenUtils; +import org.codehaus.janino.SimpleCompiler; +import scala.Tuple2; + +import java.io.FileWriter; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.AbstractCollection; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +public class RewriterCodeGen { + public static boolean DEBUG = true; + + public static String generateRewritesFromFiles(List filePaths, String targetFile, boolean optimize, final RuleContext ctx) throws IOException { + return generateRewritesFromFiles(filePaths, targetFile, optimize, 2, true, true, ctx); + } + + public static String generateRewritesFromFiles(List filePaths, String targetFile, boolean optimize, int maxOptimizationDepth, boolean includePackageInfo, boolean maintainStatistics, final RuleContext ctx) throws IOException { + List lines = new ArrayList<>(); + + for (String path : filePaths) { + lines.addAll(Files.readAllLines(Paths.get(path))); + } + + RewriterRuleSet ruleSet = RewriterRuleSet.deserialize(lines, ctx); + String javaCode = ruleSet.toJavaCode("GeneratedRewriteClass", optimize, maxOptimizationDepth, includePackageInfo, true, maintainStatistics); + + try (FileWriter writer = new FileWriter(targetFile)) { + writer.write(javaCode); + } catch (IOException e) { + throw e; + } + + return javaCode; + } + + public static Function compileRewrites(String className, List> rewrites, final RuleContext ctx, boolean ignoreErrors, boolean printErrors) throws Exception { + String code = generateClass(className, rewrites, false, false, ctx, ignoreErrors, printErrors); + System.out.println("Compiling code:\n" + code); + SimpleCompiler compiler = new SimpleCompiler(); + compiler.cook(code); + Class mClass = compiler.getClassLoader().loadClass(className); + Object instance = mClass.getDeclaredConstructor().newInstance(); + return (Function) instance; + } + + public static Function compile(String javaCode, String className) { + try { + SimpleCompiler compiler = new SimpleCompiler(); + compiler.cook(javaCode); + Class mClass = compiler.getClassLoader().loadClass(className); + Object instance = mClass.getDeclaredConstructor().newInstance(); + return (Function) instance; + } catch (Exception e) { + e.printStackTrace(); + return null; + } + } + + public static String generateClass(String className, List> rewrites, boolean optimize, boolean includePackageInfo, final RuleContext ctx, boolean ignoreErrors, boolean printErrors) { + return generateClass(className, rewrites, optimize, 2, includePackageInfo, ctx, ignoreErrors, printErrors, false); + } + + public static String generateClass(String className, List> rewrites, boolean optimize, int maxOptimizationDepth, boolean includePackageInfo, final RuleContext ctx, boolean ignoreErrors, boolean printErrors, boolean maintainRewriteStats) { + StringBuilder msb = new StringBuilder(); + + if (includePackageInfo) { + // Include license + msb.append("/*\n" + + " * Licensed to the Apache Software Foundation (ASF) under one\n" + + " * or more contributor license agreements. See the NOTICE file\n" + + " * distributed with this work for additional information\n" + + " * regarding copyright ownership. The ASF licenses this file\n" + + " * to you under the Apache License, Version 2.0 (the\n" + + " * \"License\"); you may not use this file except in compliance\n" + + " * with the License. You may obtain a copy of the License at\n" + + " *\n" + + " * http://www.apache.org/licenses/LICENSE-2.0\n" + + " *\n" + + " * Unless required by applicable law or agreed to in writing,\n" + + " * software distributed under the License is distributed on an\n" + + " * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n" + + " * KIND, either express or implied. See the License for the\n" + + " * specific language governing permissions and limitations\n" + + " * under the License.\n" + + " */\n\n"); + msb.append("package org.apache.sysds.hops.rewriter.generated;\n\n"); + } + + msb.append("import java.util.ArrayList;\n"); + msb.append("import java.util.function.Function;\n"); + msb.append("\n"); + msb.append("import org.apache.sysds.utils.Statistics;\n"); + msb.append("import org.apache.sysds.hops.Hop;\n"); + msb.append("import org.apache.sysds.hops.LiteralOp;\n"); + msb.append("import org.apache.sysds.hops.UnaryOp;\n"); + msb.append("import org.apache.sysds.hops.BinaryOp;\n"); + msb.append("import org.apache.sysds.hops.ReorgOp;\n"); + msb.append("import org.apache.sysds.hops.AggUnaryOp;\n"); + msb.append("import org.apache.sysds.hops.AggBinaryOp;\n"); + msb.append("import org.apache.sysds.hops.DataGenOp;\n"); + msb.append("import org.apache.sysds.hops.TernaryOp;\n"); + msb.append("import org.apache.sysds.common.Types;\n"); + msb.append("import org.apache.sysds.hops.rewrite.HopRewriteUtils;\n"); + msb.append("import org.apache.sysds.hops.rewriter.dml.DMLExecutor;\n"); + msb.append("import org.apache.sysds.hops.rewriter.RewriterRuntimeUtils;\n"); + msb.append("\n"); + msb.append("public class " + className + " implements Function {\n\n"); + + StringBuilder implSb = new StringBuilder(); + Set implemented = new HashSet<>(); + for (Tuple2 appliedRewrites : rewrites) { + String mRewriteFn; + if (ignoreErrors) { + try { + mRewriteFn = generateRewriteFunction(appliedRewrites._2, appliedRewrites._1, 1, maintainRewriteStats, ctx); + } catch (Exception e) { + if (printErrors) + e.printStackTrace(); + + continue; + } + } else { + mRewriteFn = generateRewriteFunction(appliedRewrites._2, appliedRewrites._1, 1, maintainRewriteStats, ctx); + } + + implSb.append('\n'); + indent(1, implSb); + implSb.append("// Implementation of the rule " + appliedRewrites._2 + "\n"); + implSb.append(mRewriteFn); + implemented.add(appliedRewrites._1); + } + + indent(1, msb); + msb.append("@Override\n"); + indent(1, msb); + msb.append("public Object apply( Object _hi ) {\n"); + indent(2, msb); + msb.append("if ( _hi == null )\n"); + indent(3, msb); + msb.append("return null;\n\n"); + indent(2, msb); + msb.append("Hop hi = (Hop) _hi;\n\n"); + + if (optimize) { + List> implementedRewrites = rewrites.stream().filter(t -> implemented.contains(t._1)).collect(Collectors.toList()); + + List rules = rewrites.stream().map(t -> t._2).collect(Collectors.toList()); + Map ruleNames = new HashMap<>(); + + for (Tuple2 t : implementedRewrites) + ruleNames.put(t._2, t._1); + + List conditions = CodeGenCondition.buildCondition(rules, maxOptimizationDepth, 5, ctx); + CodeGenCondition.buildSelection(msb, conditions, 2, ruleNames, ctx); + } else { + for (Tuple2 appliedRewrites : rewrites) { + if (implemented.contains(appliedRewrites._1)) { + indent(2, msb); + msb.append("hi = " + appliedRewrites._1 + "((Hop) hi);\t\t// "); + msb.append(appliedRewrites._2.toString()); + msb.append('\n'); + } + } + } + + indent(2, msb); + msb.append("return hi;\n"); + + indent(1, msb); + msb.append("}\n"); + + msb.append(implSb); + + msb.append('\n'); + buildTypeCastFunction(msb, 1); + msb.append('\n'); + buildMinIdxFunction(msb, 1); + msb.append('\n'); + msb.append("}"); + return msb.toString(); + } + + private static String generateRewriteFunction(RewriterRule rule, String fName, int indentation, boolean maintainRewriteStats, final RuleContext ctx) { + try { + Tuple2, Boolean> t = RewriterCostEstimator.determineSingleReferenceRequirement(rule, ctx); + Set mSet = t._1; + if (mSet instanceof AbstractCollection) + mSet = new HashSet<>(mSet); + mSet.add(rule.getStmt1()); + boolean allowCombinedMultiRefs = t._2; + + StringBuilder sb = new StringBuilder(); + + // Append the function signature + indent(indentation, sb); + sb.append("private static Hop " + fName + "(Hop hi) {\n"); + + if (!allowCombinedMultiRefs) { + indent(indentation + 1, sb); + sb.append("boolean _multiReference = false;\n"); + } + + List tos = rule.isConditionalMultiRule() ? rule.getConditionalMultiRuleTargets() : List.of(rule.getStmt2()); + + // Build the function body + buildMatchingSequence(rule.toString(), rule.getStmt1(), tos, rule.getStmt1Cost(), rule.getStmt2Costs(), rule.getCombinedAssertions(), sb, ctx, indentation + 1, mSet, allowCombinedMultiRefs, maintainRewriteStats); + indent(indentation, sb); + + sb.append("}\n"); + + return sb.toString(); + } catch (Exception e) { + e.addSuppressed(new Exception("Failed to generate rewrite rule: " + rule.toString() + "\nAssertions: " + rule.getCombinedAssertions())); + throw e; + } + } + + private static void buildMatchingSequence(String name, RewriterStatement from, List tos, RewriterStatement fromCost, List toCosts, RewriterAssertions combinedAssertions, StringBuilder sb, final RuleContext ctx, int indentation, Set allowedMultiRefs, boolean allowCombinations, boolean maintainRewriteStats) { + Map vars = new HashMap<>(); + vars.put(from, "hi"); + recursivelyBuildMatchingSequence(from, sb, "hi", ctx, indentation, vars, allowedMultiRefs, allowCombinations); + + if (fromCost != null) { + List msb = new ArrayList<>(); + msb.add(new StringBuilder()); + Set> requirements = new HashSet<>(); + + buildCostFnRecursively(fromCost, vars, ctx, msb.get(0), requirements); + + for (RewriterStatement toCost : toCosts) { + StringBuilder msb2 = new StringBuilder(); + buildCostFnRecursively(toCost, vars, ctx, msb2, requirements); + msb.add(msb2); + } + + // First, we build the necessary checks (e.g. if we have nnz / dim information we need, otherwise this rewrite cannot be applied) + if (!requirements.isEmpty()) { + sb.append('\n'); + indent(indentation, sb); + sb.append("if ( "); + + int ctr = 0; + for (Tuple2 req : requirements) { + if (ctr != 0) + sb.append(" || "); + + sb.append(req._1); + switch (req._2) { + case "_nnz": + sb.append(".getNnz() == -1"); + break; + case "nrow": + sb.append(".getDim1() == -1"); + break; + case "ncol": + sb.append(".getDim2() == -1"); + break; + default: + throw new IllegalArgumentException(req._2); + } + + ctr++; + } + + sb.append(" )\n"); + indent(indentation + 1, sb); + sb.append("return hi;\n\n"); + } + + // Then we build the cost functions + sb.append('\n'); + indent(indentation, sb); + sb.append("double[] costs = new double["); + sb.append(msb.size()); + sb.append("];\n"); + + for (int i = 0; i < msb.size(); i++) { + indent(indentation, sb); + sb.append("costs["); + sb.append(i); + sb.append("] = "); + sb.append(msb.get(i)); + sb.append(";\n"); + } + + indent(indentation, sb); + sb.append("int minIdx = minIdx(costs);\n\n"); + indent(indentation, sb); + sb.append("switch( minIdx ) {\n"); + + for (int i = 1; i < msb.size(); i++) { + indent(indentation+1, sb); + sb.append("case " + i + ": {"); + buildNewHop(name, from, tos.get(i-1), sb, combinedAssertions, new HashMap<>(vars), ctx, indentation+2, maintainRewriteStats); + indent(indentation+1, sb); + sb.append("}\n"); + } + + indent(indentation, sb); + sb.append("}\n"); + + indent(indentation, sb); + sb.append("return hi;\n"); + } else { + buildNewHop(name, from, tos.get(0), sb, combinedAssertions, vars, ctx, indentation, maintainRewriteStats); + } + } + + private static void buildNewHop(String rewriteName, RewriterStatement from, RewriterStatement to, StringBuilder sb, RewriterAssertions combinedAssertions, Map vars, final RuleContext ctx, int indentation, boolean maintainRewriteStats) { + sb.append('\n'); + indent(indentation, sb); + sb.append("// Now, we start building the new HOP-DAG: "); + sb.append(to.toParsableString(ctx)); + sb.append('\n'); + + Set activeStatements = buildRewrite(to, sb, combinedAssertions, vars, ctx, indentation); + + String newRoot = vars.get(to); + + sb.append('\n'); + indent(indentation, sb); + sb.append("Hop newRoot = " + newRoot + ";\n"); + indent(indentation, sb); + sb.append("if ( " + newRoot + ".getValueType() != hi.getValueType() ) {\n"); + indent(indentation + 1, sb); + sb.append("newRoot = castIfNecessary(newRoot, hi);\n"); + indent(indentation + 1, sb); + sb.append("if ( newRoot == null )\n"); + indent(indentation + 2, sb); + sb.append("return hi;\n"); + indent(indentation, sb); + sb.append("}\n"); + + + sb.append('\n'); + indent(indentation, sb); + sb.append("ArrayList parents = new ArrayList<>(hi.getParent());\n\n"); + indent(indentation, sb); + sb.append("for ( Hop p : parents )\n"); + indent(indentation + 1, sb); + sb.append("HopRewriteUtils.replaceChildReference(p, hi, newRoot);\n\n"); + + indent(indentation, sb); + sb.append("// Remove old unreferenced Hops\n"); + removeUnreferencedHops(from, activeStatements, sb, vars, ctx, indentation); + sb.append('\n'); + + if (DEBUG) { + indent(indentation, sb); + sb.append("DMLExecutor.println(\"Applying rewrite: " + rewriteName + "\");\n"); + } + + if (maintainRewriteStats) { + indent(indentation, sb); + sb.append("Statistics.applyGeneratedRewrite(\"" + rewriteName + "\");\n"); + } + + indent(indentation, sb); + sb.append("return newRoot;\n"); + } + + private static void buildTypeCastFunction(StringBuilder sb, int indentation) { + String str = "private static Hop castIfNecessary(Hop newRoot, Hop oldRoot) {\n" + + "\tTypes.OpOp1 cast = null;\n" + + "\tswitch ( oldRoot.getValueType().toExternalString() ) {\n" + + "\t\tcase \"DOUBLE\":\n" + + "\t\t\tcast = Types.OpOp1.CAST_AS_DOUBLE;\n" + + "\t\t\tbreak;\n" + + "\t\tcase \"INT\":\n" + + "\t\t\tcast = Types.OpOp1.CAST_AS_INT;\n" + + "\t\t\tbreak;\n" + + "\t\tcase \"BOOLEAN\":\n" + + "\t\t\tcast = Types.OpOp1.CAST_AS_BOOLEAN;\n" + + "\t\t\tbreak;\n" + + "\t\tdefault:\n" + + "\t\t\treturn null;\n" + + "\t}\n" + + "\n" + + "\treturn new UnaryOp(\"tmp\", oldRoot.getDataType(), oldRoot.getValueType(), cast, newRoot);\n" + + "}\n"; + + sb.append(indentMultilineString(str, indentation)); + } + + private static void buildMinIdxFunction(StringBuilder sb, int indentation) { + String str = "private static int minIdx(double[] l) {\n" + + "\tdouble minValue = Double.MAX_VALUE;\n" + + "\tint minIdx = -1;\n" + + "\n" + + "\tfor (int i = 0; i < l.length; i++) {\n" + + "\t\tif (l[i] < minValue) {\n" + + "\t\t\tminValue = l[i];\n" + + "\t\t\tminIdx = i;\n" + + "\t\t}\n" + + "\t}\n" + + "\n" + + "\treturn minIdx;\n" + + "}\n"; + + sb.append(indentMultilineString(str, indentation)); + } + + private static String indentMultilineString(String str, int indentation) { + String tabs = "\t".repeat(indentation); + return str.lines() // Split the string into lines + .map(line -> tabs + line) // Add tabs to the beginning of each line + .collect(Collectors.joining("\n")); // Join the lines back together + } + + private static void buildCostFnRecursively(RewriterStatement costFn, Map vars, final RuleContext ctx, StringBuilder sb, Set> requirements) { + if (costFn.isLiteral()) { + sb.append(costFn.floatLiteral()); + return; + } + + if (!costFn.isInstruction()) + throw new IllegalArgumentException(); + + List operands; + + if (!costFn.getOperands().isEmpty() && costFn.getChild(0).isArgumentList()) + operands = costFn.getChild(0).getOperands(); + else + operands = costFn.getOperands(); + + String varName; + + // Then, the cost function is an instruction + switch (costFn.trueInstruction()) { + case "_nnz": + varName = vars.get(costFn.getChild(0)); + + if (varName == null) + throw new IllegalArgumentException(costFn.toParsableString(ctx)); + + requirements.add(new Tuple2<>(varName, "_nnz")); + sb.append(varName); + sb.append(".getNnz()"); + break; + + case "nrow": + varName = vars.get(costFn.getChild(0)); + + if (varName == null) + throw new IllegalArgumentException(); + + requirements.add(new Tuple2<>(varName, "nrow")); + sb.append(varName); + sb.append(".getDim1()"); + break; + + case "ncol": + varName = vars.get(costFn.getChild(0)); + + if (varName == null) + throw new IllegalArgumentException(); + + requirements.add(new Tuple2<>(varName, "ncol")); + sb.append(varName); + sb.append(".getDim2()"); + break; + + case "+": + case "*": + sb.append('('); + + for (int i = 0; i < operands.size(); i++) { + if (i != 0) { + sb.append(' '); + sb.append(costFn.trueInstruction()); + sb.append(' '); + } + + buildCostFnRecursively(operands.get(i), vars, ctx, sb, requirements); + } + + sb.append(')'); + break; + case "inv": + sb.append("(1.0 / "); + buildCostFnRecursively(operands.get(0), vars, ctx, sb, requirements); + sb.append(')'); + break; + case "min": + case "max": + sb.append("Math."); + sb.append(costFn.trueInstruction()); + sb.append('('); + for (int i = 0; i < operands.size(); i++) { + if (i != 0) + sb.append(", "); + + buildCostFnRecursively(operands.get(i), vars, ctx, sb, requirements); + } + sb.append(')'); + break; + case "_EClass": + // Here, we can just select a random representant + // Ideally, we would choose one that has dimensions available, but for now, we just take the first + buildCostFnRecursively(operands.get(0), vars, ctx, sb, requirements); + break; + default: + throw new IllegalArgumentException(costFn.trueInstruction()); + } + } + + // Returns the set of all active statements after the rewrite + private static Set buildRewrite(RewriterStatement newRoot, StringBuilder sb, RewriterAssertions assertions, Map vars, final RuleContext ctx, int indentation) { + Set visited = new HashSet<>(); + recursivelyBuildNewHop(sb, newRoot, assertions, vars, ctx, indentation, 1, visited, newRoot.getResultingDataType(ctx).equals("FLOAT")); + + return visited; + } + + private static void removeUnreferencedHops(RewriterStatement oldRoot, Set activeStatements, StringBuilder sb, Map vars, final RuleContext ctx, int indentation) { + oldRoot.forEachPreOrder(cur -> { + if (activeStatements.contains(cur)) + return true; + + indent(indentation, sb); + sb.append("HopRewriteUtils.cleanupUnreferenced(" + vars.get(cur) + ");\n"); + return true; + }, false); + } + + private static int recursivelyBuildNewHop(StringBuilder sb, RewriterStatement cur, RewriterAssertions assertions, Map vars, final RuleContext ctx, int indentation, int varCtr, Set visited, boolean enforceRootDataType) { + visited.add(cur); + if (vars.containsKey(cur)) + return varCtr; + + for (RewriterStatement child : cur.getOperands()) + varCtr = recursivelyBuildNewHop(sb, child, assertions, vars, ctx, indentation, varCtr, visited, false); + + if (cur instanceof RewriterDataType) { + if (cur.isLiteral()) { + indent(indentation, sb); + String name = "l" + (varCtr++); + String literalStr = cur.getLiteral().toString(); + + if (enforceRootDataType) { + sb.append("LiteralOp " + name + ";\n"); + indent(indentation, sb); + sb.append("switch (hi.getValueType()) {\n"); + indent(indentation+1, sb); + sb.append("case FP64:\n"); + indent(indentation+2, sb); + sb.append(name); + sb.append(" = new LiteralOp( " + cur.floatLiteral() + " );\n"); + indent(indentation+2, sb); + sb.append("break;\n"); + indent(indentation+1, sb); + sb.append("case INT64:\n"); + indent(indentation+2, sb); + sb.append(name); + sb.append(" = new LiteralOp( " + cur.intLiteral(true) + " );\n"); + indent(indentation+2, sb); + sb.append("break;\n"); + indent(indentation+1, sb); + sb.append("case BOOLEAN:\n"); + indent(indentation+2, sb); + sb.append(name); + sb.append(" = new LiteralOp( " + cur.boolLiteral() + " );\n"); + indent(indentation+2, sb); + sb.append("break;\n"); + indent(indentation+1, sb); + sb.append("default:\n"); + indent(indentation+2, sb); + sb.append("return hi;\n"); + indent(indentation+1, sb); + sb.append("}\n"); + } else { + sb.append("LiteralOp " + name + " = new LiteralOp( " + literalStr + " );\n"); + } + vars.put(cur, name); + } + + return varCtr; + } else { + String opClass = CodeGenUtils.getOpClass(cur, ctx); + String[] operandRefs = cur.getOperands().stream().map(vars::get).toArray(String[]::new); + + if (CodeGenUtils.opRequiresBinaryBroadcastingMatch(cur, ctx)) { + // Then we need to validate that broadcasting still works after rearranging + indent(indentation, sb); + sb.append("if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(" + operandRefs[0] + ", " + operandRefs[1] + ") )\n"); + indent(indentation+1, sb); + sb.append("return hi;\n"); + } else { + List matchingDims = CodeGenUtils.matchingDimRequirement(cur, ctx); + + if (!matchingDims.isEmpty()) { + // Then we need to validate that broadcasting still works after rearranging + sb.append("if ( !RewriterRuntimeUtils.hasMatchingDims(" + matchingDims.stream().map(idx -> operandRefs[idx]).collect(Collectors.joining(", ")) + ") )\n"); + indent(indentation+1, sb); + sb.append("return hi;\n"); + } + } + + String constructor = CodeGenUtils.getHopConstructor(cur, assertions, vars, ctx, operandRefs); + String name = "v" + (varCtr++); + indent(indentation, sb); + sb.append(opClass + " " + name + " = " + constructor + ";\n"); + + vars.put(cur, name); + } + + return varCtr; + } + + private static void recursivelyBuildMatchingSequence(RewriterStatement cur, StringBuilder sb, String curVar, final RuleContext ctx, int indentation, Map map, Set allowedMultiRefs, boolean allowCombinations) { + if (cur.isLiteral()) { + String[] types = CodeGenUtils.getReturnType(cur, ctx); + indent(indentation, sb); + sb.append("if ( !(" + curVar + " instanceof LiteralOp) )\n"); + indent(indentation + 1, sb); + sb.append("return hi;\n\n"); + indent(indentation, sb); + String lVar = "l_" + curVar; + sb.append("LiteralOp " + lVar + " = (LiteralOp) " + curVar + ";\n\n"); + indent(indentation, sb); + sb.append("if ( " + lVar + ".getDataType() != " + types[0]); + sb.append("|| !" + lVar + ".getValueType().isNumeric()"); + sb.append(" )\n"); + + indent(indentation + 1, sb); + sb.append("return hi;\n\n"); + + indent(indentation, sb); + sb.append("if ( " + lVar + "." + CodeGenUtils.literalGetterFunction(cur, ctx) + " != " + cur.getLiteral() + " )\n"); + indent(indentation + 1, sb); + sb.append("return hi;\n\n"); + + return; + } + + // Check if we have to ensure a single reference to this object + if (cur.isInstruction() && !allowedMultiRefs.contains(cur)) { + if (allowCombinations && !allowedMultiRefs.contains(cur)) { + indent(indentation, sb); + sb.append("if ("); + sb.append(curVar); + sb.append(".getParent().size() > 1)\n"); + indent(indentation + 1, sb); + sb.append("return hi;\n"); + } else if (!allowedMultiRefs.contains(cur)) { + indent(indentation, sb); + sb.append("if ("); + sb.append(curVar); + sb.append(".getParent().size() > 1) {\n"); + indent(indentation + 1, sb); + sb.append("if (_multiReference)\n"); + indent(indentation + 2, sb); + sb.append("return hi;\n"); + indent(indentation + 1, sb); + sb.append("else\n"); + indent(indentation + 2, sb); + sb.append("_multiReference = true;\n"); + indent(indentation + 1, sb); + sb.append("}\n"); + } + } + + String specialOpCheck = CodeGenUtils.getSpecialOpCheck(cur, ctx, curVar); + + // E.g. A %*% B, which is an AggBinaryOp consisting of multiple OpCodes + if (specialOpCheck != null) { + indent(indentation, sb); + sb.append("if ( !" + specialOpCheck + " )\n"); + indent(indentation + 1, sb); + sb.append("return hi;\n\n"); + } else if (!cur.isDataOrigin()) { + String opClass = CodeGenUtils.getOpClass(cur, ctx); + + // Generate initial class check + indent(indentation, sb); + sb.append("if ( !(" + curVar + " instanceof " + opClass + ") )\n"); + indent(indentation + 1, sb); + sb.append("return hi;\n\n"); + + // Cast the expression to the corresponding op-class + String cCurVar = "c_" + curVar; + indent(indentation, sb); + sb.append(opClass + " " + cCurVar + " = (" + opClass + ") " + curVar + ";\n\n"); + + String opCode = CodeGenUtils.getOpCode(cur, ctx); + + // Check if the instruction matches + indent(indentation, sb); + if (opCode != null) { + sb.append("if ( " + cCurVar + ".getOp() != " + opCode); + sb.append(" || !" + cCurVar + ".getValueType().isNumeric()"); + } else { + sb.append("if ( !" + cCurVar + ".getValueType().isNumeric()"); + } + + sb.append(" )\n"); + indent(indentation + 1, sb); + sb.append("return hi;\n\n"); + + String additionalCheck = CodeGenUtils.getAdditionalCheck(cur, ctx, cCurVar); + + if (additionalCheck != null) { + indent(indentation, sb); + sb.append("if ( !(" + additionalCheck + ") )\n"); + indent(indentation + 1, sb); + sb.append("return hi;\n\n"); + } + } else { + indent(indentation, sb); + String[] types = CodeGenUtils.getReturnType(cur, ctx); + sb.append("if ( " + curVar + ".getDataType() != " + types[0]); + sb.append(" || !" + curVar + ".getValueType().isNumeric()"); + + if (cur.isRowVector()) { + sb.append(" || " + curVar + ".getDim2() != 1L"); + } else if (cur.isColVector()) { + sb.append(" || " + curVar + ".getDim1() != 1L"); + } + + sb.append(" )\n"); + indent(indentation + 1, sb); + sb.append("return hi;\n\n"); + + String additionalCheck = CodeGenUtils.getAdditionalCheck(cur, ctx, curVar); + + if (additionalCheck != null) { + indent(indentation, sb); + sb.append("if ( !(" + additionalCheck + ") )\n"); + indent(indentation + 1, sb); + sb.append("return hi;\n\n"); + } + } + + // Now, we match the children + for (int i = 0; i < cur.getOperands().size(); i++) { + RewriterStatement stmt = cur.getChild(i); + + String existingVar = map.get(stmt); + + if (existingVar != null) { + String name = resolveOperand(cur, i, sb, curVar, ctx, indentation); + sb.append('\n'); + // Just check if they are identical + indent(indentation, sb); + sb.append("if ( " + existingVar + " != " + name + " )\n"); + indent(indentation + 1, sb); + sb.append("return hi;\n\n"); + continue; + } + + // Build the variable definition + String name = resolveOperand(cur, i, sb, curVar, ctx, indentation); + map.put(stmt, name); + sb.append('\n'); + recursivelyBuildMatchingSequence(stmt, sb, name, ctx, indentation, map, allowedMultiRefs, allowCombinations); + } + } + + private static String resolveOperand(RewriterStatement stmt, int idx, StringBuilder sb, String curVar, final RuleContext ctx, int indentation) { + String name = curVar + "_" + idx; + indent(indentation, sb); + sb.append("Hop " + name + " = " + curVar + ".getInput(" + idx + ");\n"); + return name; + } + + public static void indent(int depth, StringBuilder sb) { + sb.append("\t".repeat(depth)); + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/dml/DMLCodeGenerator.java b/src/main/java/org/apache/sysds/hops/rewriter/dml/DMLCodeGenerator.java new file mode 100644 index 00000000000..9baba0c3fa1 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/dml/DMLCodeGenerator.java @@ -0,0 +1,383 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.dml; + +import org.apache.commons.lang3.NotImplementedException; +import org.apache.commons.lang3.function.TriFunction; +import org.apache.commons.lang3.mutable.MutableInt; +import org.apache.sysds.hops.rewriter.RewriterInstruction; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import scala.Tuple2; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +public class DMLCodeGenerator { + public static final long MATRIX_DIMS = 100; + public static final double EPS = 1e-10; + public static Random rd = new Random(42); + + + private static final HashSet printAsBinary = new HashSet<>(); + private static final HashMap, Boolean>> customEncoders = new HashMap<>(); + private static final RuleContext ctx = RewriterUtils.buildDefaultContext(); + + static { + printAsBinary.add("+"); + printAsBinary.add("-"); + printAsBinary.add("*"); + printAsBinary.add("/"); + printAsBinary.add("^"); + printAsBinary.add("&"); + printAsBinary.add("|"); + printAsBinary.add("=="); + printAsBinary.add("!="); + printAsBinary.add(">"); + printAsBinary.add(">="); + printAsBinary.add("<"); + printAsBinary.add("<="); + printAsBinary.add("%*%"); + + customEncoders.put("[]", (stmt, sb, tmpVars) -> { + if (stmt.getOperands().size() == 3) { + appendExpression(stmt.getChild(0), sb, tmpVars); + sb.append('['); + appendExpression(stmt.getChild(1), sb, tmpVars); + sb.append(", "); + appendExpression(stmt.getChild(2), sb, tmpVars); + sb.append(']'); + return true; + } else if (stmt.getOperands().size() == 5) { + appendExpression(stmt.getChild(0), sb, tmpVars); + sb.append('['); + appendExpression(stmt.getChild(1), sb, tmpVars); + sb.append(" : "); + appendExpression(stmt.getChild(2), sb, tmpVars); + sb.append(", "); + appendExpression(stmt.getChild(3), sb, tmpVars); + sb.append(" : "); + appendExpression(stmt.getChild(4), sb, tmpVars); + sb.append(']'); + return true; + } + + return false; + }); + + customEncoders.put("const", (stmt, sb, tmpVars) -> { + sb.append("matrix("); + appendExpression(stmt.getChild(1), sb, tmpVars); + sb.append(", rows="); + sb.append(MATRIX_DIMS); + sb.append(", cols="); + sb.append(MATRIX_DIMS); + sb.append(')'); + + return true; + }); + + customEncoders.put("cast.MATRIX", (stmt, sb, tmpVars) -> { + sb.append("as.matrix("); + appendExpression(stmt.getChild(0), sb, tmpVars); + sb.append(')'); + + return true; + }); + + customEncoders.put("cast.FLOAT", (stmt, sb, tmpVars) -> { + if (stmt.getChild(0).getResultingDataType(ctx).equals("MATRIX")) { + sb.append("as.scalar("); + appendExpression(stmt.getChild(0), sb, tmpVars); + sb.append(')'); + } else { + sb.append("as.double("); + appendExpression(stmt.getChild(0), sb, tmpVars); + sb.append(')'); + } + + return true; + }); + } + + public static Consumer ruleValidationScript(String ruleName, String sessionId, Consumer validator) { + return line -> { + if (!line.startsWith(sessionId)) + return; + + if (line.endsWith("valid: TRUE")) { + validator.accept(true); + } else { + validator.accept(false); + } + }; + } + + public static String generateRuleValidationDML(RewriterRule rule, String sessionId, final RuleContext ctx) { + return generateRuleValidationDML(rule, EPS, sessionId, ctx); + } + + public static String generateRuleValidationDML(RewriterRule rule, double eps, String sessionId, final RuleContext ctx) { + RewriterStatement stmtFrom = RewriterUtils.unfuseOperators(rule.getStmt1(), ctx); + RewriterStatement stmtTo = RewriterUtils.unfuseOperators(rule.getStmt2(), ctx); + + Set vars = new HashSet<>(); + List> orderedTmpVars = new ArrayList<>(); + Map tmpVars = new HashMap<>(); + MutableInt tmpVarCtr = new MutableInt(0); + + stmtFrom.forEachPostOrder((stmt, pred) -> { + if (stmt.isDataOrigin() && !stmt.isLiteral()) + vars.add(stmt); + else + createTmpVars(stmt, orderedTmpVars, tmpVars, tmpVarCtr); + }, false); + + stmtTo.forEachPostOrder((stmt, pred) -> { + if (stmt.isDataOrigin() && !stmt.isLiteral()) + vars.add(stmt); + else + createTmpVars(stmt, orderedTmpVars, tmpVars, tmpVarCtr); + }, false); + + Set toRemove = vars.stream().filter(t -> t.isInstruction() && !t.trueInstruction().equals("const")).map(instr -> instr.getChild(0)).collect(Collectors.toSet()); + vars.removeAll(toRemove); + + StringBuilder sb = new StringBuilder(); + + sb.append(generateDMLVariables(vars)); + + Map incrementingTmpVars = new HashMap<>(); + + for (Tuple2 t : orderedTmpVars) { + sb.append(t._2); + sb.append(" = "); + sb.append(generateDML(t._1, incrementingTmpVars)); + sb.append('\n'); + incrementingTmpVars.put(t._1, t._2); + } + + sb.append('\n'); + sb.append("R1 = "); + sb.append(generateDML(stmtFrom, tmpVars)); + sb.append('\n'); + sb.append("R2 = "); + sb.append(generateDML(stmtTo, tmpVars)); + sb.append('\n'); + sb.append("print(\""); + sb.append(sessionId); + sb.append(" valid: \" + ("); + sb.append(generateEqualityCheck("R1", "R2", stmtFrom.getResultingDataType(ctx), eps)); + sb.append("))"); + + return sb.toString(); + } + + private static boolean createTmpVars(RewriterStatement stmt, List> orderedTmpVars, Map tmpVars, MutableInt tmpVarCtr) { + if (stmt.isInstruction() && stmt.trueInstruction().equals("[]")) { + // Then we need to put the child into a variable + RewriterStatement child = stmt.getChild(0); + if (child.isInstruction() || child.isLiteral()) { + String tmpVar = "tmp" + tmpVarCtr.getAndIncrement(); + tmpVars.put(child, tmpVar); + orderedTmpVars.add(new Tuple2<>(child, tmpVar)); + return true; + } + } + + return false; + } + + public static Set getVariables(RewriterStatement root) { + Set vars = new HashSet<>(); + root.forEachPostOrder((stmt, pred) -> { + if (stmt.isDataOrigin() && !stmt.isLiteral()) + vars.add(stmt); + }, false); + + Set toRemove = vars.stream().filter(stmt -> stmt.isInstruction() && !stmt.trueInstruction().equals("const")).map(instr -> instr.getChild(0)).collect(Collectors.toSet()); + vars.removeAll(toRemove); + + return vars; + } + + public static String generateDMLVariables(RewriterStatement root) { + return generateDMLVariables(getVariables(root)); + } + + public static String generateDMLVariables(Set vars) { + StringBuilder sb = new StringBuilder(); + + for (RewriterStatement var : vars) { + + switch (var.getResultingDataType(ctx)) { + case "MATRIX": + String mId = var.getId(); + long nrow = MATRIX_DIMS; + long ncol = MATRIX_DIMS; + if (var.isInstruction()) { + if (var.trueInstruction().equals("rowVec")) { + mId = var.getChild(0).getId(); + nrow = 1L; + } else if (var.trueInstruction().equals("colVec")) { + mId = var.getChild(0).getId(); + ncol = 1L; + } else if (var.trueInstruction().equals("const")) { + sb.append(var.getId()); + sb.append(" = matrix(" + var.getChild(1).getLiteral() + ", rows=" + nrow + ", cols=" + ncol + ")\n"); + continue; + } + } + sb.append(mId + " = cos((rand(rows=" + nrow + ", cols=" + ncol + ") * rand(rows=" + nrow + ", cols=" + ncol + ", min=(as.scalar(rand())+1.0), max=(as.scalar(rand())+2.0), seed=" + rd.nextInt(1000) + "))^as.scalar(rand()))\n"); + break; + case "FLOAT": + sb.append(var.getId() + " = cos(as.scalar(rand(min=(as.scalar(rand())+1.0), max=(as.scalar(rand())+2.0), seed=" + rd.nextInt(1000) + "))^as.scalar(rand()))\n"); + break; + case "INT": + sb.append(var.getId() + " = as.integer(cos(as.scalar(rand(min=(as.scalar(rand())+1.0), max=(as.scalar(rand()+200000.0)), seed=" + rd.nextInt(1000) + "))^as.scalar(rand())))\n"); + break; + case "BOOL": + sb.append(var.getId() + " = as.scalar(rand()) < 0.5\n"); + break; + default: + throw new NotImplementedException(var.getResultingDataType(ctx)); + } + } + + return sb.toString(); + } + + public static String generateEqualityCheck(String stmt1Var, String stmt2Var, String dataType, double eps) { + switch (dataType) { + case "MATRIX": + return "sum(abs(" + stmt1Var + " - " + stmt2Var + ") < " + eps + ") == length(" + stmt1Var + ")"; + case "INT": + case "BOOL": + return stmt1Var + " == " + stmt2Var; + case "FLOAT": + return "abs(" + stmt1Var + " - " + stmt2Var + ") < " + eps; + } + + throw new NotImplementedException(); + } + + public static String generateDMLDefs(RewriterStatement stmt) { + Map vars = new HashMap<>(); + + stmt.forEachPostOrder((cur, pred) -> { + if (!cur.isInstruction() && !cur.isLiteral()) + vars.put(cur.getId(), cur); + }, false); + + return generateDMLDefs(vars); + } + + public static String generateDMLDefs(Map defs) { + StringBuilder sb = new StringBuilder(); + + defs.forEach((k, v) -> { + sb.append(k); + sb.append(" = "); + sb.append(generateDML(v)); + sb.append('\n'); + }); + + return sb.toString(); + } + + public static String generateDML(RewriterStatement root) { + return generateDML(root, Collections.emptyMap()); + } + + public static String generateDML(RewriterStatement root, Map tmpVars) { + StringBuilder sb = new StringBuilder(); + appendExpression(root, sb, tmpVars); + + return sb.toString(); + } + + private static void appendExpression(RewriterStatement cur, StringBuilder sb, Map tmpVars) { + String tmpVar = tmpVars.get(cur); + + if (tmpVar != null) { + sb.append(tmpVar); + return; + } + + if (cur.isInstruction()) { + if (cur.isDataOrigin()) + sb.append(cur.getId()); + else + resolveExpression((RewriterInstruction) cur, sb, tmpVars); + } else { + if (cur.isLiteral()) + sb.append(cur.getLiteral()); + else + sb.append(cur.getId()); + } + } + + private static void resolveExpression(RewriterInstruction expr, StringBuilder sb, Map tmpVars) { + String typedInstr = expr.trueTypedInstruction(ctx); + String unTypedInstr = expr.trueInstruction(); + + if (expr.getOperands().size() == 2 && (printAsBinary.contains(typedInstr) || printAsBinary.contains(unTypedInstr))) { + sb.append('('); + appendExpression(expr.getChild(0), sb, tmpVars); + sb.append(") "); + sb.append(unTypedInstr); + sb.append(" ("); + appendExpression(expr.getChild(1), sb, tmpVars); + sb.append(')'); + return; + } + + TriFunction, Boolean> customEncoder = customEncoders.get(typedInstr); + + if (customEncoder == null) + customEncoder = customEncoders.get(unTypedInstr); + + if (customEncoder == null) { + sb.append(unTypedInstr); + sb.append('('); + + for (int i = 0; i < expr.getOperands().size(); i++) { + if (i != 0) + sb.append(", "); + + appendExpression(expr.getChild(i), sb, tmpVars); + } + + sb.append(')'); + } else { + customEncoder.apply(expr, sb, tmpVars); + } + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/dml/DMLExecutor.java b/src/main/java/org/apache/sysds/hops/rewriter/dml/DMLExecutor.java new file mode 100644 index 00000000000..0b07a84a7e9 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/dml/DMLExecutor.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.dml; + +import org.apache.commons.lang3.mutable.MutableBoolean; +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.OptimizerUtils; + +import java.io.OutputStream; +import java.io.PrintStream; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; +import java.util.function.Function; + +public class DMLExecutor { + private static PrintStream origPrintStream = System.out; + private static PrintStream origErrPrintStream = System.out; + + public static boolean APPLY_INJECTED_REWRITES = false; + public static Function REWRITE_FUNCTION = null; + + private static List lastErr; + + public static void executeCode(String code, boolean intercept, String... additionalArgs) { + executeCode(code, intercept ? s -> {} : null, additionalArgs); + } + + // Returns if true if the run was successful without any errors + public static boolean executeCode(String code, Consumer consoleInterceptor, String... additionalArgs) { + return executeCode(code, consoleInterceptor, null, additionalArgs); + } + + // This cannot run in parallel + public static synchronized boolean executeCode(String code, Consumer consoleInterceptor, Function injectedRewriteClass, String... additionalArgs) { + lastErr = new ArrayList<>(); + boolean exceptionOccurred = false; + + try { + if (consoleInterceptor != null) + System.setOut(new PrintStream(new CustomOutputStream(System.out, consoleInterceptor))); + + System.setErr(new PrintStream(new CustomOutputStream(System.err, lastErr::add))); + + String[] args = new String[additionalArgs.length + 2]; + + for (int i = 0; i < additionalArgs.length; i++) + args[i] = additionalArgs[i]; + + args[additionalArgs.length] = "-s"; + args[additionalArgs.length + 1] = code; + + if (injectedRewriteClass != null) { + APPLY_INJECTED_REWRITES = true; + REWRITE_FUNCTION = injectedRewriteClass; + } + + // To allow the discovery of sum((a*A)*B) which would usually be converted to n* + OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = false; + OptimizerUtils.ALLOW_OPERATOR_FUSION = false; + + DMLScript.executeScript(args); + + } catch (Exception e) { + e.printStackTrace(); + exceptionOccurred = true; + } + + APPLY_INJECTED_REWRITES = false; + REWRITE_FUNCTION = null; + + if (consoleInterceptor != null) + System.setOut(origPrintStream); + + System.setErr(origErrPrintStream); + + return !exceptionOccurred && lastErr.isEmpty(); + } + + public static List getLastErr() { + return lastErr; + } + + // Bypasses the interceptor + public static void println(Object o) { + origPrintStream.println(o); + } + + private static class CustomOutputStream extends OutputStream { + private PrintStream ps; + private StringBuilder buffer = new StringBuilder(); + private Consumer lineHandler; + + public CustomOutputStream(PrintStream actualPrintStream, Consumer lineHandler) { + this.ps = actualPrintStream; + this.lineHandler = lineHandler; + } + + @Override + public void write(int b) { + char c = (char) b; + if (c == '\n') { + lineHandler.accept(buffer.toString()); + buffer.setLength(0); // Clear the buffer after handling the line + } else { + buffer.append(c); // Accumulate characters until newline + } + } + + @Override + public void write(byte[] b, int off, int len) { + for (int i = off; i < off + len; i++) { + write(b[i]); + } + } + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/estimators/RewriterCostEstimator.java b/src/main/java/org/apache/sysds/hops/rewriter/estimators/RewriterCostEstimator.java new file mode 100644 index 00000000000..658a1114214 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/estimators/RewriterCostEstimator.java @@ -0,0 +1,947 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.estimators; + +import org.apache.commons.lang3.mutable.MutableInt; +import org.apache.commons.lang3.mutable.MutableLong; +import org.apache.commons.lang3.mutable.MutableObject; +import org.apache.sysds.hops.rewriter.RewriterInstruction; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertionUtils; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; +import org.apache.sysds.hops.rewriter.utils.StatementUtils; +import scala.Tuple2; +import scala.Tuple3; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Collectors; + +public class RewriterCostEstimator { + private static final long INSTRUCTION_OVERHEAD = 10; + private static final long MALLOC_COST = 10000; + public static final Function DEFAULT_COST_FN = el -> 2000L; + public static final BiFunction, Long> DEFAULT_NNZ_FN = (el, tpl) -> tpl._1 * tpl._2; + + // This is an important check as many intermediate matrices do not contain any sparsity information + // Thus, we want to use cost functions without sparsity information if possible + public static boolean doesHaveAnImpactOnOptimalExpression(List, Long, Long>> list, boolean sparsity, boolean sort, int costThreshhold) { + if (sort) + sort(list); + + int diff = 0; + Tuple3, Long, Long> last = null; + + for (Tuple3, Long, Long> t : list) { + if (Math.abs(t._2() - t._3()) < costThreshhold) + continue; + + if (last == null || (sparsity && !hasSameDims(last._1(), t._1()))) { + last = t; + diff = Long.signum(t._2() - t._3()); + continue; + } + + int mDiff = Long.signum(t._2() - t._3()); + + if (diff != mDiff && Math.abs(t._2() - t._3() - last._2() + last._3()) > costThreshhold) + return true; + } + + return false; + } + + private static boolean hasSameDims(List l1, List l2) { + int maxN = Math.min(l1.size(), l2.size()); + + for (int i = 0; i < maxN; i++) { + Number el1 = l1.get(i); + Number el2 = l2.get(i); + + if (el1 instanceof Long && el1.longValue() != el2.longValue()) + return false; + } + + return true; + } + + private static void sort(List, Long, Long>> list) { + list.sort((t1, t2) -> { + int size = Math.min(t1._1().size(), t2._1().size()); + for (int i = 0; i < size; i++) { + int cmp = Double.compare(t1._1().get(i).doubleValue(), t2._1().get(i).doubleValue()); + if (cmp != 0) + return cmp; // Return non-zero comparison result if elements differ + } + + return Integer.compare(t1._1().size(), t2._1().size()); + }); + } + + public static Set> findOptima(List, List>> data) { + Set> outSet = new HashSet<>(); + data.stream().forEach(t -> { + int minIdx = -1; + long minValue = Long.MAX_VALUE; + for (int i = 0; i < t._2.size(); i++) { + if (t._2.get(i) < minValue) { + minValue = t._2.get(i); + minIdx = i; + } + } + + for (int i = 0; i < t._2.size(); i++) { + if (t._2.get(i) > minValue) + outSet.add(new Tuple2<>(i, minIdx)); + } + }); + + return outSet; + } + + public static List, List>> compareCosts(List statements, RewriterAssertions jointAssertions, final RuleContext ctx, boolean sample, int sampleSize) { + List> estimates = statements.stream().map(stmt -> RewriterSparsityEstimator.estimateAllNNZ(stmt, ctx)).collect(Collectors.toList()); + + MutableObject assertionRef = new MutableObject<>(jointAssertions); + List costFns = statements.stream().map(stmt -> getRawCostFunction(stmt, ctx, assertionRef, false)).collect(Collectors.toList()); + + for (int i = 0; i < estimates.size(); i++) { + costFns.set(i, RewriterSparsityEstimator.rollupSparsities(costFns.get(i), estimates.get(i), ctx)); + } + + long[] dimVals = new long[] {10, 5000}; + double[] sparsities = new double[] {1.0D, 0.000001D}; + + Map createdObjects = new HashMap<>(); + List costFnCpys = costFns.stream().map(fn -> fn.nestedCopy(false, createdObjects)).collect(Collectors.toList()); + RewriterAssertions jointAssertionsCpy = RewriterAssertions.copy(jointAssertions, createdObjects, false); + + Set dimsToPopulate = new HashSet<>(); + Set nnzsToPopulate = new HashSet<>(); + + List costs = costFnCpys.stream().map(costFnCpy -> { + try { + return computeCostFunction(costFnCpy, el -> { + dimsToPopulate.add(el); + return 2000L; + }, (nnz, tpl) -> { + nnzsToPopulate.add(nnz.getChild(0)); + return tpl._1 * tpl._2; + }, jointAssertionsCpy, ctx); + } catch (Exception e) { + //e.printStackTrace(); + System.err.println("Error while estimating the cost: " + e.getMessage()); + return null; + } + }).collect(Collectors.toList()); + + int nDimsToPopulate = dimsToPopulate.size(); + int nNNZsToPopulate = nnzsToPopulate.size(); + + List firstList = new ArrayList<>(); + for (int i = 0; i < nDimsToPopulate; i++) + firstList.add(2000L); + for (int i = 0; i < nNNZsToPopulate; i++) + firstList.add(1.0D); + + List, List>> out = new ArrayList<>(); + out.add(new Tuple2<>(firstList, costs)); + + if (sampleSize < 2) + return out; + + List> nums = new ArrayList<>(); + List dimList = Arrays.stream(dimVals).mapToObj(dim -> ((Number)dim)).collect(Collectors.toList()); + List sparsityList = Arrays.stream(sparsities).mapToObj(s -> ((Number)s)).collect(Collectors.toList()); + + int numCombinations = 1; + + for (int i = 0; i < nDimsToPopulate; i++) { + nums.add(dimList); + numCombinations *= dimList.size(); + } + + for (int i = 0; i < nNNZsToPopulate; i++) { + nums.add(sparsityList); + numCombinations *= sparsityList.size(); + } + + Set samples = new HashSet<>(); + + if (sample) { + if (sampleSize < numCombinations) { + Random rd = new Random(); + + while (samples.size() < sampleSize) + samples.add(rd.nextInt(numCombinations)); + } else { + sample = false; + } + } + + final boolean doSample = sample; + + MutableInt ctr = new MutableInt(); + + if (nums.size() > 16) { + System.err.println("Could not properly sample: " + statements); + return out; + } + + RewriterUtils.cartesianProduct(nums, new Number[nums.size()], stack -> { + if (doSample && !samples.contains(ctr.getAndIncrement())) + return true; + + int sparsityStart = 0; + + for (Number num : stack) { + if (num instanceof Double) + break; + + sparsityStart++; + } + + final int fSparsityStart = sparsityStart; + + Map replace = new HashMap<>(); + + MutableInt dimCtr = new MutableInt(); + MutableInt sCtr = new MutableInt(); + + Map mCreatedObjects = new HashMap<>(); + List mCostFnCpys = costFns.stream().map(cpy -> cpy.nestedCopy(false, mCreatedObjects)).collect(Collectors.toList()); + RewriterAssertions mAssertionsCpy = RewriterAssertions.copy(jointAssertions, mCreatedObjects, false); + + List mCosts = mCostFnCpys.stream().map(mCpy -> { + try { + return computeCostFunction(mCpy, el -> { + Long literal = replace.get(el); + + if (literal == null) { + literal = (Long) stack[dimCtr.getAndIncrement()]; + //System.out.println("populated size with: " + literal); + replace.put(el, literal); + } + + return literal; + }, (nnz, tpl) -> { + Long literal = replace.get(nnz.getChild(0)); + + if (literal == null) { + double sparsity = (double) stack[fSparsityStart + sCtr.getAndIncrement()]; + literal = (long) Math.ceil(sparsity * tpl._1 * tpl._2); + replace.put(nnz.getChild(0), literal); + } + + return literal; + }, mAssertionsCpy, ctx); + } catch (Exception e) { + e.printStackTrace(); + return null; + } + }).collect(Collectors.toList()); + + out.add(new Tuple2<>(new ArrayList<>(Arrays.asList(stack)), mCosts)); + + return true; + }); + + return out; + } + + // Computes the cost of an expression using different matrix dimensions and sparsities + public static List, Long, Long>> compareCosts(RewriterStatement stmt1, RewriterStatement stmt2, RewriterAssertions jointAssertions, final RuleContext ctx, boolean sample, int sampleSize, boolean returnOnDifference) { + Map estimates1 = RewriterSparsityEstimator.estimateAllNNZ(stmt1, ctx); + Map estimates2 = RewriterSparsityEstimator.estimateAllNNZ(stmt2, ctx); + + MutableObject assertionRef = new MutableObject<>(jointAssertions); + RewriterStatement costFn1 = getRawCostFunction(stmt1, ctx, assertionRef, false); + RewriterStatement costFn2 = getRawCostFunction(stmt2, ctx, assertionRef, false); + + costFn1 = RewriterSparsityEstimator.rollupSparsities(costFn1, estimates1, ctx); + costFn2 = RewriterSparsityEstimator.rollupSparsities(costFn2, estimates2, ctx); + + final RewriterStatement fCostFn1 = costFn1; + final RewriterStatement fCostFn2 = costFn2; + + long[] dimVals = new long[] {10, 5000}; + double[] sparsities = new double[] {1.0D, 0.05D}; + + Map createdObjects = new HashMap<>(); + RewriterStatement costFn1Cpy = costFn1.nestedCopy(true, createdObjects); + RewriterStatement costFn2Cpy = costFn2.nestedCopy(false, createdObjects); + RewriterAssertions jointAssertionsCpy = RewriterAssertions.copy(jointAssertions, createdObjects, false); + + Set dimsToPopulate = new HashSet<>(); + Set nnzsToPopulate = new HashSet<>(); + + long cost1 = computeCostFunction(costFn1Cpy, el -> { + dimsToPopulate.add(el); + return 2000L; + }, (nnz, tpl) -> { + nnzsToPopulate.add(nnz.getChild(0)); + return tpl._1 * tpl._2; + }, jointAssertionsCpy, ctx); + long cost2 = computeCostFunction(costFn2Cpy, el -> { + dimsToPopulate.add(el); + return 2000L; + }, (nnz, tpl) -> { + nnzsToPopulate.add(nnz.getChild(0)); + return tpl._1 * tpl._2; + }, jointAssertionsCpy, ctx); + + int nDimsToPopulate = dimsToPopulate.size(); + int nNNZsToPopulate = nnzsToPopulate.size(); + + List firstList = new ArrayList<>(); + for (int i = 0; i < nDimsToPopulate; i++) + firstList.add(2000L); + for (int i = 0; i < nNNZsToPopulate; i++) + firstList.add(1.0D); + + List, Long, Long>> out = new ArrayList<>(); + out.add(new Tuple3<>(firstList, cost1, cost2)); + + if (returnOnDifference && cost1 != cost2) + return out; + + List> nums = new ArrayList<>(); + List dimList = Arrays.stream(dimVals).mapToObj(dim -> ((Number)dim)).collect(Collectors.toList()); + List sparsityList = Arrays.stream(sparsities).mapToObj(s -> ((Number)s)).collect(Collectors.toList()); + + int numCombinations = 1; + + for (int i = 0; i < nDimsToPopulate; i++) { + nums.add(dimList); + numCombinations *= dimList.size(); + } + + for (int i = 0; i < nNNZsToPopulate; i++) { + nums.add(sparsityList); + numCombinations *= sparsityList.size(); + } + + Set samples = new HashSet<>(); + + if (sample) { + if (sampleSize < numCombinations) { + Random rd = new Random(); + + while (samples.size() < sampleSize) + samples.add(rd.nextInt(numCombinations)); + } else { + sample = false; + } + } + + final boolean doSample = sample; + + MutableInt ctr = new MutableInt(); + + RewriterUtils.cartesianProduct(nums, new Number[nums.size()], stack -> { + if (doSample && !samples.contains(ctr.getAndIncrement())) + return true; + + int sparsityStart = 0; + + for (Number num : stack) { + if (num instanceof Double) + break; + + sparsityStart++; + } + + final int fSparsityStart = sparsityStart; + + Map replace = new HashMap<>(); + + MutableInt dimCtr = new MutableInt(); + MutableInt sCtr = new MutableInt(); + + Map mCreatedObjects = new HashMap<>(); + RewriterStatement mCpy1 = fCostFn1.nestedCopy(false, mCreatedObjects); + RewriterStatement mCpy2 = fCostFn2.nestedCopy(false, mCreatedObjects); + RewriterAssertions mAssertionsCpy = RewriterAssertions.copy(jointAssertions, mCreatedObjects, false); + + long mCost1 = computeCostFunction(mCpy1, el -> { + Long literal = replace.get(el); + + if (literal == null) { + literal = (Long) stack[dimCtr.getAndIncrement()]; + replace.put(el, literal); + } + + return literal; + }, (nnz, tpl) -> { + Long literal = replace.get(nnz.getChild(0)); + + if (literal == null) { + double sparsity = (double) stack[fSparsityStart + sCtr.getAndIncrement()]; + literal = (long)Math.ceil(sparsity * tpl._1 * tpl._2); + replace.put(nnz.getChild(0), literal); + } + + return literal; + }, mAssertionsCpy, ctx); + long mCost2 = computeCostFunction(mCpy2, el -> { + Long literal = replace.get(el); + + if (literal == null) { + literal = (Long) stack[dimCtr.getAndIncrement()]; + replace.put(el, literal); + } + + return literal; + }, (nnz, tpl) -> { + Long literal = replace.get(nnz.getChild(0)); + + if (literal == null) { + double sparsity = (double) stack[fSparsityStart + sCtr.getAndIncrement()]; + literal = (long)Math.ceil(sparsity * tpl._1 * tpl._2); + replace.put(nnz.getChild(0), literal); + } + + return literal; + }, mAssertionsCpy, ctx); + + out.add(new Tuple3<>(new ArrayList<>(Arrays.asList(stack)), mCost1, mCost2)); + + return !returnOnDifference || mCost1 == mCost2; + }); + + return out; + } + + public static Tuple2, Boolean> determineSingleReferenceRequirement(RewriterRule rule, final RuleContext ctx) { + MutableObject assertionRef = new MutableObject<>(); + long fullCost = RewriterCostEstimator.estimateCost(rule.getStmt1(), ctx, assertionRef); + long maxCost = RewriterCostEstimator.estimateCost(rule.getStmt2(), ctx); + return RewriterCostEstimator.determineSingleReferenceRequirement(rule.getStmt2(), RewriterCostEstimator.DEFAULT_COST_FN, RewriterCostEstimator.DEFAULT_NNZ_FN, assertionRef.getValue(), fullCost, maxCost, ctx); + } + + public static Tuple2, Boolean> determineSingleReferenceRequirement(RewriterStatement root, Function costFn, RewriterAssertions assertions, long fullCost, long maxCost, final RuleContext ctx) { + return determineSingleReferenceRequirement(root, costFn, RewriterCostEstimator.DEFAULT_NNZ_FN, assertions, fullCost, maxCost, ctx); + } + + // Returns all (upmost) sub-DAGs that can have multiple references and true as a second arg if all statements can have multiple references at once + public static Tuple2, Boolean> determineSingleReferenceRequirement(RewriterStatement root, Function costFn, BiFunction, Long> nnzFn, RewriterAssertions assertions, long fullCost, long maxCost, final RuleContext ctx) { + if (fullCost >= maxCost) + return new Tuple2<>(Collections.emptySet(), true); + + List> subDAGCosts = new ArrayList<>(); + + root.forEachPreOrder((cur, pred) -> { + if (pred.isRoot() || !cur.isInstruction()) + return true; + + long cost = estimateCost(cur, costFn, nnzFn, ctx, new MutableObject<>(assertions)); + + if (fullCost + cost <= maxCost) { + subDAGCosts.add(new Tuple2<>(cur, cost)); + return false; + } + + return true; + }, true); + + boolean canCombine = true; + long curCost = fullCost; + + for (Tuple2 t : subDAGCosts) { + curCost += t._2; + + if (curCost > maxCost) { + canCombine = false; + break; + } + } + + return new Tuple2<>(subDAGCosts.stream().map(t -> t._1).collect(Collectors.toSet()), canCombine); + } + + public static long estimateCost(RewriterStatement stmt, final RuleContext ctx) { + return estimateCost(stmt, DEFAULT_COST_FN, ctx); + } + + public static long estimateCost(RewriterStatement stmt, final RuleContext ctx, MutableObject assertionRef) { + return estimateCost(stmt, DEFAULT_COST_FN, DEFAULT_NNZ_FN, ctx, assertionRef); + } + + public static long estimateCost(RewriterStatement stmt, Function propertyGenerator, final RuleContext ctx) { + return estimateCost(stmt, propertyGenerator, DEFAULT_NNZ_FN, ctx, null); + } + + public static long estimateCost(RewriterStatement stmt, Function propertyGenerator, BiFunction, Long> nnzGenerator, final RuleContext ctx, MutableObject assertionRef) { + if (assertionRef == null) + assertionRef = new MutableObject<>(); + + RewriterStatement costFn = getRawCostFunction(stmt, ctx, assertionRef, false); + return computeCostFunction(costFn, propertyGenerator, nnzGenerator, assertionRef.getValue(), ctx); + } + + public static RewriterStatement getRawCostFunction(RewriterStatement stmt, final RuleContext ctx, MutableObject assertionRef, boolean treatAsDense) { + RewriterAssertions assertions = assertionRef != null && assertionRef.getValue() != null ? assertionRef.getValue() : new RewriterAssertions(ctx); + + if (assertionRef != null) + assertionRef.setValue(assertions); + + RewriterStatement costFn = propagateCostFunction(stmt, ctx, assertions, treatAsDense); + Map estimations = RewriterSparsityEstimator.estimateAllNNZ(costFn, ctx); + RewriterSparsityEstimator.rollupSparsities(costFn, estimations, ctx); + costFn = assertions.update(costFn); + costFn = RewriterUtils.foldConstants(costFn, ctx); + + return costFn; + } + + public static long computeCostFunction(RewriterStatement costFn, Function propertyGenerator, BiFunction, Long> nnzGenerator, RewriterAssertions assertions, final RuleContext ctx) { + Map map = new HashMap<>(); + + costFn.forEachPreOrder((cur, pred) -> { + for (int i = 0; i < cur.getOperands().size(); i++) { + RewriterStatement op = cur.getChild(i); + + RewriterStatement mNew = map.get(op); + if (mNew != null) { + cur.getOperands().set(i, mNew); + continue; + } + + if (op.isEClass()) { + RewriterAssertions.RewriterAssertion assertion = assertions.getAssertionObj(op); + Optional literal = assertion != null ? assertion.getLiteral() : Optional.empty(); + + mNew = literal.orElseGet(() -> RewriterStatement.literal(ctx, propertyGenerator.apply(op))); + + map.put(op, mNew); + cur.getOperands().set(i, mNew); + } else if (op.isInstruction()) { + if (op.trueInstruction().equals("ncol") || op.trueInstruction().equals("nrow")) { + RewriterStatement eClassStmt = assertions.getAssertionStatement(op, null); + mNew = RewriterStatement.literal(ctx, propertyGenerator.apply(eClassStmt)); + map.put(eClassStmt, mNew); + cur.getOperands().set(i, mNew); + } + } + } + + return true; + }, false); + + costFn.forEachPreOrder((cur, pred) -> { + for (int i = 0; i < cur.getOperands().size(); i++) { + RewriterStatement op = cur.getChild(i); + + RewriterStatement mNew = map.get(op); + if (mNew != null) { + cur.getOperands().set(i, mNew); + continue; + } + + if (op.isInstruction() && op.trueInstruction().equals("_nnz")) { + RewriterStatement ncolLiteral = map.get(op.getChild(0).getNCol()); + + if (ncolLiteral == null) { + RewriterAssertions.RewriterAssertion assertion = assertions.getAssertionObj(op.getChild(0).getNCol()); + + if (assertion != null) { + RewriterStatement assStmt = assertion.getEClassStmt(ctx, assertions); + ncolLiteral = map.get(assStmt); + + if (ncolLiteral == null) { + ncolLiteral = RewriterStatement.literal(ctx, propertyGenerator.apply(assStmt)); + map.put(assStmt, ncolLiteral); + } + } else { + ncolLiteral = RewriterStatement.literal(ctx, propertyGenerator.apply(op.getChild(0).getNCol())); + map.put(op.getChild(0).getNCol(), ncolLiteral); + } + } + + RewriterStatement nrowLiteral = map.get(op.getChild(0).getNRow()); + + if (nrowLiteral == null) { + RewriterAssertions.RewriterAssertion assertion = assertions.getAssertionObj(op.getChild(0).getNRow()); + + if (assertion != null) { + RewriterStatement assStmt = assertion.getEClassStmt(ctx, assertions); + nrowLiteral = map.get(assStmt); + + if (nrowLiteral == null) { + nrowLiteral = RewriterStatement.literal(ctx, propertyGenerator.apply(assStmt)); + map.put(assStmt, nrowLiteral); + } + } else { + nrowLiteral = RewriterStatement.literal(ctx, propertyGenerator.apply(op.getChild(0).getNRow())); + map.put(op.getChild(0).getNRow(), nrowLiteral); + } + } + + mNew = RewriterStatement.literal(ctx, nnzGenerator.apply(op, new Tuple2<>(nrowLiteral.intLiteral(false), ncolLiteral.intLiteral(false)))); + map.put(op, mNew); + cur.getOperands().set(i, mNew); + } + } + + return true; + }, false); + + costFn.forEachPreOrder(cur -> { + if (cur.isInstruction()) + cur.refreshReturnType(ctx); + + return true; + }, false); + + costFn = RewriterUtils.foldConstants(costFn, ctx); + + if (!costFn.isLiteral()) { + throw new IllegalArgumentException("Cost function must be a literal: " + costFn.toParsableString(ctx)); + } + + if (costFn.getLiteral() instanceof Double) + return (long)((double)costFn.getLiteral()); + + return (long)costFn.getLiteral(); + } + + private static RewriterStatement propagateCostFunction(RewriterStatement stmt, final RuleContext ctx, RewriterAssertions assertions, boolean treatAsDense) { + List includedCosts = new ArrayList<>(); + MutableLong instructionOverhead = new MutableLong(0); + + stmt.forEachPostOrder((cur, pred) -> { + if (!(cur instanceof RewriterInstruction)) + return; + + computeCostOf((RewriterInstruction) cur, ctx, includedCosts, assertions, instructionOverhead, treatAsDense, stmt); + instructionOverhead.add(INSTRUCTION_OVERHEAD); + }, false); + + includedCosts.add(RewriterStatement.literal(ctx, instructionOverhead.longValue())); + + RewriterStatement argList = RewriterStatement.argList(ctx, includedCosts); + RewriterStatement add = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("+").withOps(argList).consolidate(ctx); + add.unsafePutMeta("_assertions", assertions); + return add; + } + + private static RewriterStatement computeCostOf(RewriterInstruction instr, final RuleContext ctx, List uniqueCosts, RewriterAssertions assertions, MutableLong instructionOverhead, boolean treatAsDense, RewriterStatement exprRoot) { + if (instr.getResultingDataType(ctx).equals("MATRIX")) + return computeMatrixOpCost(instr, ctx, uniqueCosts, assertions, instructionOverhead, treatAsDense, exprRoot); + else + return computeScalarOpCost(instr, ctx, uniqueCosts, assertions, instructionOverhead, treatAsDense, exprRoot); + } + + private static RewriterStatement computeMatrixOpCost(RewriterInstruction instr, final RuleContext ctx, List uniqueCosts, RewriterAssertions assertions, MutableLong overhead, boolean treatAsDense, RewriterStatement exprRoot) { + RewriterAssertionUtils.buildImplicitAssertion(instr, assertions, exprRoot, ctx); + + RewriterStatement cost = null; + Map map = new HashMap<>(); + + switch (instr.trueInstruction()) { + case "%*%": + map.put("A", instr.getChild(0)); + map.put("B", instr.getChild(1)); + map.put("nrowA", instr.getChild(0).getNRow()); + map.put("ncolA", instr.getChild(0).getNCol()); + map.put("nrowB", instr.getChild(1).getNRow()); + map.put("ncolB", instr.getChild(1).getNCol()); + map.put("mulCost", atomicOpCostStmt("*", ctx)); + map.put("sumCost", atomicOpCostStmt("+", ctx)); + map.put("nnzA", RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense)); + map.put("nnzB", RewriterStatement.nnz(instr.getChild(1), ctx, treatAsDense)); + // Rough estimation + cost = RewriterUtils.parse("*(argList(min(nnzA, nnzB), ncolA, +(argList(mulCost, sumCost))))", ctx, map); + overhead.add(MALLOC_COST); + break; + case "t": + case "rev": + cost = RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense);//RewriterUtils.parse("_nnz(A)", ctx, map); + overhead.add(MALLOC_COST); + break; + case "rowSums": + case "colSums": + map.put("A", instr.getChild(0)); + map.put("nnzA", RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense)); + RewriterStatement aoc = atomicOpCostStmt("+", ctx); + map.put("opcost", aoc); + // Rough estimation + cost = RewriterUtils.parse("*(argList(nnzA, opcost))", ctx, map); + overhead.add(MALLOC_COST); + break; + case "diag": + map.put("nrowA", instr.getChild(0).getNRow()); + map.put("nnzA", RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense)); + map.put("A", instr.getChild(0)); + cost = RewriterUtils.parse("min(nnzA, nrowA)", ctx, map); + overhead.add(MALLOC_COST); + break; + case "cast.MATRIX": + cost = RewriterStatement.literal(ctx, 20L); + break; + case "[]": + cost = RewriterStatement.literal(ctx, 0L); + break; // I assume that nothing is materialized + case "RBind": + case "CBind": + map.put("A", instr.getChild(0)); + map.put("B", instr.getChild(1)); + map.put("nnzA", RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense)); + map.put("nnzB", RewriterStatement.nnz(instr.getChild(1), ctx, treatAsDense)); + cost = RewriterUtils.parse("+(argList(nnzA, nnzB))", ctx, map); + overhead.add(MALLOC_COST); + break; + case "rand": + cost = RewriterStatement.nnz(instr, ctx, treatAsDense); + overhead.add(MALLOC_COST); + break; + case "1-*": + RewriterStatement subtractionCost = atomicOpCostStmt("-", ctx); + RewriterStatement mulCost = atomicOpCostStmt("*", ctx); + RewriterStatement sparsityAwareMul = RewriterStatement.multiArgInstr(ctx, "*", mulCost, StatementUtils.min(ctx, RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense), RewriterStatement.nnz(instr.getChild(1), ctx, treatAsDense))); + RewriterStatement oneMinus = RewriterStatement.multiArgInstr(ctx, "*", subtractionCost, instr.getNCol(), instr.getNRow()); + cost = RewriterStatement.multiArgInstr(ctx, "+", oneMinus, sparsityAwareMul); + overhead.add(MALLOC_COST); + break; + case "+*": + RewriterStatement additionCost = atomicOpCostStmt("+", ctx); + mulCost = atomicOpCostStmt("*", ctx); + RewriterStatement sum = RewriterStatement.multiArgInstr(ctx, "+", additionCost, mulCost); + cost = RewriterStatement.multiArgInstr(ctx, "*", sum, StatementUtils.min(ctx, RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense), RewriterStatement.nnz(instr.getChild(2), ctx, treatAsDense))); + overhead.add(MALLOC_COST + 50); // To make it worse than 1-* + break; + case "-*": + subtractionCost = atomicOpCostStmt("-", ctx); + mulCost = atomicOpCostStmt("*", ctx); + sum = RewriterStatement.multiArgInstr(ctx, "+", subtractionCost, mulCost); + cost = RewriterStatement.multiArgInstr(ctx, "*", sum, StatementUtils.min(ctx, RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense), RewriterStatement.nnz(instr.getChild(2), ctx, treatAsDense))); + overhead.add(MALLOC_COST + 50); // To make it worse than 1-* + break; + case "*2": + cost = RewriterStatement.multiArgInstr(ctx, "*", atomicOpCostStmt("*2", ctx), RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense)); + overhead.add(MALLOC_COST); + break; + case "sq": + cost = RewriterStatement.multiArgInstr(ctx, "*", atomicOpCostStmt("sq", ctx), RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense)); + overhead.add(MALLOC_COST); + break; + case "sqrt": + cost = RewriterStatement.multiArgInstr(ctx, "*", atomicOpCostStmt("sqrt", ctx), RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense)); + overhead.add(MALLOC_COST); + break; + case "exp": + cost = RewriterStatement.multiArgInstr(ctx, "*", atomicOpCostStmt("exp", ctx), RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense)); + overhead.add(MALLOC_COST); + break; + case "log_nz": { + // Must be a matrix + RewriterStatement logCost = atomicOpCostStmt("log", ctx); + RewriterStatement twoLogCost = RewriterStatement.multiArgInstr(ctx, "*", RewriterStatement.literal(ctx, 2L), logCost); + RewriterStatement neqCost = atomicOpCostStmt("!=", ctx); + sum = RewriterStatement.multiArgInstr(ctx, "+", neqCost, instr.getOperands().size() == 2 ? twoLogCost : logCost); + cost = RewriterStatement.multiArgInstr(ctx, "*", sum, RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense)); + overhead.add(MALLOC_COST); + break; + } + case "log": + if (instr.getChild(0).getResultingDataType(ctx).equals("MATRIX")) { + RewriterStatement logCost = atomicOpCostStmt("log", ctx); + RewriterStatement twoLogCost = RewriterStatement.multiArgInstr(ctx, "*", RewriterStatement.literal(ctx, 2L), logCost); + cost = RewriterStatement.multiArgInstr(ctx, "*", instr.getOperands().size() == 2 ? twoLogCost : logCost, instr.getNCol(), instr.getNRow()); + overhead.add(MALLOC_COST); + } else { + RewriterStatement logCost = atomicOpCostStmt("log", ctx); + RewriterStatement twoLogCost = RewriterStatement.multiArgInstr(ctx, "*", RewriterStatement.literal(ctx, 2L), logCost); + cost = instr.getOperands().size() == 2 ? twoLogCost : logCost; + } + break; + case "const": + case "rowVec": + case "colVec": + case "cellMat": + cost = RewriterStatement.literal(ctx, 0L); + break; + } + + if (cost == null) { + if (instr.hasProperty("ElementWiseInstruction", ctx)) { + RewriterStatement firstMatrix = null; + RewriterStatement secondMatrix = null; + if (instr.getChild(0).getResultingDataType(ctx).equals("MATRIX")) { + firstMatrix = instr.getChild(0); + } + + if (instr.getChild(1).getResultingDataType(ctx).equals("MATRIX")) { + if (firstMatrix == null) + firstMatrix = instr.getChild(1); + else + secondMatrix = instr.getChild(1); + } + + RewriterStatement opCost = atomicOpCostStmt(instr.trueInstruction(), ctx); + + if (firstMatrix != null) { + switch (instr.trueInstruction()) { + case "*": + cost = new RewriterInstruction().as(UUID.randomUUID().toString()) + .withInstruction("*") + .withOps(RewriterStatement.argList(ctx, opCost, secondMatrix != null ? StatementUtils.min(ctx, RewriterStatement.nnz(firstMatrix, ctx, treatAsDense), RewriterStatement.nnz(secondMatrix, ctx, treatAsDense)) : RewriterStatement.nnz(firstMatrix, ctx, treatAsDense))); + break; + case "/": + if (instr.getChild(0).getResultingDataType(ctx).equals("MATRIX")) + cost = new RewriterInstruction().as(UUID.randomUUID().toString()) + .withInstruction("*") + .withOps(RewriterStatement.argList(ctx, opCost, RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense))); + else + cost = new RewriterInstruction().as(UUID.randomUUID().toString()) + .withInstruction("*") + .withOps(RewriterStatement.argList(ctx, opCost, StatementUtils.length(ctx, firstMatrix))); + + break; + case "+": + case "-": + cost = new RewriterInstruction().as(UUID.randomUUID().toString()) + .withInstruction("*") + .withOps(RewriterStatement.argList(ctx, opCost, secondMatrix != null ? StatementUtils.add(ctx, RewriterStatement.nnz(firstMatrix, ctx, treatAsDense), RewriterStatement.nnz(secondMatrix, ctx, treatAsDense)) : RewriterStatement.nnz(firstMatrix, ctx, treatAsDense))); + break; + default: + cost = RewriterStatement.multiArgInstr(ctx, "*", opCost, instr.getNRow(), instr.getNCol()); + break; + } + + overhead.add(MALLOC_COST); + } else { + cost = opCost; + } + } else if (instr.hasProperty("UnaryElementWiseOperator", ctx)) { + RewriterStatement opCost = atomicOpCostStmt(instr.trueInstruction(), ctx); + cost = new RewriterInstruction().as(UUID.randomUUID().toString()) + .withInstruction("*") + .withOps(RewriterStatement.argList(ctx, opCost, RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense))); + overhead.add(MALLOC_COST); + } else { + throw new IllegalArgumentException("Unknown instruction: " + instr.trueTypedInstruction(ctx)); + } + } + + uniqueCosts.add(cost); + return cost; + } + + private static RewriterStatement computeScalarOpCost(RewriterInstruction instr, final RuleContext ctx, List uniqueCosts, RewriterAssertions assertions, MutableLong overhead, boolean treatAsDense, RewriterStatement exprRoot) { + RewriterAssertionUtils.buildImplicitAssertion(instr, assertions, exprRoot, ctx); + Map map = new HashMap<>(); + switch (instr.trueTypedInstruction(ctx)) { + case "sum(MATRIX)": + case "min(MATRIX)": + case "max(MATRIX)": + map.put("nnzA", RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense)); + uniqueCosts.add(RewriterUtils.parse("nnzA", ctx, map)); + return uniqueCosts.get(uniqueCosts.size()-1); + case "sumSq(MATRIX)": + map.put("nnzA", RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense)); + uniqueCosts.add(RewriterStatement.multiArgInstr(ctx, "*", RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense), RewriterStatement.literal(ctx, 2L))); + return uniqueCosts.get(uniqueCosts.size()-1); + case "trace(MATRIX)": + uniqueCosts.add(StatementUtils.min(ctx, RewriterStatement.nnz(instr.getChild(0), ctx, treatAsDense), instr.getChild(0).getNRow())); + return uniqueCosts.get(uniqueCosts.size()-1); + case "[](MATRIX,INT,INT)": + return RewriterStatement.literal(ctx, 0L); + case "cast.FLOAT(MATRIX)": + return RewriterStatement.literal(ctx, INSTRUCTION_OVERHEAD); + case "const(MATRIX,FLOAT)": + case "_nnz(MATRIX)": + return RewriterStatement.literal(ctx, 0L); + } + + double opCost = atomicOpCost(instr.trueInstruction()); + uniqueCosts.add(RewriterUtils.parse(Double.toString(opCost), ctx, "LITERAL_FLOAT:" + opCost)); + return uniqueCosts.get(uniqueCosts.size()-1); + } + + private static RewriterStatement atomicOpCostStmt(String op, final RuleContext ctx) { + double opCost = atomicOpCost(op); + return RewriterUtils.parse(Double.toString(opCost), ctx, "LITERAL_FLOAT:" + opCost); + } + + private static double atomicOpCost(String op) { + switch (op) { + case "+": + case "-": + return 1; + case "*": + return 2; + case "*2": + return 1; // To make *2 cheaper than A+A + case "/": + case "inv": + return 3; + case "length": + case "nrow": + case "ncol": + case "_nnz": + return 0; // These just fetch metadata + case "sqrt": + return 10; + case "sq": + return 1.8; // To make it cheaper than *(A,A) + case "exp": + case "log": + case "^": + return 20; + case "!": + case "|": + case "&": + case ">": + case ">=": + case "<": + case "<=": + case "==": + case "!=": + return 1; + case "round": + return 2; + case "abs": + return 2; + case "cast.FLOAT": + return 1; + case "literal.FLOAT": + case "literal.INT": + case "literal.BOOL": + return 0; + } + + throw new IllegalArgumentException("Unknown instruction: " + op); + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/estimators/RewriterSparsityEstimator.java b/src/main/java/org/apache/sysds/hops/rewriter/estimators/RewriterSparsityEstimator.java new file mode 100644 index 00000000000..06bd446bd9f --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/estimators/RewriterSparsityEstimator.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.estimators; + +import org.apache.sysds.hops.rewriter.utils.ConstantFoldingUtils; +import org.apache.sysds.hops.rewriter.RewriterInstruction; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.hops.rewriter.utils.StatementUtils; + +import java.util.HashMap; +import java.util.Map; + +public class RewriterSparsityEstimator { + public static RewriterStatement rollupSparsities(RewriterStatement sparsityEstimate, Map sparsityMap, final RuleContext ctx) { + sparsityEstimate.forEachPreOrder(cur -> { + for (int i = 0; i < cur.getOperands().size(); i++) { + RewriterStatement child = cur.getChild(i); + + if (child.isInstruction() && child.trueInstruction().equals("_nnz")) { + RewriterStatement subEstimate = sparsityMap.get(child.getChild(0)); + + if (subEstimate != null) { + cur.getOperands().set(i, subEstimate); + } + } + } + return true; + }, false); + + return sparsityEstimate; + } + + public static Map estimateAllNNZ(RewriterStatement stmt, final RuleContext ctx) { + Map map = new HashMap<>(); + stmt.forEachPostOrder((cur, pred) -> { + RewriterStatement estimation = estimateNNZ(cur, ctx); + if (estimation != null) + map.put(cur, estimation); + }, false); + + return map; + } + + public static RewriterStatement estimateNNZ(RewriterStatement stmt, final RuleContext ctx) { + if (!stmt.isInstruction() || !stmt.getResultingDataType(ctx).equals("MATRIX")) + return null; + switch (stmt.trueInstruction()) { + case "%*%": + RewriterStatement min1 = StatementUtils.min(ctx, RewriterStatement.multiArgInstr(ctx, "*", RewriterStatement.nnz(stmt.getChild(0), ctx), new RewriterInstruction("inv", ctx, stmt.getChild(0).getNRow())), RewriterStatement.literal(ctx, 1.0D)); + RewriterStatement min2 = StatementUtils.min(ctx, RewriterStatement.multiArgInstr(ctx, "*", RewriterStatement.nnz(stmt.getChild(1), ctx), new RewriterInstruction("inv", ctx, stmt.getChild(1).getNCol())), RewriterStatement.literal(ctx, 1.0D)); + return RewriterStatement.multiArgInstr(ctx, "*", min1, min2, stmt.getNRow(), stmt.getNCol()); + } + + switch (stmt.trueTypedInstruction(ctx)) { + case "*(MATRIX,MATRIX)": + return StatementUtils.min(ctx, RewriterStatement.nnz(stmt.getChild(0), ctx), RewriterStatement.nnz(stmt.getChild(1), ctx)); + case "*(MATRIX,FLOAT)": + if (stmt.getChild(1).isLiteral() && ConstantFoldingUtils.overwritesLiteral(((Double) stmt.getChild(1).getLiteral()), "*", ctx) != null) + return RewriterStatement.literal(ctx, 0L); + return RewriterStatement.nnz(stmt.getChild(0), ctx); + case "*(FLOAT,MATRIX)": + if (stmt.getChild(0).isLiteral() && ConstantFoldingUtils.overwritesLiteral(((Double) stmt.getChild(0).getLiteral()), "*", ctx) != null) + return RewriterStatement.literal(ctx, 0L); + return RewriterStatement.nnz(stmt.getChild(1), ctx); + case "+(MATRIX,MATRIX)": + case "-(MATRIX,MATRIX)": + return StatementUtils.min(ctx, RewriterStatement.multiArgInstr(ctx, "+", RewriterStatement.nnz(stmt.getChild(0), ctx), RewriterStatement.nnz(stmt.getChild(1), ctx)), StatementUtils.length(ctx, stmt)); + case "+(MATRIX,FLOAT)": + case "-(MATRIX,FLOAT)": + if (stmt.getChild(1).isLiteral() && ConstantFoldingUtils.isNeutralElement(stmt.getChild(1).getLiteral(), "+")) + return RewriterStatement.nnz(stmt.getChild(0), ctx); + return StatementUtils.length(ctx, stmt); + case "+(FLOAT,MATRIX)": + case "-(FLOAT,MATRIX)": + if (stmt.getChild(0).isLiteral() && ConstantFoldingUtils.isNeutralElement(stmt.getChild(0).getLiteral(), "+")) + return RewriterStatement.nnz(stmt.getChild(1), ctx); + return StatementUtils.length(ctx, stmt); + case "!=(MATRIX,MATRIX)": + if (stmt.getChild(0).equals(stmt.getChild(1))) + return RewriterStatement.literal(ctx, 0L); + return StatementUtils.length(ctx, stmt); + + case "sqrt(MATRIX)": + return RewriterStatement.nnz(stmt.getChild(0), ctx); + + case "diag(MATRIX)": + return StatementUtils.min(ctx, stmt.getNRow(), RewriterStatement.nnz(stmt.getChild(0), ctx)); + + case "/(MATRIX,FLOAT)": + case "/(MATRIX,MATRIX)": + return RewriterStatement.nnz(stmt.getChild(0), ctx); + case "/(FLOAT,MATRIX)": + if (stmt.getChild(0).isLiteral() && ConstantFoldingUtils.isNeutralElement(stmt.getChild(0).getLiteral(), "+")) + return RewriterStatement.literal(ctx, 0L); + return StatementUtils.length(ctx, stmt); + + case "RBind(MATRIX,MATRIX)": + case "CBind(MATRIX,MATRIX)": + return StatementUtils.add(ctx, RewriterStatement.nnz(stmt.getChild(0), ctx), RewriterStatement.nnz(stmt.getChild(1), ctx)); + + // Fused operators + case "log_nz(MATRIX)": + case "*2(MATRIX)": + case "sq(MATRIX)": + case "t(MATRIX)": + return RewriterStatement.nnz(stmt.getChild(0), ctx); + case "1-*(MATRIX,MATRIX)": + return StatementUtils.length(ctx, stmt); + case "+*(MATRIX,FLOAT,MATRIX)": + case "-*(MATRIX,FLOAT,MATRIX)": + if (stmt.getChild(1).isLiteral() && ConstantFoldingUtils.isNeutralElement(stmt.getChild(1).getLiteral(), "+")) + return RewriterStatement.nnz(stmt.getChild(0), ctx); + return StatementUtils.min(ctx, RewriterStatement.multiArgInstr(ctx, "+", RewriterStatement.nnz(stmt.getChild(0), ctx), RewriterStatement.nnz(stmt.getChild(2), ctx)), StatementUtils.length(ctx, stmt)); + } + + return StatementUtils.length(ctx, stmt); + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/generated/GeneratedRewriteClass.java b/src/main/java/org/apache/sysds/hops/rewriter/generated/GeneratedRewriteClass.java new file mode 100644 index 00000000000..c239e6c6ef4 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/generated/GeneratedRewriteClass.java @@ -0,0 +1,4044 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.generated; + +import java.util.ArrayList; +import java.util.function.Function; + +import org.apache.sysds.utils.Statistics; +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.LiteralOp; +import org.apache.sysds.hops.UnaryOp; +import org.apache.sysds.hops.BinaryOp; +import org.apache.sysds.hops.ReorgOp; +import org.apache.sysds.hops.AggUnaryOp; +import org.apache.sysds.hops.AggBinaryOp; +import org.apache.sysds.hops.DataGenOp; +import org.apache.sysds.hops.TernaryOp; +import org.apache.sysds.common.Types; +import org.apache.sysds.hops.rewrite.HopRewriteUtils; +import org.apache.sysds.hops.rewriter.dml.DMLExecutor; +import org.apache.sysds.hops.rewriter.RewriterRuntimeUtils; + +public class GeneratedRewriteClass implements Function { + + @Override + public Object apply( Object _hi ) { + if ( _hi == null ) + return null; + + Hop hi = (Hop) _hi; + + if ( hi.getDataType() == Types.DataType.SCALAR ) { + hi = _applyRewrite0(hi); // *(0.0,a) => 0.0 + hi = _applyRewrite1(hi); // *(a,0.0) => 0.0 + hi = _applyRewrite23(hi); // sum(/(tmp83271,tmp60732)) => /(sum(tmp83271),tmp60732) + hi = _applyRewrite27(hi); // sum(*(*(tmp8790,tmp30390),tmp97178)) => *(tmp30390,sum(*(tmp97178,tmp8790))) + } else if ( hi.getDataType() == Types.DataType.MATRIX ) { + if ( hi instanceof BinaryOp ) { + if ( (( BinaryOp ) hi ).getOp() == Types.OpOp2.PLUS ) { + if ( hi.getInput().size() == 2 ) { + Hop hi_0 = hi.getInput(0); + Hop hi_1 = hi.getInput(1); + if ( hi_0.getDataType() == Types.DataType.MATRIX ) { + if ( hi_0 instanceof BinaryOp ) { + if ( (( BinaryOp ) hi_0 ).getOp() == Types.OpOp2.MINUS ) { + if ( hi_0.getInput().size() == 2 ) { + Hop hi_0_0 = hi_0.getInput(0); + Hop hi_0_1 = hi_0.getInput(1); + hi = _applyRewrite2(hi); // +(A,0.0) => A + hi = _applyRewrite7(hi); // +(-(0.0,A),B) => -(B,A) + hi = _applyRewrite10(hi); // +(-(A,a),b) => +(A,-(b,a)) + hi = _applyRewrite12(hi); // +(-(a,A),b) => -(+(a,b),A) + hi = _applyRewrite20(hi); // +(-(tmp80035,f12880),tmp63699) => -(+(tmp63699,tmp80035),f12880) + hi = _applyRewrite31(hi); // +(-(a,tmp98488),tmp82242) => +(-(tmp82242,tmp98488),a) + hi = _applyRewrite37(hi); // +(-(*(C,b),d),A) => -(+*(A,b,C),d) + hi = _applyRewrite38(hi); // +(-(*(D,c),B),A) => -(A,-*(B,c,D)) + hi = _applyRewrite39(hi); // +(M9347,*(K,*(M40316,f32765))) => +*(M9347,f32765,*(K,M40316)) + hi = _applyRewrite41(hi); // +(-(f45081,A),B) => +(f45081,-(B,A)) + hi = _applyRewrite42(hi); // +(M9347,*(*(f32765,K),M40316)) => +*(M9347,f32765,*(K,M40316)) + hi = _applyRewrite46(hi); // +(-(b,%*%(C,D)),A) => +(b,-(A,%*%(C,D))) + hi = _applyRewrite54(hi); // +(-(C,d),%*%(A,B)) => -(+(C,%*%(A,B)),d) + } else { + hi = _applyRewrite2(hi); // +(A,0.0) => A + hi = _applyRewrite39(hi); // +(M9347,*(K,*(M40316,f32765))) => +*(M9347,f32765,*(K,M40316)) + hi = _applyRewrite42(hi); // +(M9347,*(*(f32765,K),M40316)) => +*(M9347,f32765,*(K,M40316)) + } + } else if ( (( BinaryOp ) hi_0 ).getOp() == Types.OpOp2.MULT ) { + if ( hi_0.getInput().size() == 2 ) { + Hop hi_0_0 = hi_0.getInput(0); + Hop hi_0_1 = hi_0.getInput(1); + hi = _applyRewrite2(hi); // +(A,0.0) => A + hi = _applyRewrite18(hi); // +(*(*(y_corr,-(float599,is_zero_y_corr)),tmp8608),*(tmp20367,+(tmp23071,tmp55180))) => +(*(*(tmp8608,y_corr),-(float599,is_zero_y_corr)),*(tmp20367,+(tmp55180,tmp23071))) + hi = _applyRewrite32(hi); // +(*(tmp99142,missing_mask_Y),*(tmp58606,missing_mask_Y)) => *(missing_mask_Y,+(tmp99142,tmp58606)) + hi = _applyRewrite39(hi); // +(M9347,*(K,*(M40316,f32765))) => +*(M9347,f32765,*(K,M40316)) + hi = _applyRewrite42(hi); // +(M9347,*(*(f32765,K),M40316)) => +*(M9347,f32765,*(K,M40316)) + hi = _applyRewrite43(hi); // +(*(*(K,f32765),M40316),M9347) => +*(M9347,f32765,*(K,M40316)) + } else { + hi = _applyRewrite2(hi); // +(A,0.0) => A + hi = _applyRewrite39(hi); // +(M9347,*(K,*(M40316,f32765))) => +*(M9347,f32765,*(K,M40316)) + hi = _applyRewrite42(hi); // +(M9347,*(*(f32765,K),M40316)) => +*(M9347,f32765,*(K,M40316)) + } + } else { + hi = _applyRewrite2(hi); // +(A,0.0) => A + hi = _applyRewrite39(hi); // +(M9347,*(K,*(M40316,f32765))) => +*(M9347,f32765,*(K,M40316)) + hi = _applyRewrite42(hi); // +(M9347,*(*(f32765,K),M40316)) => +*(M9347,f32765,*(K,M40316)) + } + } else { + hi = _applyRewrite2(hi); // +(A,0.0) => A + hi = _applyRewrite39(hi); // +(M9347,*(K,*(M40316,f32765))) => +*(M9347,f32765,*(K,M40316)) + hi = _applyRewrite42(hi); // +(M9347,*(*(f32765,K),M40316)) => +*(M9347,f32765,*(K,M40316)) + } + } else if ( hi_0.getDataType() == Types.DataType.SCALAR ) { + hi = _applyRewrite3(hi); // +(0.0,A) => A + hi = _applyRewrite11(hi); // +(a,-(A,b)) => +(A,-(a,b)) + hi = _applyRewrite13(hi); // +(a,-(b,A)) => -(+(a,b),A) + } + } + } else if ( (( BinaryOp ) hi ).getOp() == Types.OpOp2.MINUS ) { + if ( hi.getInput().size() == 2 ) { + Hop hi_0 = hi.getInput(0); + Hop hi_1 = hi.getInput(1); + if ( hi_0.getDataType() == Types.DataType.MATRIX ) { + if ( hi_0 instanceof BinaryOp ) { + if ( (( BinaryOp ) hi_0 ).getOp() == Types.OpOp2.MINUS ) { + if ( hi_0.getInput().size() == 2 ) { + Hop hi_0_0 = hi_0.getInput(0); + Hop hi_0_1 = hi_0.getInput(1); + hi = _applyRewrite4(hi); // -(A,0.0) => A + hi = _applyRewrite14(hi); // -(-(A,a),b) => -(A,+(b,a)) + hi = _applyRewrite16(hi); // -(-(a,A),b) => -(-(a,b),A) + hi = _applyRewrite29(hi); // -(obj,+(tmp6500,tmp26035)) => -(-(obj,tmp6500),tmp26035) + hi = _applyRewrite30(hi); // -(-(tmp68530,tmp73960),tmp29113) => -(tmp68530,+(tmp73960,tmp29113)) + hi = _applyRewrite40(hi); // -(y,+(%*%(X,B),intercept)) => -(-(y,%*%(X,B)),intercept) + hi = _applyRewrite47(hi); // -(-(f43240,A),f67634) => -(-(f43240,f67634),A) + hi = _applyRewrite51(hi); // -(M22650,*(*(f97734,M67673),M97683)) => -*(M22650,f97734,*(M97683,M67673)) + hi = _applyRewrite52(hi); // -(-(f75306,M67233),*(A,M350)) => -(f75306,+(*(A,M350),M67233)) + hi = _applyRewrite53(hi); // -(-(f75306,*(A,M350)),M67233) => -(f75306,+(*(A,M350),M67233)) + } else { + hi = _applyRewrite4(hi); // -(A,0.0) => A + hi = _applyRewrite29(hi); // -(obj,+(tmp6500,tmp26035)) => -(-(obj,tmp6500),tmp26035) + hi = _applyRewrite40(hi); // -(y,+(%*%(X,B),intercept)) => -(-(y,%*%(X,B)),intercept) + hi = _applyRewrite51(hi); // -(M22650,*(*(f97734,M67673),M97683)) => -*(M22650,f97734,*(M97683,M67673)) + } + } else if ( (( BinaryOp ) hi_0 ).getOp() == Types.OpOp2.PLUS ) { + hi = _applyRewrite28(hi); // -(+(a,tmp82242),tmp98488) => +(-(tmp82242,tmp98488),a) + hi = _applyRewrite4(hi); // -(A,0.0) => A + hi = _applyRewrite29(hi); // -(obj,+(tmp6500,tmp26035)) => -(-(obj,tmp6500),tmp26035) + hi = _applyRewrite40(hi); // -(y,+(%*%(X,B),intercept)) => -(-(y,%*%(X,B)),intercept) + hi = _applyRewrite51(hi); // -(M22650,*(*(f97734,M67673),M97683)) => -*(M22650,f97734,*(M97683,M67673)) + } else { + hi = _applyRewrite4(hi); // -(A,0.0) => A + hi = _applyRewrite29(hi); // -(obj,+(tmp6500,tmp26035)) => -(-(obj,tmp6500),tmp26035) + hi = _applyRewrite40(hi); // -(y,+(%*%(X,B),intercept)) => -(-(y,%*%(X,B)),intercept) + hi = _applyRewrite51(hi); // -(M22650,*(*(f97734,M67673),M97683)) => -*(M22650,f97734,*(M97683,M67673)) + } + } else { + hi = _applyRewrite4(hi); // -(A,0.0) => A + hi = _applyRewrite29(hi); // -(obj,+(tmp6500,tmp26035)) => -(-(obj,tmp6500),tmp26035) + hi = _applyRewrite40(hi); // -(y,+(%*%(X,B),intercept)) => -(-(y,%*%(X,B)),intercept) + hi = _applyRewrite51(hi); // -(M22650,*(*(f97734,M67673),M97683)) => -*(M22650,f97734,*(M97683,M67673)) + } + } else if ( hi_0.getDataType() == Types.DataType.SCALAR ) { + hi = _applyRewrite8(hi); // -(0.0,-(B,A)) => -(A,B) + hi = _applyRewrite15(hi); // -(a,-(A,b)) => -(+(a,b),A) + hi = _applyRewrite17(hi); // -(a,-(b,A)) => +(-(a,b),A) + hi = _applyRewrite21(hi); // -(tmp66496,cast.MATRIX(tmp91996)) => cast.MATRIX(-(tmp66496,tmp91996)) + } + } + } else if ( (( BinaryOp ) hi ).getOp() == Types.OpOp2.MULT ) { + if ( hi.getInput().size() == 2 ) { + Hop hi_0 = hi.getInput(0); + Hop hi_1 = hi.getInput(1); + if ( hi_0.getDataType() == Types.DataType.MATRIX ) { + if ( hi_0 instanceof BinaryOp ) { + if ( (( BinaryOp ) hi_0 ).getOp() == Types.OpOp2.DIV ) { + if ( hi_0.getInput().size() == 2 ) { + Hop hi_0_0 = hi_0.getInput(0); + Hop hi_0_1 = hi_0.getInput(1); + hi = _applyRewrite5(hi); // *(A,0.0) => const(A,0.0) + hi = _applyRewrite9(hi); // *(A,/(1.0,B)) => /(A,B) + hi = _applyRewrite19(hi); // *(/(1.0,tmp5995),tmp41945) => /(tmp41945,tmp5995) + hi = _applyRewrite34(hi); // *(/(1.0,B),a) => /(a,B) + hi = _applyRewrite44(hi); // *(/(1.0,M13119),A) => /(A,M13119) + hi = _applyRewrite49(hi); // *(A,/(1.0,M13119)) => /(A,M13119) + } else { + hi = _applyRewrite5(hi); // *(A,0.0) => const(A,0.0) + hi = _applyRewrite9(hi); // *(A,/(1.0,B)) => /(A,B) + hi = _applyRewrite49(hi); // *(A,/(1.0,M13119)) => /(A,M13119) + } + } else if ( (( BinaryOp ) hi_0 ).getOp() == Types.OpOp2.MULT ) { + hi = _applyRewrite25(hi); // *(*(y_corr,-(float599,is_zero_y_corr)),tmp8608) => *(*(y_corr,tmp8608),-(float599,is_zero_y_corr)) + hi = _applyRewrite5(hi); // *(A,0.0) => const(A,0.0) + hi = _applyRewrite9(hi); // *(A,/(1.0,B)) => /(A,B) + hi = _applyRewrite49(hi); // *(A,/(1.0,M13119)) => /(A,M13119) + } else { + hi = _applyRewrite5(hi); // *(A,0.0) => const(A,0.0) + hi = _applyRewrite9(hi); // *(A,/(1.0,B)) => /(A,B) + hi = _applyRewrite49(hi); // *(A,/(1.0,M13119)) => /(A,M13119) + } + } else if ( hi_0 instanceof AggBinaryOp ) { + hi = _applyRewrite26(hi); // *(%*%(scale_lambda,parsertemp150455),tmp43267) => {%*%(*(tmp43267,scale_lambda),parsertemp150455)} + hi = _applyRewrite5(hi); // *(A,0.0) => const(A,0.0) + hi = _applyRewrite9(hi); // *(A,/(1.0,B)) => /(A,B) + hi = _applyRewrite49(hi); // *(A,/(1.0,M13119)) => /(A,M13119) + } else { + hi = _applyRewrite5(hi); // *(A,0.0) => const(A,0.0) + hi = _applyRewrite9(hi); // *(A,/(1.0,B)) => /(A,B) + hi = _applyRewrite49(hi); // *(A,/(1.0,M13119)) => /(A,M13119) + } + } else if ( hi_0.getDataType() == Types.DataType.SCALAR ) { + hi = _applyRewrite6(hi); // *(0.0,A) => const(A,0.0) + hi = _applyRewrite33(hi); // *(tmp43267,%*%(scale_lambda,parsertemp150455)) => {%*%(*(tmp43267,scale_lambda),parsertemp150455)} + hi = _applyRewrite36(hi); // *(a,cast.MATRIX(b)) => cast.MATRIX(*(a,b)) + hi = _applyRewrite50(hi); // *(f68833,-(0.0,M48693)) => *(M48693,-(0.0,f68833)) + } + } + } else if ( (( BinaryOp ) hi ).getOp() == Types.OpOp2.DIV ) { + hi = _applyRewrite35(hi); // /(a,cast.MATRIX(b)) => cast.MATRIX(/(a,b)) + hi = _applyRewrite45(hi); // /(M43656,2.0) => *(0.5,M43656) + hi = _applyRewrite48(hi); // /(M62235,2000.0) => *(5.0E-4,M62235) + } + } else if ( hi instanceof ReorgOp ) { + hi = _applyRewrite22(hi); // t(==(key_unique,t(key))) => ==(key,t(key_unique)) + } else if ( hi instanceof AggBinaryOp ) { + hi = _applyRewrite24(hi); // %*%(t(X_batch),tmp92007) => {t(%*%(t(tmp92007),X_batch))} + } + } + return hi; + } + + // Implementation of the rule *(0.0,a) => 0.0 + private static Hop _applyRewrite0(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MULT || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( !(hi_0 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_0 = (LiteralOp) hi_0; + + if ( l_hi_0.getDataType() != Types.DataType.SCALAR|| !l_hi_0.getValueType().isNumeric() ) + return hi; + + if ( l_hi_0.getDoubleValue() != 0.0 ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.SCALAR || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: 0.0 + + Hop newRoot = hi_0; + if ( hi_0.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: *(0.0,a) => 0.0"); + return newRoot; + } + + // Implementation of the rule *(a,0.0) => 0.0 + private static Hop _applyRewrite1(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MULT || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.SCALAR || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( !(hi_1 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_1 = (LiteralOp) hi_1; + + if ( l_hi_1.getDataType() != Types.DataType.SCALAR|| !l_hi_1.getValueType().isNumeric() ) + return hi; + + if ( l_hi_1.getDoubleValue() != 0.0 ) + return hi; + + + // Now, we start building the new HOP-DAG: 0.0 + + Hop newRoot = hi_1; + if ( hi_1.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + + DMLExecutor.println("Applying rewrite: *(a,0.0) => 0.0"); + return newRoot; + } + + // Implementation of the rule +(A,0.0) => A + private static Hop _applyRewrite2(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.MATRIX || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( !(hi_1 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_1 = (LiteralOp) hi_1; + + if ( l_hi_1.getDataType() != Types.DataType.SCALAR|| !l_hi_1.getValueType().isNumeric() ) + return hi; + + if ( l_hi_1.getDoubleValue() != 0.0 ) + return hi; + + + // Now, we start building the new HOP-DAG: A + + Hop newRoot = hi_0; + if ( hi_0.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: +(A,0.0) => A"); + return newRoot; + } + + // Implementation of the rule +(0.0,A) => A + private static Hop _applyRewrite3(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( !(hi_0 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_0 = (LiteralOp) hi_0; + + if ( l_hi_0.getDataType() != Types.DataType.SCALAR|| !l_hi_0.getValueType().isNumeric() ) + return hi; + + if ( l_hi_0.getDoubleValue() != 0.0 ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: A + + Hop newRoot = hi_1; + if ( hi_1.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + + DMLExecutor.println("Applying rewrite: +(0.0,A) => A"); + return newRoot; + } + + // Implementation of the rule -(A,0.0) => A + private static Hop _applyRewrite4(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MINUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.MATRIX || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( !(hi_1 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_1 = (LiteralOp) hi_1; + + if ( l_hi_1.getDataType() != Types.DataType.SCALAR|| !l_hi_1.getValueType().isNumeric() ) + return hi; + + if ( l_hi_1.getDoubleValue() != 0.0 ) + return hi; + + + // Now, we start building the new HOP-DAG: A + + Hop newRoot = hi_0; + if ( hi_0.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: -(A,0.0) => A"); + return newRoot; + } + + // Implementation of the rule *(A,0.0) => const(A,0.0) + private static Hop _applyRewrite5(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MULT || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.MATRIX || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( !(hi_1 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_1 = (LiteralOp) hi_1; + + if ( l_hi_1.getDataType() != Types.DataType.SCALAR|| !l_hi_1.getValueType().isNumeric() ) + return hi; + + if ( l_hi_1.getDoubleValue() != 0.0 ) + return hi; + + + // Now, we start building the new HOP-DAG: const(A,0.0) + DataGenOp v1 = ((DataGenOp) HopRewriteUtils.createDataGenOpFromDims(HopRewriteUtils.createUnary(hi_0, Types.OpOp1.NROW),HopRewriteUtils.createUnary(hi_0, Types.OpOp1.NCOL),0.0D)); + + Hop newRoot = v1; + if ( v1.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + + DMLExecutor.println("Applying rewrite: *(A,0.0) => const(A,0.0)"); + return newRoot; + } + + // Implementation of the rule *(0.0,A) => const(A,0.0) + private static Hop _applyRewrite6(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MULT || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( !(hi_0 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_0 = (LiteralOp) hi_0; + + if ( l_hi_0.getDataType() != Types.DataType.SCALAR|| !l_hi_0.getValueType().isNumeric() ) + return hi; + + if ( l_hi_0.getDoubleValue() != 0.0 ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: const(A,0.0) + DataGenOp v1 = ((DataGenOp) HopRewriteUtils.createDataGenOpFromDims(HopRewriteUtils.createUnary(hi_1, Types.OpOp1.NROW),HopRewriteUtils.createUnary(hi_1, Types.OpOp1.NCOL),0.0D)); + + Hop newRoot = v1; + if ( v1.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + + DMLExecutor.println("Applying rewrite: *(0.0,A) => const(A,0.0)"); + return newRoot; + } + + // Implementation of the rule +(-(0.0,A),B) => -(B,A) + private static Hop _applyRewrite7(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( !(hi_0_0 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_0_0 = (LiteralOp) hi_0_0; + + if ( l_hi_0_0.getDataType() != Types.DataType.SCALAR|| !l_hi_0_0.getValueType().isNumeric() ) + return hi; + + if ( l_hi_0_0.getDoubleValue() != 0.0 ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(B,A) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1, hi_0_1) ) + return hi; + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1, hi_0_1, Types.OpOp2.MINUS); + + Hop newRoot = v1; + if ( v1.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_0_0); + + DMLExecutor.println("Applying rewrite: +(-(0.0,A),B) => -(B,A)"); + return newRoot; + } + + // Implementation of the rule -(0.0,-(B,A)) => -(A,B) + private static Hop _applyRewrite8(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MINUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( !(hi_0 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_0 = (LiteralOp) hi_0; + + if ( l_hi_0.getDataType() != Types.DataType.SCALAR|| !l_hi_0.getValueType().isNumeric() ) + return hi; + + if ( l_hi_0.getDoubleValue() != 0.0 ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.MINUS || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( hi_1_0.getDataType() != Types.DataType.MATRIX || !hi_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_1_1.getDataType() != Types.DataType.MATRIX || !hi_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(A,B) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1_1, hi_1_0) ) + return hi; + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1_1, hi_1_0, Types.OpOp2.MINUS); + + Hop newRoot = v1; + if ( v1.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: -(0.0,-(B,A)) => -(A,B)"); + return newRoot; + } + + // Implementation of the rule *(A,/(1.0,B)) => /(A,B) + private static Hop _applyRewrite9(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MULT || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.MATRIX || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.DIV || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( !(hi_1_0 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_1_0 = (LiteralOp) hi_1_0; + + if ( l_hi_1_0.getDataType() != Types.DataType.SCALAR|| !l_hi_1_0.getValueType().isNumeric() ) + return hi; + + if ( l_hi_1_0.getDoubleValue() != 1.0 ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_1_1.getDataType() != Types.DataType.MATRIX || !hi_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: /(A,B) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0, hi_1_1) ) + return hi; + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0, hi_1_1, Types.OpOp2.DIV); + + Hop newRoot = v1; + if ( v1.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + HopRewriteUtils.cleanupUnreferenced(hi_1_0); + + DMLExecutor.println("Applying rewrite: *(A,/(1.0,B)) => /(A,B)"); + return newRoot; + } + + // Implementation of the rule +(-(A,a),b) => +(A,-(b,a)) + private static Hop _applyRewrite10(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.MATRIX || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.SCALAR || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.SCALAR || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: +(A,-(b,a)) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1, hi_0_1, Types.OpOp2.MINUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, v1, Types.OpOp2.PLUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + + DMLExecutor.println("Applying rewrite: +(-(A,a),b) => +(A,-(b,a))"); + return newRoot; + } + + // Implementation of the rule +(a,-(A,b)) => +(A,-(a,b)) + private static Hop _applyRewrite11(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.SCALAR || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.MINUS || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( hi_1_0.getDataType() != Types.DataType.MATRIX || !hi_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_1_1.getDataType() != Types.DataType.SCALAR || !hi_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: +(A,-(a,b)) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0, hi_1_1, Types.OpOp2.MINUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_1_0, v1, Types.OpOp2.PLUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: +(a,-(A,b)) => +(A,-(a,b))"); + return newRoot; + } + + // Implementation of the rule +(-(a,A),b) => -(+(a,b),A) + private static Hop _applyRewrite12(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.SCALAR || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.SCALAR || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(+(a,b),A) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, hi_1, Types.OpOp2.PLUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_0_1, Types.OpOp2.MINUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + + DMLExecutor.println("Applying rewrite: +(-(a,A),b) => -(+(a,b),A)"); + return newRoot; + } + + // Implementation of the rule +(a,-(b,A)) => -(+(a,b),A) + private static Hop _applyRewrite13(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.SCALAR || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.MINUS || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( hi_1_0.getDataType() != Types.DataType.SCALAR || !hi_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_1_1.getDataType() != Types.DataType.MATRIX || !hi_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(+(a,b),A) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0, hi_1_0, Types.OpOp2.PLUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_1_1, Types.OpOp2.MINUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: +(a,-(b,A)) => -(+(a,b),A)"); + return newRoot; + } + + // Implementation of the rule -(-(A,a),b) => -(A,+(b,a)) + private static Hop _applyRewrite14(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MINUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.MATRIX || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.SCALAR || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.SCALAR || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(A,+(b,a)) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1, hi_0_1, Types.OpOp2.PLUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, v1, Types.OpOp2.MINUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + + DMLExecutor.println("Applying rewrite: -(-(A,a),b) => -(A,+(b,a))"); + return newRoot; + } + + // Implementation of the rule -(a,-(A,b)) => -(+(a,b),A) + private static Hop _applyRewrite15(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MINUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.SCALAR || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.MINUS || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( hi_1_0.getDataType() != Types.DataType.MATRIX || !hi_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_1_1.getDataType() != Types.DataType.SCALAR || !hi_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(+(a,b),A) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0, hi_1_1, Types.OpOp2.PLUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_1_0, Types.OpOp2.MINUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: -(a,-(A,b)) => -(+(a,b),A)"); + return newRoot; + } + + // Implementation of the rule -(-(a,A),b) => -(-(a,b),A) + private static Hop _applyRewrite16(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MINUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.SCALAR || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.SCALAR || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(-(a,b),A) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, hi_1, Types.OpOp2.MINUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_0_1, Types.OpOp2.MINUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + + DMLExecutor.println("Applying rewrite: -(-(a,A),b) => -(-(a,b),A)"); + return newRoot; + } + + // Implementation of the rule -(a,-(b,A)) => +(-(a,b),A) + private static Hop _applyRewrite17(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MINUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.SCALAR || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.MINUS || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( hi_1_0.getDataType() != Types.DataType.SCALAR || !hi_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_1_1.getDataType() != Types.DataType.MATRIX || !hi_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: +(-(a,b),A) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0, hi_1_0, Types.OpOp2.MINUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_1_1, Types.OpOp2.PLUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: -(a,-(b,A)) => +(-(a,b),A)"); + return newRoot; + } + + // Implementation of the rule +(*(*(y_corr,-(float599,is_zero_y_corr)),tmp8608),*(tmp20367,+(tmp23071,tmp55180))) => +(*(*(tmp8608,y_corr),-(float599,is_zero_y_corr)),*(tmp20367,+(tmp55180,tmp23071))) + private static Hop _applyRewrite18(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MULT || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if (hi_0_0.getParent().size() > 1) + return hi; + if ( !(hi_0_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0_0 = (BinaryOp) hi_0_0; + + if ( c_hi_0_0.getOp() != Types.OpOp2.MULT || !c_hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0_0 = hi_0_0.getInput(0); + + if ( hi_0_0_0.getDataType() != Types.DataType.MATRIX || !hi_0_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0_1 = hi_0_0.getInput(1); + + if (hi_0_0_1.getParent().size() > 1) + return hi; + if ( !(hi_0_0_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0_0_1 = (BinaryOp) hi_0_0_1; + + if ( c_hi_0_0_1.getOp() != Types.OpOp2.MINUS || !c_hi_0_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0_1_0 = hi_0_0_1.getInput(0); + + if ( hi_0_0_1_0.getDataType() != Types.DataType.SCALAR || !hi_0_0_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0_1_1 = hi_0_0_1.getInput(1); + + if ( hi_0_0_1_1.getDataType() != Types.DataType.MATRIX || !hi_0_0_1_1.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.MULT || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( hi_1_0.getDataType() != Types.DataType.SCALAR || !hi_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if (hi_1_1.getParent().size() > 1) + return hi; + if ( !(hi_1_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1_1 = (BinaryOp) hi_1_1; + + if ( c_hi_1_1.getOp() != Types.OpOp2.PLUS || !c_hi_1_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1_0 = hi_1_1.getInput(0); + + if ( hi_1_1_0.getDataType() != Types.DataType.MATRIX || !hi_1_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1_1 = hi_1_1.getInput(1); + + if ( hi_1_1_1.getDataType() != Types.DataType.MATRIX || !hi_1_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: +(*(*(tmp8608,y_corr),-(float599,is_zero_y_corr)),*(tmp20367,+(tmp55180,tmp23071))) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_1, hi_0_0_0) ) + return hi; + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_1, hi_0_0_0, Types.OpOp2.MULT); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0_1_0, hi_0_0_1_1, Types.OpOp2.MINUS); + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(v1, v2) ) + return hi; + BinaryOp v3 = HopRewriteUtils.createAutoGeneratedBinary(v1, v2, Types.OpOp2.MULT); + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1_1_1, hi_1_1_0) ) + return hi; + BinaryOp v4 = HopRewriteUtils.createAutoGeneratedBinary(hi_1_1_1, hi_1_1_0, Types.OpOp2.PLUS); + BinaryOp v5 = HopRewriteUtils.createAutoGeneratedBinary(hi_1_0, v4, Types.OpOp2.MULT); + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(v3, v5) ) + return hi; + BinaryOp v6 = HopRewriteUtils.createAutoGeneratedBinary(v3, v5, Types.OpOp2.PLUS); + + Hop newRoot = v6; + if ( v6.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_0_0); + HopRewriteUtils.cleanupUnreferenced(hi_0_0_1); + HopRewriteUtils.cleanupUnreferenced(hi_1); + HopRewriteUtils.cleanupUnreferenced(hi_1_1); + + DMLExecutor.println("Applying rewrite: +(*(*(y_corr,-(float599,is_zero_y_corr)),tmp8608),*(tmp20367,+(tmp23071,tmp55180))) => +(*(*(tmp8608,y_corr),-(float599,is_zero_y_corr)),*(tmp20367,+(tmp55180,tmp23071)))"); + return newRoot; + } + + // Implementation of the rule *(/(1.0,tmp5995),tmp41945) => /(tmp41945,tmp5995) + private static Hop _applyRewrite19(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MULT || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.DIV || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( !(hi_0_0 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_0_0 = (LiteralOp) hi_0_0; + + if ( l_hi_0_0.getDataType() != Types.DataType.SCALAR|| !l_hi_0_0.getValueType().isNumeric() ) + return hi; + + if ( l_hi_0_0.getDoubleValue() != 1.0 ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.SCALAR || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: /(tmp41945,tmp5995) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1, hi_0_1, Types.OpOp2.DIV); + + Hop newRoot = v1; + if ( v1.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_0_0); + + DMLExecutor.println("Applying rewrite: *(/(1.0,tmp5995),tmp41945) => /(tmp41945,tmp5995)"); + return newRoot; + } + + // Implementation of the rule +(-(tmp80035,f12880),tmp63699) => -(+(tmp63699,tmp80035),f12880) + private static Hop _applyRewrite20(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.MATRIX || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.SCALAR || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(+(tmp63699,tmp80035),f12880) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1, hi_0_0) ) + return hi; + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1, hi_0_0, Types.OpOp2.PLUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_0_1, Types.OpOp2.MINUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + + DMLExecutor.println("Applying rewrite: +(-(tmp80035,f12880),tmp63699) => -(+(tmp63699,tmp80035),f12880)"); + return newRoot; + } + + // Implementation of the rule -(tmp66496,cast.MATRIX(tmp91996)) => cast.MATRIX(-(tmp66496,tmp91996)) + private static Hop _applyRewrite21(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MINUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.SCALAR || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof UnaryOp) ) + return hi; + + UnaryOp c_hi_1 = (UnaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp1.CAST_AS_MATRIX || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( hi_1_0.getDataType() != Types.DataType.SCALAR || !hi_1_0.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: cast.MATRIX(-(tmp66496,tmp91996)) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0, hi_1_0, Types.OpOp2.MINUS); + UnaryOp v2 = HopRewriteUtils.createUnary(v1, Types.OpOp1.CAST_AS_MATRIX); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: -(tmp66496,cast.MATRIX(tmp91996)) => cast.MATRIX(-(tmp66496,tmp91996))"); + return newRoot; + } + + // Implementation of the rule t(==(key_unique,t(key))) => ==(key,t(key_unique)) + private static Hop _applyRewrite22(Hop hi) { + if ( !(hi instanceof ReorgOp) ) + return hi; + + ReorgOp c_hi = (ReorgOp) hi; + + if ( c_hi.getOp() != Types.ReOrgOp.TRANS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.EQUAL || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.MATRIX || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if (hi_0_1.getParent().size() > 1) + return hi; + if ( !(hi_0_1 instanceof ReorgOp) ) + return hi; + + ReorgOp c_hi_0_1 = (ReorgOp) hi_0_1; + + if ( c_hi_0_1.getOp() != Types.ReOrgOp.TRANS || !c_hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1_0 = hi_0_1.getInput(0); + + if ( hi_0_1_0.getDataType() != Types.DataType.MATRIX || !hi_0_1_0.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: ==(key,t(key_unique)) + ReorgOp v1 = HopRewriteUtils.createTranspose(hi_0_0); + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_1_0, v1) ) + return hi; + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_1_0, v1, Types.OpOp2.EQUAL); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_0_1); + + DMLExecutor.println("Applying rewrite: t(==(key_unique,t(key))) => ==(key,t(key_unique))"); + return newRoot; + } + + // Implementation of the rule sum(/(tmp83271,tmp60732)) => /(sum(tmp83271),tmp60732) + private static Hop _applyRewrite23(Hop hi) { + if ( !(hi instanceof AggUnaryOp) ) + return hi; + + AggUnaryOp c_hi = (AggUnaryOp) hi; + + if ( c_hi.getOp() != Types.AggOp.SUM || !c_hi.getValueType().isNumeric() ) + return hi; + + if ( !(c_hi.getDirection() == Types.Direction.RowCol) ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.DIV || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.MATRIX || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.SCALAR || !hi_0_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: /(sum(tmp83271),tmp60732) + AggUnaryOp v1 = HopRewriteUtils.createAggUnaryOp(hi_0_0, Types.AggOp.SUM, Types.Direction.RowCol); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_0_1, Types.OpOp2.DIV); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + + DMLExecutor.println("Applying rewrite: sum(/(tmp83271,tmp60732)) => /(sum(tmp83271),tmp60732)"); + return newRoot; + } + + // Implementation of the rule %*%(t(X_batch),tmp92007) => {t(%*%(t(tmp92007),X_batch))} + private static Hop _applyRewrite24(Hop hi) { + if ( !HopRewriteUtils.isMatrixMultiply(hi) ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof ReorgOp) ) + return hi; + + ReorgOp c_hi_0 = (ReorgOp) hi_0; + + if ( c_hi_0.getOp() != Types.ReOrgOp.TRANS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.MATRIX || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + if ( hi_1.getDim2() == -1 || hi_1.getNnz() == -1 || hi_0_0.getNnz() == -1 || hi_0_0.getDim2() == -1 || hi_1.getDim1() == -1 ) + return hi; + + + double[] costs = new double[2]; + costs[0] = (hi_0_0.getNnz() + (Math.min(hi_0_0.getNnz(), hi_1.getNnz()) * hi_1.getDim1() * 3.0) + 20020.0); + costs[1] = (hi_1.getNnz() + (Math.min(hi_1.getNnz(), hi_0_0.getNnz()) * hi_1.getDim1() * 3.0) + (Math.min((hi_1.getNnz() * (1.0 / hi_1.getDim2())), 1.0) * Math.min((hi_0_0.getNnz() * (1.0 / hi_0_0.getDim2())), 1.0) * hi_1.getDim2() * hi_0_0.getDim2()) + 30030.0); + int minIdx = minIdx(costs); + + switch( minIdx ) { + case 1: { + // Now, we start building the new HOP-DAG: t(%*%(t(tmp92007),X_batch)) + ReorgOp v1 = HopRewriteUtils.createTranspose(hi_1); + AggBinaryOp v2 = HopRewriteUtils.createMatrixMultiply(v1, hi_0_0); + ReorgOp v3 = HopRewriteUtils.createTranspose(v2); + + Hop newRoot = v3; + if ( v3.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + + DMLExecutor.println("Applying rewrite: %*%(t(X_batch),tmp92007) => {t(%*%(t(tmp92007),X_batch))}"); + return newRoot; + } + } + return hi; + } + + // Implementation of the rule *(*(y_corr,-(float599,is_zero_y_corr)),tmp8608) => *(*(y_corr,tmp8608),-(float599,is_zero_y_corr)) + private static Hop _applyRewrite25(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MULT || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MULT || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.MATRIX || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if (hi_0_1.getParent().size() > 1) + return hi; + if ( !(hi_0_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0_1 = (BinaryOp) hi_0_1; + + if ( c_hi_0_1.getOp() != Types.OpOp2.MINUS || !c_hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1_0 = hi_0_1.getInput(0); + + if ( hi_0_1_0.getDataType() != Types.DataType.SCALAR || !hi_0_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1_1 = hi_0_1.getInput(1); + + if ( hi_0_1_1.getDataType() != Types.DataType.MATRIX || !hi_0_1_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: *(*(y_corr,tmp8608),-(float599,is_zero_y_corr)) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_0, hi_1) ) + return hi; + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, hi_1, Types.OpOp2.MULT); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_1_0, hi_0_1_1, Types.OpOp2.MINUS); + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(v1, v2) ) + return hi; + BinaryOp v3 = HopRewriteUtils.createAutoGeneratedBinary(v1, v2, Types.OpOp2.MULT); + + Hop newRoot = v3; + if ( v3.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_0_1); + + DMLExecutor.println("Applying rewrite: *(*(y_corr,-(float599,is_zero_y_corr)),tmp8608) => *(*(y_corr,tmp8608),-(float599,is_zero_y_corr))"); + return newRoot; + } + + // Implementation of the rule *(%*%(scale_lambda,parsertemp150455),tmp43267) => {%*%(*(tmp43267,scale_lambda),parsertemp150455)} + private static Hop _applyRewrite26(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MULT || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !HopRewriteUtils.isMatrixMultiply(hi_0) ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.MATRIX || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.SCALAR || !hi_1.getValueType().isNumeric() ) + return hi; + + + if ( hi_0_0.getDim1() == -1 || hi_0_1.getDim2() == -1 || hi_0_1.getNnz() == -1 || hi_0_0.getNnz() == -1 || hi_0_1.getDim1() == -1 ) + return hi; + + + double[] costs = new double[2]; + costs[0] = ((Math.min(hi_0_0.getNnz(), hi_0_1.getNnz()) * hi_0_1.getDim1() * 3.0) + (2.0 * (Math.min((hi_0_0.getNnz() * (1.0 / hi_0_0.getDim1())), 1.0) * Math.min((hi_0_1.getNnz() * (1.0 / hi_0_1.getDim2())), 1.0) * hi_0_0.getDim1() * hi_0_1.getDim2())) + 20020.0); + costs[1] = ((2.0 * hi_0_0.getNnz()) + (Math.min(hi_0_0.getNnz(), hi_0_1.getNnz()) * hi_0_1.getDim1() * 3.0) + 20020.0); + int minIdx = minIdx(costs); + + switch( minIdx ) { + case 1: { + // Now, we start building the new HOP-DAG: %*%(*(tmp43267,scale_lambda),parsertemp150455) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1, hi_0_0, Types.OpOp2.MULT); + AggBinaryOp v2 = HopRewriteUtils.createMatrixMultiply(v1, hi_0_1); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + + DMLExecutor.println("Applying rewrite: *(%*%(scale_lambda,parsertemp150455),tmp43267) => {%*%(*(tmp43267,scale_lambda),parsertemp150455)}"); + return newRoot; + } + } + return hi; + } + + // Implementation of the rule sum(*(*(tmp8790,tmp30390),tmp97178)) => *(tmp30390,sum(*(tmp97178,tmp8790))) + private static Hop _applyRewrite27(Hop hi) { + if ( !(hi instanceof AggUnaryOp) ) + return hi; + + AggUnaryOp c_hi = (AggUnaryOp) hi; + + if ( c_hi.getOp() != Types.AggOp.SUM || !c_hi.getValueType().isNumeric() ) + return hi; + + if ( !(c_hi.getDirection() == Types.Direction.RowCol) ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MULT || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if (hi_0_0.getParent().size() > 1) + return hi; + if ( !(hi_0_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0_0 = (BinaryOp) hi_0_0; + + if ( c_hi_0_0.getOp() != Types.OpOp2.MULT || !c_hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0_0 = hi_0_0.getInput(0); + + if ( hi_0_0_0.getDataType() != Types.DataType.MATRIX || !hi_0_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0_1 = hi_0_0.getInput(1); + + if ( hi_0_0_1.getDataType() != Types.DataType.SCALAR || !hi_0_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: *(tmp30390,sum(*(tmp97178,tmp8790))) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_1, hi_0_0_0) ) + return hi; + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_1, hi_0_0_0, Types.OpOp2.MULT); + AggUnaryOp v2 = HopRewriteUtils.createAggUnaryOp(v1, Types.AggOp.SUM, Types.Direction.RowCol); + BinaryOp v3 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0_1, v2, Types.OpOp2.MULT); + + Hop newRoot = v3; + if ( v3.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_0_0); + + DMLExecutor.println("Applying rewrite: sum(*(*(tmp8790,tmp30390),tmp97178)) => *(tmp30390,sum(*(tmp97178,tmp8790)))"); + return newRoot; + } + + // Implementation of the rule -(+(a,tmp82242),tmp98488) => +(-(tmp82242,tmp98488),a) + private static Hop _applyRewrite28(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MINUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.PLUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.SCALAR || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: +(-(tmp82242,tmp98488),a) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_1, hi_1) ) + return hi; + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_1, hi_1, Types.OpOp2.MINUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_0_0, Types.OpOp2.PLUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + + DMLExecutor.println("Applying rewrite: -(+(a,tmp82242),tmp98488) => +(-(tmp82242,tmp98488),a)"); + return newRoot; + } + + // Implementation of the rule -(obj,+(tmp6500,tmp26035)) => -(-(obj,tmp6500),tmp26035) + private static Hop _applyRewrite29(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MINUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.MATRIX || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.PLUS || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( hi_1_0.getDataType() != Types.DataType.MATRIX || !hi_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_1_1.getDataType() != Types.DataType.SCALAR || !hi_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(-(obj,tmp6500),tmp26035) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0, hi_1_0) ) + return hi; + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0, hi_1_0, Types.OpOp2.MINUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_1_1, Types.OpOp2.MINUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: -(obj,+(tmp6500,tmp26035)) => -(-(obj,tmp6500),tmp26035)"); + return newRoot; + } + + // Implementation of the rule -(-(tmp68530,tmp73960),tmp29113) => -(tmp68530,+(tmp73960,tmp29113)) + private static Hop _applyRewrite30(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MINUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.SCALAR || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(tmp68530,+(tmp73960,tmp29113)) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_1, hi_1) ) + return hi; + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_1, hi_1, Types.OpOp2.PLUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, v1, Types.OpOp2.MINUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + + DMLExecutor.println("Applying rewrite: -(-(tmp68530,tmp73960),tmp29113) => -(tmp68530,+(tmp73960,tmp29113))"); + return newRoot; + } + + // Implementation of the rule +(-(a,tmp98488),tmp82242) => +(-(tmp82242,tmp98488),a) + private static Hop _applyRewrite31(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.SCALAR || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: +(-(tmp82242,tmp98488),a) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1, hi_0_1) ) + return hi; + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1, hi_0_1, Types.OpOp2.MINUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_0_0, Types.OpOp2.PLUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + + DMLExecutor.println("Applying rewrite: +(-(a,tmp98488),tmp82242) => +(-(tmp82242,tmp98488),a)"); + return newRoot; + } + + // Implementation of the rule +(*(tmp99142,missing_mask_Y),*(tmp58606,missing_mask_Y)) => *(missing_mask_Y,+(tmp99142,tmp58606)) + private static Hop _applyRewrite32(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MULT || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.SCALAR || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.MULT || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( hi_1_0.getDataType() != Types.DataType.SCALAR || !hi_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_0_1 != hi_1_1 ) + return hi; + + + // Now, we start building the new HOP-DAG: *(missing_mask_Y,+(tmp99142,tmp58606)) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, hi_1_0, Types.OpOp2.PLUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_1, v1, Types.OpOp2.MULT); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: +(*(tmp99142,missing_mask_Y),*(tmp58606,missing_mask_Y)) => *(missing_mask_Y,+(tmp99142,tmp58606))"); + return newRoot; + } + + // Implementation of the rule *(tmp43267,%*%(scale_lambda,parsertemp150455)) => {%*%(*(tmp43267,scale_lambda),parsertemp150455)} + private static Hop _applyRewrite33(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MULT || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.SCALAR || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !HopRewriteUtils.isMatrixMultiply(hi_1) ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( hi_1_0.getDataType() != Types.DataType.MATRIX || !hi_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_1_1.getDataType() != Types.DataType.MATRIX || !hi_1_1.getValueType().isNumeric() ) + return hi; + + + if ( hi_1_0.getNnz() == -1 || hi_1_1.getDim2() == -1 || hi_1_0.getDim1() == -1 || hi_1_0.getDim2() == -1 || hi_1_1.getNnz() == -1 ) + return hi; + + + double[] costs = new double[2]; + costs[0] = ((Math.min(hi_1_0.getNnz(), hi_1_1.getNnz()) * hi_1_0.getDim2() * 3.0) + (2.0 * (Math.min((hi_1_0.getNnz() * (1.0 / hi_1_0.getDim1())), 1.0) * Math.min((hi_1_1.getNnz() * (1.0 / hi_1_1.getDim2())), 1.0) * hi_1_0.getDim1() * hi_1_1.getDim2())) + 20020.0); + costs[1] = ((2.0 * hi_1_0.getNnz()) + (Math.min(hi_1_0.getNnz(), hi_1_1.getNnz()) * hi_1_0.getDim2() * 3.0) + 20020.0); + int minIdx = minIdx(costs); + + switch( minIdx ) { + case 1: { + // Now, we start building the new HOP-DAG: %*%(*(tmp43267,scale_lambda),parsertemp150455) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0, hi_1_0, Types.OpOp2.MULT); + AggBinaryOp v2 = HopRewriteUtils.createMatrixMultiply(v1, hi_1_1); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: *(tmp43267,%*%(scale_lambda,parsertemp150455)) => {%*%(*(tmp43267,scale_lambda),parsertemp150455)}"); + return newRoot; + } + } + return hi; + } + + // Implementation of the rule *(/(1.0,B),a) => /(a,B) + private static Hop _applyRewrite34(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MULT || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.DIV || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( !(hi_0_0 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_0_0 = (LiteralOp) hi_0_0; + + if ( l_hi_0_0.getDataType() != Types.DataType.SCALAR|| !l_hi_0_0.getValueType().isNumeric() ) + return hi; + + if ( l_hi_0_0.getDoubleValue() != 1.0 ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.SCALAR || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: /(a,B) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1, hi_0_1, Types.OpOp2.DIV); + + Hop newRoot = v1; + if ( v1.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_0_0); + + DMLExecutor.println("Applying rewrite: *(/(1.0,B),a) => /(a,B)"); + return newRoot; + } + + // Implementation of the rule /(a,cast.MATRIX(b)) => cast.MATRIX(/(a,b)) + private static Hop _applyRewrite35(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.DIV || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.SCALAR || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof UnaryOp) ) + return hi; + + UnaryOp c_hi_1 = (UnaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp1.CAST_AS_MATRIX || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( hi_1_0.getDataType() != Types.DataType.SCALAR || !hi_1_0.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: cast.MATRIX(/(a,b)) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0, hi_1_0, Types.OpOp2.DIV); + UnaryOp v2 = HopRewriteUtils.createUnary(v1, Types.OpOp1.CAST_AS_MATRIX); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: /(a,cast.MATRIX(b)) => cast.MATRIX(/(a,b))"); + return newRoot; + } + + // Implementation of the rule *(a,cast.MATRIX(b)) => cast.MATRIX(*(a,b)) + private static Hop _applyRewrite36(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MULT || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.SCALAR || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof UnaryOp) ) + return hi; + + UnaryOp c_hi_1 = (UnaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp1.CAST_AS_MATRIX || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( hi_1_0.getDataType() != Types.DataType.SCALAR || !hi_1_0.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: cast.MATRIX(*(a,b)) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0, hi_1_0, Types.OpOp2.MULT); + UnaryOp v2 = HopRewriteUtils.createUnary(v1, Types.OpOp1.CAST_AS_MATRIX); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: *(a,cast.MATRIX(b)) => cast.MATRIX(*(a,b))"); + return newRoot; + } + + // Implementation of the rule +(-(*(C,b),d),A) => -(+*(A,b,C),d) + private static Hop _applyRewrite37(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if (hi_0_0.getParent().size() > 1) + return hi; + if ( !(hi_0_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0_0 = (BinaryOp) hi_0_0; + + if ( c_hi_0_0.getOp() != Types.OpOp2.MULT || !c_hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0_0 = hi_0_0.getInput(0); + + if ( hi_0_0_0.getDataType() != Types.DataType.MATRIX || !hi_0_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0_1 = hi_0_0.getInput(1); + + if ( hi_0_0_1.getDataType() != Types.DataType.SCALAR || !hi_0_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.SCALAR || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(+*(A,b,C),d) + if ( !RewriterRuntimeUtils.hasMatchingDims(hi_1, hi_0_0_0) ) + return hi; + TernaryOp v1 = HopRewriteUtils.createTernary(hi_1, hi_0_0_1, hi_0_0_0,Types.OpOp3.PLUS_MULT); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_0_1, Types.OpOp2.MINUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_0_0); + + DMLExecutor.println("Applying rewrite: +(-(*(C,b),d),A) => -(+*(A,b,C),d)"); + return newRoot; + } + + // Implementation of the rule +(-(*(D,c),B),A) => -(A,-*(B,c,D)) + private static Hop _applyRewrite38(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if (hi_0_0.getParent().size() > 1) + return hi; + if ( !(hi_0_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0_0 = (BinaryOp) hi_0_0; + + if ( c_hi_0_0.getOp() != Types.OpOp2.MULT || !c_hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0_0 = hi_0_0.getInput(0); + + if ( hi_0_0_0.getDataType() != Types.DataType.MATRIX || !hi_0_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0_1 = hi_0_0.getInput(1); + + if ( hi_0_0_1.getDataType() != Types.DataType.SCALAR || !hi_0_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(A,-*(B,c,D)) + if ( !RewriterRuntimeUtils.hasMatchingDims(hi_0_1, hi_0_0_0) ) + return hi; + TernaryOp v1 = HopRewriteUtils.createTernary(hi_0_1, hi_0_0_1, hi_0_0_0,Types.OpOp3.MINUS_MULT); + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1, v1) ) + return hi; + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_1, v1, Types.OpOp2.MINUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_0_0); + + DMLExecutor.println("Applying rewrite: +(-(*(D,c),B),A) => -(A,-*(B,c,D))"); + return newRoot; + } + + // Implementation of the rule +(M9347,*(K,*(M40316,f32765))) => +*(M9347,f32765,*(K,M40316)) + private static Hop _applyRewrite39(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.MATRIX || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.MULT || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( hi_1_0.getDataType() != Types.DataType.MATRIX || !hi_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if (hi_1_1.getParent().size() > 1) + return hi; + if ( !(hi_1_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1_1 = (BinaryOp) hi_1_1; + + if ( c_hi_1_1.getOp() != Types.OpOp2.MULT || !c_hi_1_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1_0 = hi_1_1.getInput(0); + + if ( hi_1_1_0.getDataType() != Types.DataType.MATRIX || !hi_1_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1_1 = hi_1_1.getInput(1); + + if ( hi_1_1_1.getDataType() != Types.DataType.SCALAR || !hi_1_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: +*(M9347,f32765,*(K,M40316)) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1_0, hi_1_1_0) ) + return hi; + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1_0, hi_1_1_0, Types.OpOp2.MULT); + if ( !RewriterRuntimeUtils.hasMatchingDims(hi_0, v1) ) + return hi; + TernaryOp v2 = HopRewriteUtils.createTernary(hi_0, hi_1_1_1, v1,Types.OpOp3.PLUS_MULT); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + HopRewriteUtils.cleanupUnreferenced(hi_1_1); + + DMLExecutor.println("Applying rewrite: +(M9347,*(K,*(M40316,f32765))) => +*(M9347,f32765,*(K,M40316))"); + return newRoot; + } + + // Implementation of the rule -(y,+(%*%(X,B),intercept)) => -(-(y,%*%(X,B)),intercept) + private static Hop _applyRewrite40(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MINUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.MATRIX || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.PLUS || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if (hi_1_0.getParent().size() > 1) + return hi; + if ( !HopRewriteUtils.isMatrixMultiply(hi_1_0) ) + return hi; + + Hop hi_1_0_0 = hi_1_0.getInput(0); + + if ( hi_1_0_0.getDataType() != Types.DataType.MATRIX || !hi_1_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0_1 = hi_1_0.getInput(1); + + if ( hi_1_0_1.getDataType() != Types.DataType.MATRIX || !hi_1_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_1_1.getDataType() != Types.DataType.SCALAR || !hi_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(-(y,%*%(X,B)),intercept) + AggBinaryOp v1 = HopRewriteUtils.createMatrixMultiply(hi_1_0_0, hi_1_0_1); + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0, v1) ) + return hi; + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_0, v1, Types.OpOp2.MINUS); + BinaryOp v3 = HopRewriteUtils.createAutoGeneratedBinary(v2, hi_1_1, Types.OpOp2.MINUS); + + Hop newRoot = v3; + if ( v3.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + HopRewriteUtils.cleanupUnreferenced(hi_1_0); + + DMLExecutor.println("Applying rewrite: -(y,+(%*%(X,B),intercept)) => -(-(y,%*%(X,B)),intercept)"); + return newRoot; + } + + // Implementation of the rule +(-(f45081,A),B) => +(f45081,-(B,A)) + private static Hop _applyRewrite41(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.SCALAR || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: +(f45081,-(B,A)) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1, hi_0_1) ) + return hi; + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1, hi_0_1, Types.OpOp2.MINUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, v1, Types.OpOp2.PLUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + + DMLExecutor.println("Applying rewrite: +(-(f45081,A),B) => +(f45081,-(B,A))"); + return newRoot; + } + + // Implementation of the rule +(M9347,*(*(f32765,K),M40316)) => +*(M9347,f32765,*(K,M40316)) + private static Hop _applyRewrite42(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.MATRIX || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.MULT || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if (hi_1_0.getParent().size() > 1) + return hi; + if ( !(hi_1_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1_0 = (BinaryOp) hi_1_0; + + if ( c_hi_1_0.getOp() != Types.OpOp2.MULT || !c_hi_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0_0 = hi_1_0.getInput(0); + + if ( hi_1_0_0.getDataType() != Types.DataType.SCALAR || !hi_1_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0_1 = hi_1_0.getInput(1); + + if ( hi_1_0_1.getDataType() != Types.DataType.MATRIX || !hi_1_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_1_1.getDataType() != Types.DataType.MATRIX || !hi_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: +*(M9347,f32765,*(K,M40316)) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1_0_1, hi_1_1) ) + return hi; + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1_0_1, hi_1_1, Types.OpOp2.MULT); + if ( !RewriterRuntimeUtils.hasMatchingDims(hi_0, v1) ) + return hi; + TernaryOp v2 = HopRewriteUtils.createTernary(hi_0, hi_1_0_0, v1,Types.OpOp3.PLUS_MULT); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + HopRewriteUtils.cleanupUnreferenced(hi_1_0); + + DMLExecutor.println("Applying rewrite: +(M9347,*(*(f32765,K),M40316)) => +*(M9347,f32765,*(K,M40316))"); + return newRoot; + } + + // Implementation of the rule +(*(*(K,f32765),M40316),M9347) => +*(M9347,f32765,*(K,M40316)) + private static Hop _applyRewrite43(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MULT || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if (hi_0_0.getParent().size() > 1) + return hi; + if ( !(hi_0_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0_0 = (BinaryOp) hi_0_0; + + if ( c_hi_0_0.getOp() != Types.OpOp2.MULT || !c_hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0_0 = hi_0_0.getInput(0); + + if ( hi_0_0_0.getDataType() != Types.DataType.MATRIX || !hi_0_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0_1 = hi_0_0.getInput(1); + + if ( hi_0_0_1.getDataType() != Types.DataType.SCALAR || !hi_0_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: +*(M9347,f32765,*(K,M40316)) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_0_0, hi_0_1) ) + return hi; + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0_0, hi_0_1, Types.OpOp2.MULT); + if ( !RewriterRuntimeUtils.hasMatchingDims(hi_1, v1) ) + return hi; + TernaryOp v2 = HopRewriteUtils.createTernary(hi_1, hi_0_0_1, v1,Types.OpOp3.PLUS_MULT); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_0_0); + + DMLExecutor.println("Applying rewrite: +(*(*(K,f32765),M40316),M9347) => +*(M9347,f32765,*(K,M40316))"); + return newRoot; + } + + // Implementation of the rule *(/(1.0,M13119),A) => /(A,M13119) + private static Hop _applyRewrite44(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MULT || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.DIV || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( !(hi_0_0 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_0_0 = (LiteralOp) hi_0_0; + + if ( l_hi_0_0.getDataType() != Types.DataType.SCALAR|| !l_hi_0_0.getValueType().isNumeric() ) + return hi; + + if ( l_hi_0_0.getDoubleValue() != 1.0 ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: /(A,M13119) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1, hi_0_1) ) + return hi; + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1, hi_0_1, Types.OpOp2.DIV); + + Hop newRoot = v1; + if ( v1.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_0_0); + + DMLExecutor.println("Applying rewrite: *(/(1.0,M13119),A) => /(A,M13119)"); + return newRoot; + } + + // Implementation of the rule /(M43656,2.0) => *(0.5,M43656) + private static Hop _applyRewrite45(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.DIV || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.MATRIX || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( !(hi_1 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_1 = (LiteralOp) hi_1; + + if ( l_hi_1.getDataType() != Types.DataType.SCALAR|| !l_hi_1.getValueType().isNumeric() ) + return hi; + + if ( l_hi_1.getDoubleValue() != 2.0 ) + return hi; + + + // Now, we start building the new HOP-DAG: *(0.5,M43656) + LiteralOp l1 = new LiteralOp( 0.5 ); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(l1, hi_0, Types.OpOp2.MULT); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: /(M43656,2.0) => *(0.5,M43656)"); + return newRoot; + } + + // Implementation of the rule +(-(b,%*%(C,D)),A) => +(b,-(A,%*%(C,D))) + private static Hop _applyRewrite46(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.SCALAR || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if (hi_0_1.getParent().size() > 1) + return hi; + if ( !HopRewriteUtils.isMatrixMultiply(hi_0_1) ) + return hi; + + Hop hi_0_1_0 = hi_0_1.getInput(0); + + if ( hi_0_1_0.getDataType() != Types.DataType.MATRIX || !hi_0_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1_1 = hi_0_1.getInput(1); + + if ( hi_0_1_1.getDataType() != Types.DataType.MATRIX || !hi_0_1_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: +(b,-(A,%*%(C,D))) + AggBinaryOp v1 = HopRewriteUtils.createMatrixMultiply(hi_0_1_0, hi_0_1_1); + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1, v1) ) + return hi; + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_1, v1, Types.OpOp2.MINUS); + BinaryOp v3 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, v2, Types.OpOp2.PLUS); + + Hop newRoot = v3; + if ( v3.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_0_1); + + DMLExecutor.println("Applying rewrite: +(-(b,%*%(C,D)),A) => +(b,-(A,%*%(C,D)))"); + return newRoot; + } + + // Implementation of the rule -(-(f43240,A),f67634) => -(-(f43240,f67634),A) + private static Hop _applyRewrite47(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MINUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.SCALAR || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.SCALAR || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(-(f43240,f67634),A) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, hi_1, Types.OpOp2.MINUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_0_1, Types.OpOp2.MINUS); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + + DMLExecutor.println("Applying rewrite: -(-(f43240,A),f67634) => -(-(f43240,f67634),A)"); + return newRoot; + } + + // Implementation of the rule /(M62235,2000.0) => *(5.0E-4,M62235) + private static Hop _applyRewrite48(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.DIV || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.MATRIX || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( !(hi_1 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_1 = (LiteralOp) hi_1; + + if ( l_hi_1.getDataType() != Types.DataType.SCALAR|| !l_hi_1.getValueType().isNumeric() ) + return hi; + + if ( l_hi_1.getDoubleValue() != 2000.0 ) + return hi; + + + // Now, we start building the new HOP-DAG: *(5.0E-4,M62235) + LiteralOp l1 = new LiteralOp( 5.0E-4 ); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(l1, hi_0, Types.OpOp2.MULT); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: /(M62235,2000.0) => *(5.0E-4,M62235)"); + return newRoot; + } + + // Implementation of the rule *(A,/(1.0,M13119)) => /(A,M13119) + private static Hop _applyRewrite49(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MULT || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.MATRIX || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.DIV || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( !(hi_1_0 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_1_0 = (LiteralOp) hi_1_0; + + if ( l_hi_1_0.getDataType() != Types.DataType.SCALAR|| !l_hi_1_0.getValueType().isNumeric() ) + return hi; + + if ( l_hi_1_0.getDoubleValue() != 1.0 ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_1_1.getDataType() != Types.DataType.MATRIX || !hi_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: /(A,M13119) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0, hi_1_1) ) + return hi; + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0, hi_1_1, Types.OpOp2.DIV); + + Hop newRoot = v1; + if ( v1.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + HopRewriteUtils.cleanupUnreferenced(hi_1_0); + + DMLExecutor.println("Applying rewrite: *(A,/(1.0,M13119)) => /(A,M13119)"); + return newRoot; + } + + // Implementation of the rule *(f68833,-(0.0,M48693)) => *(M48693,-(0.0,f68833)) + private static Hop _applyRewrite50(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MULT || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.SCALAR || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.MINUS || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( !(hi_1_0 instanceof LiteralOp) ) + return hi; + + LiteralOp l_hi_1_0 = (LiteralOp) hi_1_0; + + if ( l_hi_1_0.getDataType() != Types.DataType.SCALAR|| !l_hi_1_0.getValueType().isNumeric() ) + return hi; + + if ( l_hi_1_0.getDoubleValue() != 0.0 ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_1_1.getDataType() != Types.DataType.MATRIX || !hi_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: *(M48693,-(0.0,f68833)) + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1_0, hi_0, Types.OpOp2.MINUS); + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_1_1, v1, Types.OpOp2.MULT); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: *(f68833,-(0.0,M48693)) => *(M48693,-(0.0,f68833))"); + return newRoot; + } + + // Implementation of the rule -(M22650,*(*(f97734,M67673),M97683)) => -*(M22650,f97734,*(M97683,M67673)) + private static Hop _applyRewrite51(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MINUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if ( hi_0.getDataType() != Types.DataType.MATRIX || !hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.MULT || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if (hi_1_0.getParent().size() > 1) + return hi; + if ( !(hi_1_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1_0 = (BinaryOp) hi_1_0; + + if ( c_hi_1_0.getOp() != Types.OpOp2.MULT || !c_hi_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0_0 = hi_1_0.getInput(0); + + if ( hi_1_0_0.getDataType() != Types.DataType.SCALAR || !hi_1_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0_1 = hi_1_0.getInput(1); + + if ( hi_1_0_1.getDataType() != Types.DataType.MATRIX || !hi_1_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_1_1.getDataType() != Types.DataType.MATRIX || !hi_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -*(M22650,f97734,*(M97683,M67673)) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1_1, hi_1_0_1) ) + return hi; + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1_1, hi_1_0_1, Types.OpOp2.MULT); + if ( !RewriterRuntimeUtils.hasMatchingDims(hi_0, v1) ) + return hi; + TernaryOp v2 = HopRewriteUtils.createTernary(hi_0, hi_1_0_0, v1,Types.OpOp3.MINUS_MULT); + + Hop newRoot = v2; + if ( v2.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_1); + HopRewriteUtils.cleanupUnreferenced(hi_1_0); + + DMLExecutor.println("Applying rewrite: -(M22650,*(*(f97734,M67673),M97683)) => -*(M22650,f97734,*(M97683,M67673))"); + return newRoot; + } + + // Implementation of the rule -(-(f75306,M67233),*(A,M350)) => -(f75306,+(*(A,M350),M67233)) + private static Hop _applyRewrite52(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MINUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.SCALAR || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.MATRIX || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !(hi_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_1 = (BinaryOp) hi_1; + + if ( c_hi_1.getOp() != Types.OpOp2.MULT || !c_hi_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( hi_1_0.getDataType() != Types.DataType.MATRIX || !hi_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_1_1.getDataType() != Types.DataType.MATRIX || !hi_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(f75306,+(*(A,M350),M67233)) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1_0, hi_1_1) ) + return hi; + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1_0, hi_1_1, Types.OpOp2.MULT); + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(v1, hi_0_1) ) + return hi; + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_0_1, Types.OpOp2.PLUS); + BinaryOp v3 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, v2, Types.OpOp2.MINUS); + + Hop newRoot = v3; + if ( v3.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: -(-(f75306,M67233),*(A,M350)) => -(f75306,+(*(A,M350),M67233))"); + return newRoot; + } + + // Implementation of the rule -(-(f75306,*(A,M350)),M67233) => -(f75306,+(*(A,M350),M67233)) + private static Hop _applyRewrite53(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.MINUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.SCALAR || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if (hi_0_1.getParent().size() > 1) + return hi; + if ( !(hi_0_1 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0_1 = (BinaryOp) hi_0_1; + + if ( c_hi_0_1.getOp() != Types.OpOp2.MULT || !c_hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1_0 = hi_0_1.getInput(0); + + if ( hi_0_1_0.getDataType() != Types.DataType.MATRIX || !hi_0_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1_1 = hi_0_1.getInput(1); + + if ( hi_0_1_1.getDataType() != Types.DataType.MATRIX || !hi_0_1_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if ( hi_1.getDataType() != Types.DataType.MATRIX || !hi_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(f75306,+(*(A,M350),M67233)) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_1_0, hi_0_1_1) ) + return hi; + BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_1_0, hi_0_1_1, Types.OpOp2.MULT); + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(v1, hi_1) ) + return hi; + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_1, Types.OpOp2.PLUS); + BinaryOp v3 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, v2, Types.OpOp2.MINUS); + + Hop newRoot = v3; + if ( v3.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_0_1); + + DMLExecutor.println("Applying rewrite: -(-(f75306,*(A,M350)),M67233) => -(f75306,+(*(A,M350),M67233))"); + return newRoot; + } + + // Implementation of the rule +(-(C,d),%*%(A,B)) => -(+(C,%*%(A,B)),d) + private static Hop _applyRewrite54(Hop hi) { + if ( !(hi instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi = (BinaryOp) hi; + + if ( c_hi.getOp() != Types.OpOp2.PLUS || !c_hi.getValueType().isNumeric() ) + return hi; + + Hop hi_0 = hi.getInput(0); + + if (hi_0.getParent().size() > 1) + return hi; + if ( !(hi_0 instanceof BinaryOp) ) + return hi; + + BinaryOp c_hi_0 = (BinaryOp) hi_0; + + if ( c_hi_0.getOp() != Types.OpOp2.MINUS || !c_hi_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_0 = hi_0.getInput(0); + + if ( hi_0_0.getDataType() != Types.DataType.MATRIX || !hi_0_0.getValueType().isNumeric() ) + return hi; + + Hop hi_0_1 = hi_0.getInput(1); + + if ( hi_0_1.getDataType() != Types.DataType.SCALAR || !hi_0_1.getValueType().isNumeric() ) + return hi; + + Hop hi_1 = hi.getInput(1); + + if (hi_1.getParent().size() > 1) + return hi; + if ( !HopRewriteUtils.isMatrixMultiply(hi_1) ) + return hi; + + Hop hi_1_0 = hi_1.getInput(0); + + if ( hi_1_0.getDataType() != Types.DataType.MATRIX || !hi_1_0.getValueType().isNumeric() ) + return hi; + + Hop hi_1_1 = hi_1.getInput(1); + + if ( hi_1_1.getDataType() != Types.DataType.MATRIX || !hi_1_1.getValueType().isNumeric() ) + return hi; + + + // Now, we start building the new HOP-DAG: -(+(C,%*%(A,B)),d) + AggBinaryOp v1 = HopRewriteUtils.createMatrixMultiply(hi_1_0, hi_1_1); + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_0, v1) ) + return hi; + BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, v1, Types.OpOp2.PLUS); + BinaryOp v3 = HopRewriteUtils.createAutoGeneratedBinary(v2, hi_0_1, Types.OpOp2.MINUS); + + Hop newRoot = v3; + if ( v3.getValueType() != hi.getValueType() ) { + newRoot = castIfNecessary(newRoot, hi); + if ( newRoot == null ) + return hi; + } + + ArrayList parents = new ArrayList<>(hi.getParent()); + + for ( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, newRoot); + + // Remove old unreferenced Hops + HopRewriteUtils.cleanupUnreferenced(hi); + HopRewriteUtils.cleanupUnreferenced(hi_0); + HopRewriteUtils.cleanupUnreferenced(hi_1); + + DMLExecutor.println("Applying rewrite: +(-(C,d),%*%(A,B)) => -(+(C,%*%(A,B)),d)"); + return newRoot; + } + + private static Hop castIfNecessary(Hop newRoot, Hop oldRoot) { + Types.OpOp1 cast = null; + switch ( oldRoot.getValueType().toExternalString() ) { + case "DOUBLE": + cast = Types.OpOp1.CAST_AS_DOUBLE; + break; + case "INT": + cast = Types.OpOp1.CAST_AS_INT; + break; + case "BOOLEAN": + cast = Types.OpOp1.CAST_AS_BOOLEAN; + break; + default: + return null; + } + + return new UnaryOp("tmp", oldRoot.getDataType(), oldRoot.getValueType(), cast, newRoot); + } + private static int minIdx(double[] l) { + double minValue = Double.MAX_VALUE; + int minIdx = -1; + + for (int i = 0; i < l.length; i++) { + if (l[i] < minValue) { + minValue = l[i]; + minIdx = i; + } + } + + return minIdx; + } +} \ No newline at end of file diff --git a/src/main/java/org/apache/sysds/hops/rewriter/generated/RewriteAutomaticallyGenerated.java b/src/main/java/org/apache/sysds/hops/rewriter/generated/RewriteAutomaticallyGenerated.java new file mode 100644 index 00000000000..d8a05d85ebd --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/generated/RewriteAutomaticallyGenerated.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.generated; + +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.rewrite.HopRewriteRule; +import org.apache.sysds.hops.rewrite.ProgramRewriteStatus; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleSet; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; + +public class RewriteAutomaticallyGenerated extends HopRewriteRule { + public static final String FILE_PATH = null; + public static RewriteAutomaticallyGenerated existingRewrites; + + private Function rewriteFn; + public static long totalTimeNanos = 0; + public static long callCount = 0; + public static long maxTimeNanos = -1; + + // This constructor could be used to dynamically compile generated rewrite rules from a file + @Deprecated + public RewriteAutomaticallyGenerated() { + // Try to read the file + try { + final RuleContext ctx = RewriterUtils.buildDefaultContext(); + List lines = Files.readAllLines(Paths.get(FILE_PATH)); + RewriterRuleSet ruleSet = RewriterRuleSet.deserialize(lines, ctx); + + rewriteFn = ruleSet.compile("AutomaticallyGeneratedRewriteFunction", false); + existingRewrites = this; + } catch (IOException e) { + } + } + + public RewriteAutomaticallyGenerated(Function rewriteFn) { + this.rewriteFn = rewriteFn; + } + + @Override + public ArrayList rewriteHopDAGs(ArrayList roots, ProgramRewriteStatus state) { + if( roots == null || rewriteFn == null ) + return roots; + + long startNanos = System.nanoTime(); + + //one pass rewrite-descend (rewrite created pattern) + for( Hop h : roots ) + rule_apply( h, false ); + Hop.resetVisitStatus(roots, true); + + //one pass descend-rewrite (for rollup) + for( Hop h : roots ) + rule_apply( h, true ); + + long diff = System.nanoTime() - startNanos; + totalTimeNanos += diff; + callCount++; + if (maxTimeNanos == -1 || maxTimeNanos < diff) + maxTimeNanos = diff; + + return roots; + } + + @Override + public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) { + if( root == null || rewriteFn == null ) + return root; + + long startNanos = System.nanoTime(); + + //one pass rewrite-descend (rewrite created pattern) + rule_apply( root, false ); + + root.resetVisitStatus(); + + //one pass descend-rewrite (for rollup) + rule_apply( root, true ); + + long diff = System.nanoTime() - startNanos; + totalTimeNanos += diff; + callCount++; + if (maxTimeNanos == -1 || maxTimeNanos < diff) + maxTimeNanos = diff; + + return root; + } + + private void rule_apply(Hop hop, boolean descendFirst) + { + if(hop.isVisited()) + return; + + //recursively process children + for( int i=0; i f; + private final boolean accelerated; + + public RewriterHeuristic(RewriterRuleSet ruleSet) { + this(ruleSet, true); + } + + public RewriterHeuristic(RewriterRuleSet ruleSet, boolean accelerated) { + this.ruleSet = ruleSet; + this.accelerated = accelerated; + this.f = null; + } + + public RewriterHeuristic(Function f) { + this.ruleSet = null; + this.accelerated = false; + this.f = f; + } + + public void forEachRuleSet(Consumer consumer, boolean printNames) { + consumer.accept(ruleSet); + } + + public RewriterStatement apply(RewriterStatement current) { + return apply(current, null); + } + + public RewriterStatement apply(RewriterStatement current, @Nullable BiFunction handler) { + return apply(current, handler, new MutableBoolean(false), true); + } + + public RewriterStatement apply(RewriterStatement currentStmt, @Nullable BiFunction handler, MutableBoolean foundRewrite, boolean print) { + if (f != null) + return f.apply(currentStmt); + + RuleContext.currentContext = ruleSet.getContext(); + + if (handler != null && !handler.apply(currentStmt, null)) + return currentStmt; + + RewriterRuleSet.ApplicableRule rule; + if (accelerated) + rule = ruleSet.acceleratedFindFirst(currentStmt); + else + throw new NotImplementedException("Must use accelerated mode"); + + if (rule != null) + foundRewrite.setValue(true); + + for (int i = 0; i < 500 && rule != null; i++) { + currentStmt = rule.rule.apply(rule.matches.get(0), currentStmt, rule.forward, false); + + if (handler != null && !handler.apply(currentStmt, rule.rule)) { + rule = null; + break; + } + + if (!(currentStmt instanceof RewriterInstruction)) { + rule = null; + break; + } + + if (accelerated) + rule = ruleSet.acceleratedFindFirst(currentStmt); + else + throw new IllegalArgumentException("Must use accelerated mode!"); + } + + if (rule != null) + throw new IllegalArgumentException("Expression did not converge:\n" + currentStmt.toParsableString(ruleSet.getContext(), true) + "\nRule: " + rule); + + return currentStmt; + } + + @Override + public String toString() { + return ruleSet.toString(); + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterHeuristicTransformation.java b/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterHeuristicTransformation.java new file mode 100644 index 00000000000..4a62323c77b --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterHeuristicTransformation.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.rule; + +import org.apache.commons.lang3.mutable.MutableBoolean; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleSet; + +import javax.annotation.Nullable; +import java.util.function.BiFunction; +import java.util.function.Consumer; + +public interface RewriterHeuristicTransformation { + RewriterStatement apply(RewriterStatement stmt, @Nullable BiFunction func, MutableBoolean bool, boolean print); + + void forEachRuleSet(Consumer consumer, boolean printNames); + + default RewriterStatement apply(RewriterStatement stmt, @Nullable BiFunction func) { + return apply(stmt, func, new MutableBoolean(false), true); + } + + default RewriterStatement apply(RewriterStatement stmt, @Nullable BiFunction func, boolean print) { + return apply(stmt, func, new MutableBoolean(false), print); + } + + default RewriterStatement apply(RewriterStatement stmt) { + return apply(stmt, null, new MutableBoolean(false), true); + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterHeuristics.java b/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterHeuristics.java new file mode 100644 index 00000000000..681ac34e1a9 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterHeuristics.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.rule; + +import org.apache.commons.lang3.mutable.MutableBoolean; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.jetbrains.annotations.Nullable; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.BiFunction; +import java.util.function.Consumer; + +public class RewriterHeuristics implements RewriterHeuristicTransformation { + protected static final Log LOG = LogFactory.getLog(RewriterHeuristic.class.getName()); + List heuristics = new ArrayList<>(); + + public void forEachRuleSet(Consumer consumer, boolean printNames) { + heuristics.forEach(entry -> { + if (printNames) { + LOG.info("\n"); + LOG.info("> " + entry.name + " <"); + LOG.info("\n"); + } + entry.heuristics.forEachRuleSet(consumer, printNames); + }); + } + + public void add(String name, RewriterHeuristicTransformation heur) { + heuristics.add(new HeuristicEntry(name, heur)); + } + + public void addRepeated(String name, RewriterHeuristicTransformation heur) { + heuristics.add(new HeuristicEntry(name, new RepeatedHeuristics(heur))); + } + + @Override + public RewriterStatement apply(RewriterStatement stmt, @Nullable BiFunction func, MutableBoolean bool, boolean print) { + for (HeuristicEntry entry : heuristics) { + if (print) { + System.out.println("\n"); + System.out.println("> " + entry.name + " <"); + System.out.println("\n"); + } + + stmt = entry.heuristics.apply(stmt, func, bool, print); + } + + return stmt; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + + for (HeuristicEntry entry : heuristics) { + sb.append("\n> "); + sb.append(entry.name); + sb.append(" <\n"); + + sb.append(entry.heuristics.toString()); + } + + return sb.toString(); + } + + class RepeatedHeuristics implements RewriterHeuristicTransformation { + RewriterHeuristicTransformation heuristic; + + public RepeatedHeuristics(RewriterHeuristicTransformation heuristic) { + this.heuristic = heuristic; + } + + @Override + public RewriterStatement apply(RewriterStatement stmt, @Nullable BiFunction func, MutableBoolean bool, boolean print) { + bool.setValue(true); + + while (bool.getValue()) { + bool.setValue(false); + stmt = heuristic.apply(stmt, func, bool, print); + } + + return stmt; + } + + @Override + public void forEachRuleSet(Consumer consumer, boolean printNames) { + heuristic.forEachRuleSet(consumer, printNames); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + + sb.append("\n===== REPEAT =====\n"); + + for (HeuristicEntry entry : heuristics) { + sb.append("\n> "); + sb.append(entry.name); + sb.append(" <\n"); + + sb.append(entry.heuristics.toString()); + } + + sb.append("\n===== END REPEAT ====="); + + return sb.toString(); + } + } + + + class HeuristicEntry { + String name; + RewriterHeuristicTransformation heuristics; + + public HeuristicEntry(String name, RewriterHeuristicTransformation heuristics) { + this.name = name; + this.heuristics = heuristics; + } + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRule.java b/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRule.java new file mode 100644 index 00000000000..408abe71290 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRule.java @@ -0,0 +1,489 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.rule; + +import org.apache.commons.lang3.NotImplementedException; +import org.apache.commons.lang3.mutable.MutableObject; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertionUtils; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; +import org.apache.sysds.hops.rewriter.estimators.RewriterCostEstimator; +import org.apache.sysds.hops.rewriter.estimators.RewriterSparsityEstimator; +import scala.Tuple2; +import scala.Tuple3; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class RewriterRule { + + private final RuleContext ctx; + private final String name; + private final RewriterStatement fromRoot; + private final RewriterStatement toRoot; + private List toRoots; + private final HashMap linksStmt1ToStmt2; // Contains the explicit links a transformation has (like instructions, (a+b)-c = a+(b-c), but '+' and '-' are the same instruction still [important if instructions have metadata]) + private final HashMap linksStmt2ToStmt1; + private final List>> applyStmt1ToStmt2; + private final List>> applyStmt2ToStmt1; + private final Function iff1to2; + private final Function iff2to1; + private final boolean unidirectional; + private final Consumer postProcessor; + private Set allowedMultiReferences = Collections.emptySet(); + private RewriterAssertions combinedAssertions; + private boolean allowCombinations = false; + private boolean requireCostCheck = false; + private RewriterStatement fromCost = null; + private List toCosts = null; + + public RewriterRule(final RuleContext ctx, String name, RewriterStatement fromRoot, RewriterStatement toRoot, boolean unidirectional, HashMap linksStmt1ToStmt2, HashMap linksStmt2ToStmt1) { + this(ctx, name, fromRoot, toRoot, unidirectional, linksStmt1ToStmt2, linksStmt2ToStmt1, null, null, null, null, null); + } + + public RewriterRule(final RuleContext ctx, String name, RewriterStatement fromRoot, RewriterStatement toRoot, boolean unidirectional, HashMap linksStmt1ToStmt2, HashMap linksStmt2ToStmt1, Function iff1to2, Function iff2to1, List>> apply1To2, List>> apply2To1) { + this(ctx, name, fromRoot, toRoot, unidirectional, linksStmt1ToStmt2, linksStmt2ToStmt1, iff1to2, iff2to1, apply1To2, apply2To1, null); + } + + public RewriterRule(final RuleContext ctx, String name, RewriterStatement fromRoot, RewriterStatement toRoot, boolean unidirectional, HashMap linksStmt1ToStmt2, HashMap linksStmt2ToStmt1, Function iff1to2, Function iff2to1, List>> apply1To2, List>> apply2To1, Consumer postProcessor) { + this.ctx = ctx; + this.name = name; + this.fromRoot = fromRoot; + this.toRoot = toRoot; + this.unidirectional = unidirectional; + this.linksStmt1ToStmt2 = linksStmt1ToStmt2; + this.linksStmt2ToStmt1 = linksStmt2ToStmt1; + this.iff1to2 = iff1to2; + this.iff2to1 = iff2to1; + this.applyStmt1ToStmt2 = apply1To2; + this.applyStmt2ToStmt1 = apply2To1; + this.postProcessor = postProcessor; + } + + // Determine if this rule can universally be applied or only in some conditions (e.g. certain dimensions / sparsity) + public boolean determineConditionalApplicability() { + RewriterAssertions assertions = new RewriterAssertions(ctx); + RewriterAssertionUtils.buildImplicitAssertion(fromRoot, assertions, fromRoot, ctx); + for (RewriterStatement root : getStmt2AsList()) + RewriterAssertionUtils.buildImplicitAssertion(root, assertions, root, ctx); + + List, Long, Long>> costs = RewriterCostEstimator.compareCosts(fromRoot, getStmt2(), assertions, ctx, false, -1, false); + + requireCostCheck = isConditionalMultiRule() || RewriterCostEstimator.doesHaveAnImpactOnOptimalExpression(costs, false, true, 20); + + if (!requireCostCheck) + return false; + + List roots = toRoots == null ? List.of(toRoot) : toRoots; + + boolean integrateSparsityInCosts = isConditionalMultiRule() || RewriterCostEstimator.doesHaveAnImpactOnOptimalExpression(costs, true, false, 20); + + MutableObject assertionRef = new MutableObject<>(assertions); + fromCost = RewriterCostEstimator.getRawCostFunction(fromRoot, ctx, assertionRef, !integrateSparsityInCosts); + toCosts = getStmt2AsList().stream().map(root -> RewriterCostEstimator.getRawCostFunction(root, ctx, assertionRef, !integrateSparsityInCosts)).collect(Collectors.toList()); + + fromCost = RewriterSparsityEstimator.rollupSparsities(fromCost, RewriterSparsityEstimator.estimateAllNNZ(fromRoot, ctx), ctx); + toCosts = IntStream.range(0, toCosts.size()).mapToObj(i -> RewriterSparsityEstimator.rollupSparsities(toCosts.get(i), RewriterSparsityEstimator.estimateAllNNZ(roots.get(i), ctx), ctx)).collect(Collectors.toList()); + + return requireCostCheck; + } + + public boolean requiresCostCheck() { + return requireCostCheck; + } + + public RewriterStatement getStmt1Cost() { + return fromCost; + } + + public RewriterStatement getStmt2Cost() { + return toCosts.get(0); + } + + public List getStmt2Costs() { + return toCosts; + } + + public void buildCombinedAssertions() { + combinedAssertions = RewriterAssertionUtils.buildImplicitAssertions(fromRoot, ctx); + if (toRoot != null) + RewriterAssertionUtils.buildImplicitAssertions(toRoot, combinedAssertions, ctx); + else { + for (RewriterStatement root : toRoots) + RewriterAssertionUtils.buildImplicitAssertions(root, combinedAssertions, ctx); + } + } + + public RewriterAssertions getCombinedAssertions() { + if (combinedAssertions == null) + buildCombinedAssertions(); + + return combinedAssertions; + } + + public void setAllowedMultiReferences(Set allowed, boolean allowCombinations) { + this.allowedMultiReferences = allowed; + this.allowCombinations = allowCombinations; + } + + /** + * Overwrites the rule as a conditional rule + * @param targets all possible target statements + */ + public void setConditional(List targets) { + toRoots = targets; + } + + public boolean isConditionalMultiRule() { + return toRoots != null; + } + + public List getConditionalMultiRuleTargets() { + return toRoots; + } + + public String getName() { + return name; + } + + public RewriterStatement getStmt1() { + return fromRoot; + } + + /** + * Returns the target statement. + * @return the target statement; in case of a multi-rule, this will return the first option + */ + public RewriterStatement getStmt2() { + return toRoot != null ? toRoot : toRoots.get(0); + } + + public List getStmt2AsList() { + return toRoot != null ? List.of(toRoot) : toRoots; + } + + public boolean isUnidirectional() { + return unidirectional; + } + + public HashMap getForwardLinks() { + return linksStmt1ToStmt2; + } + + public HashMap getBackwardLinks() { + return linksStmt2ToStmt1; + } + + public RewriterStatement apply(RewriterStatement.MatchingSubexpression match, RewriterStatement rootNode, boolean forward, boolean inplace) { + return apply(match, rootNode, forward, inplace, false); + } + + public RewriterStatement apply(RewriterStatement.MatchingSubexpression match, RewriterStatement rootNode, boolean forward, boolean inplace, boolean updateTypes) { + return forward ? applyForward(match, rootNode, inplace, updateTypes) : applyBackward(match, rootNode, inplace, updateTypes); + } + + public RewriterStatement applyForward(RewriterStatement.MatchingSubexpression match, RewriterStatement rootNode, boolean inplace, boolean updateTypes) { + return applyForward(match, rootNode, inplace, updateTypes, new MutableObject<>(null)); + } + + public RewriterStatement applyForward(RewriterStatement.MatchingSubexpression match, RewriterStatement rootNode, boolean inplace, boolean updateTypes, MutableObject> modificationHandle) { + if (inplace) + throw new NotImplementedException("Inplace operations have been removed"); + RewriterStatement out = apply(match, rootNode, toRoot, modificationHandle, applyStmt1ToStmt2 == null ? Collections.emptyList() : applyStmt1ToStmt2); + if (updateTypes) + updateTypes(out, ctx); + return out; + } + + public RewriterStatement applyBackward(RewriterStatement.MatchingSubexpression match, RewriterStatement rootNode, boolean inplace, boolean updateTypes) { + return applyBackward(match, rootNode, inplace, updateTypes, new MutableObject<>(null)); + } + + public RewriterStatement applyBackward(RewriterStatement.MatchingSubexpression match, RewriterStatement rootNode, boolean inplace, boolean updateTypes, MutableObject> modificationHandle) { + if (inplace) + throw new NotImplementedException("Inplace operations have been removed"); + RewriterStatement out = apply(match, rootNode, fromRoot, modificationHandle, applyStmt2ToStmt1 == null ? Collections.emptyList() : applyStmt2ToStmt1); + if (updateTypes) + updateTypes(out, ctx); + return out; + } + + public RewriterStatement.MatchingSubexpression matchSingleStmt1(RewriterStatement exprRoot, RewriterStatement.RewriterPredecessor pred, RewriterStatement stmt, boolean allowImplicitTypeConversions) { + RewriterStatement.MatcherContext mCtx = new RewriterStatement.MatcherContext(ctx, stmt, pred, exprRoot, getStmt1(), true, true, false, true, true, false, true, false, false, allowImplicitTypeConversions, linksStmt1ToStmt2); + mCtx.currentStatement = stmt; + boolean match = getStmt1().match(mCtx); + + if (match) { + RewriterStatement.MatchingSubexpression matchExpr = mCtx.toMatch(); + + if (iff1to2 == null || iff1to2.apply(matchExpr)) + return matchExpr; + } + + return null; + } + + public RewriterStatement.MatchingSubexpression matchSingleStmt2(RewriterStatement exprRoot, RewriterStatement.RewriterPredecessor pred, RewriterStatement stmt, boolean allowImplicitTypeConversions) { + RewriterStatement.MatcherContext mCtx = new RewriterStatement.MatcherContext(ctx, stmt, pred, exprRoot, getStmt2(), true, true, false, true, true, false, true, false, false, allowImplicitTypeConversions, linksStmt2ToStmt1); + mCtx.currentStatement = stmt; + boolean match = getStmt2().match(mCtx); + + if (match) { + RewriterStatement.MatchingSubexpression matchExpr = mCtx.toMatch(); + + if (iff2to1 == null || iff2to1.apply(matchExpr)) + return matchExpr; + } + + return null; + } + + public void updateTypes(RewriterStatement root, final RuleContext ctx) { + root.forEachPostOrder((cur, pred) -> { + cur.refreshReturnType(ctx); + }, true); + } + + private RewriterStatement apply(RewriterStatement.MatchingSubexpression match, RewriterStatement rootInstruction, RewriterStatement dest, MutableObject> modificationHandle, List>> applyFunction) { + if (match.getPredecessor().isRoot()) { + final Map createdObjects = new HashMap<>(); + RewriterStatement cpy = dest.nestedCopyOrInject(createdObjects, obj -> { + RewriterStatement assoc = match.getAssocs().get(obj); + if (assoc != null) { + RewriterStatement assocCpy = createdObjects.get(assoc); + if (assocCpy == null) { + assocCpy = assoc.nestedCopyOrInject(createdObjects, obj2 -> null); + createdObjects.put(assoc, assocCpy); + } + + return assocCpy; + } + + return null; + }); + + RewriterStatement tmp = cpy.simplify(ctx); + if (tmp != null) + cpy = tmp; + + match.setNewExprRoot(cpy); + + RewriterStatement oldRootCpy = createdObjects.get(match.getExpressionRoot()); + RewriterAssertions assertions = null; + + if (oldRootCpy != null) { + assertions = (RewriterAssertions) oldRootCpy.getMeta("_assertions"); + oldRootCpy.unsafeRemoveMeta("_assertions"); + } else if (match.getExpressionRoot().getMeta("_assertions") != null) { + assertions = ((RewriterAssertions) match.getExpressionRoot().getMeta("_assertions")).nestedCopyOrInject(createdObjects, (obj, p, pIdx) -> { + RewriterStatement assoc = match.getAssocs().get(obj); + if (assoc != null) { + RewriterStatement assocCpy = createdObjects.get(assoc); + if (assocCpy == null) { + assocCpy = assoc.nestedCopyOrInject(createdObjects, obj2 -> null); + createdObjects.put(assoc, assocCpy); + } + + return assocCpy; + } + + return null; + }, match.getNewExprRoot()); + match.getExpressionRoot().unsafeRemoveMeta("_assertions"); + } + + if (assertions != null) { + if (!cpy.isLiteral()) + cpy.unsafePutMeta("_assertions", assertions); + } + + match.getLinks().forEach(lnk -> lnk.newStmt.replaceAll(createdObjects::get)); + match.getLinks().forEach(lnk -> lnk.transferFunction.accept(lnk)); + applyFunction.forEach(t -> t._2.accept(createdObjects.get(t._1), match)); + + if (postProcessor != null) + postProcessor.accept(cpy); + + if (ctx.metaPropagator != null) { + RewriterStatement mNew = ctx.metaPropagator.apply(cpy); + + if (mNew != cpy) { + mNew.unsafePutMeta("_assertions", cpy.getMeta("_assertions")); + cpy.unsafeRemoveMeta("_assertions"); + cpy = mNew; + } + } + + cpy.prepareForHashing(); + cpy.recomputeHashCodes(ctx); + + modificationHandle.setValue(new Tuple3<>(cpy, null, -1)); + + return cpy; + } + + final Map createdObjects = new HashMap<>(); + RewriterStatement cpy2 = rootInstruction.nestedCopyOrInject(createdObjects, (obj2, parent, pIdx) -> { + if (obj2.equals(match.getMatchRoot())) { + RewriterStatement cpy = dest.nestedCopyOrInject(createdObjects, obj -> { + RewriterStatement assoc = match.getAssocs().get(obj); + if (assoc != null) { + RewriterStatement assocCpy = createdObjects.get(assoc); + if (assocCpy == null) { + assocCpy = assoc.nestedCopyOrInject(createdObjects, obj3 -> null); + createdObjects.put(assoc, assocCpy); + } + return assocCpy; + } + return null; + }); + createdObjects.put(obj2, cpy); + modificationHandle.setValue(new Tuple3<>(cpy, parent, pIdx)); + return cpy; + } + return null; + }); + RewriterStatement tmp = cpy2.simplify(ctx); + if (tmp != null) + cpy2 = tmp; + + match.setNewExprRoot(cpy2); + + match.getLinks().forEach(lnk -> lnk.newStmt.replaceAll(createdObjects::get)); + cpy2.prepareForHashing(); + match.getLinks().forEach(lnk -> lnk.transferFunction.accept(lnk)); + applyFunction.forEach(t -> t._2.accept(createdObjects.get(t._1), match)); + + if (postProcessor != null) + postProcessor.accept(cpy2); + + if (ctx.metaPropagator != null) { + RewriterStatement mNew = ctx.metaPropagator.apply(cpy2); + + if (mNew != cpy2) { + mNew.unsafePutMeta("_assertions", cpy2.getMeta("_assertions")); + cpy2.unsafeRemoveMeta("_assertions"); + cpy2 = mNew; + } + } + + cpy2.prepareForHashing(); + cpy2.recomputeHashCodes(ctx); + + return cpy2; + } + + public String toString() { + if (isUnidirectional()) + if (isConditionalMultiRule()) + return fromRoot.toParsableString(ctx) + " => {" + toRoots.stream().map(stmt -> stmt.toParsableString(ctx)).collect(Collectors.joining("; ")) + "}"; + else + return fromRoot.toParsableString(ctx) + " => " + toRoot.toParsableString(ctx); + else + return fromRoot.toParsableString(ctx) + " <=> " + toRoot.toParsableString(ctx); + } + + public String toParsableString(final RuleContext ctx) { + Map> varDefs = new HashMap<>(); + StringBuilder sb = new StringBuilder(); + Map refs = new HashMap<>(); + int refIdx = fromRoot.toParsableString(sb, refs, 0, varDefs, allowedMultiReferences, ctx); + String stmt1 = sb.toString(); + sb = new StringBuilder(); + if (toRoot != null) { + toRoot.toParsableString(sb, refs, refIdx, varDefs, allowedMultiReferences, ctx); + } else { + for (RewriterStatement mToRoot : toRoots) { + mToRoot.toParsableString(sb, refs, refIdx, varDefs, allowedMultiReferences, ctx); + sb.append('\n'); + } + } + String stmt2 = sb.toString(); + String multiRefDefs = ""; + + if (!allowedMultiReferences.isEmpty()) { + multiRefDefs = "AllowedMultiRefs:" + allowedMultiReferences.stream().map(stmt -> "$" + refs.get(stmt)).collect(Collectors.joining(",")) + "\nAllowCombinations:" + allowCombinations + "\n"; + } + + String defs = RewriterStatement.parsableDefinitions(varDefs); + + if (toRoot != null) + return multiRefDefs + defs + "\n" + stmt1 + "\n=>\n" + stmt2; + else + return multiRefDefs + defs + "\n" + stmt1 + "\n=>\n{\n" + stmt2 + "}"; + } + + public static class LinkObject { + public List stmt; + public Consumer transferFunction; + + public LinkObject() { + stmt = new ArrayList<>(2); + } + + public LinkObject(List stmt, Consumer transferFunction) { + this.stmt = stmt; + this.transferFunction = transferFunction; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < stmt.size(); i++) { + if (i != 0) + sb.append(", "); + sb.append(stmt.get(i)); + } + return sb.toString(); + } + + @Override + public boolean equals(Object o) { + return o instanceof LinkObject && ((LinkObject)o).stmt == stmt; + } + + @Override + public int hashCode() { + return stmt.hashCode(); + } + } + + public static class ExplicitLink { + public final RewriterStatement oldStmt; + public List newStmt; + public final Consumer transferFunction; + + public ExplicitLink(RewriterStatement oldStmt, List newStmt, Consumer transferFunction) { + this.oldStmt = oldStmt; + this.newStmt = new ArrayList<>(newStmt); + this.transferFunction = transferFunction; + } + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleBuilder.java b/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleBuilder.java new file mode 100644 index 00000000000..078d79e2c95 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleBuilder.java @@ -0,0 +1,543 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.rule; + +import org.apache.sysds.hops.rewriter.RewriterDataType; +import org.apache.sysds.hops.rewriter.RewriterInstruction; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import scala.Tuple2; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Set; +import java.util.UUID; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.Function; + +public class RewriterRuleBuilder { + private final RuleContext ctx; + private String ruleName = "?"; + private ArrayList instrSeq = new ArrayList<>(); + private ArrayList mappingSeq = new ArrayList<>(); + private HashMap globalIds = new HashMap<>(); + private HashMap instrSeqIds = new HashMap<>(); + private HashMap mappingSeqIds = new HashMap<>(); + private HashMap linksStmt1ToStmt2 = new HashMap<>(); + private ArrayList>> applyStmt1ToStmt2 = new ArrayList<>(); + private HashMap linksStmt2ToStmt1 = new HashMap<>(); + private ArrayList>> applyStmt2ToStmt1 = new ArrayList<>(); + private RewriterStatement fromRoot = null; + private RewriterStatement toRoot = null; + private List multiRuleRoots = null; + private Function iff1to2 = null; + private Function iff2to1 = null; + private boolean isUnidirectional = false; + private boolean buildSingleDAG = false; + + private RewriterStatement currentStatement = null; + private boolean mappingState = false; + + private boolean canBeModified = true; + + private Set allowedMultiReferences = Collections.emptySet(); + private boolean allowCombinations = false; + + public RewriterRuleBuilder(final RuleContext ctx) { + this.ctx = ctx; + } + + public RewriterRuleBuilder(final RuleContext ctx, String ruleName) { + this.ctx = ctx; + this.ruleName = ruleName; + } + + public RewriterRuleBuilder iff(Function iff, boolean forward) { + if (buildSingleDAG) + throw new IllegalArgumentException(); + + if (forward) + iff1to2 = iff; + else + iff2to1 = iff; + + return this; + } + + public RewriterRuleBuilder parseGlobalVars(String globalVarDefinition) { + if (!canBeModified) + throw new IllegalArgumentException(); + RewriterUtils.parseDataTypes(globalVarDefinition, globalIds, ctx); + return this; + } + + public RewriterRuleBuilder intLiteral(String id, int value) { + return intLiteral(id, value, "global"); + } + + public RewriterRuleBuilder intLiteral(String id, int value, String scope) { + switch (scope) { + case "global": + globalIds.put(id, new RewriterDataType().as(id).ofType("INT").asLiteral(value)); + break; + case "from": + instrSeqIds.put(id, new RewriterDataType().as(id).ofType("INT").asLiteral(value)); + break; + case "to": + mappingSeqIds.put(id, new RewriterDataType().as(id).ofType("INT").asLiteral(value)); + break; + } + + return this; + } + + public RewriterRuleBuilder parseGlobalStatementAsVariable(String varName, String expr) { + return parseGlobalStatementAsVariable(varName, expr, new HashMap<>()); + } + + public RewriterRuleBuilder parseGlobalStatementAsVariable(String varName, String expr, HashMap refMap) { + if (!canBeModified) + throw new IllegalArgumentException(); + + RewriterStatement parsed = RewriterUtils.parseExpression(expr, refMap, globalIds, ctx); + parsed.consolidate(ctx); + globalIds.put(varName, parsed); + return this; + } + + public RewriterRuleBuilder withParsedStatement(String stmt) { + return withParsedStatement(stmt, new HashMap<>()); + } + + public RewriterRuleBuilder withParsedStatement(String stmt, HashMap refMap) { + if (!canBeModified) + throw new IllegalArgumentException(); + fromRoot = RewriterUtils.parseExpression(stmt, refMap, globalIds, ctx); + fromRoot.forEachPreOrderWithDuplicates(el -> { + instrSeqIds.put(el.getId(), el); + return true; + }); + return this; + } + + public RewriterRuleBuilder toParsedStatement(String stmt) { + return toParsedStatement(stmt, new HashMap<>()); + } + + public RewriterRuleBuilder toParsedStatement(String stmt, HashMap refMap) { + if (!canBeModified) + throw new IllegalArgumentException(); + mappingState = true; + toRoot = RewriterUtils.parseExpression(stmt, refMap, globalIds, ctx); + toRoot.forEachPreOrderWithDuplicates(el -> { + mappingSeqIds.put(el.getId(), el); + return true; + }); + return this; + } + + public RewriterRuleBuilder prepare() { + if (!canBeModified) + return this; + if (buildSingleDAG) { + getCurrentInstruction().consolidate(ctx); + fromRoot.prepareForHashing(); + fromRoot.recomputeHashCodes(ctx); + canBeModified = false; + } else { + if (getCurrentInstruction() != null) + getCurrentInstruction().consolidate(ctx); + fromRoot.prepareForHashing(); + if (toRoot != null) + toRoot.prepareForHashing(); + else + multiRuleRoots.forEach(RewriterStatement::prepareForHashing); + fromRoot.recomputeHashCodes(ctx); + if (toRoot != null) + toRoot.recomputeHashCodes(ctx); + else + multiRuleRoots.forEach(rt -> rt.recomputeHashCodes(ctx)); + canBeModified = false; + } + + return this; + } + + public RewriterRule build() { + if (buildSingleDAG) + throw new IllegalArgumentException("Cannot build a rule if DAG was specified"); + if (!mappingState) + throw new IllegalArgumentException("No mapping expression"); + if (fromRoot == null) + throw new IllegalArgumentException("From-root statement cannot be null"); + if (toRoot == null && multiRuleRoots == null) + throw new IllegalArgumentException("To-root statement cannot be null"); + if (getCurrentInstruction() != null) + getCurrentInstruction().consolidate(ctx); + prepare(); + RewriterRule rule = new RewriterRule(ctx, ruleName, fromRoot, toRoot, isUnidirectional, linksStmt1ToStmt2, linksStmt2ToStmt1, iff1to2, iff2to1, applyStmt1ToStmt2, applyStmt2ToStmt1); + rule.setAllowedMultiReferences(allowedMultiReferences, allowCombinations); + if (multiRuleRoots != null) + rule.setConditional(multiRuleRoots); + return rule; + } + + public RewriterStatement buildDAG() { + if (!buildSingleDAG) + throw new IllegalArgumentException("Cannot build a DAG if rule was specified"); + prepare(); + return fromRoot; + } + + public RewriterRuleBuilder asDAGBuilder() { + buildSingleDAG = true; + return this; + } + + public RewriterRuleBuilder setUnidirectional(boolean unidirectional) { + this.isUnidirectional = unidirectional; + return this; + } + + public RewriterStatement getCurrentInstruction() { + if (mappingState) + if (mappingSeq.size() > 0) + return mappingSeq.get(mappingSeq.size()-1); + else if (toRoot != null) + return toRoot; + else if (multiRuleRoots != null) + return multiRuleRoots.get(0); // Just as a dummy + else + throw new IllegalArgumentException("There is no current instruction in the mapping sequence"); + else + if (instrSeq.size() > 0) + return instrSeq.get(instrSeq.size()-1); + else if (fromRoot != null) + return fromRoot; + else + throw new IllegalArgumentException("There is no current instruction in the instruction sequence"); + } + + public RewriterDataType getCurrentOperand() { + if (currentStatement instanceof RewriterDataType) + return (RewriterDataType)currentStatement; + else + throw new IllegalArgumentException("The current operand is not a data type"); + } + + public RewriterRuleBuilder withDataType(String id, String type) { + withDataType(id, type, null); + return this; + } + + public RewriterRuleBuilder withDataType(String id, String type, Object literal) { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + if (!instrSeq.isEmpty()) + throw new IllegalArgumentException("To define a single data type, the instruction sequence must be empty"); + fromRoot = new RewriterDataType().ofType(type).asLiteral(literal).as(id); + storeVar(fromRoot); + return this; + } + + public RewriterRuleBuilder withInstruction(String instr) { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + if (mappingState) + throw new IllegalArgumentException("Cannot add an instruction when a mapping instruction was already defined"); + if (instrSeq.size() > 0) + getCurrentInstruction().consolidate(ctx); + instrSeq.add(new RewriterInstruction().withInstruction(instr)); + return this; + } + + public RewriterRuleBuilder completeRule(RewriterStatement from, RewriterStatement to) { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + if (mappingState) + throw new IllegalArgumentException("Cannot add an instruction when a mapping instruction was already defined"); + this.fromRoot = from; + this.toRoot = to; + this.mappingState = true; + return this; + } + + public RewriterRuleBuilder completeConditionalRule(RewriterStatement from, List to) { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + if (mappingState) + throw new IllegalArgumentException("Cannot add an instruction when a mapping instruction was already defined"); + this.fromRoot = from; + this.multiRuleRoots = to; + this.mappingState = true; + return this; + } + + public RewriterRuleBuilder withAllowedMultiRefs(Set allowedMultiRefs, boolean allowCombinations) { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + if (!mappingState) + throw new IllegalArgumentException("Cannot add an instruction when a mapping instruction was already defined"); + + this.allowedMultiReferences = allowedMultiRefs; + this.allowCombinations = allowCombinations; + return this; + } + + public RewriterRuleBuilder withOps(RewriterDataType... operands) { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + ((RewriterInstruction)getCurrentInstruction()).withOps(operands); + currentStatement = null; + return this; + } + + public RewriterRuleBuilder addOp(String id) { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + RewriterDataType dt = new RewriterDataType().as(id); + storeVar(dt); + ((RewriterInstruction)getCurrentInstruction()).addOp(dt); + if (currentStatement != null) + currentStatement.consolidate(ctx); + currentStatement = dt; + return this; + } + + public RewriterRuleBuilder addDynamicOpListInstr(String id, String type, boolean fromInstr, String... ops) { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + + if (fromInstr) + withInstruction("argList"); + else + toInstruction("argList"); + + if (ops.length == 0 && type.endsWith("...")) { + // Add one placeholder operand to implicitly determine the data type + addOp(UUID.randomUUID().toString()).ofType(type.substring(0, type.length()-3)); + } else { + for (String op : ops) + addExistingOp(op); + } + + as(id); + return this; + } + + public RewriterRuleBuilder asLiteral(Object literal) { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + getCurrentOperand().asLiteral(literal); + return this; + } + + public RewriterRuleBuilder as(String id) { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + getCurrentInstruction().as(id); + currentVars().put(id, getCurrentInstruction()); + storeVar(getCurrentInstruction()); + return this; + } + + public RewriterRuleBuilder asRootInstruction() { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + if (mappingState) { + if (toRoot != null) + throw new IllegalArgumentException("Cannot have more than one root instruction"); + toRoot = getCurrentInstruction().as("result"); + mappingSeqIds.put("result", toRoot); + } else { + if (fromRoot != null) + throw new IllegalArgumentException("Cannot have more than one root instruction"); + fromRoot = getCurrentInstruction().as("result"); + instrSeqIds.put("result", fromRoot); + } + return this; + } + + public RewriterRuleBuilder addExistingOp(String id) { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + RewriterStatement operand = findVar(id); + + if (operand == null) + throw new IllegalArgumentException("Operand with id '" + id + "' does not exist"); + + if (currentStatement != null) + currentStatement.consolidate(ctx); + + currentStatement = operand; + ((RewriterInstruction)getCurrentInstruction()).addOp(operand); + + return this; + } + + public RewriterRuleBuilder ofType(String type) { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + getCurrentOperand().ofType(type); + return this; + } + + public RewriterRuleBuilder instrMeta(String key, Object value) { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + getCurrentInstruction().putMeta(key, value); + return this; + } + + public RewriterRuleBuilder operandMeta(String key, Object value) { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + getCurrentOperand().putMeta(key, value); + return this; + } + + public RewriterRuleBuilder toInstruction(String instr) { + if (!canBeModified) + throw new IllegalArgumentException("The DAG is final and cannot be modified"); + if (buildSingleDAG) + throw new IllegalArgumentException("Cannot create a mapping instruction when building a single DAG"); + getCurrentInstruction().consolidate(ctx); + mappingSeq.add(new RewriterInstruction().withInstruction(instr)); + mappingState = true; + return this; + } + + public RewriterRuleBuilder linkUnidirectional(String idFrom, String idTo, Consumer transferFunction, boolean forward) { + return linkManyUnidirectional(idFrom, List.of(idTo), transferFunction, forward); + } + + public RewriterRuleBuilder linkManyUnidirectional(String idFrom, List idsTo, Consumer transferFunction, boolean forward) { + prepare(); + RewriterStatement stmt1 = forward ? instrSeqIds.get(idFrom) : mappingSeqIds.get(idFrom); + if (stmt1 == null) + stmt1 = globalIds.get(idFrom); + if (stmt1 == null) + throw new IllegalArgumentException("Could not find instruction id: " + idFrom); + if (!stmt1.isConsolidated()) + stmt1.consolidate(ctx); + + List stmts2 = new ArrayList<>(); + + for (String idTo : idsTo) { + RewriterStatement stmt2 = forward ? mappingSeqIds.get(idTo) : instrSeqIds.get(idTo); + if (stmt2 == null) + stmt2 = globalIds.get(idTo); + if (stmt2 == null) + throw new IllegalArgumentException("Could not find instruction id: " + idTo); + if (!stmt2.isConsolidated()) + stmt2.consolidate(ctx); + + stmts2.add(stmt2); + } + + HashMap links = forward ? linksStmt1ToStmt2 : linksStmt2ToStmt1; + + RewriterRule.LinkObject lnk = new RewriterRule.LinkObject(stmts2, transferFunction); + + if (links.containsKey(stmt1) || links.containsValue(lnk)) + throw new IllegalArgumentException("Key or value already exists in explicit link map."); + + links.put(stmt1, lnk); + return this; + } + + public RewriterRuleBuilder link(String id, String id2, Consumer transferFunction) { + linkUnidirectional(id, id2, transferFunction, true); + linkUnidirectional(id2, id, transferFunction, false); + return this; + } + + public RewriterRuleBuilder apply(String id, Consumer applicationFunction, boolean forward) { + return apply(id, (stmt, match) -> applicationFunction.accept(stmt), forward); + } + + public RewriterRuleBuilder apply(String id, BiConsumer applicationFunction, boolean forward) { + prepare(); + RewriterStatement stmt1 = forward ? mappingSeqIds.get(id) : instrSeqIds.get(id); + if (stmt1 == null) + stmt1 = globalIds.get(id); + if (stmt1 == null) + throw new IllegalArgumentException("Could not find instruction id: " + id); + if (!stmt1.isConsolidated()) + stmt1.consolidate(ctx); + + if (forward) + applyStmt1ToStmt2.add(new Tuple2<>(stmt1, applicationFunction)); + else + applyStmt2ToStmt1.add(new Tuple2<>(stmt1, applicationFunction)); + + return this; + } + + public RewriterRuleBuilder toDataType(String id, String type) { + toDataType(id, type, null); + return this; + } + + public RewriterRuleBuilder toDataType(String id, String type, Object literal) { + if (!mappingSeq.isEmpty()) + throw new IllegalArgumentException("To define a single data type, the mapping sequence must be empty"); + toRoot = new RewriterDataType().ofType(type).asLiteral(literal).as(id); + storeVar(toRoot); + return this; + } + + private HashMap currentVars() { + return mappingState ? mappingSeqIds : instrSeqIds; + } + + private RewriterStatement findVar(String id) { + RewriterStatement stmt = null; + + if (mappingState) { + stmt = mappingSeqIds.get(id); + if (stmt != null) + return stmt; + } else { + stmt = instrSeqIds.get(id); + if (stmt != null) + return stmt; + } + return globalIds.get(id); + } + + private void storeVar(RewriterStatement var) { + if (var.getId() == null) + throw new IllegalArgumentException("The id of a statement cannot be null!"); + + if (mappingState) { + mappingSeqIds.put(var.getId(), var); + } else { + if (var instanceof RewriterDataType) + globalIds.put(var.getId(), var); + else + instrSeqIds.put(var.getId(), var); + } + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleCollection.java b/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleCollection.java new file mode 100644 index 00000000000..0c5d4c99b98 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleCollection.java @@ -0,0 +1,1445 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.rule; + +import org.apache.sysds.hops.rewriter.RewriterInstruction; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; + +import java.util.HashMap; +import java.util.List; +import java.util.UUID; + +import static org.apache.sysds.hops.rewriter.RewriterContextSettings.ALL_TYPES; +import static org.apache.sysds.hops.rewriter.RewriterContextSettings.SCALARS; + +public class RewriterRuleCollection { + public static void substituteEquivalentStatements(final List rules, final RuleContext ctx) { + HashMap hooks = new HashMap<>(); + + rules.add(new RewriterRuleBuilder(ctx, "as.scalar(A) => cast.FLOAT(A)") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .withParsedStatement("as.scalar(A)") + .toParsedStatement("cast.FLOAT(A)") + .build() + ); + + SCALARS.forEach(t -> { + rules.add(new RewriterRuleBuilder(ctx, "as.matrix(a) => cast.MATRIX(a)") + .setUnidirectional(true) + .parseGlobalVars(t + ":a") + .withParsedStatement("as.matrix(a)") + .toParsedStatement("cast.MATRIX(a)") + .build() + ); + }); + + // Some meta operators + rules.add(new RewriterRuleBuilder(ctx, "rowVec(A) => [](A, ...)") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("rowVec(A)") + .toParsedStatement("[]($1:A, 1, 1, 1, ncol(A))", hooks) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "colVec(A) => [](A, ...)") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("colVec(A)") + .toParsedStatement("[](A, 1, nrow(A), 1, 1)") + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "cellMat(A) => [](A, ...)") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("cellMat(A)") + .toParsedStatement("[](A, 1, 1, 1, 1)") + .build() + ); + + substituteFusedOps(rules, ctx); + } + + public static void substituteFusedOps(final List rules, final RuleContext ctx) { + // Now resolve fused operators + rules.add(new RewriterRuleBuilder(ctx, "1-*(A,B) => -(1, *(A, B))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("LITERAL_FLOAT:1.0") // We take a float as this framework is optimized for floats + .withParsedStatement("1-*(A, B)") + .toParsedStatement("-(1.0, *(A, B))") + .build() + ); + rules.add(new RewriterRuleBuilder(ctx, "log_nz(A) => *(!=(A, 0.0), log(A))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("LITERAL_FLOAT:0.0") // We take a float as this framework is optimized for floats + .withParsedStatement("log_nz(A)") + .toParsedStatement("*(!=(A, 0.0), log(A))") + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "sumSq(A) => sum(*(A,A))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("LITERAL_FLOAT:0.0") + .withParsedStatement("sumSq(A)") + .toParsedStatement("sum(*(A,A))") + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "+*(A,s,Y) => +(A, *(s, Y))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,Y") + .parseGlobalVars("FLOAT:s") + .withParsedStatement("+*(A,s,Y)") + .toParsedStatement("+(A, *(s, Y))") + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "-*(A,s,Y) => -(A, *(s, Y))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,Y") + .parseGlobalVars("FLOAT:s") + .withParsedStatement("-*(A,s,Y)") + .toParsedStatement("-(A, *(s, Y))") + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "sq(A) => *(A,A)") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .withParsedStatement("sq(A)") + .toParsedStatement("*(A, A)") + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "_nnz(A) => sum(!=(A,0.0))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("LITERAL_FLOAT:0.0") + .withParsedStatement("_nnz(A)") + .toParsedStatement("sum(!=(A,0.0))") + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "*2(A) => +(A,A)") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .withParsedStatement("*2(A)") + .toParsedStatement("+(A,A)") + .build() + ); + + SCALARS.forEach(t -> { + rules.add(new RewriterRuleBuilder(ctx, "log_nz(A, a) => *(!=(A, 0.0), *(log(A), inv(log(a)))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("FLOAT:a") // We take a float as this framework is optimized for floats + .parseGlobalVars("LITERAL_FLOAT:0.0") + .withParsedStatement("log_nz(A, a)") + .toParsedStatement("*(!=(A, 0.0), *(log(A), inv(log(a))))") + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "log(A, a) => *(log(A), inv(log(a)))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("FLOAT:a") + .withParsedStatement("log(A, a)") + .toParsedStatement("*(log(A), inv(log(a)))") + .build() + ); + }); + } + + public static void eliminateMultipleCasts(final List rules, final RuleContext ctx) { + SCALARS.forEach(t -> { + rules.add(new RewriterRuleBuilder(ctx, "cast.TYPE(cast.TYPE(A)) => cast.TYPE(A)") + .setUnidirectional(true) + .parseGlobalVars(t + ":a") + .withParsedStatement("cast.MATRIX(cast.MATRIX(a))") + .toParsedStatement("cast.MATRIX(a)") + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "cast.TYPE(a::TYPE) => a") + .setUnidirectional(true) + .parseGlobalVars(t + ":a") + .withParsedStatement("cast." + t + "(a)") + .toParsedStatement("a") + .build() + ); + + SCALARS.forEach(t2 -> { + SCALARS.forEach(t3 -> { + rules.add(new RewriterRuleBuilder(ctx, "cast.TYPE(+(a, b)) => ...") + .setUnidirectional(true) + .parseGlobalVars(t2 + ":a") + .parseGlobalVars(t3 + ":b") + .withParsedStatement("cast." + t + "(+(a,b))") + .toParsedStatement("+(cast." + t + "(a), cast." + t + "(b))") + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "cast.TYPE(*(a, b)) => ...") + .setUnidirectional(true) + .parseGlobalVars(t2 + ":a") + .parseGlobalVars(t3 + ":b") + .withParsedStatement("cast." + t + "(*(a,b))") + .toParsedStatement("*(cast." + t + "(a), cast." + t + "(b))") + .build() + ); + }); + + rules.add(new RewriterRuleBuilder(ctx, "cast.TYPE(cast.TYPE(A)) => cast.TYPE(A)") + .setUnidirectional(true) + .parseGlobalVars(t + ":a") + .withParsedStatement("cast." + t2 + "(cast." + t2 + "(a))") + .toParsedStatement("cast." + t2 + "(a)") + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "cast.SCALAR(cast.MATRIX(a)) => a") + .setUnidirectional(true) + .parseGlobalVars(t2 + ":a") + .withParsedStatement("cast." + t + "(cast.MATRIX(a))") + .toParsedStatement("cast." + t + "(a)") + .build() + ); + }); + }); + } + + public static void canonicalizeAlgebraicStatements(final List rules, boolean allowInversionCanonicalization, final RuleContext ctx) { + HashMap hooks = new HashMap<>(); + + RewriterUtils.buildBinaryPermutations(ALL_TYPES, (t1, t2) -> { + rules.add(new RewriterRuleBuilder(ctx, "-(a,b) => +(a,-(b))") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .withParsedStatement("-(a, b)", hooks) + .toParsedStatement("+(a, -(b))", hooks) + .build() + ); + + if (allowInversionCanonicalization) { + rules.add(new RewriterRuleBuilder(ctx, "/(a,b) => *(a, inv(b))") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .withParsedStatement("/(a, b)", hooks) + .toParsedStatement("*(a, inv(b))", hooks) + .build() + ); + } + + rules.add(new RewriterRuleBuilder(ctx, "-(+(a, b)) => +(-(a), -(b))") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .withParsedStatement("-(+(a, b))", hooks) + .toParsedStatement("$1:+(-(a), -(b))", hooks) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "-(-(a)) => a") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .withParsedStatement("-(-(a))", hooks) + .toParsedStatement("a", hooks) + .build() + ); + }); + + rules.add(new RewriterRuleBuilder(ctx, "length(A) => nrow(A) * ncol(A)") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .withParsedStatement("length(A)", hooks) + .toParsedStatement("*(nrow(A), ncol(A))", hooks) + .build() + ); + + for (String t : ALL_TYPES) { + rules.add(new RewriterRuleBuilder(ctx, "-(inv(a)) => inv(-(a))") + .setUnidirectional(true) + .parseGlobalVars(t + ":A") + .withParsedStatement("-(inv(A))", hooks) + .toParsedStatement("inv(-(A))", hooks) + .build() + ); + } + + rules.add(new RewriterRuleBuilder(ctx, "-(sum(A)) => sum(-(A))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .withParsedStatement("-(sum(A))", hooks) + .toParsedStatement("sum(-(A))", hooks) + .build() + ); + } + + public static void canonicalizeBooleanStatements(final List rules, final RuleContext ctx) { + HashMap hooks = new HashMap<>(); + + RewriterUtils.buildBinaryPermutations(ALL_TYPES, (t1, t2) -> { + rules.add(new RewriterRuleBuilder(ctx, ">(a, b) => <(b, a)") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .withParsedStatement(">(a, b)", hooks) + .toParsedStatement("<(b, a)", hooks) + .build() + ); + + // These hold only for boolean expressions + /*rules.add(new RewriterRuleBuilder(ctx, "!(!(a)) = a") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .withParsedStatement("!(!(a))", hooks) + .toParsedStatement("a", hooks) + .build() + );*/ + + rules.add(new RewriterRuleBuilder(ctx, "<=(a, b) => |(<(a, b), ==(a, b))") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .withParsedStatement("<=(a, b)", hooks) + .toParsedStatement("|(<(a, b), ==(a, b))", hooks) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, ">=(a, b) => |(<(b, a), ==(b, a))") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .withParsedStatement(">=(a, b)", hooks) + .toParsedStatement("|(<(b, a), ==(b, a))", hooks) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "!(&(a, b)) => |(!(a), !(b))") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .withParsedStatement("!(&(a, b))", hooks) + .toParsedStatement("|(!(a), !(b))", hooks) + .build() + ); + + List.of("&(a, b)", "&(b, a)").forEach(exp -> { + List.of("|(" + exp + ", a)", "|(a, " + exp + ")").forEach(tExpr -> { + rules.add(new RewriterRuleBuilder(ctx, tExpr + " => a") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .withParsedStatement(tExpr, hooks) + .toParsedStatement("a", hooks) + .build() + ); + }); + }); + + List.of("|(a, b)", "|(b, a)").forEach(exp -> { + List.of("&(" + exp + ", a)", "&(a, " + exp + ")").forEach(tExpr -> { + rules.add(new RewriterRuleBuilder(ctx, tExpr + " => a") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .withParsedStatement(tExpr, hooks) + .toParsedStatement("a", hooks) + .build() + ); + }); + }); + + rules.add(new RewriterRuleBuilder(ctx, "|(<(b, a), <(a, b)) => b != a") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .withParsedStatement("|(<(b, a), <(a, b))", hooks) + .toParsedStatement("!=(b, a)", hooks) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "&(<(b, a), <(a, b)) => FALSE") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .parseGlobalVars("LITERAL_BOOL:FALSE") + .withParsedStatement("&(<(b, a), <(a, b))", hooks) + .toParsedStatement("FALSE", hooks) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "!(!=(a, b)) => ==(a, b)") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .parseGlobalVars("LITERAL_BOOL:FALSE") + .withParsedStatement("!(!=(a, b))", hooks) + .toParsedStatement("==(a, b)", hooks) + .build() + ); + + /*rules.add(new RewriterRuleBuilder(ctx, "==(a, b) => isZero(+(a, -(b)))") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .parseGlobalVars("LITERAL_BOOL:FALSE") + .withParsedStatement("!(!=(a, b))", hooks) + .toParsedStatement("==(a, b)", hooks) + .build() + );*/ + }); + } + + // E.g. expand A * B -> _m($1:_idx(), 1, nrow(A), _m($2:_idx(), 1, nrow(B), A[$1, $2] * B[$1, $2])) + public static void expandStreamingExpressions(final List rules, final RuleContext ctx) { + HashMap hooks = new HashMap<>(); + + // cast.MATRIX + rules.add(new RewriterRuleBuilder(ctx, "Expand const matrix") + .setUnidirectional(true) + .parseGlobalVars("FLOAT:a") + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("cast.MATRIX(a)", hooks) + .toParsedStatement("$4:_m(1, 1, a)", hooks) + .build() + ); + + // cast.FLOAT + rules.add(new RewriterRuleBuilder(ctx, "") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:a") + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("cast.FLOAT(A)", hooks) + .toParsedStatement("[](A, 1, 1)", hooks) + .build() + ); + + // Const + rules.add(new RewriterRuleBuilder(ctx, "Expand const matrix") + .setUnidirectional(true) + .parseGlobalVars("FLOAT:a") + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("const(A, a)", hooks) + .toParsedStatement("$4:_m($1:_idx(1, nrow(A)), $2:_idx(1, ncol(A)), a)", hooks) + .apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(4).getId(), (stmt, match) -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getChild(0).unsafePutMeta("ownerId", id); + }, true) // Assumes it will never collide + .build() + ); + + // Diag + rules.add(new RewriterRuleBuilder(ctx, "Expand diag matrix") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("LITERAL_INT:1") + .parseGlobalVars("LITERAL_FLOAT:0.0") + .withParsedStatement("diag(A)", hooks) + .toParsedStatement("$4:_m($1:_idx(1, nrow(A)), $2:_idx(1, ncol(A)), $5:ifelse(==($1,$2), [](A, $1, $2), 0.0))", hooks) + .apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(4).getId(), (stmt, match) -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getChild(0).unsafePutMeta("ownerId", id); + RewriterStatement aRef = stmt.getChild(0, 1, 0); + match.getNewExprRoot().getAssertions(ctx).addEqualityAssertion(aRef.getNCol(), aRef.getNRow(), match.getNewExprRoot()); + }, true) // Assumes it will never collide + .build() + ); + + + // Matrix Multiplication + rules.add(new RewriterRuleBuilder(ctx, "Expand matrix product") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("%*%(A, B)", hooks) + .toParsedStatement("$4:_m($1:_idx(1, nrow(A)), $2:_idx(1, ncol(B)), sum($5:_m($3:_idx(1, ncol(A)), 1, *([](A, $1, $3), [](B, $3, $2)))))", hooks) + .apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(3).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(4).getId(), (stmt, match) -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + stmt.getOperands().get(1).unsafePutMeta("ownerId", id); + + RewriterStatement aRef = stmt.getChild(0, 1, 0); + RewriterStatement bRef = stmt.getChild(1, 1, 0); + RewriterAssertions assertions = match.getNewExprRoot().getAssertions(ctx); + assertions.addEqualityAssertion(aRef.getNCol(), bRef.getNRow(), match.getNewExprRoot()); + assertions.update(match.getNewExprRoot()); + }, true) // Assumes it will never collide + .apply(hooks.get(5).getId(), stmt -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + }, true) // Assumes it will never collide + .build() + ); + + // E.g. A + B + rules.add(new RewriterRuleBuilder(ctx, "Expand Element Wise Instruction") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("$1:ElementWiseInstruction(A,B)", hooks) + .toParsedStatement("$7:_m($2:_idx(1, $5:nrow(A)), $3:_idx(1, $6:ncol(A)), $4:ElementWiseInstruction([](A, $2, $3), [](B, $2, $3)))", hooks) + .link(hooks.get(1).getId(), hooks.get(4).getId(), RewriterStatement::transferMeta) + .apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(3).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) + .apply(hooks.get(7).getId(), (stmt, match) -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + stmt.getOperands().get(1).unsafePutMeta("ownerId", id); + + // Now we assert that nrow(A) = nrow(B) and ncol(A) = ncol(B) + RewriterStatement aRef = stmt.getChild(2, 0, 0); + RewriterStatement bRef = stmt.getChild(2, 1, 0); + match.getNewExprRoot().getAssertions(ctx).addEqualityAssertion(aRef.getNRow(), bRef.getNRow(), match.getNewExprRoot()); + match.getNewExprRoot().getAssertions(ctx).addEqualityAssertion(aRef.getNCol(), bRef.getNCol(), match.getNewExprRoot()); + }, true) // Assumes it will never collide + .build() + ); + + List.of("$2:_m(i, j, v1), v2", "v1, $2:_m(i, j, v2)").forEach(s -> { + rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("LITERAL_INT:1") + .parseGlobalVars("INT:i,j") + .parseGlobalVars("FLOAT:v1,v2") + .withParsedStatement("$1:ElementWiseInstruction(" + s + ")", hooks) + .toParsedStatement("$3:_m(i, j, $4:ElementWiseInstruction(v1, v2))", hooks) + .link(hooks.get(1).getId(), hooks.get(4).getId(), RewriterStatement::transferMeta) + .link(hooks.get(2).getId(), hooks.get(3).getId(), RewriterStatement::transferMeta) + .build() + ); + }); + + // Trace(A) + rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("trace(A)", hooks) + .toParsedStatement("sum($3:_m($1:_idx(1, $2:nrow(A)), 1, [](A, $1, $1)))", hooks) + .apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("dontExpand", true), true) + .apply(hooks.get(3).getId(), (stmt, match) -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + + // Assert that the matrix is squared + RewriterStatement aRef = stmt.getChild(0, 1, 0); + match.getNewExprRoot().getAssertions(ctx).addEqualityAssertion(aRef.getNRow(), aRef.getNCol(), match.getNewExprRoot()); + }, true) + .build() + ); + + // t(A) + rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("t(A)", hooks) + .toParsedStatement("$3:_m($1:_idx(1, ncol(A)), $2:_idx(1, nrow(A)), [](A, $2, $1))", hooks) + .apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) + .apply(hooks.get(3).getId(), stmt -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + stmt.getOperands().get(1).unsafePutMeta("ownerId", id); + }, true) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("rev(A)", hooks) + .toParsedStatement("$3:_m($1:_idx(1, nrow(A)), $2:_idx(1, ncol(A)), [](A, -(+(ncol(A), 1), $1), $2))", hooks) + .apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) + .apply(hooks.get(3).getId(), stmt -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + stmt.getOperands().get(1).unsafePutMeta("ownerId", id); + }, true) + .build() + ); + + // rand(rows, cols, min, max) + rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("LITERAL_INT:1") + .parseGlobalVars("INT:n,m") + .parseGlobalVars("FLOAT:a,b") + .withParsedStatement("rand(n, m, a, b)", hooks) + .toParsedStatement("$3:_m($1:_idx(1, n), $2:_idx(1, m), +(a, $4:*(+(b, -(a)), rand(argList($1,$2)))))", hooks) + .apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) + .apply(hooks.get(3).getId(), stmt -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + stmt.getOperands().get(1).unsafePutMeta("ownerId", id); + }, true) + .build() + ); + + // sum(A) = sum(_m($1:_idx(1, nrow(A)), 1, sum(_m($2:_idx(1, ncol(A)), 1, [](A, $1, $2))))) + rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("sum(A)", hooks) + .toParsedStatement("sum($3:_m($1:_idx(1, nrow(A)), 1, sum($4:_m($2:_idx(1, ncol(A)), 1, [](A, $1, $2)))))", hooks) + .iff(match -> { + RewriterStatement meta = (RewriterStatement) match.getMatchRoot().getOperands().get(0).getMeta("ncol"); + + if (meta == null) + throw new IllegalArgumentException("Column meta should not be null: " + match.getMatchRoot().getOperands().get(0).toString(ctx)); + + return !meta.isLiteral() || ((long)meta.getLiteral()) != 1; + }, true) + .apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) + .apply(hooks.get(3).getId(), stmt -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + }, true) + .apply(hooks.get(4).getId(), stmt -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + }, true) + .build() + ); + + // rowSums(A) -> _m($1:_idx(1, nrow(A)), 1, sum(_m($2:_idx(1, ncol(A)), 1, [](A, $1, $2))) + rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("rowSums(A)", hooks) + .toParsedStatement("$3:_m($1:_idx(1, nrow(A)), 1, sum($4:_m($2:_idx(1, ncol(A)), 1, [](A, $1, $2))))", hooks) + .apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) + .apply(hooks.get(3).getId(), stmt -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + }, true) + .apply(hooks.get(4).getId(), stmt -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + }, true) + .build() + ); + + // colSums(A) -> _m($1:_idx(1, ncol(A)), 1, sum(_m($2:_idx(1, nrow(A)), 1, [](A, $2, $1))) + rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("colSums(A)", hooks) + .toParsedStatement("$3:_m(1, $1:_idx(1, ncol(A)), sum($4:_m($2:_idx(1, nrow(A)), 1, [](A, $2, $1))))", hooks) + .apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) + .apply(hooks.get(3).getId(), stmt -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(1).unsafePutMeta("ownerId", id); + }, true) + .apply(hooks.get(4).getId(), stmt -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + }, true) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("INT:l") + .withParsedStatement("_idx(l, l)", hooks) + .toParsedStatement("l", hooks) + .build() + ); + + // Scalars dependent on matrix to index streams + rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("sum(A)", hooks) + .toParsedStatement("sum($3:_idxExpr($1:_idx(1, nrow(A)), $4:_idxExpr($2:_idx(1, ncol(A)), [](A, $1, $2))))", hooks) + .apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) + .apply(hooks.get(3).getId(), stmt -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + }, true) + .apply(hooks.get(4).getId(), stmt -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + }, true) + .build() + ); + + // diag(A) -> _m($1:_idx(1, nrow(A)), 1, [](A, $1, $1)) + /*rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("diag(A)", hooks) + .toParsedStatement("$2:_m($1:_idx(1, nrow(A)), 1, [](A, $1, $1))", hooks) + .apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(2).getId(), (stmt, match) -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + + RewriterStatement aRef = stmt.getChild(0, 1, 0); + match.getNewExprRoot().getAssertions(ctx).addEqualityAssertion(aRef.getNRow(), aRef.getNCol(), match.getNewExprRoot()); + }, true) + .build() + );*/ + + // cast.MATRIX(a) => _m(1, 1, a) + for (String t : List.of("INT", "BOOL", "FLOAT")) { + rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars(t + ":a") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("cast.MATRIX(a)", hooks) + .toParsedStatement("$2:_m(1, 1, a)", hooks) + .apply(hooks.get(2).getId(), (stmt, match) -> stmt.unsafePutMeta("ownerId", UUID.randomUUID()), true) + .build() + ); + } + } + + public static void expandArbitraryMatrices(final List rules, final RuleContext ctx) { + HashMap hooks = new HashMap<>(); + // This must be the last rule in the heuristic as it handles any matrix that has not been written as a stream + // A -> _m() + rules.add(new RewriterRuleBuilder(ctx, "Expand arbitrary matrix expression") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("A", hooks) + .toParsedStatement("$3:_m($1:_idx(1, nrow(A)), $2:_idx(1, ncol(A)), [](A, $1, $2))", hooks) + .iff(match -> match.getMatchRoot().getMeta("dontExpand") == null && !(match.getMatchRoot().isInstruction() && match.getMatchRoot().trueInstruction().equals("_m")), true) + .apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide + .apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) + .apply(hooks.get(3).getId(), stmt -> { + UUID id = UUID.randomUUID(); + stmt.unsafePutMeta("ownerId", id); + stmt.getOperands().get(0).unsafePutMeta("ownerId", id); + stmt.getOperands().get(1).unsafePutMeta("ownerId", id); + RewriterStatement A = stmt.getChild(0, 1, 0); + A.unsafePutMeta("dontExpand", true); + if (A.getNRow().isInstruction() && A.getNRow().trueInstruction().equals("nrow") && A.getNRow().getChild(0) == stmt) + A.getNRow().getOperands().set(0, A); + if (A.getNCol().isInstruction() && A.getNCol().trueInstruction().equals("ncol") && A.getNCol().getChild(0) == stmt) + A.getNCol().getOperands().set(0, A); + }, true) + .build() + ); + } + + // TODO: Big issue when having multiple references to the same sub-dag + public static void pushdownStreamSelections(final List rules, final RuleContext ctx) { + HashMap hooks = new HashMap<>(); + + // ifelse merging + rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("FLOAT:a,b,c,d") + .parseGlobalVars("INT:l1,l2") + .withParsedStatement("$1:ElementWiseInstruction(ifelse(==(l1, l2), a, b), ifelse(==(l1, l2), c, d))", hooks) + .toParsedStatement("ifelse(==(l1, l2), $2:ElementWiseInstruction(a, c), $3:ElementWiseInstruction(b, d))", hooks) + .linkManyUnidirectional(hooks.get(1).getId(), List.of(hooks.get(2).getId(), hooks.get(3).getId()), RewriterStatement::transferMeta, true) + .build() + ); + + SCALARS.forEach(t -> { + SCALARS.forEach(t2 -> { + // redundant ifelse elimination + rules.add(new RewriterRuleBuilder(ctx, "Remove redundant ifelse") + .setUnidirectional(true) + .parseGlobalVars(t2 + ":c,d,e") + .parseGlobalVars(t + ":a,b") + .withParsedStatement("ifelse(==(a, b), ifelse(==(a, b), c, e), d)", hooks) + .toParsedStatement("ifelse(==(a, b), c, d)", hooks) + .build() + ); + rules.add(new RewriterRuleBuilder(ctx, "Remove redundant ifelse") + .setUnidirectional(true) + .parseGlobalVars(t2 + ":c,d,e") + .parseGlobalVars(t + ":a,b") + .withParsedStatement("ifelse(==(a, b), d, ifelse(==(a, b), c, e))", hooks) + .toParsedStatement("ifelse(==(a, b), d, e)", hooks) + .build() + ); + + // ifelse expression pullup + rules.add(new RewriterRuleBuilder(ctx, "Ifelse expression pullup") + .setUnidirectional(true) + .parseGlobalVars(t + ":a,c") + .parseGlobalVars(t2 + ":d") + .parseGlobalVars("BOOL:b") + .withParsedStatement("$1:ElementWiseInstruction(ifelse(b, a, c), d)", hooks) + .toParsedStatement("ifelse(b, $2:ElementWiseInstruction(a, d), $3:ElementWiseInstruction(c, d))", hooks) + .linkManyUnidirectional(hooks.get(1).getId(), List.of(hooks.get(2).getId(), hooks.get(3).getId()), RewriterStatement::transferMeta, true) + .build() + ); + rules.add(new RewriterRuleBuilder(ctx, "Ifelse expression pullup") + .setUnidirectional(true) + .parseGlobalVars(t + ":a,c") + .parseGlobalVars(t2 + ":d") + .parseGlobalVars("BOOL:b") + .withParsedStatement("$1:ElementWiseInstruction(d, ifelse(b, a, c))", hooks) + .toParsedStatement("ifelse(b, $2:ElementWiseInstruction(d, a), $3:ElementWiseInstruction(d, c))", hooks) + .linkManyUnidirectional(hooks.get(1).getId(), List.of(hooks.get(2).getId(), hooks.get(3).getId()), RewriterStatement::transferMeta, true) + .build() + ); + }); + + rules.add(new RewriterRuleBuilder(ctx, "Ifelse branch merge") + .setUnidirectional(true) + .parseGlobalVars(t + ":a,c,d") + .parseGlobalVars("BOOL:b") + .withParsedStatement("ifelse(b, a, a)", hooks) + .toParsedStatement("a", hooks) + .build() + ); + }); + + SCALARS.forEach(t -> { + rules.add(new RewriterRuleBuilder(ctx, "Fold true statement") + .setUnidirectional(true) + .parseGlobalVars(t + ":a") + .parseGlobalVars("LITERAL_BOOL:TRUE") + .withParsedStatement("==(a,a)", hooks) + .toParsedStatement("TRUE", hooks) + .build() + ); + }); + + rules.add(new RewriterRuleBuilder(ctx, "Eliminate unnecessary branches") + .setUnidirectional(true) + .parseGlobalVars("FLOAT:a,b") + .parseGlobalVars("LITERAL_BOOL:TRUE") + .withParsedStatement("ifelse(TRUE, a, b)", hooks) + .toParsedStatement("a", hooks) + .build() + ); + rules.add(new RewriterRuleBuilder(ctx, "Eliminate unnecessary branches") + .setUnidirectional(true) + .parseGlobalVars("FLOAT:a,b") + .parseGlobalVars("LITERAL_BOOL:FALSE") + .withParsedStatement("ifelse(FALSE, a, b)", hooks) + .toParsedStatement("b", hooks) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("INT:l") + .withParsedStatement("_idx(l, l)", hooks) + .toParsedStatement("l", hooks) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "Eliminate scalar matrices") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("INT:i,j") + .parseGlobalVars("FLOAT:v") + .withParsedStatement("as.scalar(v)", hooks) + .toParsedStatement("v", hooks) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "Element selection pushdown") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("INT:h,i,j,k,l,m") + .parseGlobalVars("FLOAT:v") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("[]($1:_m(h, i, v), l, m)", hooks) + .toParsedStatement("$3:as.scalar($2:_m(l, m, v))", hooks) + .iff(match -> { + List ops = match.getMatchRoot().getOperands().get(0).getOperands(); + return (ops.get(0).isInstruction() + && ops.get(0).trueTypedInstruction(ctx).equals("_idx(INT,INT)")) + || (ops.get(1).isInstruction() + && ops.get(1).trueTypedInstruction(ctx).equals("_idx(INT,INT)")); + }, true) + .linkUnidirectional(hooks.get(1).getId(), hooks.get(2).getId(), lnk -> { + RewriterStatement.transferMeta(lnk); + + for (int idx = 0; idx < 2; idx++) { + RewriterStatement oldRef = lnk.oldStmt.getChild(idx); + + if (!oldRef.isInstruction() || !oldRef.trueTypedInstruction(ctx).equals("_idx(INT,INT)")) + continue; + + UUID oldRefId = (UUID)oldRef.getMeta("idxId"); + + RewriterStatement newRef = lnk.newStmt.get(0).getChild(idx); + + RewriterStatement newOne = RewriterUtils.replaceReferenceAware(lnk.newStmt.get(0).getChild(2), stmt -> { + UUID idxId = (UUID) stmt.getMeta("idxId"); + if (idxId != null) { + if (idxId.equals(oldRefId)) + return newRef; + } + + return null; + }); + + if (newOne != null) + lnk.newStmt.get(0).getOperands().set(2, newOne); + } + }, true) + .apply(hooks.get(3).getId(), stmt -> { + stmt.getOperands().set(0, stmt.getChild(0, 2)); + }, true) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "Scalar matrix selection pushdown") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("INT:h,i,j,k,l,m") + .parseGlobalVars("FLOAT:v") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("[]($1:_m(1, 1, v), j, k)", hooks) + .toParsedStatement("v", hooks) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "Selection pushdown") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("INT:h,i,j,k,l,m") + .parseGlobalVars("FLOAT:v") + .parseGlobalVars("LITERAL_INT:1") + .withParsedStatement("[]($1:_m(h, i, v), j, k, l, m)", hooks) + .toParsedStatement("$2:_m(_idx(1, +(+(k, 1), -(j))), _idx(1, +(+(m, 1), -(l))), v)", hooks) // Assuming that selections are valid + .linkUnidirectional(hooks.get(1).getId(), hooks.get(2).getId(), lnk -> { + RewriterStatement.transferMeta(lnk); + + for (int idx = 0; idx < 2; idx++) { + RewriterStatement oldRef = lnk.oldStmt.getOperands().get(idx); + RewriterStatement newRef = lnk.newStmt.get(0).getChild(idx); + RewriterStatement mStmtC = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("+").withOps(newRef.getChild(1, 1, 0), RewriterStatement.literal(ctx, -1L)).consolidate(ctx); + RewriterStatement mStmt = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("+").withOps(newRef, mStmtC).consolidate(ctx); + final RewriterStatement newStmt = RewriterUtils.foldConstants(mStmt, ctx); + + UUID oldRefId = (UUID)oldRef.getMeta("idxId"); + + RewriterStatement newOne = RewriterUtils.replaceReferenceAware(lnk.newStmt.get(0).getChild(2), stmt -> { + UUID idxId = (UUID) stmt.getMeta("idxId"); + if (idxId != null) { + if (idxId.equals(oldRefId)) + return newStmt; + } + + return null; + }); + + if (newOne != null) + lnk.newStmt.get(0).getOperands().set(2, newOne); + } + }, true) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "_idx(a,a) => a") + .setUnidirectional(true) + .parseGlobalVars("INT:a") + .withParsedStatement("_idx(a,a)", hooks) + .toParsedStatement("a", hooks) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "_idxExpr(i::, v) => v") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("INT:i,j") + .parseGlobalVars("FLOAT:v") + .withParsedStatement("_idxExpr(i, v)", hooks) + .toParsedStatement("v", hooks) + .iff(match -> { + List ops = match.getMatchRoot().getOperands(); + + boolean matching = (!ops.get(0).isInstruction() || !ops.get(0).trueInstruction().equals("_idx") || ops.get(0).getMeta("ownerId") != match.getMatchRoot().getMeta("ownerId")) + && (!ops.get(1).isInstruction() || !ops.get(1).trueInstruction().equals("_idx") || ops.get(1).getMeta("ownerId") != match.getMatchRoot().getMeta("ownerId")); + + return matching; + }, true) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "_idxExpr(i::, v) => v") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("INT:i,j") + .parseGlobalVars("FLOAT*:v") + .withParsedStatement("_idxExpr(i, v)", hooks) + .toParsedStatement("v", hooks) + .iff(match -> { + List ops = match.getMatchRoot().getOperands(); + + boolean matching = (!ops.get(0).isInstruction() || !ops.get(0).trueInstruction().equals("_idx") || ops.get(0).getMeta("ownerId") != match.getMatchRoot().getMeta("ownerId")) + && (!ops.get(1).isInstruction() || !ops.get(1).trueInstruction().equals("_idx") || ops.get(1).getMeta("ownerId") != match.getMatchRoot().getMeta("ownerId")); + + return matching; + }, true) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "_idxExpr(i, sum(...)) => sum(_idxExpr(i, ...))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("INT:i,j") + .parseGlobalVars("FLOAT:v") + .withParsedStatement("$1:_idxExpr(i, sum(v))", hooks) + .toParsedStatement("sum($2:_idxExpr(i, v))", hooks) + .link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "_idxExpr(i, sum(...)) => sum(_idxExpr(i, ...))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("INT:i,j") + .parseGlobalVars("FLOAT*:v") + .withParsedStatement("$1:_idxExpr(i, sum(v))", hooks) + .toParsedStatement("sum($2:_idxExpr(i, v))", hooks) + .link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta) + .build() + ); + + RewriterUtils.buildBinaryPermutations(List.of("FLOAT"), (t1, t2) -> { + rules.add(new RewriterRuleBuilder(ctx, "*(sum(_idxExpr(i, ...)), sum(_idxExpr(j, ...))) => _idxExpr(i, _idxExpr(j, sum(*(...)))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("INT:i,j") + .parseGlobalVars(t1 + ":v1") + .parseGlobalVars(t2 + ":v2") + .withParsedStatement("$1:*(sum($2:_idxExpr(i, v1)), sum($3:_idxExpr(j, v2)))", hooks) + .toParsedStatement("sum($4:_idxExpr(i, $5:_idxExpr(j, $6:*(v1, v2))))", hooks) + .link(hooks.get(1).getId(), hooks.get(6).getId(), RewriterStatement::transferMeta) + .link(hooks.get(2).getId(), hooks.get(4).getId(), RewriterStatement::transferMeta) + .link(hooks.get(3).getId(), hooks.get(5).getId(), RewriterStatement::transferMeta) + .build() + ); + }); + + rules.add(new RewriterRuleBuilder(ctx, "sum(sum(v)) => sum(v)") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("INT:i,j") + .parseGlobalVars("FLOAT:v") + .withParsedStatement("sum(sum(v))", hooks) + .toParsedStatement("sum(v)", hooks) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "sum(sum(v)) => sum(v)") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("INT:i,j") + .parseGlobalVars("FLOAT*:v") + .withParsedStatement("sum(sum(v))", hooks) + .toParsedStatement("sum(v)", hooks) + .build() + ); + + SCALARS.forEach(t -> { + rules.add(new RewriterRuleBuilder(ctx, "sum(v::" + t + ") => v::" + t) + .setUnidirectional(true) + .parseGlobalVars(t + ":v") + .withParsedStatement("sum(v)", hooks) + .toParsedStatement("v", hooks) + .build() + ); + }); + + rules.add(new RewriterRuleBuilder(ctx, "[](UnaryElementWiseOperator(A), i, j) => UnaryElementWiseOperator([](A, i, j))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("INT:i,j") + .withParsedStatement("[]($1:UnaryElementWiseOperator(A), i, j)", hooks) + .toParsedStatement("$2:UnaryElementWiseOperator([](A, i, j))", hooks) + .link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "[](ElementWiseUnary.FLOAT(A), i, j) => ElementWiseUnary.FLOAT([](A, i, j))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("INT:i,j") + .withParsedStatement("[]($1:ElementWiseUnary.FLOAT(A), i, j)", hooks) + .toParsedStatement("$2:ElementWiseUnary.FLOAT([](A, i, j))", hooks) + .link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta) + .build() + ); + + for (String t : ALL_TYPES) { + if (t.equals("MATRIX")) { + rules.add(new RewriterRuleBuilder(ctx, "ElementWiseInstruction(_m(i, j, v), b) => _m(i, j, ElementWiseInstruction(v, b))") + .setUnidirectional(true) + .parseGlobalVars("FLOAT:v") + .parseGlobalVars(t + ":B") + .parseGlobalVars("INT:i,j") + .withParsedStatement("$1:ElementWiseInstruction($2:_m(i, j, v), B)", hooks) + .toParsedStatement("$3:_m(i, j, $4:ElementWiseInstruction(v, [](B, i, j)))", hooks) + .link(hooks.get(1).getId(), hooks.get(4).getId(), RewriterStatement::transferMeta) + .link(hooks.get(2).getId(), hooks.get(3).getId(), RewriterStatement::transferMeta) + .apply(hooks.get(3).getId(), (stmt, match) -> { + // Then we an infer that the two matrices have the same dimensions + match.getNewExprRoot().getAssertions(ctx).addEqualityAssertion(stmt.getNCol(), stmt.getChild(2, 1, 0).getNCol(), match.getNewExprRoot()); + match.getNewExprRoot().getAssertions(ctx).addEqualityAssertion(stmt.getNRow(), stmt.getChild(2, 1, 0).getNRow(), match.getNewExprRoot()); + }, true) + .build() + ); + + continue; + } + rules.add(new RewriterRuleBuilder(ctx, "ElementWiseInstruction(_m(i, j, A), b) => _m(i, j, ElementWiseInstruction(A, b))") + .setUnidirectional(true) + .parseGlobalVars("FLOAT:v") + .parseGlobalVars(t + ":b") + .parseGlobalVars("INT:i,j") + .withParsedStatement("$1:ElementWiseInstruction($2:_m(i, j, v), b)", hooks) + .toParsedStatement("$3:_m(i, j, $4:ElementWiseInstruction(v, b))", hooks) + .link(hooks.get(1).getId(), hooks.get(4).getId(), RewriterStatement::transferMeta) + .link(hooks.get(2).getId(), hooks.get(3).getId(), RewriterStatement::transferMeta) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "[](ElementWiseInstruction(A, v), i, j) => ElementWiseInstruction(v, [](A, i, j))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars(t + ":v") + .parseGlobalVars("INT:i,j") + .withParsedStatement("[]($1:ElementWiseInstruction(A, v), i, j)", hooks) + .toParsedStatement("$2:ElementWiseInstruction([](A, i, j), v)", hooks) + .link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "[](ElementWiseInstruction(v, A), i, j) => ElementWiseInstruction(v, [](A, i, j))") + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars(t + ":v") + .parseGlobalVars("INT:i,j") + .withParsedStatement("[]($1:ElementWiseInstruction(v, A), i, j)", hooks) + .toParsedStatement("$2:ElementWiseInstruction(v, [](A, i, j))", hooks) + .link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta) + .build() + ); + } + } + + // This expands the statements to a common canonical form + public static void canonicalExpandAfterFlattening(final List rules, final RuleContext ctx) { + HashMap hooks = new HashMap<>(); + + rules.add(new RewriterRuleBuilder(ctx, "sum($1:_idxExpr(indices, -(A))) => -(sum($2:_idxExpr(indices, A)))") + .setUnidirectional(true) + .parseGlobalVars("FLOAT:a") + .parseGlobalVars("INT...:indices") + .withParsedStatement("sum($1:_idxExpr(indices, -(a)))", hooks) + .toParsedStatement("-(sum($2:_idxExpr(indices, a)))", hooks) + .link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "sum($1:_idxExpr(indices, -(a))) => -(sum($2:_idxExpr(indices, a)))") + .setUnidirectional(true) + .parseGlobalVars("INT:a") + .parseGlobalVars("INT...:indices") + .withParsedStatement("sum($1:_idxExpr(indices, -(a)))", hooks) + .toParsedStatement("-(sum($2:_idxExpr(indices, a)))", hooks) + .link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta) + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "sum(_idxExpr(indices, +(ops))) => +(argList(sum(_idxExpr(indices, op1)), sum(_idxExpr(...)), ...))") + .setUnidirectional(true) + .parseGlobalVars("INT...:indices") + .parseGlobalVars("FLOAT...:ops") + .withParsedStatement("sum($1:_idxExpr(indices, +(ops)))", hooks) + .toParsedStatement("$4:+($3:argList(sum($2:_idxExpr(indices, +(ops)))))", hooks) // The inner +(ops) is temporary and will be removed + .link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta) + .apply(hooks.get(3).getId(), newArgList -> { + RewriterStatement oldArgList = newArgList.getChild(0, 0, 1, 0); + newArgList.getChild(0, 0).getOperands().set(1, oldArgList.getChild(0)); + + for (int i = 1; i < oldArgList.getOperands().size(); i++) { + RewriterStatement newIdxExpr = newArgList.getChild(0, 0).copyNode(); + newIdxExpr.getOperands().set(1, oldArgList.getChild(i)); + RewriterStatement newSum = new RewriterInstruction() + .as(UUID.randomUUID().toString()) + .withInstruction("sum") + .withOps(newIdxExpr); + RewriterUtils.copyIndexList(newIdxExpr); + newIdxExpr.refreshReturnType(ctx); + newSum.consolidate(ctx); + newArgList.getOperands().add(newSum); + } + + newArgList.refreshReturnType(ctx); + }, true) + .apply(hooks.get(4).getId(), stmt -> { + stmt.refreshReturnType(ctx); + }, true) + .build() + ); + } + + public static void flattenedAlgebraRewrites(final List rules, final RuleContext ctx) { + HashMap hooks = new HashMap<>(); + + // Minus pushdown + rules.add(new RewriterRuleBuilder(ctx, "-(+(...)) => +(-(el1), -(el2), ...)") + .setUnidirectional(true) + .parseGlobalVars("FLOAT...:ops") + .withParsedStatement("-(+(ops))", hooks) + .toParsedStatement("$1:+(ops)", hooks) // Temporary + .apply(hooks.get(1).getId(), (stmt, match) -> { + RewriterStatement argList = stmt.getChild(0); + + for (int i = 0; i < argList.getOperands().size(); i++) { + RewriterInstruction newStmt = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("-").withOps(argList.getOperands().get(i)); + newStmt.consolidate(ctx); + argList.getOperands().set(i, newStmt); + } + + RewriterUtils.tryFlattenNestedOperatorPatterns(ctx, match.getNewExprRoot()); + }, true) + .build() + ); + } + + public static List buildElementWiseAlgebraicCanonicalization(final List rules, final RuleContext ctx) { + RewriterUtils.buildTernaryPermutations(List.of("FLOAT", "INT", "BOOL"), (t1, t2, t3) -> { + rules.add(new RewriterRuleBuilder(ctx, "*(+(a, b), c) => +(*(a, c), *(b, c))") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .parseGlobalVars(t3 + ":c") + .withParsedStatement("*(+(a, b), c)") + .toParsedStatement("+(*(a, c), *(b, c))") + .build() + ); + + rules.add(new RewriterRuleBuilder(ctx, "*(c, +(a, b)) => +(*(c, a), *(c, b))") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":a") + .parseGlobalVars(t2 + ":b") + .parseGlobalVars(t3 + ":c") + .withParsedStatement("*(c, +(a, b))") + .toParsedStatement("+(*(c, a), *(c, b))") + .build() + ); + }); + + /*List.of("FLOAT", "INT").forEach(t -> { + rules.add(new RewriterRuleBuilder(ctx, "-(a) => *(-1.0, a)") + .setUnidirectional(true) + .parseGlobalVars(t + ":a") + .parseGlobalVars("LITERAL_" + t + ":-1") + .withParsedStatement("-(a)") + .toParsedStatement("*(-1, a)") + .build() + ); + });*/ + + return rules; + } + + public static List replaceNegation(final List rules, final RuleContext ctx) { + List.of("FLOAT", "INT").forEach(t -> { + rules.add(new RewriterRuleBuilder(ctx, "-(a) => *(-1.0, a)") + .setUnidirectional(true) + .parseGlobalVars(t + ":a") + .parseGlobalVars("LITERAL_" + t + ":-1") + .withParsedStatement("-(a)") + .toParsedStatement("*(-1, a)") + .build() + ); + }); + + return rules; + } + + @Deprecated + public static void streamifyExpressions(final List rules, final RuleContext ctx) { + HashMap hooks = new HashMap<>(); + + ALL_TYPES.forEach(t -> { + if (t.equals("MATRIX")) + return; + + rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars(t + ":b") + .parseGlobalVars("INT:i,j") + .parseGlobalVars("FLOAT:v") + .withParsedStatement("$1:ElementWiseInstruction($3:_m(i, j, v), b)", hooks) + .toParsedStatement("$4:_m(i, j, $2:ElementWiseInstruction(v, b))", hooks) + .link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta) + .link(hooks.get(3).getId(), hooks.get(4).getId(), RewriterStatement::transferMeta) + .build()); + + rules.add(new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars(t + ":b") + .parseGlobalVars("INT:i,j") + .parseGlobalVars("FLOAT:v") + .withParsedStatement("$1:ElementWiseInstruction(b, $3:_m(i, j, v))", hooks) + .toParsedStatement("$4:_m(i, j, $2:ElementWiseInstruction(b, v))", hooks) + .link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta) + .link(hooks.get(3).getId(), hooks.get(4).getId(), RewriterStatement::transferMeta) + .build()); + }); + + + } + + public static void flattenOperations(final List rules, final RuleContext ctx) { + HashMap hooks = new HashMap<>(); + + RewriterUtils.buildBinaryPermutations(List.of("INT", "INT..."), (t1, t2) -> { + for (String t3 : List.of("FLOAT", "FLOAT*", "INT", "INT*", "BOOL", "BOOL*")) { + rules.add(new RewriterRuleBuilder(ctx, "Flatten nested index expression") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":i") + .parseGlobalVars(t2 + ":j") + .parseGlobalVars(t3 + ":v") + .withParsedStatement("$1:_idxExpr(i, $2:_idxExpr(j, v))", hooks) + .toParsedStatement("$3:_idxExpr(argList(i, j), v)", hooks) + .link(hooks.get(1).getId(), hooks.get(3).getId(), RewriterStatement::transferMeta) + .apply(hooks.get(3).getId(), (stmt, match) -> { + UUID newOwnerId = (UUID) stmt.getMeta("ownerId"); + + if (newOwnerId == null) + throw new IllegalArgumentException(); + + if (!stmt.getChild(0, 1).isLiteral()) + stmt.getOperands().get(0).getOperands().get(1).unsafePutMeta("ownerId", newOwnerId); + }, true) + .build()); + + if (t1.equals("INT")) { + // This must be executed after the rule above + rules.add(new RewriterRuleBuilder(ctx, "Flatten nested index expression") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":i") + .parseGlobalVars(t3 + ":v") + .withParsedStatement("$1:_idxExpr(i, v)", hooks) + .toParsedStatement("$3:_idxExpr(argList(i), v)", hooks) + .link(hooks.get(1).getId(), hooks.get(3).getId(), RewriterStatement::transferMeta) + .build()); + } + } + }); + + RewriterUtils.buildBinaryPermutations(List.of("MATRIX", "INT", "FLOAT", "BOOL"), (t1, t2) -> { + rules.add(new RewriterRuleBuilder(ctx, "Flatten fusable binary operator") + .setUnidirectional(true) + .parseGlobalVars(t1 + ":A") + .parseGlobalVars(t2 + ":B") + .withParsedStatement("$1:FusableBinaryOperator(A,B)", hooks) + .toParsedStatement("$2:FusedOperator(argList(A,B))", hooks) + .link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta) + .build()); + + rules.add(new RewriterRuleBuilder(ctx, "Flatten fusable binary operator") + .setUnidirectional(true) + .parseGlobalVars(t1 + "...:A") + .parseGlobalVars(t2 + ":B") + .withParsedStatement("$1:FusableBinaryOperator($2:FusedOperator(A), B)", hooks) + .toParsedStatement("$3:FusedOperator(argList(A, B))", hooks) + .iff(match -> { + return match.getMatchRoot().trueInstruction().equals(match.getMatchRoot().getOperands().get(0).trueInstruction()); + }, true) + .link(hooks.get(2).getId(), hooks.get(3).getId(), RewriterStatement::transferMeta) + .build()); + + rules.add(new RewriterRuleBuilder(ctx, "Flatten fusable binary operator") + .setUnidirectional(true) + .parseGlobalVars(t1 + "...:A") + .parseGlobalVars(t2 + ":B") + .withParsedStatement("$1:FusableBinaryOperator(B, $2:FusedOperator(A))", hooks) + .toParsedStatement("$3:FusedOperator(argList(B, A))", hooks) + .iff(match -> { + return match.getMatchRoot().trueInstruction().equals(match.getMatchRoot().getOperands().get(0).trueInstruction()); + }, true) + .link(hooks.get(2).getId(), hooks.get(3).getId(), RewriterStatement::transferMeta) + .build()); + }); + + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleCreator.java b/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleCreator.java new file mode 100644 index 00000000000..28fc7bf6028 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleCreator.java @@ -0,0 +1,537 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.rule; + +import org.apache.commons.collections4.bidimap.DualHashBidiMap; +import org.apache.commons.lang3.mutable.MutableBoolean; +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.rewriter.RewriterDataType; +import org.apache.sysds.hops.rewriter.RewriterRuntimeUtils; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.hops.rewriter.dml.DMLCodeGenerator; +import org.apache.sysds.hops.rewriter.dml.DMLExecutor; +import org.apache.sysds.hops.rewriter.estimators.RewriterCostEstimator; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import scala.Tuple2; + +import javax.annotation.Nullable; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.stream.Collectors; + +public class RewriterRuleCreator { + + private RuleContext ctx; + private RewriterRuleSet ruleSet; + private List activeRules; + + public RewriterRuleCreator(final RuleContext ctx) { + this.ctx = ctx; + activeRules = Collections.synchronizedList(new LinkedList<>()); + ruleSet = new RewriterRuleSet(ctx, activeRules); + } + + public synchronized void forEachRule(Consumer consumer) { + activeRules.forEach(consumer); + } + + public boolean registerRule(RewriterRule rule, Function canonicalFormConverter, final RuleContext ctx) { + try { + return registerRule(rule, RewriterCostEstimator.estimateCost(rule.getStmt1(), ctx), RewriterCostEstimator.estimateCost(rule.getStmt2(), ctx), false, canonicalFormConverter); + } catch (Exception e) { + System.err.println("Error while registering a rule: " + rule); + e.printStackTrace(); + return false; + } + } + + public synchronized boolean registerRule(RewriterRule rule, long preCost, long postCost, boolean validateCorrectness, Function canonicalFormCreator) { + // First, we check if an existing rule already applies an equivalent rewrite (cost wise) + RewriterStatement toTest = rule.getStmt1().nestedCopy(false); + + RewriterStatement newStmt = rule.getStmt2().nestedCopy(false); + + boolean converged = false; + boolean changed = false; + + List appliedRules = new ArrayList<>(); + + for (int i = 0; i < 500; i++) { + RewriterRuleSet.ApplicableRule applicableRule = ruleSet.acceleratedFindFirst(newStmt, true); + + if (applicableRule == null) { + converged = true; + break; // Then we converged + } + + newStmt = applicableRule.rule.apply(applicableRule.matches.get(0), newStmt, applicableRule.forward, false); + RewriterUtils.mergeArgLists(newStmt, ctx); + newStmt = RewriterUtils.foldConstants(newStmt, ctx); + appliedRules.add(applicableRule.rule); + changed = true; + } + + if (!converged) + throw new IllegalArgumentException("The existing rule-set did not seem to converge for the example: \n" + toTest.toParsableString(ctx, true) + "\n" + String.join("\n", appliedRules.subList(appliedRules.size()-5, appliedRules.size()).stream().map(rl -> rl.toParsableString(ctx)).collect(Collectors.toList()))); + + appliedRules.clear(); + + for (int i = 0; i < 500; i++) { + RewriterRuleSet.ApplicableRule applicableRule = ruleSet.acceleratedFindFirst(toTest, true); + + if (applicableRule == null) { + converged = true; + break; // Then we converged + } + + toTest = applicableRule.rule.apply(applicableRule.matches.get(0), toTest, applicableRule.forward, false); + + RewriterUtils.mergeArgLists(toTest, ctx); + toTest = RewriterUtils.foldConstants(toTest, ctx); + appliedRules.add(applicableRule.rule); + changed = true; + } + + if (!converged) + throw new IllegalArgumentException("The existing rule-set did not seem to converge for the example: \n" + toTest.toParsableString(ctx, true) + "\n" + String.join("\n", appliedRules.stream().map(rl -> rl.toParsableString(ctx)).collect(Collectors.toList()))); + + if (newStmt != rule.getStmt2()) { + // Then the mapping has changed, and we need to + try { + postCost = RewriterCostEstimator.estimateCost(newStmt, ctx); + } catch (Exception e) { + System.err.println("Err in cost from orig: " + rule.getStmt2().toParsableString(ctx)); + System.err.println("NewStmt: " + newStmt.toParsableString(ctx)); + e.printStackTrace(); + return false; + } + } + + if (changed) { + long existingPostCost; + + try { + existingPostCost = RewriterCostEstimator.estimateCost(toTest, ctx); + } catch (Exception e) { + System.err.println("Err in cost from orig: " + rule.getStmt1().toParsableString(ctx)); + System.err.println("ToTest: " + toTest.toParsableString(ctx)); + System.err.println("AppliedRules: " + appliedRules); + e.printStackTrace(); + return false; + } + + if (existingPostCost <= postCost || preCost >= postCost) + return false; // Then this rule is not beneficial + } + + // We might have to rebuild the rule + if (changed || newStmt != rule.getStmt2()) { + try { + rule = createRule(toTest, newStmt, canonicalFormCreator.apply(toTest), canonicalFormCreator.apply(newStmt), ctx); + } catch (Exception e) { + System.err.println("Failed to create: " + toTest.toParsableString(ctx) + " => " + newStmt.toParsableString(ctx)); + } + } + + + if (validateCorrectness) { + // Now, we validate the rule by executing it in the system + if (!validateRuleCorrectnessAndGains(rule, ctx)) + return false; // Then, either the rule is incorrect or is already implemented + } + + //System.out.println("Rule is correct!"); + + RewriterRuleSet probingSet = new RewriterRuleSet(ctx, List.of(rule)); + List rulesToRemove = new ArrayList<>(); + List rulesThatMustComeBefore = new ArrayList<>(); + + // Check for interactions between different rules + for (RewriterRule existingRule : activeRules) { + RewriterStatement mProbe = existingRule.getStmt1(); + RewriterRuleSet.ApplicableRule applicableRule = probingSet.acceleratedFindFirst(mProbe); + + if (applicableRule != null) { + // Then we have to take a deeper look into the interaction between the rules + // Either the new rule achieves a better result -> the old rule can be eliminated + // Or the new rule finds a worse rewrite for the existing rule -> Then the existing rule must be kept and be applied before the new rule + mProbe = mProbe.nestedCopy(true); + + for (int i = 0; i < 20; i++) { + applicableRule = probingSet.acceleratedFindFirst(mProbe); + + if (i == 19) + throw new IllegalArgumentException("The following rule created a conflict with another rule:\nNew one:\n" + rule + "\t[Cost: " + preCost + " => " + postCost + "]\nExisting:\n" + existingRule + "\t[Cost: " + existingRule.getStmt1().getCost(ctx) + " => " + existingRule.getStmt2().getCost(ctx) + "]"); + if (applicableRule != null) + mProbe = applicableRule.rule.apply(applicableRule.matches.get(0), mProbe, applicableRule.forward, false); + else + break; + } + + long newCost = mProbe.getCost(ctx); + long existingRuleNewCost = existingRule.getStmt2().getCost(ctx); + + if (newCost == -1 || existingRuleNewCost == -1) + throw new IllegalArgumentException("The rule set or the new rule resulted in an invalid cost:\nNew one:\n" + rule + "\nExisting:\n" + existingRule); + + if (newCost <= existingRuleNewCost) { + // Then we remove the old rule + rulesToRemove.add(existingRule); + } else { + // Then the existing rule is still legitimate and must come before the new rule as it is more specific + rulesThatMustComeBefore.add(existingRule); + } + } + } + + // Check if rule is expansive (e.g. expands itself leading to an infinite loop) + RewriterRuleSet testSet = new RewriterRuleSet(ctx, List.of(rule)); + testSet.accelerate(); + RewriterStatement mProbe = rule.getStmt2(); + if (testSet.acceleratedFindFirst(mProbe) != null) + throw new IllegalArgumentException("Expansive rule detected!"); + + + activeRules.removeAll(rulesToRemove); + + // Now, we include the rule to the system + // TODO: Further checks are needed, especially if the new heuristic converges in all cases + activeRules.add(rule); + + ruleSet.accelerate(); + + return true; + } + + public RewriterRuleSet getRuleSet() { + return ruleSet; + } + + public void throwOutInvalidRules(boolean correctness, boolean relevance) { + if (!correctness && !relevance) + return; + + activeRules.removeIf(rule -> (correctness && !validateRuleCorrectness(rule, ctx)) || (relevance && !validateRuleApplicability(rule, ctx))); + ruleSet.accelerate(); + } + + + + + + + ///// STATIC METHODS ///// + + // This runs the rule from expressions + public static boolean validateRuleCorrectnessAndGains(RewriterRule rule, final RuleContext ctx) { + return validateRuleCorrectness(rule, ctx) && validateRuleApplicability(rule, ctx); + } + + public static boolean validateRuleCorrectness(RewriterRule rule, final RuleContext ctx) { + RewriterUtils.renameIllegalVarnames(ctx, rule.getStmt1(), rule.getStmt2()); + String sessionId = UUID.randomUUID().toString(); + String code = DMLCodeGenerator.generateRuleValidationDML(rule, sessionId, ctx); + + MutableBoolean isValid = new MutableBoolean(false); + boolean successful = DMLExecutor.executeCode(code, DMLCodeGenerator.ruleValidationScript(rule.toParsableString(ctx), sessionId, isValid::setValue)); + + if (!isValid.booleanValue()) { + String errStr = "An invalid rule was found: " + rule + "\n\tReason: " + (successful ? "Assertion" : "Error"); + + if (!successful && !DMLExecutor.getLastErr().isEmpty()) + errStr += " (" + DMLExecutor.getLastErr().get(0) + ")"; + + DMLExecutor.println(errStr); + } + + return isValid.booleanValue(); + } + + public static boolean validateRuleApplicability(RewriterRule rule, final RuleContext ctx) { + return validateRuleApplicability(rule, ctx, false, null); + } + + public static boolean validateRuleApplicability(RewriterRule rule, final RuleContext ctx, boolean print, @Nullable Function injectedRewriteClass) { + RewriterStatement _mstmt = rule.getStmt1(); + RewriterStatement _mstmt2 = rule.getStmt2(); + if (ctx.metaPropagator != null) { + ctx.metaPropagator.apply(_mstmt); + ctx.metaPropagator.apply(_mstmt2); + } + + final RewriterStatement stmt1 = RewriterUtils.unfuseOperators(_mstmt, ctx); + + Set vars = DMLCodeGenerator.getVariables(stmt1); + Set varNames = vars.stream().map(RewriterStatement::getId).collect(Collectors.toSet()); + String code2Header = DMLCodeGenerator.generateDMLVariables(vars); + String code2 = code2Header + "\nresult = " + DMLCodeGenerator.generateDML(stmt1); + + boolean isMatrix = stmt1.getResultingDataType(ctx).equals("MATRIX"); + + if (isMatrix) + code2 += "\nprint(lineage(result))"; + else + code2 += "\nprint(lineage(as.matrix(result)))"; + + MutableBoolean isRelevant = new MutableBoolean(false); + + final RewriterStatement expectedStmt = injectedRewriteClass != null ? _mstmt2 : _mstmt; + + RewriterRuntimeUtils.attachHopInterceptor(prog -> { + Hop hop; + + if (isMatrix) + hop = prog.getStatementBlocks().get(0).getHops().get(0).getInput(0).getInput(0); + else + hop = prog.getStatementBlocks().get(0).getHops().get(0).getInput(0).getInput(0).getInput(0); + + RewriterStatement stmt = RewriterRuntimeUtils.buildDAGFromHop(hop, 1000, true, ctx); + + if (stmt == null) + return false; + + Map nameAssocs = new HashMap<>(); + // Find the variables that are actually leafs in the original rule + stmt.forEachPreOrder(cur -> { + for (int i = 0; i < cur.getOperands().size(); i++) { + RewriterStatement child = cur.getChild(i); + + if (varNames.contains(child.getId())) { + RewriterStatement assoc = nameAssocs.get(child.getId()); + + if (assoc == null) { + assoc = new RewriterDataType().as(child.getId()).ofType(child.getResultingDataType(ctx)).consolidate(ctx); + + Long ncol = (Long) child.getMeta("_actualNCol"); + Long nrow = (Long) child.getMeta("_actualNRow"); + + if (ncol != null) + assoc.unsafePutMeta("_actualNCol", ncol); + + if (nrow != null) + assoc.unsafePutMeta("_actualNRow", nrow); + + nameAssocs.put(child.getId(), assoc); + } + + cur.getOperands().set(i, assoc); + } + } + + return true; + }, false); + + stmt = RewriterRuntimeUtils.populateDataCharacteristics(stmt, ctx); + stmt = ctx.metaPropagator.apply(stmt); + + stmt = stmt.nestedCopyOrInject(new HashMap<>(), mstmt -> { + if (mstmt.isInstruction() && (mstmt.trueInstruction().equals("ncol") || mstmt.trueInstruction().equals("nrow"))) + return RewriterStatement.literal(ctx, DMLCodeGenerator.MATRIX_DIMS); + return null; + }); + + stmt.prepareForHashing(); + stmt.recomputeHashCodes(ctx); + + Map createdObjects = new HashMap<>(); + + RewriterStatement stmt1ReplaceNCols = expectedStmt.nestedCopyOrInject(createdObjects, mstmt -> { + if (mstmt.isInstruction() && (mstmt.trueInstruction().equals("ncol") || mstmt.trueInstruction().equals("nrow"))) + return RewriterStatement.literal(ctx, DMLCodeGenerator.MATRIX_DIMS); + return null; + }); + + stmt1ReplaceNCols.prepareForHashing(); + stmt1ReplaceNCols.recomputeHashCodes(ctx); + + Set mVars = vars.stream().map(createdObjects::get).filter(Objects::nonNull).collect(Collectors.toSet()); + + if (print) { + DMLExecutor.println("Observed statement: " + stmt.toParsableString(ctx)); + DMLExecutor.println("Expected statement: " + stmt1ReplaceNCols.toParsableString(ctx)); + } + + RewriterStatement.MatcherContext mCtx = RewriterStatement.MatcherContext.exactMatch(ctx, stmt, stmt1ReplaceNCols); + if (stmt1ReplaceNCols.match(mCtx)) { + // Check if also the right variables are associated + boolean assocsMatching = true; + if (mCtx.getDependencyMap() != null) { + for (RewriterStatement var : mVars) { + RewriterStatement assoc = mCtx.getDependencyMap().get(var.isInstruction() && !var.trueInstruction().equals("const") ? var.getChild(0) : var); + + if (assoc == null) + throw new IllegalArgumentException("Association is null!"); + + if (!assoc.getId().equals(var.getId())) { + assocsMatching = false; + break; + } + } + } + + if (assocsMatching) { + // Then the rule matches, meaning that the statement is not rewritten by SystemDS + isRelevant.setValue(true); + } + } + + // TODO: Maybe we can still rewrite the new graph if it still has less cost + + // TODO: Evaluate cost and if our rule can still be applied + return injectedRewriteClass != null; // The program should not be executed as we just want to extract any rewrites that are applied to the current statement + }); + + MutableBoolean wasApplied = new MutableBoolean(true); + + if (injectedRewriteClass != null) { + String ruleStr = rule.toString(); + wasApplied.setValue(false); + DMLExecutor.executeCode(code2, s -> { + if (s.equals("Applying rewrite: " + ruleStr)) { + wasApplied.setValue(true); + } + }, injectedRewriteClass); + } else { + DMLExecutor.executeCode(code2, true); + } + + RewriterRuntimeUtils.detachHopInterceptor(); + + return isRelevant.booleanValue() && wasApplied.booleanValue(); + } + + public static RewriterRule createRule(RewriterStatement from, RewriterStatement to, RewriterStatement canonicalForm1, RewriterStatement canonicalForm2, final RuleContext ctx) { + Tuple2 commonForm = createCommonForm(from, to, canonicalForm1, canonicalForm2, ctx); + from = commonForm._1; + to = commonForm._2; + + return new RewriterRuleBuilder(ctx, "Autogenerated rule").setUnidirectional(true).completeRule(from, to).build(); + } + + public static RewriterRule createRuleFromCommonStatements(RewriterStatement from, RewriterStatement to, final RuleContext ctx) { + return new RewriterRuleBuilder(ctx, "Autogenerated rule").setUnidirectional(true).completeRule(from, to).build(); + } + + public static RewriterRule createConditionalRuleFromCommonStatements(RewriterStatement from, List to, final RuleContext ctx) { + return new RewriterRuleBuilder(ctx, "Autogenerated conditional rule").setUnidirectional(true).completeConditionalRule(from, to).build(); + } + + public static Tuple2 createCommonForm(RewriterStatement from, RewriterStatement to, RewriterStatement canonicalForm1, RewriterStatement canonicalForm2, final RuleContext ctx) { + from = from.nestedCopy(true); + Map assocs = getAssociations(from, to, canonicalForm1, canonicalForm2, ctx); + // Now, we replace all variables with a common element + from.forEachPreOrder((cur, pred) -> { + for (int i = 0; i < cur.getOperands().size(); i++) { + RewriterStatement child = cur.getChild(i); + + if (child instanceof RewriterDataType && !child.isLiteral()) { + RewriterStatement newRef = assocs.get(child); + + if (newRef != null) + cur.getOperands().set(i, newRef); + } + } + + return true; + }, false); + + from = ctx.metaPropagator.apply(from); + return new Tuple2<>(from, to); + } + + private static Map getAssociations(RewriterStatement from, RewriterStatement to, RewriterStatement canonicalFormFrom, RewriterStatement canonicalFormTo, final RuleContext ctx) { + Map fromCanonicalLink = getAssociationToCanonicalForm(from, canonicalFormFrom, true, ctx); + Map toCanonicalLink = getAssociationToCanonicalForm(to, canonicalFormTo, true, ctx); + + RewriterStatement.MatcherContext matcher = RewriterStatement.MatcherContext.exactMatch(ctx, canonicalFormTo, canonicalFormFrom); + canonicalFormFrom.match(matcher); + + Map assocs = new HashMap<>(); + matcher.getDependencyMap().forEach((k, v) -> { + if (k.isLiteral()) + return; + + RewriterStatement newKey = fromCanonicalLink.get(k); + RewriterStatement newValue = toCanonicalLink.get(v); + + if (newKey == null || newValue == null) + return; + + assocs.put(newKey, newValue); + }); + + return assocs; + } + + private static Random rd = new Random(); + private static Map getAssociationToCanonicalForm(RewriterStatement stmt, RewriterStatement canonicalForm, boolean reversed, final RuleContext ctx) { + // We identify all associations by their names + // If there are name collisions, this does not work + Map namedVariables = new HashMap<>(); + stmt.forEachPostOrder((cur, pred) -> { + if (!(cur instanceof RewriterDataType) || cur.isLiteral()) + return; + + if (namedVariables.put(cur.getId(), cur) != null) + throw new IllegalArgumentException("Duplicate variable name: " + cur.toParsableString(RuleContext.currentContext) + "\nEntire statement:\n" + stmt.toParsableString(ctx) + "\nRaw: " + stmt); + }, false); + + Map assoc = new DualHashBidiMap<>(); + + canonicalForm.forEachPostOrder((cur, pred) -> { + if (!(cur instanceof RewriterDataType) || cur.isLiteral()) + return; + + RewriterStatement ref = namedVariables.get(cur.getId()); + + if (ref == null) { + assoc.put(ref, ref); + } + + if (reversed) + assoc.put(cur, ref); + else + assoc.put(ref, cur); + }, false); + + namedVariables.values().forEach(ref -> { + if (reversed) { + if (!assoc.containsValue(ref)) + ref.rename("u_" + rd.nextInt(100000)); + } else { + if (!assoc.containsKey(ref)) + ref.rename("u_" + rd.nextInt(100000)); + } + }); + + return assoc; + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleSet.java b/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleSet.java new file mode 100644 index 00000000000..d64de719c8d --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleSet.java @@ -0,0 +1,345 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.rule; + +import org.apache.commons.collections4.bidimap.DualHashBidiMap; +import org.apache.commons.lang3.mutable.MutableObject; +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.rewriter.RewriterInstruction; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.hops.rewriter.codegen.RewriterCodeGen; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import scala.Tuple2; +import scala.Tuple3; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.BiConsumer; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class RewriterRuleSet { + + public static class ApplicableRule { + public final ArrayList matches; + public final RewriterRule rule; + public final boolean forward; + + public ApplicableRule(ArrayList matches, RewriterRule rule, boolean forward) { + this.matches = matches; + this.rule = rule; + this.forward = forward; + } + + public String toString(final RuleContext ctx) { + StringBuilder builder = new StringBuilder(); + builder.append("Rule: " + rule + "\n\n"); + int ctr = 1; + for (RewriterStatement.MatchingSubexpression match : matches) { + builder.append("Match " + ctr++ + ": \n"); + builder.append(" " + match.getMatchRoot() + " = " + (forward ? rule.getStmt1() : rule.getStmt2()) + "\n\n"); + for (Map.Entry entry : match.getAssocs().entrySet()) { + builder.append(" - " + entry.getKey() + "::" + (ctx == null ? "?" : entry.getKey().getResultingDataType(ctx)) + " -> " + entry.getValue().getId() + "::" + (ctx == null ? "?" : entry.getValue().getResultingDataType(ctx)) + "\n"); + } + builder.append("\n"); + } + + return builder.toString(); + } + + @Override + public String toString() { + return toString(null); + } + } + + private RuleContext ctx; + private List rules; + private Map>> accelerator; + + public RewriterRuleSet(RuleContext ctx, List rules) { + this.ctx = ctx; + this.rules = rules; + accelerate(); + } + + public RuleContext getContext() { + return ctx; + } + + public void determineConditionalApplicability() { + rules.forEach(RewriterRule::determineConditionalApplicability); + } + + public void forEachRule(BiConsumer consumer) { + rules.forEach(r -> consumer.accept(r, ctx)); + } + + public List getRules() { + return rules; + } + + public ApplicableRule acceleratedFindFirst(RewriterStatement root) { + return acceleratedFindFirst(root, false); + } + + public ApplicableRule acceleratedFindFirst(RewriterStatement root, boolean allowImplicitTypeConversions) { + List match = acceleratedRecursiveMatch(root, true, allowImplicitTypeConversions); + if (match.isEmpty()) + return null; + else + return match.get(0); + } + + public List acceleratedRecursiveMatch(RewriterStatement root, boolean findFirst, boolean allowImplicitTypeConversions) { + List> matches = new ArrayList<>(); + MutableObject> dependencyMap = new MutableObject<>(new HashMap<>()); + MutableObject> links = new MutableObject<>(new ArrayList<>()); + MutableObject> linkObjects = new MutableObject<>(new HashMap<>()); + + root.forEachPreOrder((el, pred) -> { + String typedStr = el.isInstruction() ? el.trueTypedInstruction(allowImplicitTypeConversions, ctx) : RewriterUtils.convertImplicitly(el.getResultingDataType(ctx), allowImplicitTypeConversions); + Set props = el instanceof RewriterInstruction ? ((RewriterInstruction)el).getProperties(ctx) : Collections.emptySet(); + boolean found = acceleratedMatch(root, el, matches, typedStr, RewriterUtils.convertImplicitly(el.getResultingDataType(ctx), allowImplicitTypeConversions), props, pred, dependencyMap, links, linkObjects, findFirst, allowImplicitTypeConversions); + return !findFirst || !found; + }, true); + + Map, ApplicableRule> uniqueRules = new HashMap<>(); + + for (Tuple3 match : matches) { + Tuple2 t = new Tuple2<>(match._1(), match._2()); + + if (uniqueRules.containsKey(t)) + uniqueRules.get(t).matches.add(match._3()); + else { + ArrayList list = new ArrayList<>(); + list.add(match._3()); + uniqueRules.put(t, new ApplicableRule(list, match._1(), match._2())); + } + } + + return new ArrayList<>(uniqueRules.values()); + } + + public boolean acceleratedMatch(RewriterStatement exprRoot, RewriterStatement stmt, List> appRules, String realTypedInstr, String realType, Set properties, RewriterStatement.RewriterPredecessor pred, MutableObject> dependencyMap, MutableObject> links, MutableObject> linkObjects, boolean findFirst, boolean allowImplicitTypeConversions) { + List> potentialMatches; + boolean foundMatch = false; + + if (realTypedInstr != null) { + potentialMatches = accelerator.get(realTypedInstr); + if (potentialMatches != null) { + foundMatch |= checkPotentialMatches(stmt, potentialMatches, appRules, pred, dependencyMap, links, linkObjects, exprRoot, findFirst, allowImplicitTypeConversions); + + if (foundMatch && findFirst) + return true; + } + } + + potentialMatches = accelerator.get(realType); + if (potentialMatches != null) { + foundMatch |= checkPotentialMatches(stmt, potentialMatches, appRules, pred, dependencyMap, links, linkObjects, exprRoot, findFirst, allowImplicitTypeConversions); + + if (foundMatch && findFirst) + return true; + } + + if (properties != null) { + for (String props : properties) { + potentialMatches = accelerator.get(props); + if (potentialMatches != null) { + foundMatch |= checkPotentialMatches(stmt, potentialMatches, appRules, pred, dependencyMap, links, linkObjects, exprRoot, findFirst, allowImplicitTypeConversions); + + if (foundMatch && findFirst) + return true; + } + } + } + + return foundMatch; + } + + private boolean checkPotentialMatches(RewriterStatement stmt, List> potentialMatches, List> appRules, RewriterStatement.RewriterPredecessor pred, MutableObject> dependencyMap, MutableObject> links, MutableObject> linkObjects, RewriterStatement exprRoot, boolean findFirst, boolean allowImplicitTypeConversions) { + boolean anyMatch = false; + for (Tuple2 m : potentialMatches) { + RewriterStatement.MatchingSubexpression match; + + if (m._2()) { + match = m._1().matchSingleStmt1(exprRoot, pred, stmt, allowImplicitTypeConversions); + } else { + match = m._1().matchSingleStmt2(exprRoot, pred, stmt, allowImplicitTypeConversions); + } + + if (match != null) { + appRules.add(new Tuple3<>(m._1(), m._2(), match)); + dependencyMap.setValue(new HashMap<>()); + links.setValue(new ArrayList<>()); + linkObjects.setValue(new HashMap<>()); + + if (findFirst) + return true; + + anyMatch = true; + } else { + dependencyMap.getValue().clear(); + links.getValue().clear(); + linkObjects.getValue().clear(); + } + } + + return anyMatch; + } + + // Look for intersecting roots and try to find them once + public void accelerate() { + accelerator = new HashMap<>(); + for (RewriterRule rule : rules) { + accelerate(rule, true); + if (!rule.isUnidirectional()) + accelerate(rule, false); + } + } + + private void accelerate(RewriterRule rule, boolean forward) { + RewriterStatement stmt = forward ? rule.getStmt1() : rule.getStmt2(); + String t = stmt.isInstruction() ? stmt.trueTypedInstruction(ctx) : stmt.getResultingDataType(ctx); + List> l = accelerator.get(t); + + if (l == null) { + l = new ArrayList<>(); + accelerator.put(t, l); + } + + l.add(new Tuple2<>(rule, forward)); + } + + @Override + public String toString() { + return serialize(); + } + + public String serialize() { + StringBuilder sb = new StringBuilder(); + + for (RewriterRule rule : rules) { + try { + sb.append("::RULE\n"); + sb.append(rule.toParsableString(ctx)); + sb.append("\n\n"); + } catch (Exception e) { + e.printStackTrace(); + } + } + + return sb.toString(); + } + + public Set generateCodeAndTest(boolean optimize, boolean print) { + String javaCode = toJavaCode("MGeneratedRewriteClass", optimize, false, true, true); + Function f = RewriterCodeGen.compile(javaCode, "MGeneratedRewriteClass"); + + if (f == null) + return null; // Then, the code could not compile + + Set removed = new HashSet<>(); + + for (int i = 0; i < rules.size(); i++) { + if (!RewriterRuleCreator.validateRuleApplicability(rules.get(i), ctx, print, f)) { + System.out.println("Faulty rule: " + rules.get(i)); + removed.add(rules.get(i)); + } + } + + return removed; + } + + public static RewriterRuleSet deserialize(String data, final RuleContext ctx) { + return deserialize(data.split("\n"), ctx); + } + + public static RewriterRuleSet deserialize(List data, final RuleContext ctx) { + return deserialize(data.toArray(String[]::new), ctx); + } + + public static RewriterRuleSet deserialize(String[] data, final RuleContext ctx) { + List currentLines = new ArrayList<>(); + List rules = new ArrayList<>(); + + for (int i = 0; i < data.length; i++) { + if (data[i].equals("::RULE")) { + if (!currentLines.isEmpty()) { + rules.add(RewriterUtils.parseRule(String.join("\n", currentLines), ctx)); + currentLines.clear(); + } + } else { + currentLines.add(data[i]); + } + } + + if (!currentLines.isEmpty()) { + rules.add(RewriterUtils.parseRule(String.join("\n", currentLines), ctx)); + currentLines.clear(); + } + + for (RewriterRule rule : rules) { + try { + rule.determineConditionalApplicability(); + } catch (Exception e) { + System.err.println("Error while determining the conditional ability of " + rule.toString()); + e.printStackTrace(); + } + } + + return new RewriterRuleSet(ctx, rules); + } + + public String toJavaCode(String className, boolean optimize, boolean includePackageInfo, boolean printErrors, boolean maintainStatistics) { + List> mRules = IntStream.range(0, rules.size()).mapToObj(i -> new Tuple2<>("_applyRewrite" + i, rules.get(i))).collect(Collectors.toList()); + return RewriterCodeGen.generateClass(className, mRules, optimize, 2, includePackageInfo, ctx, true, printErrors, maintainStatistics); + } + + public String toJavaCode(String className, boolean optimize) { + List> mRules = IntStream.range(0, rules.size()).mapToObj(i -> new Tuple2<>("_applyRewrite" + i, rules.get(i))).collect(Collectors.toList()); + return RewriterCodeGen.generateClass(className, mRules, optimize, 2, true, ctx, true, true, false); + } + + public String toJavaCode(String className, boolean optimize, int maxOptimizationDepth, boolean includePackageInfo, boolean printErrors, boolean maintainStatistics) { + List> mRules = IntStream.range(0, rules.size()).mapToObj(i -> new Tuple2<>("_applyRewrite" + i, rules.get(i))).collect(Collectors.toList()); + return RewriterCodeGen.generateClass(className, mRules, optimize, maxOptimizationDepth, includePackageInfo, ctx, true, printErrors, maintainStatistics); + } + + public Function compile(String className, boolean printErrors) { + try { + List> mRules = IntStream.range(0, rules.size()).mapToObj(i -> new Tuple2<>("_applyRewrite" + i, rules.get(i))).collect(Collectors.toList()); + return RewriterCodeGen.compileRewrites(className, mRules, ctx, true, printErrors); + } catch (Exception e) { + if (printErrors) + e.printStackTrace(); + + return null; + } + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/utils/CodeGenUtils.java b/src/main/java/org/apache/sysds/hops/rewriter/utils/CodeGenUtils.java new file mode 100644 index 00000000000..533f646a7ed --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/utils/CodeGenUtils.java @@ -0,0 +1,498 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.utils; + +import org.apache.commons.lang3.NotImplementedException; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +public class CodeGenUtils { + public static String getSpecialOpCheck(RewriterStatement stmt, final RuleContext ctx, String hopVar) { + if (!stmt.isInstruction()) + return null; + switch (stmt.trueInstruction()) { + case "%*%": + return "HopRewriteUtils.isMatrixMultiply(" + hopVar + ")"; + } + + return null; + } + + public static String getAdditionalCheck(RewriterStatement stmt, final RuleContext ctx, String hopVar) { + if (!stmt.isInstruction()) + return null; + + switch (stmt.trueInstruction()) { + case "rowSums": + return hopVar + ".getDirection() == Types.Direction.Row"; + case "colSums": + return hopVar + ".getDirection() == Types.Direction.Col"; + case "sum": + return hopVar + ".getDirection() == Types.Direction.RowCol"; + } + + return null; + } + + public static String getOpCode(RewriterStatement stmt, final RuleContext ctx) { + if (stmt.getOperands().size() == 1) { + // Handle unary ops + // TODO: nrow, ncol, length + switch (stmt.trueInstruction()) { + case "t": + return "Types.ReOrgOp.TRANS"; + case "rev": + return "Types.ReOrgOp.REV"; + case "!": + return "Types.OpOp1.NOT"; + case "sqrt": + return "Types.OpOp1.SQRT"; + case "sq": + return "Types.OpOp1.POW2"; + case "log": + return "Types.OpOp1.LOG"; + case "abs": + return "Types.OpOp1.ABS"; + case "round": + return "Types.OpOp1.ROUND"; + case "rowSums": + case "colSums": + case "sum": + return "Types.AggOp.SUM"; + case "trace": + return "Types.AggOp.TRACE"; + case "*2": + return "Types.OpOp1.MULT2"; + case "cast.MATRIX": + return "Types.OpOp1.CAST_AS_MATRIX"; + case "cast.FLOAT": + return "Types.OpOp1.CAST_AS_SCALAR"; + case "const": + return "Types.OpOpDG.RAND"; + case "nrow": + return "Types.OpOp1.NROW"; + case "ncol": + return "Types.OpOp1.NCOL"; + case "length": + return "Types.OpOp1.LENGTH"; + } + } else if (stmt.getOperands().size() == 2) { + switch (stmt.trueInstruction()) { + case "+": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.PLUS"; + case "-": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.MINUS"; + case "*": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.MULT"; + case "/": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.DIV"; + case "min": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.MIN"; + case "max": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.MAX"; + case "!=": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.NOTEQUAL"; + case "==": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.EQUAL"; + case ">": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.GREATER"; + case ">=": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.GREATEREQUAL"; + case "<": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.LESS"; + case "<=": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.LESSEQUAL"; + case "&": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.AND"; + case "|": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.OR"; + case "^": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.POW"; + + case "RBind": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.RBIND"; + case "CBind": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.CBIND"; + case "1-*": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "Types.OpOp2.MINUS1_MULT"; + case "log_nz": + if (stmt.getOperands().size() != 1) + throw new IllegalArgumentException(); + + return "Types.OpOp1.LOG_NZ"; + + case "%*%": + return "true"; // This should be resolved by the custom handler function + } + } else { + switch (stmt.trueInstruction()) { + case "+*": + if (stmt.getOperands().size() != 3) + throw new IllegalArgumentException(); + + return "Types.OpOp3.PLUS_MULT"; + case "-*": + if (stmt.getOperands().size() != 3) + throw new IllegalArgumentException(); + + return "Types.OpOp3.MINUS_MULT"; + case "literal.FLOAT": + return null; // There is no opcheck on literals + } + } + + throw new NotImplementedException(stmt.trueInstruction()); + } + + /** + * + * @param stmt the statement + * @param ctx the context + * @return a list of operand indices that must be matched + */ + public static List matchingDimRequirement(RewriterStatement stmt, final RuleContext ctx) { + switch (stmt.trueInstruction()) { + case "1-*": + return List.of(0, 1); + case "+*": + case "-*": + return List.of(0, 2); + default: + return Collections.emptyList(); + } + } + + public static boolean opRequiresBinaryBroadcastingMatch(RewriterStatement stmt, final RuleContext ctx) { + return getOpClass(stmt, ctx).equals("BinaryOp") && stmt.getChild(0).getResultingDataType(ctx).equals("MATRIX") && stmt.getChild(1).getResultingDataType(ctx).equals("MATRIX"); + } + + public static String getOpClass(RewriterStatement stmt, final RuleContext ctx) { + switch (stmt.trueInstruction()) { + case "!": + case "sqrt": + case "log": + case "abs": + case "round": + case "*2": + case "cast.MATRIX": + case "cast.FLOAT": + case "nrow": + case "ncol": + case "length": + case "sq": + return "UnaryOp"; + + case "rowSums": + case "colSums": + case "sum": + case "trace": + return "AggUnaryOp"; + + case "+": + case "-": + case "*": + case "/": + case "min": + case "max": + case "!=": + case "==": + case ">": + case ">=": + case "<": + case "<=": + case "&": + case "|": + case "^": + case "RBind": + case "CBind": + case "1-*": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "BinaryOp"; + + case "%*%": + if (stmt.getOperands().size() != 2) + throw new IllegalArgumentException(); + + return "AggBinaryOp"; + + case "t": + case "rev": + return "ReorgOp"; + + case "+*": + case "-*": + return "TernaryOp"; + + case "const": + return "DataGenOp"; + + case "literal.FLOAT": + case "literal.INT": + case "literal.BOOL": + return "LiteralOp"; + } + + throw new NotImplementedException(stmt.trueTypedInstruction(ctx)); + } + + public static String[] getReturnType(RewriterStatement stmt, final RuleContext ctx) { + return getReturnType(stmt.getResultingDataType(ctx)); + } + + public static String[] getReturnType(String typeStr) { + switch (typeStr) { + case "FLOAT": + return new String[] { "Types.DataType.SCALAR", "Types.ValueType.FP64", "Types.ValueType.FP32" }; + case "INT": + return new String[] { "Types.DataType.SCALAR", "Types.ValueType.INT64", "Types.ValueType.INT32" }; + case "BOOL": + return new String[] { "Types.DataType.SCALAR", "Types.ValueType.BOOLEAN" }; + case "MATRIX": + return new String[] { "Types.DataType.MATRIX" }; + } + + throw new NotImplementedException(typeStr); + } + + public static String literalGetterFunction(RewriterStatement stmt, final RuleContext ctx) { + switch (stmt.getResultingDataType(ctx)) { + case "INT": + return "getLongValue()"; + case "FLOAT": + return "getDoubleValue()"; + case "BOOL": + return "getBooleanValue()"; + } + + throw new IllegalArgumentException(); + } + + public static String getHopConstructor(RewriterStatement cur, RewriterAssertions assertions, Map varNameMapping, final RuleContext ctx, String... children) { + String opClass = getOpClass(cur, ctx); + String opCode = null; + + // Special instructions + switch (cur.trueInstruction()) { + case "%*%": + if (children.length != 2) + throw new IllegalArgumentException(); + + return "HopRewriteUtils.createMatrixMultiply(" + children[0] + ", " + children[1] + ")"; + + case "t": + if (children.length != 1) + throw new IllegalArgumentException(); + return "HopRewriteUtils.createTranspose(" + children[0] + ")"; + + case "rev": + if (children.length != 1) + throw new IllegalArgumentException(); + return "HopRewriteUtils.createReorg(" + children[0] + ", Types.ReOrgOp.REV)"; + + case "rowSums": + if (children.length != 1) + throw new IllegalArgumentException(); + + return "HopRewriteUtils.createAggUnaryOp(" + children[0] + ", Types.AggOp.SUM, Types.Direction.Row)"; + + case "colSums": + if (children.length != 1) + throw new IllegalArgumentException(); + + return "HopRewriteUtils.createAggUnaryOp(" + children[0] + ", Types.AggOp.SUM, Types.Direction.Col)"; + + case "sum": + if (children.length != 1) + throw new IllegalArgumentException(); + + return "HopRewriteUtils.createAggUnaryOp(" + children[0] + ", Types.AggOp.SUM, Types.Direction.RowCol)"; + + case "trace": + if (children.length != 1) + throw new IllegalArgumentException(); + + return "HopRewriteUtils.createAggUnaryOp(" + children[0] + ", Types.AggOp.TRACE, Types.Direction.RowCol)"; + + case "ncol": + if (children.length != 1) + throw new IllegalArgumentException(); + + return "HopRewriteUtils.createUnary(" + children[0] + ", Types.OpOp1.NCOL)"; + + case "nrow": + if (children.length != 1) + throw new IllegalArgumentException(); + + return "HopRewriteUtils.createUnary(" + children[0] + ", Types.OpOp1.NROW)"; + + case "const": + String referredVarName = varNameMapping.get(cur.getChild(0)); + String nrowContent; + String ncolContent; + + if (referredVarName == null) { + Optional nrowLiteral = cur.getNRow().isLiteral() ? Optional.of(cur.getNRow()) : Optional.empty(); + Optional ncolLiteral = cur.getNCol().isLiteral() ? Optional.of(cur.getNCol()) : Optional.empty(); + + RewriterAssertions.RewriterAssertion nrowAssertion = assertions.getAssertionObj(cur.getNRow()); + RewriterAssertions.RewriterAssertion ncolAssertion = assertions.getAssertionObj(cur.getNCol()); + + nrowLiteral = nrowAssertion == null ? nrowLiteral : nrowAssertion.getLiteral(); + ncolLiteral = ncolAssertion == null ? ncolLiteral : ncolAssertion.getLiteral(); + + + if (nrowLiteral.isPresent()) { + nrowContent = "new LiteralOp(" + nrowLiteral.get().getLiteral().toString() + ")"; + } else { + // Find the first + nrowContent = null; + + if (nrowAssertion == null) + throw new IllegalArgumentException(); + + for (RewriterStatement stmt : nrowAssertion.getEClass()) { + String mappedName = varNameMapping.get(stmt); + + if (mappedName != null) { + nrowContent = getHopConstructor(stmt, assertions, varNameMapping, ctx, mappedName); + break; + } + } + + if (nrowContent == null) + throw new IllegalArgumentException(); + } + + if (ncolLiteral.isPresent()) { + ncolContent = "new LiteralOp(" + ncolLiteral.get().getLiteral().toString() + ")"; + } else { + // Find the first + ncolContent = null; + + if (ncolAssertion == null) + throw new IllegalArgumentException(); + + for (RewriterStatement stmt : ncolAssertion.getEClass()) { + String mappedName = varNameMapping.get(stmt); + + if (mappedName != null) { + ncolContent = getHopConstructor(stmt, assertions, varNameMapping, ctx, mappedName); + break; + } + } + + if (ncolContent == null) + throw new IllegalArgumentException(); + } + } else { + nrowContent = getHopConstructor(cur.getChild(0).getNRow(), assertions, varNameMapping, ctx, referredVarName); + ncolContent = getHopConstructor(cur.getChild(0).getNCol(), assertions, varNameMapping, ctx, referredVarName); + } + + return "((DataGenOp) HopRewriteUtils.createDataGenOpFromDims(" + nrowContent + "," + ncolContent + "," + cur.getChild(1).getLiteral() + "D))"; + } + + switch (opClass) { + case "UnaryOp": + if (children.length != 1) + throw new IllegalArgumentException(); + + opCode = getOpCode(cur, ctx); + return "HopRewriteUtils.createUnary(" + children[0] + ", " + opCode + ")"; + case "BinaryOp": + if (children.length != 2) + throw new IllegalArgumentException(); + + opCode = getOpCode(cur, ctx); + return "HopRewriteUtils.createAutoGeneratedBinary(" + children[0] + ", " + children[1] + ", " + opCode + ")"; + case "TernaryOp": + if (children.length != 3) + throw new IllegalArgumentException(); + + opCode = getOpCode(cur, ctx); + return "HopRewriteUtils.createTernary(" + children[0] + ", " + children[1] + ", " + children[2] + "," + opCode + ")"; + } + + throw new NotImplementedException(cur.trueTypedInstruction(ctx)); + } + +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/utils/ConstantFoldingUtils.java b/src/main/java/org/apache/sysds/hops/rewriter/utils/ConstantFoldingUtils.java new file mode 100644 index 00000000000..b46fb0e62b0 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/utils/ConstantFoldingUtils.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.utils; + +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.function.BiFunction; + +public class ConstantFoldingUtils { + static final double EPS = 1e-20; + + public static BiFunction foldingBiFunction(String op, String type) { + switch (op) { + case "+": + if (type.equals("FLOAT")) + return (num, stmt) -> foldSumFloat(num == null ? 0.0 : (double)num, stmt); + else if (type.equals("INT")) + return (num, stmt) -> foldSumInt(num == null ? 0L : (long)num, stmt); + else + throw new UnsupportedOperationException(); + case "*": + if (type.equals("FLOAT")) + return (num, stmt) -> foldMulFloat(num == null ? 1.0D : (double)num, stmt); + else if (type.equals("INT")) + return (num, stmt) -> foldMulInt(num == null ? 1L : (long)num, stmt); + else + throw new UnsupportedOperationException(); + case "min": + if (type.equals("FLOAT")) + return (num, stmt) -> num == null ? stmt.floatLiteral() : foldMinFloat((double)num, stmt); + else if (type.equals("INT")) + return (num, stmt) -> num == null ? stmt.intLiteral(false) : foldMinInt((long)num, stmt); + break; + case "max": + if (type.equals("FLOAT")) + return (num, stmt) -> num == null ? stmt.floatLiteral() : foldMaxFloat((double)num, stmt); + else if (type.equals("INT")) + return (num, stmt) -> num == null ? stmt.intLiteral(false) : foldMaxInt((long)num, stmt); + break; + } + + throw new UnsupportedOperationException(); + } + + public static boolean isNeutralElement(Object num, String op) { + switch (op) { + case "+": + return num.equals(0L) || num.equals(0.0D); + case "*": + return num.equals(1L) || num.equals(1.0D); + } + + return false; + } + + public static boolean isNegNeutral(Object num, String op) { + if (num == null) + return false; + + switch (op) { + case "*": + return num.equals(-1L) || num.equals(-1.0D); + } + + return false; + } + + public static boolean cancelOutNary(String op, List stmts) { + Set toRemove = new HashSet<>(); + switch (op) { + case "+": + for (int i = 0; i < stmts.size(); i++) { + RewriterStatement stmt1 = stmts.get(i); + for (int j = i+1; j < stmts.size(); j++) { + RewriterStatement stmt2 = stmts.get(j); + + if (stmt1.isInstruction() && stmt1.trueInstruction().equals("-") && stmt1.getChild(0).equals(stmt2) + || (stmt2.isInstruction() && stmt2.trueInstruction().equals("-") && stmt2.getChild(0).equals(stmt1))) { + if (!toRemove.contains(i) && !toRemove.contains(j)) { + toRemove.add(i); + toRemove.add(j); + } + } + + } + } + case "*": + for (int i = 0; i < stmts.size(); i++) { + RewriterStatement stmt1 = stmts.get(i); + for (int j = i+1; j < stmts.size(); j++) { + RewriterStatement stmt2 = stmts.get(j); + + if (stmt1.isInstruction() && stmt1.trueInstruction().equals("inv") && stmt1.getChild(0).equals(stmt2) + || (stmt2.isInstruction() && stmt2.trueInstruction().equals("inv") && stmt2.getChild(0).equals(stmt1))) { + if (!toRemove.contains(i) && !toRemove.contains(j)) { + toRemove.add(i); + toRemove.add(j); + } + } + + } + } + } + + if (toRemove.isEmpty()) + return false; + + List oldCpy = new ArrayList<>(stmts); + stmts.clear(); + + for (int i = 0; i < oldCpy.size(); i++) { + if (!toRemove.contains(i)) + stmts.add(oldCpy.get(i)); + } + + return true; + } + + // This function does not handle NaNs + public static RewriterStatement overwritesLiteral(Number num, String op, final RuleContext ctx) { + if (op.equals("*") && Math.abs(num.doubleValue()) < EPS) { + if (num instanceof Double) + return RewriterStatement.literal(ctx, 0.0); + else + return RewriterStatement.literal(ctx, 0L); + } + + return null; + } + + public static double foldSumFloat(double num, RewriterStatement next) { + return num + next.floatLiteral(); + } + + public static long foldSumInt(long num, RewriterStatement next) { + return num + next.intLiteral(false); + } + + public static double foldMulFloat(double num, RewriterStatement next) { + return num * next.floatLiteral(); + } + + public static long foldMulInt(long num, RewriterStatement next) { + return num * next.intLiteral(false); + } + + public static double foldMinFloat(double num, RewriterStatement next) { + return Math.min(num, next.floatLiteral()); + } + + public static long foldMinInt(long num, RewriterStatement next) { + return Math.min(num, next.intLiteral(false)); + } + + public static double foldMaxFloat(double num, RewriterStatement next) { + return Math.max(num, next.floatLiteral()); + } + + public static long foldMaxInt(long num, RewriterStatement next) { + return Math.max(num, next.intLiteral(false)); + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterSearchUtils.java b/src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterSearchUtils.java new file mode 100644 index 00000000000..daaafa71612 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterSearchUtils.java @@ -0,0 +1,618 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.utils; + +import org.apache.sysds.hops.rewriter.RewriterDataType; +import org.apache.sysds.hops.rewriter.RewriterInstruction; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Collectors; + +public class RewriterSearchUtils { + public static final List ALL_TYPES = List.of("MATRIX", "FLOAT"); + public static final List SCALAR = List.of("FLOAT"); + public static final List MATRIX = List.of("MATRIX"); + + public static Operand[] instructionAlphabet = new Operand[] { + null, + new Operand("+", 2, ALL_TYPES, ALL_TYPES), + //new Operand("+", 2, MATRIX, SCALAR), + //new Operand("+", 2, MATRIX, MATRIX), + + new Operand("-", 2, ALL_TYPES, ALL_TYPES), + //new Operand("-", 2, MATRIX, SCALAR), + //new Operand("-", 2, MATRIX, MATRIX), + + new Operand("*", 2, ALL_TYPES, ALL_TYPES), + //new Operand("*", 2, MATRIX, SCALAR), + //new Operand("*", 2, MATRIX, MATRIX), + + new Operand("/", 2, ALL_TYPES, ALL_TYPES), + //new Operand("/", 2, MATRIX, SCALAR), + //new Operand("/", 2, MATRIX, MATRIX), + + new Operand("%*%", 2, MATRIX, MATRIX), + + new Operand("sum", 1, MATRIX), + new Operand("*sum", 2, MATRIX, ALL_TYPES), // To have a bigger search space for this instruction combination + new Operand("t", 1, MATRIX), + new Operand("rev", 1, MATRIX), + new Operand("diag", 1, MATRIX), + new Operand("trace", 1, MATRIX), + new Operand("rowSums", 1, MATRIX), + new Operand("colSums", 1, MATRIX), + new Operand("max", 1, MATRIX), + new Operand("min", 1, MATRIX), + new Operand("ncol", 0, true, MATRIX), + new Operand("nrow", 0, true, MATRIX), + new Operand("length", 0, true, MATRIX), + + new Operand("!=", 2, ALL_TYPES, ALL_TYPES), + new Operand("!=0", 1, MATRIX), + new Operand("0!=", 1, MATRIX), + + new Operand("cast.MATRIX",1, SCALAR), + new Operand("cast.FLOAT", 1, MATRIX), + + new Operand("1-*", 2, MATRIX, MATRIX), + new Operand("+*", 3, MATRIX, SCALAR, MATRIX), + new Operand("-*", 3, MATRIX, SCALAR, MATRIX), + new Operand("*2", 1, MATRIX), + new Operand("_nnz", 1, MATRIX), + new Operand("sumSq", 1, MATRIX), + new Operand("sq", 1, MATRIX), + //new Operand("log", 1, MATRIX), + + // constant stuff + new Operand("c_1+", 1, ALL_TYPES), + new Operand("c_+1", 1, ALL_TYPES), + new Operand("c_1-", 1, ALL_TYPES), + new Operand("c_-1", 1, ALL_TYPES), + + // ncol / nrow / length stuff + new Operand("c_length*", 1, ALL_TYPES), + new Operand("c_ncol*", 1, ALL_TYPES), + new Operand("c_nrow*", 1, ALL_TYPES), + + new Operand("log_nz", 1, MATRIX), + + // Placeholder operators + new Operand("zero", 0, true), + new Operand("one", 0, true) + }; + + private static String[] varNames = new String[] { + "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M" + }; + + private static RuleContext ctx; + + public static int getMaxSearchNumberForNumOps(int numOps) { + int out = 1; + for (int i = 0; i < numOps; i++) + out *= instructionAlphabet.length; + + return out; + } + + public static void rename(RewriterStatement stmt) { + Set namedVars = new HashSet<>(); + + stmt.forEachPostOrder((cur, pred) -> { + if (!cur.isInstruction() && !cur.isLiteral()) { + if (!namedVars.contains(cur)) { + if (cur.getResultingDataType(ctx).equals("MATRIX")) + cur.rename(varNames[namedVars.size()]); + else + cur.rename(varNames[namedVars.size()].toLowerCase()); + + namedVars.add(cur); + } + } + }, false); + } + + // To include structures like row/column vectors etc. + public static List buildAssertionVariations(RewriterStatement root, final RuleContext ctx) { + List interestingLeaves = new ArrayList<>(); + root.forEachPreOrder(cur -> { + if (!cur.isInstruction() && !cur.isLiteral() && cur.getResultingDataType(ctx).equals("MATRIX")) + interestingLeaves.add(cur); + return true; + }, true); + + if (interestingLeaves.isEmpty()) + return Collections.emptyList(); + + List out = new ArrayList<>(); + + for (int i = 0; i < interestingLeaves.size(); i++) { + RewriterStatement from = interestingLeaves.get(i); + RewriterStatement rv = createVectorizedStatement(root, from, true); + if (ctx.metaPropagator != null) + rv = ctx.metaPropagator.apply(rv); + out.add(rv); + RewriterStatement cv = createVectorizedStatement(root, from, false); + if (ctx.metaPropagator != null) + cv = ctx.metaPropagator.apply(cv); + out.add(cv); + + for (int j = i + 1; j < interestingLeaves.size(); j++) { + RewriterStatement to = interestingLeaves.get(i); + Map map = new HashMap<>(); + map.put(from, false); + map.put(to, false); + out.add(createVectorizedStatements(root, map)); + map.put(from, true); + out.add(createVectorizedStatements(root, map)); + map.put(to, true); + out.add(createVectorizedStatements(root, map)); + map.put(from, false); + out.add(createVectorizedStatements(root, map)); + } + } + + // Serialize and parse again as there may still be duplicate references + out = out.stream().map(stmt -> RewriterUtils.parse(stmt.toParsableString(ctx, true), ctx)).collect(Collectors.toList()); + + if (ctx.metaPropagator != null) + return out.stream().map(stmt -> ctx.metaPropagator.apply(stmt)).collect(Collectors.toList()); + + return out; + } + + private static RewriterStatement createVector(RewriterStatement of, boolean rowVector, Map createdObjects) { + // TODO: Why is it necessary to discard the old DataType? + RewriterStatement mCpy = createdObjects.get(of); + + if (mCpy == null) { + mCpy = new RewriterDataType().as(of.getId()).ofType(of.getResultingDataType(ctx)).consolidate(ctx); + createdObjects.put(of, mCpy); + } + //RewriterStatement nRowCol = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction(rowVector ? "nrow" : "ncol").withOps(mCpy).consolidate(ctx); + //createdObjects.put(of, mCpy); + return new RewriterInstruction() + .as(of.getId()) + .withInstruction(rowVector ? "rowVec" : "colVec") + .withOps(mCpy) + .consolidate(ctx); + } + + private static RewriterStatement createVectorizedStatement(RewriterStatement root, RewriterStatement of, boolean rowVector) { + HashMap createdObjects = new HashMap<>(); + RewriterStatement out = root.nestedCopyOrInject(createdObjects, stmt -> { + if (stmt.equals(of)) + return createVector(of, rowVector, createdObjects); + + return null; + }); + + return out; + } + + private static RewriterStatement createVectorizedStatements(RewriterStatement root, Map of) { + HashMap createdObjects = new HashMap<>(); + + RewriterStatement out = root.nestedCopyOrInject(createdObjects, stmt -> { + if (!stmt.isInstruction() && !stmt.isLiteral() && stmt.getResultingDataType(ctx).equals("MATRIX")) { + Boolean rowVector = of.get(stmt); + + if (rowVector != null) + return createVector(stmt, rowVector, createdObjects); + } + + return null; + }); + + return out; + } + + // Builds variations of the same graph (e.g. +(A,B) -> +(A,A)) + public static List buildVariations(RewriterStatement root, final RuleContext ctx) { + List interestingLeaves = new ArrayList<>(); + root.forEachPreOrder(cur -> { + if (!cur.isInstruction() && !cur.isLiteral() && cur.getResultingDataType(ctx).equals("MATRIX")) + interestingLeaves.add(cur); + return true; + }, true); + + if (interestingLeaves.size() < 2) + return Collections.emptyList(); + + List out = new ArrayList<>(); + + for (int i = 0; i < interestingLeaves.size(); i++) { + RewriterStatement to = interestingLeaves.get(i); + for (int j = i + 1; j < interestingLeaves.size(); j++) { + RewriterStatement from = interestingLeaves.get(j); + HashMap createdObjects = new HashMap<>(); + RewriterStatement toCpy = new RewriterDataType().as(to.getId()).ofType(to.getResultingDataType(ctx)).consolidate(ctx); + createdObjects.put(from, toCpy); + createdObjects.put(to, toCpy); + RewriterStatement cpy = root.nestedCopyOrInject(createdObjects, stmt -> null); + if (ctx.metaPropagator != null) + cpy = ctx.metaPropagator.apply(cpy); + out.add(cpy); + } + } + + // Serialize and parse again as there may still be duplicate references + out = out.stream().map(stmt -> RewriterUtils.parse(stmt.toParsableString(ctx, true), ctx)).collect(Collectors.toList()); + + return out; + } + + public static List buildAllPossibleDAGs(List operands, final RuleContext ctx, boolean rename) { + if (operands == null) + return Collections.emptyList(); + + RewriterSearchUtils.ctx = ctx; + + List allStmts = recursivelyFindAllCombinations(operands, null, ALL_TYPES); + + if (rename) + allStmts.forEach(RewriterSearchUtils::rename); + + if (ctx.metaPropagator != null) + allStmts = allStmts.stream().map(stmt -> ctx.metaPropagator.apply(stmt)).collect(Collectors.toList()); + + // Serialize and parse all statements as there are still duplicate references + return allStmts.stream().map(stmt -> RewriterUtils.parse(stmt.toParsableString(ctx, true), ctx)).collect(Collectors.toList()); + } + + private static List recursivelyFindAllCombinations(List operands, Operand parent, List supportedTypes) { + if (operands.isEmpty()) + return supportedTypes.stream().map(t -> new RewriterDataType().as(UUID.randomUUID().toString()).ofType(t).consolidate(ctx)).collect(Collectors.toList()); + + // Check if op is a placeholder + Operand op = operands.get(0); + + if (op.isLeaf && operands.size() > 1) + return Collections.emptyList(); + + if (op.op.equals("zero") || op.op.equals("one")) { + List l = new ArrayList<>(2); + if (op.op.equals("zero")) { + if (supportedTypes.contains("FLOAT")) + l.add(RewriterStatement.literal(ctx, 0.0D)); + if (supportedTypes.contains("MATRIX")) + l.add(new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("const").withOps(new RewriterDataType().as(UUID.randomUUID().toString()).ofType("MATRIX").consolidate(ctx), RewriterStatement.literal(ctx, 0.0D)).consolidate(ctx)); + } else { + if (supportedTypes.contains("FLOAT")) + l.add(RewriterStatement.literal(ctx, 1.0D)); + + if (supportedTypes.contains("MATRIX")) + l.add(new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("const").withOps(new RewriterDataType().as(UUID.randomUUID().toString()).ofType("MATRIX").consolidate(ctx), RewriterStatement.literal(ctx, 1.0D)).consolidate(ctx)); + } + + return l; + } + + int nOps = operands.get(0).numArgs; + + if (nOps == 0) { + return List.of(buildStmt(op, null)); + } + + int[] slices = new int[Math.max(nOps-1, 0)]; + + List possibleStmts = new ArrayList<>(); + + forEachSlice(1, 0, operands.size()+1, slices, () -> { + List> cartesianBuilder = new ArrayList<>(); + + for (int i = 0; i < nOps; i++) { + int lIdx = i == 0 ? 1 : slices[i-1]; + int uIdx = i == slices.length ? operands.size() : slices[i]; + + List view; + if (lIdx == uIdx) + view = Collections.emptyList(); + else + view = operands.subList(lIdx, uIdx); + + List combs = recursivelyFindAllCombinations(view, op, op.supportedTypes[i]); + + if (combs.isEmpty()) + return; // Then no subgraph can be created from that order + + cartesianBuilder.add(combs); + } + + RewriterStatement[] stack = new RewriterStatement[nOps]; + RewriterUtils.cartesianProduct(cartesianBuilder, stack, mStack -> { + try { + for (int i = 0; i < stack.length; i++) + if (!op.supportedTypes[i].contains(stack[i].getResultingDataType(ctx))) + return true; + + RewriterStatement stmt = buildStmt(operands.get(0), stack); + possibleStmts.add(stmt); + } catch (Exception e) { + // Might fail as there could be wrong types + } + return true; // Should continue + }); + }); + + return possibleStmts; + } + + private static RewriterStatement buildStmt(Operand op, RewriterStatement[] stack) { + RewriterInstruction stmt = new RewriterInstruction().as(UUID.randomUUID().toString()); + switch (op.op) { + case "!=0": { + stmt.withInstruction("!=").addOp(stack[0]).addOp(RewriterStatement.literal(ctx, 0.0D)); + break; + } + case "0!=": { + stmt.withInstruction("!=").addOp(RewriterStatement.literal(ctx, 0.0D)).addOp(stack[0]); + break; + } + case "ncol": + case "nrow": + case "length": { + String actualOp = op.op; + stmt.withInstruction(actualOp).withOps(new RewriterDataType().as(UUID.randomUUID().toString()).ofType("MATRIX").consolidate(ctx)).consolidate(ctx); + break; + } + case "fncol": + case "fnrow": + case "flength": { + String actualOp = op.op.substring(1); + stmt.withInstruction(actualOp).withOps(stack).consolidate(ctx); + stmt = (RewriterInstruction) RewriterStatement.castFloat(ctx, stmt); + break; + } + case "*sum": { + RewriterStatement old = stmt.withInstruction("sum").withOps(stack[0]).consolidate(ctx); + stmt = new RewriterInstruction("*", ctx, old, stack[1]); + break; + } + case "c_1+": { + stmt = new RewriterInstruction("+", ctx, RewriterStatement.literal(ctx, 1.0D), stack[0]); + break; + } + case "c_+1": { + stmt = new RewriterInstruction("+", ctx, stack[0], RewriterStatement.literal(ctx, 1.0D)); + break; + } + case "c_1-": { + stmt = new RewriterInstruction("-", ctx, RewriterStatement.literal(ctx, 1.0D), stack[0]); + break; + } + case "c_-1": { + stmt = new RewriterInstruction("-", ctx, stack[0], RewriterStatement.literal(ctx, 1.0D)); + break; + } + case "c_length*": { + stmt = new RewriterInstruction("*", ctx, new RewriterInstruction("length", ctx, new RewriterDataType().as(UUID.randomUUID().toString()).ofType("MATRIX").consolidate(ctx)), stack[0]); + break; + } + case "c_nrow*": { + stmt = new RewriterInstruction("*", ctx, new RewriterInstruction("nrow", ctx, new RewriterDataType().as(UUID.randomUUID().toString()).ofType("MATRIX").consolidate(ctx)), stack[0]); + break; + } + case "c_col*": { + stmt = new RewriterInstruction("*", ctx, new RewriterInstruction("ncol", ctx, new RewriterDataType().as(UUID.randomUUID().toString()).ofType("MATRIX").consolidate(ctx)), stack[0]); + break; + } + default: { + stmt.withInstruction(op.op).withOps(stack); + break; + } + } + + stmt.consolidate(ctx); + return stmt; + } + + private static void forEachSlice(int startIdx, int pos, int maxIdx, int[] slices, Runnable trigger) { + if (pos >= slices.length) { + trigger.run(); + return; + } + + for (int idx = startIdx; idx < maxIdx; idx++) { + slices[pos] = idx; + + if (pos != slices.length-1) { + forEachSlice(idx, pos+1, maxIdx, slices, trigger); + } else { + trigger.run(); + } + } + } + + public static List decodeOrderedStatements(int stmt) { + int[] instructions = fromBaseNNumber(stmt, instructionAlphabet.length); + List out = new ArrayList<>(instructions.length); + + for (int i = 0; i < instructions.length; i++) { + Operand toAdd = instructionAlphabet[instructions[i]]; + if (toAdd == null) + return null; + out.add(toAdd); + } + + return out; + } + + public static int[] fromBaseNNumber(int l, int n) { + if (l == 0) + return new int[0]; + + int numDigits = (int)(Math.log(l) / Math.log(n)) + 1; + int[] digits = new int[numDigits]; + + for (int i = numDigits - 1; i >= 0; i--) { + digits[i] = l % n; + l = l / n; + } + + return digits; + } + + public static int toBaseNNumber(int[] digits, int n) { + if (digits.length == 0) + throw new IllegalArgumentException(); + + int multiplicator = 1; + int out = 0; + + for (int i = digits.length - 1; i >= 0; i--) { + out += multiplicator * digits[i]; + multiplicator *= n; + } + + return out; + } + + public static List mergeSubtreeCombinations(RewriterStatement stmt, List indices, List> mList, final RuleContext ctx, int maximumCombinations) { + if (indices.isEmpty()) + return List.of(stmt); + + List mergedTreeCombinations = new ArrayList<>(); + RewriterUtils.cartesianProduct(mList, new RewriterStatement[mList.size()], stack -> { + RewriterStatement cpy = stmt.copyNode(); + for (int i = 0; i < stack.length; i++) + cpy.getOperands().set(indices.get(i), stack[i]); + cpy.consolidate(ctx); + cpy.prepareForHashing(); + cpy.recomputeHashCodes(ctx); + mergedTreeCombinations.add(cpy); + return mergedTreeCombinations.size() < maximumCombinations; + }); + + return mergedTreeCombinations; + } + + public static List generateSubtrees(RewriterStatement stmt, final RuleContext ctx, int maximumCombinations) { + List l = generateSubtrees(stmt, new HashMap<>(), ctx, maximumCombinations); + + if (ctx.metaPropagator != null) + l.forEach(subtree -> ctx.metaPropagator.apply(subtree)); + + return l.stream().map(subtree -> { + if (ctx.metaPropagator != null) + subtree = ctx.metaPropagator.apply(subtree); + + subtree.prepareForHashing(); + subtree.recomputeHashCodes(ctx); + // We return a copy of the tree as there are still duplicate references + return RewriterUtils.parse(subtree.toParsableString(ctx, true), ctx); + }).collect(Collectors.toList()); + } + + private static Random rd = new Random(); + + private static List generateSubtrees(RewriterStatement stmt, Map> visited, final RuleContext ctx, int maxCombinations) { + if (stmt == null) + return Collections.emptyList(); + + RewriterStatement is = stmt; + List alreadyVisited = visited.get(is); + + if (alreadyVisited != null) + return alreadyVisited; + + if (stmt.getOperands().size() == 0) + return List.of(stmt); + + // Scan if operand is not a DataType + List indices = new ArrayList<>(); + for (int i = 0; i < stmt.getOperands().size(); i++) { + if (stmt.getChild(i).isInstruction() || stmt.getChild(i).isLiteral()) + indices.add(i); + } + + int n = indices.size(); + int totalSubsets = 1 << n; + + List mList = new ArrayList<>(); + + visited.put(is, mList); + + List> mOptions = indices.stream().map(i -> generateSubtrees(stmt.getOperands().get(i), visited, ctx, maxCombinations)).collect(Collectors.toList()); + List out = new ArrayList<>(); + + for (int subsetMask = 0; subsetMask < totalSubsets; subsetMask++) { + List> mOptionCpy = new ArrayList<>(mOptions); + + for (int i = 0; i < n; i++) { + // Check if the i-th child is included in the current subset + if ((subsetMask & (1 << i)) == 0) { + String dt = stmt.getOperands().get(indices.get(i)).getResultingDataType(ctx); + String namePrefix = "tmp"; + if (dt.equals("MATRIX")) + namePrefix = "M"; + else if (dt.equals("FLOAT")) + namePrefix = "f"; + else if (dt.equals("INT")) + namePrefix = "i"; + else if (dt.equals("BOOL")) + namePrefix = "b"; + RewriterDataType mT = new RewriterDataType().as(namePrefix + rd.nextInt(100000)).ofType(dt); + mT.consolidate(ctx); + mOptionCpy.set(i, List.of(mT)); + } + } + + out.addAll(mergeSubtreeCombinations(stmt, indices, mOptionCpy, ctx, maxCombinations)); + if (out.size() > maxCombinations) { + System.out.println("Aborting early due to too many combinations"); + return out; + } + } + + return out; + } + + public static final class Operand { + public final String op; + public final int numArgs; + public final List[] supportedTypes; + public final boolean isLeaf; + + public Operand(String op, int numArgs, List... supportedTypes) { + this(op, numArgs, false, supportedTypes); + } + public Operand(String op, int numArgs, boolean isLeaf, List... supportedTypes) { + this.op = op; + this.numArgs = numArgs; + this.supportedTypes = supportedTypes; + this.isLeaf = isLeaf; + } + + public String toString() { + return op; + } + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterUtils.java b/src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterUtils.java new file mode 100644 index 00000000000..258b65002fb --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterUtils.java @@ -0,0 +1,1375 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.utils; + +import org.apache.commons.lang3.NotImplementedException; +import org.apache.commons.lang3.mutable.MutableBoolean; +import org.apache.commons.lang3.mutable.MutableInt; +import org.apache.commons.lang3.mutable.MutableObject; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.logging.log4j.util.TriConsumer; +import org.apache.sysds.hops.rewriter.MetaPropagator; +import org.apache.sysds.hops.rewriter.RewriterContextSettings; +import org.apache.sysds.hops.rewriter.RewriterDataType; +import org.apache.sysds.hops.rewriter.rule.RewriterHeuristic; +import org.apache.sysds.hops.rewriter.rule.RewriterHeuristics; +import org.apache.sysds.hops.rewriter.RewriterInstruction; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleBuilder; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleCollection; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleSet; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.hops.rewriter.TopologicalSort; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class RewriterUtils { + protected static final Log LOG = LogFactory.getLog(RewriterUtils.class.getName()); + + public static final Pattern LONG_PATTERN = Pattern.compile("-?\\d+"); + public static final Pattern DOUBLE_PATTERN = Pattern.compile("-?\\d*\\.\\d+([eE][+-]?\\d+)?"); + public static final Pattern SPECIAL_FLOAT_PATTERN = Pattern.compile("Infinity|NaN"); + + public static String typedToUntypedInstruction(String instr) { + return instr.substring(0, instr.indexOf('(')); + } + + public static BiFunction binaryStringRepr(String op) { + return (stmt, ctx) -> { + List operands = stmt.getOperands(); + String op1Str = operands.get(0).toString(ctx); + if (operands.get(0) instanceof RewriterInstruction && operands.get(0).getOperands().size() > 1) + op1Str = "(" + op1Str + ")"; + String op2Str = operands.get(1).toString(ctx); + if (operands.get(1) instanceof RewriterInstruction && operands.get(1).getOperands().size() > 1) + op2Str = "(" + op2Str + ")"; + return op1Str + op + op2Str; + }; + } + + public static void mergeArgLists(RewriterStatement stmt, final RuleContext ctx) { + + stmt.forEachPreOrder(el -> { + tryFlattenNestedArgList(ctx, el, el, -1); + tryFlattenNestedOperatorPatterns(ctx, el); + el.refreshReturnType(ctx); + return true; + }, true); + + stmt.prepareForHashing(); + stmt.recomputeHashCodes(ctx); + } + + public static boolean tryFlattenNestedArgList(final RuleContext ctx, RewriterStatement stmt, RewriterStatement root, int insertAt) { + if (!stmt.isArgumentList()) + return false; + + if (stmt == root) { + boolean anyMatch = false; + + for (int i = 0; i < stmt.getOperands().size(); i++) { + RewriterStatement op = stmt.getOperands().get(i); + if (tryFlattenNestedArgList(ctx, op, root, i)) { + stmt.getOperands().remove(i); + anyMatch = true; + } + } + + return anyMatch; + } + + String dt1 = root.getResultingDataType(ctx); + String dt2 = stmt.getResultingDataType(ctx); + + String convertibleDataType = convertibleType(dt1.substring(0, dt1.length()-3), dt2.substring(0, dt2.length()-3)); + + if (convertibleDataType == null) + return false; + + root.getOperands().addAll(insertAt+1, stmt.getOperands()); + + return true; + } + + public static void tryFlattenNestedOperatorPatterns(final RuleContext ctx, RewriterStatement stmt) { + if (!stmt.isInstruction()) + return; + + RewriterInstruction instr = (RewriterInstruction) stmt; + + if (instr.hasProperty("FusedOperator", ctx)) { + for (int i = 0; i < instr.getOperands().get(0).getOperands().size(); i++) + if (flattenNestedOperatorPatterns(ctx, instr.getOperands().get(0).getOperands().get(i), instr, i)) + i--; + } + } + + private static boolean flattenNestedOperatorPatterns(final RuleContext ctx, RewriterStatement stmt, RewriterInstruction rootInstr, int insertAt) { + if (stmt.isInstruction() && ((RewriterInstruction)stmt).hasProperty("FusedOperator", ctx) && stmt.trueInstruction().equals(rootInstr.trueInstruction())) { + RewriterStatement origArgList = rootInstr.getOperands().get(0); + RewriterStatement subArgList = stmt.getOperands().get(0); + + origArgList.getOperands().set(insertAt, subArgList.getOperands().get(0)); + origArgList.getOperands().addAll(insertAt+1, subArgList.getOperands().subList(1, subArgList.getOperands().size())); + + return true; + } + + return false; + } + + public static RewriterStatement parse(String expr, final RuleContext ctx) { + String[] split = expr.split("\n"); + return parse(split[split.length-1], ctx, Arrays.copyOfRange(split, 0, split.length-1)); + } + + public static RewriterRule parseRule(String expr, final RuleContext ctx) { + // Remove empty lines + expr = expr.replaceAll("\n\\s*\n", "\n"); + String[] split = expr.split("\n"); + Set allowedMultiRefs = Collections.emptySet(); + boolean allowCombinations = false; + boolean parsedExtendedHeader = false; + + if (split[0].startsWith("AllowedMultiRefs:")) { + split[0] = split[0].substring(17); + String[] sSplit = split[0].split(","); + allowedMultiRefs = Arrays.stream(sSplit).map(s -> Integer.parseInt(s.substring(1))).collect(Collectors.toSet()); + + if (!split[1].startsWith("AllowCombinations:")) + throw new IllegalArgumentException(); + + split[1] = split[1].substring(18); + allowCombinations = Boolean.parseBoolean(split[1]); + parsedExtendedHeader = true; + } + + int condIdxStart = -1; + for (int i = 2; i < split.length; i++) { + if (split[i].startsWith("{")) { + // Then we have a conditional rule + condIdxStart = i; + break; + } + } + + if (condIdxStart != -1) { + // Then we have a conditional rule + List toExprs = Arrays.asList(split).subList(condIdxStart+1, split.length-1); + return parseRule(split[condIdxStart-2], toExprs, allowedMultiRefs, allowCombinations, ctx, Arrays.copyOfRange(split, parsedExtendedHeader ? 2 : 0, condIdxStart-2)); + } + + return parseRule(split[split.length-3], split[split.length-1], allowedMultiRefs, allowCombinations, ctx, Arrays.copyOfRange(split, parsedExtendedHeader ? 2 : 0, split.length-3)); + } + + public static RewriterStatement parse(String expr, final RuleContext ctx, String... varDefinitions) { + return parse(expr, ctx, new HashMap<>(), varDefinitions); + } + + public static RewriterRule parseRule(String exprFrom, String exprTo, Set allowedMultiRefs, boolean allowCombinations, final RuleContext ctx, String... varDefinitions) { + return parseRule(exprFrom, exprTo, ctx, new HashMap<>(), allowedMultiRefs, allowCombinations, varDefinitions); + } + + public static RewriterRule parseRule(String exprFrom, List exprsTo, Set allowedMultiRefs, boolean allowCombinations, final RuleContext ctx, String... varDefinitions) { + return parseRule(exprFrom, exprsTo, ctx, new HashMap<>(), allowedMultiRefs, allowCombinations, true, varDefinitions); + } + + public static RewriterStatement parse(String expr, final RuleContext ctx, Map dataTypes, String... varDefinitions) { + for (String def : varDefinitions) + parseDataTypes(def, dataTypes, ctx); + + RewriterStatement parsed = parseExpression(expr, new HashMap<>(), dataTypes, ctx); + if (ctx.metaPropagator == null) + return parsed; + else { + RewriterStatement out = ctx.metaPropagator.apply(parsed); + out.prepareForHashing(); + out.recomputeHashCodes(ctx); + return out; + } + } + + public static RewriterRule parseRule(String exprFrom, String exprTo, final RuleContext ctx, Map dataTypes, Set allowedMultiRefs, boolean allowCombinations, String... varDefinitions) { + for (String def : varDefinitions) + parseDataTypes(def, dataTypes, ctx); + + HashMap mmap = new HashMap<>(); + + RewriterStatement parsedFrom = parseExpression(exprFrom, mmap, dataTypes, ctx); + RewriterStatement parsedTo = parseExpression(exprTo, mmap, dataTypes, ctx); + + if (ctx.metaPropagator != null) { + parsedFrom = ctx.metaPropagator.apply(parsedFrom); + parsedTo = ctx.metaPropagator.apply(parsedTo); + } + + return new RewriterRuleBuilder(ctx).completeRule(parsedFrom, parsedTo).withAllowedMultiRefs(allowedMultiRefs.stream().map(mmap::get).collect(Collectors.toSet()), allowCombinations).setUnidirectional(true).build(); + } + + public static RewriterRule parseRule(String exprFrom, List exprsTo, final RuleContext ctx, Map dataTypes, Set allowedMultiRefs, boolean allowCombinations, boolean asConditional, String... varDefinitions) { + if (!asConditional && exprsTo.size() > 1) + throw new IllegalArgumentException(); + + for (String def : varDefinitions) + parseDataTypes(def, dataTypes, ctx); + + HashMap mmap = new HashMap<>(); + + RewriterStatement parsedFrom = parseExpression(exprFrom, mmap, dataTypes, ctx); + if (ctx.metaPropagator != null) { + parsedFrom = ctx.metaPropagator.apply(parsedFrom); + } + + List parsedTos = new ArrayList<>(); + for (String exprTo : exprsTo) { + RewriterStatement parsedTo = parseExpression(exprTo, mmap, dataTypes, ctx); + + if (ctx.metaPropagator != null) { + parsedTo = ctx.metaPropagator.apply(parsedTo); + parsedTo.prepareForHashing(); + parsedTo.recomputeHashCodes(ctx); + } + + parsedTos.add(parsedTo); + } + + return new RewriterRuleBuilder(ctx) + .completeConditionalRule(parsedFrom, parsedTos) + .withAllowedMultiRefs(allowedMultiRefs.stream().map(mmap::get).collect(Collectors.toSet()), allowCombinations) + .setUnidirectional(true).build(); + } + + /** + * Parses an expression + * @param expr the expression string + * @param refmap test + * @param dataTypes data type + * @param ctx context + * @return test + */ + public static RewriterStatement parseExpression(String expr, Map refmap, Map dataTypes, final RuleContext ctx) { + RuleContext.currentContext = ctx; + expr = expr.replaceAll("\\s+", ""); + MutableObject mexpr = new MutableObject<>(expr); + RewriterStatement stmt = doParseExpression(mexpr, refmap, dataTypes, ctx); + stmt.prepareForHashing(); + stmt.consolidate(ctx); + return stmt; + } + + private static RewriterStatement doParseExpression(MutableObject mexpr, Map refmap, Map dataTypes, final RuleContext ctx) { + String expr = mexpr.getValue(); + if (expr.startsWith("$")) { + expr = expr.substring(1); + Pattern pattern = Pattern.compile("^\\d+"); + Matcher matcher = pattern.matcher(expr); + + if (matcher.find()) { + String number = matcher.group(); + int n = Integer.parseInt(number); + if (expr.charAt(matcher.end()) != ':') { + // Then we inject the common subexpression + String remainder = expr.substring(matcher.end()); + mexpr.setValue(remainder); + RewriterStatement var = refmap.get(n); + + if (var == null) + throw new IllegalArgumentException("Variable '$" + n + "' does not exist!"); + + return var; + } + String remainder = expr.substring(matcher.end() + 1); + mexpr.setValue(remainder); + RewriterStatement stmt = parseRawExpression(mexpr, refmap, dataTypes, ctx); + refmap.put(n, stmt); + return stmt; + } else { + throw new IllegalArgumentException("Expected a number"); + } + } else { + return parseRawExpression(mexpr, refmap, dataTypes, ctx); + } + } + + public static boolean parseDataTypes(String expr, Map dataTypes, final RuleContext ctx) { + RuleContext.currentContext = ctx; + Pattern pattern = Pattern.compile("([A-Za-z0-9]|_|\\.|\\*|\\?)([A-Za-z0-9]|_|\\.|\\*|-)*"); + Matcher matcher = pattern.matcher(expr); + + if (!matcher.find()) + return false; + + String dType = matcher.group(); + boolean intLiteral = dType.equals("LITERAL_INT"); + boolean boolLiteral = dType.equals("LITERAL_BOOL"); + boolean floatLiteral = dType.equals("LITERAL_FLOAT"); + + if (intLiteral) { + pattern = Pattern.compile("(-)?[0-9]+"); + } else if (boolLiteral) { + pattern = Pattern.compile("(TRUE|FALSE)"); + } else if (floatLiteral) { + pattern = Pattern.compile("((-)?([0-9]+(\\.[0-9]*)?(E(-)?[0-9]+)?|Infinity)|NaN)"); + } + + if (expr.charAt(matcher.end()) != ':') + return false; + + expr = expr.substring(matcher.end() + 1); + + matcher = pattern.matcher(expr); + + while (matcher.find()) { + String varName = matcher.group(); + + RewriterDataType dt; + + if (intLiteral) { + dt = new RewriterDataType().as(varName).ofType("INT").asLiteral(Long.parseLong(varName)); + } else if (boolLiteral) { + dt = new RewriterDataType().as(varName).ofType("BOOL").asLiteral(Boolean.parseBoolean(varName)); + } else if (floatLiteral) { + dt = new RewriterDataType().as(varName).ofType("FLOAT").asLiteral(Double.parseDouble(varName)); + } else { + dt = new RewriterDataType().as(varName).ofType(dType); + } + + dt.consolidate(ctx); + dataTypes.put(varName, dt); + + if (expr.length() == matcher.end()) + return true; + + if (expr.charAt(matcher.end()) != ',') + return false; + + expr = expr.substring(matcher.end()+1); + matcher = pattern.matcher(expr); + } + + return false; + } + + private static RewriterStatement parseRawExpression(MutableObject mexpr, Map refmap, Map dataTypes, final RuleContext ctx) { + String expr = mexpr.getValue(); + + Pattern pattern = Pattern.compile("^[^(),:]+"); + Matcher matcher = pattern.matcher(expr); + + if (matcher.find()) { + String token = matcher.group(); + String remainder = expr.substring(matcher.end()); + + if (remainder.isEmpty()) { + mexpr.setValue(remainder); + if (dataTypes.containsKey(token)) + return dataTypes.get(token); + throw new IllegalArgumentException("DataType: '" + token + "' doesn't exist"); + } + + + char nextChar = remainder.charAt(0); + + switch (nextChar) { + case '(': + // Then this is a function + if (remainder.charAt(1) == ')') { + RewriterInstruction mInstr = new RewriterInstruction().withInstruction(token).as(UUID.randomUUID().toString()); + handleSpecialInstructions(mInstr); + mInstr.consolidate(ctx); + mexpr.setValue(remainder.substring(2)); + return mInstr; + } else { + List opList = new ArrayList<>(); + mexpr.setValue(remainder.substring(1)); + RewriterStatement cstmt = doParseExpression(mexpr, refmap, dataTypes, ctx); + opList.add(cstmt); + + while (mexpr.getValue().charAt(0) == ',') { + mexpr.setValue(mexpr.getValue().substring(1)); + cstmt = doParseExpression(mexpr, refmap, dataTypes, ctx); + opList.add(cstmt); + } + + if (mexpr.getValue().charAt(0) != ')') + throw new IllegalArgumentException(mexpr.getValue()); + + mexpr.setValue(mexpr.getValue().substring(1)); + RewriterInstruction instr = new RewriterInstruction().withInstruction(token).withOps(opList.toArray(RewriterStatement[]::new)).as(UUID.randomUUID().toString()); + handleSpecialInstructions(instr); + instr.consolidate(ctx); + return instr; + } + case ')': + case ',': + mexpr.setValue(remainder); + if (dataTypes.containsKey(token)) + return dataTypes.get(token); + throw new IllegalArgumentException("DataType: '" + token + "' doesn't exist"); + default: + throw new NotImplementedException(); + } + } else { + throw new IllegalArgumentException(mexpr.getValue()); + } + } + + private static void handleSpecialInstructions(RewriterInstruction instr) { + if (instr.trueInstruction().equals("_m")) { + UUID ownerId = UUID.randomUUID(); + instr.unsafePutMeta("ownerId", ownerId); + + if (instr.getOperands().get(0).isInstruction() && instr.getOperands().get(0).trueInstruction().equals("_idx")) { + instr.getOperands().get(0).unsafePutMeta("ownerId", ownerId); + instr.getOperands().get(0).unsafePutMeta("idxId", UUID.randomUUID()); + } + + if (instr.getOperands().get(1).isInstruction() && instr.getOperands().get(1).trueInstruction().equals("_idx")) { + instr.getOperands().get(1).unsafePutMeta("ownerId", ownerId); + instr.getOperands().get(1).unsafePutMeta("idxId", UUID.randomUUID()); + } + } else if (instr.trueInstruction().equals("_idxExpr")) { + UUID ownerId = UUID.randomUUID(); + instr.unsafePutMeta("ownerId", ownerId); + + if (instr.getOperands().get(0).isInstruction() && instr.getOperands().get(0).trueInstruction().equals("_idx")) { + instr.getOperands().get(0).unsafePutMeta("ownerId", ownerId); + instr.getOperands().get(0).unsafePutMeta("idxId", UUID.randomUUID()); + } + } + } + + public static void buildBinaryAlgebraInstructions(StringBuilder sb, String instr, List instructions) { + for (String arg1 : instructions) { + for (String arg2 : instructions) { + sb.append(instr + "(" + arg1 + "," + arg2 + ")::"); + + if (arg1.equals("MATRIX") || arg2.equals("MATRIX")) + sb.append("MATRIX\n"); + else if (arg1.equals("FLOAT") || arg2.equals("FLOAT")) + sb.append("FLOAT\n"); + else + sb.append("INT\n"); + } + } + } + + public static void buildTernaryPermutations(List args, TriConsumer func) { + buildBinaryPermutations(args, (t1, t2) -> args.forEach(t3 -> func.accept(t1, t2, t3))); + } + + public static void buildBinaryPermutations(List args, BiConsumer func) { + buildBinaryPermutations(args, args, func); + } + + public static void buildBinaryPermutations(List args1, List args2, BiConsumer func) { + for (String arg1 : args1) + for (String arg2 : args2) + func.accept(arg1, arg2); + } + + public static String defaultTypeHierarchy(String t1, String t2) { + boolean is1ArgList = t1.endsWith("..."); + boolean is2ArgList = t2.endsWith("..."); + + if (is1ArgList) + t1 = t1.substring(0, t1.length() - 3); + + if (is2ArgList) + t2 = t2.substring(0, t2.length() - 3); + + if (t1.equals("BOOL") && t2.equals("BOOL")) + return "BOOL"; + if (t1.equals("INT") && (t2.equals("INT") || t2.equals("BOOL"))) + return "INT"; + + if (t2.equals("INT") && (t1.equals("INT") || t1.equals("BOOL"))) + return "INT"; + + if (!t1.equals("MATRIX") && !t2.equals("MATRIX")) + return "FLOAT"; + return "MATRIX"; + } + + public static String convertibleType(String t1, String t2) { + if (t1.equals("MATRIX") && t2.equals("MATRIX")) + return "MATRIX"; + + if (t1.equals("MATRIX") || t2.equals("MATRIX")) + return null; // Then it is not convertible + + if (!List.of("FLOAT", "INT", "BOOL").contains(t1) || !List.of("FLOAT", "INT", "BOOL").contains(t2)) + return null; + + if (t1.equals("FLOAT") || t2.equals("FLOAT")) + return "FLOAT"; // This is the most "general" type + + if (t1.equals("INT") || t2.equals("INT")) + return "INT"; + + return "BOOL"; + } + + public static String convertImplicitly(String type, boolean allowTypeConversions) { + if (!allowTypeConversions) + return type; + return convertImplicitly(type); + } + + public static String convertImplicitly(String type) { + if (type == null) + return null; + + if (type.equals("INT") || type.equals("BOOL")) + return "FLOAT"; + return type; + } + + public static void putAsBinaryPrintable(String instr, List types, HashMap> printFunctions, BiFunction function) { + for (String type1 : types) + for (String type2 : types) + printFunctions.put(instr + "(" + type1 + "," + type2 + ")", function); + } + + public static void putAsDefaultBinaryPrintable(List instrs, List types, HashMap> funcs) { + for (String instr : instrs) + putAsBinaryPrintable(instr, types, funcs, binaryStringRepr(" " + instr + " ")); + } + + // Updates the references (including metadata UUIDs) for a copied _idxExpr(args(_idx(...),...),...) + public static void copyIndexList(RewriterStatement idxExprRoot) { + if (!idxExprRoot.isInstruction() || !idxExprRoot.trueInstruction().equals("_idxExpr")) + throw new IllegalArgumentException(); + + Map replacements = new HashMap<>(); + UUID newOwnerId = UUID.randomUUID(); + idxExprRoot.unsafePutMeta("ownerId", newOwnerId); + + RewriterStatement newArgList = idxExprRoot.getChild(0).copyNode(); + idxExprRoot.getOperands().set(0, newArgList); + + List operands = newArgList.getOperands(); + + for (int i = 0; i < operands.size(); i++) { + RewriterStatement idx = operands.get(i); + RewriterStatement cpy = idx.copyNode(); + UUID newId = UUID.randomUUID(); + cpy.unsafePutMeta("idxId", newId); + cpy.unsafePutMeta("ownerId", newOwnerId); + replacements.put((UUID)idx.getMeta("idxId"), cpy); + operands.set(i, cpy); + } + + RewriterStatement out = RewriterUtils.replaceReferenceAware(idxExprRoot.getChild(1), stmt -> { + UUID idxId = (UUID) stmt.getMeta("idxId"); + if (idxId != null) { + RewriterStatement newStmt = replacements.get(idxId); + if (newStmt != null) + return newStmt; + } + + return null; + }); + + if (out != null) + idxExprRoot.getOperands().set(1, out); + } + + public static RewriterStatement replaceReferenceAware(RewriterStatement root, Function comparer) { + return replaceReferenceAware(root, false, comparer, new HashMap<>()); + } + + // Replaces elements in a DAG. If a parent item has multiple references, the entire path is duplicated + public static RewriterStatement replaceReferenceAware(RewriterStatement root, boolean duplicateReferences, Function comparer, HashMap visited) { + if (visited.containsKey(root)) + return visited.get(root); + + RewriterStatement newOne = comparer.apply(root); + + if (newOne == root) + newOne = null; + + root = newOne != null ? newOne : root; + + if (newOne == null) + duplicateReferences |= root.refCtr > 1; + + if (root.getOperands() != null) { + for (int i = 0; i < root.getOperands().size(); i++) { + RewriterStatement newSub = replaceReferenceAware(root.getOperands().get(i), duplicateReferences, comparer, visited); + + if (newSub != null) { + if (duplicateReferences && newOne == null) { + root = root.copyNode(); + newOne = root; + } + + root.getOperands().set(i, newSub); + } + } + } + + return newOne; + } + + // Deduplicates the DAG (removes duplicate references with new nodes except for leaf data-types) + public static void unfoldExpressions(RewriterStatement root, RuleContext ctx) { + for (int i = 0; i < root.getOperands().size(); i++) { + RewriterStatement child = root.getChild(i); + if (child.isInstruction() && child.refCtr > 1) { + if (!child.trueInstruction().equals("_idx") + && !child.trueInstruction().equals("_m") + && !child.trueInstruction().equals("idxExpr") + && !child.trueInstruction().equals("rand") + && !child.trueInstruction().equals("_EClass")) { + RewriterStatement cpy = child.copyNode(); + root.getOperands().set(i, cpy); + child.refCtr--; + cpy.getOperands().forEach(op -> op.refCtr++); + } + } + + unfoldExpressions(child, ctx); + } + } + + public static boolean cartesianProduct(List> list, T[] stack, Function emitter) { + if (list.size() == 0) + return false; + + if (list.size() == 1) { + list.get(0).forEach(t -> { + stack[0] = t; + emitter.apply(stack); + }); + return true; + } + + return _cartesianProduct(0, list, stack, emitter, new MutableBoolean(true)); + } + + private static boolean _cartesianProduct(int index, List> sets, T[] currentStack, Function emitter, MutableBoolean doContinue) { + if (index >= sets.size()) { + if (!emitter.apply(currentStack)) + doContinue.setValue(false); + return true; + } + + int size = sets.get(index).size(); + boolean matchFound = false; + + for (int i = 0; i < size; i++) { + currentStack[index] = sets.get(index).get(i); + matchFound |= _cartesianProduct(index+1, sets, currentStack, emitter, doContinue); + + if (!doContinue.booleanValue()) + return matchFound; + } + + return matchFound; + } + + public static boolean isImplicitlyConvertible(String typeFrom, String typeTo) { + if (typeFrom.equals(typeTo)) + return true; + + if (typeFrom.equals("INT") && typeTo.equals("FLOAT")) + return true; + + return false; + } + + public static boolean compareLiterals(RewriterDataType lit1, RewriterDataType lit2, boolean allowImplicitTypeConversions) { + if (allowImplicitTypeConversions) + return lit1.getLiteral().equals(literalAs(lit1.getType(), lit2)); + return lit1.getLiteral().equals(lit2.getLiteral()); + } + + public static Object literalAs(String type, RewriterDataType literal) { + switch (type) { + case "FLOAT": + return literal.floatLiteral(); + case "INT": + return literal.intLiteral(false); + case "BOOL": + return literal.boolLiteral(); + default: + return null; + } + } + + public static RuleContext buildDefaultContext() { + RuleContext ctx = RewriterContextSettings.getDefaultContext(); + ctx.metaPropagator = new MetaPropagator(ctx); + return ctx; + } + + private static RuleContext lastCtx; + private static Function lastUnfuse; + public static RewriterStatement unfuseOperators(RewriterStatement stmt, final RuleContext ctx) { + return unfuseOperators(ctx).apply(stmt); + } + public static Function unfuseOperators(final RuleContext ctx) { + if (lastCtx == ctx) + return lastUnfuse; + + ArrayList unfuseRules = new ArrayList<>(); + RewriterRuleCollection.substituteFusedOps(unfuseRules, ctx); + RewriterHeuristic heur = new RewriterHeuristic(new RewriterRuleSet(ctx, unfuseRules)); + lastCtx = ctx; + lastUnfuse = heur::apply; + return lastUnfuse; + } + + public static Function buildCanonicalFormConverter(final RuleContext ctx, boolean debug) { + return buildCanonicalFormConverter(ctx, true, debug); + } + + public static Function buildCanonicalFormConverter(final RuleContext ctx, boolean allowInversionCanonicalization, boolean debug) { + ArrayList algebraicCanonicalizationRules = new ArrayList<>(); + RewriterRuleCollection.substituteEquivalentStatements(algebraicCanonicalizationRules, ctx); + RewriterRuleCollection.eliminateMultipleCasts(algebraicCanonicalizationRules, ctx); + RewriterRuleCollection.canonicalizeBooleanStatements(algebraicCanonicalizationRules, ctx); + RewriterRuleCollection.canonicalizeAlgebraicStatements(algebraicCanonicalizationRules, allowInversionCanonicalization, ctx); + RewriterRuleCollection.eliminateMultipleCasts(algebraicCanonicalizationRules, ctx); + RewriterRuleCollection.buildElementWiseAlgebraicCanonicalization(algebraicCanonicalizationRules, ctx); + RewriterHeuristic algebraicCanonicalization = new RewriterHeuristic(new RewriterRuleSet(ctx, algebraicCanonicalizationRules)); + + ArrayList expRules = new ArrayList<>(); + RewriterRuleCollection.expandStreamingExpressions(expRules, ctx); + RewriterHeuristic streamExpansion = new RewriterHeuristic(new RewriterRuleSet(ctx, expRules)); + + ArrayList expArbitraryMatricesRules = new ArrayList<>(); + RewriterRuleCollection.expandArbitraryMatrices(expArbitraryMatricesRules, ctx); + RewriterHeuristic expandArbitraryMatrices = new RewriterHeuristic(new RewriterRuleSet(ctx, expArbitraryMatricesRules)); + + ArrayList pd = new ArrayList<>(); + RewriterRuleCollection.pushdownStreamSelections(pd, ctx); + RewriterRuleCollection.buildElementWiseAlgebraicCanonicalization(pd, ctx); + RewriterRuleCollection.eliminateMultipleCasts(pd, ctx); + RewriterRuleCollection.canonicalizeBooleanStatements(pd, ctx); + RewriterRuleCollection.canonicalizeAlgebraicStatements(pd, allowInversionCanonicalization, ctx); + RewriterHeuristic streamSelectPushdown = new RewriterHeuristic(new RewriterRuleSet(ctx, pd)); + + ArrayList flatten = new ArrayList<>(); + RewriterRuleCollection.flattenOperations(flatten, ctx); + RewriterHeuristic flattenOperations = new RewriterHeuristic(new RewriterRuleSet(ctx, flatten)); + + RewriterHeuristics canonicalFormCreator = new RewriterHeuristics(); + canonicalFormCreator.add("ALGEBRAIC CANONICALIZATION", algebraicCanonicalization); + canonicalFormCreator.add("EXPAND STREAMING EXPRESSIONS", streamExpansion); + canonicalFormCreator.add("EXPAND ARBITRARY MATRICES", expandArbitraryMatrices); + canonicalFormCreator.add("PUSHDOWN STREAM SELECTIONS", streamSelectPushdown); + canonicalFormCreator.add("FOLD CONSTANTS", new RewriterHeuristic(t -> foldConstants(t, ctx))); + //canonicalFormCreator.add("CANON ALGB", new RewriterHeuristic(new RewriterRuleSet(ctx, RewriterRuleCollection.buildElementWiseAlgebraicCanonicalization(new ArrayList<>(), ctx)))); + canonicalFormCreator.add("REPLACE NEGATIONS", new RewriterHeuristic(new RewriterRuleSet(ctx, RewriterRuleCollection.replaceNegation(new ArrayList<>(), ctx)))); + canonicalFormCreator.add("PUSHDOWN STREAM SELECTIONS", streamSelectPushdown); + canonicalFormCreator.add("FLATTEN OPERATIONS", flattenOperations); + + ArrayList canonicalExpand = new ArrayList<>(); + RewriterRuleCollection.canonicalExpandAfterFlattening(canonicalExpand, ctx); + RewriterHeuristic canonicalExpandOps = new RewriterHeuristic(new RewriterRuleSet(ctx, canonicalExpand)); + + ArrayList flattenAlgebraicRewriteList = new ArrayList<>(); + RewriterRuleCollection.flattenedAlgebraRewrites(flattenAlgebraicRewriteList, ctx); + RewriterHeuristic flattenedAlgebraicRewrites = new RewriterHeuristic(new RewriterRuleSet(ctx, flattenAlgebraicRewriteList)); + + RewriterHeuristics afterFlattening = new RewriterHeuristics(); + afterFlattening.add("CANONICAL EXPAND", canonicalExpandOps); + afterFlattening.add("FLATTENED ALGEBRA REWRITES", flattenedAlgebraicRewrites); + + return stmt -> { + stmt = stmt.nestedCopy(true); + stmt = canonicalFormCreator.apply(stmt, (t, r) -> { + if (!debug) + return true; + + if (r != null) + System.out.println("Applying rule: " + r.getName()); + System.out.println(t.toParsableString(ctx)); + return true; + }, debug); + + for (int i = 0; i < 2; i++) { + RewriterUtils.mergeArgLists(stmt, ctx); + stmt = RewriterUtils.pullOutConstants(stmt, ctx); + } + RewriterUtils.mergeArgLists(stmt, ctx); + unfoldExpressions(stmt, ctx); + stmt = RewriterUtils.pullOutConstants(stmt, ctx); + cleanupUnecessaryIndexExpressions(stmt, ctx); + stmt.prepareForHashing(); + stmt.recomputeHashCodes(ctx); + + stmt = afterFlattening.apply(stmt, (t, r) -> { + if (!debug) + return true; + + if (r != null) + System.out.println("Applying rule: " + r.getName()); + System.out.println(t.toParsableString(ctx)); + return true; + }, debug); + + stmt = foldConstants(stmt, ctx); + + for (int i = 0; i < 2; i++) { + RewriterUtils.mergeArgLists(stmt, ctx); + stmt = RewriterUtils.pullOutConstants(stmt, ctx); + } + RewriterUtils.mergeArgLists(stmt, ctx); + + stmt = stmt.getAssertions(ctx).cleanupEClasses(stmt); + unfoldExpressions(stmt, ctx); + stmt.prepareForHashing(); + + if (debug) + System.out.println("PRE1: " + stmt.toParsableString(ctx, false)); + + stmt.compress(); // To remove unnecessary metadata such as assertions that are not encoded in the graph + TopologicalSort.sort(stmt, ctx); + + if (debug) + System.out.println("FINAL1: " + stmt.toParsableString(ctx, false)); + + return stmt; + }; + } + + public static RewriterStatement pullOutConstants(RewriterStatement oldRoot, final RuleContext ctx) { + RewriterStatement newRoot = pullOutConstantsRecursively(oldRoot, ctx, new HashMap<>()); + + // Check if we have to move the assertions to new root + if (newRoot != oldRoot) + oldRoot.moveRootTo(newRoot); + + return newRoot; + } + + private static RewriterStatement pullOutConstantsRecursively(RewriterStatement cur, final RuleContext ctx, Map alreadyModified) { + if (!cur.isInstruction()) + return cur; + + RewriterStatement modified = alreadyModified.get(cur); + + if (modified != null) + return modified; + + alreadyModified.put(cur, cur); + + for (int i = 0; i < cur.getOperands().size(); i++) + cur.getOperands().set(i, pullOutConstantsRecursively(cur.getChild(i), ctx, alreadyModified)); + + cur.updateMetaObjects(el -> pullOutConstantsRecursively(el, ctx, alreadyModified)); + + switch (cur.trueInstruction()) { + case "sum": + return tryPullOutSum(cur, ctx); + } + + return cur; + } + + private static RewriterStatement tryPullOutSum(RewriterStatement sum, final RuleContext ctx) { + // TODO: What happens on multi-index? Then, some unnecessary indices will currently not be pulled out + RewriterStatement idxExpr = sum.getChild(0); + UUID ownerId = (UUID) idxExpr.getMeta("ownerId"); + RewriterStatement sumBody = idxExpr.getChild(1); + + Map checked = new HashMap<>(); + + + if (!checkSubgraphDependency(sumBody, ownerId, checked)) { + // Then we have to remove the sum entirely + List indices = idxExpr.getChild(0).getOperands(); + List components = new ArrayList<>(); + + for (RewriterStatement idx : indices) { + if (idx.isLiteral()) + continue; + RewriterStatement idxFrom = idx.getChild(0); + RewriterStatement idxTo = idx.getChild(1); + RewriterStatement negation = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("-").withOps(/*RewriterStatement.ensureFloat(ctx, idxFrom)*/idxFrom).consolidate(ctx); + RewriterStatement add = RewriterStatement.multiArgInstr(ctx, "+", /*RewriterStatement.ensureFloat(ctx, idxTo)*/idxTo, RewriterStatement.literal(ctx, 1.0D), negation); + components.add(add); + } + + RewriterStatement out = RewriterStatement.multiArgInstr(ctx, "*", sumBody); + out.getChild(0).getOperands().addAll(components); + return foldConstants(out, ctx); + } + + if (isDirectlyDependent(sumBody, ownerId)) + return sum; + + if (sumBody.trueInstruction().equals("*")) { + // We have to assume here, that this instruction is not referenced anywhere else in the graph + List argList = sumBody.getChild(0).getOperands(); + List toRemove = new ArrayList<>(argList.size()); + + for (RewriterStatement stmt : argList) { + if (!checkSubgraphDependency(stmt, ownerId, checked)) + toRemove.add(stmt); + } + + if (!toRemove.isEmpty()) { + argList.removeAll(toRemove); + + if (argList.size() == 1) { + idxExpr.getOperands().set(1, argList.get(0)); + } + + toRemove.add(sum); + + return RewriterStatement.multiArgInstr(ctx, "*", toRemove.toArray(RewriterStatement[]::new)); + } + } else if (sumBody.trueInstruction().equals("+")) { + // TODO: What about sum(+(A, *(a, B)))? We could pull out a + + // We have to assume here, that this instruction is not referenced anywhere else in the graph + List argList = sumBody.getChild(0).getOperands(); + List toRemove = new ArrayList<>(argList.size()); + + for (RewriterStatement stmt : argList) { + if (!checkSubgraphDependency(stmt, ownerId, checked)) + toRemove.add(stmt); + } + + if (!toRemove.isEmpty()) { + argList.removeAll(toRemove); + + if (argList.size() == 1) { + idxExpr.getOperands().set(1, argList.get(0)); + } + + RewriterStatement outerSum = RewriterStatement.multiArgInstr(ctx, "+", toRemove.toArray(RewriterStatement[]::new)); + List mul = new ArrayList<>(); + + for (RewriterStatement idx : idxExpr.getChild(0).getOperands()) { + RewriterStatement neg = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("-").withOps(/*RewriterStatement.ensureFloat(ctx, idx.getChild(0))*/idx.getChild(0)).consolidate(ctx); + RewriterStatement msum = RewriterStatement.multiArgInstr(ctx, "+", /*RewriterStatement.ensureFloat(ctx, idx.getChild(1))*/idx.getChild(1), neg, RewriterStatement.literal(ctx, 1.0)); + mul.add(msum); + } + + mul.add(outerSum); + RewriterStatement mulStmt = RewriterStatement.multiArgInstr(ctx, "*", mul.toArray(RewriterStatement[]::new)); + + return RewriterStatement.multiArgInstr(ctx, "+", mulStmt, sum); + } + } + + return sum; + } + + // Returns true if the subgraph is dependent on the corresponding owner + private static boolean checkSubgraphDependency(RewriterStatement expr, UUID id, Map checked) { + Boolean b = checked.get(expr); + + if (b != null) + return b; + + if (expr.isInstruction() && expr.trueInstruction().equals("_idx")) { + UUID mid = (UUID) expr.getMeta("ownerId"); + boolean isDependent = id.equals(mid); + + if (isDependent) { + checked.put(expr, true); + return true; + } + } + + for (RewriterStatement stmt : expr.getOperands()) { + if (checkSubgraphDependency(stmt, id, checked)) { + checked.put(expr, true); + return true; + } + } + + checked.put(expr, false); + return false; + } + + private static boolean isDirectlyDependent(RewriterStatement child, UUID ownerId) { + if (child.isInstruction() && child.trueInstruction().equals("_idx")) { + UUID mid = (UUID) child.getMeta("_ownerId"); + return ownerId.equals(mid); + } + + return false; + } + + public static RewriterStatement foldConstants(RewriterStatement stmt, final RuleContext ctx) { + Map replaced = new HashMap<>(); + RewriterStatement ret = foldConstantsRecursively(stmt, ctx, replaced); + ret.prepareForHashing(); + ret.recomputeHashCodes(ctx); + return ret; + } + + private static RewriterStatement foldConstantsRecursively(RewriterStatement cur, final RuleContext ctx, Map alreadyFolded) { + if (!cur.isInstruction()) + return cur; + + RewriterStatement folded = alreadyFolded.get(cur); + + if (folded != null) + return folded; + + alreadyFolded.put(cur, cur); + + for (int i = 0; i < cur.getOperands().size(); i++) + cur.getOperands().set(i, foldConstantsRecursively(cur.getChild(i), ctx, alreadyFolded)); + + cur.updateMetaObjects(el -> foldConstantsRecursively(el, ctx, alreadyFolded)); + + RewriterStatement ret = cur; + + switch (cur.trueInstruction()) { + case "+": + case "*": + case "min": + case "max": + ret = foldNaryReducible(cur, ctx); + break; + case "_EClass": + ret = foldEClass(cur, ctx); + break; + default: + if (cur.getOperands().size() == 1) + ret = foldUnary(cur, ctx); + break; + } + + ret.refreshReturnType(ctx); + alreadyFolded.put(cur, ret); + return ret; + } + + private static RewriterStatement foldEClass(RewriterStatement stmt, final RuleContext ctx) { + RewriterStatement lit = stmt.getLiteralStatement(); + if (lit != null) + return lit; + return stmt; + } + + private static RewriterStatement foldNaryReducible(RewriterStatement stmt, final RuleContext ctx) { + List argList; + if (stmt.getChild(0).isArgumentList()) + argList = stmt.getChild(0).getOperands(); + else + argList = stmt.getOperands(); + + if (argList.isEmpty()) + throw new IllegalArgumentException(stmt.toString(ctx)); + + if (stmt.isInstruction() && (stmt.trueInstruction().equals("min") || stmt.trueInstruction().equals("max")) && argList.size() == 1 && !List.of("FLOAT", "INT", "BOOL").contains(argList.get(0).getResultingDataType(ctx))) + return stmt; + + if (argList.size() < 2) + return argList.get(0); + + int[] literals = IntStream.range(0, argList.size()).filter(i -> argList.get(i).isLiteral()).toArray(); + + if (literals.length == 1) { + Object literal = argList.get(literals[0]).getLiteral(); + if (literal instanceof Number) { + RewriterStatement overwrite = ConstantFoldingUtils.overwritesLiteral((Number) literal, stmt.trueInstruction(), ctx); + if (overwrite != null) + return overwrite; + } + + // Check if is neutral element + if (ConstantFoldingUtils.isNeutralElement(argList.get(literals[0]).getLiteral(), stmt.trueInstruction())) { + RewriterStatement neutral = argList.get(literals[0]); + argList.remove(literals[0]); + + if (argList.size() == 1) + return argList.get(0); + else if (argList.isEmpty()) + return neutral; + } + } + + if (literals.length < 2) + return stmt; + + String rType = stmt.getResultingDataType(ctx); + + BiFunction foldingFunction = ConstantFoldingUtils.foldingBiFunction(stmt.trueInstruction(), rType); + + RewriterDataType foldedLiteral = new RewriterDataType(); + Number val = null; + + for (int literal : literals) + val = foldingFunction.apply(val, argList.get(literal)); + + + RewriterStatement overwrite = ConstantFoldingUtils.overwritesLiteral(val, stmt.trueInstruction(), ctx); + if (overwrite != null) + return overwrite; + + foldedLiteral.as(val.toString()).ofType(rType).asLiteral(val).consolidate(ctx); + + argList.removeIf(RewriterStatement::isLiteral); + + if (argList.isEmpty() || !ConstantFoldingUtils.isNeutralElement(foldedLiteral.getLiteral(), stmt.trueInstruction())) + argList.add(foldedLiteral); + + ConstantFoldingUtils.cancelOutNary(stmt.trueInstruction(), argList); + + if (argList.size() == 1) + return argList.get(0); + + return stmt; + } + + private static RewriterStatement foldUnary(RewriterStatement stmt, final RuleContext ctx) { + RewriterStatement child = stmt.getChild(0); + + if (!child.isLiteral()) + return stmt; + + boolean isFloat = stmt.getResultingDataType(ctx).equals("FLOAT"); + + switch (stmt.trueInstruction()) { + case "inv": + if (isFloat) + return RewriterStatement.literal(ctx, 1.0 / child.floatLiteral()); + else + return RewriterStatement.literal(ctx, 1L / child.intLiteral()); + case "-": + if (isFloat) + return RewriterStatement.literal(ctx, -child.floatLiteral()); + else + return RewriterStatement.literal(ctx, -child.intLiteral()); + } + + // Not implemented yet + return stmt; + } + + public static RewriterStatement cleanupUnecessaryIndexExpressions(RewriterStatement stmt, final RuleContext ctx) { + RewriterStatement mNew = cleanupIndexExprRecursively(stmt, ctx); + + if (mNew != null) + stmt.moveRootTo(mNew); + + recursivePostCleanup(mNew != null ? mNew : stmt); + + return mNew; + } + + private static RewriterStatement cleanupIndexExprRecursively(RewriterStatement cur, final RuleContext ctx) { + for (int i = 0; i < cur.getOperands().size(); i++) { + RewriterStatement mNew = cleanupIndexExprRecursively(cur.getChild(i), ctx); + + if (mNew != null) + cur.getOperands().set(i, mNew); + } + + return cleanupIndexExpr(cur); + } + + private static void recursivePostCleanup(RewriterStatement cur) { + for (RewriterStatement child : cur.getOperands()) + recursivePostCleanup(child); + + postCleanupIndexExpr(cur); + } + + private static RewriterStatement cleanupIndexExpr(RewriterStatement cur) { + if (!cur.isInstruction() || !cur.trueInstruction().equals("sum")) + return null; + + RewriterStatement base = cur; + cur = cur.getChild(0); + + if (!cur.isInstruction() || !cur.trueInstruction().equals("_idxExpr")) + return null; + + if (!cur.getChild(1).isInstruction() || !cur.getChild(1).trueInstruction().equals("ifelse") || !cur.getChild(1,2).isLiteral() || cur.getChild(1,2).floatLiteral() != 0.0D) + return null; + + RewriterStatement query = cur.getChild(1, 0); + + if (query.isInstruction() && query.trueInstruction().equals("==")) { + RewriterStatement idx1 = query.getChild(0); + RewriterStatement idx2 = query.getChild(1); + + if (idx1.isInstruction() && idx2.isInstruction() && idx1.trueInstruction().equals("_idx") && idx2.trueInstruction().equals("_idx")) { + List indices = cur.getChild(0).getOperands(); + RewriterStatement indexFromUpperLevel = null; + if (idx1 == idx2) { + cur.getOperands().set(1, cur.getChild(1, 1)); + } else if (indices.contains(idx1)) { + boolean removed = indices.remove(idx2); + indexFromUpperLevel = removed ? null : idx2; + + if (removed) { + cur.getOperands().set(1, cur.getChild(1, 1)); + cur.getChild(1).forEachPreOrder(cur2 -> { + for (int i = 0; i < cur2.getOperands().size(); i++) { + if (cur2.getChild(i).equals(idx2)) + cur2.getOperands().set(i, idx1); + } + + return true; + }, true); + } + } else if (indices.contains(idx2)) { + indexFromUpperLevel = idx1; + } + + if (indexFromUpperLevel != null) { + cur.getOperands().set(1, cur.getChild(1, 1)); + final RewriterStatement fIdxUpperLevel = indexFromUpperLevel; + final RewriterStatement fIdxLowerLevel = idx1 == indexFromUpperLevel ? idx2 : idx1; + cur.getChild(1).forEachPreOrder(cur2 -> { + for (int i = 0; i < cur2.getOperands().size(); i++) { + if (cur2.getChild(i).equals(fIdxLowerLevel)) + cur2.getOperands().set(i, fIdxUpperLevel); + } + + return true; + }, true); + indices.remove(idx2); + } + + if (indices.isEmpty()) { + return cur.getChild(1); + } + } + } + + return base; + } + + // To unify ifelse (e.g. ifelse(a == b, a+b, a-b) => ifelse(a == b, a+a, a-b) + private static void postCleanupIndexExpr(RewriterStatement cur) { + if (!cur.isInstruction() || !cur.trueInstruction().equals("ifelse") || !cur.getChild(2).isLiteral() || cur.getChild(2).floatLiteral() != 0.0D) + return; + + RewriterStatement query = cur.getChild(0); + + if (query.isInstruction() && query.trueInstruction().equals("==")) { + RewriterStatement idx1 = query.getChild(0); + RewriterStatement idx2 = query.getChild(1); + + if (idx1.isInstruction() && idx2.isInstruction() && idx1.trueInstruction().equals("_idx") && idx2.trueInstruction().equals("_idx")) { + // Then we just choose the first index + cur.getChild(1).forEachPreOrder(cur2 -> { + for (int i = 0; i < cur2.getOperands().size(); i++) { + if (cur2.getChild(i).equals(idx2)) + cur2.getOperands().set(i, idx1); + } + + return true; + }, true); + cur.getChild(2).forEachPreOrder(cur2 -> { + for (int i = 0; i < cur2.getOperands().size(); i++) { + if (cur2.getChild(i).equals(idx2)) + cur2.getOperands().set(i, idx1); + } + + return true; + }, true); + } + } + } + + public static void renameIllegalVarnames(final RuleContext ctx, RewriterStatement... stmts) { + MutableInt matrixVarCtr = new MutableInt(0); + MutableInt scalarVarCtr = new MutableInt(0); + + Set varnames = new HashSet<>(); + for (RewriterStatement stmt : stmts) { + stmt.forEachPreOrder(cur -> { + if (cur.isInstruction()) + return true; + + varnames.add(cur.getId()); + return true; + }, false); + } + + for (RewriterStatement stmt : stmts) { + stmt.forEachPreOrder(cur -> { + if (cur.isInstruction() || cur.isLiteral()) + return true; + + boolean isMatrix = cur.getResultingDataType(ctx).equals("MATRIX"); + + if (cur.getId().equals("?")) { + cur.rename(getVarname(varnames, isMatrix ? matrixVarCtr : scalarVarCtr, isMatrix)); + return true; + } + + if (cur.getId().contains("_")) { + cur.rename(getVarname(varnames, isMatrix? matrixVarCtr : scalarVarCtr, isMatrix)); + } + + try { + UUID.fromString(cur.getId()); + // If it could parse, then we should rename + cur.rename(getVarname(varnames, isMatrix ? matrixVarCtr : scalarVarCtr, isMatrix)); + return true; + } catch (Exception e) { + // Then this is not a UUID + } + + return true; + }, false); + } + } + + private static String getVarname(Set existingNames, MutableInt mInt, boolean matrix) { + char origChar; + + if (matrix) + origChar = 'A'; + else + origChar = 'a'; + + char ch = (char)(origChar + mInt.getAndIncrement()); + + while (existingNames.contains(String.valueOf(ch))) + ch = (char)(origChar + mInt.getAndIncrement()); + + return String.valueOf(ch); + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewriter/utils/StatementUtils.java b/src/main/java/org/apache/sysds/hops/rewriter/utils/StatementUtils.java new file mode 100644 index 00000000000..055e2691bfb --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewriter/utils/StatementUtils.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewriter.utils; + +import org.apache.sysds.hops.rewriter.RewriterInstruction; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; + +public class StatementUtils { + public static RewriterStatement max(final RuleContext ctx, RewriterStatement... of) { + if (of.length == 1) + return of[0]; + + if (of.length == 2) + return new RewriterInstruction("max", ctx, of); + + throw new UnsupportedOperationException(); + } + + public static RewriterStatement min(final RuleContext ctx, RewriterStatement... of) { + if (of.length == 1) + return of[0]; + + if (of.length == 2) + return new RewriterInstruction("min", ctx, of); + + throw new UnsupportedOperationException(); + } + + public static RewriterStatement length(final RuleContext ctx, RewriterStatement matrix) { + if (!matrix.getResultingDataType(ctx).equals("MATRIX")) + throw new IllegalArgumentException(matrix.toParsableString(ctx)); + + return new RewriterInstruction("*", ctx, matrix.getNRow(), matrix.getNCol()); + } + + public static RewriterStatement add(final RuleContext ctx, RewriterStatement... terms) { + if (terms.length == 1) + return terms[0]; + + return new RewriterInstruction("+", ctx, new RewriterInstruction("argList", ctx, terms)); + } +} diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java b/src/main/java/org/apache/sysds/utils/Statistics.java index e6fdf5db3cd..561a99a7d3a 100644 --- a/src/main/java/org/apache/sysds/utils/Statistics.java +++ b/src/main/java/org/apache/sysds/utils/Statistics.java @@ -47,6 +47,7 @@ import org.apache.sysds.utils.stats.RecompileStatistics; import org.apache.sysds.utils.stats.SparkStatistics; import org.apache.sysds.utils.stats.TransformStatistics; +import scala.Tuple2; import java.lang.management.CompilationMXBean; import java.lang.management.GarbageCollectorMXBean; @@ -338,6 +339,35 @@ public static void stopRunTimer() { public static long getRunTime() { return execEndTime - execStartTime; } + + private static HashMap appliedGeneratedRewrites = new HashMap<>(); + private static HashMap, Integer> appliedGeneratedRewritesCounts = new HashMap<>(); + private static boolean recordGeneratedRewrites = false; + private static String currentTestName = ""; + + public static void recordAppliedGeneratedRewrites(boolean record) { + recordGeneratedRewrites = record; + } + + public static void applyGeneratedRewrite(String rewrite) { + if (recordGeneratedRewrites) { + appliedGeneratedRewrites.compute(rewrite, (k, v) -> v == null ? 1 : v + 1); + if (!currentTestName.isEmpty()) + appliedGeneratedRewritesCounts.compute(new Tuple2<>(rewrite, currentTestName), (k, v) -> v == null ? 1 : v + 1); + } + } + + public static Map getAppliedRewrites() { + return appliedGeneratedRewrites; + } + + public static Map, Integer> getAdvancedAppliedRewrites() { + return appliedGeneratedRewritesCounts; + } + + public static void setCurrentTestName(String testName) { + currentTestName = testName; + } public static void reset() { diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java index 6b280301afb..8e496c189f8 100644 --- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java +++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java @@ -31,6 +31,8 @@ import java.net.InetSocketAddress; import java.net.ServerSocket; import java.nio.charset.Charset; +import java.nio.file.Files; +import java.nio.file.Paths; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -57,6 +59,8 @@ import org.apache.sysds.conf.DMLConfig; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.hops.fedplanner.FTypes.FType; +import org.apache.sysds.hops.rewriter.generated.RewriteAutomaticallyGenerated; +import org.apache.sysds.hops.rewriter.RewriterRuntimeUtils; import org.apache.sysds.lops.Lop; import org.apache.sysds.lops.compile.Dag; import org.apache.sysds.parser.ParseException; @@ -90,6 +94,7 @@ import org.junit.After; import org.junit.Assert; import org.junit.Before; +import scala.Tuple4; /** *

@@ -105,6 +110,51 @@ * */ public abstract class AutomatedTestBase { + protected static final boolean RECORD_GENERATED_REWRITES = false; + protected static final boolean ALLOW_GENERATED_REWRITES = false; + protected static final String BASE_DATA_DIR = null; + + + ///// THESE SHOULD NOT BE MODIFIED ///// + private static String currentTestName = ""; + + + static { + RewriterRuntimeUtils.setupIfNecessary(); + + if (RECORD_GENERATED_REWRITES) { + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + StringBuilder csvBuilder2 = new StringBuilder(); + csvBuilder2.append("Rewrite;Count\n"); + + Statistics.getAppliedRewrites().forEach((k, v) -> { + csvBuilder2.append(k); + csvBuilder2.append(';'); + csvBuilder2.append(v); + csvBuilder2.append('\n'); + }); + + StringBuilder csvBuilder3 = new StringBuilder(); + csvBuilder3.append("Rewrite;TestName;Count\n"); + + Statistics.getAdvancedAppliedRewrites().forEach((k, v) -> { + csvBuilder3.append(k._1); + csvBuilder3.append(';'); + csvBuilder3.append(k._2); + csvBuilder3.append(';'); + csvBuilder3.append(v); + csvBuilder3.append('\n'); + }); + + try { + Files.writeString(Paths.get(BASE_DATA_DIR + "applied_rewrites.csv"), csvBuilder2.toString()); + Files.writeString(Paths.get(BASE_DATA_DIR + "rewrite_info.csv"), csvBuilder3.toString()); + } catch (IOException e) { + e.printStackTrace(); + } + })); + } + } private static final Log LOG = LogFactory.getLog(AutomatedTestBase.class.getName()); @@ -1139,6 +1189,9 @@ protected void runRScript() { */ protected void runRScript(boolean newWay) { + if (RewriterRuntimeUtils.interceptAll) + return; + String executionFile = sourceDirectory + selectedTest + ".R"; if(fullRScriptName != null) executionFile = fullRScriptName; @@ -1388,6 +1441,21 @@ protected ByteArrayOutputStream runTest(boolean newWay, boolean exceptionExpecte String errMessage, int maxSparkInst) { try{ final List out = new ArrayList<>(); + + if (RECORD_GENERATED_REWRITES) { + if (currentTestName == null || !currentTestName.equals(this.getClass().getSimpleName())) { + currentTestName = this.getClass().getSimpleName(); + } + + Statistics.reset(); + RewriteAutomaticallyGenerated.totalTimeNanos = 0; + RewriteAutomaticallyGenerated.callCount = 0; + RewriteAutomaticallyGenerated.maxTimeNanos = -1; + + Statistics.recordAppliedGeneratedRewrites(true); + Statistics.setCurrentTestName(currentTestName); + } + Thread t = new Thread( () -> out.add(runTestWithTimeout(newWay, exceptionExpected, expectedException, errMessage, maxSparkInst)), "TestRunner_main"); @@ -1437,6 +1505,10 @@ private ByteArrayOutputStream runTestWithTimeout(boolean newWay, boolean excepti cleanupScratchSpace(); ArrayList args = new ArrayList<>(); + if (ALLOW_GENERATED_REWRITES) { + args.add("-applyGeneratedRewrites"); + } + // setup arguments to SystemDS if(DEBUG) { args.add("-Dsystemds.logging=trace"); diff --git a/src/test/java/org/apache/sysds/test/applications/L2SVMTest.java b/src/test/java/org/apache/sysds/test/applications/L2SVMTest.java index 534b058425a..42bf618f5a2 100644 --- a/src/test/java/org/apache/sysds/test/applications/L2SVMTest.java +++ b/src/test/java/org/apache/sysds/test/applications/L2SVMTest.java @@ -49,7 +49,7 @@ public L2SVMTest(int rows, int cols, double sp, boolean intercept) { numRecords = rows; numFeatures = cols; sparsity = sp; - intercept = this.intercept; + this.intercept = intercept; } @Parameters diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterNormalFormTests.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterNormalFormTests.java new file mode 100644 index 00000000000..add648bbc62 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterNormalFormTests.java @@ -0,0 +1,561 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.hops.rewriter.TopologicalSort; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.util.function.Function; + +public class RewriterNormalFormTests { + protected static final Log LOG = LogFactory.getLog(RewriterNormalFormTests.class.getName()); + + private static RuleContext ctx; + private static Function canonicalConverter; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, false); + } + + //e.g., matrix(1,nrow(X),ncol(X))/X -> 1/X + @Test + public void testUnnecessaryVectorize() { + RewriterStatement stmt1 = RewriterUtils.parse("/(const(A, 1.0), A)", ctx, "MATRIX:A", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("/(1.0, A)", ctx, "MATRIX:A", "LITERAL_FLOAT:1.0"); + + assert match(stmt1, stmt2); + } + + @Test + public void testRemoveUnnecessaryBinaryOperation() { + RewriterStatement stmt1 = RewriterUtils.parse("*(1.0, A)", ctx, "MATRIX:A", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("A", ctx, "MATRIX:A", "LITERAL_FLOAT:1.0"); + + assert match(stmt1, stmt2); + } + + @Test + public void testFuseDatagenAndBinaryOperation() { + RewriterStatement stmt1 = RewriterUtils.parse("*(rand(nrow(A), ncol(A), -1.0, 1.0), a)", ctx, "MATRIX:A", "FLOAT:a", "LITERAL_FLOAT:1.0,-1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("rand(nrow(A), ncol(A), -(a), a)", ctx, "MATRIX:A", "FLOAT:a"); + + assert match(stmt1, stmt2); + } + + //@Test + public void testFuseDatagenAndMinusOperation() { + RewriterStatement stmt1 = RewriterUtils.parse("-(rand(nrow(A), ncol(A), -2.0, 1.0))", ctx, "MATRIX:A", "LITERAL_FLOAT:1.0,-2.0"); + RewriterStatement stmt2 = RewriterUtils.parse("rand(nrow(A), ncol(A), -1.0, 2.0)", ctx, "MATRIX:A", "LITERAL_FLOAT:-1.0,2.0"); + + assert match(stmt1, stmt2); + } + + @Test + public void testCanonicalizeMatrixMultScalarAdd() { + RewriterStatement stmt1 = RewriterUtils.parse("+(eps, %*%(A, t(B)))", ctx, "MATRIX:A,B", "FLOAT:eps"); + RewriterStatement stmt2 = RewriterUtils.parse("+(%*%(A, t(B)), eps)", ctx, "MATRIX:A,B", "FLOAT:eps"); + + assert match(stmt1, stmt2); + } + + @Test + public void testCanonicalizeMatrixMultScalarAdd2() { + RewriterStatement stmt1 = RewriterUtils.parse("-(%*%(A, t(B)), eps)", ctx, "MATRIX:A,B", "FLOAT:eps"); + RewriterStatement stmt2 = RewriterUtils.parse("+(%*%(A, t(B)), -(eps))", ctx, "MATRIX:A,B", "FLOAT:eps"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyMultiBinaryToBinaryOperation() { + RewriterStatement stmt1 = RewriterUtils.parse("-(1.0, *(A,B))", ctx, "MATRIX:A,B", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("1-*(A, B)", ctx, "MATRIX:A,B", "LITERAL_FLOAT:1.0"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyDistributiveBinaryOperation() { + RewriterStatement stmt1 = RewriterUtils.parse("-(A, *(B,A))", ctx, "MATRIX:A,B,C", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("*(-(1.0,B), A)", ctx, "MATRIX:A,B,C", "LITERAL_FLOAT:1.0"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyBushyBinaryOperation() { + RewriterStatement stmt1 = RewriterUtils.parse("*(A,*(B, %*%(C, colVec(D))))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("*(*(A,B), %*%(C, colVec(D)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert RewriterStatement.MatcherContext.exactMatch(ctx, stmt1, stmt2).match(); + } + + @Test + public void testSimplifyUnaryAggReorgOperation() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(t(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0"); + + assert match(stmt1, stmt2); + } + + @Test + public void testRemoveUnnecessaryAggregates() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(rowSums(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyBinaryMatrixScalarOperation() { + RewriterStatement stmt1 = RewriterUtils.parse("as.scalar(*(A,a))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("*(as.scalar(A),a)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0"); + + assert match(stmt1, stmt2); + } + + @Test + public void testPushdownUnaryAggTransposeOperation() { + RewriterStatement stmt1 = RewriterUtils.parse("colSums(t(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("t(rowSums(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0"); + + assert match(stmt1, stmt2); + } + + @Test + public void testPushdownCSETransposeScalarOperation() { + // Introduce a dummy instruction * as I don't support the assignment operator + RewriterStatement stmt1 = RewriterUtils.parse("*(t(A), t(sq(A)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0"); + RewriterStatement stmt2 = RewriterUtils.parse("*(t(A), sq(t(A)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0"); + + assert match(stmt1, stmt2); + } + + @Test + public void testPushdownSumBinaryMult() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(*(a,A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0"); + RewriterStatement stmt2 = RewriterUtils.parse("*(a, sum(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyTraceMatrixMult() { + RewriterStatement stmt1 = RewriterUtils.parse("trace(%*%(A,B))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(*(A, t(B)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifySlicedMatrixMult() { + RewriterStatement stmt1 = RewriterUtils.parse("[](%*%(A,B), 1, 1)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("as.scalar(%*%(rowVec(A), colVec(B)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0", "LITERAL_INT:1"); + + assert match(stmt1, stmt2); + } + + @Test + public void testRemoveUnnecessaryReorgOperation() { + RewriterStatement stmt1 = RewriterUtils.parse("t(t(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("A", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0", "LITERAL_INT:1"); + + assert match(stmt1, stmt2); + } + + //@Test + public void testRemoveUnnecessaryReorgOperation2() { + RewriterStatement stmt1 = RewriterUtils.parse("rev(rev(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("A", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0", "LITERAL_INT:1"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyTransposeAggBinBinaryChains() { + RewriterStatement stmt1 = RewriterUtils.parse("t(+(%*%(t(A),t(B)), C))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("+(%*%(B,A), t(C))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0", "LITERAL_INT:1"); + + assert match(stmt1, stmt2); + } + + @Test + public void testRemoveUnnecessaryMinus() { + RewriterStatement stmt1 = RewriterUtils.parse("-(-(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("A", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0", "LITERAL_INT:1"); + + assert match(stmt1, stmt2); + } + + @Test + public void testFuseLogNzUnaryOperation() { + RewriterStatement stmt1 = RewriterUtils.parse("*(!=(A,0.0), log(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("log_nz(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1"); + + assert match(stmt1, stmt2); + } + + @Test + public void testFuseLogNzBinaryOperation() { + RewriterStatement stmt1 = RewriterUtils.parse("*(!=(A,0.0), log(A, a))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("log_nz(A, a)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1"); + + assert match(stmt1, stmt2); + } + + //@Test + public void testSimplifyNotOverComparisons() { + RewriterStatement stmt1 = RewriterUtils.parse("!(>(A,B))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("<=(A,B)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1"); + + assert match(stmt1, stmt2); + } + + ///// DYNAMIC SIMPLIFICATIONS ////// + + @Test + public void testRemoveEmptyRightIndexing() { + // We do not directly support the specification of nnz, but we can emulate such a matrix by multiplying with 0 + RewriterStatement stmt1 = RewriterUtils.parse("[](*(A, 0.0), 1, nrow(A), 1, 1)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("const(colVec(A), 0.0)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1"); + + assert match(stmt1, stmt2); + } + + @Test + public void testRemoveUnnecessaryRightIndexing() { + RewriterStatement stmt1 = RewriterUtils.parse("[](colVec(A), 1, nrow(A), 1, 1)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("colVec(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1"); + + assert match(stmt1, stmt2); + } + + @Test + public void testRemoveUnnecessaryReorgOperation3() { + RewriterStatement stmt1 = RewriterUtils.parse("t(cellMat(A)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("cellMat(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1"); + + assert match(stmt1, stmt2); + } + + //@Test + public void testRemoveUnnecessaryOuterProduct() { + RewriterStatement stmt1 = RewriterUtils.parse("*(A, %*%(colVec(B), const(t(colVec(B)), 1.0)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("*(A, colVec(B))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1"); + + assert match(stmt1, stmt2); + } + + @Test + public void testRemoveUnnecessaryIfElseOperation() { + // Ifelse is not directly supported yet but only on scalars. Thus, we will our index expression syntax to reflect that statement + // Note that we "cheated" here a bit as we index using nrow(A) and ncol(A). We would not get a match if we used nrow(B)... + RewriterStatement stmt1 = RewriterUtils.parse("_m($1:_idx(1, nrow(A)), $2:_idx(1, ncol(A)), ifelse(TRUE, [](A, $1, $2), [](B, $1, $2)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE"); + RewriterStatement stmt2 = RewriterUtils.parse("A", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE"); + + assert match(stmt1, stmt2); + } + + @Test + public void testFuseDatagenAndReorgOperation() { + RewriterStatement stmt1 = RewriterUtils.parse("t(rand(i, 1, 0.0, 1.0))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("rand(1, i, 0.0, 1.0)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyColwiseAggregate() { + RewriterStatement stmt1 = RewriterUtils.parse("colSums(rowVec(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("rowVec(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyRowwiseAggregate() { + RewriterStatement stmt1 = RewriterUtils.parse("rowSums(colVec(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("colVec(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + // We don't have broadcasting semantics + @Test + public void testSimplifyColSumsMVMult() { + RewriterStatement stmt1 = RewriterUtils.parse("colSums(*(colVec(A), colVec(B)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("%*%(t(colVec(B)), colVec(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + // We don't have broadcasting semantics + @Test + public void testSimplifyRowSumsMVMult() { + RewriterStatement stmt1 = RewriterUtils.parse("rowSums(*(rowVec(A), rowVec(B)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("%*%(rowVec(A), t(rowVec(B)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyUnnecessaryAggregate() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(cellMat(A)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("as.scalar(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyEmptyAggregate() { + // We emulate an empty matrix by multiplying by zero + RewriterStatement stmt1 = RewriterUtils.parse("sum(*(A, 0.0))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("0.0", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyEmptyReorgOperation() { + // We emulate an empty matrix by multiplying by zero + RewriterStatement stmt1 = RewriterUtils.parse("t(*(A, 0.0))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("const(t(A), 0.0)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + // This is a hacky workaround + @Test + public void testSimplifyEmptyMatrixMult() { + // We emulate an empty matrix by multiplying by zero + // Note that we pass the dimension info of the matrix multiply to get the same e-class assertions + RewriterStatement stmt1 = RewriterUtils.parse("%*%(*(A, 0.0), B)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("const(%*%(A, B), 0.0)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + // We need to explicitly assert A and B + stmt2.givenThatEqual(stmt2.getChild(0, 1).getNRow(), stmt2.getChild(0, 0).getNCol(), ctx); + stmt2.recomputeAssertions(); + + assert match(stmt1, stmt2, true); + } + + @Test + public void testSimplifyEmptyMatrixMult2() { + RewriterStatement stmt1 = RewriterUtils.parse("%*%(colVec(A), cast.MATRIX(1.0))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("colVec(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyScalarMatrixMult() { + RewriterStatement stmt1 = RewriterUtils.parse("%*%(colVec(A), cast.MATRIX(a))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("*(colVec(A), as.scalar(cast.MATRIX(a)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyDistributiveMatrixMult() { + RewriterStatement stmt1 = RewriterUtils.parse("+(%*%(A, B), %*%(A, C))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("%*%(A, +(B, C)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + // Note that we did not implement the overloaded diag(A) operation as we defined diag(A) as setting all other entries to zero (which is not how it is actually handled by SystemDS) + // In this case, we obtain the same rewrite, even though the diag operation is different + @Test + public void testSimplifySumDiagToTrace() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(diag(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("trace(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + // Note that we did not implement the overloaded diag(A) operation as we defined diag(A) as setting all other entries to zero (which is not how it is actually handled by SystemDS) + // In this case, we obtain the same equivalence, but in case of our implementation the rewrite would not be beneficial + @Test + public void testPushdownBinaryOperationOnDiag() { + RewriterStatement stmt1 = RewriterUtils.parse("*(diag(A), a)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("diag(*(A, a))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testPushdownSumOnAdditiveBinary() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(+(A, B))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("+(sum(A), sum(B))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + // We need to assert that the dimensions are the same, which we currently cannot do implicitly through an expression + stmt2.givenThatEqualDimensions(stmt2.getChild(0, 0), stmt2.getChild(1, 0), ctx); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyDotProductSum() { + RewriterStatement stmt1 = RewriterUtils.parse("cast.MATRIX(sum(sq(colVec(A))))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("%*%(t(colVec(A)), colVec(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testFuseSumSquared() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(sq(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("sumSq(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testFuseAxpyBinaryOperationChain() { + RewriterStatement stmt1 = RewriterUtils.parse("+(A, *(a, B))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("+*(A, a, B)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testFuseAxpyBinaryOperationChain2() { + RewriterStatement stmt1 = RewriterUtils.parse("-(A, *(a, B))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("-*(A, a, B)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testReorderMinusMatrixMult() { + RewriterStatement stmt1 = RewriterUtils.parse("%*%(-(t(A)), B)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("-(%*%(t(A), B))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifySumMatrixMult() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(%*%(A, B))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(*(t(colSums(A)), rowSums(B)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyEmptyBinaryOperation() { + RewriterStatement stmt1 = RewriterUtils.parse("*(A, const(A, 0.0))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("const(A, 0.0)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyEmptyBinaryOperation2() { + RewriterStatement stmt1 = RewriterUtils.parse("+(A, const(A, 0.0))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("A", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyEmptyBinaryOperation3() { + RewriterStatement stmt1 = RewriterUtils.parse("-(A, const(A, 0.0))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("A", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + //@Test + public void testSimplifyScalarMVBinaryOperation() { + RewriterStatement stmt1 = RewriterUtils.parse("*(A, colVec(colVec(B)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("*(A, as.scalar(B))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + @Test + public void testSimplifyNnzComputation() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(!=(A, 0.0))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("_nnz(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + // We only support concrete literals (which is a current limitation of this framework) + @Test + public void testSimplifyNrowNcolComputation() { + // We simulate a matrix with known dimensions by doing a concrete left-indexing + RewriterStatement stmt1 = RewriterUtils.parse("nrow([](A, 1, 5, 1, 5))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1,5", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("5", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1,5", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + // We only support concrete literals (which is a current limitation of this framework) + @Test + public void testSimplifyNrowNcolComputation2() { + // We simulate a matrix with known dimensions by doing a concrete left-indexing + RewriterStatement stmt1 = RewriterUtils.parse("ncol([](A, 1, 5, 1, 5))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1,5", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("5", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1,5", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + // We only support concrete literals (which is a current limitation of this framework) + @Test + public void testSimplifyNrowNcolComputation3() { + // We simulate a matrix with known dimensions by doing a concrete left-indexing + RewriterStatement stmt1 = RewriterUtils.parse("length([](A, 1, 5, 1, 5))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1,5", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + RewriterStatement stmt2 = RewriterUtils.parse("25", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1,25", "LITERAL_BOOL:TRUE,FALSE", "INT:i"); + + assert match(stmt1, stmt2); + } + + private boolean match(RewriterStatement stmt1, RewriterStatement stmt2) { + return match(stmt1, stmt2, false); + } + + private boolean match(RewriterStatement stmt1, RewriterStatement stmt2, boolean debug) { + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + return RewriterStatement.MatcherContext.exactMatch(ctx, stmt1, stmt2).debug(debug).match(); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterRuleValidationTest.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterRuleValidationTest.java new file mode 100644 index 00000000000..3bab52ba8b0 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterRuleValidationTest.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite; + +import org.apache.sysds.hops.rewriter.generated.RewriteAutomaticallyGenerated; +import org.apache.sysds.hops.rewriter.estimators.RewriterCostEstimator; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleCreator; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleSet; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.junit.BeforeClass; + +import java.io.FileWriter; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.List; +import java.util.function.Function; + +public class RewriterRuleValidationTest { + + public static String RAW_FILE_PATH; // Must be specified + public static String FILE_PATH; // Must be specified + + private static RuleContext ctx; + private static Function canonicalConverter; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, false); + } + + + //@Test + public void test() { + try { + List lines = Files.readAllLines(Paths.get(RAW_FILE_PATH)); + RewriterRuleSet ruleSet = RewriterRuleSet.deserialize(lines, ctx); + RewriterRuleCreator ruleCreator = new RewriterRuleCreator(ctx); + + int ctr = 0; + for (RewriterRule rule : ruleSet.getRules()) { + if (ctr % 10 == 0) + System.out.println("Done: " + ctr + " / " + ruleSet.getRules().size()); + + ctr++; + try { + System.out.println(rule.getStmt1().toParsableString(ctx) + " => " + rule.getStmt2().toParsableString(ctx)); + long preCost = RewriterCostEstimator.estimateCost(rule.getStmt1(), ctx); + long postCost = RewriterCostEstimator.estimateCost(rule.getStmt2(), ctx); + System.out.println(ruleCreator.registerRule(rule, preCost, postCost, true, canonicalConverter)); + } catch (Exception e) { + e.printStackTrace(); + } + } + //System.out.println(ruleSet.toJavaCode("GeneratedRewriteClass", false)); + String serialized = ruleCreator.getRuleSet().serialize(); + //System.out.println(serialized); + + try (FileWriter writer = new FileWriter(FILE_PATH)) { + writer.write(serialized); + } catch (IOException ex) { + ex.printStackTrace(); + } + } catch (IOException e) { + e.printStackTrace(); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterStreamTests.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterStreamTests.java new file mode 100644 index 00000000000..58c324c7f22 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterStreamTests.java @@ -0,0 +1,1751 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite; + +import org.apache.commons.lang3.mutable.MutableObject; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleCreator; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertionUtils; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; +import org.apache.sysds.hops.rewriter.estimators.RewriterCostEstimator; +import org.apache.sysds.hops.rewriter.RewriterDatabase; +import org.apache.sysds.hops.rewriter.rule.RewriterHeuristic; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleBuilder; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleSet; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.junit.BeforeClass; +import org.junit.Test; +import scala.Tuple2; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; + +public class RewriterStreamTests { + protected static final Log LOG = LogFactory.getLog(RewriterStreamTests.class.getName()); + + private static RuleContext ctx; + private static Function canonicalConverter; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, false); + } + + @Test + public void testAdditionFloat1() { + RewriterStatement stmt = RewriterUtils.parse("+(+(a, b), 1)", ctx, "MATRIX:A,B,C", "FLOAT:a,b", "LITERAL_INT:0,1"); + stmt = canonicalConverter.apply(stmt); + LOG.info(stmt.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, RewriterUtils.parse("+(argList(a, b, 1))", ctx, "FLOAT:a,b", "LITERAL_INT:0,1"), stmt)); + } + + @Test + public void testAdditionFloat2() { + RewriterStatement stmt = RewriterUtils.parse("+(1, +(a, b))", ctx, "FLOAT:a,b", "LITERAL_INT:0,1"); + stmt = canonicalConverter.apply(stmt); + LOG.info(stmt.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, RewriterUtils.parse("+(argList(a, b, 1))", ctx, "FLOAT:a,b", "LITERAL_INT:0,1"), stmt)); + } + + @Test + public void testAdditionMatrix1() { + RewriterStatement stmt1 = RewriterUtils.parse("+(+(A, B), 1)", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("+(+(B, 1), A)", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testSubtractionFloat1() { + RewriterStatement stmt = RewriterUtils.parse("+(-(a, b), 1)", ctx, "MATRIX:A,B,C", "FLOAT:a,b", "LITERAL_INT:0,1"); + RewriterStatement stmt2 = RewriterUtils.parse("+(argList(-(b), a, 1))", ctx, "FLOAT:a,b", "LITERAL_INT:0,1"); + stmt = canonicalConverter.apply(stmt); + stmt2 = canonicalConverter.apply(stmt2); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testSubtractionFloat2() { + RewriterStatement stmt = RewriterUtils.parse("+(1, -(a, -(b, c)))", ctx, "MATRIX:A,B,C", "FLOAT:a,b,c", "LITERAL_INT:0,1"); + RewriterStatement stmt2 = RewriterUtils.parse("+(argList(-(b), a, c, 1))", ctx, "FLOAT:a,b, c", "LITERAL_INT:0,1"); + stmt = canonicalConverter.apply(stmt); + stmt2 = canonicalConverter.apply(stmt2); + LOG.info(stmt.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + // Fusion will no longer be pursued + /*@Test + public void testFusedPlanMatrixGeneration() { + RewriterStatement stmt = RewriterUtils.parse("+(1, +(A, B))", ctx, "MATRIX:A,B", "LITERAL_INT:0,1"); + stmt = converter.apply(stmt); + RewriterStatement fused = RewriterUtils.buildFusedPlan(stmt, ctx); + LOG.info("Orig: " + stmt.toParsableString(ctx, true)); + LOG.info("Fused: " + (fused == null ? null : fused.toParsableString(ctx, true))); + } + + @Test + public void testFusedPlanAggregationGeneration() { + RewriterStatement stmt = RewriterUtils.parse("sum(*(/(A, B), B))", ctx, "MATRIX:A,B", "LITERAL_INT:0,1"); + stmt = converter.apply(stmt); + RewriterStatement fused = RewriterUtils.buildFusedPlan(stmt, ctx); + LOG.info("Orig: " + stmt.toParsableString(ctx, true)); + LOG.info("Fused: " + (fused == null ? null : fused.toParsableString(ctx, true))); + } + + @Test + public void testFusedPlanAdvancedAggregationGeneration() { + RewriterStatement stmt = RewriterUtils.parse("sum(*(t(A), B))", ctx, "MATRIX:A,B", "LITERAL_INT:0,1"); + stmt = converter.apply(stmt); + RewriterStatement fused = RewriterUtils.buildFusedPlan(stmt, ctx); + LOG.info("Orig: " + stmt.toParsableString(ctx, true)); + LOG.info("Fused: " + (fused == null ? null : fused.toParsableString(ctx, true))); + }*/ + + @Test + public void testReorgEquivalence() { + RewriterStatement stmt1 = RewriterUtils.parse("t(t(A))", ctx, "MATRIX:A"); + RewriterStatement stmt2 = RewriterUtils.parse("A", ctx, "MATRIX:A"); + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testTraceEquivalence1() { + RewriterStatement stmt = RewriterUtils.parse("trace(%*%(A, B))", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(*(t(A), B))", ctx, "MATRIX:A,B"); + stmt = canonicalConverter.apply(stmt); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testTraceEquivalence2() { + RewriterStatement stmt = RewriterUtils.parse("trace(%*%(A, B))", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(*(A, t(B)))", ctx, "MATRIX:A,B"); + stmt = canonicalConverter.apply(stmt); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testTraceEquivalence3() { + RewriterStatement stmt = RewriterUtils.parse("trace(*(A, B))", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(*(diag(A), diag(B)))", ctx, "MATRIX:A,B"); + stmt = canonicalConverter.apply(stmt); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testAggEquivalence() { + RewriterStatement stmt = RewriterUtils.parse("sum(%*%(A, B))", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(*(colSums(A), t(rowSums(B))))", ctx, "MATRIX:A,B"); + stmt = canonicalConverter.apply(stmt); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testSumEquality6() { + RewriterStatement stmt = RewriterUtils.parse("sum(+(B, sum(*(a, A))))", ctx, "MATRIX:A,B", "FLOAT:a"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(+(B, *(a, sum(A))))", ctx, "MATRIX:A,B", "FLOAT:a"); + stmt = canonicalConverter.apply(stmt); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testSumEquality() { + RewriterStatement stmt = RewriterUtils.parse("sum(+(B, sum(*(a, A))))", ctx, "MATRIX:A,B", "FLOAT:a"); + //RewriterStatement stmt2 = RewriterUtils.parse("+(*(a, length(A)), sum(+(B, sum(A))))", ctx, "MATRIX:A,B", "FLOAT:a"); + RewriterStatement stmt3 = RewriterUtils.parse("sum(+(B, *(a, sum(A))))", ctx, "MATRIX:A,B", "FLOAT:a"); + stmt = canonicalConverter.apply(stmt); + //stmt2 = canonicalConverter.apply(stmt2); + stmt3 = canonicalConverter.apply(stmt3); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt3.toParsableString(ctx, true)); + LOG.info("=========="); + //LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt3, stmt)); + //assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testArgListSelectionPushdown() { + RewriterStatement stmt = RewriterUtils.parse("[](+(A, 1), 1, 1)", ctx, "MATRIX:A", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("+([](A, 1, 1), 1)", ctx, "MATRIX:A", "LITERAL_INT:1"); + stmt = canonicalConverter.apply(stmt); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testDistributiveLaw1() { + RewriterStatement stmt = RewriterUtils.parse("*(+(a, b), c)", ctx, "FLOAT:a,b,c"); + RewriterStatement stmt2 = RewriterUtils.parse("+(*(a, c), *(b, c))", ctx, "FLOAT:a,b,c"); + stmt = canonicalConverter.apply(stmt); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testDistributiveLaw2() { + RewriterStatement stmt = RewriterUtils.parse("*(a, +(b, c))", ctx, "FLOAT:a,b,c"); + RewriterStatement stmt2 = RewriterUtils.parse("+(*(a, b), *(a, c))", ctx, "FLOAT:a,b,c"); + stmt = canonicalConverter.apply(stmt); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testEClassProperties() { + RewriterStatement stmt = RewriterUtils.parse("*(+(A, B), nrow(A))", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("*(+(A, B), nrow(B))", ctx, "MATRIX:A,B"); + stmt = canonicalConverter.apply(stmt); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testRealExamples1() { + RewriterStatement stmt1 = RewriterUtils.parse("t(%*%(t(U),V))", ctx, "MATRIX:U,V"); + RewriterStatement stmt2 = RewriterUtils.parse("%*%(t(V), U)", ctx, "MATRIX:U,V"); + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + //TopologicalSort.sort(stmt1, ctx); + //TopologicalSort.sort(stmt2, ctx); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void test() { + RewriterStatement stmt = RewriterUtils.parse("t(A)", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("A", ctx, "FLOAT:A,B"); + stmt = canonicalConverter.apply(stmt); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert !stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void test2() { + RewriterStatement stmt = RewriterUtils.parse("+(0.0,*(2,%*%(t(X),T)))", ctx, "MATRIX:T,X", "FLOAT:0.0", "INT:2"); + stmt = canonicalConverter.apply(stmt); + + LOG.info(stmt.toParsableString(ctx)); + } + + @Test + public void test3() { + RewriterStatement stmt = RewriterUtils.parse("+(+(A,X),t(X))", ctx, "MATRIX:X,A"); + stmt = canonicalConverter.apply(stmt); + + LOG.info(stmt.toParsableString(ctx)); + } + + @Test + public void test4() { + RewriterDatabase db = new RewriterDatabase(); + RewriterStatement stmt = RewriterUtils.parse("trace(%*%(A, B))", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(*(A, t(B)))", ctx, "MATRIX:A,B"); + stmt = canonicalConverter.apply(stmt); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + db.insertEntry(ctx, stmt); + + assert !db.insertEntry(ctx, stmt2); + } + + @Test + public void testForFailure() { + RewriterStatement stmt = RewriterUtils.parse("[](hIndex,i,i,1,1)", ctx, "MATRIX:hIndex", "INT:i", "LITERAL_INT:1"); + stmt = canonicalConverter.apply(stmt); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + } + + @Test + public void testTypeConversions() { + RewriterStatement stmt1 = RewriterUtils.parse("+(TRUE, 1)", ctx, "LITERAL_BOOL:TRUE", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("+(1, 1)", ctx, "LITERAL_INT:1"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testCSE() { + RewriterStatement stmt1 = RewriterUtils.parse("+(*(a, b), *(b, a))", ctx, "FLOAT:a,b"); + RewriterStatement stmt2 = RewriterUtils.parse("+($1:*(a, b), $1)", ctx, "FLOAT:a,b"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + RewriterDatabase db = new RewriterDatabase(); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + db.insertEntry(ctx, stmt1); + + assert !db.insertEntry(ctx, stmt2); + } + + @Test + public void testExactMatch() { + RewriterStatement stmt1 = RewriterUtils.parse("+(*(a, b), *(b, a))", ctx, "FLOAT:a,b"); + RewriterStatement stmt2 = RewriterUtils.parse("+($1:*(a, b), $1)", ctx, "FLOAT:a,b"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + assert stmt2.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt1, stmt2)); + } + + //@Test + public void testMinEquivalence() { + RewriterStatement stmt1 = RewriterUtils.parse("min(min(A), min(B))", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("min(A, B)", ctx, "MATRIX:A,B"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testSumEquivalence() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(A)", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(t(A))", ctx, "MATRIX:A,B"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + /*@Test + public void testSimpleAlgebra1() { + RewriterStatement stmt1 = RewriterUtils.parse("-(X, *(Y, X))", ctx, "MATRIX:X,Y"); + RewriterStatement stmt2 = RewriterUtils.parse("*(-(1, Y), X)", ctx, "MATRIX:X,Y", "LITERAL_INT:1"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + }*/ + + @Test + public void testSimpleAlgebra2() { + RewriterStatement stmt1 = RewriterUtils.parse("diag(*(X, 7))", ctx, "MATRIX:X,Y", "LITERAL_INT:7"); + RewriterStatement stmt2 = RewriterUtils.parse("*(diag(X), 7)", ctx, "MATRIX:X,Y", "LITERAL_INT:7"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + /*@Test + public void testSimpleAlgebra3() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(+(+(X, 7), Y))", ctx, "MATRIX:X,Y", "LITERAL_INT:7"); + RewriterStatement stmt2 = RewriterUtils.parse("+(+(sum(X), 7), sum(Y))", ctx, "MATRIX:X,Y", "LITERAL_INT:7"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + }*/ + + @Test + public void testSimpleAlgebra4() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(-(+(+(X, 7), Y)))", ctx, "MATRIX:X,Y", "LITERAL_INT:7"); + + RewriterStatement matX = RewriterUtils.parse("X", ctx, "MATRIX:X"); + RewriterStatement matY = RewriterUtils.parse("Y", ctx, "MATRIX:Y"); + Map vars = new HashMap<>(); + vars.put("X", matX); + vars.put("Y", matY); + RewriterStatement stmt2 = RewriterUtils.parse("-(+(sum(+(X, 7)), sum(Y)))", ctx, vars, "LITERAL_INT:7"); + stmt2.givenThatEqual(vars.get("X").getNCol(), vars.get("Y").getNCol(), stmt2, ctx); + stmt2.givenThatEqual(vars.get("X").getNRow(), vars.get("Y").getNRow(), stmt2, ctx); + stmt2 = stmt2.recomputeAssertions(); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testSimpleSumPullOut() { + RewriterStatement stmt1 = RewriterUtils.parse("-(sum(+(A, 7)))", ctx, "MATRIX:A", "LITERAL_FLOAT:7"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(+(-(A), -7))", ctx, "MATRIX:A", "LITERAL_FLOAT:-7"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testSimpleInverseEquivalence() { + RewriterStatement stmt1 = RewriterUtils.parse("inv(A)", ctx, "MATRIX:A,B", "LITERAL_INT:7"); + RewriterStatement stmt2 = RewriterUtils.parse("-(inv(-(A)))", ctx, "MATRIX:A,B", "LITERAL_INT:7"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + //@Test + public void testBackrefInequality() { + // Some example where _backRef() is not the same as another one + // As we need to compare to the meta-data + assert false; + } + + @Test + public void myTest() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(-(X, 7))", ctx, "MATRIX:X,Y", "LITERAL_INT:1,7", "INT:a", "LITERAL_FLOAT:7.0"); + stmt1 = canonicalConverter.apply(stmt1); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + } + + @Test + public void myTest2() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(_idxExpr(_idx(1, 7), -(a)))", ctx, "MATRIX:X,Y", "LITERAL_INT:1,7", "INT:a"); + stmt1 = canonicalConverter.apply(stmt1); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + } + + @Test + public void myTest3() { + RewriterStatement stmt = RewriterUtils.parse("%*%(X,[](B,1,ncol(X),1,ncol(B)))", ctx, "MATRIX:X,B,intercept", "LITERAL_INT:1"); + stmt = canonicalConverter.apply(stmt); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + } + + @Test + public void myTest4() { + RewriterStatement stmt = RewriterUtils.parse("*(CBind(t(KM),KM_cols_select),KM_cols_select)", ctx, "MATRIX:KM,KM_cols_select"); + stmt = canonicalConverter.apply(stmt); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + } + + @Test + public void myTest5() { + RewriterStatement stmt = RewriterUtils.parse("*(CBind(A, A),A)", ctx, "MATRIX:A"); + stmt = canonicalConverter.apply(stmt); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + } + + @Test + public void myTest6() { + RewriterStatement stmt = RewriterUtils.parse("rowSums(<=(D,minD))", ctx, "MATRIX:D,minD"); + stmt = canonicalConverter.apply(stmt); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + } + + @Test + public void myTest7() { + String stmtStr = "MATRIX:combined\n" + + "FLOAT:int0,int496,int236,int618\n" + + "LITERAL_INT:1,2\n" + + "INT:parsertemp71754,int497,int280\n" + + "&(RBind(!=([](combined,1,-(parsertemp71754,int497),1,ncol(combined)),[](combined,2,nrow(combined),1,ncol(combined))),rand(1,1,int0,int496)),RBind(rand(1,1,int618,int236),!=([](combined,1,-(parsertemp71754,int280),1,ncol(combined)),[](combined,2,nrow(combined),1,ncol(combined)))))"; + + RewriterStatement stmt = RewriterUtils.parse(stmtStr, ctx); + stmt = canonicalConverter.apply(stmt); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + } + + @Test + public void myTest8() { + String stmtStr = "MATRIX:prec_chol,X,mu\n" + + "INT:i,k\n" + + "LITERAL_INT:1,5\n" + + "%*%(X,[](prec_chol,1,*(i,ncol(X)),1,5))"; + + RewriterStatement stmt = RewriterUtils.parse(stmtStr, ctx); + stmt = canonicalConverter.apply(stmt); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + } + + @Test + public void myTest9() { + String stmtStr = "MATRIX:A,scale_X,shift_X,parsertemp282257,parsertemp282256,parsertemp282259,parsertemp282258\n" + + "INT:m_ext\n" + + "LITERAL_INT:1\n" + + "+(%*%(diag(scale_X),t(+(%*%(parsertemp282256,A),%*%(shift_X,A)))),%*%(shift_X,[](t(+(parsertemp282257,parsertemp282258)),m_ext,m_ext,1,nrow(parsertemp282259))))"; + + RewriterStatement stmt = RewriterUtils.parse(stmtStr, ctx); + stmt = canonicalConverter.apply(stmt); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + } + + @Test + public void myTest10() { + String stmtStr = "MATRIX:P,minD,D,X\n" + + "/(%*%(t(/(<=(D,minD),rowSums(P))),X),t(colSums(/(<=(D,minD),rowSums(P)))))"; + + RewriterStatement stmt = RewriterUtils.parse(stmtStr, ctx); + stmt = canonicalConverter.apply(stmt); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + } + + @Test + public void testConstantFolding1() { + RewriterStatement stmt1 = RewriterUtils.parse("*(1, A)", ctx, "MATRIX:A", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("A", ctx, "MATRIX:A"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testConstantFolding2() { + RewriterStatement stmt1 = RewriterUtils.parse("+(A, 0)", ctx, "MATRIX:A", "LITERAL_INT:0"); + RewriterStatement stmt2 = RewriterUtils.parse("A", ctx, "MATRIX:A"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testConstantFolding3() { + RewriterStatement stmt1 = RewriterUtils.parse("+(A, *(3, -(1, 1)))", ctx, "MATRIX:A", "LITERAL_INT:1,3"); + RewriterStatement stmt2 = RewriterUtils.parse("A", ctx, "MATRIX:A"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testConstantFolding4() { + RewriterStatement stmt1 = RewriterUtils.parse("*(A, 0)", ctx, "MATRIX:A", "LITERAL_FLOAT:0"); + RewriterStatement stmt2 = RewriterUtils.parse("rand(nrow(A), ncol(A), 0, 0)", ctx, "MATRIX:A", "LITERAL_FLOAT:0"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testAdvancedEquivalence1() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(+(A, -7))", ctx, "MATRIX:A", "LITERAL_FLOAT:-7"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(-(A, 7))", ctx, "MATRIX:A", "LITERAL_FLOAT:7"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testInequality() { + RewriterStatement stmt1 = RewriterUtils.parse("/(*(A, A), B)", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("/(*(A, A), sum(B))", ctx, "MATRIX:A,B"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testDiagEquivalence() { + RewriterStatement stmt1 = RewriterUtils.parse("diag(diag(A))", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("diag(A)", ctx, "MATRIX:A,B"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testRIXInequality() { + RewriterStatement stmt1 = RewriterUtils.parse("+(A, [](B, 1, nrow(A), 1, ncol(A)))", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("+(A, B)", ctx, "MATRIX:A,B"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void convergenceTest() { + String stmtStr = "MATRIX:dl_matrix\n" + + "INT:i,j,46307663-5c68-48ba-aa86-c1c36de45dbe\n" + + "LITERAL_INT:1,2\n" + + "[](dl_matrix,+(i,-(2)),-(i,2),1,1)"; + + RewriterStatement stmt = RewriterUtils.parse(stmtStr, ctx); + stmt = canonicalConverter.apply(stmt); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + } + + @Test + public void someTest() { + RewriterStatement stmt1 = RewriterUtils.parse("+([](%*%(A,B),151,151,1,ncol(B)),C)", ctx, "MATRIX:A,B,C", "LITERAL_INT:1,151"); + RewriterStatement stmt2 = RewriterUtils.parse("+([](C,151,151,1,ncol(B)),%*%(A,B))", ctx, "MATRIX:A,B,C", "LITERAL_INT:1,151"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void my_Test() { + RewriterStatement stmt1 = RewriterUtils.parse("[](A, 1, 1, 151, 151)", ctx, "MATRIX:A,B,C", "LITERAL_INT:1,151"); + + stmt1 = canonicalConverter.apply(stmt1); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + } + + @Test + public void testSumEquality2() { + RewriterStatement stmt1 = RewriterUtils.parse("rowSums(colSums(A))", ctx, "MATRIX:A"); + RewriterStatement stmt2 = RewriterUtils.parse("as.matrix(sum(A))", ctx, "MATRIX:A"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testSumEquality3() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(%*%(A, B))", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("as.scalar(%*%(colSums(A), rowSums(B)))", ctx, "MATRIX:A,B"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testSumEquality4() { + RewriterStatement stmt1 = RewriterUtils.parse("%*%(t(colVec(A)), colVec(A))", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("as.matrix(sum(*(colVec(A), colVec(A))))", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testSumEquality5() { + RewriterStatement stmt1 = RewriterUtils.parse("rowSums([](A, 1, nrow(A), 1, 1))", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("[](A, 1, nrow(A), 1, 1)", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testSimpleConvergence() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(a)", ctx, "FLOAT:a"); + + stmt1 = canonicalConverter.apply(stmt1); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + } + + @Test + public void testImplicitInequality() { + RewriterStatement stmt1 = RewriterUtils.parse("+([](A,1, nrow(A), 1, 1), B)", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("+([](A,1, nrow(A), 1, 1), [](B, 1, nrow(B), 1, 1))", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testTraceEquivalence() { + RewriterStatement stmt1 = RewriterUtils.parse("trace(%*%(t(S),R))", ctx, "MATRIX:S,R", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(*(S,R))", ctx, "MATRIX:S,R", "LITERAL_INT:1"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testMMEquivalence() { + RewriterStatement stmt1 = RewriterUtils.parse("%*%(A,*(b, B))", ctx, "MATRIX:A,B", "FLOAT:b", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("*(b, %*%(A, B))", ctx, "MATRIX:A,B", "FLOAT:b", "LITERAL_INT:1"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info(stmt1.getAssertions(ctx)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + LOG.info(stmt2.getAssertions(ctx)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testMMEquivalence2() { + RewriterStatement stmt1 = RewriterUtils.parse("cast.MATRIX(sum(*(t(rowVec(A)), colVec(B))))", ctx, "MATRIX:A,B", "FLOAT:b", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("%*%(rowVec(A), colVec(B))", ctx, "MATRIX:A,B", "FLOAT:b", "LITERAL_INT:1"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testColSumEquivalence4() { + RewriterStatement stmt1 = RewriterUtils.parse("colSums(*(A, b))", ctx, "MATRIX:A,B", "FLOAT:b", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("*(b, colSums(A))", ctx, "MATRIX:A,B", "FLOAT:b", "LITERAL_INT:1"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testColSumEquivalence5() { + RewriterStatement stmt1 = RewriterUtils.parse("colSums(*(A, b))", ctx, "MATRIX:A,B", "FLOAT:b", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("*(b, colSums(A))", ctx, "MATRIX:A,B", "FLOAT:b", "LITERAL_INT:1"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testZeroElimination() { + RewriterStatement stmt1 = RewriterUtils.parse("*(A,0.0)", ctx, "MATRIX:A,B", "FLOAT:b", "LITERAL_INT:1", "LITERAL_FLOAT:0.0"); + RewriterStatement stmt2 = RewriterUtils.parse("const(A, 0.0)", ctx, "MATRIX:A,B", "FLOAT:b", "LITERAL_INT:1", "LITERAL_FLOAT:0.0"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testMMScalarPullout() { + RewriterStatement stmt1 = RewriterUtils.parse("%*%(*(A, b), B)", ctx, "MATRIX:A,B", "FLOAT:b", "LITERAL_INT:1", "LITERAL_FLOAT:0.0"); + RewriterStatement stmt2 = RewriterUtils.parse("*(b, %*%(A, B))", ctx, "MATRIX:A,B", "FLOAT:b", "LITERAL_INT:1", "LITERAL_FLOAT:0.0"); + + long cost1 = RewriterCostEstimator.estimateCost(stmt1, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, ctx); + + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + + assert cost2 == cost1; + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testWrong() { + RewriterStatement stmt1 = RewriterUtils.parse("*(sum(colVec(A)),colSums(B))", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("%*%(colVec(A),colSums(B))", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testWrong2() { + RewriterStatement stmt1 = RewriterUtils.parse("*(a,1.0)", ctx, "FLOAT:a", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("a", ctx, "FLOAT:a", "LITERAL_INT:1"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + RewriterStatement newStmt = canonicalConverter.apply(stmt1); + LOG.info(newStmt); + LOG.info(stmt1); + //stmt2 = canonicalConverter.apply(stmt2); + + /*LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1));*/ + } + + //@Test + public void testRev() { + RewriterStatement stmt1 = RewriterUtils.parse("rev(rev(A))", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("A", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testTrace() { + RewriterStatement stmt1 = RewriterUtils.parse("trace(%*%(B,B))", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(*(A, t(B)))", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + stmt1.compress(); + stmt2.compress(); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testFused1() { + RewriterStatement stmt1 = RewriterUtils.parse("1-*(A, B)", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("-(1.0, *(A, B))", ctx, "MATRIX:A,B", "LITERAL_FLOAT:1.0"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testFused2() { + RewriterStatement stmt1 = RewriterUtils.parse("+(a, 1-*(A, B))", ctx, "MATRIX:A,B", "FLOAT:a"); + RewriterStatement stmt2 = RewriterUtils.parse("-(1.0, -(*(A, B), a))", ctx, "MATRIX:A,B", "FLOAT:a", "LITERAL_FLOAT:1.0"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testFused3() { + RewriterStatement stmt1 = RewriterUtils.parse("log_nz(A)", ctx, "MATRIX:A,B", "FLOAT:a"); + RewriterStatement stmt2 = RewriterUtils.parse("*(!=(0.0, A), log(A))", ctx, "MATRIX:A,B", "LITERAL_FLOAT:0.0"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testFused4() { + RewriterStatement stmt1 = RewriterUtils.parse("log_nz(A, a)", ctx, "MATRIX:A,B", "FLOAT:a"); + RewriterStatement stmt2 = RewriterUtils.parse("*(!=(0.0, A), log(A, a))", ctx, "MATRIX:A,B", "FLOAT:a", "LITERAL_FLOAT:0.0"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testFused5() { + RewriterStatement stmt1 = RewriterUtils.parse("sq(1-*(A,A))", ctx, "MATRIX:A,B", "FLOAT:a"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + } + + @Test + public void testFused6() { + RewriterStatement stmt1 = RewriterUtils.parse("/(A,A)", ctx, "MATRIX:A,B", "FLOAT:a"); + RewriterStatement stmt2 = RewriterUtils.parse("/(A,rev(A))", ctx, "MATRIX:A,B", "FLOAT:a", "LITERAL_FLOAT:0.0"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testFused7() { + RewriterStatement stmt1 = RewriterUtils.parse("+*(A,a,B)", ctx, "MATRIX:A,B", "FLOAT:a"); + RewriterStatement stmt2 = RewriterUtils.parse("+(*(a, B), A)", ctx, "MATRIX:A,B", "FLOAT:a"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testFused8() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(!=(0.0, A))", ctx, "MATRIX:A,B", "LITERAL_FLOAT:0.0"); + RewriterStatement stmt2 = RewriterUtils.parse("_nnz(A)", ctx, "MATRIX:A,B", "FLOAT:a"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testFusedCompilation() { + RewriterStatement stmt1 = RewriterUtils.parse("+(a,*2(1-*(B,B)))", ctx, "MATRIX:A,B", "FLOAT:a"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + } + + @Test + public void testSum() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(+(a,A))", ctx, "MATRIX:A,B", "FLOAT:a"); + RewriterStatement stmt2 = RewriterUtils.parse("+(*(a, length(A)), sum(A))", ctx, "MATRIX:A,B", "FLOAT:a", "LITERAL_FLOAT:0.0"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testRowSums() { + RewriterStatement stmt1 = RewriterUtils.parse("*(rowSums(/(a,C)),b)", ctx, "MATRIX:A,B,C", "FLOAT:a,b"); + RewriterStatement stmt2 = RewriterUtils.parse("rowSums(/(*(a,b),C))", ctx, "MATRIX:A,B,C", "FLOAT:a,b", "LITERAL_FLOAT:0.0"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testRowSums2() { + RewriterStatement stmt1 = RewriterUtils.parse("rowSums(*(A,+(B,1.0)))", ctx, "MATRIX:A,B,C", "FLOAT:a,b", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("+(rowSums(A), rowSums(*(B,A)))", ctx, "MATRIX:A,B,C", "FLOAT:a,b", "LITERAL_FLOAT:1.0"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testDistrib3() { + RewriterStatement stmt1 = RewriterUtils.parse("*(A,+(B,1.0))", ctx, "MATRIX:A,B,C", "FLOAT:a,b", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("+(A, *(B,A))", ctx, "MATRIX:A,B,C", "FLOAT:a,b", "LITERAL_FLOAT:1.0"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testRev2() { + RewriterStatement stmt1 = RewriterUtils.parse("trace(rev(A))", ctx, "MATRIX:A,B,C", "FLOAT:a,b", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("trace(A)", ctx, "MATRIX:A,B,C", "FLOAT:a,b", "LITERAL_FLOAT:1.0"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testSumInequality() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(+(a,*(B,c)))", ctx, "MATRIX:B", "FLOAT:a,c"); + RewriterStatement stmt2 = RewriterUtils.parse("*(a, sum(+(B,c)))", ctx, "MATRIX:B", "FLOAT:a,c", "LITERAL_FLOAT:0.0"); + + LOG.info("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx)); + LOG.info("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx)); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testDiag1() { + RewriterStatement stmt1 = RewriterUtils.parse("diag(+(A, B))", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("+(diag(A), diag(B))", ctx, "MATRIX:A,B"); + + long cost1 = RewriterCostEstimator.estimateCost(stmt1, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, ctx); + + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + assert cost1 > cost2; + } + + @Test + public void testDiag2() { + RewriterStatement stmt1 = RewriterUtils.parse("trace(A)", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("trace(diag(A))", ctx, "MATRIX:A,B"); + + long cost1 = RewriterCostEstimator.estimateCost(stmt1, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, ctx); + + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testDiag3() { + RewriterStatement stmt1 = RewriterUtils.parse("%*%(diag(A), diag(B))", ctx, "MATRIX:A,B"); + RewriterStatement stmt2 = RewriterUtils.parse("*(diag(A), diag(B))", ctx, "MATRIX:A,B"); + + long cost1 = RewriterCostEstimator.estimateCost(stmt1, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, ctx); + + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testConstFold() { + RewriterStatement stmt1 = RewriterUtils.parse("-(+(1.0,a), 1.0)", ctx, "FLOAT:a", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("a", ctx, "FLOAT:a"); + + long cost1 = RewriterCostEstimator.estimateCost(stmt1, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, ctx); + + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + //@Test + public void testConst() { + RewriterStatement stmt1 = RewriterUtils.parse("min(const(A, a))", ctx, "FLOAT:a", "MATRIX:A"); + RewriterStatement stmt2 = RewriterUtils.parse("a", ctx, "FLOAT:a"); + + long cost1 = RewriterCostEstimator.estimateCost(stmt1, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, ctx); + + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testMin() { + RewriterStatement stmt1 = RewriterUtils.parse("+(A, min(B))", ctx, "MATRIX:A,B"); + + long cost1 = RewriterCostEstimator.estimateCost(stmt1, ctx); + + LOG.info("Cost1: " + cost1); + + stmt1 = canonicalConverter.apply(stmt1); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + } + + @Test + public void testBoolDiag() { + RewriterStatement stmt1 = RewriterUtils.parse("diag(!=(A,A))", ctx, "MATRIX:A,B"); + + long cost1 = RewriterCostEstimator.estimateCost(stmt1, ctx); + + LOG.info("Cost1: " + cost1); + + stmt1 = canonicalConverter.apply(stmt1); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + } + + @Test + public void testWrong3() { + RewriterStatement stmt1 = RewriterUtils.parse("*(A, /(A,C))", ctx, "MATRIX:A,C"); + RewriterStatement stmt2 = RewriterUtils.parse("*(sum(*(C,A)), A)", ctx, "MATRIX:A,C"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testWrong4() { + // TODO: Rule "Element selection pushdown" seems to be an issue here + RewriterStatement stmt1 = RewriterUtils.parse("/(A, rev(A))", ctx, "MATRIX:A"); + RewriterStatement stmt2 = RewriterUtils.parse("/(A, A)", ctx, "MATRIX:A"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testWrong5() { + RewriterStatement stmt1 = RewriterUtils.parse("*2(-(B,B))", ctx, "MATRIX:B"); + RewriterStatement stmt2 = RewriterUtils.parse("*2(-(a, B))", ctx, "MATRIX:B", "FLOAT:a"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testWrong6() { + RewriterStatement stmt1 = RewriterUtils.parse("%*%(t(+(A,A)), B)", ctx, "MATRIX:A,B,C", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("%*%(t(A), +(B, B))", ctx, "MATRIX:A,B,C", "FLOAT:a"); + + RewriterStatement can1 = canonicalConverter.apply(stmt1); + RewriterStatement can2 = canonicalConverter.apply(stmt2); + + stmt1 = RewriterRuleCreator.createCommonForm(stmt1, stmt2, can1, can2, ctx)._1; + RewriterAssertions assertions = RewriterAssertionUtils.buildImplicitAssertions(stmt1, ctx); + RewriterAssertionUtils.buildImplicitAssertion(stmt2, assertions, stmt1, ctx); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + LOG.info(RewriterCostEstimator.getRawCostFunction(stmt1, ctx, new MutableObject<>(assertions), false).toParsableString(ctx)); + LOG.info(RewriterCostEstimator.getRawCostFunction(stmt2, ctx, new MutableObject<>(assertions), false).toParsableString(ctx)); + LOG.info(RewriterCostEstimator.compareCosts(List.of(stmt1, stmt2), assertions, ctx, false, 5)); + Set> t = RewriterCostEstimator.findOptima(RewriterCostEstimator.compareCosts(List.of(stmt1, stmt2), assertions, ctx, true, 5)); + LOG.info(t); + + assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, can1, can2)); + } + + @Test + public void testWrong7() { + RewriterStatement stmt1 = RewriterUtils.parse("*(+(B,B),A)", ctx, "MATRIX:A,B,C", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("%*%(t(A), +(B, B))", ctx, "MATRIX:A,B,C", "FLOAT:a"); + + RewriterStatement can1 = canonicalConverter.apply(stmt1); + RewriterStatement can2 = canonicalConverter.apply(stmt2); + + stmt1 = RewriterRuleCreator.createCommonForm(stmt1, stmt2, can1, can2, ctx)._1; + RewriterAssertions assertions = RewriterAssertionUtils.buildImplicitAssertions(stmt1, ctx); + RewriterAssertionUtils.buildImplicitAssertion(stmt2, assertions, stmt1, ctx); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + LOG.info(RewriterCostEstimator.getRawCostFunction(stmt1, ctx, new MutableObject<>(assertions), false).toParsableString(ctx)); + LOG.info(RewriterCostEstimator.getRawCostFunction(stmt2, ctx, new MutableObject<>(assertions), false).toParsableString(ctx)); + LOG.info(RewriterCostEstimator.compareCosts(List.of(stmt1, stmt2), assertions, ctx, false, 5)); + Set> t = RewriterCostEstimator.findOptima(RewriterCostEstimator.compareCosts(List.of(stmt1, stmt2), assertions, ctx, true, 5)); + LOG.info(t); + + assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, can1, can2)); + } + + @Test + public void testConstInequivality() { + RewriterStatement stmt1 = RewriterUtils.parse("%*%(const(A, 0.0), A)", ctx, "MATRIX:A", "LITERAL_FLOAT:0.0"); + RewriterStatement stmt2 = RewriterUtils.parse("const(A, 0.0)", ctx, "MATRIX:A", "LITERAL_FLOAT:0.0"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testSumEquality7() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(+(a, A))", ctx, "MATRIX:A", "FLOAT:a"); + RewriterStatement stmt2 = RewriterUtils.parse("+(*(a, length(A)), sum(A))", ctx, "MATRIX:A", "FLOAT:a"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testSumEquality8() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(const(A,1.0))", ctx, "MATRIX:A", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("length(A)", ctx, "MATRIX:A", "FLOAT:a"); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testSparsityComparison() { + RewriterStatement stmt1 = RewriterUtils.parse("+(*(A, B),*(A, C))", ctx, "MATRIX:A,B,C", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("*(A, +(B, C))", ctx, "MATRIX:A,B,C", "FLOAT:a"); + + RewriterStatement can1 = canonicalConverter.apply(stmt1); + RewriterStatement can2 = canonicalConverter.apply(stmt2); + + stmt1 = RewriterRuleCreator.createCommonForm(stmt1, stmt2, can1, can2, ctx)._1; + RewriterAssertions assertions = RewriterAssertionUtils.buildImplicitAssertions(stmt1, ctx); + RewriterAssertionUtils.buildImplicitAssertion(stmt2, assertions, stmt1, ctx); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + LOG.info(RewriterCostEstimator.getRawCostFunction(stmt1, ctx, new MutableObject<>(assertions), false).toParsableString(ctx)); + LOG.info(RewriterCostEstimator.getRawCostFunction(stmt2, ctx, new MutableObject<>(assertions), false).toParsableString(ctx)); + LOG.info(RewriterCostEstimator.compareCosts(List.of(stmt1, stmt2), assertions, ctx, false, 5)); + Set> t = RewriterCostEstimator.findOptima(RewriterCostEstimator.compareCosts(List.of(stmt1, stmt2), assertions, ctx, true, 5)); + LOG.info(t); + + assert can2.match(RewriterStatement.MatcherContext.exactMatch(ctx, can1, can2)); + } + + @Test + public void testTEST() { + RewriterStatement stmt1 = RewriterUtils.parse("t(/(<=(A,B),rowSums(<=(C,B))))", ctx, "MATRIX:A,B,C,D,E", "LITERAL_FLOAT:1.0", "LITERAL_INT:1"); + + stmt1 = canonicalConverter.apply(stmt1); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterTopologySortTests.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterTopologySortTests.java new file mode 100644 index 00000000000..a34a73b3774 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterTopologySortTests.java @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.hops.rewriter.TopologicalSort; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.util.function.Function; + +public class RewriterTopologySortTests { + protected static final Log LOG = LogFactory.getLog(RewriterTopologySortTests.class.getName()); + private static RuleContext ctx; + private static Function converter; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + converter = RewriterUtils.buildCanonicalFormConverter(ctx, false); + } + + @Test + public void testSimpleEquivalence1() { + RewriterStatement stmt = RewriterUtils.parse("+(*(a, b), *(a, c))", ctx, "FLOAT:a,b,c"); + RewriterStatement stmt2 = RewriterUtils.parse("+(*(b, a), *(c, a))", ctx, "FLOAT:a,b,c"); + stmt = converter.apply(stmt); + stmt2 = converter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testSimpleEquivalence2() { + // Here, a and b are indistinguishable + // Thus, the topological sort has to decide a random but consistent order + RewriterStatement stmt = RewriterUtils.parse("+(*(a, b), *(b, a))", ctx, "FLOAT:a,b"); + RewriterStatement stmt2 = RewriterUtils.parse("+(*(b, a), *(b, a))", ctx, "FLOAT:a,b"); + stmt = converter.apply(stmt); + stmt2 = converter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testSimpleEquivalence3() { + RewriterStatement stmt = RewriterUtils.parse("+(-(*(a, b)), *(b, a))", ctx, "FLOAT:a,b"); + RewriterStatement stmt2 = RewriterUtils.parse("+(*(b, a), -(*(b, a)))", ctx, "FLOAT:a,b"); + stmt = converter.apply(stmt); + stmt2 = converter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testSimpleEquivalence4() { + RewriterStatement stmt = RewriterUtils.parse("+(*(-(a), b), *(b, a))", ctx, "FLOAT:a,b"); + RewriterStatement stmt2 = RewriterUtils.parse("+(*(a, -(b)), *(b, a))", ctx, "FLOAT:a,b"); + stmt = converter.apply(stmt); + stmt2 = converter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testSimpleEquivalence5() { + RewriterStatement stmt = RewriterUtils.parse("+(1, 2)", ctx, "LITERAL_INT:1,2"); + RewriterStatement stmt2 = RewriterUtils.parse("+(2, 1)", ctx, "LITERAL_INT:1,2"); + stmt = converter.apply(stmt); + stmt2 = converter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testSimpleEquivalence6() { + RewriterStatement stmt = RewriterUtils.parse("+(*(a, b), *(*(a, b), c))", ctx, "FLOAT:a,b,c"); + RewriterStatement stmt2 = RewriterUtils.parse("+(*(*(a, b), c), *(a, b))", ctx, "FLOAT:a,b,c"); + stmt = converter.apply(stmt); + stmt2 = converter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testSimpleEquivalence7() { + RewriterStatement stmt = RewriterUtils.parse("+(*(a, b), *(/(a, b), /(b, a)))", ctx, "FLOAT:a,b,c"); + RewriterStatement stmt2 = RewriterUtils.parse("+(*(/(a, b), /(b, a)), *(a, b))", ctx, "FLOAT:a,b,c"); + stmt = converter.apply(stmt); + stmt2 = converter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testSimpleEquivalence9() { + RewriterStatement stmt = RewriterUtils.parse("+(*(-(a), b), *(a, a))", ctx, "FLOAT:a,b"); + RewriterStatement stmt2 = RewriterUtils.parse("+(*(a, -(b)), *(a, a))", ctx, "FLOAT:a,b"); + stmt = converter.apply(stmt); + stmt2 = converter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void testSimpleEquivalence10() { + RewriterStatement stmt = RewriterUtils.parse("+(argList(*(argList(a,b)),*(argList(a,inv(b),b,inv(a)))))", ctx, "FLOAT:a,b,c"); + RewriterStatement stmt2 = RewriterUtils.parse("+(argList(*(argList(a,inv(b),b,inv(a))),*(argList(a,b))))", ctx, "FLOAT:a,b,c"); + TopologicalSort.sort(stmt, ctx); + TopologicalSort.sort(stmt2, ctx); + + LOG.info("=========="); + LOG.info(stmt.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt)); + } + + @Test + public void test4() { + RewriterStatement stmt = RewriterUtils.parse("sum(*(A, A))", ctx, "MATRIX:A"); + stmt = converter.apply(stmt); + + LOG.info(stmt.toParsableString(ctx, true)); + } + + @Test + public void test5() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(_idxExpr($1:_idx(1,_EClass(argList(nrow(A),nrow(B)))),*(argList([](B,$1,$1),[](A,$1,$1)))))", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(_idxExpr($1:_idx(1,_EClass(argList(nrow(B),nrow(A)))),*(argList([](B,$1,$1),[](A,$1,$1)))))", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + TopologicalSort.sort(stmt1, ctx); + TopologicalSort.sort(stmt2, ctx); + LOG.info(stmt1.toParsableString(ctx)); + LOG.info(stmt2.toParsableString(ctx)); + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testComplex1() { + RewriterStatement stmt1 = RewriterUtils.parse("_m($1:_idx(1,ncol(V)),$2:_idx(1,ncol(U)),sum(_idxExpr($3:_idx(1,_EClass(argList(nrow(V),nrow(U)))),*(argList([](V,$3,$1),[](U,$3,$2))))))", ctx, "MATRIX:U,V", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("_m($1:_idx(1,ncol(V)),$2:_idx(1,ncol(U)),sum(_idxExpr($3:_idx(1,_EClass(argList(nrow(U),nrow(V)))),*(argList([](U,$3,$2),[](V,$3,$1))))))", ctx, "MATRIX:U,V", "LITERAL_INT:1"); + + TopologicalSort.sort(stmt1, ctx); + TopologicalSort.sort(stmt2, ctx); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testComplex2() { + RewriterStatement stmt1 = RewriterUtils.parse("_m($1:_idx(1,ncol(V)),$2:_idx(1,ncol(U)),sum(_idxExpr($3:_idx(1,_EClass(argList(nrow(V),nrow(U)))),1.0)))", ctx, "MATRIX:U,V", "LITERAL_INT:1", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("_m($1:_idx(1,ncol(V)),$2:_idx(1,ncol(U)),sum(_idxExpr($3:_idx(1,_EClass(argList(nrow(U),nrow(V)))),1.0)))", ctx, "MATRIX:U,V", "LITERAL_INT:1", "LITERAL_FLOAT:1.0"); + + TopologicalSort.sort(stmt1, ctx); + TopologicalSort.sort(stmt2, ctx); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testComplex3() { + RewriterStatement stmt1 = RewriterUtils.parse("_m(ncol(V),ncol(U),as.float(_EClass(argList(nrow(V),nrow(U))))))", ctx, "MATRIX:U,V", "LITERAL_INT:1", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("_m(ncol(V),ncol(U),as.float(_EClass(argList(nrow(U),nrow(V))))))", ctx, "MATRIX:U,V", "LITERAL_INT:1", "LITERAL_FLOAT:1.0"); + + TopologicalSort.sort(stmt1, ctx); + TopologicalSort.sort(stmt2, ctx); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void testSimple() { + RewriterStatement stmt = RewriterUtils.parse("*(argList(a, sum(b), a))", ctx, "FLOAT:a,b"); + TopologicalSort.sort(stmt, ctx); + + String parsableString = stmt.toParsableString(ctx); + LOG.info(parsableString); + assert "*(argList(a,a,sum(b)))".equals(parsableString); + } + + @Test + public void test2() { + RewriterStatement stmt1 = RewriterUtils.parse("+(argList(_EClass(argList(1, ncol(A), ncol(B))), _EClass(argList(nrow(C),nrow(B),nrow(A)))))", ctx, "MATRIX:A,B,C", "LITERAL_INT:1", "LITERAL_FLOAT:1.0"); + RewriterStatement stmt2 = RewriterUtils.parse("+(argList(_EClass(argList(1, ncol(A), ncol(B))), _EClass(argList(nrow(A),nrow(C),nrow(B)))))", ctx, "MATRIX:A,B,C", "LITERAL_INT:1", "LITERAL_FLOAT:1.0"); + + TopologicalSort.sort(stmt1, ctx); + TopologicalSort.sort(stmt2, ctx); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/AssertionTests.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/AssertionTests.java new file mode 100644 index 00000000000..6f9db682bcb --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/AssertionTests.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite.functions; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.test.component.codegen.rewrite.RewriterTopologySortTests; +import org.junit.BeforeClass; +import org.junit.Test; + +public class AssertionTests { + protected static final Log LOG = LogFactory.getLog(AssertionTests.class.getName()); + + private static RuleContext ctx; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + } + + @Test + public void test1() { + RewriterAssertions assertion = new RewriterAssertions(ctx); + RewriterStatement stmt1 = RewriterUtils.parse("*(*(nrow(A), nrow(B)), *(nrow(C), nrow(A)))", ctx, "MATRIX:A,B,C"); + RewriterStatement nrowA = stmt1.getOperands().get(0).getOperands().get(0); + RewriterStatement nrowB = stmt1.getOperands().get(0).getOperands().get(1); + RewriterStatement nrowC = stmt1.getOperands().get(1).getOperands().get(0); + RewriterStatement nrowA2 = stmt1.getOperands().get(1).getOperands().get(1); + + assert assertion.addEqualityAssertion(nrowA, nrowC, stmt1); + LOG.info(assertion.getAssertions(nrowA)); + + assert !assertion.addEqualityAssertion(nrowA, nrowC, stmt1); + LOG.info(assertion.getAssertions(nrowC)); + + assert assertion.addEqualityAssertion(nrowC, nrowB, stmt1); + LOG.info(assertion.getAssertions(nrowC)); + + LOG.info(assertion.getAssertions(nrowA2)); + } + +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeExecutionTest.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeExecutionTest.java new file mode 100644 index 00000000000..481a896bad5 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeExecutionTest.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite.functions; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.hops.rewriter.dml.DMLExecutor; +import org.apache.sysds.test.component.codegen.rewrite.RewriterTopologySortTests; +import org.junit.Test; + +public class CodeExecutionTest { + protected static final Log LOG = LogFactory.getLog(CodeExecutionTest.class.getName()); + + @Test + public void test() { + String str = "X = rand(rows=5000, cols=5000, sparsity=0.1)\n" + + "Y = rand(rows=5000, cols=5000, sparsity=0.1)\n" + + "R = X*Y\n" + + "print(lineage(R))"; + DMLScript.APPLY_GENERATED_REWRITES = true; + DMLExecutor.executeCode(str, false, "-applyGeneratedRewrites"); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeGenConditionTests.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeGenConditionTests.java new file mode 100644 index 00000000000..5f0c6da1e0f --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeGenConditionTests.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite.functions; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.hops.rewriter.codegen.CodeGenCondition; +import org.apache.sysds.test.component.codegen.rewrite.RewriterTopologySortTests; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +public class CodeGenConditionTests { + protected static final Log LOG = LogFactory.getLog(CodeGenConditionTests.class.getName()); + + private static RuleContext ctx; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + } + + @Test + public void test1() { + String ruleStr = "MATRIX:A\n" + + "\n" + + "t(t(A))\n" + + "=>\n" + + "A"; + + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + + List cgcs = CodeGenCondition.buildCondition(List.of(rule), 1, ctx); + } + + @Test + public void test2() { + String ruleStr = "MATRIX:A\n" + + "\n" + + "t(t(A))\n" + + "=>\n" + + "A"; + + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + + String ruleStr2 = "MATRIX:A,B\n" + + "\n" + + "+(t(A), t(B))\n" + + "=>\n" + + "t(+(A, B))"; + + RewriterRule rule2 = RewriterUtils.parseRule(ruleStr2, ctx); + + String ruleStr3 = "MATRIX:A,B\n" + + "\n" + + "%*%(t(A), t(B))\n" + + "=>\n" + + "t(%*%(B, A))"; + + RewriterRule rule3 = RewriterUtils.parseRule(ruleStr3, ctx); + + Map fNames = new HashMap<>(); + fNames.put(rule, "rule1"); + fNames.put(rule2, "rule2"); + fNames.put(rule3, "rule3"); + + List cgcs = CodeGenCondition.buildCondition(List.of(rule, rule2, rule3), 1, ctx); + LOG.info(CodeGenCondition.getSelectionString(cgcs, 0, fNames, ctx)); + } + + @Test + public void test3() { + String ruleStr = "MATRIX:A\nFLOAT:b\n" + + "\n" + + "!=(-(b,rev(A)),A)\n" + + "=>\n" + + "!=(A,-(b,A))"; + + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + + String ruleStr2 = "MATRIX:A,B\n" + + "\n" + + "!=(-(B,rev(A)),A)\n" + + "=>\n" + + "!=(A,-(B,A))"; + + RewriterRule rule2 = RewriterUtils.parseRule(ruleStr2, ctx); + + String ruleStr3 = "MATRIX:A,B,C\n" + + "\n" + + "+(*(A,C),*(A,B))\n" + + "=>\n" + + "*(A,+(B,C))"; + + RewriterRule rule3 = RewriterUtils.parseRule(ruleStr3, ctx); + + String ruleStr4 = "MATRIX:A,B,C\n" + + "\n" + + "+(*(A,C),*(B,A))\n" + + "=>\n" + + "*(A,+(B,C))"; + + RewriterRule rule4 = RewriterUtils.parseRule(ruleStr4, ctx); + + String ruleStr5 = "MATRIX:B,C\nFLOAT:a\n" + + "\n" + + "+(*(a,C),*(B,a))\n" + + "=>\n" + + "*(a,+(B,C))"; + + RewriterRule rule5 = RewriterUtils.parseRule(ruleStr5, ctx); + + Map fNames = new HashMap<>(); + fNames.put(rule, "rule1"); + fNames.put(rule2, "rule2"); + fNames.put(rule3, "rule3"); + fNames.put(rule4, "rule4"); + fNames.put(rule5, "rule5"); + + List cgcs = CodeGenCondition.buildCondition(List.of(rule, rule2, rule3, rule4, rule5), 1, ctx); + LOG.info(cgcs); + LOG.info(CodeGenCondition.getSelectionString(cgcs, 0, fNames, ctx)); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeGenTests.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeGenTests.java new file mode 100644 index 00000000000..b439b92dd5c --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeGenTests.java @@ -0,0 +1,314 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite.functions; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types; +import org.apache.sysds.hops.AggUnaryOp; +import org.apache.sysds.hops.BinaryOp; +import org.apache.sysds.hops.DataGenOp; +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.LiteralOp; +import org.apache.sysds.hops.ReorgOp; +import org.apache.sysds.hops.rewrite.HopRewriteUtils; +import org.apache.sysds.hops.rewriter.codegen.RewriterCodeGen; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleBuilder; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleSet; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.parser.DataExpression; +import org.apache.sysds.parser.DataIdentifier; +import org.apache.sysds.test.component.codegen.rewrite.RewriterTopologySortTests; +import org.junit.BeforeClass; +import org.junit.Test; +import scala.Tuple2; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.function.Function; + +public class CodeGenTests { + protected static final Log LOG = LogFactory.getLog(CodeGenTests.class.getName()); + + private static RuleContext ctx; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + } + + @Test + public void test1() { + RewriterStatement stmt1 = RewriterUtils.parse("+(1, 1)", ctx, "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("2", ctx, "LITERAL_INT:2"); + + RewriterRule rule = new RewriterRuleBuilder(ctx, "testRule") + .setUnidirectional(true) + .completeRule(stmt1, stmt2) + .build(); + + LOG.info(RewriterCodeGen.generateClass("MRuleTest", List.of(new Tuple2<>("testRule", rule)), false, false, ctx, false, false)); + + try { + Function f = RewriterCodeGen.compileRewrites("MRuleTest", List.of(new Tuple2<>("testRule", rule)), ctx, false, false); + Hop l = new LiteralOp(1); + Hop add = new BinaryOp("test", Types.DataType.SCALAR, Types.ValueType.INT64, Types.OpOp2.PLUS, l, l); + Hop result = f.apply(add); + + assert result instanceof LiteralOp && ((LiteralOp) result).getLongValue() == 2; + } catch (Exception e) { + e.printStackTrace(); + assert false; + } + } + + @Test + public void test2() { + HashMap vars = new HashMap<>(); + vars.put("A", RewriterUtils.parse("A", ctx, "MATRIX:A")); + vars.put("B", RewriterUtils.parse("B", ctx, "MATRIX:B")); + RewriterStatement stmt1 = RewriterUtils.parse("+(t(A), t(B))", ctx, vars); + RewriterStatement stmt2 = RewriterUtils.parse("t(+(A, B))", ctx, vars); + + RewriterRule rule = new RewriterRuleBuilder(ctx, "testRule") + .setUnidirectional(true) + .completeRule(stmt1, stmt2) + .build(); + + LOG.info(RewriterCodeGen.generateClass("MRuleTest", List.of(new Tuple2<>("testRule", rule)), false, false, ctx, false, false)); + + try { + Function f = RewriterCodeGen.compileRewrites("MRuleTest", List.of(new Tuple2<>("testRule", rule)), ctx, false, false); + HashMap inputParams = new HashMap<>(); + inputParams.put(DataExpression.RAND_ROWS, new LiteralOp(100)); + inputParams.put(DataExpression.RAND_COLS, new LiteralOp(100)); + inputParams.put(DataExpression.RAND_MIN, new LiteralOp(0.0)); + inputParams.put(DataExpression.RAND_MAX, new LiteralOp(1.0)); + Hop A = new DataGenOp(Types.OpOpDG.RAND, new DataIdentifier("A", Types.DataType.MATRIX, Types.ValueType.FP64), inputParams); + Hop B = new DataGenOp(Types.OpOpDG.RAND, new DataIdentifier("B", Types.DataType.MATRIX, Types.ValueType.FP64), inputParams); + Hop tA = new ReorgOp("t", Types.DataType.MATRIX, Types.ValueType.FP64, Types.ReOrgOp.TRANS, A); + Hop tB = new ReorgOp("t", Types.DataType.MATRIX, Types.ValueType.FP64, Types.ReOrgOp.TRANS, B); + Hop add = new BinaryOp("test", Types.DataType.MATRIX, Types.ValueType.FP64, Types.OpOp2.PLUS, tA, tB); + Hop result = f.apply(add); + + assert result instanceof ReorgOp && result.getInput().size() == 1 && result.getInput(0) instanceof BinaryOp; + } catch (Exception e) { + e.printStackTrace(); + assert false; + } + } + + @Test + public void test3() { + HashMap vars = new HashMap<>(); + vars.put("A", RewriterUtils.parse("A", ctx, "MATRIX:A")); + vars.put("B", RewriterUtils.parse("B", ctx, "MATRIX:B")); + RewriterStatement stmt1 = RewriterUtils.parse("^(t(A), t(B))", ctx, vars); + RewriterStatement stmt2 = RewriterUtils.parse("t(^(A, B))", ctx, vars); + + RewriterRule rule = new RewriterRuleBuilder(ctx, "testRule") + .setUnidirectional(true) + .completeRule(stmt1, stmt2) + .build(); + + LOG.info(RewriterCodeGen.generateClass("MRuleTest", List.of(new Tuple2<>("testRule", rule)), false, false, ctx, false, false)); + + try { + Function f = RewriterCodeGen.compileRewrites("MRuleTest", List.of(new Tuple2<>("testRule", rule)), ctx, false, false); + HashMap inputParams = new HashMap<>(); + inputParams.put(DataExpression.RAND_ROWS, new LiteralOp(100)); + inputParams.put(DataExpression.RAND_COLS, new LiteralOp(100)); + inputParams.put(DataExpression.RAND_MIN, new LiteralOp(0.0)); + inputParams.put(DataExpression.RAND_MAX, new LiteralOp(1.0)); + Hop A = new DataGenOp(Types.OpOpDG.RAND, new DataIdentifier("A", Types.DataType.MATRIX, Types.ValueType.FP64), inputParams); + Hop B = new DataGenOp(Types.OpOpDG.RAND, new DataIdentifier("B", Types.DataType.MATRIX, Types.ValueType.FP64), inputParams); + Hop tA = new ReorgOp("t", Types.DataType.MATRIX, Types.ValueType.FP64, Types.ReOrgOp.TRANS, A); + Hop tB = new ReorgOp("t", Types.DataType.MATRIX, Types.ValueType.FP64, Types.ReOrgOp.TRANS, B); + Hop pow = new BinaryOp("test", Types.DataType.MATRIX, Types.ValueType.FP64, Types.OpOp2.POW, tA, tB); + Hop result = f.apply(pow); + + assert result instanceof ReorgOp && result.getInput().size() == 1 && result.getInput(0) instanceof BinaryOp; + } catch (Exception e) { + e.printStackTrace(); + assert false; + } + } + + @Test + public void test4() { + HashMap vars = new HashMap<>(); + vars.put("A", RewriterUtils.parse("A", ctx, "MATRIX:A")); + vars.put("B", RewriterUtils.parse("B", ctx, "MATRIX:B")); + RewriterStatement stmt1 = RewriterUtils.parse("%*%(t(A), t(B))", ctx, vars); + RewriterStatement stmt2 = RewriterUtils.parse("t(%*%(B, A))", ctx, vars); + + RewriterRule rule = new RewriterRuleBuilder(ctx, "testRule") + .setUnidirectional(true) + .completeRule(stmt1, stmt2) + .build(); + + LOG.info(RewriterCodeGen.generateClass("MRuleTest", List.of(new Tuple2<>("testRule", rule)), false, false, ctx, false, false)); + + try { + Function f = RewriterCodeGen.compileRewrites("MRuleTest", List.of(new Tuple2<>("testRule", rule)), ctx, false, false); + HashMap inputParams = new HashMap<>(); + inputParams.put(DataExpression.RAND_ROWS, new LiteralOp(100)); + inputParams.put(DataExpression.RAND_COLS, new LiteralOp(100)); + inputParams.put(DataExpression.RAND_MIN, new LiteralOp(0.0)); + inputParams.put(DataExpression.RAND_MAX, new LiteralOp(1.0)); + Hop A = new DataGenOp(Types.OpOpDG.RAND, new DataIdentifier("A", Types.DataType.MATRIX, Types.ValueType.FP64), inputParams); + Hop B = new DataGenOp(Types.OpOpDG.RAND, new DataIdentifier("B", Types.DataType.MATRIX, Types.ValueType.FP64), inputParams); + Hop tA = new ReorgOp("t", Types.DataType.MATRIX, Types.ValueType.FP64, Types.ReOrgOp.TRANS, A); + Hop tB = new ReorgOp("t", Types.DataType.MATRIX, Types.ValueType.FP64, Types.ReOrgOp.TRANS, B); + Hop matmul = HopRewriteUtils.createMatrixMultiply(tA, tB); + Hop result = f.apply(matmul); + + assert result instanceof ReorgOp && result.getInput().size() == 1 && HopRewriteUtils.isMatrixMultiply(result.getInput(0)); + } catch (Exception e) { + e.printStackTrace(); + assert false; + } + } + + @Test + public void test5() { + HashMap vars = new HashMap<>(); + vars.put("A", RewriterUtils.parse("A", ctx, "MATRIX:A")); + vars.put("B", RewriterUtils.parse("B", ctx, "MATRIX:B")); + RewriterStatement stmt1 = RewriterUtils.parse("rowSums(t(A))", ctx, vars); + RewriterStatement stmt2 = RewriterUtils.parse("t(colSums(A))", ctx, vars); + + RewriterRule rule = new RewriterRuleBuilder(ctx, "testRule") + .setUnidirectional(true) + .completeRule(stmt1, stmt2) + .build(); + + LOG.info(RewriterCodeGen.generateClass("MRuleTest", List.of(new Tuple2<>("testRule", rule)), false, false, ctx, false, false)); + + try { + Function f = RewriterCodeGen.compileRewrites("MRuleTest", List.of(new Tuple2<>("testRule", rule)), ctx, false, false); + HashMap inputParams = new HashMap<>(); + inputParams.put(DataExpression.RAND_ROWS, new LiteralOp(100)); + inputParams.put(DataExpression.RAND_COLS, new LiteralOp(100)); + inputParams.put(DataExpression.RAND_MIN, new LiteralOp(0.0)); + inputParams.put(DataExpression.RAND_MAX, new LiteralOp(1.0)); + Hop A = new DataGenOp(Types.OpOpDG.RAND, new DataIdentifier("A", Types.DataType.MATRIX, Types.ValueType.FP64), inputParams); + Hop tA = new ReorgOp("t", Types.DataType.MATRIX, Types.ValueType.FP64, Types.ReOrgOp.TRANS, A); + Hop rowSums = HopRewriteUtils.createAggUnaryOp(tA, Types.AggOp.SUM, Types.Direction.Row); + Hop result = f.apply(rowSums); + + assert result instanceof ReorgOp && result.getInput().size() == 1 && result.getInput(0) instanceof AggUnaryOp; + } catch (Exception e) { + e.printStackTrace(); + assert false; + } + } + + @Test + public void generateExample() { + String ruleStr = "MATRIX:B\nFLOAT:a,c\n+(a,-(B,c))\n=>\n+(-(a,c),B)"; + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + RewriterRuleSet rs = new RewriterRuleSet(ctx, List.of(rule)); + RewriterCodeGen.DEBUG = false; + String code = rs.toJavaCode("Test", false, false, true, false); + LOG.info(code); + } + + @Test + public void generateExample2() { + String ruleStr = "MATRIX:A\n+(A,A)\n=>\n*2(A)"; + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + RewriterRuleSet rs = new RewriterRuleSet(ctx, List.of(rule)); + RewriterCodeGen.DEBUG = false; + String code = rs.toJavaCode("Test", false, false, true, false); + LOG.info(code); + } + + @Test + public void testConditional() { + String ruleStr = "MATRIX:Xm,tmp852\n" + + "FLOAT:tmp65855\n" + + "\n" + + "%*%(t(/(Xm,tmp65855)),tmp852)\n" + + "=>\n" + + "{\n" + + "%*%(t(Xm),/(tmp852,tmp65855))\n" + + "/(%*%(t(Xm),tmp852),tmp65855)\n" + + "t(/(%*%(t(tmp852),Xm),tmp65855))\n" + + "}"; + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + RewriterRuleSet rs = new RewriterRuleSet(ctx, List.of(rule)); + rs.determineConditionalApplicability(); + RewriterCodeGen.DEBUG = false; + String code = rs.toJavaCode("GeneratedRewriteClass", false, true, true, false); + LOG.info(code); + } + + @Test + public void testLiteral() { + String ruleStr = "MATRIX:A\n" + + "\n" + + "-(+(A, $1:literal.FLOAT()), $2:literal.FLOAT())\n" + + "=>\n" + + "+(A, -($1, $2))"; + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + RewriterRuleSet rs = new RewriterRuleSet(ctx, List.of(rule)); + rs.determineConditionalApplicability(); + RewriterCodeGen.DEBUG = false; + String code = rs.toJavaCode("GeneratedRewriteClass", false, true, true, false); + LOG.info(code); + } + + @Test + public void testCFold() { + String ruleStr = "LITERAL_FLOAT:1.0,2.0\n" + + "\n" + + "+(1.0,1.0)\n" + + "=>\n" + + "2.0"; + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + RewriterRuleSet rs = new RewriterRuleSet(ctx, List.of(rule)); + rs.determineConditionalApplicability(); + RewriterCodeGen.DEBUG = false; + String code = rs.toJavaCode("GeneratedRewriteClass", false, true, true, false); + LOG.info(code); + } + + //@Test + public void codeGen() { + List files = List.of("/Users/janniklindemann/Dev/Rewrite-Generator-Reproducibility/data/rules_end_to_end.dml"); + //List files = List.of(RewriteAutomaticallyGenerated.FILE_PATH_MB); + String targetPath = "/Users/janniklindemann/Dev/MScThesis/other/GeneratedRewriteClass.java"; + + try { + // This is to specify that the generated code should print to the console if it modifies the DAG + // This should be disabled when generating production code + RewriterCodeGen.DEBUG = false; + RewriterCodeGen.generateRewritesFromFiles(files, targetPath, true, 3, true, false, ctx); + } catch (IOException e) { + e.printStackTrace(); + } + } + +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CostEstimates.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CostEstimates.java new file mode 100644 index 00000000000..dde5f991378 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CostEstimates.java @@ -0,0 +1,395 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite.functions; + +import org.apache.commons.lang3.mutable.MutableObject; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; +import org.apache.sysds.hops.rewriter.estimators.RewriterCostEstimator; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.test.component.codegen.rewrite.RewriterTopologySortTests; +import org.junit.BeforeClass; +import org.junit.Test; +import scala.Tuple2; +import scala.Tuple3; + +import java.util.List; +import java.util.Set; +import java.util.function.Function; + +public class CostEstimates { + protected static final Log LOG = LogFactory.getLog(CostEstimates.class.getName()); + + private static RuleContext ctx; + private static Function canonicalConverter; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, true); + } + + @Test + public void test1() { + RewriterStatement stmt = RewriterUtils.parse("%*%(+(A,B), C)", ctx, "MATRIX:A,B,C"); + MutableObject assertionRef = new MutableObject<>(); + long cost1 = RewriterCostEstimator.estimateCost(stmt, ctx, assertionRef); + LOG.info(cost1); + long cost2 = RewriterCostEstimator.estimateCost(stmt.getChild(0), ctx, assertionRef); + LOG.info(cost2); + assert cost2 < cost1; + } + + @Test + public void test2() { + RewriterStatement stmt = RewriterUtils.parse("*(+(1, 1), 2)", ctx, "LITERAL_INT:1,2"); + LOG.info(canonicalConverter.apply(stmt)); + } + + @Test + public void test3() { + RewriterStatement stmt = RewriterUtils.parse("_EClass(argList(1, ncol(X)))", ctx, "LITERAL_INT:1", "MATRIX:X"); + LOG.info(canonicalConverter.apply(stmt)); + } + + @Test + public void test4() { + RewriterStatement stmt1 = RewriterUtils.parse("t(%*%(+(A,B), C))", ctx, "MATRIX:A,B,C"); + RewriterStatement stmt2 = RewriterUtils.parse("%*%(t(C), t(+(A,B)))", ctx, "MATRIX:A,B,C"); + long cost1 = RewriterCostEstimator.estimateCost(stmt1, el -> 2000L, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, el -> 2000L, ctx); + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + LOG.info("Ratio: " + ((double)cost1)/cost2); + assert cost1 < cost2; + } + + @Test + public void test5() { + RewriterStatement stmt1 = RewriterUtils.parse("t(/(*(A, B), C))", ctx, "MATRIX:A,B,C"); + RewriterStatement stmt2 = RewriterUtils.parse("/(*(t(A), t(B)), t(C))", ctx, "MATRIX:A,B,C"); + long cost1 = RewriterCostEstimator.estimateCost(stmt1, el -> 2000L, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, el -> 2000L, ctx); + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + LOG.info("Ratio: " + ((double)cost1)/cost2); + assert cost1 < cost2; + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void test6() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(+(A, B))", ctx, "MATRIX:A,B,C"); + RewriterStatement stmt2 = RewriterUtils.parse("+(sum(A), sum(B))", ctx, "MATRIX:A,B,C"); + stmt2.givenThatEqualDimensions(stmt2.getChild(0, 0), stmt2.getChild(1, 0), ctx); + long cost1 = RewriterCostEstimator.estimateCost(stmt1, el -> 2000L, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, el -> 2000L, ctx); + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + LOG.info("Ratio: " + ((double)cost2)/cost1); + assert cost2 < cost1; + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void test7() { + RewriterStatement stmt1 = RewriterUtils.parse("cast.MATRIX(sum(A))", ctx, "MATRIX:A,B,C"); + RewriterStatement stmt2 = RewriterUtils.parse("rowSums(colSums(A))", ctx, "MATRIX:A,B,C"); + long cost1 = RewriterCostEstimator.estimateCost(stmt1, el -> 2000L, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, el -> 2000L, ctx); + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + LOG.info("Ratio: " + ((double)cost1)/cost2); + assert cost1 < cost2; + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void test8() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(*(diag(A), diag(B)))", ctx, "MATRIX:A,B,C"); + RewriterStatement stmt2 = RewriterUtils.parse("trace(*(A, B))", ctx, "MATRIX:A,B,C"); + + long cost1 = RewriterCostEstimator.estimateCost(stmt1, el -> 2000L, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, el -> 2000L, ctx); + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + LOG.info("Ratio: " + ((double)cost1)/cost2); + assert cost1 < cost2; + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void test9() { + String stmt1Str = "MATRIX:WM\n" + + "FLOAT:m2X,c19b086e-34d2-46dd-9651-7b6d1d16e459\n" + + "LITERAL_FLOAT:1.0\n" + + "sqrt(*(m2X,/(sum(WM),-(c19b086e-34d2-46dd-9651-7b6d1d16e459,1.0))))"; + String stmt2Str = "MATRIX:1167aa9b-102a-4bae-9801-8b18d210f954\n" + + "FLOAT:m2,41d7e6fb-d4a7-45cf-89cb-cea8ecf3430a\n" + + "LITERAL_FLOAT:1.0\n" + + "sqrt(/(*(m2,sum(1167aa9b-102a-4bae-9801-8b18d210f954)),-(41d7e6fb-d4a7-45cf-89cb-cea8ecf3430a,1.0)))"; + + RewriterStatement stmt1 = RewriterUtils.parse(stmt1Str, ctx); + RewriterStatement stmt2 = RewriterUtils.parse(stmt2Str, ctx); + + long cost1 = RewriterCostEstimator.estimateCost(stmt1, el -> 2000L, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, el -> 2000L, ctx); + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + LOG.info("Ratio: " + ((double)cost1)/cost2); + assert cost1 == cost2; + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void test10() { + String stmt1Str = "INT:num_records\n" + + "LITERAL_INT:3\n" + + "*(num_records,3)"; + String stmt2Str = "LITERAL_INT:3\n" + + "INT:run_index\n" + + "*(3,run_index)"; + + RewriterStatement stmt1 = RewriterUtils.parse(stmt1Str, ctx); + RewriterStatement stmt2 = RewriterUtils.parse(stmt2Str, ctx); + + long cost1 = RewriterCostEstimator.estimateCost(stmt1, el -> 2000L, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, el -> 2000L, ctx); + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + LOG.info("Ratio: " + ((double)cost1)/cost2); + assert cost1 == cost2; + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void test11() { + String stmtStr1 = "MATRIX:A,p_CG,z\n" + + "FLOAT:trust_delta_sq\n" + + "*(cast.FLOAT(A),cast.FLOAT(%*%(p_CG,z)))"; + String stmtStr2 = "MATRIX:A,p_CG,z\n" + + "FLOAT:trust_delta_sq\n" + + "*(cast.FLOAT(%*%(p_CG,z)),cast.FLOAT(A))"; + + RewriterStatement stmt1 = RewriterUtils.parse(stmtStr1, ctx); + RewriterStatement stmt2 = RewriterUtils.parse(stmtStr2, ctx); + + long cost1 = RewriterCostEstimator.estimateCost(stmt1, el -> 2000L, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, el -> 2000L, ctx); + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + LOG.info("Ratio: " + ((double)cost1)/cost2); + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void test12() { + String stmtStr1 = "MATRIX:A,B\n" + + "LITERAL_INT:1\n" + + "+([](A, 1, nrow(A), 1, 1),B)"; + String stmtStr2 = "MATRIX:A,B\n" + + "LITERAL_INT:1\n" + + "+([](A, 1, nrow(A), 1, ncol(A)), B)"; + + RewriterStatement stmt1 = RewriterUtils.parse(stmtStr1, ctx); + RewriterStatement stmt2 = RewriterUtils.parse(stmtStr2, ctx); + + long cost1 = RewriterCostEstimator.estimateCost(stmt1, el -> 2000L, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, el -> 2000L, ctx); + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + LOG.info("Ratio: " + ((double)cost1)/cost2); + + assert cost1 < cost2; + } + + @Test + public void test13() { + String stmtStr1 = "MATRIX:A,B\n" + + "LITERAL_INT:1\n" + + "[](rowSums(A), 1, nrow(A), 1, 1)"; + String stmtStr2 = "MATRIX:A,B\n" + + "LITERAL_INT:1\n" + + "rowSums(A)"; + + RewriterStatement stmt1 = RewriterUtils.parse(stmtStr1, ctx); + RewriterStatement stmt2 = RewriterUtils.parse(stmtStr2, ctx); + + long cost1 = RewriterCostEstimator.estimateCost(stmt1, el -> 2000L, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, el -> 2000L, ctx); + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + LOG.info("Ratio: " + ((double)cost1)/cost2); + + assert cost2 < cost1; + + stmt1 = canonicalConverter.apply(stmt1); + stmt2 = canonicalConverter.apply(stmt2); + + LOG.info("=========="); + LOG.info(stmt1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(stmt2.toParsableString(ctx, true)); + assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1)); + } + + @Test + public void test14() { + RewriterStatement stmt1 = RewriterUtils.parse("t(t(A))", ctx, "MATRIX:A"); + MutableObject assertionRef = new MutableObject<>(); + long maxCost = RewriterCostEstimator.estimateCost(stmt1, ctx, assertionRef); + Tuple2, Boolean> allowedCombinations = RewriterCostEstimator.determineSingleReferenceRequirement(stmt1, RewriterCostEstimator.DEFAULT_COST_FN, assertionRef.getValue(), 0, maxCost, ctx); + LOG.info(allowedCombinations._1); + LOG.info("AllowCombinations: " + allowedCombinations._2); + assert allowedCombinations._1.size() == 1; + } + + @Test + public void test15() { + RewriterStatement stmt1 = RewriterUtils.parse("sum(rowSums(A))", ctx, "MATRIX:A"); + RewriterStatement stmt2 = RewriterUtils.parse("sum(A)", ctx, "MATRIX:A"); + MutableObject assertionRef = new MutableObject<>(); + long maxCost = RewriterCostEstimator.estimateCost(stmt1, ctx, assertionRef); + long fullCost = RewriterCostEstimator.estimateCost(stmt2, ctx, assertionRef); + Tuple2, Boolean> allowedCombinations = RewriterCostEstimator.determineSingleReferenceRequirement(stmt1, RewriterCostEstimator.DEFAULT_COST_FN, assertionRef.getValue(), fullCost, maxCost, ctx); + LOG.info(allowedCombinations._1); + LOG.info("AllowCombinations: " + allowedCombinations._2); + assert allowedCombinations._1.isEmpty(); + } + + @Test + public void test16() { + RewriterStatement stmt1 = RewriterUtils.parse("+(colSums(A),[](B,1,1,1,ncol(B)))", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("+(colSums(A),colSums([](B,1,1,1,ncol(B))))", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + long cost1 = RewriterCostEstimator.estimateCost(stmt1, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, ctx); + assert cost1 < cost2; + } + + @Test + public void test17() { + RewriterStatement stmt1 = RewriterUtils.parse("%*%(colVec(A),B)", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + RewriterStatement stmt2 = RewriterUtils.parse("%*%(colSums(colVec(A)),B)", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + long cost1 = RewriterCostEstimator.estimateCost(stmt1, ctx); + long cost2 = RewriterCostEstimator.estimateCost(stmt2, ctx); + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + assert cost1 < cost2; + } + + @Test + public void test18() { + String ruleStr = + "MATRIX:tmp55220\n" + + "FLOAT:tmp23781\n" + + "\n" + + "/(t(tmp55220),tmp23781)\n" + + "=>\n" + + "t(/(tmp55220,tmp23781))"; + + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + + List, Long, Long>> cmp = RewriterCostEstimator.compareCosts(rule.getStmt1(), rule.getStmt2(), rule.getStmt1().getAssertions(ctx), ctx, false, 0, false); + + LOG.info(cmp); + long cost1 = RewriterCostEstimator.estimateCost(rule.getStmt1(), ctx); + long cost2 = RewriterCostEstimator.estimateCost(rule.getStmt2(), ctx); + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + assert cost1 == cost2; + } + + @Test + public void test19() { + String ruleStr = + "MATRIX:tmp14587,tmp76084\n" + + "FLOAT:one_over_sqrt_two_pi\n" + + "\n" + + "*(tmp14587,/(one_over_sqrt_two_pi,tmp76084))\n" + + "=>\n" + + "/(*(one_over_sqrt_two_pi,tmp14587),tmp76084)"; + + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + + List, Long, Long>> cmp = RewriterCostEstimator.compareCosts(rule.getStmt1(), rule.getStmt2(), rule.getStmt1().getAssertions(ctx), ctx, false, 0, false); + + LOG.info(cmp); + long cost1 = RewriterCostEstimator.estimateCost(rule.getStmt1(), ctx); + long cost2 = RewriterCostEstimator.estimateCost(rule.getStmt2(), ctx); + LOG.info("Cost1: " + cost1); + LOG.info("Cost2: " + cost2); + assert cost1 == cost2; + } +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/DMLCodeGenTest.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/DMLCodeGenTest.java new file mode 100644 index 00000000000..46a6069a7c8 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/DMLCodeGenTest.java @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite.functions; + +import org.apache.commons.lang3.mutable.MutableBoolean; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.hops.rewriter.dml.DMLCodeGenerator; +import org.apache.sysds.hops.rewriter.dml.DMLExecutor; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleCreator; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.test.component.codegen.rewrite.RewriterTopologySortTests; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.util.UUID; +import java.util.function.Function; + +public class DMLCodeGenTest { + protected static final Log LOG = LogFactory.getLog(DMLCodeGenTest.class.getName()); + + private static RuleContext ctx; + private static Function canonicalConverter; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, false); + } + + @Test + public void test1() { + RewriterStatement stmt = RewriterUtils.parse("trace(+(A, t(B)))", ctx, "MATRIX:A,B"); + LOG.info(DMLCodeGenerator.generateDML(stmt)); + } + + @Test + public void test2() { + String ruleStr1 = "MATRIX:A\nt(t(A))\n=>\nA"; + String ruleStr2 = "MATRIX:A\nrowSums(t(A))\n=>\nt(colSums(A))"; + RewriterRule rule1 = RewriterUtils.parseRule(ruleStr1, ctx); + RewriterRule rule2 = RewriterUtils.parseRule(ruleStr2, ctx); + + //RewriterRuleSet ruleSet = new RewriterRuleSet(ctx, List.of(rule1, rule2)); + String sessionId = UUID.randomUUID().toString(); + String validationScript = DMLCodeGenerator.generateRuleValidationDML(rule2, DMLCodeGenerator.EPS, sessionId, ctx); + LOG.info("Validation script:"); + LOG.info(validationScript); + MutableBoolean valid = new MutableBoolean(true); + DMLExecutor.executeCode(validationScript, line -> { + if (!line.startsWith(sessionId)) + return; + + if (!line.endsWith("valid: TRUE")) { + DMLExecutor.println("An invalid rule was found!"); + DMLExecutor.println(line); + valid.setValue(false); + } + }); + + LOG.info("Exiting..."); + assert valid.booleanValue(); + } + + @Test + public void test3() { + String ruleStr2 = "MATRIX:A,B\nt(*(A,t(B)))\n=>\n*(t(A),B)"; + RewriterRule rule2 = RewriterUtils.parseRule(ruleStr2, ctx); + + assert RewriterRuleCreator.validateRuleCorrectnessAndGains(rule2, ctx); + } + + @Test + public void test4() { + // Should already be implemented + String ruleStr2 = "MATRIX:A,B\nt(+(A,t(B)))\n=>\n+(t(A),B)"; + RewriterRule rule2 = RewriterUtils.parseRule(ruleStr2, ctx); + + assert RewriterRuleCreator.validateRuleCorrectnessAndGains(rule2, ctx); + } + + @Test + public void test5() { + String ruleStr2 = "MATRIX:A\nLITERAL_FLOAT:1,2\n-(+(1,A), 1)\n=>\n*(1,A)"; + RewriterRule rule2 = RewriterUtils.parseRule(ruleStr2, ctx); + + assert RewriterRuleCreator.validateRuleCorrectnessAndGains(rule2, ctx); + } + @Test + public void test6() { + String ruleStr2 = "MATRIX:?,B\nLITERAL_INT:1,2\n+(?,B)\n=>\n*(1,+(?,B))"; + RewriterRule rule2 = RewriterUtils.parseRule(ruleStr2, ctx); + + assert RewriterRuleCreator.validateRuleCorrectnessAndGains(rule2, ctx); + } + + @Test + public void test7() { + String ruleStr2 = "MATRIX:?,B\nLITERAL_INT:1,2\n+(?,B)\n=>\n*(1,+(?,B))"; + RewriterRule rule2 = RewriterUtils.parseRule(ruleStr2, ctx); + + assert RewriterRuleCreator.validateRuleCorrectnessAndGains(rule2, ctx); + } + + @Test + public void test8() { + String ruleStr = "MATRIX:8cbda53a-49a8-479f-bf34-baeeb1eb8b0f,is_LT_infinite,flip_pos\n" + + "\n" + + "+(%*%(is_LT_infinite,flip_pos),%*%(8cbda53a-49a8-479f-bf34-baeeb1eb8b0f,flip_pos))\n" + + "=>\n" + + "%*%(+(8cbda53a-49a8-479f-bf34-baeeb1eb8b0f,is_LT_infinite),flip_pos)"; + + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + + assert RewriterRuleCreator.validateRuleCorrectnessAndGains(rule, ctx); + } + + @Test + public void testRev() { + String ruleStr = "MATRIX:A\n" + + "FLOAT:b\n" + + "\n" + + "rev(*(rev(A),b))\n" + + "=>\n" + + "*(A,b)"; + + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + + LOG.info(DMLCodeGenerator.generateRuleValidationDML(rule, "test", ctx)); + + assert RewriterRuleCreator.validateRuleCorrectness(rule, ctx); + + assert RewriterRuleCreator.validateRuleApplicability(rule, ctx); + } + + @Test + public void testFused1() { + String ruleStr = "MATRIX:A\nLITERAL_FLOAT:0.0\n" + + "sum(!=(0.0,A))\n" + + "=>\n" + + "_nnz(A)"; + + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + + LOG.info(DMLCodeGenerator.generateRuleValidationDML(rule, "test", ctx)); + + assert RewriterRuleCreator.validateRuleCorrectness(rule, ctx); + + assert RewriterRuleCreator.validateRuleApplicability(rule, ctx); + } + + @Test + public void testFused2() { + String ruleStr = "MATRIX:A,B\nLITERAL_FLOAT:0.0,1.0\n" + + "-(0.0, -(*(A,B), 1.0))\n" + + "=>\n" + + "1-*(A,B)"; + + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + + LOG.info(DMLCodeGenerator.generateRuleValidationDML(rule, "test", ctx)); + + assert RewriterRuleCreator.validateRuleCorrectness(rule, ctx); + + assert RewriterRuleCreator.validateRuleApplicability(rule, ctx); + } + + @Test + public void testFused3() { + String ruleStr = "MATRIX:A,B\nLITERAL_FLOAT:0.0,1.0\n" + + "+(-(A,B),A)\n" + + "=>\n" + + "-(*2(A), B)"; + + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + + LOG.info(DMLCodeGenerator.generateRuleValidationDML(rule, "test", ctx)); + + assert RewriterRuleCreator.validateRuleCorrectness(rule, ctx); + + assert RewriterRuleCreator.validateRuleApplicability(rule, ctx); + } + + @Test + public void testFused4() { + String ruleStr = "MATRIX:A,B,C\nLITERAL_FLOAT:0.0,1.0\n" + + "1-*(A, const(A, 0.0))\n" + + "=>\n" + + "const(A, 1.0)"; + + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + + LOG.info(canonicalConverter.apply(rule.getStmt1()).toParsableString(ctx)); + LOG.info(canonicalConverter.apply(rule.getStmt2()).toParsableString(ctx)); + + //assert rule.getStmt1().match(RewriterStatement.MatcherContext.exactMatch(ctx, rule.getStmt2(), rule.getStmt1())); + + LOG.info(DMLCodeGenerator.generateRuleValidationDML(rule, "test", ctx)); + + assert RewriterRuleCreator.validateRuleCorrectness(rule, ctx); + + // As we have disabled operator fusion + assert !RewriterRuleCreator.validateRuleApplicability(rule, ctx, true, null); + } + + @Test + public void testFused5() { + String ruleStr = "MATRIX:A\n" + + "LITERAL_FLOAT:0.0\n" + + "\n" + + "sum(!=(0.0,A))\n" + + "=>\n" + + "_nnz(A)"; + + RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx); + + LOG.info(canonicalConverter.apply(rule.getStmt1()).toParsableString(ctx)); + LOG.info(canonicalConverter.apply(rule.getStmt2()).toParsableString(ctx)); + + //assert rule.getStmt1().match(RewriterStatement.MatcherContext.exactMatch(ctx, rule.getStmt2(), rule.getStmt1())); + + LOG.info(DMLCodeGenerator.generateRuleValidationDML(rule, "test", ctx)); + + assert RewriterRuleCreator.validateRuleCorrectness(rule, ctx); + + assert RewriterRuleCreator.validateRuleApplicability(rule, ctx, true, null); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/MinimalDifference.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/MinimalDifference.java new file mode 100644 index 00000000000..778ef8ac7d1 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/MinimalDifference.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite.functions; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.test.component.codegen.rewrite.RewriterTopologySortTests; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.util.function.Function; + +public class MinimalDifference { + protected static final Log LOG = LogFactory.getLog(MinimalDifference.class.getName()); + + private static RuleContext ctx; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + } + + @Test + public void test1() { + RewriterStatement stmt1 = RewriterUtils.parse("t(t(A))", ctx, "MATRIX:A"); + RewriterStatement stmt2 = RewriterUtils.parse("t(A)", ctx, "MATRIX:A"); + + RewriterStatement.MatcherContext mCtx = RewriterStatement.MatcherContext.findMinimalDifference(ctx, stmt2, stmt1); + stmt1.match(mCtx); + LOG.info("Minimal Difference: "); + LOG.info(mCtx.getFirstMismatch()._1.toParsableString(ctx)); + LOG.info(mCtx.getFirstMismatch()._2.toParsableString(ctx)); + } + + @Test + public void test2() { + RewriterStatement stmt1 = RewriterUtils.parse("-(A, t(+(A, A)))", ctx, "MATRIX:A"); + RewriterStatement stmt2 = RewriterUtils.parse("-(A, t(*(2, A)))", ctx, "MATRIX:A", "LITERAL_INT:2"); + + RewriterStatement.MatcherContext mCtx = RewriterStatement.MatcherContext.findMinimalDifference(ctx, stmt2, stmt1); + stmt1.match(mCtx); + LOG.info("Minimal Difference: "); + LOG.info(mCtx.getFirstMismatch()._1.toParsableString(ctx)); + LOG.info(mCtx.getFirstMismatch()._2.toParsableString(ctx)); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RewriterSearchUtilsTest.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RewriterSearchUtilsTest.java new file mode 100644 index 00000000000..4014fc85b5c --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RewriterSearchUtilsTest.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite.functions; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.hops.rewriter.utils.RewriterSearchUtils; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.test.component.codegen.rewrite.RewriterTopologySortTests; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.util.Arrays; +import java.util.List; +import java.util.function.Function; + +public class RewriterSearchUtilsTest { + protected static final Log LOG = LogFactory.getLog(RewriterSearchUtilsTest.class.getName()); + + private static RuleContext ctx; + private static Function canonicalConverter; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, false); + } + + @Test + public void testDecode1() { + int l = 27; + int n = 5; + int[] digits = RewriterSearchUtils.fromBaseNNumber(l, n); + assert digits.length == 3 && digits[0] == 1 && digits[1] == 0 && digits[2] == 2; + } + + @Test + public void testDecode2() { + int l = 5; + int n = 5; + int[] digits = RewriterSearchUtils.fromBaseNNumber(l, n); + LOG.info(Arrays.toString(digits)); + assert digits.length == 2 && digits[0] == 1 && digits[1] == 0; + } + + @Test + public void testEncode1() { + int[] digits = new int[] { 1, 0, 2 }; + int[] digits2 = new int[] {4, 4, 4}; + int n = 5; + int l = RewriterSearchUtils.toBaseNNumber(digits, n); + int l2 = RewriterSearchUtils.toBaseNNumber(digits2, n); + LOG.info(l); + LOG.info(Integer.toBinaryString(l)); + LOG.info(l2); + LOG.info(Integer.toBinaryString(l2)); + assert l == 27; + } + + @Test + public void testRandomStatementGeneration() { + LOG.info(RewriterSearchUtils.getMaxSearchNumberForNumOps(3)); + int ctr = 0; + for (int i = 0; i < 20; i++) { + List ops = RewriterSearchUtils.decodeOrderedStatements(i); + //LOG.info("Idx: " + i); + //LOG.info(ops); + //LOG.info(RewriterAlphabetEncoder.buildAllPossibleDAGs(ops, ctx, false).size()); + for (RewriterStatement stmt : RewriterSearchUtils.buildAllPossibleDAGs(ops, ctx, true)) { + LOG.info("Base: " + stmt.toParsableString(ctx)); + for (RewriterStatement sstmt : RewriterSearchUtils.buildAssertionVariations(stmt, ctx)) { + canonicalConverter.apply(sstmt); + LOG.info(sstmt.toParsableString(ctx)); + //LOG.info("Raw: " + sstmt); + ctr++; + } + } + } + + LOG.info("Total DAGs: " + ctr); + } + + @Test + public void testRandomStatementGeneration2() { + int ctr = 0; + //for (int i = 0; i < 20; i++) { + List ops = List.of(RewriterSearchUtils.instructionAlphabet[3], RewriterSearchUtils.instructionAlphabet[16], RewriterSearchUtils.instructionAlphabet[6]); + //LOG.info("Idx: " + i); + //LOG.info(ops); + //LOG.info(RewriterAlphabetEncoder.buildAllPossibleDAGs(ops, ctx, false).size()); + for (RewriterStatement stmt : RewriterSearchUtils.buildAllPossibleDAGs(ops, ctx, true)) { + LOG.info("Base: " + stmt.toParsableString(ctx)); + for (RewriterStatement sstmt : RewriterSearchUtils.buildVariations(stmt, ctx)) { + canonicalConverter.apply(sstmt); + LOG.info(sstmt.toParsableString(ctx)); + //LOG.info("Raw: " + sstmt); + ctr++; + } + } + //} + + LOG.info("Total DAGs: " + ctr); + } + + @Test + public void test() { + RewriterStatement stmt = RewriterUtils.parse("+([](A, 1, 1, 1, 1), B)", ctx, "MATRIX:A,B", "LITERAL_INT:1"); + stmt = canonicalConverter.apply(stmt); + LOG.info(stmt.toParsableString(ctx)); + } + +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RuleCreationTests.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RuleCreationTests.java new file mode 100644 index 00000000000..eabe5138258 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RuleCreationTests.java @@ -0,0 +1,295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite.functions; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleBuilder; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleCreator; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleSet; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.test.component.codegen.rewrite.RewriterTopologySortTests; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.util.List; +import java.util.function.Function; + +public class RuleCreationTests { + protected static final Log LOG = LogFactory.getLog(RuleCreationTests.class.getName()); + + private static RuleContext ctx; + private static Function canonicalConverter; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, false); + } + + @Test + public void test1() { + RewriterStatement from = RewriterUtils.parse("t(%*%(t(U),V))", ctx, "MATRIX:U,V"); + RewriterStatement to = RewriterUtils.parse("%*%(t(U), V)", ctx, "MATRIX:U,V"); + RewriterStatement canonicalForm1 = canonicalConverter.apply(from); + RewriterStatement canonicalForm2 = canonicalConverter.apply(to); + + LOG.info("=========="); + LOG.info(canonicalForm1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(canonicalForm2.toParsableString(ctx, true)); + assert canonicalForm1.match(RewriterStatement.MatcherContext.exactMatch(ctx, canonicalForm2, canonicalForm1)); + + RewriterRule rule = RewriterRuleCreator.createRule(from, to, canonicalForm1, canonicalForm2, ctx); + LOG.info(rule); + } + + @Test + public void test2() { + RewriterStatement from = RewriterUtils.parse("t(t(A))", ctx, "MATRIX:A"); + RewriterStatement to = RewriterUtils.parse("A", ctx, "MATRIX:A"); + RewriterStatement canonicalForm1 = canonicalConverter.apply(from); + RewriterStatement canonicalForm2 = canonicalConverter.apply(to); + + LOG.info("=========="); + LOG.info(canonicalForm1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(canonicalForm2.toParsableString(ctx, true)); + assert canonicalForm1.match(RewriterStatement.MatcherContext.exactMatch(ctx, canonicalForm2, canonicalForm1)); + + RewriterRule rule = RewriterRuleCreator.createRule(from, to, canonicalForm1, canonicalForm2, ctx); + LOG.info(rule); + + RewriterRuleSet rs = new RewriterRuleSet(ctx, List.of(rule)); + + RewriterStatement testStmt = RewriterUtils.parse("t(t([](A, 1, ncol(A), 1, 1)))", ctx, "MATRIX:A", "LITERAL_INT:1"); + + RewriterRuleSet.ApplicableRule ar = rs.acceleratedFindFirst(testStmt); + + assert ar != null; + } + + @Test + public void validationTest1() { + RewriterRule rule = new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A") + .parseGlobalVars("FLOAT:b") + .withParsedStatement("sum(/(A, b))") + .toParsedStatement("/(sum(A), b)") + .build(); + + assert RewriterRuleCreator.validateRuleCorrectnessAndGains(rule, ctx); + } + + @Test + public void validationTest2() { + RewriterRule rule = new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .parseGlobalVars("FLOAT:b") + .withParsedStatement("rowSums(colSums(%*%(A, B)))") + .toParsedStatement("%*%(colSums(A), rowSums(B))") + .build(); + + assert RewriterRuleCreator.validateRuleCorrectness(rule, ctx); + assert !RewriterRuleCreator.validateRuleApplicability(rule, ctx); + } + + @Test + public void validationTest3() { + RewriterRule rule = new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .withParsedStatement("cast.MATRIX(sum(rowVec(A)))") + .toParsedStatement("rowSums(rowVec(A))") + .build(); + + assert RewriterRuleCreator.validateRuleCorrectness(rule, ctx); + assert !RewriterRuleCreator.validateRuleApplicability(rule, ctx); + } + + @Test + public void test3() { + RewriterStatement from = RewriterUtils.parse("%*%(A,%*%(B,rowVec(C)))", ctx, "MATRIX:A,B,C"); + RewriterStatement to = RewriterUtils.parse("%*%(%*%(A,B),rowVec(C))", ctx, "MATRIX:A,B,C"); + RewriterStatement canonicalForm1 = canonicalConverter.apply(from); + RewriterStatement canonicalForm2 = canonicalConverter.apply(to); + + LOG.info("=========="); + LOG.info(canonicalForm1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(canonicalForm2.toParsableString(ctx, true)); + + assert canonicalForm1.match(RewriterStatement.MatcherContext.exactMatch(ctx, canonicalForm2, canonicalForm1)); + } + + @Test + public void test4() { + RewriterStatement from = RewriterUtils.parse("*(a,0.0)", ctx, "FLOAT:a", "LITERAL_FLOAT:0.0"); + RewriterStatement to = RewriterUtils.parse("0.0", ctx, "LITERAL_FLOAT:0.0"); + RewriterStatement canonicalForm1 = canonicalConverter.apply(from); + RewriterStatement canonicalForm2 = canonicalConverter.apply(to); + + LOG.info("=========="); + LOG.info(canonicalForm1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(canonicalForm2.toParsableString(ctx, true)); + + assert canonicalForm1.match(RewriterStatement.MatcherContext.exactMatch(ctx, canonicalForm2, canonicalForm1)); + + RewriterRule rule = RewriterRuleCreator.createRule(from, to, canonicalForm1, canonicalForm2, ctx); + LOG.info(rule); + + RewriterStatement from2 = RewriterUtils.parse("/(0.0,a)", ctx, "FLOAT:a", "LITERAL_FLOAT:0.0"); + RewriterStatement to2 = RewriterUtils.parse("0.0", ctx, "LITERAL_FLOAT:0.0"); + RewriterStatement canonicalForm12 = canonicalConverter.apply(from2); + RewriterStatement canonicalForm22 = canonicalConverter.apply(to2); + + LOG.info("=========="); + LOG.info(canonicalForm12.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(canonicalForm22.toParsableString(ctx, true)); + + assert canonicalForm12.match(RewriterStatement.MatcherContext.exactMatch(ctx, canonicalForm22, canonicalForm12)); + + RewriterRule rule2 = RewriterRuleCreator.createRule(from2, to2, canonicalForm12, canonicalForm22, ctx); + LOG.info(rule2); + + RewriterRuleSet rs = new RewriterRuleSet(ctx, List.of(rule, rule2)); + + RewriterStatement testStmt = RewriterUtils.parse("/(*(a,0.0), b)", ctx, "FLOAT:a,b", "LITERAL_FLOAT:0.0"); + + RewriterRuleSet.ApplicableRule ar = rs.acceleratedFindFirst(testStmt); + + assert ar != null; + + testStmt = ar.rule.apply(ar.matches.get(0), testStmt, true, false); + + LOG.info("HERE"); + LOG.info(testStmt.toParsableString(ctx)); + + ar = rs.acceleratedFindFirst(testStmt); + + assert ar != null; + + testStmt = ar.rule.apply(ar.matches.get(0), testStmt, true, false); + + LOG.info(testStmt); + } + + @Test + public void test5() { + RewriterRule rule1 = RewriterUtils.parseRule("FLOAT:a\nLITERAL_FLOAT:0.0\n*(a, 0.0)\n=>\n0.0", ctx); + RewriterRule rule2 = RewriterUtils.parseRule("FLOAT:a\nLITERAL_FLOAT:0.0\n/(0.0, a)\n=>\n0.0", ctx); + RewriterRule rule3 = RewriterUtils.parseRule("FLOAT:a,b\nLITERAL_FLOAT:0.0\n/(*(a, 0.0), b)\n=>\n0.0", ctx); + RewriterRuleCreator rc = new RewriterRuleCreator(ctx); + rc.registerRule(rule3, rule3.getStmt1().getCost(ctx), rule3.getStmt2().getCost(ctx), false, canonicalConverter); + rc.registerRule(rule2, rule2.getStmt1().getCost(ctx), rule2.getStmt2().getCost(ctx), false, canonicalConverter); + rc.registerRule(rule1, rule1.getStmt1().getCost(ctx), rule1.getStmt2().getCost(ctx), false, canonicalConverter); + + LOG.info(rc.getRuleSet().serialize()); + } + + @Test + public void test6() { + RewriterStatement from = RewriterUtils.parse("%*%(const(colVec(A),0.0),log_nz(B))", ctx, "MATRIX:A,B", "LITERAL_FLOAT:0.0"); + RewriterStatement to = RewriterUtils.parse("%*%(colVec(A),const(B,0.0))", ctx, "MATRIX:A,B", "LITERAL_FLOAT:0.0"); + RewriterStatement canonicalForm1 = canonicalConverter.apply(from); + RewriterStatement canonicalForm2 = canonicalConverter.apply(to); + + LOG.info("=========="); + LOG.info(canonicalForm1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(canonicalForm2.toParsableString(ctx, true)); + /*LOG.info(canonicalForm1.getChild(1, 1, 0)); + LOG.info(canonicalForm1.getChild(1, 1, 0).getNCol()); + LOG.info(canonicalForm1.getChild(1, 1, 0).getNRow()); + LOG.info(canonicalForm2.getChild(1, 1, 0)); + LOG.info(canonicalForm2.getChild(1, 1, 0).getNCol()); + LOG.info(canonicalForm2.getChild(1, 1, 0).getNRow());*/ + RewriterStatement.MatcherContext mCtx = RewriterStatement.MatcherContext.exactMatch(ctx, canonicalForm2, canonicalForm1); + if (!canonicalForm1.match(mCtx)) { + LOG.info(mCtx.getFirstMismatch()._1); + LOG.info(mCtx.getFirstMismatch()._2); + assert false; + } + + RewriterRule rule = RewriterRuleCreator.createRule(from, to, canonicalForm1, canonicalForm2, ctx); + LOG.info(rule); + } + + @Test + public void testTypeInvariantRuleRegistration() { + RewriterRule rule1 = RewriterUtils.parseRule("FLOAT:a\nLITERAL_FLOAT:0\n*(a,0)\n=>\na", ctx); + RewriterRule rule2 = RewriterUtils.parseRule("INT:a\nLITERAL_INT:0\n*(a,0)\n=>\na", ctx); + RewriterRuleCreator ruleCreator = new RewriterRuleCreator(ctx); + ruleCreator.registerRule(rule1, canonicalConverter, ctx); + + assert !ruleCreator.registerRule(rule2, canonicalConverter, ctx); + } + + @Test + public void testRuleElimination() { + String rs1 = + "MATRIX:tmp34827,tmp40318\n" + + "LITERAL_FLOAT:0.0\n" + + "\n" + + "+(%*%(tmp34827,tmp40318),0.0)\n" + + "=>\n" + + "%*%(tmp34827,tmp40318)"; + String rs2 = + "MATRIX:tmp34827,tmp40318\n" + + "LITERAL_FLOAT:0.0\n" + + "\n" + + "+(tmp34827,0.0)\n" + + "=>\n" + + "tmp34827"; + + RewriterRule rule1 = RewriterUtils.parseRule(rs1, ctx); + RewriterRule rule2 = RewriterUtils.parseRule(rs2, ctx); + + RewriterRuleCreator ruleCreator = new RewriterRuleCreator(ctx); + ruleCreator.registerRule(rule1, canonicalConverter, ctx); + + assert ruleCreator.registerRule(rule2, canonicalConverter, ctx); + LOG.info(ruleCreator.getRuleSet().getRules()); + assert ruleCreator.getRuleSet().getRules().size() == 1; + } + + @Test + public void testExpansiveRule() { + String rs1 = + "MATRIX:A,B\n" + + "LITERAL_FLOAT:0.0\n" + + "\n" + + "+*(A,0.0,B)\n" + + "=>\n" + + "+*(A,0.0,!=(B,B))"; + + RewriterRule rule1 = RewriterUtils.parseRule(rs1, ctx); + + RewriterRuleCreator ruleCreator = new RewriterRuleCreator(ctx); + assert !ruleCreator.registerRule(rule1, canonicalConverter, ctx); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RuleSerializationTest.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RuleSerializationTest.java new file mode 100644 index 00000000000..d6ae07120f2 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RuleSerializationTest.java @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite.functions; + +import org.apache.commons.lang3.mutable.MutableObject; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; +import org.apache.sysds.hops.rewriter.estimators.RewriterCostEstimator; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleCreator; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleSet; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.junit.BeforeClass; +import org.junit.Test; +import scala.Tuple2; + +import java.util.List; +import java.util.Set; +import java.util.function.Function; + +public class RuleSerializationTest { + protected static final Log LOG = LogFactory.getLog(RuleSerializationTest.class.getName()); + + private static RuleContext ctx; + private static Function canonicalConverter; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, false); + } + + @Test + public void test1() { + String ruleStr1 = "MATRIX:A\nt(t(A))\n=>\nA"; + String ruleStr2 = "MATRIX:A\nrowSums(t(A))\n=>\nt(colSums(A))"; + RewriterRule rule1 = RewriterUtils.parseRule(ruleStr1, ctx); + RewriterRule rule2 = RewriterUtils.parseRule(ruleStr2, ctx); + + RewriterRuleSet ruleSet = new RewriterRuleSet(ctx, List.of(rule1, rule2)); + String serialized = ruleSet.serialize(); + + LOG.info(serialized); + + RewriterRuleSet newRuleSet = RewriterRuleSet.deserialize(serialized, ctx); + String newSerialized = newRuleSet.serialize(); + + LOG.info(newSerialized); + + assert serialized.equals(newSerialized); + } + + @Test + public void test2() { + RewriterStatement from = RewriterUtils.parse("t(t(U))", ctx, "MATRIX:U,V"); + RewriterStatement to = RewriterUtils.parse("U", ctx, "MATRIX:U,V"); + RewriterStatement canonicalForm1 = canonicalConverter.apply(from); + RewriterStatement canonicalForm2 = canonicalConverter.apply(to); + + LOG.info("=========="); + LOG.info(canonicalForm1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(canonicalForm2.toParsableString(ctx, true)); + assert canonicalForm1.match(RewriterStatement.MatcherContext.exactMatch(ctx, canonicalForm2, canonicalForm1)); + + RewriterRule rule = RewriterRuleCreator.createRule(from, to, canonicalForm1, canonicalForm2, ctx); + from = rule.getStmt1(); + to = rule.getStmt2(); + + MutableObject assertionRef = new MutableObject<>(); + long fullCost = RewriterCostEstimator.estimateCost(to, ctx); + long maxCost = RewriterCostEstimator.estimateCost(from, ctx, assertionRef); + Tuple2, Boolean> result = RewriterCostEstimator.determineSingleReferenceRequirement(from, RewriterCostEstimator.DEFAULT_COST_FN, assertionRef.getValue(), fullCost, maxCost, ctx); + + assert result._1.size() == 1 && result._2; + + rule.setAllowedMultiReferences(result._1, result._2); + + String serialized = rule.toParsableString(ctx); + + LOG.info("::RULE"); + LOG.info(serialized); + LOG.info(""); + + RewriterRule newRule = RewriterUtils.parseRule(serialized, ctx); + String newSerialized = newRule.toParsableString(ctx); + + LOG.info(newSerialized); + + assert serialized.equals(newSerialized); + } + + @Test + public void test3() { + RewriterStatement from = RewriterUtils.parse("sum(t(U))", ctx, "MATRIX:U,V"); + RewriterStatement to = RewriterUtils.parse("sum(U)", ctx, "MATRIX:U,V"); + RewriterStatement canonicalForm1 = canonicalConverter.apply(from); + RewriterStatement canonicalForm2 = canonicalConverter.apply(to); + + LOG.info("=========="); + LOG.info(canonicalForm1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(canonicalForm2.toParsableString(ctx, true)); + assert canonicalForm1.match(RewriterStatement.MatcherContext.exactMatch(ctx, canonicalForm2, canonicalForm1)); + + RewriterRule rule = RewriterRuleCreator.createRule(from, to, canonicalForm1, canonicalForm2, ctx); + from = rule.getStmt1(); + to = rule.getStmt2(); + + MutableObject assertionRef = new MutableObject<>(); + long fullCost = RewriterCostEstimator.estimateCost(to, ctx); + long maxCost = RewriterCostEstimator.estimateCost(from, ctx, assertionRef); + Tuple2, Boolean> result = RewriterCostEstimator.determineSingleReferenceRequirement(from, RewriterCostEstimator.DEFAULT_COST_FN, assertionRef.getValue(), fullCost, maxCost, ctx); + + assert result._1.size() == 1 && result._2; + + rule.setAllowedMultiReferences(result._1, result._2); + + String serialized = rule.toParsableString(ctx); + + LOG.info("::RULE"); + LOG.info(serialized); + LOG.info(""); + + RewriterRule newRule = RewriterUtils.parseRule(serialized, ctx); + String newSerialized = newRule.toParsableString(ctx); + + LOG.info(newSerialized); + + assert serialized.equals(newSerialized); + } + + @Test + public void test4() { + String ruleStr1 = "MATRIX:W1_rand,tmp29911\n" + + "FLOAT:tmp65095\n" + + "\n" + + "*(tmp65095,%*%(W1_rand,t(tmp29911)))\n" + + "=>\n" + + "{\n" + + "t(%*%(*(tmp65095,tmp29911),t(W1_rand)))\n" + + "%*%(*(tmp65095,W1_rand),t(tmp29911))\n" + + "*(tmp65095,t(%*%(tmp29911,t(W1_rand))))\n" + + "}"; + RewriterRule rule1 = RewriterUtils.parseRule(ruleStr1, ctx); + LOG.info(rule1.toString()); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/SparsityEstimationTest.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/SparsityEstimationTest.java new file mode 100644 index 00000000000..63af60ea230 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/SparsityEstimationTest.java @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite.functions; + +import org.apache.commons.lang3.mutable.MutableObject; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleCreator; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertionUtils; +import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; +import org.apache.sysds.hops.rewriter.estimators.RewriterCostEstimator; +import org.apache.sysds.hops.rewriter.estimators.RewriterSparsityEstimator; +import org.apache.sysds.test.component.codegen.rewrite.RewriterTopologySortTests; +import org.junit.BeforeClass; +import org.junit.Test; +import scala.Tuple3; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +public class SparsityEstimationTest { + protected static final Log LOG = LogFactory.getLog(SparsityEstimationTest.class.getName()); + + private static RuleContext ctx; + private static Function canonicalConverter; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, false); + } + + @Test + public void test1() { + RewriterStatement stmt = RewriterUtils.parse("+*(A, 0.0, B)", ctx, "MATRIX:A,B", "LITERAL_FLOAT:0.0"); + LOG.info(RewriterSparsityEstimator.estimateNNZ(stmt, ctx).toParsableString(ctx)); + } + + @Test + public void test2() { + RewriterStatement stmt = RewriterUtils.parse("+*(A, a, B)", ctx, "MATRIX:A,B", "FLOAT:a"); + LOG.info(RewriterSparsityEstimator.estimateNNZ(stmt, ctx).toParsableString(ctx)); + } + + @Test + public void test3() { + RewriterStatement stmt = RewriterUtils.parse("%*%(A, -(B, A))", ctx, "MATRIX:A,B", "FLOAT:a"); + RewriterAssertionUtils.buildImplicitAssertions(stmt, stmt.getAssertions(ctx), ctx); + + Map estimates = RewriterSparsityEstimator.estimateAllNNZ(stmt, ctx); + + estimates.forEach((k, v) -> { + stmt.getAssertions(ctx).update(v); + LOG.info("K: " + k.toParsableString(ctx)); + LOG.info("NNZ: " + v.toParsableString(ctx)); + }); + + LOG.info("Rollup: " + RewriterSparsityEstimator.rollupSparsities(estimates.get(stmt), estimates, ctx).toParsableString(ctx)); + + Map nnzs = new HashMap<>(); + nnzs.put(stmt.getChild(0), 3000L); + nnzs.put(stmt.getChild(1, 0), 50000L); + + MutableObject assertionRef = new MutableObject<>(); + RewriterStatement costFunction = RewriterCostEstimator.getRawCostFunction(stmt, ctx, assertionRef, false); + costFunction = RewriterSparsityEstimator.rollupSparsities(costFunction, estimates, ctx); + + LOG.info(costFunction.toParsableString(ctx)); + + LOG.info("Dense cost: " + RewriterCostEstimator.estimateCost(stmt, ctx)); + LOG.info("Sparse cost: " + RewriterCostEstimator.computeCostFunction(costFunction, RewriterCostEstimator.DEFAULT_COST_FN, (el, tpl) -> nnzs.get(el.getChild(0)), assertionRef.getValue(), ctx)); + } + + @Test + public void test4() { + RewriterStatement from = RewriterUtils.parse("+(*(A, B), *(A, C))", ctx, "MATRIX:A,B,C"); + RewriterStatement to = RewriterUtils.parse("*(A, +(B, C))", ctx, "MATRIX:A,B,C"); + RewriterStatement canonicalForm1 = canonicalConverter.apply(from); + RewriterStatement canonicalForm2 = canonicalConverter.apply(to); + + LOG.info("=========="); + LOG.info(canonicalForm1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(canonicalForm2.toParsableString(ctx, true)); + assert canonicalForm1.match(RewriterStatement.MatcherContext.exactMatch(ctx, canonicalForm2, canonicalForm1)); + + RewriterRule rule = RewriterRuleCreator.createRule(from, to, canonicalForm1, canonicalForm2, ctx); + LOG.info(rule); + + RewriterAssertionUtils.buildImplicitAssertion(rule.getStmt1(), rule.getStmt1().getAssertions(ctx), rule.getStmt1(), ctx); + RewriterAssertionUtils.buildImplicitAssertion(rule.getStmt2(), rule.getStmt1().getAssertions(ctx), rule.getStmt2(), ctx); + + RewriterCostEstimator.compareCosts(rule.getStmt1(), rule.getStmt2(), rule.getStmt1().getAssertions(ctx), ctx, true, 5, false); + } + + @Test + public void test5() { + RewriterStatement from = RewriterUtils.parse("t(%*%(t(A), B))", ctx, "MATRIX:A,B,C"); + RewriterStatement to = RewriterUtils.parse("%*%(t(B), A)", ctx, "MATRIX:A,B,C"); + RewriterStatement canonicalForm1 = canonicalConverter.apply(from); + RewriterStatement canonicalForm2 = canonicalConverter.apply(to); + + LOG.info("=========="); + LOG.info(canonicalForm1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(canonicalForm2.toParsableString(ctx, true)); + assert canonicalForm1.match(RewriterStatement.MatcherContext.exactMatch(ctx, canonicalForm2, canonicalForm1)); + + RewriterRule rule = RewriterRuleCreator.createRule(from, to, canonicalForm1, canonicalForm2, ctx); + LOG.info(rule); + + RewriterAssertionUtils.buildImplicitAssertion(rule.getStmt1(), rule.getStmt1().getAssertions(ctx), rule.getStmt1(), ctx); + RewriterAssertionUtils.buildImplicitAssertion(rule.getStmt2(), rule.getStmt1().getAssertions(ctx), rule.getStmt2(), ctx); + //rule.getStmt2().unsafePutMeta("_assertions", rule.getStmt1().getAssertions(ctx)); + + List, Long, Long>> costs = RewriterCostEstimator.compareCosts(rule.getStmt1(), rule.getStmt2(), rule.getStmt1().getAssertions(ctx), ctx, false, 5, false); + LOG.info(costs); + LOG.info("Does sparsity have an impact on optimal expression? >> " + RewriterCostEstimator.doesHaveAnImpactOnOptimalExpression(costs, true, true, 0)); + } + + @Test + public void test6() { + RewriterStatement from = RewriterUtils.parse("t(+(A, B))", ctx, "MATRIX:A,B,C"); + RewriterStatement to = RewriterUtils.parse("+(t(A), t(B))", ctx, "MATRIX:A,B,C"); + RewriterStatement canonicalForm1 = canonicalConverter.apply(from); + RewriterStatement canonicalForm2 = canonicalConverter.apply(to); + + LOG.info("=========="); + LOG.info(canonicalForm1.toParsableString(ctx, true)); + LOG.info("=========="); + LOG.info(canonicalForm2.toParsableString(ctx, true)); + assert canonicalForm1.match(RewriterStatement.MatcherContext.exactMatch(ctx, canonicalForm2, canonicalForm1)); + + RewriterRule rule = RewriterRuleCreator.createRule(from, to, canonicalForm1, canonicalForm2, ctx); + LOG.info(rule); + + RewriterAssertionUtils.buildImplicitAssertion(rule.getStmt1(), rule.getStmt1().getAssertions(ctx), rule.getStmt1(), ctx); + RewriterAssertionUtils.buildImplicitAssertion(rule.getStmt2(), rule.getStmt1().getAssertions(ctx), rule.getStmt2(), ctx); + + List, Long, Long>> costs = RewriterCostEstimator.compareCosts(rule.getStmt1(), rule.getStmt2(), rule.getStmt1().getAssertions(ctx), ctx, false, 5, false); + LOG.info(costs); + LOG.info("Does sparsity have an impact on optimal expression? >> " + RewriterCostEstimator.doesHaveAnImpactOnOptimalExpression(costs, true, true, 0)); + LOG.info("Does anything have an impact on optimal expression? >> " + RewriterCostEstimator.doesHaveAnImpactOnOptimalExpression(costs, true, false, 0)); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/SubtreeGeneratorTest.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/SubtreeGeneratorTest.java new file mode 100644 index 00000000000..de672c09ae4 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/SubtreeGeneratorTest.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite.functions; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterSearchUtils; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.apache.sysds.test.component.codegen.rewrite.RewriterTopologySortTests; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.util.List; +import java.util.function.Function; + +public class SubtreeGeneratorTest { + protected static final Log LOG = LogFactory.getLog(SubtreeGeneratorTest.class.getName()); + + private static RuleContext ctx; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + } + + @Test + public void test1() { + RewriterStatement stmt = RewriterUtils.parse("+(1, a)", ctx, "LITERAL_INT:1", "FLOAT:a"); + List subtrees = RewriterSearchUtils.generateSubtrees(stmt, ctx, 100); + + for (RewriterStatement sub : subtrees) { + LOG.info("=========="); + LOG.info(sub.toParsableString(ctx, true)); + } + + assert subtrees.size() == 2; + } + + @Test + public void test2() { + RewriterStatement stmt = RewriterUtils.parse("+(+(1, b), a)", ctx, "LITERAL_INT:1", "FLOAT:a,b"); + List subtrees = RewriterSearchUtils.generateSubtrees(stmt, ctx, 100); + + for (RewriterStatement sub : subtrees) { + LOG.info("=========="); + LOG.info(sub.toParsableString(ctx, true)); + } + + assert subtrees.size() == 3; + } + + @Test + public void test3() { + RewriterStatement stmt = RewriterUtils.parse("-(+(1.0,A),B)", ctx, "LITERAL_FLOAT:1.0", "MATRIX:A,B"); + List subtrees = RewriterSearchUtils.generateSubtrees(stmt, ctx, 100); + + for (RewriterStatement sub : subtrees) { + LOG.info("=========="); + LOG.info(sub.toParsableString(ctx, true)); + } + + assert subtrees.size() == 3; + } +} diff --git a/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/TestRuleSet.java b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/TestRuleSet.java new file mode 100644 index 00000000000..0826a81cf51 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/TestRuleSet.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.codegen.rewrite.functions; + +import org.apache.sysds.hops.rewriter.rule.RewriterRule; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleBuilder; +import org.apache.sysds.hops.rewriter.rule.RewriterRuleSet; +import org.apache.sysds.hops.rewriter.RewriterStatement; +import org.apache.sysds.hops.rewriter.utils.RewriterUtils; +import org.apache.sysds.hops.rewriter.RuleContext; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.util.List; + +public class TestRuleSet { + private static RuleContext ctx; + + @BeforeClass + public static void setup() { + ctx = RewriterUtils.buildDefaultContext(); + } + + @Test + public void test1() { + RewriterRule rule = new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .withParsedStatement("sum(%*%(A, t(B)))") + .toParsedStatement("sum(*(A, B))") + .build(); + + RewriterRuleSet rs = new RewriterRuleSet(ctx, List.of(rule)); + + RewriterStatement stmt = RewriterUtils.parse("sum(%*%(colVec(A), t(colVec(B))))", ctx, "MATRIX:A,B"); + + RewriterRuleSet.ApplicableRule ar = rs.acceleratedFindFirst(stmt); + + assert ar != null; + + stmt = ar.rule.apply(ar.matches.get(0), stmt, ar.forward, false); + } + + @Test + public void test2() { + RewriterRule rule = new RewriterRuleBuilder(ctx) + .setUnidirectional(true) + .parseGlobalVars("MATRIX:A,B") + .withParsedStatement("as.matrix(sum(colVec(A)))") + .toParsedStatement("rowSums(rowVec(A))") + .build(); + + RewriterRuleSet rs = new RewriterRuleSet(ctx, List.of(rule)); + + RewriterStatement stmt = RewriterUtils.parse("as.matrix(sum(t(rowVec(A))))", ctx, "MATRIX:A,B"); + + RewriterRuleSet.ApplicableRule ar = rs.acceleratedFindFirst(stmt); + + assert ar != null; + + stmt = ar.rule.apply(ar.matches.get(0), stmt, ar.forward, false); + } +} diff --git a/src/test/resources/rewriterframework/expressions.db b/src/test/resources/rewriterframework/expressions.db new file mode 100644 index 00000000000..8b5397f8d4a --- /dev/null +++ b/src/test/resources/rewriterframework/expressions.db @@ -0,0 +1,18610 @@ + +::STMT +MATRIX:prediction,target +LITERAL_FLOAT:1.0 +*(/(1.0,nrow(target)),-(prediction,target)) +::STMT +MATRIX:parsertemp75086 +LITERAL_FLOAT:32.0 +*(parsertemp75086,32.0) +::STMT +LITERAL_FLOAT:1.0 +cast.MATRIX(1.0) +::STMT +MATRIX:y_corr,parsertemp171089,parsertemp171084,parsertemp171095 +FLOAT:float98,float133,float340 +LITERAL_FLOAT:-1.0,1.0,2.0 +*(+(*(sqrt(parsertemp171084),-1.0),/(+(float340,parsertemp171089),+(float98,parsertemp171095))),-(1.0,*(2.0,>(y_corr,float133)))) +::STMT +MATRIX:parsertemp109934 +LITERAL_FLOAT:42.0 +*(parsertemp109934,42.0) +::STMT +MATRIX:posSampleVariances,negSampleVariances +FLOAT:int210,parsertemp31048,parsertemp31047,int867,int429,parsertemp31053,parsertemp31052,int196 +LITERAL_FLOAT:2.0 +/(^(+(/(posSampleVariances,int429),/(negSampleVariances,int210)),2.0),+(/(^(posSampleVariances,int196),*(parsertemp31047,parsertemp31048)),/(^(negSampleVariances,int867),*(parsertemp31052,parsertemp31053)))) +::STMT +MATRIX:X +FLOAT:int40 +LITERAL_FLOAT:1764.0 +sqrt(/(colSums(^(X,int40)),1764.0)) +::STMT +MATRIX:id +diag(diag(==(id,t(id)))) +::STMT +MATRIX:scale_X,z,beta +*(cast.FLOAT(diag(scale_X)),+(cast.FLOAT(beta),cast.FLOAT(z))) +::STMT +MATRIX:X +FLOAT:int459 +LITERAL_FLOAT:1.0,1.0E-6 +/(*(1.0E-6,sum(^(X,int459))),1.0) +::STMT +MATRIX:parsertemp18128,X,parsertemp18133 +FLOAT:int389 +LITERAL_FLOAT:0.0 +rowSums(*(>(%*%(X,parsertemp18128),0.0),t(^(int389,parsertemp18133)))) +::STMT +MATRIX:hubs +FLOAT:parsertemp30953 +LITERAL_FLOAT:2.0 +sum(^(-(/(hubs,parsertemp30953),hubs),2.0)) +::STMT +FLOAT:index +LITERAL_FLOAT:2.0 ++(*(index,2.0),2.0) +::STMT +MATRIX:R,dssp,dsep +FLOAT:4_eAvg +LITERAL_FLOAT:1.0 +-(/(/(+(R,dsep),+(R,dssp)),4_eAvg),1.0) +::STMT +MATRIX:r_LS,parsertemp170556,p_LS,parsertemp170552 +FLOAT:norm_r2_LS,lambda_LS ++(r_LS,*(/(norm_r2_LS,sum(parsertemp170556)),+(%*%(parsertemp170552,p_LS),*(lambda_LS,p_LS)))) +::STMT +MATRIX:X,RMSE +/(RMSE,-(max(X),min(X))) +::STMT +MATRIX:parsertemp472412,fP +FLOAT:max_values,parsertemp472284 +t(<=(parsertemp472412,/(^(parsertemp472284,max_values),ncol(fP)))) +::STMT +MATRIX:ts +FLOAT:q +cast.FLOAT(+(-(q,%*%(ts,ts)),%*%(ts,ts))) +::STMT +MATRIX:Y +FLOAT:x,X +LITERAL_FLOAT:1.0 +*(-(1.0,/(-(x,X),-(X,X))),cast.FLOAT(Y)) +::STMT +MATRIX:X +LITERAL_FLOAT:200.0,2.0 +^(/(t(colSums(X)),200.0),2.0) +::STMT +MATRIX:R +FLOAT:int37,int162 +INT:int981,parsertemp503363 +t(+(R,diag(rand(parsertemp503363,int981,int162,int37)))) +::STMT +MATRIX:y +FLOAT:beta,n +LITERAL_FLOAT:2.0 +/(sum(^(-(beta,y),2.0)),n) +::STMT +FLOAT:g +LITERAL_FLOAT:1.0,2.0 ++(*(g,2.0),1.0) +::STMT +MATRIX:sv,s,w,X,Y,out +FLOAT:step_sz +-(%*%(t(X),*(*(sv,out),Y)),+(w,*(step_sz,s))) +::STMT +MATRIX:parsertemp10744,parsertemp10743,W,H,parsertemp10739 +%*%(W,%*%(*(H,/(parsertemp10739,parsertemp10743)),t(*(H,parsertemp10744)))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,100.0 +*(*(-(i,1.0),100.0),100.0) +::STMT +MATRIX:Y_counts,Y +/(colSums(Y),sum(Y_counts)) +::STMT +MATRIX:minD,D +/(<=(D,minD),rowSums(<=(D,minD))) +::STMT +MATRIX:parsertemp472317,parsertemp472315,ig +t(rev(*(&(parsertemp472315,parsertemp472317),ig))) +::STMT +FLOAT:factor_up,parsertemp195892 +LITERAL_FLOAT:2.0 +-(*(2.0,factor_up),parsertemp195892) +::STMT +MATRIX:dY,W,Y,sumW +LITERAL_FLOAT:300.0,0.9 +-(*(0.9,dY),*(300.0,-(*(Y,sumW),%*%(W,Y)))) +::STMT +FLOAT:o_init,o +LITERAL_FLOAT:-1.0,2.0 +*(-(*(2.0,o_init),*(2.0,o)),-1.0) +::STMT +MATRIX:parsertemp265709,parsertemp265718 +LITERAL_FLOAT:2.0 +*(2.0,cast.FLOAT(%*%(colSums(parsertemp265718),rowSums(parsertemp265709)))) +::STMT +MATRIX:parsertemp555766,parsertemp555762,target +FLOAT:int381,int17 +sum(-(*(*(target,int17),parsertemp555762),*(-(int381,target),parsertemp555766))) +::STMT +MATRIX:ssX_V,X,P_1K +rowSums(*(P_1K,%*%(X,ssX_V))) +::STMT +LITERAL_FLOAT:8000.0 +8000.0 +::STMT +MATRIX:p,q,lambda,parsertemp116061,parsertemp116062,scale_X,shift_X ++(+(*(scale_X,%*%(parsertemp116061,parsertemp116062)),*(cast.FLOAT(q),shift_X)),*(lambda,p)) +::STMT +MATRIX:ss_avg_res_Y,ss_avg_tot_Y +LITERAL_FLOAT:1.0 +-(1.0,/(ss_avg_res_Y,ss_avg_tot_Y)) +::STMT +MATRIX:Xd,Xu +LITERAL_FLOAT:1.0 +/(1.0,-(Xu,Xd)) +::STMT +MATRIX:Y_counts,parsertemp560521,ent2_vec +sqrt(sum(*(Y_counts,-(ent2_vec,parsertemp560521)))) +::STMT +MATRIX:X,H,parsertemp16755 +LITERAL_FLOAT:0.0,2.0 +*(>(%*%(X,t(H)),0.0),^(2.0,cast.FLOAT(parsertemp16755))) +::STMT +MATRIX:cdf_min_distances +FLOAT:float467,float609 +INT:int767,num_runs +colSums(<(cdf_min_distances,*(rand(int767,num_runs,float609,float467),cdf_min_distances))) +::STMT +MATRIX:WM,Y +/(sum(*(Y,WM)),sum(WM)) +::STMT +MATRIX:scale_lambda +LITERAL_FLOAT:0.0 +*(scale_lambda,0.0) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:2.0 +*(linear_terms,2.0) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,80656.0 ++(*(-(i,1.0),80656.0),1.0) +::STMT +MATRIX:P +LITERAL_FLOAT:4.0 +*(P,4.0) +::STMT +MATRIX:fdom,X,parsertemp1688 ++(X,-(t(parsertemp1688),fdom)) +::STMT +MATRIX:sample_maps,X +LITERAL_FLOAT:2.0 +rowSums(^(%*%(sample_maps,X),2.0)) +::STMT +MATRIX:p,lambda,X +*(p,+(%*%(t(X),%*%(X,p)),*(lambda,p))) +::STMT +MATRIX:Ileft,_funvar2706,_funvar2707 +FLOAT:numI +-(cast.FLOAT(_funvar2706),*(/(rowSums(Ileft),numI),_funvar2707)) +::STMT +MATRIX:parsertemp31029,parsertemp31031 +FLOAT:int871 +LITERAL_FLOAT:149.0,150.0 +/(/(-(colSums(parsertemp31029),*(int871,parsertemp31031)),149.0),150.0) +::STMT +MATRIX:parsertemp130418 +LITERAL_FLOAT:1.0,4.0 ++(*(max(parsertemp130418),4.0),1.0) +::STMT +MATRIX:X +FLOAT:s +LITERAL_FLOAT:0.0 +-(+(nrow(X),0.0),s) +::STMT +MATRIX:parsertemp283570,tpr,fpr,parsertemp283568 +LITERAL_FLOAT:2.0 ++(cast.FLOAT(*(tpr,fpr)),sum(/(*(parsertemp283568,parsertemp283570),2.0))) +::STMT +MATRIX:xs +FLOAT:256_x +LITERAL_FLOAT:1000.0 +-(1000.0,sum(>=(xs,256_x))) +::STMT +MATRIX:parsertemp72182 +LITERAL_FLOAT:8.0 +*(parsertemp72182,8.0) +::STMT +FLOAT:num_centroids +LITERAL_FLOAT:3.0 +*(3.0,num_centroids) +::STMT +MATRIX:scale_X,X,parsertemp274503,parsertemp274506,P_1K +%*%(diag(scale_X),%*%(t(X),-(*(P_1K,parsertemp274503),*(P_1K,parsertemp274506)))) +::STMT +MATRIX:X +FLOAT:n +LITERAL_FLOAT:-1.0 +*(/(t(colSums(X)),n),-1.0) +::STMT +MATRIX:parsertemp42202,F +FLOAT:parsertemp42203,W,int416,meanX +t(*(/(F,-(W,int416)),-(+(parsertemp42202,parsertemp42203),meanX))) +::STMT +LITERAL_FLOAT:6.0,2001.0 +*(6.0,2001.0) +::STMT +MATRIX:parsertemp410987,parsertemp410989,parsertemp410978,W,H,parsertemp410980 +sum(%*%(/(*(W,parsertemp410987),t(parsertemp410989)),/(*(H,parsertemp410978),t(parsertemp410980)))) +::STMT +MATRIX:linear_terms +FLOAT:var_power,link_power,int701 +LITERAL_FLOAT:1.0 +^(linear_terms,-(/(-(int701,var_power),link_power),1.0)) +::STMT +MATRIX:parsertemp149339,parsertemp149335 +FLOAT:int257,obj,parsertemp149332 +LITERAL_FLOAT:0.5 +-(obj,+(+(*(parsertemp149332,int257),sum(parsertemp149335)),*(0.5,sum(parsertemp149339)))) +::STMT +MATRIX:parsertemp107030 +LITERAL_FLOAT:7.0 +*(parsertemp107030,7.0) +::STMT +MATRIX:y_batch,parsertemp459782,parsertemp459784 +FLOAT:loss ++(loss,/(sum(*(parsertemp459782,parsertemp459784)),nrow(y_batch))) +::STMT +MATRIX:parsertemp73634 +LITERAL_FLOAT:16.0 +*(parsertemp73634,16.0) +::STMT +MATRIX:P,Y +LITERAL_FLOAT:1.0 +/(P,+(-(ncol(Y),1.0),1.0)) +::STMT +MATRIX:ss,se,e +LITERAL_FLOAT:1.0,40.0 +-(/(/(se,ss),/(sum(e),40.0)),1.0) +::STMT +FLOAT:parsertemp254715,parsertemp254694,2123_sq_root_d,pp_CG,float162 ++(float162,*(parsertemp254715,/(-(parsertemp254694,2123_sq_root_d),pp_CG))) +::STMT +MATRIX:_sbcvar78,parsertemp22266 +FLOAT:int513 +LITERAL_FLOAT:2.0,10000.0 +/(^(-(_sbcvar78,/(parsertemp22266,int513)),2.0),/(%*%(rowSums(_sbcvar78),colSums(_sbcvar78)),10000.0)) +::STMT +MATRIX:linear_terms +FLOAT:int750,var_power,link_power +LITERAL_FLOAT:2.0 +^(linear_terms,-(/(-(int750,var_power),link_power),2.0)) +::STMT +MATRIX:linear_terms +FLOAT:var_power +LITERAL_FLOAT:1.0,2.0 +/(^(linear_terms,2.0),-(1.0,var_power)) +::STMT +MATRIX:tmp +FLOAT:norm_r2_LS +/(cast.FLOAT(%*%(t(tmp),tmp)),norm_r2_LS) +::STMT +MATRIX:parsertemp556355 +LITERAL_FLOAT:0.125 +*(parsertemp556355,0.125) +::STMT +FLOAT:num_records +LITERAL_FLOAT:1920.0 +/(1920.0,num_records) +::STMT +MATRIX:prec_chol,mu +LITERAL_FLOAT:2.0 +rowSums(*(^(mu,2.0),^(prec_chol,2.0))) +::STMT +LITERAL_FLOAT:100.0 +100.0 +::STMT +LITERAL_FLOAT:105.0 +105.0 +::STMT +LITERAL_FLOAT:81.0 +81.0 +::STMT +LITERAL_FLOAT:80.0 +80.0 +::STMT +LITERAL_FLOAT:127.0 +127.0 +::STMT +LITERAL_FLOAT:120.0 +120.0 +::STMT +MATRIX:parsertemp409212,ctab +LITERAL_FLOAT:0.45 +>(/(parsertemp409212,rowSums(ctab)),0.45) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0 +t(colSums(!=(X,0.0))) +::STMT +MATRIX:negSampleMeans,negSamples +LITERAL_FLOAT:2.0,1500.0 +-(colSums(^(negSamples,2.0)),*(1500.0,^(negSampleMeans,2.0))) +::STMT +MATRIX:totalE,parsertemp31933,X2,parsertemp31935 +t(%*%(t(totalE),==(%*%(X2,parsertemp31935),t(parsertemp31933)))) +::STMT +LITERAL_FLOAT:16.0 +16.0 +::STMT +MATRIX:p,V +%*%(t(V),%*%(V,p)) +::STMT +FLOAT:mu +LITERAL_FLOAT:0.999 +-(0.999,mu) +::STMT +LITERAL_FLOAT:15.0 +15.0 +::STMT +FLOAT:int302,int418 +LITERAL_FLOAT:1.0 ++(+(+(+(int302,int418),1.0),1.0),1.0) +::STMT +MATRIX:subspace_idx,parsertemp73653 +LITERAL_FLOAT:16.0,1.0 +<(-(subspace_idx,*(parsertemp73653,16.0)),1.0) +::STMT +MATRIX:samples_vs_runs_map,centroids,X_samples +LITERAL_FLOAT:2.0 +*(2.0,rowSums(*(X_samples,%*%(samples_vs_runs_map,centroids)))) +::STMT +LITERAL_FLOAT:33.0 +33.0 +::STMT +LITERAL_FLOAT:32.0 +32.0 +::STMT +MATRIX:parsertemp43626 +LITERAL_FLOAT:-1.0 +*(parsertemp43626,-1.0) +::STMT +MATRIX:rowSums_X_sq +FLOAT:D +LITERAL_FLOAT:0.5 +/(*(0.5,sqrt(D)),max(sqrt(rowSums_X_sq))) +::STMT +MATRIX:scale_X,shift_X +FLOAT:r +LITERAL_FLOAT:2.0 +sum(^(+(*(scale_X,r),*(r,shift_X)),2.0)) +::STMT +LITERAL_FLOAT:31.0 +31.0 +::STMT +LITERAL_FLOAT:30.0 +30.0 +::STMT +LITERAL_FLOAT:50.0 +50.0 +::STMT +MATRIX:parsertemp500607,parsertemp500610 +FLOAT:tau +*(tau,sum(abs(*(parsertemp500607,parsertemp500610)))) +::STMT +MATRIX:os,y,o +LITERAL_FLOAT:0.0 +exp(*(-(0.0,y),+(o,os))) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:2.0 +/(^(linear_terms,2.0),2.0) +::STMT +MATRIX:p_LS +FLOAT:norm_r2_LS,parsertemp170552,lambda_LS +*(/(norm_r2_LS,*(cast.FLOAT(p_LS),+(parsertemp170552,lambda_LS))),+(*(cast.FLOAT(parsertemp170552),cast.FLOAT(p_LS)),*(lambda_LS,cast.FLOAT(p_LS)))) +::STMT +MATRIX:b4,2362_2360_Y,W4 +t(+(%*%(W4,t(2362_2360_Y)),b4)) +::STMT +MATRIX:g_new,s,g_old +FLOAT:int686,int503 +*(/(sum(^(g_new,int503)),sum(^(g_old,int686))),s) +::STMT +LITERAL_FLOAT:42.0 +42.0 +::STMT +MATRIX:means,variances +FLOAT:beta +t(-(means,*(beta,variances))) +::STMT +MATRIX:WM,CVars,CFreqs +FLOAT:float270,parsertemp31268,int751,W +LITERAL_FLOAT:1.0 +/(sum(*(-(CFreqs,int751),CVars)),*(-(sum(WM),1.0),/(*(parsertemp31268,W),-(W,float270)))) +::STMT +LITERAL_FLOAT:45.0 +45.0 +::STMT +MATRIX:parsertemp439367,mean,parsertemp439305,weight,parsertemp439306,avgMean +FLOAT:int994 +LITERAL_FLOAT:1.0E-6 ++(+(-(/(parsertemp439367,parsertemp439306),*(int994,avgMean)),/(*(mean,parsertemp439305),t(weight))),1.0E-6) +::STMT +MATRIX:U,X,parsertemp382851 +FLOAT:int910 +t(%*%(t(U),*(!=(X,int910),-(parsertemp382851,X)))) +::STMT +MATRIX:prec_chol,X +LITERAL_FLOAT:2.0 +%*%(rowSums(^(X,2.0)),t(^(prec_chol,2.0))) +::STMT +MATRIX:s,w +FLOAT:lambda,step_sz +*(lambda,+(w,*(step_sz,s))) +::STMT +LITERAL_FLOAT:1000.0 +1000.0 +::STMT +MATRIX:U,V,X +LITERAL_FLOAT:2.0 +^(-(%*%(U,t(V)),X),2.0) +::STMT +MATRIX:S,parsertemp42207 +LITERAL_FLOAT:1.0,2.0 ++(-(parsertemp42207,/(t(S),2.0)),/(1.0,2.0)) +::STMT +MATRIX:parsertemp10744,V,W,H,parsertemp10748 +FLOAT:Eps +/(%*%(V,t(*(H,parsertemp10744))),+(%*%(W,%*%(H,parsertemp10748)),Eps)) +::STMT +MATRIX:ss +LITERAL_FLOAT:0.050000000000000044,1.0,40.0 +*(0.050000000000000044,-(/(40.0,ss),1.0)) +::STMT +MATRIX:W,H,X,parsertemp410997 +-(sum(%*%(W,H)),sum(*(X,parsertemp410997))) +::STMT +MATRIX:mean,parsertemp437225,X,parsertemp437631,weight,parsertemp437222 ++(/(-(%*%(parsertemp437222,X),%*%(parsertemp437225,mean)),sum(weight)),diag(parsertemp437631)) +::STMT +MATRIX:Q3,X,IQR +LITERAL_FLOAT:1.5 +>(X,+(Q3,*(1.5,IQR))) +::STMT +MATRIX:Q1,X,IQR +LITERAL_FLOAT:1.5 +<(X,-(Q1,*(1.5,IQR))) +::STMT +LITERAL_FLOAT:0.0 +INT:int502,int777 +t(rand(int502,int777,0.0,0.0)) +::STMT +LITERAL_FLOAT:0.5000000001 +0.5000000001 +::STMT +MATRIX:C,Xm,parsertemp265701,XtZ +FLOAT:ZtZ_sum +colSums(%*%(%*%(Xm,%*%(C,parsertemp265701)),t(/(XtZ,ZtZ_sum)))) +::STMT +LITERAL_FLOAT:3136.0 +3136.0 +::STMT +MATRIX:d,parsertemp410052,d_r_rev +*(d,t(colSums(*(parsertemp410052,d_r_rev)))) +::STMT +MATRIX:subspace_variance,parsertemp72203 +FLOAT:int677 +LITERAL_FLOAT:1.0 +%*%(t(subspace_variance),diag(/(1.0,<(parsertemp72203,int677)))) +::STMT +FLOAT:parsertemp31330 +LITERAL_FLOAT:1.0,10000.0 +/(*(parsertemp31330,10000.0),-(10000.0,1.0)) +::STMT +MATRIX:ubScores +FLOAT:minsc +LITERAL_FLOAT:0.0 +&(>(ubScores,minsc),>(ubScores,0.0)) +::STMT +MATRIX:parsertemp31105,parsertemp31107 +LITERAL_FLOAT:7.996E9,1999.0,2.0 +/(^(/(-(parsertemp31105,parsertemp31107),1999.0),2.0),7.996E9) +::STMT +LITERAL_FLOAT:254.0 +254.0 +::STMT +LITERAL_FLOAT:255.0 +255.0 +::STMT +LITERAL_FLOAT:300.0 +300.0 +::STMT +MATRIX:p_LS,tmp +FLOAT:norm_r2_LS +/(norm_r2_LS,*(cast.FLOAT(p_LS),cast.FLOAT(tmp))) +::STMT +MATRIX:valueCount,Y +/(t(valueCount),nrow(Y)) +::STMT +MATRIX:selCols2 +sum(!(selCols2)) +::STMT +MATRIX:lambda,B,Grad +LITERAL_FLOAT:2.0 +^(+(Grad,*(lambda,B)),2.0) +::STMT +MATRIX:R,dsep,dssm +/(+(R,dsep),-(R,dssm)) +::STMT +MATRIX:2940_mask,2939_out +LITERAL_FLOAT:0.35 +/(*(2939_out,2940_mask),0.35) +::STMT +MATRIX:r,alpha,Hd +LITERAL_FLOAT:2.0 +^(-(r,*(cast.FLOAT(alpha),Hd)),2.0) +::STMT +MATRIX:Hdiff,betamax,beta +FLOAT:INF,int762,int239 +!=(+(*(>=(Hdiff,int762),betamax),*(<(Hdiff,int239),beta)),INF) +::STMT +MATRIX:out2,parsertemp146940,184_dtemp,W2,W3 +LITERAL_FLOAT:0.0 +%*%(*(>(out2,0.0),%*%(-(184_dtemp,parsertemp146940),t(W3))),t(W2)) +::STMT +MATRIX:W4_rand +LITERAL_FLOAT:0.0873148795050037 +*(0.0873148795050037,W4_rand) +::STMT +MATRIX:w,out +LITERAL_FLOAT:1.0,50.0,0.5 +*(1.0,+(*(0.5,cast.FLOAT(out)),*(50.0,cast.FLOAT(w)))) +::STMT +MATRIX:parsertemp460644 +LITERAL_FLOAT:0.0625 +*(parsertemp460644,0.0625) +::STMT +MATRIX:r,w +FLOAT:tau +LITERAL_FLOAT:0.5 ++(*(0.5,cast.FLOAT(%*%(r,r))),*(tau,sum(abs(w)))) +::STMT +LITERAL_FLOAT:500.0 +500.0 +::STMT +MATRIX:parsertemp31112,parsertemp31114 +LITERAL_FLOAT:1499.0,2.0,3.37275E9 +/(^(/(-(parsertemp31112,parsertemp31114),1499.0),2.0),3.37275E9) +::STMT +MATRIX:S,parsertemp42207 +LITERAL_FLOAT:2.0,0.5 ++(-(parsertemp42207,/(t(S),2.0)),0.5) +::STMT +MATRIX:out,parsertemp2798 +FLOAT:int695,int909,int977,int948 +sum(*(*(>(out,int948),-(int695,parsertemp2798)),*(>(out,int909),-(int977,parsertemp2798)))) +::STMT +MATRIX:parsertemp389760,permut +LITERAL_FLOAT:1.0 +%*%(t(permut),/(-(exp(parsertemp389760),1.0),+(exp(parsertemp389760),1.0))) +::STMT +MATRIX:parsertemp477715,Y,K +FLOAT:X +LITERAL_FLOAT:1.0 +*(-(*(cast.FLOAT(K),-(X,X)),-(cast.FLOAT(Y),cast.FLOAT(Y))),-(1.0,/(cast.FLOAT(parsertemp477715),-(X,X)))) +::STMT +MATRIX:parsertemp222703 +LITERAL_FLOAT:0.0 +==(t(parsertemp222703),0.0) +::STMT +MATRIX:d,parsertemp43998 +FLOAT:int973 +cast.FLOAT(%*%(t(d),+(d,*(int973,parsertemp43998)))) +::STMT +MATRIX:q,r +FLOAT:p,a,norm_r2 +%*%(t(+(r,*(a,q))),+(r,*(/(norm_r2,p),+(q,q)))) +::STMT +MATRIX:m_err +cast.FLOAT(rowSums(colSums(m_err))) +::STMT +FLOAT:m2X +LITERAL_FLOAT:1.000010000100001 +sqrt(*(m2X,1.000010000100001)) +::STMT +MATRIX:g_reg,p_CG +FLOAT:parsertemp170148,int960,q_CG,int952,z,parsertemp170170,pq_CG +*(+(+(*(parsertemp170170,pq_CG),*(z,q_CG)),sum(*(g_reg,p_CG))),/(-(*(z,int952),sqrt(parsertemp170148)),sum(^(p_CG,int960)))) +::STMT +MATRIX:sts,d,parsertemp44021,parsertemp44023 +FLOAT:delta2 +sqrt(+(*(%*%(parsertemp44021,d),%*%(parsertemp44021,d)),*(%*%(parsertemp44023,d),-(delta2,sts)))) +::STMT +FLOAT:offset_x +LITERAL_FLOAT:1.0 +-(1.0,round(offset_x)) +::STMT +MATRIX:parsertemp220853,parsertemp220854,betamax,Hneg,Hpos,beta +FLOAT:INF,logU +LITERAL_FLOAT:0.0 +*(>=(-(+(parsertemp220853,parsertemp220854),logU),0.0),!=(+(*(Hpos,betamax),*(Hneg,beta)),INF)) +::STMT +MATRIX:y_prob,ones_ctg +LITERAL_FLOAT:1.0 +%*%(y_prob,-(1.0,diag(ones_ctg))) +::STMT +MATRIX:A +LITERAL_FLOAT:1.0 ++(+(/(1.0,cast.FLOAT(A)),/(1.0,cast.FLOAT(A))),/(1.0,cast.FLOAT(A))) +::STMT +LITERAL_FLOAT:-0.001 +-0.001 +::STMT +LITERAL_FLOAT:0.001 +0.001 +::STMT +MATRIX:f,I +*(sum(I),max(f)) +::STMT +MATRIX:parsertemp379668 +FLOAT:int826 +LITERAL_FLOAT:1.0,-1.0 +*(sum(-(>=(parsertemp379668,int826),1.0)),-1.0) +::STMT +FLOAT:int713,int28 +LITERAL_FLOAT:0.0 +INT:parsertemp557199,int576 +==(diag(rand(parsertemp557199,int576,int713,int28)),0.0) +::STMT +MATRIX:parsertemp149283,parsertemp149281 +FLOAT:delta2,s2 +LITERAL_FLOAT:2.0 +sqrt(+(^(sum(parsertemp149281),2.0),*(sum(parsertemp149283),-(delta2,s2)))) +::STMT +MATRIX:s,parsertemp44016,d +FLOAT:parsertemp44015 +cast.FLOAT(%*%(t(-(s,parsertemp44016)),-(s,*(parsertemp44015,d)))) +::STMT +MATRIX:b,X +exp(%*%(X,b)) +::STMT +FLOAT:parsertemp41020,m2,int106 +LITERAL_FLOAT:2003.0 +/(sqrt(*(/(int106,parsertemp41020),m2)),sqrt(2003.0)) +::STMT +MATRIX:parsertemp497802,Y +LITERAL_FLOAT:0.0 +*(Y,!=(parsertemp497802,0.0)) +::STMT +MATRIX:p,lambda,scale_X,shift_X +FLOAT:q,norm_r2 +*(/(norm_r2,sum(*(p,q))),+(+(*(scale_X,q),*(q,shift_X)),*(lambda,p))) +::STMT +FLOAT:sample_block_size +LITERAL_FLOAT:1.0,3.0 ++(*(sample_block_size,3.0),1.0) +::STMT +MATRIX:2697_b,parsertemp459149,parsertemp459147 +rowSums(exp(-(+(parsertemp459147,2697_b),parsertemp459149))) +::STMT +MATRIX:output_values,initial_prediction +FLOAT:learning_rate ++(initial_prediction,*(learning_rate,sum(output_values))) +::STMT +FLOAT:m2,float276,int815 +LITERAL_FLOAT:2000.0 +sqrt(*(/(2000.0,-(int815,float276)),m2)) +::STMT +MATRIX:probs,out3,y_batch,184_scores,parsertemp146933 +FLOAT:float988,int950,183_N,int776 +LITERAL_FLOAT:1.0 +*(*(*(/(int950,183_N),-(int776,y_batch)),/(1.0,+(probs,float988))),/(exp(-(out3,parsertemp146933)),rowSums(exp(184_scores)))) +::STMT +MATRIX:parsertemp220853,parsertemp220854,beta +LITERAL_FLOAT:0.0,3.4011973816621555 +*(>=(-(+(parsertemp220853,parsertemp220854),3.4011973816621555),0.0),beta) +::STMT +LITERAL_FLOAT:0.002 +0.002 +::STMT +MATRIX:parsertemp382680,parsertemp382677 +FLOAT:parsertemp382674 +LITERAL_FLOAT:0.5,5.0E-7 ++(*(0.5,parsertemp382674),*(5.0E-7,+(sum(parsertemp382677),sum(parsertemp382680)))) +::STMT +MATRIX:p_LS,X +FLOAT:lambda_LS ++(%*%(%*%(t(X),X),p_LS),*(lambda_LS,p_LS)) +::STMT +LITERAL_FLOAT:8001.0 +8001.0 +::STMT +MATRIX:parsertemp396419,W4_rand +FLOAT:int485,int992 +LITERAL_FLOAT:0.08681986202598489 +%*%(*(0.08681986202598489,W4_rand),t(/(-(parsertemp396419,int992),+(parsertemp396419,int485)))) +::STMT +MATRIX:Y_prob,parsertemp171377,Y,parsertemp171380 +FLOAT:int900 +LITERAL_FLOAT:3.141592653589793 +/(*(rowSums(Y),-(*(Y,Y_prob),*(Y,Y_prob))),*(*(*(parsertemp171377,Y_prob),Y_prob),*(+(int900,parsertemp171380),3.141592653589793))) +::STMT +MATRIX:parsertemp220853,parsertemp220854 +FLOAT:logU +LITERAL_FLOAT:0.0,2.0 +*(2.0,>=(-(+(parsertemp220853,parsertemp220854),logU),0.0)) +::STMT +FLOAT:parsertemp500918,offset_x +-(parsertemp500918,round(offset_x)) +::STMT +FLOAT:var_power,link_power +LITERAL_FLOAT:1.0 +/(-(1.0,var_power),link_power) +::STMT +FLOAT:index +LITERAL_FLOAT:1.0,2.0 ++(+(*(index,2.0),2.0),1.0) +::STMT +MATRIX:Yhat_prime,H3_prime,E,W4 +colSums(*(H3_prime,%*%(*(E,Yhat_prime),W4))) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:750.0 ++(rowSums(classFeatureCounts),750.0) +::STMT +MATRIX:LT,parsertemp149320,parsertemp150469 +rowSums(exp(-(LT,%*%(parsertemp149320,parsertemp150469)))) +::STMT +MATRIX:X,parsertemp429911 +FLOAT:int813,int704 +LITERAL_FLOAT:300.0,2.0 +-(t(colSums(^(X,int813))),*(300.0,^(/(parsertemp429911,int704),2.0))) +::STMT +MATRIX:y_hat,X_adapted +FLOAT:k,parsertemp176418 +>(X_adapted,+(sqrt(parsertemp176418),*(k,y_hat))) +::STMT +MATRIX:y_hat,X_adapted +FLOAT:parsertemp176421,k +<(X_adapted,-(sqrt(parsertemp176421),*(k,y_hat))) +::STMT +FLOAT:int630,i_iter,interval,i_process_item +LITERAL_FLOAT:1.0 +-(i_process_item,+(*(-(i_iter,int630),interval),1.0)) +::STMT +MATRIX:termination_bitmap,final_wcss_successful +LITERAL_FLOAT:1.0,10.0 +*(+(*(10.0,max(final_wcss_successful)),10.0),-(1.0,termination_bitmap)) +::STMT +MATRIX:sig +FLOAT:q,mu +LITERAL_FLOAT:4.0 +/(-(q,*(4.0,*(mu,mu))),*(4.0,*(cast.FLOAT(sig),cast.FLOAT(sig)))) +::STMT +MATRIX:negSampleMeans,negSamples +FLOAT:int169,int697 +LITERAL_FLOAT:1499.0,1500.0 +/(-(colSums(^(negSamples,int169)),*(1500.0,^(negSampleMeans,int697))),1499.0) +::STMT +MATRIX:X +LITERAL_FLOAT:300.0,2.0 +^(/(t(colSums(X)),300.0),2.0) +::STMT +FLOAT:log_l_change +LITERAL_FLOAT:2.0 +*(2.0,abs(log_l_change)) +::STMT +MATRIX:parsertemp132003,parsertemp132023,leftIdx +%*%(parsertemp132023,%*%(t(parsertemp132003),leftIdx)) +::STMT +MATRIX:d,X,logisticD +FLOAT:C +*(C,%*%(t(X),*(logisticD,%*%(X,d)))) +::STMT +MATRIX:parsertemp222700,X,parsertemp222696,parsertemp222693 +LITERAL_FLOAT:-2.0 +<=(+(*(-2.0,%*%(X,parsertemp222693)),t(rowSums(parsertemp222696))),parsertemp222700) +::STMT +MATRIX:X +FLOAT:int617 +LITERAL_FLOAT:0.0 +!=(t(colSums(!=(X,int617))),0.0) +::STMT +MATRIX:ss +FLOAT:alpha +LITERAL_FLOAT:1.0,20.0 +*(-(1.0,alpha),-(/(20.0,ss),1.0)) +::STMT +MATRIX:means,Y +colSums(-(Y,means)) +::STMT +FLOAT:link_power +LITERAL_FLOAT:-2.0 +*(-2.0,link_power) +::STMT +FLOAT:pow_two +LITERAL_FLOAT:2.0 +*(2.0,pow_two) +::STMT +MATRIX:p_CG +FLOAT:parsertemp285739,parsertemp285737,pp_CG +LITERAL_FLOAT:-1.0 +/(-(*(cast.FLOAT(p_CG),-1.0),sqrt(-(parsertemp285737,parsertemp285739))),pp_CG) +::STMT +MATRIX:p,q,V,parsertemp1939 +FLOAT:norm_r2,eps +*(/(norm_r2,cast.FLOAT(%*%(parsertemp1939,q))),+(%*%(t(V),%*%(V,p)),*(eps,p))) +::STMT +MATRIX:linear_terms +FLOAT:link_power,int370 +LITERAL_FLOAT:0.0,1.0 +-(^(+(linear_terms,==(linear_terms,int370)),/(1.0,link_power)),==(linear_terms,0.0)) +::STMT +MATRIX:w,X,y +LITERAL_FLOAT:-1.0 +exp(*(*(y,-1.0),%*%(X,w))) +::STMT +LITERAL_FLOAT:2.0,1500.0 +^(1500.0,2.0) +::STMT +MATRIX:parsertemp132494,rightHist,outBucket +%*%(==(outBucket,t(parsertemp132494)),rightHist) +::STMT +MATRIX:parsertemp174552 +LITERAL_FLOAT:0.0 +abs(==(parsertemp174552,0.0)) +::STMT +MATRIX:y_corr +LITERAL_FLOAT:1.0,2.0,0.5 +-(1.0,*(2.0,>(y_corr,0.5))) +::STMT +FLOAT:ytest,yhat,int56,parsertemp454076,int163 +LITERAL_FLOAT:1.0,2.0 +-(1.0,/(^(-(ytest,yhat),2.0),-(^(ytest,int56),*(int163,parsertemp454076)))) +::STMT +MATRIX:Q1,IQR +FLOAT:k +-(Q1,*(k,IQR)) +::STMT +MATRIX:xs +FLOAT:256_x +LITERAL_FLOAT:1.0,1000.0 ++(-(1000.0,sum(>=(xs,256_x))),1.0) +::STMT +MATRIX:w,parsertemp2794 +LITERAL_FLOAT:2.0,0.5 +*(0.5,sum(^(+(w,parsertemp2794),2.0))) +::STMT +MATRIX:linear_terms,Y +FLOAT:int668 +LITERAL_FLOAT:0.0,1.0 ++(*(linear_terms,-(1.0,==(Y,int668))),==(Y,0.0)) +::STMT +MATRIX:parsertemp410080,d_r_rev,parsertemp410079,parsertemp410090 +LITERAL_FLOAT:-1.0 ++(*(cast.FLOAT(%*%(parsertemp410079,parsertemp410080)),-1.0),sum(*(d_r_rev,parsertemp410090))) +::STMT +MATRIX:parsertemp132003,parsertemp132023,leftIdx +LITERAL_FLOAT:0.0 +>(%*%(parsertemp132023,%*%(t(parsertemp132003),leftIdx)),0.0) +::STMT +MATRIX:parsertemp410987,parsertemp410979,W,parsertemp410981 +/(*(W,parsertemp410987),t(rowSums(/(parsertemp410979,parsertemp410981)))) +::STMT +LITERAL_FLOAT:1.0,4.0 ++(4.0,1.0) +::STMT +MATRIX:D,ZERODIAG,parsertemp220891 +FLOAT:int374,int42 +LITERAL_FLOAT:1.0 +/(*(/(1.0,+(D,int42)),ZERODIAG),sum(*(/(int374,parsertemp220891),ZERODIAG))) +::STMT +MATRIX:184_probs,183_dpred,parsertemp146939,outr2 +LITERAL_FLOAT:2.0 +^(%*%(t(outr2),-(*(183_dpred,184_probs),*(184_probs,parsertemp146939))),2.0) +::STMT +MATRIX:q,r +FLOAT:alpha +sum(*(+(r,*(alpha,q)),+(r,*(alpha,q)))) +::STMT +MATRIX:vb1,parsertemp460691 +FLOAT:lr,mu +-(*(mu,vb1),*(lr,rowSums(parsertemp460691))) +::STMT +FLOAT:obj,obj_new,gs +-(-(obj_new,obj),gs) +::STMT +MATRIX:parsertemp76118 +LITERAL_FLOAT:0.5,4460.0 ++(0.5,/(parsertemp76118,4460.0)) +::STMT +MATRIX:r,parsertemp44050 +sqrt(sum(*(-(r,parsertemp44050),-(r,parsertemp44050)))) +::STMT +FLOAT:padh,Hin +LITERAL_FLOAT:2.0 ++(Hin,*(2.0,padh)) +::STMT +FLOAT:numRows +LITERAL_FLOAT:0.05 +*(0.05,numRows) +::STMT +MATRIX:grad +LITERAL_FLOAT:-1.0 +*(*(grad,-1.0),*(grad,-1.0)) +::STMT +MATRIX:xs +LITERAL_FLOAT:10.0,4.5 +-(10.0,sum(>=(xs,4.5))) +::STMT +MATRIX:parsertemp555766,parsertemp555762,target +LITERAL_FLOAT:-1.0,1.0 +-(*(*(target,-1.0),parsertemp555762),*(-(1.0,target),parsertemp555766)) +::STMT +FLOAT:191_beta2,191_t,191_lr +LITERAL_FLOAT:1.0 +*(191_lr,sqrt(-(1.0,^(191_beta2,191_t)))) +::STMT +MATRIX:w,X,y +sum(*(-(%*%(X,w),y),-(%*%(X,w),y))) +::STMT +LITERAL_FLOAT:0.08720414403938946 +0.08720414403938946 +::STMT +MATRIX:simplex +-(rowSums(simplex),simplex) +::STMT +MATRIX:scale_lambda +LITERAL_FLOAT:1.0E-7 +diag(*(scale_lambda,1.0E-7)) +::STMT +MATRIX:g +FLOAT:lambda +LITERAL_FLOAT:2.0 +sqrt(sum(^(+(g,lambda),2.0))) +::STMT +MATRIX:X,y +FLOAT:int442 +LITERAL_FLOAT:0.0 +INT:m,int706 +-(%*%(X,rand(m,int706,0.0,int442)),y) +::STMT +MATRIX:parsertemp77570 +LITERAL_FLOAT:0.5,2358.0 ++(0.5,/(parsertemp77570,2358.0)) +::STMT +MATRIX:p,q,r,lambda +FLOAT:norm_r2 ++(r,*(/(norm_r2,cast.FLOAT(p)),+(q,*(lambda,p)))) +::STMT +MATRIX:feature +LITERAL_FLOAT:1.0 +-(1.0,min(feature)) +::STMT +MATRIX:n_event_stratum,n_risk_stratum,n_risk +LITERAL_FLOAT:2.0 +*(^(n_risk_stratum,2.0),*(n_risk,n_event_stratum)) +::STMT +MATRIX:Y +FLOAT:check_max,check_min +LITERAL_FLOAT:2.0 +*(/(2.0,-(check_max,check_min)),Y) +::STMT +MATRIX:U,X,parsertemp382669 +LITERAL_FLOAT:0.0,2.0 +*(!=(X,0.0),^(-(%*%(U,parsertemp382669),X),2.0)) +::STMT +FLOAT:idx +LITERAL_FLOAT:256.0 +-(256.0,idx) +::STMT +MATRIX:paramLens,parsertemp387457 +rev(/(parsertemp387457,rev(paramLens))) +::STMT +MATRIX:p_CG,z +FLOAT:trust_delta_sq +-(*(cast.FLOAT(%*%(p_CG,z)),cast.FLOAT(%*%(p_CG,z))),*(cast.FLOAT(%*%(p_CG,p_CG)),-(cast.FLOAT(z),trust_delta_sq))) +::STMT +MATRIX:X,H,parsertemp18133 +LITERAL_FLOAT:0.0,2.0 +*(>(%*%(X,t(H)),0.0),t(^(2.0,parsertemp18133))) +::STMT +MATRIX:parsertemp429918,parsertemp429916,parsertemp429914 +FLOAT:int453,int941 +LITERAL_FLOAT:0.0,1.0,299.0 +*(/(-(t(parsertemp429914),*(int453,parsertemp429916)),299.0),-(1.0,<=(/(parsertemp429918,int941),0.0))) +::STMT +FLOAT:idx +LITERAL_FLOAT:253.0 +-(253.0,idx) +::STMT +MATRIX:parsertemp175075,parsertemp175079,X,R1 +-(R1,/(exp(-(X,parsertemp175075)),rowSums(exp(parsertemp175079)))) +::STMT +FLOAT:522_strideh,522_padh,522_Hin,int470 +LITERAL_FLOAT:1.0 +/(-(+(522_Hin,*(int470,522_padh)),1.0),522_strideh) +::STMT +MATRIX:WM +FLOAT:parsertemp31268 +LITERAL_FLOAT:1.0 +/(*(parsertemp31268,sum(WM)),-(sum(WM),1.0)) +::STMT +LITERAL_FLOAT:0.0,1.0 +INT:n_components,parsertemp506195 +/(rand(parsertemp506195,n_components,0.0,1.0),rowSums(rand(parsertemp506195,n_components,0.0,1.0))) +::STMT +FLOAT:covXY +covXY +::STMT +MATRIX:is_row_in_samples,parsertemp76114 +LITERAL_FLOAT:13381.0 +-(13381.0,*(is_row_in_samples,parsertemp76114)) +::STMT +MATRIX:output,output1 +LITERAL_FLOAT:0.1 +<(abs(-(output,output1)),0.1) +::STMT +MATRIX:prec,X,mu +LITERAL_FLOAT:2.0 +rowSums(^(-(%*%(X,prec),%*%(mu,prec)),2.0)) +::STMT +LITERAL_FLOAT:1.0,100.0 +-(100.0,1.0) +::STMT +MATRIX:parsertemp222310 +FLOAT:parsertemp222313 +LITERAL_FLOAT:0.5 +round(+(/(parsertemp222310,parsertemp222313),0.5)) +::STMT +MATRIX:resp,X,parsertemp437188 +FLOAT:float168 +LITERAL_FLOAT:2.0 +^(/(%*%(t(resp),X),t(+(parsertemp437188,float168))),2.0) +::STMT +MATRIX:y_residual,ytest +LITERAL_FLOAT:2.0 +*($1:nrow(ytest),^(/(sum(y_residual),$1),2.0)) +::STMT +LITERAL_FLOAT:5.0E-4 +5.0E-4 +::STMT +MATRIX:316_unnorm_probs,parsertemp175059 +LITERAL_FLOAT:1.0E-6 +<(abs(-(/(316_unnorm_probs,parsertemp175059),/(316_unnorm_probs,parsertemp175059))),1.0E-6) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0 +t(-(0.0,t(colSums(X)))) +::STMT +MATRIX:y_train,prediction +FLOAT:float477 +/(sum(==(y_train,>(prediction,float477))),nrow(y_train)) +::STMT +MATRIX:z +FLOAT:trust_delta_sq +LITERAL_FLOAT:2.0 +-(^(cast.FLOAT(z),2.0),trust_delta_sq) +::STMT +MATRIX:e_r_rev_agg,d_r_rev,X_agg +t(colSums(/(*(X_agg,d_r_rev),e_r_rev_agg))) +::STMT +MATRIX:parsertemp222327,is_row_in_samples +FLOAT:sample_block_size,num_samples +LITERAL_FLOAT:1.0 +-(+(*(sample_block_size,num_samples),1.0),*(is_row_in_samples,parsertemp222327)) +::STMT +FLOAT:m2Y,sigmaX +LITERAL_FLOAT:1.0002795638803466 +*(sigmaX,sqrt(*(m2Y,1.0002795638803466))) +::STMT +MATRIX:X,permut +FLOAT:n +/(colSums(%*%(permut,X)),n) +::STMT +LITERAL_FLOAT:1.0E-10 +1.0E-10 +::STMT +MATRIX:output_values +LITERAL_FLOAT:0.3 +*(0.3,sum(output_values)) +::STMT +LITERAL_FLOAT:1.0,-1.0 +*(1.0,-1.0) +::STMT +MATRIX:Q,V,X,P_1K +%*%(t(X),-(*(P_1K,%*%(X,V)),*(P_1K,rowSums(Q)))) +::STMT +MATRIX:prec +diag(t(prec)) +::STMT +LITERAL_FLOAT:1.0,5.0 ++(5.0,1.0) +::STMT +LITERAL_FLOAT:0.0 +cast.MATRIX(0.0) +::STMT +MATRIX:parsertemp382680,col_nonzeros,parsertemp382677,row_nonzeros +FLOAT:reg +LITERAL_FLOAT:0.5 +*(*(0.5,reg),+(sum(*(parsertemp382677,row_nonzeros)),sum(*(parsertemp382680,col_nonzeros)))) +::STMT +MATRIX:d_r,parsertemp409781 +%*%(t(rev(d_r)),parsertemp409781) +::STMT +MATRIX:B,X,y +FLOAT:intercept +-(y,+(%*%(X,B),intercept)) +::STMT +MATRIX:A,scale_X,shift_X +t(+(%*%(diag(scale_X),A),%*%(shift_X,A))) +::STMT +FLOAT:idx +LITERAL_FLOAT:1.0,64.0 ++(-(64.0,idx),1.0) +::STMT +MATRIX:g_new,parsertemp2824,s,parsertemp2826 ++(*(/(sum(parsertemp2824),sum(parsertemp2826)),s),g_new) +::STMT +MATRIX:linear_terms +FLOAT:link_power +LITERAL_FLOAT:0.0,1.0 +^(+(linear_terms,==(linear_terms,0.0)),/(1.0,link_power)) +::STMT +MATRIX:parsertemp171600,g_Y,lambda,scale_X,shift_X,gXY,beta ++(+(%*%(diag(scale_X),%*%(parsertemp171600,g_Y)),%*%(shift_X,gXY)),*(lambda,beta)) +::STMT +MATRIX:s,parsertemp44016,d +FLOAT:delta2 +*(%*%(t(d),d),-(delta2,%*%(t(s),-(s,parsertemp44016)))) +::STMT +LITERAL_FLOAT:0.0 +INT:int1,int961 +exp(rand(int1,int961,0.0,0.0)) +::STMT +MATRIX:V,X,P_1K +*(P_1K,rowSums(*(P_1K,%*%(X,V)))) +::STMT +MATRIX:W2_rand +LITERAL_FLOAT:0.6546536707079771 +*(0.6546536707079771,W2_rand) +::STMT +MATRIX:linear_terms +FLOAT:var_power +LITERAL_FLOAT:2.0 +*(linear_terms,-(2.0,var_power)) +::STMT +MATRIX:X2 +LITERAL_FLOAT:4.0 +<(t(colSums(X2)),4.0) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0 +-(+(i,1.0),1.0) +::STMT +FLOAT:approx_sample_size +LITERAL_FLOAT:10.0 ++(approx_sample_size,round(*(10.0,sqrt(approx_sample_size)))) +::STMT +MATRIX:B +LITERAL_FLOAT:4.0 +-(4.0,nrow(B)) +::STMT +FLOAT:dist +t(cast.MATRIX(dist)) +::STMT +MATRIX:num_std +t(sqrt(num_std)) +::STMT +MATRIX:var_X_cols,tmp +FLOAT:int300,int338,int958,N +LITERAL_FLOAT:0.0,1.0 ++(*(/(tmp,-(N,int338)),-(1.0,<=(var_X_cols,int958))),<=(/(tmp,-(N,int300)),0.0)) +::STMT +LITERAL_FLOAT:1.0E-12 +1.0E-12 +::STMT +FLOAT:float824,int237,float466,int581 +LITERAL_FLOAT:1.0,3.0,6.0,2000.0 +/(*(*(6.0,2000.0),-(2000.0,1.0)),*(*(-(int237,float466),+(int581,float824)),+(2000.0,3.0))) +::STMT +MATRIX:is_LT_infinite,Y_prob +LITERAL_FLOAT:1.0 ++(*(Y_prob,-(1.0,rowSums(is_LT_infinite))),is_LT_infinite) +::STMT +MATRIX:means,Y_counts,Y,parsertemp560602 +-(-(Y,means),%*%(Y_counts,/(colSums(parsertemp560602),sum(Y_counts)))) +::STMT +MATRIX:parsertemp382947 +FLOAT:reg,parsertemp382956,loss_init,parsertemp382953,float925 +LITERAL_FLOAT:0.5 +-(loss_init,+(*(0.5,sum(parsertemp382947)),*(*(float925,reg),+(parsertemp382953,parsertemp382956)))) +::STMT +FLOAT:index +LITERAL_FLOAT:2.0,4.0 ++(*(index,4.0),2.0) +::STMT +MATRIX:R +FLOAT:i8 +LITERAL_FLOAT:24.0 +-(nrow(R),*(24.0,i8)) +::STMT +MATRIX:parsertemp436114 +FLOAT:int359,int471 +INT:2663_2662_n_col,int558 +*(cast.FLOAT(parsertemp436114),rand(int558,2663_2662_n_col,int359,int471)) +::STMT +FLOAT:parsertemp83 +abs(-(cast.MATRIX(parsertemp83),parsertemp83)) +::STMT +MATRIX:parsertemp31112,parsertemp31114 +FLOAT:int597,int905 +LITERAL_FLOAT:1.0,2.0,1500.0 +/(^(/(-(parsertemp31112,parsertemp31114),-(int905,int597)),2.0),*(^(1500.0,2.0),-(1500.0,1.0))) +::STMT +MATRIX:Ileft,Iright +FLOAT:min_leaf +&(>=(rowSums(Ileft),min_leaf),>=(rowSums(Iright),min_leaf)) +::STMT +MATRIX:codebook +FLOAT:j +LITERAL_FLOAT:1.0 +*(-(j,1.0),ncol(codebook)) +::STMT +MATRIX:parsertemp429916,parsertemp429914 +FLOAT:int441 +LITERAL_FLOAT:0.0,299.0 +<=(/(-(t(parsertemp429914),*(int441,parsertemp429916)),299.0),0.0) +::STMT +MATRIX:subspace_idx,parsertemp107049 +LITERAL_FLOAT:1.0,7.0 +<(-(subspace_idx,*(parsertemp107049,7.0)),1.0) +::STMT +MATRIX:W1_rand +LITERAL_FLOAT:0.08146881698903526 +*(0.08146881698903526,W1_rand) +::STMT +MATRIX:R +LITERAL_FLOAT:0.0 +cast.FLOAT(==(R,0.0)) +::STMT +MATRIX:parsertemp10743,V,parsertemp10742,H,parsertemp10739,parsertemp10738 +FLOAT:Eps +%*%(*(H,/(%*%(parsertemp10738,V),+(parsertemp10742,Eps))),t(*(H,/(parsertemp10739,parsertemp10743)))) +::STMT +MATRIX:P,Y,dP +sum(&(<=(P,dP),!(Y))) +::STMT +MATRIX:distances,ksmall +FLOAT:int819,int751 +LITERAL_FLOAT:0.0 +INT:parsertemp557199,int480 +*(<=(distances,ksmall),==(diag(rand(parsertemp557199,int480,int819,int751)),0.0)) +::STMT +FLOAT:2690_Hin,parsertemp459058 +LITERAL_FLOAT:1.0,2.0 ++(/(-(+(2690_Hin,parsertemp459058),2.0),2.0),1.0) +::STMT +FLOAT:252_Y,float605,int241,252_X,252_K,float60 +LITERAL_FLOAT:1.0 ++(*(-(*(252_K,252_X),-(252_Y,252_Y)),-(1.0,/(float605,252_X))),*(+(*(int241,252_X),-(252_Y,252_Y)),/(-(float60,252_X),-(252_X,252_X)))) +::STMT +MATRIX:Y,parsertemp2798,Xw +FLOAT:int496,int812 +LITERAL_FLOAT:0.0,1.0 +*(*(>(-(int496,parsertemp2798),0.0),-(1.0,*(Y,Xw))),*(>(-(int812,parsertemp2798),0.0),-(1.0,*(Y,Xw)))) +::STMT +MATRIX:2364_2359_Y_prime,W2,2364_2358_Y,parsertemp389612 +FLOAT:int492 +LITERAL_FLOAT:1.0 +t(*(-(1.0,^(2364_2358_Y,int492)),%*%(*(2364_2359_Y_prime,parsertemp389612),W2))) +::STMT +MATRIX:s +FLOAT:n +LITERAL_FLOAT:1.0 +-(*(/(1.0,s),n),1.0) +::STMT +MATRIX:y_corr +FLOAT:int922 +LITERAL_FLOAT:1.0 +*(*(y_corr,-(1.0,<=(y_corr,int922))),-(1.0,>=(y_corr,1.0))) +::STMT +FLOAT:429_C +LITERAL_FLOAT:1.0 +*(*(429_C,1.0),1.0) +::STMT +MATRIX:parsertemp220853,parsertemp220854,beta +FLOAT:logU +LITERAL_FLOAT:0.0 +*(>=(-(+(parsertemp220853,parsertemp220854),logU),0.0),beta) +::STMT +MATRIX:Y,2212_tp +/(2212_tp,sum(Y)) +::STMT +FLOAT:int489,lratio_t,N +LITERAL_FLOAT:1.0 +-(1.0,exp(/(*(lratio_t,int489),N))) +::STMT +MATRIX:parsertemp116096,X2 +LITERAL_FLOAT:0.0,32.0 +|(<(t(colSums(X2)),32.0),==(t(%*%(parsertemp116096,X2)),0.0)) +::STMT +MATRIX:H2_prime,2365_delta3,H1_prime,W2,W3 +*(H1_prime,%*%(*(H2_prime,%*%(2365_delta3,W3)),W2)) +::STMT +MATRIX:parsertemp44107,parsertemp44109,wnew +FLOAT:C +*(+(wnew,*(C,%*%(parsertemp44107,parsertemp44109))),+(wnew,*(C,%*%(parsertemp44107,parsertemp44109)))) +::STMT +MATRIX:Y +LITERAL_FLOAT:0.0,0.5 +*(-(0.5,Y),==(rowSums(Y),0.0)) +::STMT +MATRIX:s,d,parsertemp44021 +FLOAT:delta2 +*(cast.FLOAT(%*%(t(d),d)),-(delta2,cast.FLOAT(%*%(parsertemp44021,s)))) +::STMT +LITERAL_FLOAT:1.0,100.0,0.8 ++(*(100.0,0.8),1.0) +::STMT +MATRIX:tmp,X,parsertemp393475,parsertemp393466 +LITERAL_FLOAT:1.0E-17 +t(/(-(%*%(tmp,X),parsertemp393466),+(sqrt(parsertemp393475),1.0E-17))) +::STMT +MATRIX:parsertemp129018 +LITERAL_FLOAT:1.0,2.0 ++(*(max(parsertemp129018),2.0),1.0) +::STMT +MATRIX:surv,se_surv,parsertemp538736 +FLOAT:parsertemp538734 +^(surv,exp(/(*(parsertemp538734,se_surv),parsertemp538736))) +::STMT +FLOAT:i,k +LITERAL_FLOAT:4.0 ++(+(i,k),4.0) +::STMT +MATRIX:p,V +FLOAT:eps ++(%*%(t(V),%*%(V,p)),*(eps,p)) +::STMT +MATRIX:parsertemp552345,tab,catTotal +LITERAL_FLOAT:-1.0 +sum(*(*(/(tab,catTotal),-1.0),parsertemp552345)) +::STMT +MATRIX:X2 +LITERAL_FLOAT:32.0 +<(t(colSums(X2)),32.0) +::STMT +MATRIX:m_iter_err_sum,m_err +t(+(colSums(m_err),m_iter_err_sum)) +::STMT +MATRIX:r,d,parsertemp43998 +FLOAT:int723 +LITERAL_FLOAT:2.0 +/(sum(^(r,2.0)),%*%(t(d),+(d,*(int723,parsertemp43998)))) +::STMT +MATRIX:parsertemp122063,parsertemp122058 +FLOAT:eAvg,alpha,n +LITERAL_FLOAT:1.0 +-(*(alpha,-(/(parsertemp122058,eAvg),1.0)),*(-(1.0,alpha),-(*(parsertemp122063,n),1.0))) +::STMT +MATRIX:m_err_mean +LITERAL_FLOAT:-0.001 +-(-0.001,cast.FLOAT(m_err_mean)) +::STMT +MATRIX:X +LITERAL_FLOAT:300.0,0.0 +-(0.0,/(t(colSums(X)),300.0)) +::STMT +MATRIX:WM +FLOAT:m2X,W,float201 +sqrt(*(m2X,/(sum(WM),-(W,float201)))) +::STMT +LITERAL_FLOAT:1.0,3.0 ++(3.0,1.0) +::STMT +MATRIX:V,W,parsertemp10741,H +FLOAT:Eps +*(H,/(%*%(t(W),V),+(%*%(parsertemp10741,H),Eps))) +::STMT +MATRIX:parsertemp410118,g0_1,g_2 +cast.FLOAT(%*%(t(+(g0_1,g_2)),+(g0_1,t(parsertemp410118)))) +::STMT +MATRIX:E,F,parsertemp12849 +FLOAT:q,int210 +sqrt(/(sum(/(parsertemp12849,E)),*(sum(F),-(q,int210)))) +::STMT +MATRIX:log_prob,X +LITERAL_FLOAT:1.8378770664093453 ++(*(ncol(X),1.8378770664093453),log_prob) +::STMT +MATRIX:X,parsertemp16893,parsertemp16892 +/(%*%(X,t(X)),%*%(sqrt(rowSums(parsertemp16892)),t(sqrt(parsertemp16893)))) +::STMT +MATRIX:s,w +%*%(t(+(w,s)),+(w,s)) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0,200.0 +-(0.0,/(t(colSums(X)),200.0)) +::STMT +MATRIX:parsertemp443530,mean,parsertemp443532,X,weight +FLOAT:float416 +-(%*%(t(X),X),%*%(*(t(mean),+(parsertemp443530,float416)),/(%*%(parsertemp443532,X),t(weight)))) +::STMT +MATRIX:linear_terms +FLOAT:var_power +LITERAL_FLOAT:2.0 +/(^(linear_terms,2.0),-(2.0,var_power)) +::STMT +MATRIX:parsertemp170101 +FLOAT:parsertemp170114,r_CG,g_reg,z,277_sq_root_d,parsertemp170093,pp_CG +LITERAL_FLOAT:0.5 ++(*(0.5,*(cast.FLOAT(z),+(r_CG,g_reg))),*(+(+(parsertemp170114,z),sum(parsertemp170101)),/(-(parsertemp170093,277_sq_root_d),pp_CG))) +::STMT +MATRIX:Y +FLOAT:maxv,minv ++(sum(==(Y,minv)),sum(==(Y,maxv))) +::STMT +FLOAT:num_records +LITERAL_FLOAT:960.0 +/(960.0,num_records) +::STMT +MATRIX:r,parsertemp44050 +FLOAT:norm_r2 +/(sum(*(-(r,parsertemp44050),-(r,parsertemp44050))),norm_r2) +::STMT +MATRIX:X,permut +colSums(%*%(permut,X)) +::STMT +FLOAT:batch_size,parsertemp145942 +LITERAL_FLOAT:1.0 +-(+(+(parsertemp145942,1.0),batch_size),1.0) +::STMT +MATRIX:lambda,V,shift_X,parsertemp274512,HV +*(V,+(+(%*%(parsertemp274512,HV),%*%(shift_X,HV)),*(lambda,V))) +::STMT +MATRIX:I,y2 +LITERAL_FLOAT:2.0 +^(/(%*%(I,y2),sum(I)),2.0) +::STMT +MATRIX:H3_prime,delta4,W4 +t(colSums(*(H3_prime,%*%(delta4,W4)))) +::STMT +MATRIX:tmp,parsertemp260786,X,Y,parsertemp260785,out +%*%(t(-(%*%(parsertemp260785,parsertemp260786),tmp)),-(%*%(t(X),*(out,Y)),tmp)) +::STMT +MATRIX:Y,missing_mask_Y +LITERAL_FLOAT:1.0 +*(missing_mask_Y,+(max(Y),1.0)) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:1.0,0.231641888 ++(1.0,*(abs(finite_linear_terms),0.231641888)) +::STMT +MATRIX:ytest,yhat +FLOAT:int780,mean_y_test +LITERAL_FLOAT:1.0,2.0 +/(^(-(cast.FLOAT(ytest),cast.FLOAT(yhat)),2.0),-(^(cast.FLOAT(ytest),2.0),*(1.0,^(mean_y_test,int780)))) +::STMT +MATRIX:z +FLOAT:trust_delta_sq,pp_CG +LITERAL_FLOAT:2.0 +*(pp_CG,-(^(cast.FLOAT(z),2.0),trust_delta_sq)) +::STMT +MATRIX:parsertemp147188 +FLOAT:D +LITERAL_FLOAT:2.0 +*(parsertemp147188,sqrt(/(2.0,D))) +::STMT +MATRIX:X +FLOAT:int111 +LITERAL_FLOAT:1.0E-6 +/(*(1.0E-6,sum(^(X,int111))),ncol(X)) +::STMT +LITERAL_FLOAT:1.4142135623730951 +1.4142135623730951 +::STMT +MATRIX:sq_sums,mu +FLOAT:window_size +-(/(sq_sums,window_size),*(mu,mu)) +::STMT +MATRIX:663_img +t(rev(t(663_img))) +::STMT +MATRIX:W1_rand +LITERAL_FLOAT:0.08692913816996169 +*(0.08692913816996169,W1_rand) +::STMT +MATRIX:classes +LITERAL_FLOAT:1.0,0.7 +*(cast.FLOAT(classes),-(1.0,0.7)) +::STMT +MATRIX:parsertemp2832 +sum(==(round(parsertemp2832),max(round(parsertemp2832)))) +::STMT +FLOAT:i +LITERAL_FLOAT:18.0 ++(18.0,i) +::STMT +MATRIX:V +FLOAT:std_dev,int435,mu +*(<(V,-(mu,*(int435,std_dev))),V) +::STMT +MATRIX:V +FLOAT:std_dev,mu,int91 +*(>(V,+(mu,*(int91,std_dev))),V) +::STMT +MATRIX:d,X,logisticD +%*%(t(X),*(logisticD,%*%(X,d))) +::STMT +MATRIX:parsertemp477917,b +FLOAT:int929 +LITERAL_FLOAT:2.0 +sum(^(%*%(*(parsertemp477917,int929),b),2.0)) +::STMT +MATRIX:subspace_idx,parsertemp72201 +LITERAL_FLOAT:1.0,8.0 +<(-(subspace_idx,*(parsertemp72201,8.0)),1.0) +::STMT +MATRIX:Y +LITERAL_FLOAT:0.0 +/(Y,+(rowSums(Y),==(rowSums(Y),0.0))) +::STMT +MATRIX:w,X,y +FLOAT:int253 +LITERAL_FLOAT:1.0 ++(1.0,exp(*(*(y,int253),%*%(X,w)))) +::STMT +MATRIX:r_LS +FLOAT:norm_r2_LS,p_LS,parsertemp170552,lambda_LS +LITERAL_FLOAT:0.0 +-(0.0,+(r_LS,*(/(norm_r2_LS,p_LS),+(parsertemp170552,lambda_LS)))) +::STMT +MATRIX:parsertemp552530,Y +LITERAL_FLOAT:0.0 +INT:parsertemp552529,idx +==(+(rand(parsertemp552529,idx,0.0,0.0),t(parsertemp552530)),Y) +::STMT +MATRIX:y_corr +LITERAL_FLOAT:1.0,2.0 +-(1.0,*(2.0,y_corr)) +::STMT +MATRIX:linear_terms +FLOAT:int594 +LITERAL_FLOAT:1.0,2.0 ++(1.0,-(*(2.0,>=(linear_terms,int594)),1.0)) +::STMT +MATRIX:shift_X,w,parsertemp170066,X +*(cast.FLOAT(shift_X),cast.FLOAT(%*%(t(X),*(w,parsertemp170066)))) +::STMT +MATRIX:parsertemp437548,pred,parsertemp437666 +==(*(parsertemp437666,t(parsertemp437548)),pred) +::STMT +MATRIX:means,parsertemp389215 +FLOAT:int11 +LITERAL_FLOAT:1057.0,1058.0 +/(*(-(parsertemp389215,^(means,int11)),1058.0),1057.0) +::STMT +MATRIX:U,V_sum +rowSums(/(*(U,U),sum(V_sum))) +::STMT +FLOAT:padh,strideh,int428,Hin,Hf +/(-(+(Hin,*(int428,padh)),Hf),strideh) +::STMT +MATRIX:parsertemp31111,parsertemp31113 +FLOAT:int876 +LITERAL_FLOAT:1499.0,2.0 +^(/(-(colSums(parsertemp31111),*(int876,parsertemp31113)),1499.0),2.0) +::STMT +MATRIX:parsertemp16859,X +FLOAT:int570 +LITERAL_FLOAT:1.0E-6 ++(sqrt(rowSums(^(X,int570))),*(<(sqrt(parsertemp16859),1.0E-6),1.0E-6)) +::STMT +FLOAT:new_log_l,log_l,neg_log_l_change_predicted +LITERAL_FLOAT:-1.0 +/(+(*(new_log_l,-1.0),log_l),neg_log_l_change_predicted) +::STMT +FLOAT:i2 +LITERAL_FLOAT:24.0,1.0 ++(*(24.0,i2),1.0) +::STMT +MATRIX:grad +sqrt(sum(*(grad,grad))) +::STMT +FLOAT:res_eee +LITERAL_FLOAT:2.0,0.3 +round(-(/(res_eee,2.0),0.3)) +::STMT +MATRIX:parsertemp285531,z,parsertemp285533 +FLOAT:pp,sq_root_d,zq,parsertemp285544,parsertemp285526 +LITERAL_FLOAT:0.5 ++(*(0.5,sum(*(z,parsertemp285533))),*(+(+(parsertemp285544,zq),sum(parsertemp285531)),/(-(parsertemp285526,sq_root_d),pp))) +::STMT +MATRIX:parsertemp382919,parsertemp382916,S,col_nonzeros +FLOAT:reg +*(S,+(t(%*%(parsertemp382916,parsertemp382919)),*(*(reg,S),col_nonzeros))) +::STMT +MATRIX:vectors,pq_result +LITERAL_FLOAT:2.0,480.0 +/(sum(^(-(vectors,pq_result),2.0)),480.0) +::STMT +MATRIX:X,ScaleFactor +FLOAT:N +t(/(colSums(/(X,ScaleFactor)),N)) +::STMT +MATRIX:border,parsertemp386448,parsertemp386459,withinEps +LITERAL_FLOAT:0.0 +t(*(>(*(parsertemp386448,withinEps),0.0),==(-(border,parsertemp386459),0.0))) +::STMT +MATRIX:vectors,pq_result +LITERAL_FLOAT:800.0,2.0 +/(sum(^(-(vectors,pq_result),2.0)),800.0) +::STMT +MATRIX:p,lambda,parsertemp456801,parsertemp456800 +cast.FLOAT(%*%(t(p),+(%*%(parsertemp456800,parsertemp456801),*(lambda,p)))) +::STMT +MATRIX:parsertemp500609,parsertemp500606,parsertemp500604,w +FLOAT:int146,int367 +*(-(*(*(parsertemp500604,parsertemp500606),>(parsertemp500609,int146)),w),-(*(*(parsertemp500604,parsertemp500606),>(parsertemp500609,int367)),w)) +::STMT +LITERAL_FLOAT:0.21483446221182986 +0.21483446221182986 +::STMT +MATRIX:P,X,Y,parsertemp148868 +FLOAT:float9 +LITERAL_FLOAT:0.0,2.0 +^(+(%*%(t(X),-(P,Y)),*(*(parsertemp148868,float9),0.0)),2.0) +::STMT +MATRIX:parsertemp467675,Y,Xw +FLOAT:int437 +LITERAL_FLOAT:0.0,1.0 +*(*(>(-(int437,parsertemp467675),0.0),-(1.0,*(Y,Xw))),Y) +::STMT +MATRIX:simplex +/(-(rowSums(simplex),simplex),nrow(simplex)) +::STMT +MATRIX:d_r,parsertemp409781 +*(rev(d_r),parsertemp409781) +::STMT +FLOAT:W +LITERAL_FLOAT:1.0 +/(W,-(W,1.0)) +::STMT +MATRIX:is_LT_infinite,Y_prob +LITERAL_FLOAT:1.0 +*(/(Y_prob,rowSums(Y_prob)),-(1.0,rowSums(is_LT_infinite))) +::STMT +MATRIX:parsertemp409788,parsertemp409797 +LITERAL_FLOAT:-1.0,2.0 +^(+(*(t(parsertemp409788),-1.0),t(colSums(parsertemp409797))),2.0) +::STMT +MATRIX:parsertemp31111,parsertemp31113 +FLOAT:int649 +LITERAL_FLOAT:1499.0,1500.0 +/(/(-(colSums(parsertemp31111),*(int649,parsertemp31113)),1499.0),1500.0) +::STMT +LITERAL_FLOAT:1.0E-17 +1.0E-17 +::STMT +MATRIX:scale_lambda,parsertemp150455 +LITERAL_FLOAT:0.0,1.0E-5 +*(*(%*%(scale_lambda,parsertemp150455),1.0E-5),0.0) +::STMT +FLOAT:e,decay +LITERAL_FLOAT:1.0 ++(1.0,*(decay,e)) +::STMT +MATRIX:A +/(*(cast.FLOAT(A),cast.FLOAT(A)),*(cast.FLOAT(A),cast.FLOAT(A))) +::STMT +MATRIX:p,V +FLOAT:eps +%*%(t(p),+(%*%(t(V),%*%(V,p)),*(eps,p))) +::STMT +MATRIX:parsertemp43621,X,y +FLOAT:float787 +LITERAL_FLOAT:1.0 +%*%(t(X),*(-(/(float787,parsertemp43621),1.0),y)) +::STMT +MATRIX:g_new,g_old +LITERAL_FLOAT:2.0 +/(sum(^(g_new,2.0)),sum(^(g_old,2.0))) +::STMT +MATRIX:_sbcvar415,parsertemp116129 +FLOAT:eAvg,parsertemp116127 +LITERAL_FLOAT:0.050000000000000044,1.0,0.95 +-(*(0.95,-(/(parsertemp116129,eAvg),1.0)),*(0.050000000000000044,-(/(parsertemp116127,_sbcvar415),1.0))) +::STMT +MATRIX:w_X,X +FLOAT:int159 +cast.FLOAT(%*%(t(-(int159,w_X)),t(colSums(X)))) +::STMT +MATRIX:prec,X,mu +LITERAL_FLOAT:2.0 +^(-(%*%(X,prec),%*%(mu,prec)),2.0) +::STMT +FLOAT:i,Hin,Win +LITERAL_FLOAT:1.0 +*(*(-(i,1.0),Hin),Win) +::STMT +MATRIX:missing_val_maps +LITERAL_FLOAT:3.0 +-(3.0,nrow(missing_val_maps)) +::STMT +MATRIX:out +FLOAT:dd,step_sz,wd +*(-(+(wd,*(step_sz,dd)),sum(out)),-(+(wd,*(step_sz,dd)),sum(out))) +::STMT +MATRIX:W1_rand +LITERAL_FLOAT:0.07261134713572442 +*(0.07261134713572442,W1_rand) +::STMT +MATRIX:g +FLOAT:float990 +LITERAL_FLOAT:0.0,2.0 +sum(^(-(0.0,*(float990,g)),2.0)) +::STMT +MATRIX:cm,FD +LITERAL_FLOAT:1.0 ++(FD,==(cm,1.0)) +::STMT +FLOAT:parsertemp22485,parsertemp22452,parsertemp22453 +LITERAL_FLOAT:2.0 +-(parsertemp22485,*(2.0,sqrt(+(parsertemp22452,parsertemp22453)))) +::STMT +MATRIX:residual_matrix +LITERAL_FLOAT:0.0,2.0 +/(^(sum(residual_matrix),2.0),+(nrow(residual_matrix),0.0)) +::STMT +MATRIX:lambda,parsertemp285716,scale_X,p_CG,shift_X,parsertemp285714,temp_CG ++(+(*(lambda,p_CG),%*%(diag(scale_X),%*%(parsertemp285714,parsertemp285716))),%*%(shift_X,temp_CG)) +::STMT +MATRIX:parsertemp389212,parsertemp389215 +FLOAT:int362 +LITERAL_FLOAT:2.0,1058.0 +*(-(parsertemp389215,^(/(parsertemp389212,int362),2.0)),1058.0) +::STMT +MATRIX:Xm,parsertemp265706,Z,parsertemp265702 +FLOAT:ss +sum(+(%*%(t(Z),%*%(Xm,parsertemp265702)),*(parsertemp265706,ss))) +::STMT +FLOAT:delta +LITERAL_FLOAT:4.0 +*(4.0,delta) +::STMT +MATRIX:parsertemp42207,parsertemp42208,_sbcvar330,438_Ranks +FLOAT:parsertemp42222,meanY,meanX +LITERAL_FLOAT:0.5 +*(t(*(/(_sbcvar330,parsertemp42222),-(438_Ranks,meanX))),-(+(-(parsertemp42207,parsertemp42208),0.5),meanY)) +::STMT +MATRIX:z,parsertemp285752 +FLOAT:2234_sq_root_d,parsertemp285742,pp_CG,parsertemp285757 +LITERAL_FLOAT:0.5 ++(*(0.5,cast.FLOAT(%*%(z,parsertemp285752))),*(parsertemp285757,/(+(parsertemp285742,2234_sq_root_d),pp_CG))) +::STMT +FLOAT:batch_size,parsertemp145942 +LITERAL_FLOAT:1.0 ++(+(parsertemp145942,1.0),batch_size) +::STMT +FLOAT:m2X,W,float178 +sqrt(*(m2X,/(W,-(W,float178)))) +::STMT +MATRIX:p,q +FLOAT:norm_r2 +*(/(norm_r2,sum(*(p,q))),p) +::STMT +FLOAT:m2,mu +LITERAL_FLOAT:1.0005 +/(sqrt(*(1.0005,m2)),mu) +::STMT +MATRIX:Y_counts,Y,parsertemp560599 +FLOAT:parsertemp560600 +LITERAL_FLOAT:2.0 +^(-(Y,%*%(Y_counts,/(parsertemp560599,parsertemp560600))),2.0) +::STMT +MATRIX:Xd,parsertemp2775 +FLOAT:int805 +LITERAL_FLOAT:0.0 +*(*(Xd,>(-(int805,parsertemp2775),0.0)),Xd) +::STMT +MATRIX:parsertemp500663 +LITERAL_FLOAT:-1.0E30 +*(-1.0E30,parsertemp500663) +::STMT +MATRIX:parsertemp477829,2814_Y +FLOAT:2814_X,inp_x +*(+(*(cast.FLOAT(parsertemp477829),-(2814_X,2814_X)),-(cast.FLOAT(2814_Y),cast.FLOAT(2814_Y))),/(-(inp_x,cast.FLOAT(2814_X)),-(cast.FLOAT(2814_X),cast.FLOAT(2814_X)))) +::STMT +MATRIX:Xtest_dists +FLOAT:eps +LITERAL_FLOAT:0.0 +*(<=(Xtest_dists,eps),<(0.0,Xtest_dists)) +::STMT +MATRIX:parsertemp410250,event +FLOAT:parsertemp410251 +/(-(max(^(parsertemp410250,parsertemp410251)),min(^(parsertemp410250,parsertemp410251))),sum(event)) +::STMT +MATRIX:275_X,275_curr_X +FLOAT:275_value +&(==(275_X,275_curr_X),<(275_X,275_value)) +::STMT +MATRIX:r_CG,g_reg,z +cast.FLOAT(%*%(t(z),+(r_CG,g_reg))) +::STMT +MATRIX:X +FLOAT:var_lag,xq_lag,arch_coef,var_coef,a0 +LITERAL_FLOAT:2.0 +/(^(cast.FLOAT(X),2.0),+(+(a0,*(arch_coef,xq_lag)),*(var_coef,var_lag))) +::STMT +FLOAT:k,n +LITERAL_FLOAT:2.0,4.0 +-(+(-(n,4.0),2.0),k) +::STMT +MATRIX:X +FLOAT:x +-(x,X) +::STMT +MATRIX:Hdiff,beta,betamin +FLOAT:int455,int899 ++(beta,+(*(<(Hdiff,int899),betamin),*(>=(Hdiff,int455),beta))) +::STMT +MATRIX:Hdiff,betamax,beta +FLOAT:int866,int847 ++(beta,+(*(>=(Hdiff,int847),betamax),*(<(Hdiff,int866),beta))) +::STMT +MATRIX:z +LITERAL_FLOAT:2.0 +sqrt(^(cast.FLOAT(z),2.0)) +::STMT +MATRIX:X,H +LITERAL_FLOAT:0.0 +>(%*%(X,t(H)),0.0) +::STMT +MATRIX:Bx +diag(Bx) +::STMT +MATRIX:parsertemp31189,parsertemp31194,parsertemp31196,parsertemp31187 +LITERAL_FLOAT:1499.0,6999.0,1500.0,7000.0 ++(/(/(-(parsertemp31187,parsertemp31189),6999.0),7000.0),/(/(-(parsertemp31194,parsertemp31196),1499.0),1500.0)) +::STMT +MATRIX:parsertemp170244,parsertemp170240,parsertemp170238 +FLOAT:float847,float32,float42 +LITERAL_FLOAT:1.0,-0.284496736 +*(/(1.0,+(1.0,*(parsertemp170238,float847))),+(-0.284496736,*(/(float32,parsertemp170240),+(float42,parsertemp170244)))) +::STMT +FLOAT:2690_Hin +LITERAL_FLOAT:0.0,2.0 +-(+(2690_Hin,*(2.0,0.0)),2.0) +::STMT +MATRIX:A,B,X +<=(%*%(X,A),B) +::STMT +LITERAL_FLOAT:1.0,2.0,3.0,2001.0 +*(*(-(2001.0,2.0),+(2001.0,1.0)),+(2001.0,3.0)) +::STMT +MATRIX:P,I,X2 +*(t(%*%(X2,P)),I) +::STMT +LITERAL_FLOAT:0.06835859270246632 +0.06835859270246632 +::STMT +MATRIX:d,parsertemp410053 +sum(*(d,t(colSums(parsertemp410053)))) +::STMT +MATRIX:W +LITERAL_FLOAT:1.0,2.0 +*(^(sum(round(W)),2.0),+(sum(round(W)),1.0)) +::STMT +MATRIX:2883_ctab +LITERAL_FLOAT:0.0,1.0 +==(rowSums(!=(2883_ctab,0.0)),1.0) +::STMT +MATRIX:M2,X +-(nrow(X),nrow(M2)) +::STMT +MATRIX:parsertemp403496,W3_rand +FLOAT:int454,int938 +LITERAL_FLOAT:0.1651445647689541 +%*%(*(0.1651445647689541,W3_rand),t(/(-(parsertemp403496,int454),+(parsertemp403496,int938)))) +::STMT +MATRIX:w,parsertemp2794 +FLOAT:lambda +LITERAL_FLOAT:2.0 +*(/(lambda,2.0),sum(*(+(w,parsertemp2794),+(w,parsertemp2794)))) +::STMT +MATRIX:parsertemp31029,parsertemp31031 +FLOAT:int700 +LITERAL_FLOAT:1.0,2.0,150.0 +^(/(-(colSums(parsertemp31029),*(int700,parsertemp31031)),-(150.0,1.0)),2.0) +::STMT +MATRIX:p,z +*(sum(*(p,z)),sum(*(p,z))) +::STMT +MATRIX:X +LITERAL_FLOAT:-2.0 +*(-2.0,%*%(X,t(X))) +::STMT +MATRIX:parsertemp31189,parsertemp31194,parsertemp31196,parsertemp31187 +FLOAT:int893,int871,int192,int39 +LITERAL_FLOAT:1500.0,7000.0 ++(/(/(-(parsertemp31187,parsertemp31189),-(int893,int39)),7000.0),/(/(-(parsertemp31194,parsertemp31196),-(int192,int871)),1500.0)) +::STMT +MATRIX:scale_X,shift_X +LITERAL_FLOAT:2.0 +*(*(2.0,scale_X),shift_X) +::STMT +MATRIX:COMPONENTS,id +-(==(id,cast.FLOAT(id)),cast.FLOAT(diag(diag(COMPONENTS)))) +::STMT +MATRIX:252_X +LITERAL_FLOAT:4.5 +/(-(4.5,cast.FLOAT(252_X)),-(cast.FLOAT(252_X),cast.FLOAT(252_X))) +::STMT +MATRIX:252_Y,252_X,252_K +-(*(cast.FLOAT(252_K),-(cast.FLOAT(252_X),cast.FLOAT(252_X))),-(cast.FLOAT(252_Y),cast.FLOAT(252_Y))) +::STMT +MATRIX:gs +FLOAT:alpha2Scalar +LITERAL_FLOAT:-0.5 +/(*(-0.5,cast.FLOAT(gs)),alpha2Scalar) +::STMT +MATRIX:parsertemp146940,184_dtemp +FLOAT:beta2 +LITERAL_FLOAT:1.0,2.0 +*(-(1.0,beta2),^(colSums(-(184_dtemp,parsertemp146940)),2.0)) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0,2000.0 +*(/(2000.0,-(2000.0,1.0)),m2) +::STMT +MATRIX:parsertemp387404,K_inv,Ks,Kss +-(cast.FLOAT(Kss),cast.FLOAT(%*%(%*%(parsertemp387404,K_inv),Ks))) +::STMT +MATRIX:parsertemp131907,parsertemp131918,cumLeftHist,parsertemp132092,leftHist,outBucket ++(%*%(==(outBucket,%*%(parsertemp132092,parsertemp131907)),-(cumLeftHist,leftHist)),parsertemp131918) +::STMT +MATRIX:e,X2 +LITERAL_FLOAT:0.0 +>(t(%*%(t(e),X2)),0.0) +::STMT +MATRIX:PRED,GT +/(sum(==(PRED,GT)),length(==(PRED,GT))) +::STMT +MATRIX:U,V,X +-(X,%*%(U,t(V))) +::STMT +FLOAT:m2X,float180,int20 +LITERAL_FLOAT:100000.0 +sqrt(*(m2X,/(100000.0,-(int20,float180)))) +::STMT +MATRIX:p,A +*(p,%*%(t(A),%*%(A,p))) +::STMT +MATRIX:V_nonzero,row_nonzeros,lambda_I ++(%*%(t(V_nonzero),V_nonzero),*(cast.FLOAT(row_nonzeros),lambda_I)) +::STMT +MATRIX:C,Xm,parsertemp265706,parsertemp265704,Z,parsertemp265701 +FLOAT:ss +/(%*%(t(Xm),%*%(Xm,%*%(C,parsertemp265701))),sum(+(%*%(parsertemp265704,Z),*(parsertemp265706,ss)))) +::STMT +FLOAT:parsertemp13703 +LITERAL_FLOAT:1.0,1000.0 +/(*(parsertemp13703,1000.0),-(1000.0,1.0)) +::STMT +MATRIX:WM +FLOAT:parsertemp31268 +LITERAL_FLOAT:1.0 +*(-(sum(WM),1.0),/(*(parsertemp31268,sum(WM)),-(sum(WM),1.0))) +::STMT +MATRIX:Xm,tmp,parsertemp265702 +t(/(%*%(t(Xm),%*%(Xm,parsertemp265702)),sum(tmp))) +::STMT +MATRIX:scale_X,shift_X,X,parsertemp271403 +FLOAT:int126,int545 +LITERAL_FLOAT:2.0 ++(+(%*%(^(X,int126),^(scale_X,int545)),%*%(X,*(parsertemp271403,shift_X))),sum(^(shift_X,2.0))) +::STMT +FLOAT:parsertemp271435 +LITERAL_FLOAT:1500.0 +*(1500.0,parsertemp271435) +::STMT +FLOAT:Hin +LITERAL_FLOAT:184.0 +*(+(Hin,184.0),+(Hin,184.0)) +::STMT +MATRIX:d,parsertemp43996,parsertemp43997 +LITERAL_FLOAT:2.0 +%*%(t(d),+(d,*(2.0,%*%(parsertemp43996,parsertemp43997)))) +::STMT +MATRIX:K_inv,Ks,Kss +-(Kss,%*%(%*%(t(Ks),K_inv),Ks)) +::STMT +MATRIX:parsertemp220900,parsertemp220899,Y +LITERAL_FLOAT:300.0,0.0 ++(Y,-(0.0,*(300.0,-(parsertemp220899,parsertemp220900)))) +::STMT +MATRIX:WM +LITERAL_FLOAT:1.0 +-(sum(WM),1.0) +::STMT +FLOAT:res_eee +LITERAL_FLOAT:2.0,0.3 +-(/(res_eee,2.0),0.3) +::STMT +MATRIX:parsertemp24102 +FLOAT:num_bins +LITERAL_FLOAT:1.0 +*(>(+(round(parsertemp24102),1.0),num_bins),num_bins) +::STMT +MATRIX:W +FLOAT:m2 +*(m2,sum(round(W))) +::STMT +MATRIX:2903_mask,dout,2902_W +FLOAT:2903_p +*(/(2903_mask,2903_p),%*%(dout,t(2902_W))) +::STMT +MATRIX:r_CG,q_CG +FLOAT:alpha_CG ++(r_CG,*(alpha_CG,cast.FLOAT(q_CG))) +::STMT +MATRIX:_sbcvar95,_sbcvar97 +FLOAT:221_my +LITERAL_FLOAT:0.0,2.0 +^(+(%*%(_sbcvar95,_sbcvar97),-(0.0,221_my)),2.0) +::STMT +MATRIX:parsertemp395002,W4_rand,parsertemp395005 +LITERAL_FLOAT:0.08692913816996169 +t(%*%(*(0.08692913816996169,W4_rand),t(/(parsertemp395002,parsertemp395005)))) +::STMT +MATRIX:X,Y,K +-(*(K,-(X,X)),-(Y,Y)) +::STMT +MATRIX:Xd,out +FLOAT:dd,parsertemp467655,wd +/(*(-(+(wd,parsertemp467655),sum(out)),-(+(wd,parsertemp467655),sum(out))),+(dd,sum(Xd))) +::STMT +FLOAT:i +LITERAL_FLOAT:27.0 ++(27.0,i) +::STMT +MATRIX:X,y +LITERAL_FLOAT:0.0 +%*%(-(0.0,t(X)),y) +::STMT +MATRIX:leaf_ids +FLOAT:boundary_left,step_size +<(leaf_ids,+(+(boundary_left,step_size),step_size)) +::STMT +FLOAT:idx +LITERAL_FLOAT:1.0,11.0 ++(-(11.0,idx),1.0) +::STMT +LITERAL_FLOAT:1.0,2.0 ++(2.0,1.0) +::STMT +MATRIX:p_gaps_vector +FLOAT:number_nans +/(number_nans,sum(p_gaps_vector)) +::STMT +FLOAT:g,h +/(*(g,g),h) +::STMT +MATRIX:var_X_cols,parsertemp414376,parsertemp414378 +FLOAT:int672 +LITERAL_FLOAT:0.0,1.0,199.0 ++(*(/(-(parsertemp414376,parsertemp414378),199.0),-(1.0,<=(var_X_cols,int672))),<=(/(-(parsertemp414376,parsertemp414378),199.0),0.0)) +::STMT +LITERAL_FLOAT:1.0,6.0,2001.0 +*(*(6.0,2001.0),-(2001.0,1.0)) +::STMT +MATRIX:D,beta +LITERAL_FLOAT:-1.0 +exp(*(*(D,-1.0),beta)) +::STMT +MATRIX:d,exp_Xb,X +rev(*(X,*(%*%(X,d),exp_Xb))) +::STMT +MATRIX:K_inv,parsertemp387408,Ks,Kss +cast.FLOAT(-(Kss,%*%(%*%(parsertemp387408,K_inv),Ks))) +::STMT +MATRIX:present_domain_vals_mat,parsertemp27485 +FLOAT:my +-(%*%(present_domain_vals_mat,parsertemp27485),my) +::STMT +MATRIX:p_CG,z +LITERAL_FLOAT:-1.0 +*(cast.FLOAT(%*%(t(p_CG),z)),-1.0) +::STMT +MATRIX:parsertemp43626 +LITERAL_FLOAT:-1.0,2.0 +*(2.0,sum(*(parsertemp43626,-1.0))) +::STMT +MATRIX:X_batch,W1_grad +FLOAT:step +*(/(step,nrow(X_batch)),W1_grad) +::STMT +MATRIX:_sbcvar1156,_sbcvar1155 +FLOAT:num_records +LITERAL_FLOAT:1.0 ++(*(_sbcvar1155,_sbcvar1156),*(+(num_records,1.0),-(1.0,_sbcvar1156))) +::STMT +MATRIX:e_r_rev_agg,select,d_r_rev,X_rev_agg +/(*(%*%(select,X_rev_agg),d_r_rev),e_r_rev_agg) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,750.0 +/(+(classFeatureCounts,1.0),+(rowSums(classFeatureCounts),750.0)) +::STMT +MATRIX:parsertemp220853,Ws,beta +FLOAT:logU +LITERAL_FLOAT:0.0 +<(-(+(parsertemp220853,*(beta,Ws)),logU),0.0) +::STMT +MATRIX:P,D,ZERODIAG +LITERAL_FLOAT:1.0E-12 +/(rowSums(*(*(P,ZERODIAG),D)),+(rowSums(*(P,ZERODIAG)),1.0E-12)) +::STMT +MATRIX:tmp,w,out +LITERAL_FLOAT:50.0,0.5 ++(*(0.5,cast.FLOAT(%*%(out,out))),*(50.0,cast.FLOAT(%*%(w,tmp)))) +::STMT +MATRIX:p,G +FLOAT:alpha +*(alpha,%*%(G,p)) +::STMT +MATRIX:q,z +FLOAT:pp,pq,parsertemp285524 +LITERAL_FLOAT:0.5 ++(*(*(0.5,/(parsertemp285524,pp)),pq),sum(*(z,q))) +::STMT +MATRIX:p,q,lambda +sum(*(p,+(q,*(lambda,p)))) +::STMT +MATRIX:lambda,beta +LITERAL_FLOAT:0.0 +*(+(0.0,*(cast.FLOAT(lambda),cast.FLOAT(beta))),+(0.0,*(cast.FLOAT(lambda),cast.FLOAT(beta)))) +::STMT +MATRIX:_sbcvar78 +LITERAL_FLOAT:10000.0 +/(%*%(rowSums(_sbcvar78),colSums(_sbcvar78)),10000.0) +::STMT +MATRIX:parsertemp16859,77_Y_row_norm,parsertemp16868,X,Y,parsertemp16861 +FLOAT:float904 +/(%*%(X,t(Y)),%*%(+(sqrt(parsertemp16859),*(parsertemp16861,float904)),t(+(77_Y_row_norm,parsertemp16868)))) +::STMT +MATRIX:X_adapted,parsertemp176506 +FLOAT:intercept,parsertemp176418 +LITERAL_FLOAT:3.0 +<(X_adapted,-(sqrt(parsertemp176418),*(3.0,+(parsertemp176506,intercept)))) +::STMT +MATRIX:X_adapted,parsertemp176506 +FLOAT:intercept,parsertemp176418 +LITERAL_FLOAT:3.0 +>(X_adapted,+(sqrt(parsertemp176418),*(3.0,+(parsertemp176506,intercept)))) +::STMT +MATRIX:parsertemp171600,g_Y,lambda,parsertemp171602,beta +LITERAL_FLOAT:2.0 +^(+(*(cast.FLOAT(parsertemp171602),%*%(parsertemp171600,g_Y)),*(cast.FLOAT(lambda),cast.FLOAT(beta))),2.0) +::STMT +MATRIX:z +FLOAT:trust_delta_sq,pp_CG +*(pp_CG,-(cast.FLOAT(%*%(z,z)),trust_delta_sq)) +::STMT +LITERAL_FLOAT:-0.36651292058166435 +-0.36651292058166435 +::STMT +MATRIX:F +LITERAL_FLOAT:2.0 +/(rowSums(F),2.0) +::STMT +MATRIX:parsertemp31104,parsertemp31106 +FLOAT:int255 +LITERAL_FLOAT:1999.0,2000.0 +/(/(-(colSums(parsertemp31104),*(int255,parsertemp31106)),1999.0),2000.0) +::STMT +FLOAT:parsertemp5,m2X,parsertemp9,m2Y,covXY +/(covXY,*(sqrt(*(m2X,parsertemp5)),sqrt(*(m2Y,parsertemp9)))) +::STMT +MATRIX:diff_nominal +FLOAT:num_std_median +LITERAL_FLOAT:0.0 +*(!=(diff_nominal,0.0),num_std_median) +::STMT +MATRIX:W1_rand,X,parsertemp399148,parsertemp399138 +FLOAT:float154 +LITERAL_FLOAT:0.08692913816996169 +%*%(*(0.08692913816996169,W1_rand),t(/(-(X,parsertemp399138),+(parsertemp399148,float154)))) +::STMT +MATRIX:maskd1,out1,185_dX,parsertemp146947,W2 +FLOAT:p +LITERAL_FLOAT:0.0 +*(>(out1,0.0),*(/(maskd1,p),%*%(*(parsertemp146947,185_dX),t(W2)))) +::STMT +MATRIX:X +FLOAT:n +LITERAL_FLOAT:2.0 +^(/(t(colSums(X)),n),2.0) +::STMT +MATRIX:LT,Y,parsertemp149320 +sum(*(Y,-(LT,parsertemp149320))) +::STMT +MATRIX:V,X +LITERAL_FLOAT:0.0 +*(V,t(!=(X,0.0))) +::STMT +MATRIX:X,K +LITERAL_FLOAT:-1.0 +*(*(K,-1.0),-(X,X)) +::STMT +MATRIX:W +FLOAT:m4 +LITERAL_FLOAT:1.0,2.0 +*(*(^(sum(W),2.0),+(sum(W),1.0)),m4) +::STMT +MATRIX:parsertemp32006,simplex +LITERAL_FLOAT:2.0 +-(*(2.0,/(-(parsertemp32006,simplex),nrow(simplex))),simplex) +::STMT +MATRIX:resp,mean,X,weight,diff +/(%*%(t(*(diff,resp)),-(X,mean)),cast.FLOAT(weight)) +::STMT +MATRIX:X +/(t(colSums(X)),nrow(X)) +::STMT +MATRIX:parsertemp31023,parsertemp31025,parsertemp31030,parsertemp31032 +LITERAL_FLOAT:149.0,150.0,99.0,100.0 ++(/(/(-(parsertemp31023,parsertemp31025),99.0),100.0),/(/(-(parsertemp31030,parsertemp31032),149.0),150.0)) +::STMT +MATRIX:R +LITERAL_FLOAT:0.0,32.0 +&(>=(R,32.0),>(R,0.0)) +::STMT +MATRIX:parsertemp31763,parsertemp31756 +FLOAT:minSup +LITERAL_FLOAT:0.0 +sum(&(>=(t(parsertemp31756),minSup),>(t(parsertemp31763),0.0))) +::STMT +MATRIX:ytest,yhat +LITERAL_FLOAT:1.0,2.0 +^(/(-(cast.FLOAT(ytest),cast.FLOAT(yhat)),1.0),2.0) +::STMT +MATRIX:s,w +LITERAL_FLOAT:2.0 +*(2.0,cast.FLOAT(%*%(t(w),s))) +::STMT +LITERAL_FLOAT:0.2656844656620286 +0.2656844656620286 +::STMT +MATRIX:R,parsertemp40219,parsertemp40216,parsertemp40225,removedE +FLOAT:level +-(+(R,rowSums(*(parsertemp40216,parsertemp40225))),rowSums(*(==(parsertemp40219,level),t(removedE)))) +::STMT +MATRIX:Y_val,parsertemp459795 +FLOAT:int459 +LITERAL_FLOAT:50.0 +/(sum(*(-(int459,Y_val),parsertemp459795)),50.0) +::STMT +MATRIX:majority +LITERAL_FLOAT:0.0,1.0,2.0 +INT:int589,parsertemp282730 +*(>(rand(parsertemp282730,int589,1.0,2.0),0.0),majority) +::STMT +MATRIX:s,parsertemp44016,d +FLOAT:delta2 +*(%*%(t(d),d),-(delta2,%*%(t(s),-(s,parsertemp44016)))) +::STMT +MATRIX:parsertemp460691 +FLOAT:lr +*(lr,rowSums(parsertemp460691)) +::STMT +MATRIX:parsertemp171269,Y,linear_terms +FLOAT:int153,int429 +LITERAL_FLOAT:0.0 +-(/(+(Y,==(Y,int429)),+(*(linear_terms,parsertemp171269),==(Y,int153))),==(Y,0.0)) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:1.0E7 +*(==(+(1.0E7,exp(finite_linear_terms)),1.0E7),exp(finite_linear_terms)) +::STMT +MATRIX:CI_l +LITERAL_FLOAT:0.5 +t(<=(CI_l,0.5)) +::STMT +MATRIX:m_iter_err_sum,m_err +-(t(+(colSums(m_err),m_iter_err_sum)),+(colSums(m_err),m_iter_err_sum)) +::STMT +MATRIX:F +colSums(F) +::STMT +MATRIX:ot2 +FLOAT:int521,Nt +LITERAL_FLOAT:100.0 +/(*(sum(>(ot2,int521)),100.0),Nt) +::STMT +MATRIX:eVals,eVecs +LITERAL_FLOAT:-1.0 +%*%(eVecs,diag(^(eVals,-1.0))) +::STMT +MATRIX:R,3_ss,dsep +FLOAT:3_eAvg +/(/(+(R,dsep),3_ss),3_eAvg) +::STMT +MATRIX:b,X +rev(*(X,exp(%*%(X,b)))) +::STMT +MATRIX:X +FLOAT:p +-(nrow(X),p) +::STMT +MATRIX:indexWithInGroups,parsertemp129475,groupIndex,selectedMatrix ++(-(*(groupIndex,max(parsertemp129475)),max(parsertemp129475)),rowSums(*(indexWithInGroups,selectedMatrix))) +::STMT +MATRIX:Y +-(Y,/(sum(Y),nrow(Y))) +::STMT +LITERAL_FLOAT:5.0,2003.0 ++(2003.0,5.0) +::STMT +FLOAT:i,s_cols +LITERAL_FLOAT:1.0 +*(-(i,1.0),s_cols) +::STMT +MATRIX:parsertemp271438,parsertemp271437 +LITERAL_FLOAT:2.0 +sqrt(sum(^(%*%(parsertemp271437,parsertemp271438),2.0))) +::STMT +FLOAT:max_depth +LITERAL_FLOAT:2.0 +^(2.0,max_depth) +::STMT +LITERAL_FLOAT:1.0,100000.0 +/(100000.0,-(100000.0,1.0)) +::STMT +MATRIX:parsertemp251811 +FLOAT:f +LITERAL_FLOAT:0.0 +==(<=(parsertemp251811,f),0.0) +::STMT +LITERAL_FLOAT:44.75488800120049 +44.75488800120049 +::STMT +MATRIX:H2_prime,H1_prime,W2,parsertemp389612 +t(*(H1_prime,%*%(*(H2_prime,parsertemp389612),W2))) +::STMT +LITERAL_FLOAT:0.001308 +0.001308 +::STMT +MATRIX:img_in +FLOAT:w +LITERAL_FLOAT:2.0 +/(-(ncol(img_in),w),2.0) +::STMT +MATRIX:out2,parsertemp146942,184_dscores,outd1 +FLOAT:int988 +LITERAL_FLOAT:2.0 +^(%*%(t(outd1),*(>(out2,int988),%*%(184_dscores,parsertemp146942))),2.0) +::STMT +MATRIX:z_LS +FLOAT:norm_r2_LS,p_LS ++(z_LS,*(/(norm_r2_LS,*(p_LS,p_LS)),cast.FLOAT(p_LS))) +::STMT +MATRIX:y_val,preds +t(-(y_val,preds)) +::STMT +MATRIX:parsertemp2832 +max(round(parsertemp2832)) +::STMT +MATRIX:parsertemp131906,parsertemp132092,rightHist,outBucket +%*%(==(outBucket,%*%(parsertemp132092,t(parsertemp131906))),rightHist) +::STMT +MATRIX:b_cumulant,Y,natural_parameters +-(*(Y,natural_parameters),b_cumulant) +::STMT +FLOAT:norm_r2,norm_r2_initial +/(norm_r2,norm_r2_initial) +::STMT +MATRIX:leaf_ids,out +FLOAT:boundary_left,step_size ++(out,&(>=(leaf_ids,boundary_left),<(leaf_ids,+(boundary_left,step_size)))) +::STMT +MATRIX:B,X,y +FLOAT:intercept +LITERAL_FLOAT:2.0 +^(-(y,+(%*%(X,B),intercept)),2.0) +::STMT +MATRIX:mean +LITERAL_FLOAT:2.0 +*(2.0,^(mean,2.0)) +::STMT +FLOAT:sv,rad,delta2,s2 +/(-(delta2,s2),+(sv,rad)) +::STMT +MATRIX:classes,X +FLOAT:split ++(-(nrow(X),split),nrow(classes)) +::STMT +MATRIX:parsertemp553014,M2,parsertemp553121,parsertemp553122,missing,parsertemp553008 +LITERAL_FLOAT:2.0 +-(+(%*%(rowSums(parsertemp553008),parsertemp553121),t(%*%(parsertemp553014,parsertemp553122))),*(2.0,%*%(M2,t(missing)))) +::STMT +LITERAL_FLOAT:1.6583123951777 +1.6583123951777 +::STMT +FLOAT:sum_y_test,n +LITERAL_FLOAT:2.0 +*(n,^(/(sum_y_test,n),2.0)) +::STMT +FLOAT:a,b,rad +LITERAL_FLOAT:-1.0,2.0 +/(*(-(b,rad),-1.0),*(2.0,a)) +::STMT +MATRIX:linear_terms +FLOAT:var_power +LITERAL_FLOAT:1.0 +*(linear_terms,-(1.0,var_power)) +::STMT +MATRIX:r,s,grad +-(%*%(t(s),grad),%*%(t(s),r)) +::STMT +MATRIX:parsertemp43631,parsertemp43633 +LITERAL_FLOAT:0.0,2.0 +^(+(0.0,*(2.0,%*%(parsertemp43631,parsertemp43633))),2.0) +::STMT +MATRIX:minD,D +colSums(/(<=(D,minD),rowSums(<=(D,minD)))) +::STMT +MATRIX:ones,classFeatureCounts +LITERAL_FLOAT:1.0,500.0 +%*%(+(rowSums(classFeatureCounts),*(500.0,1.0)),ones) +::STMT +MATRIX:b4,parsertemp389330,parsertemp389333,W4 ++(%*%(W4,t(/(parsertemp389330,parsertemp389333))),b4) +::STMT +MATRIX:M +LITERAL_FLOAT:2.0 +>=(rowSums(M),2.0) +::STMT +MATRIX:F +t(colSums(F)) +::STMT +MATRIX:parsertemp146957,188_dX +FLOAT:beta2 +LITERAL_FLOAT:1.0,2.0 +*(-(1.0,beta2),^(colSums(*(parsertemp146957,188_dX)),2.0)) +::STMT +MATRIX:W +LITERAL_FLOAT:1.0,2.0 +-(^(sum(round(W)),2.0),1.0) +::STMT +MATRIX:parsertemp220863,parsertemp220864,Hdiff,beta +FLOAT:int40,INF +LITERAL_FLOAT:2.0 +*(*(*(2.0,>=(Hdiff,int40)),==(+(parsertemp220863,parsertemp220864),INF)),beta) +::STMT +MATRIX:parsertemp42200,parsertemp42201,_sbcvar330 +FLOAT:meanX +LITERAL_FLOAT:1.0,0.5 +*(/(_sbcvar330,-(sum(_sbcvar330),1.0)),-(+(-(parsertemp42200,parsertemp42201),0.5),meanX)) +::STMT +FLOAT:nFeats +LITERAL_FLOAT:6.283185307179586 +^(6.283185307179586,nFeats) +::STMT +MATRIX:p_CG +FLOAT:trust_delta_sq,z,pp_CG +-(*(*(cast.FLOAT(z),sum(p_CG)),*(cast.FLOAT(z),sum(p_CG))),*(pp_CG,-(*(z,z),trust_delta_sq))) +::STMT +MATRIX:45_CVars,45_CFreqs +FLOAT:int43 +LITERAL_FLOAT:1000.0 +/(sum(*(-(45_CFreqs,int43),45_CVars)),-(1000.0,nrow(45_CFreqs))) +::STMT +MATRIX:parsertemp555613,X,Xc,parsertemp555606,parsertemp555615 +LITERAL_FLOAT:1.0 +/(/(%*%(t(Xc),-(X,parsertemp555606)),-(nrow(X),1.0)),%*%(t(sqrt(parsertemp555613)),sqrt(parsertemp555615))) +::STMT +MATRIX:Bxu,Bxd +LITERAL_FLOAT:2.0 +*(2.0,+(Bxd,Bxu)) +::STMT +MATRIX:lambda,beta +LITERAL_FLOAT:0.0 ++(0.0,*(lambda,beta)) +::STMT +FLOAT:parsertemp557360,parsertemp557354,parsertemp557356,parsertemp557358,prob_true,prob_false ++(/(*(prob_true,parsertemp557354),parsertemp557356),/(*(prob_false,parsertemp557358),parsertemp557360)) +::STMT +FLOAT:parsertemp557360,parsertemp557354,parsertemp557356,weight,parsertemp557358,prob_true,prob_false +LITERAL_FLOAT:-1.0 +*(*(-1.0,weight),+(/(*(prob_true,parsertemp557354),parsertemp557356),/(*(prob_false,parsertemp557358),parsertemp557360))) +::STMT +FLOAT:F1 +LITERAL_FLOAT:2.0 +*(*(F1,2.0),2.0) +::STMT +FLOAT:p,P,Q +LITERAL_FLOAT:1.0 ++(+(+(1.0,p),P),Q) +::STMT +MATRIX:scale_X,X +LITERAL_FLOAT:2.0 +%*%(^(X,2.0),^(scale_X,2.0)) +::STMT +MATRIX:ts +FLOAT:q +-(q,%*%(ts,ts)) +::STMT +FLOAT:s +LITERAL_FLOAT:1.0,4.0 +/(4.0,+(s,1.0)) +::STMT +MATRIX:parsertemp410978,W,H +/(*(H,t(parsertemp410978)),t(colSums(W))) +::STMT +MATRIX:classes +LITERAL_FLOAT:0.30000000000000004 +*(cast.FLOAT(classes),0.30000000000000004) +::STMT +MATRIX:g_reg,p_CG +FLOAT:q_CG,z,int78,pq_CG,pp_CG,parsertemp170107,parsertemp170091 +*(+(+(*(parsertemp170107,pq_CG),*(z,q_CG)),sum(*(g_reg,p_CG))),/(+(*(z,int78),sqrt(parsertemp170091)),pp_CG)) +::STMT +MATRIX:s,d,alpha +FLOAT:parsertemp44004 +%*%(t(+(s,*(parsertemp44004,d))),+(s,*(cast.FLOAT(alpha),d))) +::STMT +MATRIX:fP +FLOAT:max_values +^(ncol(fP),max_values) +::STMT +MATRIX:e,X,tS +FLOAT:l +t(%*%(t(e),==(%*%(X,tS),l))) +::STMT +MATRIX:parsertemp22683,id +-(==(id,t(id)),diag(diag(==(id,parsertemp22683)))) +::STMT +MATRIX:g +FLOAT:lambda,parsertemp169913 +LITERAL_FLOAT:2.0 +*(sum(^(+(g,lambda),2.0)),parsertemp169913) +::STMT +MATRIX:dout,out +LITERAL_FLOAT:1.0,2.0 +*(-(1.0,^(out,2.0)),dout) +::STMT +MATRIX:V,W,H,parsertemp10749 +LITERAL_FLOAT:1.0E-8 +*(W,/(%*%(V,t(H)),+(%*%(W,parsertemp10749),1.0E-8))) +::STMT +LITERAL_FLOAT:2.0,2003.0 +^(2003.0,2.0) +::STMT +MATRIX:out2,parsertemp146942,184_dscores +FLOAT:int386,beta1 +LITERAL_FLOAT:1.0 +*(-(1.0,beta1),colSums(*(>(out2,int386),%*%(184_dscores,parsertemp146942)))) +::STMT +MATRIX:X,Centering,ScaleFactor +t(/(-(X,Centering),ScaleFactor)) +::STMT +MATRIX:p +LITERAL_FLOAT:1.0E-8 +*(1.0E-8,p) +::STMT +MATRIX:2701_mask,2700_W,parsertemp459178,2699_dtemp,2702_X +LITERAL_FLOAT:0.0,0.5 +*(>(2702_X,0.0),*(/(2701_mask,0.5),%*%(-(2699_dtemp,parsertemp459178),t(2700_W)))) +::STMT +MATRIX:parsertemp171268,Y,linear_terms,parsertemp171271,vec1 +FLOAT:link_power,int612 +/(-(-(/(parsertemp171268,parsertemp171271),==(Y,int612)),*(*(Y,vec1),linear_terms)),link_power) +::STMT +MATRIX:lambda,shift_X,gXY,parsertemp171602,beta +LITERAL_FLOAT:2.0 +^(+(+(%*%(parsertemp171602,gXY),%*%(shift_X,gXY)),*(lambda,beta)),2.0) +::STMT +MATRIX:W1_rand +LITERAL_FLOAT:0.08333333333333333 +*(0.08333333333333333,W1_rand) +::STMT +MATRIX:P,Z,ZERODIAG,parsertemp220891 +FLOAT:int793 +-(P,/(*(/(int793,parsertemp220891),ZERODIAG),sum(*(Z,ZERODIAG)))) +::STMT +MATRIX:R,S,parsertemp40218 +FLOAT:level +-(R,rowSums(==(%*%(S,parsertemp40218),level))) +::STMT +MATRIX:A +LITERAL_FLOAT:1.0 +sum(>=(rowSums(abs(A)),1.0)) +::STMT +FLOAT:float320,parsertemp169813 +LITERAL_FLOAT:2.302585092994046,4.0 +*(2.302585092994046,-(4.0,round(-(parsertemp169813,float320)))) +::STMT +MATRIX:parsertemp393584,W4_rand,parsertemp393587 +LITERAL_FLOAT:0.08709382882250233 +t(%*%(*(0.08709382882250233,W4_rand),t(/(parsertemp393584,parsertemp393587)))) +::STMT +MATRIX:parsertemp414374,avg_X_cols +FLOAT:int635 +LITERAL_FLOAT:200.0,199.0 +/(-(t(colSums(parsertemp414374)),*(200.0,^(avg_X_cols,int635))),199.0) +::STMT +MATRIX:parsertemp10743,V,H,parsertemp10739 +%*%(V,t(*(H,/(parsertemp10739,parsertemp10743)))) +::STMT +MATRIX:D,beta +LITERAL_FLOAT:0.0 +*(-(0.0,D),beta) +::STMT +MATRIX:X_batch,2365_delta2,H1_prime,W2 +%*%(t(*(H1_prime,%*%(2365_delta2,W2))),X_batch) +::STMT +MATRIX:parsertemp409789,parsertemp409798,g0_2,g0_1 +FLOAT:int16 +cast.FLOAT(%*%(t(+(g0_1,g0_2)),+(-(int16,parsertemp409789),t(parsertemp409798)))) +::STMT +MATRIX:WM +FLOAT:parsertemp31268 +LITERAL_FLOAT:1.0 +/(*(parsertemp31268,sum(WM)),-(sum(WM),1.0)) +::STMT +FLOAT:arch_coef,var_coef,a0 +LITERAL_FLOAT:1.0 +/(a0,-(-(1.0,arch_coef),var_coef)) +::STMT +MATRIX:parsertemp220988,parsertemp220989,dY,Y +LITERAL_FLOAT:300.0,0.9 ++(Y,-(*(0.9,dY),*(300.0,-(parsertemp220988,parsertemp220989)))) +::STMT +MATRIX:p,q,r,parsertemp1947 +FLOAT:norm_r2,alpha +LITERAL_FLOAT:-1.0 ++(*(+(r,*(alpha,q)),-1.0),*(/(sum(parsertemp1947),norm_r2),p)) +::STMT +MATRIX:upd_W1 +LITERAL_FLOAT:0.9 +*(0.9,upd_W1) +::STMT +LITERAL_FLOAT:1.0E-30 +1.0E-30 +::STMT +MATRIX:p,q,parsertemp503394,Z +FLOAT:norm_r2 +*(/(norm_r2,cast.FLOAT(%*%(parsertemp503394,q))),%*%(Z,p)) +::STMT +FLOAT:n_group_cols +LITERAL_FLOAT:2.0 ++(2.0,n_group_cols) +::STMT +MATRIX:P,Phi,Theta +%*%(%*%(P,Theta),t(Phi)) +::STMT +MATRIX:2697_out,2697_b,parsertemp459149,parsertemp459147 +/(exp(-(+(parsertemp459147,2697_b),parsertemp459149)),rowSums(exp(-(2697_out,parsertemp459149)))) +::STMT +MATRIX:out +FLOAT:dd,step_sz,wd +-(+(wd,*(step_sz,dd)),sum(out)) +::STMT +MATRIX:A,scale_X,shift_X,parsertemp115874,X +t(+(%*%(diag(scale_X),%*%(parsertemp115874,X)),%*%(shift_X,A))) +::STMT +MATRIX:d +cast.FLOAT(%*%(t(d),d)) +::STMT +MATRIX:tmp +FLOAT:norm_r2_LS +/(*(cast.FLOAT(tmp),cast.FLOAT(tmp)),norm_r2_LS) +::STMT +MATRIX:r,s,grad +-(cast.FLOAT(%*%(t(s),grad)),cast.FLOAT(%*%(t(s),r))) +::STMT +FLOAT:o_init,N +LITERAL_FLOAT:-2.0 +exp(/(*(-2.0,o_init),N)) +::STMT +MATRIX:W2_rand +LITERAL_FLOAT:0.1092173494617922 +*(0.1092173494617922,W2_rand) +::STMT +LITERAL_FLOAT:1.0,-1.0E30 +INT:int11,M +*(-1.0E30,rand(M,int11,1.0,1.0)) +::STMT +MATRIX:the_exp,linear_terms,Y +FLOAT:int787 +*(*(exp(*(the_exp,int787)),exp(linear_terms)),rowSums(Y)) +::STMT +MATRIX:rowSums_X_sq +LITERAL_FLOAT:8.660254037844387 +/(8.660254037844387,max(sqrt(rowSums_X_sq))) +::STMT +MATRIX:parsertemp500607,w,parsertemp500610 +sum(*(-(*(parsertemp500607,parsertemp500610),w),-(*(parsertemp500607,parsertemp500610),w))) +::STMT +MATRIX:Xtest_dists +FLOAT:int953,eps +LITERAL_FLOAT:1.0 +>=(rowSums(*(<=(Xtest_dists,eps),<(int953,Xtest_dists))),1.0) +::STMT +MATRIX:ZtZ,C,Xm,parsertemp265709,Z,parsertemp265701 +%*%(%*%(Xm,%*%(C,parsertemp265701)),t(/(%*%(parsertemp265709,Z),sum(ZtZ)))) +::STMT +MATRIX:s,w +*(w,s) +::STMT +MATRIX:linear_terms +exp(linear_terms) +::STMT +MATRIX:269_Row_norm,parsertemp34343,X_block +LITERAL_FLOAT:0.3 +>(/(%*%(X_block,t(X_block)),%*%(sqrt(parsertemp34343),t(269_Row_norm))),0.3) +::STMT +FLOAT:int874,int128,width,parsertemp387147 +LITERAL_FLOAT:-1.0,2.0 +exp(/(*(-1.0,^(parsertemp387147,int128)),*(2.0,^(width,int874)))) +::STMT +MATRIX:parsertemp174552 +LITERAL_FLOAT:0.0 +sum(==(parsertemp174552,0.0)) +::STMT +FLOAT:s +LITERAL_FLOAT:1.0,5.0 +/(5.0,+(s,1.0)) +::STMT +MATRIX:parsertemp436668,X,parsertemp436672,bc_matrix +LITERAL_FLOAT:2.0 +-(*(bc_matrix,t(rowSums(parsertemp436668))),*(2.0,%*%(X,t(parsertemp436672)))) +::STMT +MATRIX:resp,X,weight +LITERAL_FLOAT:2.0 +/(%*%(t(resp),^(X,2.0)),t(weight)) +::STMT +MATRIX:B,C,D,E,parsertemp462474 +%*%(==(%*%(<=(parsertemp462474,B),C),D),E) +::STMT +MATRIX:X,permut +FLOAT:n +LITERAL_FLOAT:2.0 +/(colSums(^(%*%(permut,X),2.0)),n) +::STMT +MATRIX:parsertemp411208,parsertemp411210,parsertemp411199,X,parsertemp411201,parsertemp411217 +-(sum(%*%(/(parsertemp411208,parsertemp411210),/(parsertemp411199,parsertemp411201))),sum(*(X,parsertemp411217))) +::STMT +MATRIX:C,parsertemp174574 +FLOAT:numRows +/(sum(==(parsertemp174574,C)),numRows) +::STMT +LITERAL_FLOAT:1.0,2003.0 ++(2003.0,1.0) +::STMT +MATRIX:X_orig +FLOAT:parsertemp164950 ++(ncol(X_orig),parsertemp164950) +::STMT +MATRIX:parsertemp196005 +FLOAT:parsertemp191170,Wf +LITERAL_FLOAT:2.0 +*(parsertemp196005,sqrt(/(2.0,*(parsertemp191170,Wf)))) +::STMT +MATRIX:tmp,X +FLOAT:x +*(cast.FLOAT(tmp),/(-(x,cast.FLOAT(X)),-(cast.FLOAT(X),cast.FLOAT(X)))) +::STMT +MATRIX:W1_rand,parsertemp396312,X,parsertemp396302 +FLOAT:float297 +LITERAL_FLOAT:0.07808688094430302 +%*%(*(0.07808688094430302,W1_rand),t(/(-(X,parsertemp396302),+(parsertemp396312,float297)))) +::STMT +MATRIX:newbeta,lambda +LITERAL_FLOAT:2.0 +*(cast.FLOAT(lambda),^(cast.FLOAT(newbeta),2.0)) +::STMT +MATRIX:parsertemp393583,W4_rand +FLOAT:int268,int639 +LITERAL_FLOAT:0.08709382882250233 +%*%(*(0.08709382882250233,W4_rand),t(/(-(parsertemp393583,int639),+(parsertemp393583,int268)))) +::STMT +MATRIX:Nc +==(Nc,max(Nc)) +::STMT +MATRIX:parsertemp31030,parsertemp31032 +LITERAL_FLOAT:149.0,2.0,3352500.0 +/(^(/(-(parsertemp31030,parsertemp31032),149.0),2.0),3352500.0) +::STMT +MATRIX:parsertemp472404 +FLOAT:max_features,n +<=(parsertemp472404,/(^(n,max_features),n)) +::STMT +MATRIX:77_Y_row_norm,parsertemp16864 +FLOAT:float693 +LITERAL_FLOAT:1.0E-6 +t(+(sqrt(rowSums(parsertemp16864)),*(<(77_Y_row_norm,float693),1.0E-6))) +::STMT +MATRIX:g_reg,q_CG,p_CG,z +FLOAT:float720,277_tau_1,pq_CG ++(+(*(*(float720,277_tau_1),pq_CG),*(cast.FLOAT(z),cast.FLOAT(q_CG))),sum(*(g_reg,p_CG))) +::STMT +MATRIX:X +X +::STMT +MATRIX:p_CG +FLOAT:trust_delta_sq,z +sqrt(-(*(cast.FLOAT(p_CG),cast.FLOAT(p_CG)),*(cast.FLOAT(p_CG),-(z,trust_delta_sq)))) +::STMT +MATRIX:parsertemp389580,parsertemp389560,2365_delta3,W3 +FLOAT:int629 +LITERAL_FLOAT:1.0 +%*%(t(*(-(int629,parsertemp389580),%*%(2365_delta3,W3))),/(-(exp(parsertemp389560),1.0),+(exp(parsertemp389560),1.0))) +::STMT +MATRIX:parsertemp410245,parsertemp410248 +FLOAT:int652,float185 +LITERAL_FLOAT:0.6666666666666666 +min(^(/(-(int652,parsertemp410245),*(float185,parsertemp410248)),0.6666666666666666)) +::STMT +MATRIX:Q,R,parsertemp500360,parsertemp500308,parsertemp500359,parsertemp500300 +LITERAL_FLOAT:2.0 +-(+(%*%(rowSums(parsertemp500300),parsertemp500359),%*%(parsertemp500360,t(parsertemp500308))),*(2.0,%*%(R,t(Q)))) +::STMT +MATRIX:y_train,prediction +LITERAL_FLOAT:0.5 +==(y_train,>(prediction,0.5)) +::STMT +MATRIX:newbeta,lambda +LITERAL_FLOAT:2.0,0.5 +*(0.5,*(cast.FLOAT(lambda),^(cast.FLOAT(newbeta),2.0))) +::STMT +MATRIX:R +FLOAT:minSup +>=(R,minSup) +::STMT +MATRIX:w,ssX_p_CG,X +cast.FLOAT(%*%(t(X),*(w,%*%(X,ssX_p_CG)))) +::STMT +FLOAT:ratio +LITERAL_FLOAT:15.0 +*(15.0,ratio) +::STMT +MATRIX:G,minDist +FLOAT:int625 +LITERAL_FLOAT:-1.0 +^(+(G,*(!=(G,int625),minDist)),-1.0) +::STMT +LITERAL_FLOAT:3.0,2003.0 ++(2003.0,3.0) +::STMT +MATRIX:w,g +FLOAT:alpha,tau +-(abs(-(w,/(g,alpha))),/(tau,alpha)) +::STMT +MATRIX:parsertemp72334 +FLOAT:rows +cast.FLOAT(/(colSums(rowSums(parsertemp72334)),rows)) +::STMT +FLOAT:new_log_l,saturated_log_l +LITERAL_FLOAT:2.0 +*(2.0,-(saturated_log_l,new_log_l)) +::STMT +MATRIX:n_risk,n_event +/(n_event,*(n_risk,-(n_risk,n_event))) +::STMT +MATRIX:parsertemp283570,tpr,fpr,parsertemp283568 +LITERAL_FLOAT:2.0 ++(*(cast.FLOAT(tpr),cast.FLOAT(fpr)),sum(/(*(parsertemp283568,parsertemp283570),2.0))) +::STMT +MATRIX:X +LITERAL_FLOAT:-2.0,2.0 ++(*(-2.0,%*%(X,t(X))),rowSums(^(X,2.0))) +::STMT +MATRIX:prec,X,mu +*(-(%*%(X,prec),%*%(mu,prec)),-(%*%(X,prec),%*%(mu,prec))) +::STMT +MATRIX:mean,X,weight +-(%*%(t(X),X),%*%(*(t(mean),weight),mean)) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,3.0 ++(*(3.0,-(i,1.0)),1.0) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0 +-(0.0,sum(X)) +::STMT +FLOAT:parsertemp85,int24,wt,parsertemp90 +LITERAL_FLOAT:1.0,2.0,4.0 +*(*(4.0,-(^(wt,int24),1.0)),^(sqrt(/(parsertemp85,parsertemp90)),2.0)) +::STMT +MATRIX:lambda,g,beta +LITERAL_FLOAT:-1.0 +*(+(g,*(lambda,beta)),-1.0) +::STMT +MATRIX:ubScores,fSizes +FLOAT:minsc +LITERAL_FLOAT:0.0 +&(fSizes,&(>(ubScores,minsc),>(ubScores,0.0))) +::STMT +LITERAL_FLOAT:1.0,2.0,4.0,2001.0 +*(4.0,-(^(2001.0,2.0),1.0)) +::STMT +LITERAL_FLOAT:-1.0 +INT:int571,n +diag(rand(n,int571,-1.0,-1.0)) +::STMT +LITERAL_FLOAT:1.0 +INT:int269,n +diag(rand(n,int269,1.0,1.0)) +::STMT +MATRIX:parsertemp71758,is_row_in_samples +FLOAT:sample_block_size +LITERAL_FLOAT:1.0,3.0 +-(+(*(sample_block_size,3.0),1.0),*(is_row_in_samples,parsertemp71758)) +::STMT +MATRIX:scale_X,w,parsertemp170066,X +*(cast.FLOAT(diag(scale_X)),cast.FLOAT(%*%(t(X),*(w,parsertemp170066)))) +::STMT +FLOAT:s +LITERAL_FLOAT:3.0 +^(3.0,s) +::STMT +MATRIX:m_iter_err_sum_squared,parsertemp379572,parsertemp379570,parsertemp379563 +FLOAT:i_process_item +LITERAL_FLOAT:1.0 +sqrt(/(+(-(parsertemp379570,parsertemp379572),+(parsertemp379563,m_iter_err_sum_squared)),-(i_process_item,1.0))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,3.0 ++(*(3.0,-(i,1.0)),3.0) +::STMT +MATRIX:X,Y +FLOAT:x +LITERAL_FLOAT:1.0 +*(-(1.0,/(-(x,X),-(X,X))),Y) +::STMT +MATRIX:w,parsertemp500601 +FLOAT:alpha,tau +LITERAL_FLOAT:0.0 +>(-(abs(-(w,parsertemp500601)),/(tau,alpha)),0.0) +::STMT +MATRIX:parsertemp131967 +*(ncol(parsertemp131967),nrow(parsertemp131967)) +::STMT +MATRIX:parsertemp265718,parsertemp265715 +FLOAT:Xm +LITERAL_FLOAT:2.0,4000.0 +/(-(+(Xm,trace(parsertemp265715)),*(2.0,cast.FLOAT(parsertemp265718))),4000.0) +::STMT +MATRIX:m_iter_err_sum_squared,m_err +LITERAL_FLOAT:2.0 ++(colSums(^(m_err,2.0)),m_iter_err_sum_squared) +::STMT +MATRIX:obj,gs,parsertemp44066,parsertemp44078 +FLOAT:parsertemp44082 +LITERAL_FLOAT:-0.5 +cast.FLOAT(/(-(obj,+(parsertemp44078,parsertemp44082)),*(-0.5,-(gs,parsertemp44066)))) +::STMT +MATRIX:X +FLOAT:x +LITERAL_FLOAT:1.0 +-(1.0,/(-(x,cast.FLOAT(X)),-(cast.FLOAT(X),cast.FLOAT(X)))) +::STMT +MATRIX:parsertemp552530 +LITERAL_FLOAT:0.0 +INT:parsertemp552529,idx ++(rand(parsertemp552529,idx,0.0,0.0),t(parsertemp552530)) +::STMT +MATRIX:Q,ssX_V,X,parsertemp150463,P_1K +-(*(P_1K,%*%(X,ssX_V)),*(P_1K,%*%(rowSums(Q),parsertemp150463))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,61.0 ++(*(-(i,1.0),61.0),61.0) +::STMT +MATRIX:prob,test_Y +FLOAT:threshold +LITERAL_FLOAT:0.0 +*(test_Y,==(>(prob,threshold),0.0)) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,61.0,34.0 ++(*(-(i,1.0),61.0),34.0) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0 +!=(rowSums(!=(X,0.0)),0.0) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,2.0,3.0 ++(*(3.0,-(i,1.0)),2.0) +::STMT +MATRIX:ss +LITERAL_FLOAT:0.050000000000000044,1.0,20.0 +*(0.050000000000000044,-(/(20.0,ss),1.0)) +::STMT +MATRIX:subspace_idx,parsertemp109953 +LITERAL_FLOAT:1.0,42.0 +<(-(subspace_idx,*(parsertemp109953,42.0)),1.0) +::STMT +MATRIX:parsertemp2832 +sum(==(round(parsertemp2832),min(round(parsertemp2832)))) +::STMT +MATRIX:parsertemp24100 +FLOAT:bin_width +LITERAL_FLOAT:1.0,0.5 ++(round(-(/(parsertemp24100,bin_width),0.5)),1.0) +::STMT +MATRIX:p,Z +cast.FLOAT(%*%(t(p),%*%(Z,p))) +::STMT +MATRIX:By2,By1 +LITERAL_FLOAT:3.0 +*(3.0,+(By1,By2)) +::STMT +FLOAT:m2X +LITERAL_FLOAT:1.0,100000.0 +*(m2X,/(100000.0,-(100000.0,1.0))) +::STMT +MATRIX:C,parsertemp11014 +LITERAL_FLOAT:1000.0,100.0 +*(/(sum(==(parsertemp11014,C)),1000.0),100.0) +::STMT +MATRIX:parsertemp2832 +sum(==(round(parsertemp2832),max(round(parsertemp2832)))) +::STMT +LITERAL_FLOAT:3.5355339059327378 +3.5355339059327378 +::STMT +FLOAT:dist ++(cast.MATRIX(dist),t(cast.MATRIX(dist))) +::STMT +MATRIX:parsertemp409054,ctab +FLOAT:threshold +>(/(parsertemp409054,rowSums(ctab)),threshold) +::STMT +MATRIX:_sbcvar95,_sbcvar97 +FLOAT:221_my +LITERAL_FLOAT:0.0 ++(%*%(_sbcvar95,_sbcvar97),-(0.0,221_my)) +::STMT +LITERAL_FLOAT:0.0 +INT:int373,int452,int579,int618 +%*%(t(rand(int373,int618,0.0,0.0)),rand(int452,int579,0.0,0.0)) +::STMT +MATRIX:leaf_ids +FLOAT:boundary_left,step_size +&(!(<(leaf_ids,+(boundary_left,step_size))),<(leaf_ids,+(+(boundary_left,step_size),step_size))) +::STMT +MATRIX:parsertemp163357 +LITERAL_FLOAT:1.0 +t(/(1.0,parsertemp163357)) +::STMT +MATRIX:ss +LITERAL_FLOAT:1.0 +/(1.0,t(ss)) +::STMT +MATRIX:parsertemp149867,Y +FLOAT:int506 +LITERAL_FLOAT:100.0 +*(/(sum(==(parsertemp149867,int506)),nrow(Y)),100.0) +::STMT +MATRIX:W1_rand +LITERAL_FLOAT:0.06835859270246632 +*(0.06835859270246632,W1_rand) +::STMT +MATRIX:221_CFreqs,_sbcvar95,_sbcvar98 +FLOAT:int359 +LITERAL_FLOAT:1000.0 +/(sum(*(+(221_CFreqs,int359),%*%(_sbcvar95,_sbcvar98))),-(1000.0,nrow(_sbcvar95))) +::STMT +MATRIX:sv,Y,Xd,out +sum(*(*(*(out,sv),Y),Xd)) +::STMT +MATRIX:w,X,y +t(-(%*%(X,w),y)) +::STMT +MATRIX:output1,dataset +LITERAL_FLOAT:0.0,0.16 +==(<(abs(-(output1,dataset)),0.16),0.0) +::STMT +MATRIX:w,out +FLOAT:lambda +LITERAL_FLOAT:2.0,0.5 ++(*(0.5,sum(*(out,out))),*(/(lambda,2.0),sum(*(w,w)))) +::STMT +MATRIX:parsertemp447299 +LITERAL_FLOAT:1.0 +t(-(parsertemp447299,1.0)) +::STMT +MATRIX:w,X,y +%*%(t(X),-(%*%(X,w),y)) +::STMT +MATRIX:parsertemp170251,lt_pos_neg +FLOAT:int508 +LITERAL_FLOAT:2.0,0.5 +*(-(0.5,lt_pos_neg),exp(/(-(int508,parsertemp170251),2.0))) +::STMT +MATRIX:g_new,g_old +/(sum(*(g_new,g_new)),sum(*(g_old,g_old))) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0,1.0 +==(<(X,1.0),0.0) +::STMT +MATRIX:sv,Xd +FLOAT:dd ++(dd,sum(*(*(Xd,sv),Xd))) +::STMT +MATRIX:parsertemp115729,parsertemp115724 +FLOAT:eAvg,n2 +LITERAL_FLOAT:0.050000000000000044,1.0,0.95 +-(*(0.95,-(/(parsertemp115724,eAvg),1.0)),*(0.050000000000000044,-(*(parsertemp115729,n2),1.0))) +::STMT +MATRIX:id +diag(==(id,t(id))) +::STMT +MATRIX:termination_bitmap,parsertemp222665 +FLOAT:int772,parsertemp222668,int577 +min(+(*(parsertemp222665,termination_bitmap),*(+(parsertemp222668,int772),-(int577,termination_bitmap)))) +::STMT +MATRIX:E,X +LITERAL_FLOAT:-1.0 +*(t(colSums(*(X,E))),-1.0) +::STMT +MATRIX:parsertemp150393 +LITERAL_FLOAT:0.0,0.1 +sum(==(<(abs(parsertemp150393),0.1),0.0)) +::STMT +MATRIX:means,parsertemp560511 +LITERAL_FLOAT:2.0 +^(rowSums(*(means,parsertemp560511)),2.0) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:1.0,2.0 +-(1.0,/(exp(finite_linear_terms),2.0)) +::STMT +MATRIX:scale_lambda,parsertemp150455 +LITERAL_FLOAT:1.0E-5 +*(%*%(scale_lambda,parsertemp150455),1.0E-5) +::STMT +MATRIX:X +FLOAT:index +LITERAL_FLOAT:1.0,2.0 ++(*(index,-(ncol(X),1.0)),2.0) +::STMT +MATRIX:p_gaps_vector +LITERAL_FLOAT:0.0 +t(>(p_gaps_vector,0.0)) +::STMT +MATRIX:X +LITERAL_FLOAT:2.0 +-(ncol(X),2.0) +::STMT +MATRIX:F +FLOAT:q +LITERAL_FLOAT:1.0 +*(sum(F),-(q,1.0)) +::STMT +MATRIX:X,Centering +LITERAL_FLOAT:2.0 +colSums(^(-(X,Centering),2.0)) +::STMT +MATRIX:m_iter_err_sum_squared,parsertemp379560,m_err_mean,m_iter_err_sum,m_err +FLOAT:int71,int123,i_process_item,int826 ++(-(*(^(m_err_mean,int123),i_process_item),*(*(int826,m_err_mean),+(parsertemp379560,m_iter_err_sum))),+(colSums(^(m_err,int71)),m_iter_err_sum_squared)) +::STMT +MATRIX:d,alpha +*(cast.FLOAT(alpha),d) +::STMT +MATRIX:prevTK2,totalE,X2 +*(==(%*%(X2,t(prevTK2)),t(rowSums(prevTK2))),totalE) +::STMT +FLOAT:parsertemp166531 +LITERAL_FLOAT:10.0 +*(10.0,parsertemp166531) +::STMT +MATRIX:parsertemp170136 +FLOAT:278_sq_root_d,parsertemp170150,pq_CG +LITERAL_FLOAT:0.5 +*(*(0.5,/(+(parsertemp170150,278_sq_root_d),sum(parsertemp170136))),pq_CG) +::STMT +MATRIX:FXY +LITERAL_FLOAT:1.0 +-(ncol(FXY),1.0) +::STMT +MATRIX:G,authorities +/(%*%(t(G),%*%(G,authorities)),max(%*%(t(G),%*%(G,authorities)))) +::STMT +MATRIX:shift_X,ssX_newbeta,z,beta ++(ssX_newbeta,%*%(t(shift_X),+(beta,z))) +::STMT +MATRIX:_sbcvar96,_sbcvar95,_sbcvar98 +LITERAL_FLOAT:-1.0 +*(+(%*%(_sbcvar95,_sbcvar96),-1.0),%*%(_sbcvar95,_sbcvar98)) +::STMT +MATRIX:V,y +LITERAL_FLOAT:-1.0 +*(%*%(t(V),y),-1.0) +::STMT +FLOAT:idx +LITERAL_FLOAT:64.0 +-(64.0,idx) +::STMT +FLOAT:e,mu,epochs +LITERAL_FLOAT:0.999,1.0 +/(-(0.999,mu),-(+(1.0,epochs),e)) +::STMT +LITERAL_FLOAT:1.0E-6 +INT:int362,int452 +rand(int452,int362,1.0E-6,1.0E-6) +::STMT +FLOAT:parsertemp22485,parsertemp22452,parsertemp22453 +abs(/(parsertemp22485,sqrt(+(parsertemp22452,parsertemp22453)))) +::STMT +MATRIX:parsertemp43626 +LITERAL_FLOAT:-1.0,2.0 +INT:int411,int621 +rand(int411,int621,*(2.0,*(-1.0,sum(parsertemp43626))),*(2.0,*(-1.0,sum(parsertemp43626)))) +::STMT +MATRIX:W +LITERAL_FLOAT:1.0,2.0 +*(^(sum(round(W)),2.0),+(sum(round(W)),1.0)) +::STMT +FLOAT:int1,parsertemp86,int43,parsertemp87,int280,wt +sqrt(/(*(*(int280,wt),-(wt,int1)),*(*(parsertemp86,parsertemp87),+(wt,int43)))) +::STMT +MATRIX:classes +FLOAT:split ++(split,nrow(classes)) +::STMT +MATRIX:V,y +LITERAL_FLOAT:-1.0 +*(*(%*%(t(V),y),-1.0),*(%*%(t(V),y),-1.0)) +::STMT +FLOAT:n_group_cols +LITERAL_FLOAT:3.0 ++(3.0,n_group_cols) +::STMT +MATRIX:rowSums_X_sq +LITERAL_FLOAT:2.29128784747792 +/(2.29128784747792,max(sqrt(rowSums_X_sq))) +::STMT +MATRIX:means,Y_counts,Y,parsertemp560603 +FLOAT:parsertemp560604 +LITERAL_FLOAT:2.0 +^(-(-(Y,means),%*%(Y_counts,/(parsertemp560603,parsertemp560604))),2.0) +::STMT +MATRIX:2883_ctab +FLOAT:int703 +LITERAL_FLOAT:1.0 +sum(==(rowSums(!=(2883_ctab,int703)),1.0)) +::STMT +MATRIX:g_new,parsertemp468777,tmp,g_old +/(cast.FLOAT(%*%(t(g_new),-(parsertemp468777,tmp))),cast.FLOAT(%*%(t(g_old),g_old))) +::STMT +FLOAT:norm_r2,norm_r2_initial +sqrt(/(norm_r2,norm_r2_initial)) +::STMT +MATRIX:Y +FLOAT:parsertemp185166 +-(parsertemp185166,min(Y)) +::STMT +MATRIX:X,parsertemp386475 +FLOAT:int965 +sqrt(+(+(*(int965,parsertemp386475),X),t(X))) +::STMT +MATRIX:2701_mask,doutd3 +LITERAL_FLOAT:0.5 +*(/(2701_mask,0.5),doutd3) +::STMT +MATRIX:svUpBnd,R,svLowBnd +*(<=(R,cast.FLOAT(svUpBnd)),>(R,cast.FLOAT(svLowBnd))) +::STMT +MATRIX:P12,map +LITERAL_FLOAT:0.0 +rowSums(!=(%*%(map,P12),0.0)) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,51.0,64.0 ++(*(-(i,1.0),64.0),51.0) +::STMT +MATRIX:parsertemp1531,y +FLOAT:int824 +LITERAL_FLOAT:2.0 +sum(^(%*%(-(int824,parsertemp1531),y),2.0)) +::STMT +FLOAT:K +LITERAL_FLOAT:11.0 +*(11.0,K) +::STMT +FLOAT:C,K +LITERAL_FLOAT:1.0,2.0 +*(*(C,+(C,1.0)),^(K,2.0)) +::STMT +MATRIX:prediction,target +LITERAL_FLOAT:2.0,0.5 +*(0.5,rowSums(^(-(prediction,target),2.0))) +::STMT +MATRIX:os,y,o +FLOAT:int829 +LITERAL_FLOAT:1.0 ++(1.0,exp(*(*(y,int829),+(o,os)))) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0005 +sqrt(*(1.0005,m2)) +::STMT +MATRIX:lambda,scale_X,p_CG,w,parsertemp170066,X ++(*(lambda,p_CG),*(cast.FLOAT(diag(scale_X)),%*%(t(X),*(w,parsertemp170066)))) +::STMT +MATRIX:parsertemp382670,X +LITERAL_FLOAT:0.0,2.0 +sum(*(!=(X,0.0),^(-(parsertemp382670,X),2.0))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,60.0,64.0 ++(*(-(i,1.0),64.0),60.0) +::STMT +FLOAT:C,Hf,Wf +LITERAL_FLOAT:2.0 +/(2.0,*(*(C,Hf),Wf)) +::STMT +MATRIX:linear_terms,Y +FLOAT:parsertemp171226,link_power,parsertemp171223,int493 +/(*(^(linear_terms,-(parsertemp171226,int493)),-(Y,^(linear_terms,parsertemp171223))),link_power) +::STMT +FLOAT:int276,z,pp_CG,parsertemp170091 +LITERAL_FLOAT:0.5 +*(0.5,/(+(*(z,int276),sqrt(parsertemp170091)),pp_CG)) +::STMT +MATRIX:r,c,F +LITERAL_FLOAT:2.0 +^(-(F,/(%*%(r,c),sum(F))),2.0) +::STMT +FLOAT:float658,float239,float677,float221 +LITERAL_FLOAT:2.0 +INT:int110,int752,int269,int936 ++(sum(^(rand(int752,int936,float221,float239),2.0)),sum(^(rand(int269,int110,float677,float658),2.0))) +::STMT +MATRIX:R,B,parsertemp503780 +%*%(t(+(R,diag(parsertemp503780))),B) +::STMT +LITERAL_FLOAT:1.0,20.0 ++(20.0,1.0) +::STMT +MATRIX:X,mu,parsertemp183827,ScaleFactor +FLOAT:int264,N +LITERAL_FLOAT:1.0 +-(/(%*%(t(X),/(X,ScaleFactor)),-(N,1.0)),*(/(N,-(N,int264)),%*%(t(mu),/(parsertemp183827,N)))) +::STMT +LITERAL_FLOAT:1.0,7000.0 +-(7000.0,1.0) +::STMT +MATRIX:knn_index +FLOAT:s +LITERAL_FLOAT:100.0 +*(/(s,100.0),ncol(knn_index)) +::STMT +FLOAT:p,P,Q,q,int89 ++(+(+(+(int89,p),P),Q),q) +::STMT +FLOAT:2344_s_err_vars,2344_s_err_mean +LITERAL_FLOAT:-1.0,0.001 +/(-(*(0.001,-1.0),2344_s_err_mean),2344_s_err_vars) +::STMT +MATRIX:Y +FLOAT:class +LITERAL_FLOAT:1.0,2.0 +-(*(2.0,==(Y,class)),1.0) +::STMT +FLOAT:int520,int776,parsertemp459331,Win +LITERAL_FLOAT:2.0,64.0 +/(2.0,*(*(64.0,/(parsertemp459331,int776)),/(/(Win,int520),2.0))) +::STMT +MATRIX:W1_rand,stds,parsertemp400568 +LITERAL_FLOAT:0.08333333333333333 +t(%*%(*(0.08333333333333333,W1_rand),t(/(parsertemp400568,stds)))) +::STMT +MATRIX:p_CG +FLOAT:int158,parsertemp254749,z,parsertemp254772,int517 +*(parsertemp254772,/(-(*(z,int158),sqrt(parsertemp254749)),sum(^(p_CG,int517)))) +::STMT +MATRIX:ytest +FLOAT:mean_y_test,int293 +LITERAL_FLOAT:0.0,1.0,2.0 +/(-(^(cast.FLOAT(ytest),2.0),*(1.0,^(mean_y_test,int293))),0.0) +::STMT +MATRIX:X2 +FLOAT:minSup +>=(t(colSums(X2)),minSup) +::STMT +MATRIX:B,S +LITERAL_FLOAT:2.0 +^(+(B,S),2.0) +::STMT +MATRIX:parsertemp31105,parsertemp31107 +FLOAT:int559,int592 +LITERAL_FLOAT:1.0,2.0,2000.0 +/(^(/(-(parsertemp31105,parsertemp31107),-(int559,int592)),2.0),*(^(2000.0,2.0),-(2000.0,1.0))) +::STMT +MATRIX:D,parsertemp570375,classMeans +LITERAL_FLOAT:1.0,2.0 +*(/(1.0,2.0),%*%(%*%(-(D,classMeans),parsertemp570375),t(-(D,classMeans)))) +::STMT +FLOAT:link_power +LITERAL_FLOAT:0.0,1.0 +-(/(0.0,link_power),1.0) +::STMT +FLOAT:parsertemp496694,int349,parsertemp496686,n,a0 +LITERAL_FLOAT:1.0,2.0 +*(/(1.0,*(2.0,n)),+(parsertemp496694,/(^(parsertemp496686,int349),a0))) +::STMT +MATRIX:yhat +FLOAT:ytest,int615 +LITERAL_FLOAT:1.0,2.0 +-(^(-(cast.FLOAT(ytest),cast.FLOAT(yhat)),2.0),*(1.0,^(/(ytest,int615),2.0))) +::STMT +MATRIX:id +diag(==(id,cast.FLOAT(id))) +::STMT +MATRIX:parsertemp456742,X,y +LITERAL_FLOAT:0.0 +%*%(t(-(0.0,%*%(parsertemp456742,y))),%*%(t(X),y)) +::STMT +MATRIX:parsertemp410081,d_r_rev,parsertemp410090 +FLOAT:o +LITERAL_FLOAT:-1.0 +-(+(*(cast.FLOAT(parsertemp410081),-1.0),cast.FLOAT(%*%(d_r_rev,parsertemp410090))),o) +::STMT +MATRIX:parsertemp570396,classVars +*(diag(parsertemp570396),max(classVars)) +::STMT +MATRIX:subspace_idx,parsertemp72201 +FLOAT:subvector_size +LITERAL_FLOAT:1.0 +/(1.0,<(-(subspace_idx,*(parsertemp72201,subvector_size)),1.0)) +::STMT +MATRIX:252_X,252_K +*(cast.FLOAT(252_K),-(cast.FLOAT(252_X),cast.FLOAT(252_X))) +::STMT +MATRIX:is_natural_parameter_log_zero,Y +sum(*(is_natural_parameter_log_zero,abs(Y))) +::STMT +MATRIX:X_Train,X_Test ++(sum(X_Train),sum(X_Test)) +::STMT +MATRIX:G,authorities,hubs +LITERAL_FLOAT:2.0 +^(-(/(%*%(G,authorities),max(hubs)),hubs),2.0) +::STMT +FLOAT:parsertemp115827,sum_sq_y_test,n +LITERAL_FLOAT:1.0 +sqrt(/(-(sum_sq_y_test,*(n,parsertemp115827)),-(n,1.0))) +::STMT +FLOAT:link_power +LITERAL_FLOAT:2.0 +-(/(2.0,link_power),2.0) +::STMT +FLOAT:link_power +LITERAL_FLOAT:0.0,2.0 +-(/(0.0,link_power),2.0) +::STMT +MATRIX:images +LITERAL_FLOAT:2.0,255.0 +*(/(images,255.0),2.0) +::STMT +MATRIX:s,w +LITERAL_FLOAT:1.0 +*(1.0,cast.FLOAT(%*%(t(w),s))) +::STMT +FLOAT:m2X +LITERAL_FLOAT:1.000010000100001 +*(m2X,1.000010000100001) +::STMT +FLOAT:check_max,check_min +/(+(check_min,check_max),-(check_max,check_min)) +::STMT +MATRIX:_sbcvar14,_sbcvar13 +FLOAT:int143,parsertemp13703,int127 +LITERAL_FLOAT:999.0 +/(sum(*(-(_sbcvar13,int143),_sbcvar14)),*(999.0,/(*(parsertemp13703,int127),999.0))) +::STMT +FLOAT:wcss +LITERAL_FLOAT:1.0E-6 +*(1.0E-6,wcss) +::STMT +MATRIX:parsertemp31763,parsertemp31756,parsertemp31758,maxsc +FLOAT:minSup +LITERAL_FLOAT:0.0 +&(&(>=(t(parsertemp31756),minSup),>(t(parsertemp31763),0.0)),|(>(t(parsertemp31758),0.0),>(maxsc,0.0))) +::STMT +MATRIX:r,parsertemp44050 +sqrt(sum(*(-(r,parsertemp44050),-(r,parsertemp44050)))) +::STMT +FLOAT:deviance_nodisp +LITERAL_FLOAT:0.1 ++(deviance_nodisp,0.1) +::STMT +MATRIX:y +FLOAT:n +LITERAL_FLOAT:2.0 +/(sum(^(y,2.0)),n) +::STMT +LITERAL_FLOAT:0.0,1.0 +INT:int467,m +sum(abs(rand(m,int467,0.0,1.0))) +::STMT +MATRIX:parsertemp436667,precisions +LITERAL_FLOAT:1.0 +INT:parsertemp436666,int896 +*(rand(int896,parsertemp436666,1.0,1.0),t(rowSums(*(parsertemp436667,precisions)))) +::STMT +MATRIX:p,q,lambda +*(p,+(q,*(lambda,p))) +::STMT +MATRIX:svUpBnd,R,svLowBnd +diag(*(<=(R,cast.FLOAT(svUpBnd)),>(R,cast.FLOAT(svLowBnd)))) +::STMT +MATRIX:lambda,B,Grad +LITERAL_FLOAT:2.0 +sum(^(+(Grad,*(lambda,B)),2.0)) +::STMT +MATRIX:W4_rand +LITERAL_FLOAT:0.08720414403938946 +*(0.08720414403938946,W4_rand) +::STMT +MATRIX:parsertemp415351,parsertemp415356 +FLOAT:parsertemp415362,parsertemp415358,n +LITERAL_FLOAT:1.0 +-(1.0,/(-(sum(parsertemp415356),*(n,parsertemp415358)),-(sum(parsertemp415351),*(n,parsertemp415362)))) +::STMT +MATRIX:y_residual,ytest +FLOAT:int275,avg_res,mean_y_test,int699,int768,int838 +/(-(sum(^(y_residual,int838)),*($1:nrow(ytest),^(avg_res,int275))),-(sum(^(ytest,int768)),*($1,^(mean_y_test,int699)))) +::STMT +MATRIX:grad +FLOAT:int211 +LITERAL_FLOAT:2.0 +sqrt(sum(^(-(int211,grad),2.0))) +::STMT +MATRIX:_sbcvar92,parsertemp27721,220_r,220_c,220_E +FLOAT:int757 +LITERAL_FLOAT:2.0,1.0E-4 +/(^(-(_sbcvar92,+(parsertemp27721,220_E)),2.0),+(*(==(220_E,int757),1.0E-4),/(%*%(220_r,220_c),sum(_sbcvar92)))) +::STMT +MATRIX:s +LITERAL_FLOAT:1.0 +/(1.0,s) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:-1.0 +/(-1.0,linear_terms) +::STMT +MATRIX:184_probs,183_dpred,parsertemp146939 +FLOAT:beta1 +LITERAL_FLOAT:1.0 +*(-(1.0,beta1),colSums(-(*(183_dpred,184_probs),*(184_probs,parsertemp146939)))) +::STMT +MATRIX:r,Hd +FLOAT:parsertemp44049 +sum(*(-(r,*(parsertemp44049,Hd)),-(r,*(parsertemp44049,Hd)))) +::STMT +MATRIX:parsertemp222310 +FLOAT:parsertemp222313 +LITERAL_FLOAT:0.5 ++(/(parsertemp222310,parsertemp222313),0.5) +::STMT +MATRIX:parsertemp414372,X +FLOAT:int923,int309 +LITERAL_FLOAT:200.0,2.0 +-(t(colSums(^(X,int923))),*(200.0,^(/(parsertemp414372,int309),2.0))) +::STMT +FLOAT:k +LITERAL_FLOAT:1.0,4.0 +-(+(k,4.0),1.0) +::STMT +FLOAT:parsertemp477829,parsertemp477814,2814_K,2814_X,2814_Y,inp_x +LITERAL_FLOAT:1.0 ++(*(-(*(2814_K,2814_X),-(2814_Y,2814_Y)),-(1.0,/(parsertemp477814,2814_X))),*(+(*(parsertemp477829,2814_X),-(2814_Y,2814_Y)),/(-(inp_x,2814_X),-(2814_X,2814_X)))) +::STMT +FLOAT:output_values,log_odds,float34 +LITERAL_FLOAT:1.0,2.7182818284 ++(1.0,^(2.7182818284,+(log_odds,*(float34,output_values)))) +::STMT +FLOAT:run_index +LITERAL_FLOAT:24.0 +*(24.0,run_index) +::STMT +MATRIX:p,parsertemp1934,parsertemp1935 +FLOAT:eps +cast.FLOAT(%*%(t(p),+(%*%(parsertemp1934,parsertemp1935),*(eps,p)))) +::STMT +MATRIX:parsertemp43620,parsertemp43619 +FLOAT:float10 +LITERAL_FLOAT:1.0 +*(/(1.0,+(1.0,exp(parsertemp43619))),-(1.0,/(1.0,+(float10,parsertemp43620)))) +::STMT +MATRIX:X +FLOAT:N +/(colSums(X),N) +::STMT +MATRIX:parsertemp31024,parsertemp31022 +FLOAT:int627 +LITERAL_FLOAT:1.0,2.0,100.0 +^(/(-(colSums(parsertemp31022),*(int627,parsertemp31024)),-(100.0,1.0)),2.0) +::STMT +MATRIX:finite_linear_terms +FLOAT:int949 +LITERAL_FLOAT:0.0,2.0 +exp(/(-(0.0,^(finite_linear_terms,int949)),2.0)) +::STMT +MATRIX:os,y,o +LITERAL_FLOAT:-1.0 +*(*(y,-1.0),+(o,os)) +::STMT +MATRIX:g +LITERAL_FLOAT:2.0,0.01 +*(0.01,sum(^(g,2.0))) +::STMT +MATRIX:Y,parsertemp171319 +FLOAT:one_over_sqrt_two_pi,float696 +LITERAL_FLOAT:2.0 +*(*(exp(/(parsertemp171319,float696)),^(one_over_sqrt_two_pi,2.0)),rowSums(Y)) +::STMT +MATRIX:negSampleMeans +LITERAL_FLOAT:2.0,1500.0 +*(1500.0,^(negSampleMeans,2.0)) +::STMT +FLOAT:parsertemp169812 +LITERAL_FLOAT:2.302585092994046,0.5 +round(-(/(parsertemp169812,2.302585092994046),0.5)) +::STMT +MATRIX:P,X,Y +%*%(t(X),-(P,Y)) +::STMT +MATRIX:parsertemp285516 +FLOAT:pp,parsertemp285518,parsertemp285520 +LITERAL_FLOAT:-1.0 +/(+(*(sum(parsertemp285516),-1.0),sqrt(-(parsertemp285518,parsertemp285520))),pp) +::STMT +LITERAL_FLOAT:0.08692913816996169 +0.08692913816996169 +::STMT +MATRIX:WM,Y,CMeans +-(CMeans,/(sum(*(Y,WM)),sum(WM))) +::STMT +MATRIX:colSD,colMean +LITERAL_FLOAT:3.0 ++(colMean,*(3.0,colSD)) +::STMT +MATRIX:p,q,V,parsertemp1939 +FLOAT:norm_r2 +LITERAL_FLOAT:1.0E-8 +*(/(norm_r2,cast.FLOAT(%*%(parsertemp1939,q))),+(%*%(t(V),%*%(V,p)),*(1.0E-8,p))) +::STMT +MATRIX:subspace_idx,parsertemp109953 +LITERAL_FLOAT:42.0 +-(subspace_idx,*(parsertemp109953,42.0)) +::STMT +MATRIX:linear_terms +FLOAT:var_power +LITERAL_FLOAT:1.0 +*(^(exp(linear_terms),-(1.0,var_power)),exp(linear_terms)) +::STMT +MATRIX:scale_lambda +LITERAL_FLOAT:0.001 +*(scale_lambda,0.001) +::STMT +MATRIX:X +FLOAT:a0 +LITERAL_FLOAT:2.0 +/(^(cast.FLOAT(X),2.0),a0) +::STMT +MATRIX:parsertemp13711,_sbcvar14 +FLOAT:parsertemp13704,float583 +LITERAL_FLOAT:1.0,999.0 +-(1.0,/(sum(*(parsertemp13711,_sbcvar14)),*(999.0,/(parsertemp13704,float583)))) +::STMT +MATRIX:P,X,Y +LITERAL_FLOAT:2.0 +sum(^(%*%(t(X),-(P,Y)),2.0)) +::STMT +MATRIX:Y_counts,vars +FLOAT:dispersion +/(*(dispersion,colSums(vars)),sum(Y_counts)) +::STMT +MATRIX:termination_bitmap,parsertemp222665,parsertemp222670 +FLOAT:parsertemp222669 +==(*(parsertemp222665,termination_bitmap),min(+(*(parsertemp222665,termination_bitmap),*(parsertemp222669,parsertemp222670)))) +::STMT +MATRIX:parsertemp31026,parsertemp31033 +FLOAT:int571,parsertemp31034,int594,int55,parsertemp31027,int98,int225,int812,int171,int19 +LITERAL_FLOAT:2.0 ++(/(^(/(parsertemp31026,parsertemp31027),2.0),*(^(int571,int98),-(int225,int171))),/(^(/(parsertemp31033,parsertemp31034),2.0),*(^(int19,int594),-(int55,int812)))) +::STMT +MATRIX:X_train +LITERAL_FLOAT:256.0 +/(nrow(X_train),256.0) +::STMT +MATRIX:r,scale_X,shift_X ++(*(scale_X,r),*(cast.FLOAT(r),shift_X)) +::STMT +MATRIX:y_hat,b,R +-(-(b,%*%(R,y_hat)),y_hat) +::STMT +MATRIX:b,H,parsertemp410187,parsertemp410189 +%*%(%*%(t(b),-(+(H,parsertemp410187),diag(parsertemp410189))),b) +::STMT +MATRIX:U,V +FLOAT:int540,int757 +LITERAL_FLOAT:5.0E-7 +*(5.0E-7,+(sum(^(U,int540)),sum(^(V,int757)))) +::STMT +MATRIX:subspace_idx,parsertemp75105 +LITERAL_FLOAT:32.0 +-(subspace_idx,*(parsertemp75105,32.0)) +::STMT +MATRIX:X +LITERAL_FLOAT:-1.0 +*(*(t(colSums(X)),-1.0),-1.0) +::STMT +MATRIX:parsertemp410245,parsertemp410248 +FLOAT:int865,float295 +LITERAL_FLOAT:0.6666666666666666 +max(^(/(-(int865,parsertemp410245),*(float295,parsertemp410248)),0.6666666666666666)) +::STMT +MATRIX:parsertemp146931,184_dtemp,parsertemp146929,184_unnorm_probs,parsertemp146936 +colSums(-(*(*(parsertemp146929,parsertemp146931),/(184_unnorm_probs,parsertemp146936)),*(/(184_unnorm_probs,parsertemp146936),rowSums(184_dtemp)))) +::STMT +MATRIX:P +/(+(P,t(P)),sum(+(P,t(P)))) +::STMT +MATRIX:parsertemp265709,tmp,Z,XtZ +FLOAT:ZtZ_sum +*(tmp,%*%(t(/(XtZ,ZtZ_sum)),/(%*%(parsertemp265709,Z),sum(tmp)))) +::STMT +MATRIX:test_val +LITERAL_FLOAT:128.0 +/(nrow(test_val),128.0) +::STMT +MATRIX:s,w +%*%(t(+(w,s)),+(w,s)) +::STMT +MATRIX:Hdiff,betamax,beta +FLOAT:INF,int983,int39 +==(+(*(>=(Hdiff,int983),betamax),*(<(Hdiff,int39),beta)),INF) +::STMT +MATRIX:subspace_idx,parsertemp73653 +LITERAL_FLOAT:16.0 +-(subspace_idx,*(parsertemp73653,16.0)) +::STMT +MATRIX:subspace_idx,parsertemp107049 +LITERAL_FLOAT:7.0 +-(subspace_idx,*(parsertemp107049,7.0)) +::STMT +LITERAL_FLOAT:1.8378770664093453 +1.8378770664093453 +::STMT +MATRIX:s,parsertemp44016,d +FLOAT:parsertemp44015,delta2 +-(delta2,%*%(t(-(s,parsertemp44016)),-(s,*(parsertemp44015,d)))) +::STMT +MATRIX:parsertemp43626 +LITERAL_FLOAT:-1.0 +sum(*(parsertemp43626,-1.0)) +::STMT +MATRIX:r,d,parsertemp43999 +LITERAL_FLOAT:2.0 +/(sum(^(r,2.0)),cast.FLOAT(%*%(t(d),+(d,parsertemp43999)))) +::STMT +MATRIX:left,tmp,right +==(%*%(tmp,left),%*%(tmp,right)) +::STMT +FLOAT:Z_logl,dispersion +/(Z_logl,sqrt(dispersion)) +::STMT +FLOAT:int81,ytest,int874,parsertemp454076 +LITERAL_FLOAT:0.0 +sqrt(/(-(^(ytest,int874),*(int81,parsertemp454076)),0.0)) +::STMT +MATRIX:r_LS +FLOAT:norm_r2_LS,p_LS,parsertemp170552,lambda_LS +LITERAL_FLOAT:2.0 +^(+(cast.FLOAT(r_LS),*(/(norm_r2_LS,p_LS),+(parsertemp170552,lambda_LS))),2.0) +::STMT +MATRIX:subspace_idx,parsertemp72201 +LITERAL_FLOAT:8.0 +-(subspace_idx,*(parsertemp72201,8.0)) +::STMT +MATRIX:w_X,z_LS,X +/(nrow(X),*(cast.FLOAT(w_X),cast.FLOAT(z_LS))) +::STMT +MATRIX:col,more_than_ub,parsertemp24107,parsertemp24102,parsertemp24103 +FLOAT:int331,num_bins +LITERAL_FLOAT:1.0 ++(+(*(-(parsertemp24107,more_than_ub),+(parsertemp24103,int331)),*(>(col,num_bins),num_bins)),<(+(round(parsertemp24102),1.0),1.0)) +::STMT +MATRIX:parsertemp171315,Y,parsertemp171307,parsertemp171319 +FLOAT:float945,float368,float541 +*(*(exp(/(parsertemp171319,float368)),*(/(float945,parsertemp171307),+(float541,parsertemp171315))),rowSums(Y)) +::STMT +FLOAT:int92,n +LITERAL_FLOAT:1.0,2.0,0.02 +*(-(+(-(n,int92),1.0),2.0),0.02) +::STMT +MATRIX:subspace_idx,parsertemp75105 +LITERAL_FLOAT:1.0,32.0 +<(-(subspace_idx,*(parsertemp75105,32.0)),1.0) +::STMT +MATRIX:Y_prob,Y +*(rowSums(Y),-(*(Y,Y_prob),*(Y,Y_prob))) +::STMT +MATRIX:neighbors +-(neighbors,diag(diag(neighbors))) +::STMT +MATRIX:X +LITERAL_FLOAT:-1.0 +*(t(colSums(X)),-1.0) +::STMT +MATRIX:X,Y,K +FLOAT:int87,x +*(+(*(*(K,int87),-(X,X)),-(Y,Y)),/(-(x,X),-(X,X))) +::STMT +MATRIX:resp +LITERAL_FLOAT:2.22E-16 +t(+(colSums(resp),2.22E-16)) +::STMT +MATRIX:TopIxs,TopVals +LITERAL_FLOAT:0.0 +*(TopIxs,>(TopVals,0.0)) +::STMT +MATRIX:p_CG +FLOAT:parsertemp170089,z,pp_CG +LITERAL_FLOAT:-1.0 +-(*(*(cast.FLOAT(z),sum(p_CG)),-1.0),sqrt(-(*(z,z),*(pp_CG,parsertemp170089)))) +::STMT +MATRIX:r,X,y +FLOAT:int400 +cast.FLOAT(%*%(t(-(int400,r)),%*%(t(X),y))) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0E-6 +*(1.0E-6,cast.FLOAT(%*%(t(X),X))) +::STMT +MATRIX:d,exp_Xb,X +rev(*(%*%(X,d),exp_Xb)) +::STMT +MATRIX:R +FLOAT:i8 +LITERAL_FLOAT:24.0 +-(ncol(R),*(24.0,i8)) +::STMT +MATRIX:W2_rand +LITERAL_FLOAT:0.21483446221182986 +*(0.21483446221182986,W2_rand) +::STMT +LITERAL_FLOAT:1.0E-12 +INT:int552,int757 +diag(rand(int552,int757,1.0E-12,1.0E-12)) +::STMT +MATRIX:C,parsertemp174574 +FLOAT:numRows +LITERAL_FLOAT:100.0 +*(/(sum(==(parsertemp174574,C)),numRows),100.0) +::STMT +FLOAT:a0 +LITERAL_FLOAT:1.0E-5 ++(a0,1.0E-5) +::STMT +MATRIX:parsertemp149307,parsertemp149305 +FLOAT:parsertemp149336,obj,parsertemp149333,parsertemp149340,float839 +LITERAL_FLOAT:-0.5 +/(-(obj,+(+(parsertemp149333,parsertemp149336),*(float839,parsertemp149340))),*(-0.5,-(sum(parsertemp149305),sum(parsertemp149307)))) +::STMT +MATRIX:Y,Xd,parsertemp2775,out +FLOAT:int664,int14 +*(*(*(-(int14,parsertemp2775),>(out,int664)),Y),Xd) +::STMT +MATRIX:2903_mask,dout,2904_X,2902_W +FLOAT:2903_p +LITERAL_FLOAT:0.0 +*(>(2904_X,0.0),*(/(2903_mask,2903_p),%*%(dout,t(2902_W)))) +::STMT +MATRIX:PRED,GT +/(sum(*(PRED,GT)),sum(PRED)) +::STMT +FLOAT:AIC_best_orig +LITERAL_FLOAT:0.001 +abs(*(0.001,AIC_best_orig)) +::STMT +MATRIX:s,d +FLOAT:norm_r2,alpha_deno +%*%(t(+(s,*(norm_r2,d))),+(s,*(/(norm_r2,alpha_deno),d))) +::STMT +MATRIX:resp,parsertemp443532,X,weight +LITERAL_FLOAT:2.22E-16 +*(t(/(%*%(parsertemp443532,X),t(weight))),+(colSums(resp),2.22E-16)) +::STMT +FLOAT:x +LITERAL_FLOAT:-1.0 +exp(*(x,-1.0)) +::STMT +FLOAT:num_records +LITERAL_FLOAT:3840.0 +/(3840.0,num_records) +::STMT +MATRIX:R,dssm +FLOAT:2_n +LITERAL_FLOAT:1.0 +-(/(2_n,-(R,dssm)),1.0) +::STMT +MATRIX:w,wnew +FLOAT:sigma,alpha +LITERAL_FLOAT:0.5 +*(*(*(0.5,sigma),alpha),sum(*(-(wnew,w),-(wnew,w)))) +::STMT +MATRIX:r,parsertemp1936,parsertemp1937 +FLOAT:parsertemp1941,norm_r2 +LITERAL_FLOAT:0.0 +-(0.0,+(r,*(/(norm_r2,parsertemp1941),+(parsertemp1936,parsertemp1937)))) +::STMT +MATRIX:qLow,length,qUp +LITERAL_FLOAT:2.0 +>=(rowSums(|(<(length,qLow),>(length,qUp))),2.0) +::STMT +MATRIX:_sbcvar11 +LITERAL_FLOAT:1000.0 +-(_sbcvar11,/(%*%(rowSums(_sbcvar11),colSums(_sbcvar11)),1000.0)) +::STMT +MATRIX:minD,D +rowSums(<=(D,minD)) +::STMT +LITERAL_FLOAT:0.99 +0.99 +::STMT +MATRIX:parsertemp500608,parsertemp500604,parsertemp500605 +FLOAT:lambda +LITERAL_FLOAT:0.0 +abs(*(*(parsertemp500604,-(parsertemp500605,lambda)),>(-(parsertemp500608,lambda),0.0))) +::STMT +MATRIX:parsertemp260769,w +FLOAT:reg +LITERAL_FLOAT:2.0 +*(/(reg,2.0),sum(*(+(w,parsertemp260769),+(w,parsertemp260769)))) +::STMT +MATRIX:_sbcvar1708 +LITERAL_FLOAT:0.7 +*(_sbcvar1708,0.7) +::STMT +MATRIX:WM +FLOAT:m2X +LITERAL_FLOAT:1.0 +*(m2X,/(sum(WM),-(sum(WM),1.0))) +::STMT +MATRIX:tmp,g_old +/(cast.FLOAT(%*%(t(tmp),tmp)),cast.FLOAT(%*%(t(g_old),g_old))) +::STMT +MATRIX:parsertemp409789,parsertemp409798,parsertemp409788,parsertemp409797 +FLOAT:int843 +LITERAL_FLOAT:0.0 +%*%(t(+(-(int843,parsertemp409789),t(parsertemp409798))),+(-(0.0,t(parsertemp409788)),t(colSums(parsertemp409797)))) +::STMT +MATRIX:tmp_Xw,Y,parsertemp2775 +FLOAT:int711 +LITERAL_FLOAT:0.0,1.0 +*(*(-(1.0,*(Y,tmp_Xw)),>(-(int711,parsertemp2775),0.0)),Y) +::STMT +MATRIX:parsertemp389212 +LITERAL_FLOAT:2.0,1058.0 +^(/(parsertemp389212,1058.0),2.0) +::STMT +MATRIX:2903_mask,dout,X,2904_X,parsertemp555692 +FLOAT:2903_p +LITERAL_FLOAT:0.0 +%*%(t(X),*(>(2904_X,0.0),*(/(2903_mask,2903_p),%*%(dout,parsertemp555692)))) +::STMT +MATRIX:w,g +FLOAT:alpha +-(w,/(g,alpha)) +::STMT +MATRIX:P,lambda,X,Y,B_new +LITERAL_FLOAT:2.0 +^(+(%*%(t(X),-(P,Y)),*(lambda,B_new)),2.0) +::STMT +FLOAT:i +LITERAL_FLOAT:4.0 ++(4.0,i) +::STMT +MATRIX:cdf_min_distances,random_row +colSums(<(cdf_min_distances,*(random_row,cdf_min_distances))) +::STMT +FLOAT:deviance_nodisp,eps +LITERAL_FLOAT:0.1 +*(eps,+(deviance_nodisp,0.1)) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0,2.0 ++(exp(*(2.0,X)),1.0) +::STMT +MATRIX:colSD,X,colMean +LITERAL_FLOAT:3.0 +<(X,-(colMean,*(3.0,colSD))) +::STMT +MATRIX:colSD,X,colMean +LITERAL_FLOAT:3.0 +>(X,+(colMean,*(3.0,colSD))) +::STMT +MATRIX:p_LS,parsertemp170551,X +FLOAT:lambda_LS ++(*(cast.FLOAT(%*%(parsertemp170551,X)),cast.FLOAT(p_LS)),*(lambda_LS,cast.FLOAT(p_LS))) +::STMT +MATRIX:parsertemp10744,parsertemp10746,V,W,H +LITERAL_FLOAT:1.0E-8 +/(%*%(V,t(*(H,parsertemp10744))),+(%*%(W,%*%(H,parsertemp10746)),1.0E-8)) +::STMT +MATRIX:W +round(W) +::STMT +MATRIX:X +FLOAT:threshold +*(>(X,threshold),X) +::STMT +MATRIX:mu +FLOAT:window_size,q +-(q,*(window_size,cast.FLOAT(*(mu,mu)))) +::STMT +FLOAT:log_ten,parsertemp169814 +LITERAL_FLOAT:4.0 +exp(*(log_ten,-(4.0,round(parsertemp169814)))) +::STMT +MATRIX:parsertemp393570,W3_rand +FLOAT:int397,int924 +LITERAL_FLOAT:0.128920512778062 +%*%(*(0.128920512778062,W3_rand),t(/(-(parsertemp393570,int397),+(parsertemp393570,int924)))) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,750.0 +/(+(classFeatureCounts,1.0),+(rowSums(classFeatureCounts),*(750.0,1.0))) +::STMT +FLOAT:a,b,x +LITERAL_FLOAT:2.0 ++(*(a,^(x,2.0)),*(b,x)) +::STMT +MATRIX:Q,R,parsertemp500307 +FLOAT:int723 +LITERAL_FLOAT:2.0 +-(+(rowSums(^(R,int723)),t(rowSums(parsertemp500307))),*(2.0,%*%(R,t(Q)))) +::STMT +LITERAL_FLOAT:0.25 +0.25 +::STMT +MATRIX:parsertemp115857,X,avg_X_cols +FLOAT:int636 +LITERAL_FLOAT:1.0 +/(-(t(colSums(parsertemp115857)),*($1:nrow(X),^(avg_X_cols,int636))),-($1,1.0)) +::STMT +FLOAT:run_index +LITERAL_FLOAT:1.0,2.0 ++(*(2.0,-(run_index,1.0)),1.0) +::STMT +FLOAT:parsertemp181047,parsertemp181040 +LITERAL_FLOAT:1.0,8.0 +sqrt(*(8.0,-(1.0,/(parsertemp181040,parsertemp181047)))) +::STMT +MATRIX:g0_1,d_r_rev,parsertemp410116 ++(g0_1,t(colSums(*(parsertemp410116,d_r_rev)))) +::STMT +MATRIX:parsertemp411194,parsertemp411197,W,H,parsertemp411205,parsertemp411206 +%*%(/(*(W,%*%(parsertemp411205,parsertemp411206)),t(rowSums(H))),/(*(H,%*%(parsertemp411194,parsertemp411197)),t(colSums(W)))) +::STMT +MATRIX:X,y +FLOAT:int879,int649 +INT:int378,m +%*%(t(X),-(%*%(X,rand(m,int378,int649,int879)),y)) +::STMT +MATRIX:ss +FLOAT:130_n +LITERAL_FLOAT:1.0 +-(/(130_n,ss),1.0) +::STMT +MATRIX:ot,yt +LITERAL_FLOAT:0.0,100.0 +*(sum(>(*(yt,ot),0.0)),100.0) +::STMT +LITERAL_FLOAT:-0.5 +-0.5 +::STMT +LITERAL_FLOAT:0.5 +0.5 +::STMT +MATRIX:W +FLOAT:int197,int575,m3,var,wt +LITERAL_FLOAT:2.0,3.0 +/(*(^(sum(W),2.0),m3),*(*(-(wt,int197),-(wt,int575)),^(sqrt(var),3.0))) +::STMT +LITERAL_FLOAT:0.254829592 +0.254829592 +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:1.0 +-(1.0,exp(linear_terms)) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,128.0,4.0 ++(*(-(i,1.0),128.0),4.0) +::STMT +FLOAT:522_padh,522_Hin +LITERAL_FLOAT:1.0,2.0 +-(+(522_Hin,*(2.0,522_padh)),1.0) +::STMT +FLOAT:obj,objnew +abs(-(objnew,obj)) +::STMT +LITERAL_FLOAT:1.0,150.0 +-(150.0,1.0) +::STMT +LITERAL_FLOAT:0.75 +0.75 +::STMT +MATRIX:p,A,r,parsertemp477951 +FLOAT:norm_r2 ++(r,*(/(norm_r2,sum(parsertemp477951)),%*%(t(A),%*%(A,p)))) +::STMT +MATRIX:parsertemp285848,X +LITERAL_FLOAT:0.0 +%*%(t(-(0.0,t(parsertemp285848))),t(colSums(X))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,128.0,3.0 ++(*(-(i,1.0),128.0),3.0) +::STMT +MATRIX:lambda,beta +LITERAL_FLOAT:0.0,2.0 +^(+(0.0,*(cast.FLOAT(lambda),cast.FLOAT(beta))),2.0) +::STMT +MATRIX:p,r,Z +FLOAT:parsertemp31794,norm_r2 +*(+(r,*(/(norm_r2,parsertemp31794),%*%(Z,p))),+(r,*(/(norm_r2,parsertemp31794),%*%(Z,p)))) +::STMT +LITERAL_FLOAT:0.0625 +0.0625 +::STMT +LITERAL_FLOAT:1.0002795638803466 +1.0002795638803466 +::STMT +FLOAT:int263,2690_Hin,int538 +LITERAL_FLOAT:2.0 +/(-(+(2690_Hin,*(int538,int263)),2.0),2.0) +::STMT +MATRIX:A,CVars,CFreqs +FLOAT:int972 +/(sum(*(-(CFreqs,int972),CVars)),-(nrow(A),nrow(CFreqs))) +::STMT +MATRIX:linear_terms +FLOAT:link_power,int964,int879 +LITERAL_FLOAT:-2.0,1.0 +/(^(linear_terms,+(-2.0,/(int964,link_power))),-(1.0,^(linear_terms,/(int879,link_power)))) +::STMT +MATRIX:ss,parsertemp31463 +FLOAT:eAvg,alpha,n +LITERAL_FLOAT:1.0 +-(*(alpha,-(/(parsertemp31463,eAvg),1.0)),*(-(1.0,alpha),-(/(n,ss),1.0))) +::STMT +LITERAL_FLOAT:1.0,0.8 ++(1.0,0.8) +::STMT +LITERAL_FLOAT:0.125 +0.125 +::STMT +MATRIX:X,Centering +LITERAL_FLOAT:2.0,1764.0 +/(colSums(^(-(X,Centering),2.0)),1764.0) +::STMT +MATRIX:intercept,X,beta +exp(+(%*%(X,beta),intercept)) +::STMT +MATRIX:A,present_domain_vals_mat,CFreqs,parsertemp27487 +FLOAT:int999 +/(sum(*(-(CFreqs,int999),%*%(present_domain_vals_mat,parsertemp27487))),-(nrow(A),nrow(present_domain_vals_mat))) +::STMT +FLOAT:int453,F1 +LITERAL_FLOAT:2.0 +*(*(*(*(F1,int453),2.0),2.0),2.0) +::STMT +MATRIX:R,parsertemp503780 +t(+(R,diag(parsertemp503780))) +::STMT +MATRIX:means,Y,vars +LITERAL_FLOAT:2.0 +/(^(-(Y,means),2.0),vars) +::STMT +MATRIX:X,parsertemp438796 +*(ncol(X),parsertemp438796) +::STMT +LITERAL_FLOAT:4.0 +4.0 +::STMT +FLOAT:n +LITERAL_FLOAT:2.0,4.0 ++(-(n,4.0),2.0) +::STMT +LITERAL_FLOAT:4.5 +4.5 +::STMT +FLOAT:start_x,i,s_cols +LITERAL_FLOAT:1.0 ++(*(-(i,1.0),s_cols),start_x) +::STMT +MATRIX:2014_cnI,parsertemp230385 +t(%*%(parsertemp230385,2014_cnI)) +::STMT +MATRIX:obj,objnew,gs +-(-(objnew,obj),gs) +::STMT +MATRIX:P,Y,dP +&(>(P,dP),Y) +::STMT +MATRIX:parsertemp44107,parsertemp44109,wnew +FLOAT:C +*(+(wnew,*(C,%*%(parsertemp44107,parsertemp44109))),+(wnew,*(C,%*%(parsertemp44107,parsertemp44109)))) +::STMT +MATRIX:G +sum(!=(rowSums(G),t(colSums(G)))) +::STMT +MATRIX:parsertemp171245,Y +LITERAL_FLOAT:1.0 +*(rowSums(Y),/(1.0,-(exp(parsertemp171245),1.0))) +::STMT +LITERAL_FLOAT:6.0 +6.0 +::STMT +MATRIX:parsertemp43626 +LITERAL_FLOAT:-1.0,2.0 +*(2.0,*(-1.0,sum(parsertemp43626))) +::STMT +LITERAL_FLOAT:5.0 +5.0 +::STMT +MATRIX:cumLeftHist,parsertemp131906,parsertemp132092,leftHist,outBucket +%*%(==(outBucket,%*%(parsertemp132092,t(parsertemp131906))),-(cumLeftHist,leftHist)) +::STMT +LITERAL_FLOAT:1.0E-9 +1.0E-9 +::STMT +LITERAL_FLOAT:2.515517 +2.515517 +::STMT +MATRIX:_sbcvar96,_sbcvar95,_sbcvar97 +FLOAT:221_my,int469 +LITERAL_FLOAT:2.0 +*(%*%(_sbcvar95,_sbcvar96),^(+(%*%(_sbcvar95,_sbcvar97),-(int469,221_my)),2.0)) +::STMT +FLOAT:n +LITERAL_FLOAT:1.0,4.0 ++(-(n,4.0),1.0) +::STMT +LITERAL_FLOAT:8.0 +8.0 +::STMT +MATRIX:prec,X,mu +-(%*%(X,prec),%*%(mu,prec)) +::STMT +LITERAL_FLOAT:9.0 +9.0 +::STMT +LITERAL_FLOAT:7.0 +7.0 +::STMT +MATRIX:M +FLOAT:parsertemp178174 ++(max(M),parsertemp178174) +::STMT +MATRIX:sample_rec_ids +FLOAT:num_records +LITERAL_FLOAT:1.0 ++(*(sample_rec_ids,<=(sample_rec_ids,num_records)),*(+(num_records,1.0),-(1.0,<=(sample_rec_ids,num_records)))) +::STMT +MATRIX:W,H +%*%(%*%(t(W),W),H) +::STMT +MATRIX:feature +LITERAL_FLOAT:1.0 ++(feature,-(1.0,min(feature))) +::STMT +MATRIX:p,ssX_p,shift_X ++(ssX_p,%*%(t(shift_X),p)) +::STMT +MATRIX:parsertemp27461,r,c,E,F +FLOAT:int686 +LITERAL_FLOAT:2.0,1.0E-4 +/(^(-(F,+(parsertemp27461,E)),2.0),+(*(==(E,int686),1.0E-4),/(%*%(r,c),sum(F)))) +::STMT +MATRIX:Q3,IQR +LITERAL_FLOAT:2.0 ++(Q3,*(2.0,IQR)) +::STMT +LITERAL_FLOAT:10.0 +10.0 +::STMT +MATRIX:linear_terms +FLOAT:link_power +LITERAL_FLOAT:-2.0,1.0 +^(linear_terms,+(-2.0,/(1.0,link_power))) +::STMT +LITERAL_FLOAT:1.0 +1.0 +::STMT +LITERAL_FLOAT:-1.0 +-1.0 +::STMT +FLOAT:parsertemp2 +cast.MATRIX(parsertemp2) +::STMT +LITERAL_FLOAT:-Infinity +-Infinity +::STMT +LITERAL_FLOAT:Infinity +Infinity +::STMT +MATRIX:W,parsertemp411110,X,H +LITERAL_FLOAT:1.0E-8 +*(W,/(%*%(X,t(H)),+(%*%(W,parsertemp411110),1.0E-8))) +::STMT +MATRIX:parsertemp459193,vW3,parsertemp459200,2703_W +FLOAT:lr,mu,float473 +-(*(mu,vW3),*(lr,+(%*%(parsertemp459200,parsertemp459193),*(float473,2703_W)))) +::STMT +MATRIX:pred +LITERAL_FLOAT:1.0E-10 ++(pred,1.0E-10) +::STMT +FLOAT:factor_up,parsertemp195892 +LITERAL_FLOAT:1.0,2.0 +-(-(*(2.0,factor_up),parsertemp195892),1.0) +::STMT +LITERAL_FLOAT:NaN +NaN +::STMT +LITERAL_FLOAT:1.5 +1.5 +::STMT +MATRIX:P1,P2,S +LITERAL_FLOAT:0.0 +!=(+(%*%(P1,S),%*%(P2,S)),0.0) +::STMT +MATRIX:parsertemp539203 +LITERAL_FLOAT:-1.0,2.0 +/(*(parsertemp539203,-1.0),2.0) +::STMT +MATRIX:parsertemp222703 +LITERAL_FLOAT:0.0,1.0 ++(rowSums(==(t(parsertemp222703),0.0)),1.0) +::STMT +MATRIX:U,row_nonzeros +FLOAT:reg +*(*(reg,U),row_nonzeros) +::STMT +MATRIX:2701_mask,2700_W,parsertemp459178,2699_dtemp,2702_X,2703_W +FLOAT:int377,float760 +%*%(*(*(>(2702_X,int377),/(2701_mask,float760)),%*%(-(2699_dtemp,parsertemp459178),t(2700_W))),t(2703_W)) +::STMT +LITERAL_FLOAT:2.0 +2.0 +::STMT +LITERAL_FLOAT:0.0 +0.0 +::STMT +LITERAL_FLOAT:-0.0 +-0.0 +::STMT +LITERAL_FLOAT:-2.0 +-2.0 +::STMT +MATRIX:parsertemp220911,g,Y +FLOAT:float687 +LITERAL_FLOAT:0.0 +-(+(Y,-(0.0,*(float687,g))),parsertemp220911) +::STMT +MATRIX:E,F +LITERAL_FLOAT:0.001 +<(-(E,F),0.001) +::STMT +MATRIX:RDMean,parsertemp265748 +LITERAL_FLOAT:2.0 +t(-(parsertemp265748,^(RDMean,2.0))) +::STMT +LITERAL_FLOAT:3.0 +3.0 +::STMT +MATRIX:svUpBnd,R,svLowBnd +*(<=(R,cast.FLOAT(svUpBnd)),>(R,cast.FLOAT(svLowBnd))) +::STMT +MATRIX:sts,d,parsertemp44021,parsertemp44023 +FLOAT:delta2 +sqrt(+(*(%*%(parsertemp44021,d),%*%(parsertemp44021,d)),*(%*%(parsertemp44023,d),-(delta2,sts)))) +::STMT +MATRIX:t,parsertemp32834,parsertemp32843,X,parsertemp32837,parsertemp32827,parsertemp32824,parsertemp32846 +FLOAT:int882,x +LITERAL_FLOAT:1.0 +*(*(/(-(x,X),-(X,X)),-(1.0,/(parsertemp32824,parsertemp32827))),+(*(-(parsertemp32834,parsertemp32837),-(int882,t)),*(+(parsertemp32843,parsertemp32846),/(parsertemp32824,parsertemp32827)))) +::STMT +MATRIX:parsertemp145796,y +FLOAT:int717 +sum(rowSums(*(*(y,int717),parsertemp145796))) +::STMT +MATRIX:Y,vec1 +FLOAT:link_power +LITERAL_FLOAT:2.0 +/(*(rowSums(Y),vec1),^(link_power,2.0)) +::STMT +MATRIX:X_Xd_exp_Xb_rev_agg,select,D_r_rev +/(%*%(select,X_Xd_exp_Xb_rev_agg),D_r_rev) +::STMT +MATRIX:s,d,alpha +FLOAT:parsertemp44015 +%*%(t(-(s,*(parsertemp44015,d))),-(s,*(cast.FLOAT(alpha),d))) +::STMT +MATRIX:X +FLOAT:int681,epsilon +<(sqrt(rowSums(^(X,int681))),epsilon) +::STMT +FLOAT:K +LITERAL_FLOAT:21.0 +*(21.0,K) +::STMT +MATRIX:neighbors,corePts,withinEps +LITERAL_FLOAT:0.0 +colSums(>(*(*(neighbors,corePts),withinEps),0.0)) +::STMT +MATRIX:y_corr,parsertemp171002 +FLOAT:int375 +LITERAL_FLOAT:0.0,1.0 +-(parsertemp171002,/(==(y_corr,0.0),-(1.0,==(y_corr,int375)))) +::STMT +MATRIX:W +FLOAT:float615,m2,wt +/(sqrt(/(*(m2,wt),-(wt,float615))),sqrt(sum(round(W)))) +::STMT +MATRIX:X +FLOAT:index +LITERAL_FLOAT:1.0 +*(index,-(ncol(X),1.0)) +::STMT +MATRIX:parsertemp472326,parsertemp472314 +-(nrow(parsertemp472314),cast.FLOAT(parsertemp472326)) +::STMT +MATRIX:b,parsertemp410078,sb +LITERAL_FLOAT:-1.0 +*(cast.FLOAT(%*%(colSums(parsertemp410078),+(b,sb))),-1.0) +::STMT +MATRIX:parsertemp24102,parsertemp24103 +FLOAT:num_bins,int935 +LITERAL_FLOAT:1.0 +-(-(1.0,<(+(parsertemp24103,int935),1.0)),>(+(round(parsertemp24102),1.0),num_bins)) +::STMT +LITERAL_FLOAT:10.0,-8.0 +^(10.0,-8.0) +::STMT +MATRIX:2792_M2 +LITERAL_FLOAT:0.0 +|(!=(2792_M2,0.0),!=(2792_M2,0.0)) +::STMT +LITERAL_FLOAT:10.0,-10.0 +^(10.0,-10.0) +::STMT +LITERAL_FLOAT:-12.0,10.0 +^(10.0,-12.0) +::STMT +MATRIX:minD,D,parsertemp222603,parsertemp222600 +t(/(<=(+(parsertemp222600,parsertemp222603),minD),rowSums(<=(D,minD)))) +::STMT +MATRIX:parsertemp222703 +LITERAL_FLOAT:0.0 +rowSums(==(t(parsertemp222703),0.0)) +::STMT +FLOAT:num_func_invoc +LITERAL_FLOAT:1.0,5.0 +-(+(num_func_invoc,5.0),1.0) +::STMT +MATRIX:ss_res_Y,var_tot_Y +FLOAT:df_ss_res_Y +/(/(ss_res_Y,df_ss_res_Y),var_tot_Y) +::STMT +MATRIX:M +LITERAL_FLOAT:0.0,2.0 +&(>(rowSums(M),0.0),<(rowSums(M),2.0)) +::STMT +MATRIX:X,permut +FLOAT:n +-(%*%(permut,X),/(colSums(%*%(permut,X)),n)) +::STMT +MATRIX:CMeans,CFreqs +FLOAT:my +LITERAL_FLOAT:2.0 +sum(*(CFreqs,^(-(CMeans,my),2.0))) +::STMT +MATRIX:B +LITERAL_FLOAT:8.0 +/(nrow(B),8.0) +::STMT +LITERAL_FLOAT:0.0873148795050037 +0.0873148795050037 +::STMT +MATRIX:W +LITERAL_FLOAT:1.0,2.0,3.0 +*(*(-(sum(W),2.0),+(sum(W),1.0)),+(sum(round(W)),3.0)) +::STMT +MATRIX:parsertemp414371 +LITERAL_FLOAT:200.0,2.0 +*(200.0,^(/(t(parsertemp414371),200.0),2.0)) +::STMT +MATRIX:X +FLOAT:x +sum(>=(X,x)) +::STMT +MATRIX:border,parsertemp386448,parsertemp386459,parsertemp386449,parsertemp386460,withinEps +FLOAT:int478,int316 +LITERAL_FLOAT:0.0 ++(*(>(*(parsertemp386448,withinEps),0.0),==(-(border,parsertemp386459),0.0)),t(*(>(parsertemp386449,int478),==(parsertemp386460,int316)))) +::STMT +LITERAL_FLOAT:10.0,-30.0 +^(10.0,-30.0) +::STMT +LITERAL_FLOAT:10.0,30.0 +^(10.0,30.0) +::STMT +MATRIX:parsertemp191275,parsertemp191273 +FLOAT:397_C ++(parsertemp191273,*(397_C,t(parsertemp191275))) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:1.0,2.0 ++(1.0,^(linear_terms,2.0)) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,128.0 ++(*(-(i,1.0),128.0),128.0) +::STMT +MATRIX:B +LITERAL_FLOAT:4.0 +/(nrow(B),4.0) +::STMT +MATRIX:237_present_domain_vals_mat,parsertemp29514,237_CFreqs +FLOAT:int194 +LITERAL_FLOAT:10000.0 +/(sum(*(-(237_CFreqs,int194),%*%(237_present_domain_vals_mat,parsertemp29514))),-(10000.0,nrow(237_present_domain_vals_mat))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,128.0,84.0 ++(*(-(i,1.0),128.0),84.0) +::STMT +MATRIX:S,parsertemp175056 +rowSums(exp(-(S,parsertemp175056))) +::STMT +MATRIX:dout +LITERAL_FLOAT:0.01 +*(0.01,dout) +::STMT +MATRIX:parsertemp122291,parsertemp122288 +LITERAL_FLOAT:0.0,4.0 +sum(|(<(t(parsertemp122288),4.0),==(t(parsertemp122291),0.0))) +::STMT +MATRIX:B +LITERAL_FLOAT:2.0 +/(nrow(B),2.0) +::STMT +MATRIX:Bx,Yd,Yu +LITERAL_FLOAT:2.0 +/(-(Yu,Yd),^(Bx,2.0)) +::STMT +MATRIX:Q1,Q3,X,IQR +FLOAT:k +|(<(X,-(Q1,*(k,IQR))),>(X,+(Q3,*(k,IQR)))) +::STMT +LITERAL_FLOAT:0.08681986202598489 +0.08681986202598489 +::STMT +FLOAT:i,k +LITERAL_FLOAT:1.0 +cast.MATRIX(-(+(i,k),1.0)) +::STMT +MATRIX:parsertemp2832 +sum(==(round(parsertemp2832),min(round(parsertemp2832)))) +::STMT +MATRIX:w +LITERAL_FLOAT:0.5 +*(0.5,%*%(t(w),w)) +::STMT +LITERAL_FLOAT:1.0,10.0 ++(10.0,1.0) +::STMT +MATRIX:vW1,dW,parsertemp459256 +FLOAT:lr,mu,float518 +-(*(mu,vW1),*(lr,+(dW,*(float518,parsertemp459256)))) +::STMT +MATRIX:p_LS +FLOAT:norm_r2_LS,parsertemp170552,lambda_LS +*(/(norm_r2_LS,*(cast.FLOAT(p_LS),+(parsertemp170552,lambda_LS))),cast.FLOAT(p_LS)) +::STMT +FLOAT:m2 +LITERAL_FLOAT:3.0,1.0005 +^(sqrt(*(1.0005,m2)),3.0) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0 +/(nrow(X),1.0) +::STMT +MATRIX:Q1,X,IQR +LITERAL_FLOAT:2.0 +<(X,-(Q1,*(2.0,IQR))) +::STMT +MATRIX:Q3,X,IQR +LITERAL_FLOAT:2.0 +>(X,+(Q3,*(2.0,IQR))) +::STMT +FLOAT:o_init +LITERAL_FLOAT:-2.0,50.0 +/(*(-2.0,o_init),50.0) +::STMT +FLOAT:m2 +LITERAL_FLOAT:4.0,1.0005 +^(sqrt(*(1.0005,m2)),4.0) +::STMT +FLOAT:std,float498,float46 +INT:int895,int207 +cast.MATRIX(*(cast.FLOAT(rand(int207,int895,float46,float498)),std)) +::STMT +FLOAT:parsertemp190484,parsertemp190485,FN,TN,FP +sqrt(*(*(*(parsertemp190484,parsertemp190485),+(TN,FP)),+(TN,FN))) +::STMT +MATRIX:parsertemp443530,resp,X +FLOAT:float889 +t(/(%*%(t(resp),X),t(+(parsertemp443530,float889)))) +::STMT +MATRIX:W,H +FLOAT:Eps ++(%*%(%*%(t(W),W),H),Eps) +::STMT +MATRIX:mean,parsertemp437236,parsertemp437235,X,weight,parsertemp437241 +FLOAT:int326 +LITERAL_FLOAT:2.0 ++(-(/(%*%(parsertemp437235,parsertemp437236),t(weight)),*(2.0,^(mean,int326))),/(*(mean,%*%(parsertemp437241,X)),t(weight))) +::STMT +MATRIX:s,d +FLOAT:parsertemp44015 +%*%(t(-(s,*(parsertemp44015,d))),d) +::STMT +MATRIX:parsertemp410987,parsertemp410978,W,H +%*%(/(*(W,parsertemp410987),t(rowSums(H))),/(*(H,t(parsertemp410978)),t(colSums(W)))) +::STMT +FLOAT:_sbcvar1799 +LITERAL_FLOAT:9.0 +-(9.0,_sbcvar1799) +::STMT +FLOAT:i +LITERAL_FLOAT:9.0 ++(i,9.0) +::STMT +MATRIX:parsertemp460644 +FLOAT:float790,2715_D +LITERAL_FLOAT:2.0 +/(*(parsertemp460644,sqrt(/(float790,2715_D))),sqrt(2.0)) +::STMT +LITERAL_FLOAT:9.999999999 +9.999999999 +::STMT +MATRIX:_sbcvar11,43_r,43_c,43_E +LITERAL_FLOAT:2.0,1000.0 +sum(/(^(-(_sbcvar11,43_E),2.0),/(%*%(43_r,43_c),1000.0))) +::STMT +MATRIX:Xd,Xw +FLOAT:step_sz ++(Xw,*(step_sz,Xd)) +::STMT +MATRIX:is_zero_y_corr,is_one_y_corr +FLOAT:parsertemp171116 +*(parsertemp171116,+(is_zero_y_corr,is_one_y_corr)) +::STMT +MATRIX:_sbcvar332,parsertemp42290 +FLOAT:float884,meanX +LITERAL_FLOAT:9999.0 +t(*(/(_sbcvar332,9999.0),-(+(parsertemp42290,float884),meanX))) +::STMT +FLOAT:KM_offset +LITERAL_FLOAT:7.0 ++(KM_offset,7.0) +::STMT +MATRIX:R,3_ss,dsep +/(+(R,dsep),3_ss) +::STMT +MATRIX:Y,parsertemp2796,Xw +LITERAL_FLOAT:0.0,1.0 +*(>(-(1.0,*(Y,Xw)),0.0),-(1.0,*(Y,+(Xw,parsertemp2796)))) +::STMT +FLOAT:i +LITERAL_FLOAT:12.0 ++(i,12.0) +::STMT +FLOAT:i +LITERAL_FLOAT:192.0 ++(192.0,i) +::STMT +FLOAT:i +LITERAL_FLOAT:11.0 ++(i,11.0) +::STMT +MATRIX:parsertemp500607,X,y,parsertemp500610 +*(-(%*%(X,*(parsertemp500607,parsertemp500610)),y),-(%*%(X,*(parsertemp500607,parsertemp500610)),y)) +::STMT +MATRIX:parsertemp436669,prec_chol,X,parsertemp436673 +FLOAT:int93,int32,int745 +LITERAL_FLOAT:2.0 +INT:parsertemp436666,int445 ++(-(*(rand(int445,parsertemp436666,int32,int93),t(parsertemp436669)),*(2.0,%*%(X,parsertemp436673))),%*%(^(X,2.0),t(^(prec_chol,int745)))) +::STMT +FLOAT:502_strideh,502_padh,502_Hin +LITERAL_FLOAT:1.0,2.0 +-(*(502_strideh,-(502_Hin,1.0)),*(2.0,502_padh)) +::STMT +FLOAT:k +LITERAL_FLOAT:4.0 ++(k,4.0) +::STMT +MATRIX:y_corr +LITERAL_FLOAT:0.0,1.0 +-(1.0,==(y_corr,0.0)) +::STMT +MATRIX:ss +FLOAT:130_n,130_alpha +LITERAL_FLOAT:1.0 +*(-(1.0,130_alpha),-(/(130_n,ss),1.0)) +::STMT +MATRIX:parsertemp44080,obj,parsertemp44076,wnew +FLOAT:C +LITERAL_FLOAT:0.5 +-(obj,+(*(0.5,%*%(parsertemp44076,wnew)),*(C,sum(parsertemp44080)))) +::STMT +MATRIX:classCounts +FLOAT:numRows +/(classCounts,numRows) +::STMT +MATRIX:parsertemp500604,w,parsertemp500601 +FLOAT:alpha,tau +*(parsertemp500604,-(abs(-(w,parsertemp500601)),/(tau,alpha))) +::STMT +FLOAT:KM_offset +LITERAL_FLOAT:6.0 ++(KM_offset,6.0) +::STMT +MATRIX:mW2,dW2 +FLOAT:193_lr,parsertemp147034,193_beta1,int779,193_t +LITERAL_FLOAT:1.0 +*(/(*(193_lr,sqrt(parsertemp147034)),-(1.0,^(193_beta1,193_t))),+(*(193_beta1,mW2),*(-(int779,193_beta1),dW2))) +::STMT +MATRIX:parsertemp40482,X2,l +/(nrow(X2),t(colSums(==(parsertemp40482,l)))) +::STMT +MATRIX:parsertemp429910 +LITERAL_FLOAT:300.0,2.0 +*(300.0,^(/(t(parsertemp429910),300.0),2.0)) +::STMT +FLOAT:w_i +LITERAL_FLOAT:5.0 ++(w_i,5.0) +::STMT +MATRIX:parsertemp171246,Y +FLOAT:int23 +LITERAL_FLOAT:1.0 +-(Y,*(Y,/(1.0,-(parsertemp171246,int23)))) +::STMT +FLOAT:run_index +LITERAL_FLOAT:48.0 +*(48.0,run_index) +::STMT +MATRIX:weightMatrix +FLOAT:threshold +LITERAL_FLOAT:0.0 +&(<(weightMatrix,threshold),>(weightMatrix,0.0)) +::STMT +MATRIX:posSampleVariances,negSampleVariances +FLOAT:int21,int125,int667,int938 +LITERAL_FLOAT:3.42951E11,2.0,3.37275E9 +/(^(+(/(posSampleVariances,int21),/(negSampleVariances,int938)),2.0),+(/(^(posSampleVariances,int667),3.42951E11),/(^(negSampleVariances,int125),3.37275E9))) +::STMT +MATRIX:p_CG +FLOAT:parsertemp285794,parsertemp285796 +LITERAL_FLOAT:-1.0 +/(-(*(cast.FLOAT(p_CG),-1.0),sqrt(-(parsertemp285794,parsertemp285796))),cast.FLOAT(%*%(t(p_CG),p_CG))) +::STMT +FLOAT:norm_Grad_initial +LITERAL_FLOAT:1.0E-4 +*(1.0E-4,norm_Grad_initial) +::STMT +MATRIX:parsertemp498247,m_iter_err_sum +FLOAT:i_process_item +LITERAL_FLOAT:0.0,2.0 +^(/(-(0.0,-(parsertemp498247,m_iter_err_sum)),i_process_item),2.0) +::STMT +FLOAT:int200,parsertemp285740,p_CG,parsertemp285763,pp_CG +*(parsertemp285763,/(-(*(p_CG,int200),sqrt(parsertemp285740)),pp_CG)) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,750.0 +/(+(classFeatureCounts,1.0),+(rowSums(classFeatureCounts),*(750.0,1.0))) +::STMT +MATRIX:P12,map +FLOAT:level +LITERAL_FLOAT:0.0 +==(rowSums(!=(%*%(map,P12),0.0)),level) +::STMT +MATRIX:W +FLOAT:m2,int91 +LITERAL_FLOAT:1.0,2.0,3.0 +*(*(*(3.0,^(m2,int91)),^(sum(W),2.0)),-(sum(round(W)),1.0)) +::STMT +MATRIX:W2_rand +LITERAL_FLOAT:0.128920512778062 +*(0.128920512778062,W2_rand) +::STMT +MATRIX:p,V +LITERAL_FLOAT:1.0E-8 +%*%(t(p),+(%*%(t(V),%*%(V,p)),*(1.0E-8,p))) +::STMT +MATRIX:X +LITERAL_FLOAT:2.0 +^(max(X),2.0) +::STMT +MATRIX:parsertemp31190,parsertemp31197 +FLOAT:parsertemp31191,parsertemp31198 +LITERAL_FLOAT:1500.0,7000.0 +sqrt(+(/(/(parsertemp31190,parsertemp31191),7000.0),/(/(parsertemp31197,parsertemp31198),1500.0))) +::STMT +LITERAL_FLOAT:0.007 +0.007 +::STMT +MATRIX:y_hat,b,R +*(-(-(b,%*%(R,y_hat)),y_hat),-(-(b,%*%(R,y_hat)),y_hat)) +::STMT +MATRIX:ytest +FLOAT:sum_y_test,n +LITERAL_FLOAT:2.0 +-(sum(^(ytest,2.0)),*(nrow(ytest),^(/(sum_y_test,n),2.0))) +::STMT +MATRIX:s,w +FLOAT:step_sz +*(+(w,*(step_sz,s)),+(w,*(step_sz,s))) +::STMT +MATRIX:dW,parsertemp459256 +FLOAT:lr +LITERAL_FLOAT:5.0E-4 +*(lr,+(dW,*(5.0E-4,parsertemp459256))) +::STMT +FLOAT:parsertemp40812,m2,int31,mu +/(sqrt(*(/(int31,parsertemp40812),m2)),mu) +::STMT +FLOAT:_sbcvar1783 +LITERAL_FLOAT:8.0 +-(8.0,_sbcvar1783) +::STMT +MATRIX:ss,se +FLOAT:parsertemp122358,int182 +LITERAL_FLOAT:1.0,0.95 +*(0.95,-(/(/(se,ss),/(parsertemp122358,int182)),1.0)) +::STMT +MATRIX:select,X_exp_Xb_rev_agg,D_r_rev,Xd_exp_Xb_rev_agg +LITERAL_FLOAT:2.0 +/(*(X_exp_Xb_rev_agg,%*%(select,Xd_exp_Xb_rev_agg)),^(D_r_rev,2.0)) +::STMT +MATRIX:Y_counts,parsertemp560508,parsertemp560522,ent1_vec +/(-(sum(rowSums(parsertemp560508)),sum(*(Y_counts,ent1_vec))),sqrt(sum(*(Y_counts,parsertemp560522)))) +::STMT +LITERAL_FLOAT:1.0,2000.0 +-(2000.0,1.0) +::STMT +FLOAT:lambda,beta +LITERAL_FLOAT:0.0 +sqrt(*(+(0.0,*(lambda,beta)),+(0.0,*(lambda,beta)))) +::STMT +MATRIX:g_Y,w +LITERAL_FLOAT:2.0 +/(^(g_Y,2.0),w) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0 +INT:int47,int606 +%*%(X,rand(int606,int47,0.0,0.0)) +::STMT +MATRIX:A +abs(-(A,t(A))) +::STMT +MATRIX:Y +sum(==(Y,max(Y))) +::STMT +MATRIX:determinants +FLOAT:nFeats +LITERAL_FLOAT:3.141592653589793,2.0 +*(^(*(2.0,3.141592653589793),nFeats),determinants) +::STMT +LITERAL_FLOAT:44.73253849269008 +44.73253849269008 +::STMT +MATRIX:L,m +FLOAT:sum +/(-(m,sum),cast.FLOAT(L)) +::STMT +MATRIX:parsertemp260755,Xd +FLOAT:dd,step_sz,wd +*(-(+(wd,*(step_sz,dd)),sum(*(parsertemp260755,Xd))),-(+(wd,*(step_sz,dd)),sum(*(parsertemp260755,Xd)))) +::STMT +MATRIX:ss +LITERAL_FLOAT:40.0 +/(40.0,ss) +::STMT +MATRIX:prec_chol,mu +FLOAT:int750 +LITERAL_FLOAT:2.0 +t(*(rowSums(^(mu,int750)),^(prec_chol,2.0))) +::STMT +MATRIX:means,Y_counts,ones_ctg +LITERAL_FLOAT:1.0 +<(*(means,%*%(Y_counts,t(ones_ctg))),1.0) +::STMT +FLOAT:int18 +LITERAL_FLOAT:0.0 +INT:int193,m +abs(rand(m,int193,0.0,int18)) +::STMT +MATRIX:probs,scores,unnorm_probs,dprobs +-(*(dprobs,/(exp(scores),rowSums(unnorm_probs))),*(/(exp(scores),rowSums(unnorm_probs)),rowSums(*(dprobs,probs)))) +::STMT +LITERAL_FLOAT:3.0,2000.0 +-(2000.0,3.0) +::STMT +FLOAT:parsertemp230731 +LITERAL_FLOAT:2.0 ++(parsertemp230731,2.0) +::STMT +MATRIX:labels +LITERAL_FLOAT:1.0 ++(labels,-(1.0,min(labels))) +::STMT +MATRIX:tmp,leftIdx +LITERAL_FLOAT:0.0 +>(%*%(tmp,%*%(t(tmp),leftIdx)),0.0) +::STMT +MATRIX:t_gp,parsertemp560875,linear_terms,parsertemp560867 +FLOAT:int721,float396 +LITERAL_FLOAT:1.0,2.0,0.254829592 +*(*(/(1.0,+(float396,parsertemp560867)),+(0.254829592,*(t_gp,parsertemp560875))),-(*(2.0,>=(linear_terms,int721)),1.0)) +::STMT +FLOAT:parsertemp191177,strideh,Hin,Hf +LITERAL_FLOAT:1.0 ++(/(-(+(Hin,parsertemp191177),Hf),strideh),1.0) +::STMT +MATRIX:parsertemp539203 +FLOAT:int975 +LITERAL_FLOAT:2.0,0.6666666666666666 +max(^(/(*(parsertemp539203,int975),2.0),0.6666666666666666)) +::STMT +MATRIX:pred,y +LITERAL_FLOAT:1.0,-1.0,1.0E-10 +*(*(/(1.0,nrow(y)),*(y,-1.0)),/(1.0,+(pred,1.0E-10))) +::STMT +FLOAT:KM_offset +LITERAL_FLOAT:3.0 ++(KM_offset,3.0) +::STMT +MATRIX:parsertemp146972,parsertemp146970,W1,191_v +FLOAT:parsertemp146984,parsertemp146982,191_epsilon +-(W1,/(*(/(parsertemp146982,parsertemp146984),+(parsertemp146970,parsertemp146972)),+(sqrt(191_v),191_epsilon))) +::STMT +MATRIX:R,dssp,dsep +/(+(R,dsep),+(R,dssp)) +::STMT +MATRIX:e,X2 +LITERAL_FLOAT:0.0 +==(t(%*%(t(e),X2)),0.0) +::STMT +LITERAL_FLOAT:1.0E-6 +INT:int996,int118 +diag(rand(int996,int118,1.0E-6,1.0E-6)) +::STMT +FLOAT:parsertemp410218,parsertemp410219,N +LITERAL_FLOAT:-1.0 +exp(/(*(-(parsertemp410218,parsertemp410219),-1.0),N)) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0 ++(i,1.0) +::STMT +MATRIX:y_prob,elt +FLOAT:int410 +LITERAL_FLOAT:1.0,1.0E7 +*(-(1.0,==(+(int410,elt),1.0E7)),-(1.0,y_prob)) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0E-7 +INT:int802,m ++(%*%(t(X),X),diag(rand(m,int802,1.0E-7,1.0E-7))) +::STMT +FLOAT:m2 +LITERAL_FLOAT:44.73253849269008,1.0005 +/(sqrt(*(1.0005,m2)),44.73253849269008) +::STMT +LITERAL_FLOAT:2.0,2000.0 +-(2000.0,2.0) +::STMT +LITERAL_FLOAT:3.42951E11 +3.42951E11 +::STMT +MATRIX:means,Y_counts,ones_ctg +LITERAL_FLOAT:5.0 +<(*(means,%*%(Y_counts,t(ones_ctg))),5.0) +::STMT +MATRIX:W2_rand +LITERAL_FLOAT:0.1651445647689541 +*(0.1651445647689541,W2_rand) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0,1.0 +INT:int447,m +%*%(X,rand(m,int447,0.0,1.0)) +::STMT +MATRIX:Y +LITERAL_FLOAT:0.0,1.0 +!=(+(Y,1.0),0.0) +::STMT +MATRIX:parsertemp379565,m_iter_err_sum +FLOAT:i_process_item +LITERAL_FLOAT:-1.0,2.0 +^(/(*(-(parsertemp379565,m_iter_err_sum),-1.0),i_process_item),2.0) +::STMT +MATRIX:252_X +FLOAT:252_X,float360 +LITERAL_FLOAT:1.0,4.5 +*(/(-(4.5,cast.FLOAT(252_X)),-(cast.FLOAT(252_X),cast.FLOAT(252_X))),-(1.0,/(-(float360,252_X),-(252_X,252_X)))) +::STMT +MATRIX:parsertemp1517,parsertemp1515 +FLOAT:int869,n +LITERAL_FLOAT:0.0,1.0 +-(1.0,<=(/(-(parsertemp1515,parsertemp1517),-(n,int869)),0.0)) +::STMT +FLOAT:_sbcvar1847 +LITERAL_FLOAT:11.0 +-(11.0,_sbcvar1847) +::STMT +FLOAT:i +LITERAL_FLOAT:1048.0 ++(i,1048.0) +::STMT +MATRIX:WM +FLOAT:parsertemp31268 +LITERAL_FLOAT:1.0 +*(-(sum(WM),1.0),/(*(parsertemp31268,sum(WM)),-(sum(WM),1.0))) +::STMT +FLOAT:i +LITERAL_FLOAT:1024.0 ++(i,1024.0) +::STMT +MATRIX:p_CG,z +FLOAT:rr_CG,pq_CG ++(z,*(/(rr_CG,pq_CG),p_CG)) +::STMT +MATRIX:ot2 +FLOAT:int897 +LITERAL_FLOAT:1500.0,100.0 +/(*(sum(>(ot2,int897)),100.0),1500.0) +::STMT +MATRIX:X +FLOAT:int17 +LITERAL_FLOAT:0.0 +INT:m,int172 +%*%(X,rand(m,int172,0.0,int17)) +::STMT +MATRIX:lambda,B_new +LITERAL_FLOAT:2.0 +sum(*(lambda,^(B_new,2.0))) +::STMT +MATRIX:W +LITERAL_FLOAT:3.0,5.0 +*(+(sum(round(W)),5.0),-(sum(round(W)),3.0)) +::STMT +MATRIX:parsertemp171326,is_lt_pos,parsertemp171330,Y,parsertemp171329 +FLOAT:one_over_sqrt_two_pi,float268 +*(one_over_sqrt_two_pi,+(-(Y,*(parsertemp171326,is_lt_pos)),*(*(parsertemp171329,parsertemp171330),-(is_lt_pos,float268)))) +::STMT +MATRIX:r,parsertemp44063,grad +LITERAL_FLOAT:-0.5 +*(-0.5,-(cast.FLOAT(%*%(parsertemp44063,grad)),cast.FLOAT(%*%(parsertemp44063,r)))) +::STMT +MATRIX:p,e,u +FLOAT:alpha +LITERAL_FLOAT:1.0 +*(-(1.0,alpha),%*%(%*%(e,u),p)) +::STMT +MATRIX:p_CG +FLOAT:rr_CG,pq_CG +*(/(rr_CG,pq_CG),p_CG) +::STMT +LITERAL_FLOAT:-0.6931471805599453 +-0.6931471805599453 +::STMT +LITERAL_FLOAT:0.6931471805599453 +0.6931471805599453 +::STMT +LITERAL_FLOAT:1.0E-7 +INT:int329,m +diag(rand(m,int329,1.0E-7,1.0E-7)) +::STMT +MATRIX:r,parsertemp1936,parsertemp1937 +FLOAT:parsertemp1941,norm_r2 +*(+(r,*(/(norm_r2,parsertemp1941),+(parsertemp1936,parsertemp1937))),+(r,*(/(norm_r2,parsertemp1941),+(parsertemp1936,parsertemp1937)))) +::STMT +MATRIX:X,Centering,ScaleFactor +FLOAT:N +/(colSums(/(-(X,Centering),ScaleFactor)),N) +::STMT +MATRIX:classFeatureCounts +FLOAT:numFeatures,laplaceCorrection +/(+(classFeatureCounts,laplaceCorrection),+(rowSums(classFeatureCounts),*(numFeatures,laplaceCorrection))) +::STMT +FLOAT:std +LITERAL_FLOAT:0.0,1.0 +INT:int654,int49 +*(cast.FLOAT(rand(int49,int654,0.0,1.0)),std) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0 ++(max(X),1.0) +::STMT +MATRIX:xs +LITERAL_FLOAT:4.5 +sum(>=(xs,4.5)) +::STMT +MATRIX:parsertemp13624,_sbcvar11 +FLOAT:int284 +LITERAL_FLOAT:2.0,1000.0 +/(^(-(_sbcvar11,/(parsertemp13624,int284)),2.0),/(%*%(rowSums(_sbcvar11),colSums(_sbcvar11)),1000.0)) +::STMT +MATRIX:R +LITERAL_FLOAT:1.0 +INT:parsertemp503363,int581 ++(R,diag(rand(parsertemp503363,int581,1.0,1.0))) +::STMT +LITERAL_FLOAT:2.22E-16 +2.22E-16 +::STMT +MATRIX:svUpBnd,R +<=(R,cast.FLOAT(svUpBnd)) +::STMT +MATRIX:vW1,dW1 +FLOAT:2727_mu,2727_lr +LITERAL_FLOAT:1.0 +*(+(1.0,2727_mu),-(*(2727_mu,vW1),*(2727_lr,dW1))) +::STMT +MATRIX:Y,predicted_Y +LITERAL_FLOAT:0.0 +/(sum(==(-(predicted_Y,Y),0.0)),nrow(Y)) +::STMT +LITERAL_FLOAT:0.025253813613805267 +0.025253813613805267 +::STMT +MATRIX:q,r +FLOAT:p,norm_r2 +t(+(r,*(/(norm_r2,p),+(q,q)))) +::STMT +MATRIX:codebook +FLOAT:j +LITERAL_FLOAT:1.0 ++(1.0,*(-(j,1.0),ncol(codebook))) +::STMT +FLOAT:_sbcvar1831 +LITERAL_FLOAT:10.0 +-(10.0,_sbcvar1831) +::STMT +FLOAT:sd_Y,sd_X +-(sqrt(sd_Y),sqrt(sd_X)) +::STMT +MATRIX:distT +LITERAL_FLOAT:0.0 +!=(distT,0.0) +::STMT +FLOAT:a,b +LITERAL_FLOAT:2.0 +*(2.0,*(a,b)) +::STMT +MATRIX:_sbcvar1006 +LITERAL_FLOAT:0.0 +>(t(_sbcvar1006),0.0) +::STMT +MATRIX:parsertemp31933,X2,parsertemp31935 +t(colSums(==(%*%(X2,parsertemp31935),t(parsertemp31933)))) +::STMT +LITERAL_FLOAT:999.0 +999.0 +::STMT +FLOAT:Hin +LITERAL_FLOAT:184.0 ++(Hin,184.0) +::STMT +MATRIX:leaf_ids +FLOAT:boundary_left,step_size +!(<(leaf_ids,+(boundary_left,step_size))) +::STMT +MATRIX:C,Xm,parsertemp265701,XtZ +FLOAT:ZtZ_sum +-(%*%(%*%(Xm,%*%(C,parsertemp265701)),t(/(XtZ,ZtZ_sum))),Xm) +::STMT +FLOAT:i +LITERAL_FLOAT:64.0 ++(i,64.0) +::STMT +MATRIX:filled_matrix,aligned +t(%*%(t(filled_matrix),aligned)) +::STMT +MATRIX:m_active_flag_tmp +LITERAL_FLOAT:1.0 +!=(m_active_flag_tmp,1.0) +::STMT +LITERAL_FLOAT:1.01 +1.01 +::STMT +MATRIX:p,r,parsertemp1934,parsertemp1935,parsertemp1940 +FLOAT:norm_r2,eps ++(r,*(/(norm_r2,cast.FLOAT(parsertemp1940)),+(%*%(parsertemp1934,parsertemp1935),*(eps,p)))) +::STMT +MATRIX:X +FLOAT:x +LITERAL_FLOAT:1.0 +-(1.0,/(-(x,X),-(X,X))) +::STMT +FLOAT:parsertemp98,int764,var,m4,parsertemp99,int59,parsertemp93,parsertemp94,wt,parsertemp105,parsertemp104 +LITERAL_FLOAT:4.0 +/(-(*(*(parsertemp93,parsertemp94),m4),*(*(parsertemp98,parsertemp99),-(wt,int59))),*(*(*(parsertemp104,parsertemp105),-(wt,int764)),^(sqrt(var),4.0))) +::STMT +MATRIX:resp,mean,X,weight +LITERAL_FLOAT:2.0 +-(/(%*%(t(resp),*(X,X)),t(weight)),*(2.0,^(mean,2.0))) +::STMT +LITERAL_FLOAT:3.141592653589793,2.0 +*(2.0,3.141592653589793) +::STMT +MATRIX:X +LITERAL_FLOAT:10.0 +!=(X,10.0) +::STMT +MATRIX:X,ScaleFactor +FLOAT:N +%*%(t(/(colSums(X),N)),/(colSums(/(X,ScaleFactor)),N)) +::STMT +FLOAT:beg +LITERAL_FLOAT:512.0 ++(beg,512.0) +::STMT +MATRIX:border,parsertemp386449,neighbors,parsertemp386460 +FLOAT:int891,int557 +LITERAL_FLOAT:0.0 +>(+(*(>(parsertemp386449,int557),==(parsertemp386460,int891)),t(*(neighbors,border))),0.0) +::STMT +MATRIX:parsertemp31188,parsertemp31186 +FLOAT:int184 +LITERAL_FLOAT:1.0,2.0,7000.0 +^(/(-(colSums(parsertemp31186),*(int184,parsertemp31188)),-(7000.0,1.0)),2.0) +::STMT +MATRIX:X,Y ++(abs(X),abs(Y)) +::STMT +MATRIX:mean,weight +%*%(*(t(mean),weight),mean) +::STMT +MATRIX:R,parsertemp40219,parsertemp40216 +FLOAT:numRows,level +/(numRows,-(+(R,rowSums(parsertemp40216)),rowSums(==(parsertemp40219,level)))) +::STMT +FLOAT:beg +LITERAL_FLOAT:256.0 ++(beg,256.0) +::STMT +FLOAT:i +LITERAL_FLOAT:253.0 ++(i,253.0) +::STMT +MATRIX:os,y,o +FLOAT:int917 +LITERAL_FLOAT:1.0 ++(1.0,exp(*(-(int917,y),+(o,os)))) +::STMT +MATRIX:X,tS +FLOAT:l +==(%*%(X,tS),l) +::STMT +LITERAL_FLOAT:2.0,83.0 +/(83.0,2.0) +::STMT +MATRIX:parsertemp171348,is_too_small,parsertemp171346,parsertemp171344,parsertemp171353,Y,the_exp,parsertemp171349 +FLOAT:int124,int429 +/(-(*(rowSums(Y),exp(parsertemp171344)),Y),+(/(*(parsertemp171348,parsertemp171349),+(the_exp,is_too_small)),*(==(parsertemp171346,int429),-(int124,parsertemp171353)))) +::STMT +FLOAT:i +LITERAL_FLOAT:3000.0 +-(3000.0,i) +::STMT +MATRIX:parsertemp400664,parsertemp400661,W3_rand +LITERAL_FLOAT:0.2656844656620286 +t(%*%(*(0.2656844656620286,W3_rand),t(/(parsertemp400661,parsertemp400664)))) +::STMT +MATRIX:240_elt,240_ones_ctg +/(240_elt,%*%(rowSums(240_elt),t(240_ones_ctg))) +::STMT +MATRIX:Bxu,Bxd ++(Bxd,Bxu) +::STMT +FLOAT:42_m2X +LITERAL_FLOAT:1.001001001001001 +*(42_m2X,1.001001001001001) +::STMT +MATRIX:parsertemp43634 +FLOAT:float614,int863,int687,float282,float925,float13 +INT:int241,int486,int281,int506 +sum(*(+(rand(int281,int486,float282,float614),*(int863,parsertemp43634)),+(rand(int506,int241,float925,float13),*(int687,parsertemp43634)))) +::STMT +MATRIX:221_present_domain_vals_mat,parsertemp27770 +sqrt(%*%(221_present_domain_vals_mat,parsertemp27770)) +::STMT +MATRIX:s +LITERAL_FLOAT:1.0,2.0 +*(1.0,sum(^(s,2.0))) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:0.0 +exp(-(0.0,linear_terms)) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:0.0 +-(0.0,exp(finite_linear_terms)) +::STMT +FLOAT:i +LITERAL_FLOAT:16.0,1.0 ++(*(-(i,1.0),16.0),1.0) +::STMT +MATRIX:Y,parsertemp221025 +FLOAT:int526 +LITERAL_FLOAT:1.0 +sum(*(/(1.0,+(Y,int526)),+(diag(parsertemp221025),1.0))) +::STMT +MATRIX:logisticnew +LITERAL_FLOAT:1.0 +*(logisticnew,-(1.0,logisticnew)) +::STMT +MATRIX:parsertemp437238,parsertemp437237,mean,weight,parsertemp437242,avgMean +FLOAT:int92,reg_covar ++(+(-(/(parsertemp437237,parsertemp437238),*(int92,avgMean)),/(*(mean,parsertemp437242),t(weight))),reg_covar) +::STMT +MATRIX:simplex +LITERAL_FLOAT:2.0,4.0 +*(2.0,/(-(rowSums(simplex),simplex),4.0)) +::STMT +MATRIX:posSamples,posSampleMeans +LITERAL_FLOAT:2.0,100.0 +-(colSums(^(posSamples,2.0)),*(100.0,^(posSampleMeans,2.0))) +::STMT +MATRIX:X2,85_s +FLOAT:alpha,int392 +LITERAL_FLOAT:1.0 +*(-(1.0,alpha),-(*(/(int392,85_s),nrow(X2)),1.0)) +::STMT +MATRIX:shift_X,beta_unscaled +cast.FLOAT(%*%(t(shift_X),beta_unscaled)) +::STMT +MATRIX:Y +FLOAT:num_categories +LITERAL_FLOAT:-1.0 ++(*(Y,-1.0),num_categories) +::STMT +LITERAL_FLOAT:24.0,1.0 +*(24.0,1.0) +::STMT +MATRIX:Nc +/(Nc,sum(Nc)) +::STMT +LITERAL_FLOAT:0.0,1.0 +INT:int540,int910,int823,int161 ++(sum(rand(int161,int823,0.0,1.0)),sum(rand(int910,int540,0.0,1.0))) +::STMT +MATRIX:w_X,z_LS +LITERAL_FLOAT:10000.0 +/(10000.0,cast.FLOAT(%*%(t(w_X),z_LS))) +::STMT +MATRIX:Y_counts +FLOAT:num_features +LITERAL_FLOAT:1.0 +-(-(sum(Y_counts),num_features),1.0) +::STMT +LITERAL_FLOAT:1.0E-9,10.0 +-(10.0,1.0E-9) +::STMT +MATRIX:parsertemp570396,classVars +FLOAT:varSmoothing +*(*(diag(parsertemp570396),max(classVars)),varSmoothing) +::STMT +MATRIX:parsertemp460643 +LITERAL_FLOAT:0.025253813613805267 +*(parsertemp460643,0.025253813613805267) +::STMT +LITERAL_FLOAT:1.0,2.0,4.0,2000.0 +*(4.0,-(^(2000.0,2.0),1.0)) +::STMT +MATRIX:Bx,Yd,Yu +/(-(Yu,Yd),*(Bx,Bx)) +::STMT +MATRIX:252_X +LITERAL_FLOAT:1.0,4.5 +-(1.0,/(-(4.5,cast.FLOAT(252_X)),-(cast.FLOAT(252_X),cast.FLOAT(252_X)))) +::STMT +LITERAL_FLOAT:0.35 +0.35 +::STMT +FLOAT:parsertemp40916,int333,m2 +LITERAL_FLOAT:2001.0 +/(sqrt(*(/(int333,parsertemp40916),m2)),sqrt(2001.0)) +::STMT +MATRIX:P,scale_X,X,Y +%*%(diag(scale_X),%*%(t(X),-(P,Y))) +::STMT +MATRIX:s,w +LITERAL_FLOAT:100.0 +*(100.0,cast.FLOAT(%*%(t(w),s))) +::STMT +FLOAT:o_init +LITERAL_FLOAT:-2.0,50.0 +exp(/(*(-2.0,o_init),50.0)) +::STMT +MATRIX:G,authorities +/(%*%(t(G),%*%(G,authorities)),max(%*%(t(G),%*%(G,authorities)))) +::STMT +MATRIX:is_natural_parameter_log_zero,Y +*(is_natural_parameter_log_zero,abs(Y)) +::STMT +FLOAT:43_q +LITERAL_FLOAT:1.0,1000.0 +*(1000.0,-(43_q,1.0)) +::STMT +FLOAT:m2X,W +LITERAL_FLOAT:1.0 +*(m2X,/(W,-(W,1.0))) +::STMT +MATRIX:r,Hd +FLOAT:c +t(+(r,*(c,Hd))) +::STMT +MATRIX:TKC +/(cast.FLOAT(TKC),cast.FLOAT(TKC)) +::STMT +LITERAL_FLOAT:0.5,-0.5 +INT:rank,m +rand(m,rank,-0.5,0.5) +::STMT +MATRIX:parsertemp382917,U,W +t(%*%(t(U),*(W,%*%(U,parsertemp382917)))) +::STMT +LITERAL_FLOAT:1.0E8 +1.0E8 +::STMT +FLOAT:int384,i,Hin,Win +LITERAL_FLOAT:1.0 ++(*(*(-(i,int384),Hin),Win),1.0) +::STMT +MATRIX:X,weight +/(weight,nrow(X)) +::STMT +MATRIX:a,b,t,parsertemp32856,Y,parsertemp32827,parsertemp32824 +FLOAT:int228,int23 ++(+(*(-(int228,t),Y),*(/(parsertemp32824,parsertemp32827),Y)),*(*(/(parsertemp32824,parsertemp32827),-(int23,t)),+(*(a,parsertemp32856),*(b,t)))) +::STMT +MATRIX:parsertemp30951,G,authorities,hubs +-(/(%*%(t(G),%*%(G,authorities)),max(%*%(parsertemp30951,hubs))),authorities) +::STMT +FLOAT:_sbcvar1735 +LITERAL_FLOAT:12.0 +-(12.0,_sbcvar1735) +::STMT +FLOAT:i,num_centroids +LITERAL_FLOAT:2.0 ++(*(num_centroids,2.0),i) +::STMT +MATRIX:parsertemp150470,LT,parsertemp149320,parsertemp150469 +/(exp(-(LT,%*%(parsertemp149320,parsertemp150469))),%*%(rowSums(exp(LT)),parsertemp150470)) +::STMT +MATRIX:w,out +FLOAT:reg +LITERAL_FLOAT:2.0,0.5 ++(*(0.5,sum(*(out,out))),*(/(reg,2.0),sum(*(w,w)))) +::STMT +MATRIX:H_inv +sqrt(diag(H_inv)) +::STMT +MATRIX:parsertemp220853,W,sum_Pi,beta +FLOAT:logU +-(+(parsertemp220853,*(beta,/(W,sum_Pi))),logU) +::STMT +MATRIX:meanDiff,parsertemp570372,parsertemp570375 +LITERAL_FLOAT:-1.0,1.0,2.0 +-(*(/(-1.0,2.0),parsertemp570372),*(/(1.0,2.0),%*%(%*%(meanDiff,parsertemp570375),t(meanDiff)))) +::STMT +MATRIX:W,parsertemp411198,H,parsertemp411200 +LITERAL_FLOAT:1.0E-8 ++(%*%(W,/(*(H,parsertemp411198),t(parsertemp411200))),1.0E-8) +::STMT +FLOAT:parsertemp190487,parsertemp190486,FN,TN,FP,TP +/(-(*(TP,TN),*(FP,FN)),sqrt(*(*(parsertemp190486,parsertemp190487),+(TN,FN)))) +::STMT +MATRIX:vW1,parsertemp146961,dout1 +FLOAT:191_beta2 +LITERAL_FLOAT:1.0,2.0 ++(*(191_beta2,vW1),*(-(1.0,191_beta2),^(%*%(parsertemp146961,dout1),2.0))) +::STMT +MATRIX:r,parsertemp1945 +FLOAT:norm_r2 +LITERAL_FLOAT:2.0 +/(sum(^(+(r,parsertemp1945),2.0)),norm_r2) +::STMT +MATRIX:WM +LITERAL_FLOAT:1.0 +/(sum(WM),-(sum(WM),1.0)) +::STMT +MATRIX:output_values,initial_prediction +LITERAL_FLOAT:0.3 ++(initial_prediction,*(0.3,sum(output_values))) +::STMT +FLOAT:so_exact,so_linear_approx +LITERAL_FLOAT:-0.5 +/(*(-0.5,so_linear_approx),-(so_exact,so_linear_approx)) +::STMT +MATRIX:X +LITERAL_FLOAT:2.0 +rowSums(^(X,2.0)) +::STMT +MATRIX:p,z +LITERAL_FLOAT:-1.0 +*(sum(*(p,z)),-1.0) +::STMT +MATRIX:LT,Y,parsertemp149320,parsertemp150469 +sum(*(Y,-(LT,%*%(parsertemp149320,parsertemp150469)))) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:0.0,2.0 +-(0.0,^(finite_linear_terms,2.0)) +::STMT +LITERAL_FLOAT:40.0,20.0 +*(20.0,40.0) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:0.0,1.0,2.0 +-(*(2.0,>=(linear_terms,0.0)),1.0) +::STMT +MATRIX:W +LITERAL_FLOAT:1.0,2.0 +*(-(sum(round(W)),1.0),-(sum(round(W)),2.0)) +::STMT +MATRIX:initial_prediction +INT:int744,parsertemp186173 +rand(parsertemp186173,int744,cast.FLOAT(initial_prediction),cast.FLOAT(initial_prediction)) +::STMT +MATRIX:s,w +sum(*(w,s)) +::STMT +MATRIX:252_X +LITERAL_FLOAT:4.5 +-(4.5,cast.FLOAT(252_X)) +::STMT +LITERAL_FLOAT:1.0,2.0,2003.0 +*(-(2003.0,2.0),+(2003.0,1.0)) +::STMT +MATRIX:scale_lambda +LITERAL_FLOAT:0.001 +diag(*(scale_lambda,0.001)) +::STMT +MATRIX:out1,187_dX,parsertemp146955 +FLOAT:int533 +LITERAL_FLOAT:2.0 +^(colSums(*(>(out1,int533),*(parsertemp146955,187_dX))),2.0) +::STMT +MATRIX:R +LITERAL_FLOAT:0.0,4.0 +&(>=(R,4.0),>(R,0.0)) +::STMT +MATRIX:precisions,X,parsertemp436695,bc_matrix,parsertemp436691 +LITERAL_FLOAT:2.0 +-(*(bc_matrix,t(*(parsertemp436691,precisions))),*(2.0,%*%(X,t(parsertemp436695)))) +::STMT +MATRIX:grad +LITERAL_FLOAT:0.0,2.0 +^(-(0.0,grad),2.0) +::STMT +MATRIX:id +==(id,t(id)) +::STMT +FLOAT:link_power +LITERAL_FLOAT:1.0 +-(/(1.0,link_power),1.0) +::STMT +FLOAT:link_power +LITERAL_FLOAT:-1.0,1.0 +-(/(-1.0,link_power),1.0) +::STMT +MATRIX:parsertemp10743,V,parsertemp10742,H,parsertemp10739,parsertemp10738 +FLOAT:Eps +%*%(*(H,/(%*%(parsertemp10738,V),+(parsertemp10742,Eps))),t(*(H,/(parsertemp10739,parsertemp10743)))) +::STMT +MATRIX:W +FLOAT:m2 +LITERAL_FLOAT:1.0 +/(*(m2,sum(round(W))),-(sum(round(W)),1.0)) +::STMT +MATRIX:parsertemp44076,wnew,parsertemp44079 +LITERAL_FLOAT:-1.0,2.0,0.5 ++(*(0.5,cast.FLOAT(%*%(parsertemp44076,wnew))),*(2.0,*(-1.0,sum(parsertemp44079)))) +::STMT +LITERAL_FLOAT:1.0,2.0,2001.0 +*(-(2001.0,2.0),+(2001.0,1.0)) +::STMT +MATRIX:A,foffb +LITERAL_FLOAT:0.0 +*(!=(A,0.0),+(A,foffb)) +::STMT +MATRIX:parsertemp397841,parsertemp397838,W4_rand +LITERAL_FLOAT:0.0873148795050037 +t(%*%(*(0.0873148795050037,W4_rand),t(/(parsertemp397838,parsertemp397841)))) +::STMT +MATRIX:parsertemp220900,parsertemp220899 +LITERAL_FLOAT:300.0,0.0,2.0 +^(-(0.0,*(300.0,-(parsertemp220899,parsertemp220900))),2.0) +::STMT +MATRIX:lambda,g,beta +LITERAL_FLOAT:0.0 +-(0.0,+(g,*(lambda,beta))) +::STMT +MATRIX:parsertemp76118 +LITERAL_FLOAT:0.5,4460.0 +round(+(0.5,/(parsertemp76118,4460.0))) +::STMT +MATRIX:knn_index +FLOAT:iter,i ++(*(iter,ncol(knn_index)),i) +::STMT +MATRIX:output,output1 +LITERAL_FLOAT:0.1 +sum(>=(abs(-(output,output1)),0.1)) +::STMT +MATRIX:n_event_stratum,n_risk_stratum,n_risk +FLOAT:int472 +LITERAL_FLOAT:1.0 +/(*(*(^(n_risk_stratum,int472),*(n_risk,n_event_stratum)),-(n_risk_stratum,n_event_stratum)),*(n_risk_stratum,-(n_risk_stratum,1.0))) +::STMT +MATRIX:X +FLOAT:parsertemp165083 +LITERAL_FLOAT:2.0 ++(*(2.0,ncol(X)),*(nrow(X),parsertemp165083)) +::STMT +FLOAT:float538,int243,42_m2X +LITERAL_FLOAT:1000.0 +sqrt(*(42_m2X,/(1000.0,-(int243,float538)))) +::STMT +MATRIX:C,Xm,parsertemp265706,parsertemp265702,parsertemp265701 +FLOAT:ss ++(%*%(t(%*%(Xm,parsertemp265702)),%*%(Xm,%*%(C,parsertemp265701))),*(parsertemp265706,ss)) +::STMT +FLOAT:parsertemp115814,sum_sq_y_test,n,ss_res +LITERAL_FLOAT:1.0 +-(1.0,/(ss_res,-(sum_sq_y_test,*(n,parsertemp115814)))) +::STMT +MATRIX:parsertemp560507,Y +sum(rowSums(*(Y,parsertemp560507))) +::STMT +FLOAT:parsertemp382948,parsertemp382956,loss_init,parsertemp382953 +LITERAL_FLOAT:0.5,5.0E-7 +-(loss_init,+(*(0.5,parsertemp382948),*(5.0E-7,+(parsertemp382953,parsertemp382956)))) +::STMT +MATRIX:p_CG +FLOAT:parsertemp285739,parsertemp285737,pp_CG +LITERAL_FLOAT:-1.0 +/(+(*(cast.FLOAT(p_CG),-1.0),sqrt(-(parsertemp285737,parsertemp285739))),pp_CG) +::STMT +MATRIX:p,q,A +FLOAT:norm_r2 +*(/(norm_r2,sum(*(p,q))),%*%(t(A),%*%(A,p))) +::STMT +MATRIX:X +FLOAT:n +/(t(colSums(X)),n) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:1.0,0.231641888 +/(1.0,+(1.0,*(abs(finite_linear_terms),0.231641888))) +::STMT +MATRIX:parsertemp2832 +min(round(parsertemp2832)) +::STMT +MATRIX:parsertemp11277 +FLOAT:block_size +LITERAL_FLOAT:1.0 ++(1.0,*(block_size,parsertemp11277)) +::STMT +MATRIX:objvals +LITERAL_FLOAT:1.5000000000000002E-8 +*(1.5000000000000002E-8,cast.FLOAT(objvals)) +::STMT +FLOAT:std,float481,float926 +LITERAL_FLOAT:2.0 +INT:int300,int902 +^(*(cast.FLOAT(rand(int300,int902,float481,float926)),std),2.0) +::STMT +MATRIX:R,parsertemp40216 +FLOAT:numRows +LITERAL_FLOAT:1.0 +-(/(numRows,+(R,rowSums(parsertemp40216))),1.0) +::STMT +MATRIX:parsertemp147200,X_train +LITERAL_FLOAT:2.0 +*(parsertemp147200,sqrt(/(2.0,ncol(X_train)))) +::STMT +LITERAL_FLOAT:1.0,2.0,2003.0 +-(^(2003.0,2.0),1.0) +::STMT +MATRIX:categorical,X_sys,freq,mask +LITERAL_FLOAT:0.0 ++(*(X_sys,==(mask,0.0)),*(>(categorical,0.0),freq)) +::STMT +MATRIX:id +diag(diag(==(id,t(id)))) +::STMT +LITERAL_FLOAT:1.0,2.0,2000.0 +*(-(2000.0,2.0),+(2000.0,1.0)) +::STMT +MATRIX:parsertemp77570 +LITERAL_FLOAT:0.5,2358.0 +round(+(0.5,/(parsertemp77570,2358.0))) +::STMT +MATRIX:parsertemp379566,m_iter_err_sum,m_err +FLOAT:int404,i_process_item +LITERAL_FLOAT:2.0 +*(*(2.0,/(*(parsertemp379566,int404),i_process_item)),+(colSums(m_err),m_iter_err_sum)) +::STMT +FLOAT:m2,mu +LITERAL_FLOAT:1.0005002501250626 +/(sqrt(*(1.0005002501250626,m2)),mu) +::STMT +MATRIX:r_CG,g_reg,z +LITERAL_FLOAT:0.5 +*(0.5,*(cast.FLOAT(z),+(cast.FLOAT(r_CG),cast.FLOAT(g_reg)))) +::STMT +LITERAL_FLOAT:0.231641888 +0.231641888 +::STMT +MATRIX:W +FLOAT:int553,m3,var,wt,int628 +LITERAL_FLOAT:2.0,3.0 +/(*(^(sum(W),2.0),m3),*(*(-(wt,int553),-(wt,int628)),^(sqrt(var),3.0))) +::STMT +MATRIX:p,r +FLOAT:norm_r2 +*(/(sum(*(r,r)),norm_r2),p) +::STMT +MATRIX:parsertemp116094,parsertemp116097 +LITERAL_FLOAT:0.0,32.0 +sum(|(<(t(parsertemp116094),32.0),==(t(parsertemp116097),0.0))) +::STMT +FLOAT:link_power +LITERAL_FLOAT:1.0,2.0 +-(/(1.0,link_power),2.0) +::STMT +MATRIX:A,scale_X +%*%(diag(scale_X),A) +::STMT +FLOAT:sv,rad,v2 +/(-(rad,sv),v2) +::STMT +MATRIX:B2,ytest,Xtest +t(-(ytest,%*%(Xtest,B2))) +::STMT +MATRIX:V +min(V) +::STMT +MATRIX:diff_nominal,diff,mask +FLOAT:num_std_median +LITERAL_FLOAT:0.0 ++(*(!=(diff_nominal,0.0),num_std_median),*(diff,==(mask,0.0))) +::STMT +MATRIX:s,parsertemp44016,d +*(%*%(t(-(s,parsertemp44016)),d),%*%(t(-(s,parsertemp44016)),d)) +::STMT +MATRIX:col +FLOAT:min_val,bin_width +LITERAL_FLOAT:0.5 +-(/(-(col,min_val),bin_width),0.5) +::STMT +LITERAL_FLOAT:0.7 +0.7 +::STMT +MATRIX:Y_counts,means,Y +%*%(Y_counts,/(colSums(-(Y,means)),sum(Y_counts))) +::STMT +FLOAT:p,P +LITERAL_FLOAT:1.0 ++(+(1.0,p),P) +::STMT +FLOAT:int494,parsertemp115813,sum_sq_y_test,n,ss_res +/(ss_res,-(sum_sq_y_test,*(n,^(parsertemp115813,int494)))) +::STMT +FLOAT:a,c +LITERAL_FLOAT:4.0 +*(*(4.0,a),c) +::STMT +LITERAL_FLOAT:0.95 +0.95 +::STMT +MATRIX:parsertemp409058,parsertemp409054,ctab +LITERAL_FLOAT:0.6 +*(parsertemp409058,>(/(parsertemp409054,rowSums(ctab)),0.6)) +::STMT +MATRIX:cov +LITERAL_FLOAT:1.0 +/(1.0,sqrt(cov)) +::STMT +FLOAT:m2 +LITERAL_FLOAT:2.0 +^(m2,2.0) +::STMT +FLOAT:parsertemp459295 +LITERAL_FLOAT:1.0,128.0 ++(+(parsertemp459295,1.0),128.0) +::STMT +MATRIX:parsertemp472305,_funvar2708,Iright,_funvar2706,_funvar2707 +FLOAT:numI +-(-(cast.FLOAT(_funvar2706),*(/(parsertemp472305,numI),_funvar2707)),*(/(rowSums(Iright),numI),_funvar2708)) +::STMT +MATRIX:parsertemp170251,lt_pos_neg +FLOAT:int953 +LITERAL_FLOAT:2.0,0.5 +*(-(0.5,lt_pos_neg),exp(/(*(parsertemp170251,int953),2.0))) +::STMT +MATRIX:Xd,out +FLOAT:int515 +sum(*(*(Xd,>(out,int515)),Xd)) +::STMT +MATRIX:parsertemp500439,y +LITERAL_FLOAT:0.5 +*(0.5,sum(*(-(parsertemp500439,y),-(parsertemp500439,y)))) +::STMT +MATRIX:oldE +LITERAL_FLOAT:1.0 +/(sum(oldE),1.0) +::STMT +MATRIX:csgaps,csmask +*(csgaps,>(csgaps,csmask)) +::STMT +MATRIX:X_cluster_local,X_comp,X_sim +|(X_cluster_local,*(X_comp,X_sim)) +::STMT +MATRIX:2364_2360_Y_prime,W2,W3,2364_2359_Y,parsertemp389610 +FLOAT:int704 +LITERAL_FLOAT:1.0 +%*%(*(-(1.0,^(2364_2359_Y,int704)),%*%(*(2364_2360_Y_prime,parsertemp389610),W3)),W2) +::STMT +LITERAL_FLOAT:1.0E-8 +1.0E-8 +::STMT +MATRIX:Y,parsertemp2773,Xw +LITERAL_FLOAT:0.0,1.0 +>(-(1.0,*(Y,+(Xw,parsertemp2773))),0.0) +::STMT +MATRIX:W,H +LITERAL_FLOAT:1.0E-8 ++(%*%(W,H),1.0E-8) +::STMT +MATRIX:A,b +LITERAL_FLOAT:-1.0,2.0 +^(%*%(*(t(A),-1.0),b),2.0) +::STMT +MATRIX:C,C_old +LITERAL_FLOAT:2.0 +sum(^(-(C,C_old),2.0)) +::STMT +MATRIX:P,lambda,X,Y,B_new ++(%*%(t(X),-(P,Y)),*(lambda,B_new)) +::STMT +MATRIX:Xtest_dists +LITERAL_FLOAT:0.0,1.0 +rowSums(*(<=(Xtest_dists,1.0),<(0.0,Xtest_dists))) +::STMT +LITERAL_FLOAT:16.0,15.0 +*(15.0,16.0) +::STMT +MATRIX:parsertemp414376,parsertemp414378 +LITERAL_FLOAT:0.0,1.0,199.0 +-(1.0,<=(/(-(parsertemp414376,parsertemp414378),199.0),0.0)) +::STMT +LITERAL_FLOAT:0.05473123640475826 +0.05473123640475826 +::STMT +FLOAT:parsertemp164939 +LITERAL_FLOAT:100.0 +*(100.0,parsertemp164939) +::STMT +MATRIX:r,parsertemp1936,parsertemp1937 +FLOAT:parsertemp1941,norm_r2 +LITERAL_FLOAT:-1.0 +*(+(r,*(/(norm_r2,parsertemp1941),+(parsertemp1936,parsertemp1937))),-1.0) +::STMT +MATRIX:_sbcvar1716 +LITERAL_FLOAT:0.8 +*(_sbcvar1716,0.8) +::STMT +MATRIX:A +rowSums(abs(A)) +::STMT +MATRIX:parsertemp30951,G,authorities,hubs +-(/(%*%(t(G),%*%(G,authorities)),max(%*%(parsertemp30951,hubs))),authorities) +::STMT +MATRIX:negSampleMeans,negSamples +FLOAT:int960,int292 +LITERAL_FLOAT:1.0,1500.0 +/(-(colSums(^(negSamples,int960)),*(1500.0,^(negSampleMeans,int292))),-(1500.0,1.0)) +::STMT +MATRIX:X,Y +FLOAT:x +*(/(-(x,X),-(X,X)),Y) +::STMT +LITERAL_FLOAT:1.0,10000.0,0.8 ++(*(10000.0,0.8),1.0) +::STMT +MATRIX:Hdiff,betamax,beta +FLOAT:int762,int537 +LITERAL_FLOAT:1.0E20 +==(+(*(>=(Hdiff,int537),betamax),*(<(Hdiff,int762),beta)),1.0E20) +::STMT +MATRIX:addedE +LITERAL_FLOAT:20.0 +/(sum(addedE),20.0) +::STMT +MATRIX:parsertemp570372 +LITERAL_FLOAT:-1.0,2.0 +*(/(-1.0,2.0),parsertemp570372) +::STMT +MATRIX:parsertemp43634 +FLOAT:int332 +LITERAL_FLOAT:0.0,2.0 +sum(^(+(0.0,*(int332,parsertemp43634)),2.0)) +::STMT +MATRIX:dotMissing,parsertemp553021,dotM2 +FLOAT:int159 +t(sqrt(-(+(dotM2,dotMissing),*(int159,parsertemp553021)))) +::STMT +MATRIX:parsertemp436043 +LITERAL_FLOAT:1.0 +INT:int684,n_col +%*%(parsertemp436043,rand(int684,n_col,1.0,1.0)) +::STMT +MATRIX:features,beta_unscaled +FLOAT:intercept,parsertemp176418 +LITERAL_FLOAT:3.0 +-(sqrt(parsertemp176418),*(3.0,+(%*%(features,beta_unscaled),intercept))) +::STMT +MATRIX:X,I +LITERAL_FLOAT:1.0 +-(/(nrow(X),t(colSums(I))),1.0) +::STMT +MATRIX:parsertemp506990 +LITERAL_FLOAT:0.7 +<(parsertemp506990,0.7) +::STMT +MATRIX:252_K +LITERAL_FLOAT:0.0 +-(0.0,cast.FLOAT(252_K)) +::STMT +MATRIX:addedE +LITERAL_FLOAT:40.0 +/(sum(addedE),40.0) +::STMT +LITERAL_FLOAT:8.674675786448736 +8.674675786448736 +::STMT +MATRIX:e,X,tS +FLOAT:l +%*%(t(e),==(%*%(X,tS),l)) +::STMT +MATRIX:_sbcvar332 +LITERAL_FLOAT:9999.0 +/(_sbcvar332,9999.0) +::STMT +MATRIX:TK +LITERAL_FLOAT:0.0 ++(TK,==(TK,0.0)) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:-1.0,1.0 +-(exp(*(linear_terms,-1.0)),1.0) +::STMT +MATRIX:parsertemp31908,X +FLOAT:l +/(nrow(X),t(colSums(==(parsertemp31908,l)))) +::STMT +MATRIX:p,Z +cast.FLOAT(%*%(t(p),%*%(Z,p))) +::STMT +MATRIX:W +FLOAT:m2,int169 +LITERAL_FLOAT:1.0,2.0,3.0 +*(*(*(3.0,^(m2,int169)),^(sum(W),2.0)),-(sum(round(W)),1.0)) +::STMT +MATRIX:parsertemp43619 +LITERAL_FLOAT:1.0 +-(/(1.0,+(1.0,exp(parsertemp43619))),1.0) +::STMT +MATRIX:minD,parsertemp222602,parsertemp222599 +FLOAT:int967 +rowSums(<=(+(*(int967,parsertemp222599),t(parsertemp222602)),minD)) +::STMT +FLOAT:num_hidden1,m +LITERAL_FLOAT:6.0 +/(sqrt(6.0),sqrt(+(m,num_hidden1))) +::STMT +FLOAT:pad_size,Hin +LITERAL_FLOAT:1.0 +-(Hin,-(pad_size,1.0)) +::STMT +MATRIX:R,parsertemp500360,parsertemp500307,parsertemp500359 +FLOAT:int52 ++(%*%(rowSums(^(R,int52)),parsertemp500359),%*%(parsertemp500360,t(rowSums(parsertemp500307)))) +::STMT +MATRIX:RDMean,parsertemp265748 +LITERAL_FLOAT:2.0 +-(parsertemp265748,^(RDMean,2.0)) +::STMT +FLOAT:float503,float111 +LITERAL_FLOAT:1.0 +INT:int154,int585 +/(1.0,+(1.0,exp(rand(int585,int154,float503,float111)))) +::STMT +MATRIX:parsertemp460642 +LITERAL_FLOAT:0.05 +*(parsertemp460642,0.05) +::STMT +MATRIX:Y,missing_mask_Y +LITERAL_FLOAT:0.0,1.0 ++(*(missing_mask_Y,+(max(Y),1.0)),*(Y,==(missing_mask_Y,0.0))) +::STMT +LITERAL_FLOAT:1.0,1000.0 +-(1000.0,1.0) +::STMT +MATRIX:vW2,dW2 +FLOAT:193_beta2 +LITERAL_FLOAT:1.0,2.0 ++(*(193_beta2,vW2),*(-(1.0,193_beta2),^(dW2,2.0))) +::STMT +MATRIX:F +%*%(rowSums(F),colSums(F)) +::STMT +MATRIX:parsertemp146940,184_dtemp,mb3 +FLOAT:beta1 +LITERAL_FLOAT:1.0 ++(*(beta1,mb3),*(-(1.0,beta1),colSums(-(184_dtemp,parsertemp146940)))) +::STMT +MATRIX:S,V +LITERAL_FLOAT:2.0 +^(sum(*(S,V)),2.0) +::STMT +MATRIX:tmp,X ++(%*%(t(X),X),diag(tmp)) +::STMT +MATRIX:P,gradients,Theta +FLOAT:alpha +*(alpha,%*%(t(gradients),%*%(P,Theta))) +::STMT +MATRIX:parsertemp389212 +LITERAL_FLOAT:1058.0 +/(parsertemp389212,1058.0) +::STMT +MATRIX:linear_terms +FLOAT:link_power +LITERAL_FLOAT:0.0,1.0 +^(linear_terms,-(/(0.0,link_power),1.0)) +::STMT +FLOAT:parsertemp22485,parsertemp22452,parsertemp22453 +LITERAL_FLOAT:2.0 ++(parsertemp22485,*(2.0,sqrt(+(parsertemp22452,parsertemp22453)))) +::STMT +MATRIX:parsertemp10964,C +==(parsertemp10964,C) +::STMT +MATRIX:parsertemp146931,184_dtemp,parsertemp146929,184_unnorm_probs,parsertemp146936,W3 +%*%(-(*(*(parsertemp146929,parsertemp146931),/(184_unnorm_probs,parsertemp146936)),*(/(184_unnorm_probs,parsertemp146936),rowSums(184_dtemp))),t(W3)) +::STMT +MATRIX:linear_terms +FLOAT:link_power +LITERAL_FLOAT:2.0 +^(linear_terms,-(/(2.0,link_power),2.0)) +::STMT +MATRIX:linear_terms +FLOAT:link_power +LITERAL_FLOAT:0.0,2.0 +^(linear_terms,-(/(0.0,link_power),2.0)) +::STMT +FLOAT:s_rows,h +LITERAL_FLOAT:2.0 +/(-(s_rows,h),2.0) +::STMT +FLOAT:i +LITERAL_FLOAT:314.0 ++(314.0,i) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:1.0 +/(1.0,+(1.0,exp(linear_terms))) +::STMT +LITERAL_FLOAT:1.0,100.0 +INT:int212,int982 +rand(int212,int982,1.0,100.0) +::STMT +MATRIX:parsertemp181045 +FLOAT:window_size,q,parsertemp181038 +LITERAL_FLOAT:1.0 +-(1.0,/(-(q,*(window_size,parsertemp181038)),*(window_size,cast.FLOAT(parsertemp181045)))) +::STMT +MATRIX:col_nonzeros,parsertemp383019,parsertemp383016,row_nonzeros +FLOAT:reg +*(reg,+(sum(*(parsertemp383016,row_nonzeros)),sum(*(parsertemp383019,col_nonzeros)))) +::STMT +MATRIX:output1,dataset +LITERAL_FLOAT:0.16 +sum(>=(abs(-(output1,dataset)),0.16)) +::STMT +LITERAL_FLOAT:1.0,2.0,7000.0 +*(^(7000.0,2.0),-(7000.0,1.0)) +::STMT +MATRIX:P,scale_X,shift_X,X,Y,Grad ++(%*%(diag(scale_X),%*%(t(X),-(P,Y))),%*%(shift_X,Grad)) +::STMT +MATRIX:g_new,s,g_old +*(/(sum(*(g_new,g_new)),sum(*(g_old,g_old))),s) +::STMT +MATRIX:centroid_placer,All_Centroids,X_samples ++(All_Centroids,%*%(centroid_placer,%*%(centroid_placer,X_samples))) +::STMT +MATRIX:C,tmp,XtZ +FLOAT:ZtZ_sum +trace(*(tmp,%*%(t(C),/(XtZ,ZtZ_sum)))) +::STMT +MATRIX:ytest +FLOAT:mean_y_test,int501,int192 +LITERAL_FLOAT:1.0 +/(-(sum(^(ytest,int501)),*($1:nrow(ytest),^(mean_y_test,int192))),-($1,1.0)) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0004995004995005,44.75488800120049 +/(sqrt(*(1.0004995004995005,m2)),44.75488800120049) +::STMT +LITERAL_FLOAT:0.5107539184552492 +0.5107539184552492 +::STMT +FLOAT:Woutc20,Houtc20,F1 +LITERAL_FLOAT:1.0 ++(*(*(F1,Houtc20),Woutc20),1.0) +::STMT +LITERAL_FLOAT:1.0005 +1.0005 +::STMT +MATRIX:e_r_rev_agg,Xi_agg_rev_agg,X_agg +LITERAL_FLOAT:2.0 +/(*(X_agg,Xi_agg_rev_agg),^(e_r_rev_agg,2.0)) +::STMT +LITERAL_FLOAT:12.0,4.0 +*(12.0,4.0) +::STMT +MATRIX:z +FLOAT:trust_delta_sq +-(sum(*(z,z)),trust_delta_sq) +::STMT +LITERAL_FLOAT:1.0E-12 +INT:int210,int691 +rand(int691,int210,1.0E-12,1.0E-12) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0 +-(0.0,cast.MATRIX(sum(X))) +::STMT +MATRIX:parsertemp443530,parsertemp443534,resp,parsertemp443533,X +FLOAT:float582 +LITERAL_FLOAT:2.22E-16 +%*%(*(t(/(parsertemp443533,parsertemp443534)),+(colSums(resp),2.22E-16)),/(%*%(t(resp),X),t(+(parsertemp443530,float582)))) +::STMT +FLOAT:i,j +LITERAL_FLOAT:1.0,10.0 ++(*(-(i,1.0),10.0),j) +::STMT +FLOAT:norm_r2_LS,p_LS,parsertemp170552,lambda_LS +/(norm_r2_LS,*(cast.FLOAT(p_LS),+(*(parsertemp170552,p_LS),*(lambda_LS,p_LS)))) +::STMT +FLOAT:q +LITERAL_FLOAT:1.0,10000.0 +*(10000.0,-(q,1.0)) +::STMT +LITERAL_FLOAT:12.0,8.0 +*(12.0,8.0) +::STMT +MATRIX:parsertemp472359,I +LITERAL_FLOAT:0.0 +*(I,==(*(t(parsertemp472359),I),0.0)) +::STMT +MATRIX:Y +sum(==(Y,min(Y))) +::STMT +FLOAT:var_lag,xq_lag,arch_coef,var_coef,a0 +INT:int818,int723 +rand(int818,int723,+(+(a0,*(arch_coef,xq_lag)),*(var_coef,var_lag)),+(+(a0,*(arch_coef,xq_lag)),*(var_coef,var_lag))) +::STMT +MATRIX:means,parsertemp560530 +LITERAL_FLOAT:1.0 +/(sum(<(*(means,parsertemp560530),1.0)),*(nrow(means),ncol(means))) +::STMT +MATRIX:classCounts +LITERAL_FLOAT:10000.0 +/(classCounts,10000.0) +::STMT +MATRIX:ones,classFeatureCounts +LITERAL_FLOAT:1.0,750.0 +%*%(+(rowSums(classFeatureCounts),*(750.0,1.0)),ones) +::STMT +MATRIX:Y_prob +LITERAL_FLOAT:0.0,1.0 +*(Y_prob,-(1.0,<=(Y_prob,0.0))) +::STMT +LITERAL_FLOAT:12.0 +*(12.0,12.0) +::STMT +MATRIX:P,R,I,L +LITERAL_FLOAT:0.0 +*(==(%*%(P,I),0.0),%*%(%*%(P,L),R)) +::STMT +MATRIX:E +LITERAL_FLOAT:2.0,0.5 +*(0.5,sum(^(E,2.0))) +::STMT +LITERAL_FLOAT:12.0,40.0 +*(12.0,40.0) +::STMT +MATRIX:P,X,Y +LITERAL_FLOAT:2.0 +^(%*%(t(X),-(P,Y)),2.0) +::STMT +MATRIX:r,parsertemp1936,parsertemp1937 +FLOAT:parsertemp1941,norm_r2 +*(+(r,*(/(norm_r2,parsertemp1941),+(parsertemp1936,parsertemp1937))),+(r,*(/(norm_r2,parsertemp1941),+(parsertemp1936,parsertemp1937)))) +::STMT +MATRIX:linear_terms +FLOAT:link_power +LITERAL_FLOAT:2.0 +/(^(linear_terms,/(2.0,link_power)),2.0) +::STMT +MATRIX:linear_terms +FLOAT:var_power +LITERAL_FLOAT:1.0 +/(linear_terms,-(1.0,var_power)) +::STMT +MATRIX:Y_prob,Y +-(*(Y,Y_prob),*(Y,Y_prob)) +::STMT +MATRIX:P +LITERAL_FLOAT:1.0,100.0 +INT:int801,int859 +%*%(P,rand(int859,int801,1.0,100.0)) +::STMT +FLOAT:502_strideh,502_padh,int986,502_Hin,502_Hf +LITERAL_FLOAT:2.0 ++(-(*(502_strideh,-(502_Hin,int986)),*(2.0,502_padh)),502_Hf) +::STMT +MATRIX:parsertemp195899 +FLOAT:center +LITERAL_FLOAT:1.0 +t(-(1.0,abs(-(parsertemp195899,center)))) +::STMT +MATRIX:parsertemp539203 +FLOAT:int999 +LITERAL_FLOAT:2.0,0.6666666666666666 +min(^(/(*(parsertemp539203,int999),2.0),0.6666666666666666)) +::STMT +MATRIX:parsertemp32833,parsertemp32842,X,Y,parsertemp32827,parsertemp32824,K,parsertemp32839 +FLOAT:x +LITERAL_FLOAT:1.0 ++(*(-(*(K,parsertemp32833),-(Y,Y)),-(1.0,/(parsertemp32824,parsertemp32827))),*(+(*(parsertemp32839,parsertemp32842),-(Y,Y)),/(-(x,X),-(X,X)))) +::STMT +MATRIX:X,Y,out +%*%(t(X),*(out,Y)) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,33.0 ++(*(-(i,1.0),33.0),1.0) +::STMT +MATRIX:lambda,parsertemp149248,V,X,P_1K,parsertemp149251 ++(%*%(t(X),-(*(P_1K,parsertemp149248),*(P_1K,parsertemp149251))),*(lambda,V)) +::STMT +MATRIX:rowSums_X_sq +LITERAL_FLOAT:0.5 +/(0.5,max(sqrt(rowSums_X_sq))) +::STMT +MATRIX:X,Y,K +-(*(cast.FLOAT(K),-(cast.FLOAT(X),cast.FLOAT(X))),-(cast.FLOAT(Y),cast.FLOAT(Y))) +::STMT +LITERAL_FLOAT:110.0,3000.0 +*(3000.0,110.0) +::STMT +MATRIX:s +FLOAT:int741,alpha,n +LITERAL_FLOAT:1.0 +*(-(1.0,alpha),-(*(/(int741,s),n),1.0)) +::STMT +LITERAL_FLOAT:3.0,5.0,2000.0 +*(+(2000.0,5.0),-(2000.0,3.0)) +::STMT +MATRIX:the_exp +FLOAT:int91,int490 +LITERAL_FLOAT:1.0,1.0E7 +*(-(1.0,==(+(int91,the_exp),1.0E7)),-(1.0,exp(-(int490,the_exp)))) +::STMT +FLOAT:parsertemp557354,parsertemp557356,prob_true +/(*(prob_true,parsertemp557354),parsertemp557356) +::STMT +MATRIX:parsertemp42288,_sbcvar332,parsertemp42289 +FLOAT:meanX +LITERAL_FLOAT:9999.0,0.5 +*(/(_sbcvar332,9999.0),-(+(-(parsertemp42288,parsertemp42289),0.5),meanX)) +::STMT +MATRIX:_sbcvar11 +LITERAL_FLOAT:1000.0 +/(%*%(rowSums(_sbcvar11),colSums(_sbcvar11)),1000.0) +::STMT +MATRIX:parsertemp436682 +FLOAT:d +t(*(d,parsertemp436682)) +::STMT +MATRIX:parsertemp31023,parsertemp31025 +LITERAL_FLOAT:2.0,99.0,990000.0 +/(^(/(-(parsertemp31023,parsertemp31025),99.0),2.0),990000.0) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,32.0 ++(*(-(i,1.0),32.0),1.0) +::STMT +FLOAT:alpha_LS,r_LS,norm_r2_LS,p_LS,int933 +LITERAL_FLOAT:0.0 ++(-(0.0,+(r_LS,*(alpha_LS,p_LS))),*(/(^(r_LS,int933),norm_r2_LS),cast.FLOAT(p_LS))) +::STMT +MATRIX:resp,mean,X +*(mean,%*%(t(resp),X)) +::STMT +MATRIX:mW2,dW2 +FLOAT:193_beta1 +LITERAL_FLOAT:1.0 ++(*(193_beta1,mW2),*(-(1.0,193_beta1),dW2)) +::STMT +FLOAT:idx +LITERAL_FLOAT:1.0,12.0 ++(-(12.0,idx),1.0) +::STMT +MATRIX:_sbcvar1716 +LITERAL_FLOAT:30.0 ++(30.0,nrow(_sbcvar1716)) +::STMT +FLOAT:sig,q,mu,int505 +LITERAL_FLOAT:1.0,4.0 +-(1.0,/(-(q,*(int505,mu)),*(4.0,*(sig,sig)))) +::STMT +MATRIX:posSamples,posSampleMeans +FLOAT:int950,int417 +LITERAL_FLOAT:6999.0,7000.0 +/(-(colSums(^(posSamples,int950)),*(7000.0,^(posSampleMeans,int417))),6999.0) +::STMT +MATRIX:dout,X +LITERAL_FLOAT:0.0 +*(>(X,0.0),dout) +::STMT +MATRIX:features,beta_unscaled +FLOAT:intercept ++(%*%(features,beta_unscaled),intercept) +::STMT +MATRIX:X_batch,mW1,parsertemp146957,187_dX +FLOAT:beta1 +LITERAL_FLOAT:1.0 ++(*(beta1,mW1),*(-(1.0,beta1),%*%(t(X_batch),*(parsertemp146957,187_dX)))) +::STMT +FLOAT:parsertemp40813,m2,mu +LITERAL_FLOAT:5.0 +-(mu,*(5.0,sqrt(*(parsertemp40813,m2)))) +::STMT +MATRIX:Y,linear_terms +-(Y,exp(linear_terms)) +::STMT +LITERAL_FLOAT:61.0,4.0 +/(61.0,4.0) +::STMT +MATRIX:qLow,length +<(length,qLow) +::STMT +MATRIX:inactive_set,w +FLOAT:int224 +sum(abs(-(inactive_set,!=(w,int224)))) +::STMT +MATRIX:W1_rand,stds,parsertemp393478 +LITERAL_FLOAT:0.07261134713572442 +t(%*%(*(0.07261134713572442,W1_rand),t(/(parsertemp393478,stds)))) +::STMT +LITERAL_FLOAT:1.0004995004995005 +1.0004995004995005 +::STMT +LITERAL_FLOAT:12.0,2.0 +*(12.0,2.0) +::STMT +MATRIX:parsertemp496901 +FLOAT:std +cast.MATRIX(*(cast.FLOAT(parsertemp496901),std)) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0,2003.0 +*(/(2003.0,-(2003.0,1.0)),m2) +::STMT +MATRIX:Y,parsertemp2796,Xw +LITERAL_FLOAT:0.0,1.0 +*(>(-(1.0,*(Y,Xw)),0.0),-(1.0,*(Y,+(Xw,parsertemp2796)))) +::STMT +LITERAL_FLOAT:3.4011973816621555 +3.4011973816621555 +::STMT +MATRIX:parsertemp396420,W4_rand,parsertemp396423 +LITERAL_FLOAT:0.08681986202598489 +t(%*%(*(0.08681986202598489,W4_rand),t(/(parsertemp396420,parsertemp396423)))) +::STMT +LITERAL_FLOAT:Infinity +INT:int207,parsertemp163324 +rand(parsertemp163324,int207,Infinity,Infinity) +::STMT +LITERAL_FLOAT:1.0 +INT:int223,int713 +rand(int223,int713,1.0,1.0) +::STMT +LITERAL_FLOAT:-1.0 +INT:int121,n +rand(n,int121,-1.0,-1.0) +::STMT +LITERAL_FLOAT:-1.0,1.0 +INT:num_hidden1,m +rand(num_hidden1,m,-1.0,1.0) +::STMT +MATRIX:parsertemp16858 +LITERAL_FLOAT:1.0E-6 +*(<(sqrt(rowSums(parsertemp16858)),1.0E-6),1.0E-6) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,3.0,32.0 ++(*(-(i,1.0),32.0),3.0) +::STMT +MATRIX:parsertemp129018 +LITERAL_FLOAT:2.0 +*(max(parsertemp129018),2.0) +::STMT +LITERAL_FLOAT:2.0,64.0 +/(64.0,2.0) +::STMT +MATRIX:p,parsertemp477949,parsertemp477948 +FLOAT:norm_r2 +/(norm_r2,sum(*(p,%*%(parsertemp477948,parsertemp477949)))) +::STMT +MATRIX:residual_matrix +FLOAT:273_lambda +LITERAL_FLOAT:2.0 +/(^(sum(residual_matrix),2.0),+(nrow(residual_matrix),273_lambda)) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0,-1.0 ++(1.0,exp(*(X,-1.0))) +::STMT +MATRIX:prediction,target +LITERAL_FLOAT:1.0 +/(*(/(1.0,nrow(target)),-(prediction,target)),*(prediction,-(1.0,prediction))) +::STMT +MATRIX:parsertemp44107,parsertemp44109,wnew +LITERAL_FLOAT:2.0 +^(+(wnew,*(2.0,%*%(parsertemp44107,parsertemp44109))),2.0) +::STMT +LITERAL_FLOAT:1.0,2.0 +INT:int199,parsertemp282730 +rand(parsertemp282730,int199,1.0,2.0) +::STMT +MATRIX:R,parsertemp40216,parsertemp40215,parsertemp40225 +FLOAT:level +/(+(R,rowSums(*(parsertemp40216,parsertemp40225))),+(R,rowSums(==(parsertemp40215,level)))) +::STMT +MATRIX:r,d +FLOAT:r2 +*(/(cast.FLOAT(%*%(r,r)),r2),d) +::STMT +MATRIX:parsertemp130418 +LITERAL_FLOAT:4.0 +*(max(parsertemp130418),4.0) +::STMT +MATRIX:lambda,scale_X,gXY,beta +FLOAT:int164 +t(+(*(scale_X,-(int164,gXY)),*(lambda,beta))) +::STMT +MATRIX:ss,se +FLOAT:130_eAvg,130_alpha +LITERAL_FLOAT:1.0 +*(130_alpha,-(/(/(se,ss),130_eAvg),1.0)) +::STMT +MATRIX:D,parsertemp570375,classMeans +%*%(%*%(-(D,classMeans),parsertemp570375),t(-(D,classMeans))) +::STMT +FLOAT:nc +LITERAL_FLOAT:1.0,10.0 +*(+(10.0,1.0),-(nc,1.0)) +::STMT +LITERAL_FLOAT:3.0,5.0,2003.0 +*(+(2003.0,5.0),-(2003.0,3.0)) +::STMT +FLOAT:FN,FP,TN,TP +*(*(*(+(TP,FP),+(TP,FN)),+(TN,FP)),+(TN,FN)) +::STMT +LITERAL_FLOAT:64.0,8.0 +/(64.0,8.0) +::STMT +MATRIX:parsertemp170238 +FLOAT:float74 +LITERAL_FLOAT:1.0,1.061405429 +*(/(1.0,+(1.0,*(parsertemp170238,float74))),1.061405429) +::STMT +MATRIX:W,H +LITERAL_FLOAT:1.0E-8 ++(%*%(%*%(t(W),W),H),1.0E-8) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,750.0 ++(rowSums(classFeatureCounts),*(750.0,1.0)) +::STMT +MATRIX:X,outlierFilter +LITERAL_FLOAT:0.0 +*(==(outlierFilter,0.0),X) +::STMT +MATRIX:Y,linear_terms +FLOAT:link_power +LITERAL_FLOAT:1.0 +-(Y,^(linear_terms,/(1.0,link_power))) +::STMT +LITERAL_FLOAT:4.0,64.0 +/(64.0,4.0) +::STMT +FLOAT:m2X +LITERAL_FLOAT:1.0005 +sqrt(*(m2X,1.0005)) +::STMT +MATRIX:parsertemp460644 +LITERAL_FLOAT:0.0625,1.4142135623730951 +/(*(parsertemp460644,0.0625),1.4142135623730951) +::STMT +MATRIX:_sbcvar415,X2 +LITERAL_FLOAT:0.050000000000000044,1.0 +*(0.050000000000000044,-(/(nrow(X2),_sbcvar415),1.0)) +::STMT +MATRIX:lambda,scale_X,p_CG,w,X,parsertemp285715 ++(*(lambda,p_CG),%*%(diag(scale_X),%*%(t(X),*(w,parsertemp285715)))) +::STMT +MATRIX:X +FLOAT:2917_N,2917_split +LITERAL_FLOAT:1.0 ++(-(nrow(X),round(*(2917_N,2917_split))),1.0) +::STMT +MATRIX:C,X +FLOAT:int301 +LITERAL_FLOAT:-2.0 ++(*(-2.0,%*%(X,t(C))),t(rowSums(^(C,int301)))) +::STMT +MATRIX:Y_counts,Y,avg_tot_Y +LITERAL_FLOAT:2.0 +colSums(^(-(Y,%*%(Y_counts,avg_tot_Y)),2.0)) +::STMT +MATRIX:parsertemp555766,target +LITERAL_FLOAT:1.0 +*(-(1.0,target),parsertemp555766) +::STMT +MATRIX:samples_vs_runs_map,centroid_placer,X_samples +LITERAL_FLOAT:2.0 +%*%(samples_vs_runs_map,rowSums(^(%*%(centroid_placer,X_samples),2.0))) +::STMT +MATRIX:parsertemp285718,p_CG,shift_X,parsertemp285720,temp_CG +sum(*(p_CG,+(+(parsertemp285718,parsertemp285720),%*%(shift_X,temp_CG)))) +::STMT +LITERAL_FLOAT:3.0,5.0,2001.0 +*(+(2001.0,5.0),-(2001.0,3.0)) +::STMT +MATRIX:parsertemp386457,parsertemp386448,parsertemp386451,parsertemp386453,withinEps +FLOAT:int257,int227 +LITERAL_FLOAT:0.0 +*(*(>(*(parsertemp386448,withinEps),0.0),&(==(parsertemp386451,int257),>(parsertemp386453,int227))),parsertemp386457) +::STMT +MATRIX:W +LITERAL_FLOAT:1.0,6.0 +*(*(6.0,sum(round(W))),-(sum(round(W)),1.0)) +::STMT +MATRIX:var_X_cols,parsertemp1522 +FLOAT:int590 +LITERAL_FLOAT:1.0 +/(1.0,sqrt(+(*(var_X_cols,parsertemp1522),<=(var_X_cols,int590)))) +::STMT +LITERAL_FLOAT:1.0,2003.0 +/(2003.0,-(2003.0,1.0)) +::STMT +MATRIX:mu +cast.FLOAT(*(mu,mu)) +::STMT +MATRIX:posSamples,posSampleMeans +FLOAT:int131,int672 +LITERAL_FLOAT:1.0,2000.0 +/(-(colSums(^(posSamples,int672)),*(2000.0,^(posSampleMeans,int131))),-(2000.0,1.0)) +::STMT +MATRIX:parsertemp43993,d,Hd,parsertemp44001 +*(cast.FLOAT(/(sum(parsertemp43993),%*%(parsertemp44001,Hd))),d) +::STMT +MATRIX:parsertemp399256,W4_rand,parsertemp399259 +LITERAL_FLOAT:0.08725945907447251 +t(%*%(*(0.08725945907447251,W4_rand),t(/(parsertemp399256,parsertemp399259)))) +::STMT +MATRIX:d,X,logisticD +*(logisticD,%*%(X,d)) +::STMT +MATRIX:P,I,X2 +LITERAL_FLOAT:0.0 +!=(*(t(%*%(X2,P)),I),0.0) +::STMT +MATRIX:is_zero_y_corr,is_one_y_corr,parsertemp171113 +FLOAT:parsertemp171116 +-(parsertemp171113,*(parsertemp171116,+(is_zero_y_corr,is_one_y_corr))) +::STMT +MATRIX:b,X +*(X,exp(%*%(X,b))) +::STMT +MATRIX:W4_rand +LITERAL_FLOAT:0.08725945907447251 +*(0.08725945907447251,W4_rand) +::STMT +FLOAT:i,n +LITERAL_FLOAT:-1.0,3.0 +*(n,^(3.0,*(i,-1.0))) +::STMT +MATRIX:2700_X,2700_W,2726_dpred,parsertemp459177,2699_probs +LITERAL_FLOAT:5.0E-4 ++(%*%(t(2700_X),-(*(2726_dpred,2699_probs),*(2699_probs,parsertemp459177))),*(5.0E-4,2700_W)) +::STMT +LITERAL_FLOAT:0.0,1.0 +INT:int840,int752,int382,int905 ++(%*%(rand(int382,int905,0.0,1.0),rand(int840,int752,0.0,1.0)),0.0) +::STMT +MATRIX:ts +LITERAL_FLOAT:4.0 +-(length(ts),4.0) +::STMT +MATRIX:linear_terms,Y +LITERAL_FLOAT:1.0 +*(^(exp(linear_terms),1.0),-(Y,exp(linear_terms))) +::STMT +MATRIX:linear_terms,Y +LITERAL_FLOAT:-1.0 +*(^(exp(linear_terms),-1.0),-(Y,exp(linear_terms))) +::STMT +MATRIX:W2_rand +LITERAL_FLOAT:0.5107539184552492 +*(0.5107539184552492,W2_rand) +::STMT +MATRIX:r +LITERAL_FLOAT:0.0,9.999999999999998E-15 +*(-(0.0,cast.FLOAT(%*%(r,r))),9.999999999999998E-15) +::STMT +FLOAT:p,i +LITERAL_FLOAT:1.0 +-(+(p,1.0),i) +::STMT +LITERAL_FLOAT:1.0,6.0,2000.0 +*(*(6.0,2000.0),-(2000.0,1.0)) +::STMT +MATRIX:s,g_old +FLOAT:step_sz +*(step_sz,cast.FLOAT(%*%(t(s),g_old))) +::STMT +MATRIX:lambda,parsertemp171604,beta,parsertemp171603 +LITERAL_FLOAT:2.0 +sum(^(+(+(parsertemp171603,parsertemp171604),*(lambda,beta)),2.0)) +::STMT +FLOAT:parsertemp40812,m2,int666 +LITERAL_FLOAT:5.0 +*(5.0,sqrt(*(/(int666,parsertemp40812),m2))) +::STMT +MATRIX:output,outputR,leading_NA ++(*(outputR,leading_NA),output) +::STMT +MATRIX:scale_X,parsertemp274081 +FLOAT:N +LITERAL_FLOAT:0.0 +*(-(0.0,/(t(parsertemp274081),N)),scale_X) +::STMT +MATRIX:parsertemp389187,parsertemp389190 +FLOAT:int284,int38 +LITERAL_FLOAT:1.0,2.0 +-(1.0,^(/(-(parsertemp389187,int284),+(parsertemp389190,int38)),2.0)) +::STMT +MATRIX:p,q,parsertemp1939 +FLOAT:norm_r2 +*(/(norm_r2,cast.FLOAT(%*%(parsertemp1939,q))),p) +::STMT +MATRIX:X +LITERAL_FLOAT:-1.0,2.0 +^(*(t(colSums(X)),-1.0),2.0) +::STMT +MATRIX:key_unique,key +t(==(key_unique,t(key))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,42.0 ++(*(-(i,1.0),42.0),1.0) +::STMT +MATRIX:P ++(P,t(P)) +::STMT +MATRIX:ss +FLOAT:130_n +/(130_n,ss) +::STMT +MATRIX:Xm,Z,parsertemp265713 +cast.FLOAT(%*%(colSums(%*%(Z,parsertemp265713)),rowSums(t(Xm)))) +::STMT +MATRIX:W +LITERAL_FLOAT:1.0,2.0 +*(-(sum(round(W)),1.0),-(sum(round(W)),2.0)) +::STMT +MATRIX:out3,parsertemp146931,parsertemp146929,184_unnorm_probs,parsertemp146936,184_scores,parsertemp146933 +*(/(exp(-(out3,parsertemp146933)),rowSums(exp(184_scores))),rowSums(*(*(parsertemp146929,parsertemp146931),/(184_unnorm_probs,parsertemp146936)))) +::STMT +MATRIX:p_LS,parsertemp170552 +FLOAT:lambda_LS +sum(*(p_LS,+(%*%(parsertemp170552,p_LS),*(lambda_LS,p_LS)))) +::STMT +FLOAT:ss3,ss2,int486,ssPrev,Fn,m,n +/(/(-(+(Fn,ss2),*(int486,ss3)),*(n,m)),ssPrev) +::STMT +FLOAT:a,b,c +LITERAL_FLOAT:2.0,4.0 +-(^(b,2.0),*(*(4.0,a),c)) +::STMT +MATRIX:parsertemp16858,parsertemp16867,parsertemp16865,77_X_row_norm +FLOAT:float257,float144 +LITERAL_FLOAT:1.0E-6 +%*%(+(sqrt(rowSums(parsertemp16858)),*(<(77_X_row_norm,float144),1.0E-6)),t(+(sqrt(parsertemp16865),*(parsertemp16867,float257)))) +::STMT +MATRIX:WM +sum(WM) +::STMT +MATRIX:X +FLOAT:parsertemp78,parsertemp80 +/(-(X,parsertemp78),sqrt(parsertemp80)) +::STMT +MATRIX:Train,2342_m_colmin +LITERAL_FLOAT:2.0 +*(2.0,-(Train,2342_m_colmin)) +::STMT +MATRIX:E,O +*(sum(-(O,E)),sum(-(O,E))) +::STMT +MATRIX:D,parsertemp10958 +%*%(D,t(parsertemp10958)) +::STMT +FLOAT:run_index +LITERAL_FLOAT:96.0 +*(96.0,run_index) +::STMT +FLOAT:padh,int343,parsertemp195863,strideh,out_padh,Hf ++(+(-(*(strideh,parsertemp195863),*(int343,padh)),Hf),out_padh) +::STMT +MATRIX:P,Z,ZERODIAG,parsertemp220891 +FLOAT:int1,parsertemp220894 +rowSums(*(-(P,/(Z,parsertemp220894)),*(/(int1,parsertemp220891),ZERODIAG))) +::STMT +MATRIX:parsertemp386457,parsertemp386459,parsertemp386449,parsertemp386452,parsertemp386454 +FLOAT:int981 +-(*(*(>(parsertemp386449,int981),&(parsertemp386452,parsertemp386454)),parsertemp386457),parsertemp386459) +::STMT +MATRIX:p_CG,z +*(cast.FLOAT(%*%(t(p_CG),z)),cast.FLOAT(%*%(t(p_CG),z))) +::STMT +MATRIX:Q1,X,IQR +FLOAT:k +<(X,-(Q1,*(k,IQR))) +::STMT +MATRIX:Q3,X,IQR +FLOAT:k +>(X,+(Q3,*(k,IQR))) +::STMT +MATRIX:ubScores,fSizes,parsertemp31451 +FLOAT:int463,minsc,level,int864 +&(&(fSizes,&(>(ubScores,minsc),>(ubScores,int463))),==(rowSums(!=(parsertemp31451,int864)),level)) +::STMT +LITERAL_FLOAT:53.0,8.0 +/(53.0,8.0) +::STMT +MATRIX:pearson_residual_sq +LITERAL_FLOAT:900.0 +/(sum(pearson_residual_sq),900.0) +::STMT +MATRIX:W +FLOAT:int267,wt,int283 +LITERAL_FLOAT:1.0,3.0,6.0 +/(*(*(6.0,sum(W)),-(sum(W),1.0)),*(*(-(wt,int283),+(wt,int267)),+(sum(W),3.0))) +::STMT +MATRIX:linear_terms +FLOAT:var_power,link_power +LITERAL_FLOAT:2.0 +^(linear_terms,/(-(2.0,var_power),link_power)) +::STMT +FLOAT:m2X,W,float189 +sqrt(*(m2X,/(W,-(W,float189)))) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:2.0 +/(exp(*(linear_terms,2.0)),2.0) +::STMT +LITERAL_FLOAT:7.996E9 +7.996E9 +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,500.0 +INT:int259,int839 +%*%(+(rowSums(classFeatureCounts),*(500.0,1.0)),rand(int839,int259,1.0,1.0)) +::STMT +FLOAT:522_strideh,parsertemp193444,522_Hin +LITERAL_FLOAT:1.0 ++(/(-(+(522_Hin,parsertemp193444),1.0),522_strideh),1.0) +::STMT +MATRIX:R,dssp,parsertemp40220 +FLOAT:numRows +LITERAL_FLOAT:1.0 +-(/(numRows,-(+(R,dssp),rowSums(parsertemp40220))),1.0) +::STMT +MATRIX:parsertemp171377,Y_prob,Y,parsertemp171381 +FLOAT:float771 +LITERAL_FLOAT:2.0 +/(^(rowSums(Y),2.0),*(*(*(parsertemp171377,Y_prob),Y_prob),^(*(parsertemp171381,float771),2.0))) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:1.0E7 ++(1.0E7,exp(finite_linear_terms)) +::STMT +MATRIX:pt_gp,Y,linear_terms,the_gauss_exp +FLOAT:int79,int185 +LITERAL_FLOAT:0.5 ++(-(Y,*(rowSums(Y),>=(linear_terms,int185))),*(*(*(the_gauss_exp,pt_gp),rowSums(Y)),-(>=(linear_terms,int79),0.5))) +::STMT +MATRIX:parsertemp1516,parsertemp1514 +FLOAT:n +LITERAL_FLOAT:0.0,1.0 +<=(/(-(t(parsertemp1514),*(n,parsertemp1516)),-(n,1.0)),0.0) +::STMT +MATRIX:err,ncCnts,maxsc,cCnts +FLOAT:int684,int597,float897,minSup +sum(&(&(>=(cCnts,minSup),>(err,int684)),|(>(ncCnts,int597),>(maxsc,float897)))) +::STMT +FLOAT:i1 +LITERAL_FLOAT:1.0,2.0 ++(1.0,*(i1,2.0)) +::STMT +LITERAL_FLOAT:-1.453152027 +-1.453152027 +::STMT +MATRIX:s +LITERAL_FLOAT:2.0 ++(s,2.0) +::STMT +FLOAT:i,cols,n +LITERAL_FLOAT:1.0 ++(-(n,-(+(i,cols),1.0)),1.0) +::STMT +MATRIX:means,parsertemp560511,parsertemp560515 +FLOAT:int468 +LITERAL_FLOAT:2.0 +-(rowSums(*(means,^(parsertemp560515,int468))),^(rowSums(*(means,parsertemp560511)),2.0)) +::STMT +MATRIX:X +FLOAT:m2X +LITERAL_FLOAT:1.0 +*(m2X,/(nrow(X),-(nrow(X),1.0))) +::STMT +MATRIX:parsertemp222331 +FLOAT:sample_block_size +LITERAL_FLOAT:0.5 ++(0.5,/(parsertemp222331,sample_block_size)) +::STMT +MATRIX:parsertemp387405,Ks,Kss +abs(-(cast.FLOAT(Kss),cast.FLOAT(%*%(parsertemp387405,Ks)))) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0 ++(ncol(X),1.0) +::STMT +MATRIX:imputed_Y +LITERAL_FLOAT:NaN ++(imputed_Y,NaN) +::STMT +MATRIX:X_batch,parsertemp389604,parsertemp389600,parsertemp389601 +FLOAT:int708,int998 +LITERAL_FLOAT:1.0,2.0 +*(-(/(-(parsertemp389600,int708),+(parsertemp389600,int998)),X_batch),-(1.0,^(/(parsertemp389601,parsertemp389604),2.0))) +::STMT +MATRIX:parsertemp146961,dout1,mW1 +FLOAT:191_t,191_lr,191_beta1,parsertemp146980,int721 +LITERAL_FLOAT:1.0 +*(/(*(191_lr,sqrt(parsertemp146980)),-(1.0,^(191_beta1,191_t))),+(*(191_beta1,mW1),*(-(int721,191_beta1),%*%(parsertemp146961,dout1)))) +::STMT +MATRIX:q_CG,z +FLOAT:parsertemp170094,pp_CG,pq_CG +LITERAL_FLOAT:0.5 ++(*(*(0.5,/(parsertemp170094,pp_CG)),pq_CG),*(cast.FLOAT(z),cast.FLOAT(q_CG))) +::STMT +MATRIX:Y +FLOAT:minv +sum(==(Y,minv)) +::STMT +FLOAT:i +LITERAL_FLOAT:100.0 +*(i,100.0) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0 ++(ncol(X),0.0) +::STMT +MATRIX:Hdiff,betamax,beta +FLOAT:int175,int467 +LITERAL_FLOAT:1.0E20 +!=(+(*(>=(Hdiff,int467),betamax),*(<(Hdiff,int175),beta)),1.0E20) +::STMT +MATRIX:B +FLOAT:ncolX +-(ncolX,nrow(B)) +::STMT +MATRIX:surv,se_surv +FLOAT:z_alpha_2 +LITERAL_FLOAT:-1.0 +/(*(*(z_alpha_2,-1.0),se_surv),surv) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0 +<(X,1.0) +::STMT +MATRIX:parsertemp170239 +FLOAT:float481 +LITERAL_FLOAT:1.0,1.061405429,-1.453152027 ++(-1.453152027,*(/(1.0,+(float481,parsertemp170239)),1.061405429)) +::STMT +MATRIX:R,parsertemp503780 +%*%(t(+(R,diag(parsertemp503780))),+(R,diag(parsertemp503780))) +::STMT +FLOAT:var_power +LITERAL_FLOAT:2.0 +-(2.0,var_power) +::STMT +FLOAT:featureCorrection +LITERAL_FLOAT:0.0 +-(0.0,featureCorrection) +::STMT +MATRIX:parsertemp500606,parsertemp500607,parsertemp500604,w,parsertemp500610 +FLOAT:int952 +%*%(t(-(*(parsertemp500607,parsertemp500610),w)),-(*(*(parsertemp500604,parsertemp500606),>(parsertemp500606,int952)),w)) +::STMT +MATRIX:parsertemp472316,parsertemp472314,ig +FLOAT:min_leaf +rev(*(&(>=(parsertemp472314,min_leaf),>=(parsertemp472316,min_leaf)),ig)) +::STMT +MATRIX:parsertemp31026,parsertemp31033 +FLOAT:parsertemp31034,parsertemp31027 +LITERAL_FLOAT:150.0,100.0 +sqrt(+(/(/(parsertemp31026,parsertemp31027),100.0),/(/(parsertemp31033,parsertemp31034),150.0))) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0005002501250626 +sqrt(*(1.0005002501250626,m2)) +::STMT +MATRIX:is_natural_parameter_log_zero,Y +FLOAT:int849 +LITERAL_FLOAT:0.0,1.0 +/(*(>(Y,0.0),is_natural_parameter_log_zero),-(1.0,*(>(Y,int849),is_natural_parameter_log_zero))) +::STMT +MATRIX:P,parsertemp222624,X +/(%*%(t(/(P,parsertemp222624)),X),t(colSums(/(P,parsertemp222624)))) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0004995004995005,5.0 +*(5.0,sqrt(*(1.0004995004995005,m2))) +::STMT +MATRIX:Xd,out +FLOAT:int853 +sum(*(*(Xd,>(out,int853)),Xd)) +::STMT +MATRIX:id +diag(==(id,t(id))) +::STMT +MATRIX:z +FLOAT:trust_delta_sq +-(*(cast.FLOAT(z),cast.FLOAT(z)),trust_delta_sq) +::STMT +MATRIX:X,Y,out,parsertemp2798 +FLOAT:int662,int861 +%*%(t(X),*(*(>(out,int861),-(int662,parsertemp2798)),Y)) +::STMT +MATRIX:d,exp_Xb,X +*(X,*(%*%(X,d),exp_Xb)) +::STMT +MATRIX:output_values +FLOAT:log_odds +LITERAL_FLOAT:0.3 ++(log_odds,*(0.3,cast.FLOAT(output_values))) +::STMT +MATRIX:_sbcvar78 +LITERAL_FLOAT:10000.0 +/(%*%(rowSums(_sbcvar78),colSums(_sbcvar78)),10000.0) +::STMT +MATRIX:parsertemp403509,W4_rand +FLOAT:int45,int391 +LITERAL_FLOAT:0.086386842558136 +%*%(*(0.086386842558136,W4_rand),t(/(-(parsertemp403509,int391),+(parsertemp403509,int45)))) +::STMT +MATRIX:X,parsertemp32827,Y,parsertemp32824 +FLOAT:x +LITERAL_FLOAT:1.0 ++(*(-(1.0,/(parsertemp32824,parsertemp32827)),Y),*(/(-(x,X),-(X,X)),Y)) +::STMT +MATRIX:W,X ++(%*%(X,W),W) +::STMT +MATRIX:lambda,parsertemp170067,parsertemp170065,p_CG,shift_X,parsertemp170060,temp_CG ++(+(*(cast.FLOAT(lambda),cast.FLOAT(p_CG)),*(cast.FLOAT(parsertemp170060),cast.FLOAT(temp_CG))),*(cast.FLOAT(shift_X),cast.FLOAT(%*%(parsertemp170065,parsertemp170067)))) +::STMT +MATRIX:parsertemp115858,X,parsertemp115860 +FLOAT:n +LITERAL_FLOAT:0.0,1.0 +<=(/(-(t(parsertemp115858),*(n,parsertemp115860)),-(nrow(X),1.0)),0.0) +::STMT +MATRIX:I,y2 +/(%*%(I,y2),sum(I)) +::STMT +MATRIX:termination_bitmap,parsertemp441285,tmp +==(*(parsertemp441285,termination_bitmap),min(tmp)) +::STMT +MATRIX:the_exp,linear_terms,Y +FLOAT:int894 +*(*(exp(-(int894,the_exp)),exp(linear_terms)),rowSums(Y)) +::STMT +MATRIX:_sbcvar1156 +FLOAT:num_records +LITERAL_FLOAT:1.0 +*(+(num_records,1.0),-(1.0,_sbcvar1156)) +::STMT +MATRIX:parsertemp383010,U,X,X_nonzero_ind +LITERAL_FLOAT:2.0 +*(X_nonzero_ind,^(-(X,%*%(U,parsertemp383010)),2.0)) +::STMT +MATRIX:G,authorities +max(%*%(t(G),%*%(G,authorities))) +::STMT +FLOAT:i +LITERAL_FLOAT:42.0 ++(42.0,i) +::STMT +FLOAT:parsertemp13703 +LITERAL_FLOAT:1000.0 +*(parsertemp13703,1000.0) +::STMT +MATRIX:D,beta +LITERAL_FLOAT:-1.0 +*(*(D,-1.0),beta) +::STMT +LITERAL_FLOAT:1.0E-15 +1.0E-15 +::STMT +MATRIX:linear_terms +FLOAT:var_power +LITERAL_FLOAT:1.0 +/(/(1.0,linear_terms),-(1.0,var_power)) +::STMT +FLOAT:parsertemp380175,interval,i_process_item +LITERAL_FLOAT:1.0 ++(-(i_process_item,+(*(parsertemp380175,interval),1.0)),1.0) +::STMT +MATRIX:X2 +FLOAT:parsertemp31772 +-(ncol(X2),parsertemp31772) +::STMT +MATRIX:parsertemp132035,left,parsertemp132041,right +==(%*%(parsertemp132035,left),%*%(parsertemp132041,right)) +::STMT +FLOAT:int252,a,b,c,x ++(+(*(a,^(x,int252)),*(b,x)),c) +::STMT +MATRIX:parsertemp40482,totalE,l +/(t(%*%(t(totalE),==(parsertemp40482,l))),t(colSums(==(parsertemp40482,l)))) +::STMT +MATRIX:X_Train,X_Test +FLOAT:float605,float128,float454,float355 +INT:int571,int543,int998,int370 +-(+(sum(rand(int571,int370,float454,float128)),sum(rand(int998,int543,float605,float355))),+(sum(X_Train),sum(X_Test))) +::STMT +FLOAT:s_err_vars,s_err_mean +LITERAL_FLOAT:-0.001 +/(-(-0.001,s_err_mean),s_err_vars) +::STMT +FLOAT:qmle_val,_funvar2930 +LITERAL_FLOAT:1.0E-5 +/(-(_funvar2930,qmle_val),1.0E-5) +::STMT +MATRIX:W +FLOAT:m2 +LITERAL_FLOAT:2.0,3.0 +*(*(3.0,^(m2,2.0)),^(sum(round(W)),2.0)) +::STMT +MATRIX:parsertemp31338,_sbcvar264 +FLOAT:parsertemp31331,float537 +LITERAL_FLOAT:9999.0,1.0 +-(1.0,/(sum(*(parsertemp31338,_sbcvar264)),*(9999.0,/(parsertemp31331,float537)))) +::STMT +MATRIX:s,parsertemp44016 +FLOAT:delta2 +-(delta2,cast.FLOAT(%*%(t(s),-(s,parsertemp44016)))) +::STMT +LITERAL_FLOAT:6.0,2000.0 +*(6.0,2000.0) +::STMT +MATRIX:parsertemp467657,Xd,parsertemp467661 +FLOAT:dd,step_sz,wd +/(-(+(wd,*(step_sz,dd)),sum(*(parsertemp467657,Xd))),+(dd,sum(*(parsertemp467661,Xd)))) +::STMT +MATRIX:Y_counts,parsertemp560606,Y +LITERAL_FLOAT:1.0,2.0 +/(colSums(^(-(Y,parsertemp560606),2.0)),-(sum(Y_counts),1.0)) +::STMT +MATRIX:W +LITERAL_FLOAT:3.0 ++(sum(round(W)),3.0) +::STMT +MATRIX:K1 +cast.FLOAT(K1) +::STMT +MATRIX:proposer_pointers +LITERAL_FLOAT:1.0 ++(cast.FLOAT(proposer_pointers),1.0) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:1.0E7 +==(+(1.0E7,exp(finite_linear_terms)),1.0E7) +::STMT +MATRIX:W +LITERAL_FLOAT:1.0 ++(sum(round(W)),1.0) +::STMT +MATRIX:parsertemp31277 +FLOAT:parsertemp31279,varY +LITERAL_FLOAT:1.0 +sqrt(-(1.0,/(sum(parsertemp31277),*(parsertemp31279,varY)))) +::STMT +MATRIX:2792_NID +LITERAL_FLOAT:1.0,2.0 ++(*(2.0,2792_NID),1.0) +::STMT +MATRIX:p,parsertemp116065,lambda,shift_X +sum(*(p,+(+(parsertemp116065,shift_X),*(lambda,p)))) +::STMT +FLOAT:191_beta2,191_t,int124 +LITERAL_FLOAT:1.0 +sqrt(-(1.0,^(191_beta2,+(191_t,int124)))) +::STMT +MATRIX:S +LITERAL_FLOAT:2.0,479.0 +/(^(diag(S),2.0),479.0) +::STMT +FLOAT:parsertemp164939,n +LITERAL_FLOAT:2.0 ++(2.0,*(n,parsertemp164939)) +::STMT +MATRIX:leaf_ids,out +FLOAT:boundary_right,boundary_left,step_size +-(+(out,&(>=(leaf_ids,boundary_left),<(leaf_ids,boundary_right))),&(!(<(leaf_ids,boundary_right)),<(leaf_ids,+(boundary_right,step_size)))) +::STMT +FLOAT:int313,int889 +LITERAL_FLOAT:0.0 +INT:int69,int17 +*(rand(int69,int17,int889,int313),0.0) +::STMT +MATRIX:X +FLOAT:x +cast.FLOAT(-(x,X)) +::STMT +MATRIX:w,yt,Xt +LITERAL_FLOAT:0.0 +sum(>(*(yt,%*%(Xt,w)),0.0)) +::STMT +MATRIX:ytest,yhat +/(sum(-(ytest,yhat)),nrow(ytest)) +::STMT +MATRIX:W,X,H +LITERAL_FLOAT:1.0E-8 +/(X,+(%*%(W,H),1.0E-8)) +::STMT +FLOAT:index +LITERAL_FLOAT:2.0 +*(index,2.0) +::STMT +MATRIX:parsertemp399243,parsertemp399246,W3_rand +LITERAL_FLOAT:0.6546536707079771 +t(%*%(*(0.6546536707079771,W3_rand),t(/(parsertemp399243,parsertemp399246)))) +::STMT +MATRIX:X,Centering +LITERAL_FLOAT:1.0,2.0 +/(colSums(^(-(X,Centering),2.0)),-(nrow(X),1.0)) +::STMT +MATRIX:X2p,maxsc +LITERAL_FLOAT:0.0 +|(>(t(colSums(X2p)),0.0),>(maxsc,0.0)) +::STMT +LITERAL_FLOAT:1.0,0.7 +-(1.0,0.7) +::STMT +MATRIX:_sbcvar92,parsertemp27718,parsertemp27720,220_E +FLOAT:220_W,float561 +LITERAL_FLOAT:2.0 +sum(/(^(-(_sbcvar92,220_E),2.0),+(*(parsertemp27720,float561),/(parsertemp27718,220_W)))) +::STMT +MATRIX:X_batch,dout1 +FLOAT:191_beta2 +LITERAL_FLOAT:1.0,2.0 +*(-(1.0,191_beta2),^(%*%(t(X_batch),dout1),2.0)) +::STMT +MATRIX:fP +FLOAT:max_values +/(^($1:ncol(fP),max_values),$1) +::STMT +FLOAT:g +LITERAL_FLOAT:1.0,2.0 +*(-(g,1.0),2.0) +::STMT +MATRIX:p,q,r,parsertemp1597,lambda +FLOAT:norm_r2 ++(r,*(/(norm_r2,sum(parsertemp1597)),+(q,*(lambda,p)))) +::STMT +MATRIX:parsertemp389212,parsertemp389214 +FLOAT:n +*(-(/(colSums(parsertemp389214),n),*(/(parsertemp389212,n),/(parsertemp389212,n))),n) +::STMT +MATRIX:y_hat,b,R +LITERAL_FLOAT:2.0 +^(-(-(b,%*%(R,y_hat)),y_hat),2.0) +::STMT +FLOAT:sample_block_size +LITERAL_FLOAT:3.0 +*(sample_block_size,3.0) +::STMT +MATRIX:is_zero_y_corr,is_one_y_corr,parsertemp317435 +FLOAT:float284 +LITERAL_FLOAT:1.0 +-(+(parsertemp317435,/(is_one_y_corr,-(float284,is_one_y_corr))),/(is_zero_y_corr,-(1.0,is_zero_y_corr))) +::STMT +MATRIX:parsertemp220853,parsertemp220854,Hneg,beta,betamin,Hpos +LITERAL_FLOAT:0.0,3.4011973816621555 +*(<(-(+(parsertemp220853,parsertemp220854),3.4011973816621555),0.0),+(beta,+(*(Hneg,betamin),*(Hpos,beta)))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,7.0 ++(*(-(i,1.0),7.0),1.0) +::STMT +FLOAT:check_max,check_min +-(check_max,check_min) +::STMT +FLOAT:mantissa +LITERAL_FLOAT:-1.0 +*(mantissa,-1.0) +::STMT +FLOAT:m_orig +LITERAL_FLOAT:1.0 +*(m_orig,1.0) +::STMT +MATRIX:X,Y +LITERAL_FLOAT:1.0E-10 ++(+(abs(X),abs(Y)),1.0E-10) +::STMT +MATRIX:p,lambda,X +%*%(t(p),+(%*%(t(X),%*%(X,p)),*(lambda,p))) +::STMT +MATRIX:R,parsertemp40215 +FLOAT:numRows,level +/(numRows,+(R,rowSums(==(parsertemp40215,level)))) +::STMT +MATRIX:p,Z +FLOAT:norm_r2 +/(norm_r2,cast.FLOAT(%*%(t(p),%*%(Z,p)))) +::STMT +FLOAT:odds +LITERAL_FLOAT:1.0 +/(odds,-(1.0,odds)) +::STMT +MATRIX:parsertemp131906,parsertemp132092,outBucket +==(outBucket,%*%(parsertemp132092,t(parsertemp131906))) +::STMT +MATRIX:V,y +LITERAL_FLOAT:-1.0 +*(*(%*%(t(V),y),-1.0),*(%*%(t(V),y),-1.0)) +::STMT +MATRIX:p_CG +FLOAT:parsertemp254766,int972,parsertemp254749,int767,z +*(parsertemp254766,/(+(*(z,int972),sqrt(parsertemp254749)),sum(^(p_CG,int767)))) +::STMT +MATRIX:parsertemp122290,X2 +LITERAL_FLOAT:0.0,4.0 +&(>=(t(colSums(X2)),4.0),>(t(%*%(parsertemp122290,X2)),0.0)) +::STMT +FLOAT:i +LITERAL_FLOAT:8.0 +*(i,8.0) +::STMT +MATRIX:Y,parsertemp221025 +LITERAL_FLOAT:1.0 +*(/(1.0,+(Y,1.0)),+(diag(parsertemp221025),1.0)) +::STMT +MATRIX:sample_rec_ids +FLOAT:num_records +LITERAL_FLOAT:1.0 +-(1.0,<=(sample_rec_ids,num_records)) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,7.0 +*(-(i,1.0),7.0) +::STMT +FLOAT:i +LITERAL_FLOAT:7.0 +*(i,7.0) +::STMT +MATRIX:linear_terms +FLOAT:var_power +LITERAL_FLOAT:2.0 +/(linear_terms,-(2.0,var_power)) +::STMT +MATRIX:parsertemp171084,parsertemp171083,parsertemp171091 +FLOAT:float122 +LITERAL_FLOAT:-2.0,1.432788 +*(sqrt(*(-2.0,parsertemp171083)),+(1.432788,*(sqrt(parsertemp171084),+(float122,parsertemp171091)))) +::STMT +MATRIX:neighbors +LITERAL_FLOAT:0.0 +<(0.0,-(neighbors,diag(diag(neighbors)))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,8.0 +*(-(i,1.0),8.0) +::STMT +LITERAL_FLOAT:2.302585092994046 +2.302585092994046 +::STMT +MATRIX:y_corr +LITERAL_FLOAT:3.141592653589793,0.5 +*(-(y_corr,0.5),3.141592653589793) +::STMT +MATRIX:m +FLOAT:sum +sqrt(-(m,sum)) +::STMT +MATRIX:z +LITERAL_FLOAT:2.0 +^(cast.FLOAT(z),2.0) +::STMT +FLOAT:i +LITERAL_FLOAT:12.0 +*(i,12.0) +::STMT +MATRIX:y_batch +LITERAL_FLOAT:0.0,1.0 +*(/(1.0,nrow(y_batch)),-(0.0,y_batch)) +::STMT +FLOAT:num_records +LITERAL_FLOAT:10.0 +*(num_records,10.0) +::STMT +MATRIX:parsertemp43631,parsertemp43633 +LITERAL_FLOAT:0.0,2.0 +INT:int81,int873,int500,int493 +*(+(rand(int493,int500,0.0,0.0),*(2.0,%*%(parsertemp43631,parsertemp43633))),+(rand(int81,int873,0.0,0.0),*(2.0,%*%(parsertemp43631,parsertemp43633)))) +::STMT +LITERAL_FLOAT:0.1651445647689541 +0.1651445647689541 +::STMT +FLOAT:p_CG,parsertemp170088,z,pp_CG,parsertemp170090 +LITERAL_FLOAT:-1.0 +/(+(*(*(z,p_CG),-1.0),sqrt(-(parsertemp170088,parsertemp170090))),pp_CG) +::STMT +FLOAT:index +LITERAL_FLOAT:4.0 +*(index,4.0) +::STMT +FLOAT:FN,TN,FP,TP +-(*(TP,TN),*(FP,FN)) +::STMT +MATRIX:R,S,parsertemp382932,HS +FLOAT:norm_R2,alpha ++(-(R,*(alpha,HS)),*(/(sum(parsertemp382932),norm_R2),S)) +::STMT +MATRIX:P1,P2,S ++(%*%(P1,S),%*%(P2,S)) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:0.0 +<(linear_terms,0.0) +::STMT +MATRIX:S,V +FLOAT:norm_R2,parsertemp149264 +LITERAL_FLOAT:2.0 +^(+(S,*(/(norm_R2,parsertemp149264),V)),2.0) +::STMT +MATRIX:scale_lambda +LITERAL_FLOAT:1.0E-7 +*(scale_lambda,1.0E-7) +::STMT +MATRIX:r +FLOAT:norm_r2_initial,int736 +sqrt(/(sum(^(r,int736)),norm_r2_initial)) +::STMT +MATRIX:U,V,X +LITERAL_FLOAT:2.0 +^(-(X,%*%(U,t(V))),2.0) +::STMT +LITERAL_FLOAT:0.0,1.0,2.0 +INT:int48,parsertemp282730 +>(rand(parsertemp282730,int48,1.0,2.0),0.0) +::STMT +FLOAT:int710,n +LITERAL_FLOAT:1.0,2.0,0.6 +*(-(+(-(n,int710),1.0),2.0),0.6) +::STMT +FLOAT:x_to_truncate +abs(x_to_truncate) +::STMT +MATRIX:R,dssp,dsep +FLOAT:4_eAvg +/(/(+(R,dsep),+(R,dssp)),4_eAvg) +::STMT +FLOAT:i +LITERAL_FLOAT:32.0 +*(i,32.0) +::STMT +MATRIX:_sbcvar2306 +max(t(_sbcvar2306)) +::STMT +MATRIX:class_counts +LITERAL_FLOAT:50000.0 +/(class_counts,50000.0) +::STMT +FLOAT:i +LITERAL_FLOAT:33.0 +*(i,33.0) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,33.0 +*(-(i,1.0),33.0) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,32.0 +*(-(i,1.0),32.0) +::STMT +MATRIX:p_CG +FLOAT:trust_delta_sq,int862,int622,z +-(*(*(cast.FLOAT(z),sum(p_CG)),*(cast.FLOAT(z),sum(p_CG))),*(sum(^(p_CG,int622)),-(^(z,int862),trust_delta_sq))) +::STMT +FLOAT:k +LITERAL_FLOAT:40.0 +*(k,40.0) +::STMT +FLOAT:var_power,link_power +LITERAL_FLOAT:1.0 +-(/(-(1.0,var_power),link_power),1.0) +::STMT +MATRIX:simplex +FLOAT:num_func_invoc +LITERAL_FLOAT:1.0 +-(+(num_func_invoc,ncol(simplex)),1.0) +::STMT +MATRIX:a,b,t,parsertemp32856,Y,parsertemp32827,parsertemp32824 +FLOAT:int277,int378 ++(+(*(-(int378,t),Y),*(/(parsertemp32824,parsertemp32827),Y)),*(*(/(parsertemp32824,parsertemp32827),-(int277,t)),+(*(a,parsertemp32856),*(b,t)))) +::STMT +FLOAT:i +LITERAL_FLOAT:42.0 +*(i,42.0) +::STMT +MATRIX:W +LITERAL_FLOAT:2.0 +^(sum(round(W)),2.0) +::STMT +FLOAT:i +LITERAL_FLOAT:16.0 +*(i,16.0) +::STMT +FLOAT:df,int687 +LITERAL_FLOAT:4.890349128221754 ++(int687,*(df,4.890349128221754)) +::STMT +MATRIX:parsertemp500608,parsertemp500604,parsertemp500605,X +FLOAT:lambda +LITERAL_FLOAT:0.0 +%*%(X,*(*(parsertemp500604,-(parsertemp500605,lambda)),>(-(parsertemp500608,lambda),0.0))) +::STMT +MATRIX:parsertemp459793,parsertemp459795 +FLOAT:val_loss +LITERAL_FLOAT:50.0 ++(val_loss,/(sum(*(parsertemp459793,parsertemp459795)),50.0)) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0 ++(classFeatureCounts,1.0) +::STMT +MATRIX:rowSums_X_sq +LITERAL_FLOAT:1.6583123951777 +/(1.6583123951777,max(sqrt(rowSums_X_sq))) +::STMT +FLOAT:i +LITERAL_FLOAT:16.0,1.0 +*(-(i,1.0),16.0) +::STMT +MATRIX:Q,parsertemp500360 +FLOAT:int245 +%*%(parsertemp500360,t(rowSums(^(Q,int245)))) +::STMT +MATRIX:X +LITERAL_FLOAT:7.0 +<(X,7.0) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,11.0 +*(-(i,1.0),11.0) +::STMT +MATRIX:prediction,target +sum(rowSums(abs(-(prediction,target)))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,10.0 +*(-(i,1.0),10.0) +::STMT +MATRIX:CMeans,CFreqs +FLOAT:my +LITERAL_FLOAT:2.0 +*(CFreqs,^(-(CMeans,my),2.0)) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,12.0 +*(-(i,1.0),12.0) +::STMT +MATRIX:qLow,length,qUp +|(<(length,qLow),>(length,qUp)) +::STMT +MATRIX:G,authorities +/(%*%(G,authorities),max(%*%(G,authorities))) +::STMT +MATRIX:linear_terms +FLOAT:var_power,float356 +LITERAL_FLOAT:2.0 +/(exp(*(linear_terms,-(float356,var_power))),-(2.0,var_power)) +::STMT +FLOAT:log_ten,parsertemp169812 +LITERAL_FLOAT:0.5 +-(/(parsertemp169812,log_ten),0.5) +::STMT +MATRIX:parsertemp220853,parsertemp220854,betamin +LITERAL_FLOAT:0.0,3.4011973816621555 +*(<(-(+(parsertemp220853,parsertemp220854),3.4011973816621555),0.0),betamin) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,128.0 +*(-(i,1.0),128.0) +::STMT +MATRIX:R,S,parsertemp40214 +FLOAT:level ++(R,rowSums(==(%*%(S,parsertemp40214),level))) +::STMT +MATRIX:Y,predicted_Y +LITERAL_FLOAT:0.0 +==(-(predicted_Y,Y),0.0) +::STMT +MATRIX:parsertemp31046,parsertemp31051,parsertemp31042,parsertemp31043 +FLOAT:parsertemp31049,parsertemp31054 +LITERAL_FLOAT:2.0 +round(/(^(+(parsertemp31042,parsertemp31043),2.0),+(/(parsertemp31046,parsertemp31049),/(parsertemp31051,parsertemp31054)))) +::STMT +MATRIX:is_one_y_corr,parsertemp317435 +LITERAL_FLOAT:1.0 ++(parsertemp317435,/(is_one_y_corr,-(1.0,is_one_y_corr))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,100.0 +*(-(i,1.0),100.0) +::STMT +MATRIX:Q,R,parsertemp500308,parsertemp500300 +FLOAT:int213,int786,int864,int854 +LITERAL_FLOAT:2.0 +INT:int279,parsertemp500306,int987,parsertemp500303 +-(+(%*%(rowSums(parsertemp500300),rand(int279,parsertemp500303,int854,int213)),%*%(rand(parsertemp500306,int987,int864,int786),t(parsertemp500308))),*(2.0,%*%(R,t(Q)))) +::STMT +FLOAT:s,parsertemp454319 +LITERAL_FLOAT:3.0 +*(parsertemp454319,^(3.0,s)) +::STMT +MATRIX:parsertemp553013,M2,parsertemp553121,parsertemp553122 ++(%*%(rowSums(*(M2,M2)),parsertemp553121),t(%*%(rowSums(parsertemp553013),parsertemp553122))) +::STMT +MATRIX:Y +LITERAL_FLOAT:0.0,24.0 +-(+(nrow(Y),0.0),24.0) +::STMT +MATRIX:neighbors,corePts,withinEps +LITERAL_FLOAT:0.0 +>(*(*(neighbors,corePts),withinEps),0.0) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,61.0 +*(-(i,1.0),61.0) +::STMT +MATRIX:log_prob,log_det_chol +FLOAT:parsertemp436710,float252 +LITERAL_FLOAT:-0.5 ++(*(-0.5,+(*(parsertemp436710,float252),log_prob)),log_det_chol) +::STMT +MATRIX:linear_terms +FLOAT:int709 +LITERAL_FLOAT:1.0 +/(1.0,-(exp(-(int709,linear_terms)),1.0)) +::STMT +MATRIX:w,parsertemp43626 +FLOAT:int89 +LITERAL_FLOAT:2.0,0.5 ++(*(0.5,%*%(t(w),w)),*(2.0,sum(*(parsertemp43626,int89)))) +::STMT +MATRIX:sq_sums,mu +LITERAL_FLOAT:2.0,4.0 +-(/(sq_sums,4.0),^(cast.FLOAT(mu),2.0)) +::STMT +MATRIX:parsertemp171314,t_gp,parsertemp171318,parsertemp171306 +FLOAT:float174,int607 +LITERAL_FLOAT:1.0,2.0,0.254829592 +*(exp(/(-(int607,parsertemp171318),2.0)),*(/(1.0,+(float174,parsertemp171306)),+(0.254829592,*(t_gp,parsertemp171314)))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,64.0 +*(-(i,1.0),64.0) +::STMT +MATRIX:neighbors,corePts,withinEps +LITERAL_FLOAT:0.0,1.0 +*(>(*(*(neighbors,corePts),withinEps),0.0),&(==(t(corePts),0.0),>(colSums(neighbors),1.0))) +::STMT +MATRIX:parsertemp220853,Ws,beta +LITERAL_FLOAT:0.0,3.4011973816621555 +<(-(+(parsertemp220853,*(beta,Ws)),3.4011973816621555),0.0) +::STMT +MATRIX:r,parsertemp500439,y +LITERAL_FLOAT:0.5 +*(0.5,cast.FLOAT(%*%(t(r),-(parsertemp500439,y)))) +::STMT +MATRIX:parsertemp1510 +FLOAT:n +LITERAL_FLOAT:2.0 +*(n,^(/(t(parsertemp1510),n),2.0)) +::STMT +MATRIX:parsertemp31910,parsertemp31913 +FLOAT:eAvg +LITERAL_FLOAT:1.0 +-(/(/(t(parsertemp31913),t(parsertemp31910)),eAvg),1.0) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,42.0 +*(-(i,1.0),42.0) +::STMT +MATRIX:shift_X,w,ssX_p_CG,X +*(cast.FLOAT(shift_X),%*%(t(X),*(w,%*%(X,ssX_p_CG)))) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,50.0 +/(+(classFeatureCounts,1.0),+(rowSums(classFeatureCounts),50.0)) +::STMT +MATRIX:X +FLOAT:parsertemp78,parsertemp80 +LITERAL_FLOAT:3.0 +^(/(-(X,parsertemp78),sqrt(parsertemp80)),3.0) +::STMT +MATRIX:W,H,X,parsertemp410975 +FLOAT:eps +*(H,%*%(t(W),/(X,+(parsertemp410975,eps)))) +::STMT +MATRIX:A +LITERAL_FLOAT:1.0 +/(1.0,cast.FLOAT(A)) +::STMT +FLOAT:i +LITERAL_FLOAT:133.0 +*(133.0,i) +::STMT +FLOAT:parsertemp40812,m2,int416 +LITERAL_FLOAT:2000.0 +/(sqrt(*(/(int416,parsertemp40812),m2)),sqrt(2000.0)) +::STMT +MATRIX:parsertemp410978,W,X,H,parsertemp410980 +FLOAT:eps +%*%(/(X,+(%*%(W,H),eps)),t(/(*(H,parsertemp410978),t(parsertemp410980)))) +::STMT +MATRIX:U,row_nonzeros +LITERAL_FLOAT:1.0E-6 +*(*(1.0E-6,U),row_nonzeros) +::STMT +MATRIX:A,B,C,D,X +==(%*%(<=(%*%(X,A),B),C),D) +::STMT +FLOAT:i +LITERAL_FLOAT:3.0 +-(3.0,i) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:105.0,1.0 +/(+(classFeatureCounts,1.0),+(rowSums(classFeatureCounts),105.0)) +::STMT +FLOAT:Hin +LITERAL_FLOAT:2.0 +/(/(Hin,2.0),2.0) +::STMT +MATRIX:parsertemp24102 +LITERAL_FLOAT:1.0 +-(1.0,<(+(round(parsertemp24102),1.0),1.0)) +::STMT +MATRIX:parsertemp150470,parsertemp149323,LT +%*%(rowSums(exp(-(LT,parsertemp149323))),parsertemp150470) +::STMT +MATRIX:tpr,fpr +LITERAL_FLOAT:2.0 +/(*(-(fpr,fpr),+(tpr,tpr)),2.0) +::STMT +FLOAT:float878,m2,int725 +LITERAL_FLOAT:2001.0 +sqrt(*(/(2001.0,-(int725,float878)),m2)) +::STMT +MATRIX:p_CG,z +FLOAT:trust_delta_sq +LITERAL_FLOAT:2.0 +*(sum(^(p_CG,2.0)),-(*(cast.FLOAT(z),cast.FLOAT(z)),trust_delta_sq)) +::STMT +MATRIX:simplex +-(rowSums(simplex),simplex) +::STMT +FLOAT:m2,wt,float618 +LITERAL_FLOAT:5.0 +*(5.0,sqrt(/(*(m2,wt),-(wt,float618)))) +::STMT +MATRIX:parsertemp383172,X_nonzero_ind +FLOAT:parsertemp383177,reg,parsertemp383180,loss_init +-(loss_init,+(sum(*(X_nonzero_ind,parsertemp383172)),*(reg,+(parsertemp383177,parsertemp383180)))) +::STMT +MATRIX:C,parsertemp11064 +LITERAL_FLOAT:10000.0 +/(sum(==(parsertemp11064,C)),10000.0) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,500.0 +/(+(classFeatureCounts,1.0),+(rowSums(classFeatureCounts),500.0)) +::STMT +LITERAL_FLOAT:2.7182818284 +2.7182818284 +::STMT +FLOAT:217_a22,int533,parsertemp22450,parsertemp22451,parsertemp22485 +/(parsertemp22485,sqrt(+(+(parsertemp22450,parsertemp22451),/(int533,217_a22)))) +::STMT +MATRIX:Grad +FLOAT:int907 +LITERAL_FLOAT:2.0 +sqrt(sum(^(*(Grad,int907),2.0))) +::STMT +MATRIX:parsertemp553017,M2,parsertemp553121,parsertemp553020,parsertemp553009 +LITERAL_FLOAT:2.0 +sqrt(-(+(%*%(parsertemp553009,parsertemp553121),t(parsertemp553017)),*(2.0,%*%(M2,parsertemp553020)))) +::STMT +MATRIX:parsertemp500609,parsertemp500606,parsertemp500604 +FLOAT:int192 +sum(abs(*(*(parsertemp500604,parsertemp500606),>(parsertemp500609,int192)))) +::STMT +MATRIX:R,dssp,dsep,dssm,dsem +/(-(+(R,dsep),dsem),-(+(R,dssp),dssm)) +::STMT +MATRIX:parsertemp131907,offset,parsertemp131910,parsertemp132092,rightHist,mask,outBucket +LITERAL_FLOAT:1.0 +/(-(-(offset,%*%(mask,parsertemp131910)),1.0),%*%(==(outBucket,%*%(parsertemp132092,parsertemp131907)),rightHist)) +::STMT +MATRIX:r,Hd +FLOAT:c +LITERAL_FLOAT:-1.0 +*(+(r,*(c,Hd)),-1.0) +::STMT +MATRIX:X +FLOAT:parsertemp496694,a0 +LITERAL_FLOAT:2.0 ++(parsertemp496694,/(^(cast.FLOAT(X),2.0),a0)) +::STMT +MATRIX:parsertemp379560,m_iter_err_sum,m_err +LITERAL_FLOAT:-1.0 +*(-(t(+(parsertemp379560,m_iter_err_sum)),+(colSums(m_err),m_iter_err_sum)),-1.0) +::STMT +FLOAT:parsertemp13703 +LITERAL_FLOAT:999.0,1000.0 +/(*(parsertemp13703,1000.0),999.0) +::STMT +MATRIX:W +FLOAT:parsertemp112,int190,parsertemp91 +LITERAL_FLOAT:2.0,3.0,4.0,5.0 +/(*(*(4.0,-(parsertemp112,int190)),^(sqrt(parsertemp91),2.0)),*(+(sum(W),5.0),-(sum(W),3.0))) +::STMT +MATRIX:parsertemp379566 +FLOAT:int699,i_process_item +LITERAL_FLOAT:2.0 +*(^(/(*(parsertemp379566,int699),i_process_item),2.0),i_process_item) +::STMT +MATRIX:Xm,Z,parsertemp265732 +/(sum(-(%*%(Z,parsertemp265732),Xm)),sum(Xm)) +::STMT +MATRIX:parsertemp396406,W3_rand +FLOAT:int564,int269 +LITERAL_FLOAT:0.16823164622761327 +%*%(*(0.16823164622761327,W3_rand),t(/(-(parsertemp396406,int564),+(parsertemp396406,int269)))) +::STMT +MATRIX:D,ZERODIAG,beta +FLOAT:int694 +*(exp(*(-(int694,D),beta)),ZERODIAG) +::STMT +LITERAL_FLOAT:3352500.0 +3352500.0 +::STMT +MATRIX:parsertemp171366,p_one_m_one +LITERAL_FLOAT:3.141592653589793,0.5 ++(0.5,/(%*%(parsertemp171366,p_one_m_one),3.141592653589793)) +::STMT +FLOAT:K +LITERAL_FLOAT:151.0 +*(151.0,K) +::STMT +MATRIX:r,c,E,F +FLOAT:int785 +LITERAL_FLOAT:1.0E-4 +-(F,+(*(==(E,int785),1.0E-4),/(%*%(r,c),sum(F)))) +::STMT +MATRIX:W +LITERAL_FLOAT:1.0,2.0,3.0 +*(*(-(sum(W),1.0),-(sum(W),2.0)),-(sum(round(W)),3.0)) +::STMT +MATRIX:os,y,o +LITERAL_FLOAT:0.0 +*(-(0.0,y),+(o,os)) +::STMT +MATRIX:B1 +FLOAT:nc +LITERAL_FLOAT:1.0 +/(nrow(B1),-(nc,1.0)) +::STMT +MATRIX:cumLens +FLOAT:i +LITERAL_FLOAT:1.0 +/(-(i,1.0),cumLens) +::STMT +MATRIX:W,H,parsertemp411100,parsertemp411104,parsertemp411105 +%*%(W,%*%(*(H,/(parsertemp411100,parsertemp411104)),t(*(H,parsertemp411105)))) +::STMT +MATRIX:p,z +FLOAT:pp,parsertemp169870,pz +LITERAL_FLOAT:-1.0 +-(*(sum(*(p,z)),-1.0),sqrt(-(*(pz,pz),*(pp,parsertemp169870)))) +::STMT +MATRIX:parsertemp185168,parsertemp185169,parsertemp185166,parsertemp185165 +>(-(parsertemp185165,parsertemp185166),-(parsertemp185168,parsertemp185169)) +::STMT +MATRIX:d_r,parsertemp409781 +sum(*(rev(d_r),parsertemp409781)) +::STMT +FLOAT:norm_grad_initial +LITERAL_FLOAT:0.001 +*(0.001,norm_grad_initial) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:1.0,2.0 +/(/(1.0,^(linear_terms,2.0)),2.0) +::STMT +MATRIX:r_CG,g_reg,z +*(cast.FLOAT(z),+(cast.FLOAT(r_CG),cast.FLOAT(g_reg))) +::STMT +MATRIX:selCols,selCols2 +-(sum(selCols),sum(selCols2)) +::STMT +MATRIX:_sbcvar92,220_r,220_c,220_E +FLOAT:int65 +LITERAL_FLOAT:1.0E-4 +-(_sbcvar92,+(*(==(220_E,int65),1.0E-4),/(%*%(220_r,220_c),sum(_sbcvar92)))) +::STMT +MATRIX:parsertemp16875 +FLOAT:epsilon +*(<(sqrt(rowSums(parsertemp16875)),epsilon),epsilon) +::STMT +MATRIX:s +LITERAL_FLOAT:2.0 +^(s,2.0) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:-0.0 +^(linear_terms,-0.0) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:-2.0 +^(linear_terms,-2.0) +::STMT +MATRIX:t_gp,pt_gp,parsertemp171320,Y,the_gauss_exp,parsertemp171316 +FLOAT:one_over_sqrt_two_pi,int5 +LITERAL_FLOAT:2.0,0.25 +/(*(*(exp(parsertemp171320),^(one_over_sqrt_two_pi,int5)),rowSums(Y)),*(*(0.25,*(t_gp,parsertemp171316)),-(2.0,*(the_gauss_exp,pt_gp)))) +::STMT +MATRIX:X +LITERAL_FLOAT:2.0 +t(colSums(^(X,2.0))) +::STMT +MATRIX:p,r,Z +FLOAT:norm_r2,parsertemp503396 +LITERAL_FLOAT:0.0 +-(0.0,+(r,*(/(norm_r2,parsertemp503396),%*%(Z,p)))) +::STMT +MATRIX:resp,X,weight +/(%*%(t(resp),*(X,X)),t(weight)) +::STMT +MATRIX:parsertemp472180,I,yhat +LITERAL_FLOAT:2.0 +rowSums(^(*(I,-(yhat,parsertemp472180)),2.0)) +::STMT +MATRIX:p,parsertemp285529,g +FLOAT:pp,pq,int710,pz,parsertemp285543,parsertemp285521 +*(+(+(*(parsertemp285543,pq),sum(parsertemp285529)),sum(*(g,p))),/(-(*(pz,int710),sqrt(parsertemp285521)),pp)) +::STMT +MATRIX:parsertemp220902,parsertemp220903 +FLOAT:tol +LITERAL_FLOAT:2.0 +*(sum(^(-(parsertemp220902,parsertemp220903),2.0)),tol) +::STMT +FLOAT:ssPrev,parsertemp265725,parsertemp265724 +LITERAL_FLOAT:1.0,4000.0 +-(1.0,/(/(-(parsertemp265724,parsertemp265725),4000.0),ssPrev)) +::STMT +LITERAL_FLOAT:0.0,1.0,2.0 +INT:D,M +*(rand(D,M,0.0,1.0),sqrt(/(2.0,D))) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:-1.0 +^(linear_terms,-1.0) +::STMT +MATRIX:e_r_rev_agg,select,d_r_rev,X_rev_agg +colSums(/(*(%*%(select,X_rev_agg),d_r_rev),e_r_rev_agg)) +::STMT +MATRIX:Y +FLOAT:num_categories +LITERAL_FLOAT:0.0,-1.0 +*(+(*(Y,-1.0),num_categories),<=(Y,0.0)) +::STMT +MATRIX:X +FLOAT:x +/(-(x,X),-(X,X)) +::STMT +MATRIX:G,authorities,hubs +-(/(%*%(G,authorities),max(%*%(G,authorities))),hubs) +::STMT +MATRIX:W1_rand,stds,parsertemp396314 +LITERAL_FLOAT:0.07808688094430302 +t(%*%(*(0.07808688094430302,W1_rand),t(/(parsertemp396314,stds)))) +::STMT +MATRIX:dist +FLOAT:i +LITERAL_FLOAT:1.0 +-(+(i,cast.FLOAT(dist)),1.0) +::STMT +MATRIX:residual_matrix +FLOAT:273_lambda ++(nrow(residual_matrix),273_lambda) +::STMT +MATRIX:diff_nominal,diff,_sbcvar1151 +FLOAT:num_std_median +LITERAL_FLOAT:0.0 ++(*(!=(diff_nominal,0.0),num_std_median),*(diff,_sbcvar1151)) +::STMT +MATRIX:Xd,parsertemp2775 +FLOAT:int811 +LITERAL_FLOAT:0.0 +*(*(Xd,>(-(int811,parsertemp2775),0.0)),Xd) +::STMT +MATRIX:Y_counts,means,parsertemp560511 +*(Y_counts,rowSums(*(means,parsertemp560511))) +::STMT +MATRIX:col,parsertemp24101,parsertemp24103 +FLOAT:int720,num_bins,float276,int627 +LITERAL_FLOAT:1.0 +*(-(-(1.0,<(col,int720)),>(+(parsertemp24103,int627),num_bins)),+(round(-(parsertemp24101,float276)),1.0)) +::STMT +MATRIX:n_event_stratum,n_risk_stratum,n_risk +LITERAL_FLOAT:1.0 +/(*(*(n_risk,n_event_stratum),-(n_risk_stratum,n_event_stratum)),*(n_risk_stratum,-(n_risk_stratum,1.0))) +::STMT +MATRIX:Y +FLOAT:num_categories,int206 +LITERAL_FLOAT:0.0 ++(Y,*(+(*(Y,int206),num_categories),<=(Y,0.0))) +::STMT +MATRIX:parsertemp409723,R +LITERAL_FLOAT:1.0 +-(+(cast.FLOAT(parsertemp409723),cast.FLOAT(R)),1.0) +::STMT +MATRIX:linear_terms +FLOAT:var_power +LITERAL_FLOAT:2.0 +exp(*(linear_terms,-(2.0,var_power))) +::STMT +MATRIX:parsertemp195898 +FLOAT:int22,parsertemp195894,factor_up +abs(-(/(parsertemp195898,factor_up),/(/(parsertemp195894,int22),factor_up))) +::STMT +FLOAT:index +LITERAL_FLOAT:2.0,3.0,4.0 ++(+(*(index,4.0),2.0),3.0) +::STMT +MATRIX:x,y +LITERAL_FLOAT:2.0 +cast.FLOAT(/(+(x,y),2.0)) +::STMT +MATRIX:w,out +LITERAL_FLOAT:1.0,0.5 +*(1.0,+(*(0.5,cast.FLOAT(out)),*(1.0,cast.FLOAT(w)))) +::STMT +MATRIX:V,W,H,parsertemp10738 +LITERAL_FLOAT:1.0E-8 +/(%*%(t(W),V),+(%*%(%*%(parsertemp10738,W),H),1.0E-8)) +::STMT +LITERAL_FLOAT:1.0 ++(1.0,1.0) +::STMT +FLOAT:link_power +LITERAL_FLOAT:1.0 +/(1.0,link_power) +::STMT +FLOAT:link_power +LITERAL_FLOAT:-1.0 +/(-1.0,link_power) +::STMT +MATRIX:y +FLOAT:beta +LITERAL_FLOAT:2.0 +sum(^(-(beta,y),2.0)) +::STMT +MATRIX:dout,parsertemp555660,parsertemp555659 +FLOAT:int582,int684 +LITERAL_FLOAT:1.0 +*(*(/(1.0,+(int582,parsertemp555659)),-(1.0,/(int684,parsertemp555660))),dout) +::STMT +MATRIX:p,e,u,G +FLOAT:alpha +LITERAL_FLOAT:1.0 ++(*(alpha,%*%(G,p)),*(-(1.0,alpha),%*%(%*%(e,u),p))) +::STMT +MATRIX:X +FLOAT:val +<=(X,val) +::STMT +MATRIX:prob,pred,test_Y +FLOAT:threshold ++(*(pred,>(prob,threshold)),*(test_Y,<=(prob,threshold))) +::STMT +MATRIX:parsertemp79022 +LITERAL_FLOAT:0.5,1270.0 ++(0.5,/(parsertemp79022,1270.0)) +::STMT +MATRIX:X +FLOAT:397_C +*(nrow(X),/(ncol(X),397_C)) +::STMT +MATRIX:output_values +FLOAT:log_odds +LITERAL_FLOAT:0.3,2.7182818284 +^(2.7182818284,+(log_odds,*(0.3,cast.FLOAT(output_values)))) +::STMT +LITERAL_FLOAT:0.0,1.0 ++(1.0,0.0) +::STMT +MATRIX:X2p +LITERAL_FLOAT:0.0 +>(t(colSums(X2p)),0.0) +::STMT +MATRIX:p,parsertemp169865,z +FLOAT:pp,trust_delta_sq +-(*(sum(*(p,z)),sum(*(p,z))),*(pp,-(sum(parsertemp169865),trust_delta_sq))) +::STMT +MATRIX:s,d,alpha_deno +FLOAT:norm_r2 ++(s,*(cast.FLOAT(/(norm_r2,alpha_deno)),d)) +::STMT +LITERAL_FLOAT:0.0,1.0 +INT:int434,int699,int424,int815 +%*%(rand(int434,int699,0.0,1.0),rand(int424,int815,0.0,1.0)) +::STMT +MATRIX:p,p2 +LITERAL_FLOAT:1.0E8 +sum(>(abs(-(p2,p)),1.0E8)) +::STMT +MATRIX:parsertemp171090,is_one_y_corr,t,parsertemp171099,parsertemp171096 +FLOAT:int352,float868 +LITERAL_FLOAT:1.0 ++(*(+(*(t,int352),/(parsertemp171090,parsertemp171096)),-(1.0,*(float868,parsertemp171099))),/(is_one_y_corr,-(1.0,is_one_y_corr))) +::STMT +MATRIX:parsertemp387409,Ks,Kss +abs(cast.FLOAT(-(Kss,%*%(parsertemp387409,Ks)))) +::STMT +MATRIX:Y +LITERAL_FLOAT:0.0,1.0 +-(+(nrow(Y),0.0),1.0) +::STMT +MATRIX:parsertemp170247,t_gp,parsertemp170252,lt_pos_neg,parsertemp170239 +FLOAT:float539,float739 +LITERAL_FLOAT:1.0,0.5,0.254829592 +*(*(-(0.5,lt_pos_neg),exp(/(parsertemp170252,float739))),*(/(1.0,+(float539,parsertemp170239)),+(0.254829592,*(t_gp,parsertemp170247)))) +::STMT +MATRIX:T_1,parsertemp410245,event,parsertemp410248 +FLOAT:float847,int506 +LITERAL_FLOAT:0.6666666666666666 +/(^(/(-(int506,parsertemp410245),*(float847,parsertemp410248)),0.6666666666666666),/(-(max(T_1),min(T_1)),sum(event))) +::STMT +MATRIX:classes +LITERAL_FLOAT:0.19999999999999996 +*(cast.FLOAT(classes),0.19999999999999996) +::STMT +FLOAT:ytest,int816 +LITERAL_FLOAT:1.0,2.0 +-(^(cast.FLOAT(ytest),2.0),*(1.0,^(/(ytest,int816),2.0))) +::STMT +LITERAL_FLOAT:2.0,7000.0 +^(7000.0,2.0) +::STMT +MATRIX:r,scale_X,shift_X,y,parsertemp116003 +FLOAT:int428 ++(*(scale_X,%*%(-(int428,parsertemp116003),y)),*(cast.FLOAT(r),shift_X)) +::STMT +MATRIX:d,X,logisticD +LITERAL_FLOAT:2.0 +*(2.0,%*%(t(X),*(logisticD,%*%(X,d)))) +::STMT +MATRIX:tmp_Xw,Y,Xd +LITERAL_FLOAT:0.0,1.0 +*(Xd,>(-(1.0,*(Y,tmp_Xw)),0.0)) +::STMT +MATRIX:W,H +LITERAL_FLOAT:1.0E-8 ++(%*%(%*%(t(W),W),H),1.0E-8) +::STMT +MATRIX:upd_W1 +LITERAL_FLOAT:0.95 +*(0.95,upd_W1) +::STMT +MATRIX:W +LITERAL_FLOAT:1.0,2.0 +*(-(sum(round(W)),2.0),+(sum(round(W)),1.0)) +::STMT +MATRIX:parsertemp222418,parsertemp222424 +FLOAT:sample_block_size +LITERAL_FLOAT:1.0 ++(*(sample_block_size,parsertemp222424),+(t(colSums(parsertemp222418)),1.0)) +::STMT +MATRIX:parsertemp254737 +FLOAT:parsertemp254766,2124_sq_root_d,float33,parsertemp254751 ++(float33,*(parsertemp254766,/(+(parsertemp254751,2124_sq_root_d),sum(parsertemp254737)))) +::STMT +MATRIX:parsertemp389328,parsertemp389331 +LITERAL_FLOAT:1.0 +t(/(-(exp(parsertemp389328),1.0),+(exp(parsertemp389331),1.0))) +::STMT +MATRIX:M +-(M,max(M)) +::STMT +MATRIX:img_in1,img_in2 +FLOAT:weight +LITERAL_FLOAT:1.0 ++(*(-(1.0,weight),img_in1),*(weight,img_in2)) +::STMT +MATRIX:parsertemp43993,os,d,X,alpha_deno ++(os,*(/(sum(parsertemp43993),cast.FLOAT(alpha_deno)),%*%(X,d))) +::STMT +FLOAT:n,norm +LITERAL_FLOAT:-2.0 +*(*(-2.0,norm),n) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0 +*($1:ncol(X),+($1,1.0)) +::STMT +MATRIX:parsertemp265709,tmp,parsertemp265718,parsertemp265714 +FLOAT:Xm +LITERAL_FLOAT:2.0 +-(+(Xm,trace(*(tmp,parsertemp265714))),*(2.0,cast.FLOAT(%*%(parsertemp265718,parsertemp265709)))) +::STMT +MATRIX:minD,D,parsertemp222603,parsertemp222600 +colSums(/(<=(+(parsertemp222600,parsertemp222603),minD),rowSums(<=(D,minD)))) +::STMT +FLOAT:i +LITERAL_FLOAT:48.0 ++(48.0,i) +::STMT +MATRIX:log_prob,X +FLOAT:parsertemp436712 ++(*(ncol(X),parsertemp436712),log_prob) +::STMT +MATRIX:176_mask,W2,175_out +FLOAT:p +%*%(/(*(175_out,176_mask),p),W2) +::STMT +MATRIX:r,parsertemp1936,parsertemp1937 +FLOAT:parsertemp1941,norm_r2 +LITERAL_FLOAT:2.0 +^(+(r,*(/(norm_r2,parsertemp1941),+(parsertemp1936,parsertemp1937))),2.0) +::STMT +MATRIX:y_corr +LITERAL_FLOAT:2.0,0.5 +*(2.0,>(y_corr,0.5)) +::STMT +FLOAT:m2,float572,wt +LITERAL_FLOAT:4.0 +^(sqrt(/(*(m2,wt),-(wt,float572))),4.0) +::STMT +MATRIX:C,Xm,parsertemp265707,parsertemp265705,parsertemp265713 ++(sum(*(Xm,Xm)),trace(*(+(parsertemp265705,parsertemp265707),%*%(parsertemp265713,C)))) +::STMT +MATRIX:w_X,z_LS +LITERAL_FLOAT:2000.0 +/(2000.0,cast.FLOAT(%*%(t(w_X),z_LS))) +::STMT +FLOAT:float15,m2,wt +LITERAL_FLOAT:3.0 +^(sqrt(/(*(m2,wt),-(wt,float15))),3.0) +::STMT +MATRIX:n_risk_stratum +LITERAL_FLOAT:1.0 +*(n_risk_stratum,-(n_risk_stratum,1.0)) +::STMT +MATRIX:parsertemp498242,m_iter_err_sum,m_err +LITERAL_FLOAT:0.0 +-(0.0,-(t(+(parsertemp498242,m_iter_err_sum)),+(colSums(m_err),m_iter_err_sum))) +::STMT +MATRIX:col,more_than_ub,parsertemp24107,parsertemp24102,parsertemp24103 +FLOAT:int33,num_bins +LITERAL_FLOAT:1.0 ++(+(*(-(parsertemp24107,more_than_ub),+(parsertemp24103,int33)),*(>(col,num_bins),num_bins)),<(+(round(parsertemp24102),1.0),1.0)) +::STMT +MATRIX:R,S,Grad +LITERAL_FLOAT:-0.5 +*(-0.5,-(sum(*(S,Grad)),sum(*(S,R)))) +::STMT +MATRIX:sample_rec_ids +FLOAT:num_records +*(sample_rec_ids,<=(sample_rec_ids,num_records)) +::STMT +MATRIX:X +LITERAL_FLOAT:8.0 +==(X,8.0) +::STMT +LITERAL_FLOAT:990000.0 +990000.0 +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,11.0 ++(*(-(i,1.0),11.0),11.0) +::STMT +MATRIX:lambda,parsertemp171475 +FLOAT:new_log_l +LITERAL_FLOAT:0.5 +-(new_log_l,*(0.5,sum(*(lambda,parsertemp171475)))) +::STMT +MATRIX:parsertemp31112,parsertemp31114,parsertemp31105,parsertemp31107 +FLOAT:int146,int788,int637,int150 +LITERAL_FLOAT:1500.0,2000.0 ++(/(/(-(parsertemp31105,parsertemp31107),-(int150,int637)),2000.0),/(/(-(parsertemp31112,parsertemp31114),-(int788,int146)),1500.0)) +::STMT +MATRIX:Xi,X_rev_2 +*(X_rev_2,rev(Xi)) +::STMT +FLOAT:var_lag,xq_lag,arch_coef,var_coef,a0 ++(+(a0,*(arch_coef,xq_lag)),*(var_coef,var_lag)) +::STMT +MATRIX:minD,parsertemp72030,parsertemp72033,parsertemp72034,parsertemp72031 +FLOAT:int588 +/(<=(+(*(int588,parsertemp72030),t(parsertemp72033)),minD),rowSums(<=(+(parsertemp72031,parsertemp72034),minD))) +::STMT +MATRIX:G +!=(rowSums(G),t(colSums(G))) +::STMT +MATRIX:e,X,tS +FLOAT:l +*(==(%*%(X,tS),l),e) +::STMT +FLOAT:cmLabels +LITERAL_FLOAT:1.0,10000.0 +*(cmLabels,/(10000.0,-(10000.0,1.0))) +::STMT +MATRIX:y_corr +LITERAL_FLOAT:0.0,1.0 +/(==(y_corr,0.0),-(1.0,==(y_corr,0.0))) +::STMT +LITERAL_FLOAT:3.37275E9 +3.37275E9 +::STMT +FLOAT:i +LITERAL_FLOAT:96.0 ++(96.0,i) +::STMT +MATRIX:p_CG +FLOAT:parsertemp170089,z,pp_CG +LITERAL_FLOAT:-1.0 ++(*(*(cast.FLOAT(z),sum(p_CG)),-1.0),sqrt(-(*(z,z),*(pp_CG,parsertemp170089)))) +::STMT +MATRIX:V +t(V) +::STMT +MATRIX:ssX_p_CG,shift_X,p_CG ++(ssX_p_CG,cast.FLOAT(%*%(t(shift_X),p_CG))) +::STMT +MATRIX:U,V,X,parsertemp382841,row_nonzeros +FLOAT:int259 +LITERAL_FLOAT:1.0E-6 ++(%*%(*(!=(X,int259),-(parsertemp382841,X)),V),*(*(1.0E-6,U),row_nonzeros)) +::STMT +FLOAT:i,n +LITERAL_FLOAT:1.0,11.0 +-(n,-(+(i,11.0),1.0)) +::STMT +LITERAL_FLOAT:1.061405429 +1.061405429 +::STMT +MATRIX:X +LITERAL_FLOAT:1.0E-6 +*(1.0E-6,nrow(X)) +::STMT +MATRIX:m_active_flag_tmp,m_active_flag +LITERAL_FLOAT:1.0 +sum(-(>=(+(m_active_flag,m_active_flag_tmp),1.0),1.0)) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,12.0 ++(*(-(i,1.0),12.0),1.0) +::STMT +LITERAL_FLOAT:0.10938070012761454 +0.10938070012761454 +::STMT +MATRIX:prevTK2,totalE,X2 +%*%(t(totalE),==(%*%(X2,t(prevTK2)),t(rowSums(prevTK2)))) +::STMT +FLOAT:X +LITERAL_FLOAT:50.0,1.0E-6 +/(*(1.0E-6,X),50.0) +::STMT +MATRIX:os,d,X,alpha_deno +FLOAT:norm_r2 ++(os,*(cast.FLOAT(/(norm_r2,alpha_deno)),%*%(X,d))) +::STMT +MATRIX:M2 +LITERAL_FLOAT:0.0 +!(!=(M2,0.0)) +::STMT +MATRIX:S,parsertemp175056 +exp(-(S,parsertemp175056)) +::STMT +MATRIX:R +LITERAL_FLOAT:0.0 +==(colSums(!=(R,0.0)),0.0) +::STMT +MATRIX:parsertemp472147,I,y2 +%*%(/(%*%(I,y2),sum(I)),parsertemp472147) +::STMT +MATRIX:lambda,parsertemp149401,parsertemp149400,B_new +LITERAL_FLOAT:2.0 +sum(^(+(%*%(parsertemp149400,parsertemp149401),*(lambda,B_new)),2.0)) +::STMT +MATRIX:lambda +FLOAT:newbeta,new_log_l,int183 +LITERAL_FLOAT:0.5 +-(new_log_l,*(0.5,*(cast.FLOAT(lambda),^(newbeta,int183)))) +::STMT +MATRIX:2846_Q,X +LITERAL_FLOAT:2.0 ++(rowSums(^(X,2.0)),sum(^(2846_Q,2.0))) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:Infinity +==(linear_terms,Infinity) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:-Infinity +==(linear_terms,-Infinity) +::STMT +MATRIX:mask +LITERAL_FLOAT:1.0 +==(mask,1.0) +::STMT +FLOAT:X +LITERAL_FLOAT:1.0E-6,100.0 +/(*(1.0E-6,X),100.0) +::STMT +FLOAT:j +LITERAL_FLOAT:4.0 +-(4.0,j) +::STMT +MATRIX:parsertemp195898 +FLOAT:parsertemp195893,int52,factor_up +LITERAL_FLOAT:2.0 +-(/(parsertemp195898,factor_up),/(/(-(parsertemp195893,int52),2.0),factor_up)) +::STMT +MATRIX:T,parsertemp537734 +LITERAL_FLOAT:0.0 +sum(==(%*%(parsertemp537734,T),0.0)) +::STMT +MATRIX:X +FLOAT:m2X,float920,W +sqrt(*(m2X,/(nrow(X),-(W,float920)))) +::STMT +MATRIX:parsertemp385504 +LITERAL_FLOAT:0.0,6.0 +-(6.0,sum(!=(t(parsertemp385504),0.0))) +::STMT +MATRIX:w_X,z_LS,X +*(/(nrow(X),sum(*(w_X,z_LS))),z_LS) +::STMT +MATRIX:parsertemp31104,parsertemp31106 +FLOAT:int975 +LITERAL_FLOAT:1.0,2.0,2000.0 +^(/(-(colSums(parsertemp31104),*(int975,parsertemp31106)),-(2000.0,1.0)),2.0) +::STMT +FLOAT:i,n +LITERAL_FLOAT:1.0,12.0 +-(n,-(+(i,12.0),1.0)) +::STMT +MATRIX:parsertemp195899,parsertemp195900 +FLOAT:center +LITERAL_FLOAT:1.0 +%*%(-(1.0,abs(-(parsertemp195899,center))),t(-(1.0,abs(parsertemp195900)))) +::STMT +MATRIX:p,parsertemp1597,beta_unscaled +FLOAT:norm_r2 ++(beta_unscaled,*(/(norm_r2,sum(parsertemp1597)),p)) +::STMT +MATRIX:parsertemp174552 +LITERAL_FLOAT:0.0 +==(parsertemp174552,0.0) +::STMT +MATRIX:linear_terms +FLOAT:var_power +LITERAL_FLOAT:1.0,2.0 +/(/(1.0,linear_terms),-(2.0,var_power)) +::STMT +MATRIX:ss +LITERAL_FLOAT:20.0 +/(20.0,ss) +::STMT +MATRIX:X,Y +FLOAT:eps ++(+(abs(X),abs(Y)),eps) +::STMT +MATRIX:parsertemp146974,mW1,190_dW,parsertemp146977 +FLOAT:parsertemp146983,191_lr,parsertemp146981,int10,191_beta1,parsertemp146971,191_epsilon +/(*(/(*(191_lr,parsertemp146981),-(int10,parsertemp146983)),+(*(191_beta1,mW1),*(parsertemp146971,190_dW))),+(sqrt(+(parsertemp146974,parsertemp146977)),191_epsilon)) +::STMT +MATRIX:lambda,beta +LITERAL_FLOAT:0.0 ++(0.0,*(cast.FLOAT(lambda),cast.FLOAT(beta))) +::STMT +MATRIX:parsertemp31030,parsertemp31032 +FLOAT:int387,int994 +LITERAL_FLOAT:1.0,2.0,150.0 +/(^(/(-(parsertemp31030,parsertemp31032),-(int994,int387)),2.0),*(^(150.0,2.0),-(150.0,1.0))) +::STMT +MATRIX:C,Xm,parsertemp265701 +%*%(t(%*%(Xm,%*%(C,parsertemp265701))),%*%(Xm,%*%(C,parsertemp265701))) +::STMT +MATRIX:g_reg,g,parsertemp285556 +sqrt(cast.FLOAT(%*%(t(g_reg),+(g,parsertemp285556)))) +::STMT +MATRIX:linear_terms +FLOAT:link_power +LITERAL_FLOAT:1.0 +^(linear_terms,-(/(1.0,link_power),1.0)) +::STMT +MATRIX:linear_terms +FLOAT:link_power +LITERAL_FLOAT:-1.0,1.0 +^(linear_terms,-(/(-1.0,link_power),1.0)) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:1.0E7 ++(exp(linear_terms),==(+(1.0E7,exp(linear_terms)),1.0E7)) +::STMT +MATRIX:parsertemp10744,W,H +LITERAL_FLOAT:1.0E-8 ++(%*%(W,%*%(*(H,parsertemp10744),t(H))),1.0E-8) +::STMT +FLOAT:int53 +LITERAL_FLOAT:0.0 +INT:int403,m +rand(m,int403,0.0,int53) +::STMT +MATRIX:Xi_X_rev_agg,e_r_rev_agg,select,Xi_agg_rev_agg,X_agg +LITERAL_FLOAT:2.0 +-(/(%*%(select,Xi_X_rev_agg),e_r_rev_agg),/(*(X_agg,Xi_agg_rev_agg),^(e_r_rev_agg,2.0))) +::STMT +MATRIX:err,cCnts +FLOAT:minSup +LITERAL_FLOAT:0.0 +sum(|(<(cCnts,minSup),==(err,0.0))) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0004995004995005,4.0 +^(sqrt(*(1.0004995004995005,m2)),4.0) +::STMT +MATRIX:linear_terms +FLOAT:link_power +LITERAL_FLOAT:1.0,2.0 +^(linear_terms,-(/(1.0,link_power),2.0)) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0004995004995005,3.0 +^(sqrt(*(1.0004995004995005,m2)),3.0) +::STMT +MATRIX:is_zero_y_corr,is_one_y_corr,parsertemp171113 +FLOAT:parsertemp171116 +LITERAL_FLOAT:1.0 ++(-(parsertemp171113,*(parsertemp171116,+(is_zero_y_corr,is_one_y_corr))),/(is_one_y_corr,-(1.0,is_one_y_corr))) +::STMT +MATRIX:ZtZ,Xm,parsertemp265709,parsertemp265706,Z,parsertemp265702,XtZ +FLOAT:ss,ZtZ_sum +*(+(%*%(t(Z),%*%(Xm,parsertemp265702)),*(parsertemp265706,ss)),%*%(t(/(XtZ,ZtZ_sum)),/(%*%(parsertemp265709,Z),sum(ZtZ)))) +::STMT +MATRIX:tmp +FLOAT:N +LITERAL_FLOAT:0.0,1.0 +<=(/(tmp,-(N,1.0)),0.0) +::STMT +MATRIX:CFreqs1 +LITERAL_FLOAT:0.0,1.0 +diag(-(1.0,==(CFreqs1,0.0))) +::STMT +MATRIX:y_hat,X +sum(*(-(X,y_hat),-(X,y_hat))) +::STMT +MATRIX:parsertemp31024,parsertemp31022 +FLOAT:int346 +LITERAL_FLOAT:99.0,100.0 +/(/(-(colSums(parsertemp31022),*(int346,parsertemp31024)),99.0),100.0) +::STMT +FLOAT:D +LITERAL_FLOAT:2.0 +sqrt(/(2.0,D)) +::STMT +MATRIX:lengths +abs(-(cast.FLOAT(lengths),cast.FLOAT(lengths))) +::STMT +MATRIX:X,Y +LITERAL_FLOAT:2.0 +^(-(X,Y),2.0) +::STMT +MATRIX:resp,Y,parsertemp506189 +==(+(resp,t(parsertemp506189)),Y) +::STMT +MATRIX:e_r_rev_agg,parsertemp409787,parsertemp409796 +LITERAL_FLOAT:-1.0 ++(*(t(colSums(parsertemp409787)),-1.0),t(colSums(/(parsertemp409796,e_r_rev_agg)))) +::STMT +MATRIX:X,Centering,ScaleFactor +colSums(/(-(X,Centering),ScaleFactor)) +::STMT +MATRIX:parsertemp402079,W3_rand,parsertemp402082 +LITERAL_FLOAT:0.1092173494617922 +t(%*%(*(0.1092173494617922,W3_rand),t(/(parsertemp402079,parsertemp402082)))) +::STMT +MATRIX:parsertemp76118 +LITERAL_FLOAT:4460.0 +/(parsertemp76118,4460.0) +::STMT +MATRIX:W,Y,sumW +LITERAL_FLOAT:300.0,0.0 +-(0.0,*(300.0,-(*(Y,sumW),%*%(W,Y)))) +::STMT +MATRIX:grad +LITERAL_FLOAT:-1.0 +sum(*(*(grad,-1.0),*(grad,-1.0))) +::STMT +MATRIX:Kss,parsertemp387410 +sqrt(abs(cast.FLOAT(-(Kss,parsertemp387410)))) +::STMT +MATRIX:img +FLOAT:Hf,Wf +*(*(nrow(img),Hf),Wf) +::STMT +MATRIX:z +sqrt(sum(*(z,z))) +::STMT +MATRIX:p,V +FLOAT:eps ++(%*%(t(V),%*%(V,p)),*(eps,p)) +::STMT +MATRIX:posSamples,posSampleMeans +FLOAT:int892,int522 +LITERAL_FLOAT:1999.0,2000.0 +/(-(colSums(^(posSamples,int892)),*(2000.0,^(posSampleMeans,int522))),1999.0) +::STMT +MATRIX:parsertemp2832 +==(round(parsertemp2832),min(round(parsertemp2832))) +::STMT +MATRIX:parsertemp77570 +LITERAL_FLOAT:2358.0 +/(parsertemp77570,2358.0) +::STMT +FLOAT:factor_up,parsertemp195891,parsertemp195892 +LITERAL_FLOAT:1.0,2.0 +/(/(-(-(parsertemp195891,parsertemp195892),1.0),2.0),factor_up) +::STMT +MATRIX:439_Ranks,parsertemp42225 +FLOAT:parsertemp42214,parsertemp42216,parsertemp42218,meanY,parsertemp42220 +/(sum(*(t(parsertemp42225),-(439_Ranks,meanY))),*(sqrt(*(parsertemp42214,parsertemp42216)),sqrt(*(parsertemp42218,parsertemp42220)))) +::STMT +FLOAT:ssPrev,parsertemp265725,parsertemp265724,m,n +LITERAL_FLOAT:1.0 +-(1.0,/(/(-(parsertemp265724,parsertemp265725),*(n,m)),ssPrev)) +::STMT +MATRIX:ytest,yhat +LITERAL_FLOAT:2.0 +^(-(cast.FLOAT(ytest),cast.FLOAT(yhat)),2.0) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:0.0 +-(0.0,exp(-(0.0,linear_terms))) +::STMT +MATRIX:parsertemp170240,parsertemp170238 +FLOAT:float911,float541 +LITERAL_FLOAT:1.0,1.061405429,-1.453152027 +*(/(1.0,+(1.0,*(parsertemp170238,float541))),+(-1.453152027,*(/(float911,parsertemp170240),1.061405429))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,12.0 ++(*(-(i,1.0),12.0),12.0) +::STMT +MATRIX:parsertemp389215,parsertemp389217 +LITERAL_FLOAT:1057.0,1058.0 +sqrt(/(*(-(parsertemp389215,parsertemp389217),1058.0),1057.0)) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:2.0 +/(exp(finite_linear_terms),2.0) +::STMT +MATRIX:A,CFreqs +-(nrow(A),nrow(CFreqs)) +::STMT +MATRIX:parsertemp129186,parsertemp129185,key_unique,key +t(==(%*%(key_unique,parsertemp129185),%*%(parsertemp129186,t(key)))) +::STMT +MATRIX:F +-(F,/(%*%(rowSums(F),colSums(F)),sum(F))) +::STMT +MATRIX:S,V,W +*(W,%*%(S,t(V))) +::STMT +MATRIX:parsertemp220853,Ws,beta +FLOAT:logU +LITERAL_FLOAT:0.0 +>=(-(+(parsertemp220853,*(beta,Ws)),logU),0.0) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,12.0,8.0 ++(*(-(i,1.0),12.0),8.0) +::STMT +MATRIX:grad +FLOAT:psi +*(psi,sqrt(sum(*(grad,grad)))) +::STMT +MATRIX:r,parsertemp44063,grad +LITERAL_FLOAT:-0.5 +*(-0.5,-(cast.FLOAT(%*%(parsertemp44063,grad)),cast.FLOAT(%*%(parsertemp44063,r)))) +::STMT +MATRIX:W +LITERAL_FLOAT:3.0,5.0 +*(+(sum(round(W)),5.0),-(sum(round(W)),3.0)) +::STMT +MATRIX:p,q,lambda +FLOAT:norm_r2 +*(/(norm_r2,sum(*(p,q))),+(q,*(lambda,p))) +::STMT +MATRIX:Q1,IQR +LITERAL_FLOAT:2.0 +-(Q1,*(2.0,IQR)) +::STMT +FLOAT:rho +LITERAL_FLOAT:10000.0 +*(10000.0,rho) +::STMT +MATRIX:r,parsertemp44063,parsertemp44065,grad +LITERAL_FLOAT:-0.5 +cast.FLOAT(*(-0.5,-(%*%(parsertemp44063,grad),%*%(parsertemp44065,r)))) +::STMT +FLOAT:cols,parsertemp451837 +LITERAL_FLOAT:1.0 ++(+(*(parsertemp451837,cols),1.0),cols) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:-1.0 +exp(*(exp(finite_linear_terms),-1.0)) +::STMT +MATRIX:parsertemp115947,TK +*(rowSums(TK),parsertemp115947) +::STMT +MATRIX:scale_lambda,parsertemp150455 +FLOAT:reg +*(%*%(scale_lambda,parsertemp150455),reg) +::STMT +MATRIX:inactive_set,w +LITERAL_FLOAT:0.0 +-(inactive_set,!=(w,0.0)) +::STMT +FLOAT:m2,mu,float907,wt +/(sqrt(/(*(m2,wt),-(wt,float907))),mu) +::STMT +MATRIX:valueCount,parsertemp552530,Y +FLOAT:int866,int933 +INT:parsertemp552529,idx +*(==(+(rand(parsertemp552529,idx,int933,int866),t(parsertemp552530)),Y),valueCount) +::STMT +MATRIX:prediction,target +/(-(prediction,target),nrow(target)) +::STMT +MATRIX:posSampleMeans +LITERAL_FLOAT:2.0,100.0 +*(100.0,^(posSampleMeans,2.0)) +::STMT +MATRIX:252_Y +FLOAT:252_X,float555 +LITERAL_FLOAT:1.0 +*(-(1.0,/(-(float555,252_X),-(252_X,252_X))),cast.FLOAT(252_Y)) +::STMT +FLOAT:vicinity,a0 +LITERAL_FLOAT:1.0 +*(-(1.0,vicinity),a0) +::STMT +MATRIX:Y +-(nrow(Y),sum(Y)) +::STMT +MATRIX:mu +FLOAT:window_size +*(window_size,cast.FLOAT(*(mu,mu))) +::STMT +MATRIX:parsertemp459193,2701_dX,vb3 +FLOAT:lr,mu +-(*(mu,vb3),*(lr,colSums(*(parsertemp459193,2701_dX)))) +::STMT +MATRIX:lambda,g,beta +LITERAL_FLOAT:0.0 +-(0.0,+(g,*(cast.FLOAT(lambda),cast.FLOAT(beta)))) +::STMT +MATRIX:parsertemp555752 +FLOAT:int398 +LITERAL_FLOAT:0.5 +sum(*(0.5,rowSums(^(parsertemp555752,int398)))) +::STMT +MATRIX:Xm,parsertemp265707,parsertemp265705,parsertemp265702 +t(/(%*%(t(Xm),%*%(Xm,parsertemp265702)),sum(+(parsertemp265705,parsertemp265707)))) +::STMT +MATRIX:parsertemp191275 +FLOAT:397_C +*(397_C,t(parsertemp191275)) +::STMT +MATRIX:ts +FLOAT:q ++(-(q,*(cast.FLOAT(ts),cast.FLOAT(ts))),*(cast.FLOAT(ts),cast.FLOAT(ts))) +::STMT +FLOAT:Z_logl +LITERAL_FLOAT:-1.0 +*(abs(Z_logl),-1.0) +::STMT +MATRIX:classFeatureCounts +FLOAT:numFeatures,laplaceCorrection ++(rowSums(classFeatureCounts),*(numFeatures,laplaceCorrection)) +::STMT +MATRIX:X +FLOAT:2917_split +round(*(nrow(X),2917_split)) +::STMT +FLOAT:parsertemp557354,parsertemp557358,prob_true,prob_false +LITERAL_FLOAT:0.6931471805599453 ++(/(*(prob_true,parsertemp557354),0.6931471805599453),/(*(prob_false,parsertemp557358),0.6931471805599453)) +::STMT +MATRIX:mn,mx +LITERAL_FLOAT:1.0 ++(-(mx,mn),1.0) +::STMT +MATRIX:parsertemp409803 +FLOAT:D +LITERAL_FLOAT:0.5 +/(*(0.5,sqrt(D)),max(sqrt(rowSums(parsertemp409803)))) +::STMT +MATRIX:r,parsertemp1945 +FLOAT:norm_r2 +/(sum(*(+(r,parsertemp1945),+(r,parsertemp1945))),norm_r2) +::STMT +FLOAT:x1,x2 +LITERAL_FLOAT:-1.0,2.0 +*(-1.0,^(-(x1,x2),2.0)) +::STMT +MATRIX:R,parsertemp40226 +FLOAT:eAvg +/(/(+(R,rowSums(parsertemp40226)),R),eAvg) +::STMT +MATRIX:V +max(V) +::STMT +MATRIX:Y_prob,Y,linear_terms +FLOAT:int926 +LITERAL_FLOAT:3.141592653589793,1.0 +*(*(*(rowSums(Y),Y_prob),Y_prob),*(+(1.0,^(linear_terms,int926)),3.141592653589793)) +::STMT +MATRIX:obj,objnew,gs +-(-(cast.FLOAT(objnew),cast.FLOAT(obj)),cast.FLOAT(gs)) +::STMT +MATRIX:prob,pred,test_Y +FLOAT:threshold +LITERAL_FLOAT:0.0 ++(*(pred,>(prob,threshold)),*(test_Y,==(>(prob,threshold),0.0))) +::STMT +FLOAT:K +LITERAL_FLOAT:300.0 +*(300.0,K) +::STMT +FLOAT:acc +LITERAL_FLOAT:1.0,100.0 +cast.MATRIX(-(1.0,/(acc,100.0))) +::STMT +MATRIX:u,minDist +!=(u,minDist) +::STMT +MATRIX:N_T,tmp,X +<=(rowSums(*(X,tmp)),%*%(tmp,t(N_T))) +::STMT +MATRIX:parsertemp32006,simplex +LITERAL_FLOAT:2.0,4.0 +-(*(2.0,/(-(parsertemp32006,simplex),4.0)),simplex) +::STMT +MATRIX:s,parsertemp44005,d +FLOAT:parsertemp44004 +cast.FLOAT(%*%(t(+(s,parsertemp44005)),+(s,*(parsertemp44004,d)))) +::STMT +MATRIX:parsertemp171348,is_too_small,parsertemp171346,parsertemp171344,parsertemp171353,linear_terms,Y,the_exp,parsertemp171349 +FLOAT:int369,int803 +/(*(*(exp(parsertemp171344),exp(linear_terms)),rowSums(Y)),+(/(*(parsertemp171348,parsertemp171349),+(the_exp,is_too_small)),*(==(parsertemp171346,int803),-(int369,parsertemp171353)))) +::STMT +MATRIX:betamax,parsertemp220870,Hpos,beta +FLOAT:INF,int237 +LITERAL_FLOAT:2.0 ++(*(*(*(int237,Hpos),==(betamax,INF)),beta),/(*(*(Hpos,parsertemp220870),+(beta,betamax)),2.0)) +::STMT +MATRIX:_sbcvar1782 +FLOAT:_sbcvar1783 +LITERAL_FLOAT:8.0 +/(_sbcvar1782,-(8.0,_sbcvar1783)) +::STMT +MATRIX:y_hat +FLOAT:parsertemp176421,k +-(sqrt(parsertemp176421),*(k,y_hat)) +::STMT +MATRIX:F +LITERAL_FLOAT:0.0 +==(/(%*%(rowSums(F),colSums(F)),sum(F)),0.0) +::STMT +MATRIX:w,X,y +FLOAT:int485,int701 +INT:int178,m +%*%(t(-(%*%(X,w),y)),-(%*%(X,rand(m,int178,int485,int701)),y)) +::STMT +MATRIX:X +LITERAL_FLOAT:480.0 +/(colSums(X),480.0) +::STMT +MATRIX:Yhat_prime,E +t(colSums(*(E,Yhat_prime))) +::STMT +MATRIX:t_gp,parsertemp171320,Y,linear_terms,parsertemp171316 +LITERAL_FLOAT:0.0,0.5 +*(*(*(exp(parsertemp171320),*(t_gp,parsertemp171316)),rowSums(Y)),-(>=(linear_terms,0.0),0.5)) +::STMT +FLOAT:prob_true,prob_false +LITERAL_FLOAT:1.0,2.0 +-(1.0,+(^(prob_true,2.0),^(prob_false,2.0))) +::STMT +LITERAL_FLOAT:1.0,100000.0 +-(100000.0,1.0) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,8.0 ++(*(-(i,1.0),8.0),1.0) +::STMT +MATRIX:2700_dX,parsertemp459190,2702_X +FLOAT:int389,lr +*(lr,colSums(*(>(2702_X,int389),*(parsertemp459190,2700_dX)))) +::STMT +MATRIX:R,B,parsertemp503364 +LITERAL_FLOAT:-1.0 +*(%*%(t(+(R,parsertemp503364)),B),-1.0) +::STMT +MATRIX:parsertemp230374 +t(t(parsertemp230374)) +::STMT +MATRIX:parsertemp409216,parsertemp409212,ctab +LITERAL_FLOAT:0.45 +*(parsertemp409216,>(/(parsertemp409212,rowSums(ctab)),0.45)) +::STMT +MATRIX:out2,184_probs,183_dpred,parsertemp146939,W3 +LITERAL_FLOAT:0.0 +*(>(out2,0.0),%*%(-(*(183_dpred,184_probs),*(184_probs,parsertemp146939)),t(W3))) +::STMT +FLOAT:n_components,n_features +LITERAL_FLOAT:1.0 +*(*(n_components,n_features),+(n_features,1.0)) +::STMT +MATRIX:parsertemp472298,I +LITERAL_FLOAT:0.0 +*(==(*(t(parsertemp472298),I),0.0),I) +::STMT +MATRIX:p,q,lambda +cast.FLOAT(%*%(t(p),+(q,*(lambda,p)))) +::STMT +LITERAL_FLOAT:0.999 +0.999 +::STMT +MATRIX:X +FLOAT:x +/(cast.FLOAT(-(x,X)),-(cast.FLOAT(X),cast.FLOAT(X))) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:1.0,2.0 +-(1.0,^(linear_terms,2.0)) +::STMT +MATRIX:output_values +LITERAL_FLOAT:0.3 +*(0.3,cast.FLOAT(output_values)) +::STMT +MATRIX:X,_sbcvar2948 +cast.FLOAT(%*%(t(-(X,_sbcvar2948)),-(X,_sbcvar2948))) +::STMT +MATRIX:P,parsertemp220889,Z,ZERODIAG,parsertemp220891 +FLOAT:int302,int765 +LITERAL_FLOAT:4.0 +-(*(P,4.0),/(*(/(int302,parsertemp220891),+(parsertemp220889,int765)),sum(*(Z,ZERODIAG)))) +::STMT +FLOAT:nc +LITERAL_FLOAT:1.0,20.0 +*(+(20.0,1.0),-(nc,1.0)) +::STMT +MATRIX:_sbcvar78,parsertemp22266 +FLOAT:int315 +LITERAL_FLOAT:2.0,10000.0 +/(^(-(_sbcvar78,/(parsertemp22266,int315)),2.0),/(%*%(rowSums(_sbcvar78),colSums(_sbcvar78)),10000.0)) +::STMT +MATRIX:negSampleMeans,negSamples +FLOAT:int318,int839 +LITERAL_FLOAT:149.0,150.0 +/(-(colSums(^(negSamples,int839)),*(150.0,^(negSampleMeans,int318))),149.0) +::STMT +MATRIX:A +LITERAL_FLOAT:1.0 ++(+(/(1.0,cast.FLOAT(A)),/(1.0,cast.FLOAT(A))),/(1.0,cast.FLOAT(A))) +::STMT +MATRIX:P,Q,parsertemp220896,Y,Z,ZERODIAG +-(*(Y,rowSums(*(parsertemp220896,Z))),%*%(*(-(P,Q),*(Z,ZERODIAG)),Y)) +::STMT +FLOAT:int496,parsertemp98,var,m4,parsertemp99,int864,parsertemp93,parsertemp94,wt,parsertemp105,parsertemp104 +LITERAL_FLOAT:4.0 +/(-(*(*(parsertemp93,parsertemp94),m4),*(*(parsertemp98,parsertemp99),-(wt,int496))),*(*(*(parsertemp104,parsertemp105),-(wt,int864)),^(sqrt(var),4.0))) +::STMT +MATRIX:parsertemp24101 +FLOAT:float99 +LITERAL_FLOAT:1.0 +<(+(round(-(parsertemp24101,float99)),1.0),1.0) +::STMT +MATRIX:parsertemp145796,y +LITERAL_FLOAT:-1.0 +rowSums(*(*(y,-1.0),parsertemp145796)) +::STMT +MATRIX:linear_terms +FLOAT:var_power,link_power,float123 +LITERAL_FLOAT:1.0 +/(^(linear_terms,/(-(float123,var_power),link_power)),-(1.0,var_power)) +::STMT +FLOAT:_sbcvar1751 +LITERAL_FLOAT:6.0 +-(6.0,_sbcvar1751) +::STMT +MATRIX:out2,parsertemp146940,184_dtemp,W3 +LITERAL_FLOAT:0.0 +colSums(*(>(out2,0.0),%*%(-(184_dtemp,parsertemp146940),t(W3)))) +::STMT +MATRIX:parsertemp555766,parsertemp555764,parsertemp555762,parsertemp555761,target +/(sum(-(*(parsertemp555761,parsertemp555762),*(parsertemp555764,parsertemp555766))),nrow(target)) +::STMT +MATRIX:parsertemp437192,parsertemp437191,parsertemp437190,mean,parsertemp437236,X,weight,parsertemp437188 +FLOAT:float202,int107 +LITERAL_FLOAT:2.0 ++(-(/(%*%(parsertemp437190,parsertemp437236),t(weight)),*(2.0,^(mean,int107))),/(*(/(parsertemp437191,parsertemp437192),%*%(parsertemp437190,X)),t(+(parsertemp437188,float202)))) +::STMT +MATRIX:parsertemp220896,W,Y,Z +LITERAL_FLOAT:300.0 +*(300.0,-(*(Y,rowSums(W)),%*%(*(parsertemp220896,Z),Y))) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:-1.0 +*(exp(*(exp(linear_terms),-1.0)),exp(linear_terms)) +::STMT +MATRIX:s +FLOAT:n +LITERAL_FLOAT:1.0 +*(/(1.0,s),n) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:3.141592653589793,1.0,2.0 +*(+(1.0,^(linear_terms,2.0)),3.141592653589793) +::STMT +MATRIX:CVars,CFreqs +LITERAL_FLOAT:1.0 +*(-(CFreqs,1.0),CVars) +::STMT +MATRIX:s,parsertemp44016,d +*(%*%(t(-(s,parsertemp44016)),d),%*%(t(-(s,parsertemp44016)),d)) +::STMT +MATRIX:P +sum(+(P,t(P))) +::STMT +MATRIX:A +FLOAT:a11,a12,int33,int524 +LITERAL_FLOAT:1.0 ++(+(+(/(int524,a11),/(int33,a12)),/(1.0,cast.FLOAT(A))),/(1.0,cast.FLOAT(A))) +::STMT +MATRIX:lambda,g,beta ++(g,*(cast.FLOAT(lambda),cast.FLOAT(beta))) +::STMT +MATRIX:Y +FLOAT:num_features,num_records +LITERAL_FLOAT:1.0 +*(-(num_records,num_features),-(ncol(Y),1.0)) +::STMT +MATRIX:L,m +FLOAT:sum +/(-(m,sum),L) +::STMT +FLOAT:e,initial_lr,decay +LITERAL_FLOAT:1.0 +*(initial_lr,/(1.0,+(1.0,*(decay,e)))) +::STMT +FLOAT:new_log_l,log_l +LITERAL_FLOAT:-1.0 ++(*(new_log_l,-1.0),log_l) +::STMT +MATRIX:r_CG,p_CG +FLOAT:rr_CG,old_rr_CG +LITERAL_FLOAT:0.0 ++(-(0.0,r_CG),*(/(rr_CG,old_rr_CG),p_CG)) +::STMT +LITERAL_FLOAT:1.0,2.0,100.0 +*(^(100.0,2.0),-(100.0,1.0)) +::STMT +MATRIX:parsertemp220911,dY,Y +-(+(Y,dY),parsertemp220911) +::STMT +MATRIX:X_train +LITERAL_FLOAT:2.0 +/(2.0,ncol(X_train)) +::STMT +MATRIX:parsertemp389218 +FLOAT:int620 +LITERAL_FLOAT:1.0E-17,1057.0 ++(sqrt(/(*(parsertemp389218,int620),1057.0)),1.0E-17) +::STMT +MATRIX:S,U,W +%*%(t(U),*(W,%*%(U,t(S)))) +::STMT +FLOAT:int602,avg_tot,sum_sq_y_test,n +LITERAL_FLOAT:1.0 +/(-(sum_sq_y_test,*(n,^(avg_tot,int602))),-(n,1.0)) +::STMT +MATRIX:W +LITERAL_FLOAT:1.0,2.0,4.0 +*(4.0,-(^(sum(W),2.0),1.0)) +::STMT +MATRIX:linear_terms +FLOAT:var_power,link_power +LITERAL_FLOAT:1.0 +^(linear_terms,/(-(1.0,var_power),link_power)) +::STMT +MATRIX:H +-(+(H,t(H)),diag(diag(H))) +::STMT +MATRIX:col +FLOAT:min_val +-(col,min_val) +::STMT +MATRIX:parsertemp146930,184_unnorm_probs,parsertemp146928,184_scores +FLOAT:int210,parsertemp146927 +rowSums(*(*(*(parsertemp146927,parsertemp146928),/(int210,parsertemp146930)),/(exp(184_scores),rowSums(184_unnorm_probs)))) +::STMT +MATRIX:W +LITERAL_FLOAT:1.0,2.0 +*(-(sum(round(W)),2.0),+(sum(round(W)),1.0)) +::STMT +MATRIX:P,parsertemp220889,Y,Z,parsertemp220891 +FLOAT:int593,int40,parsertemp220894,int923 +%*%(*(-(*(P,int923),/(Z,parsertemp220894)),*(/(int593,parsertemp220891),+(parsertemp220889,int40))),Y) +::STMT +MATRIX:W +LITERAL_FLOAT:5.0 ++(sum(round(W)),5.0) +::STMT +MATRIX:parsertemp437192,parsertemp437191,resp,X,parsertemp437188 +FLOAT:float205,int295 +LITERAL_FLOAT:2.0 +-(/(%*%(t(resp),^(X,int295)),t(+(parsertemp437188,float205))),*(2.0,^(/(parsertemp437191,parsertemp437192),2.0))) +::STMT +MATRIX:parsertemp386844,parsertemp386845 +LITERAL_FLOAT:0.0,2.0 +&(>(rowSums(|(parsertemp386844,parsertemp386845)),0.0),<(rowSums(|(parsertemp386844,parsertemp386845)),2.0)) +::STMT +MATRIX:parsertemp410977,W,H,parsertemp410974 +rowSums(/(*(H,%*%(parsertemp410974,parsertemp410977)),t(colSums(W)))) +::STMT +MATRIX:lambda,scale_X,gXY,beta +FLOAT:int58 +%*%(t(+(*(scale_X,gXY),*(lambda,beta))),+(*(scale_X,-(int58,gXY)),*(lambda,beta))) +::STMT +MATRIX:scale_X,X +%*%(diag(scale_X),%*%(t(X),X)) +::STMT +MATRIX:out2,parsertemp146942,184_dscores +FLOAT:int741 +LITERAL_FLOAT:2.0 +^(colSums(*(>(out2,int741),%*%(184_dscores,parsertemp146942))),2.0) +::STMT +MATRIX:p,q,lambda +%*%(t(p),+(q,*(lambda,p))) +::STMT +MATRIX:parsertemp220988,parsertemp220989,dY +LITERAL_FLOAT:300.0,2.0,0.9 +^(-(*(0.9,dY),*(300.0,-(parsertemp220988,parsertemp220989))),2.0) +::STMT +MATRIX:_sbcvar1674 +FLOAT:int964 +LITERAL_FLOAT:0.0,2.0 +INT:int411,parsertemp282730 +*(>(rand(parsertemp282730,int411,int964,2.0),0.0),_sbcvar1674) +::STMT +MATRIX:parsertemp555753,target +LITERAL_FLOAT:0.5 +/(sum(*(0.5,rowSums(parsertemp555753))),nrow(target)) +::STMT +MATRIX:W +sum(round(W)) +::STMT +MATRIX:one_featureX +FLOAT:287_x,287_y +LITERAL_FLOAT:2.0 +!(<(one_featureX,/(+(287_x,287_y),2.0))) +::STMT +LITERAL_FLOAT:0.0,1.0 +INT:n_components,parsertemp506195 +rowSums(rand(parsertemp506195,n_components,0.0,1.0)) +::STMT +MATRIX:R +FLOAT:s,i8 +-(ncol(R),*(s,i8)) +::STMT +MATRIX:p,r +FLOAT:norm_r2 +*(/(sum(*(r,r)),norm_r2),p) +::STMT +LITERAL_FLOAT:1.0,750.0 +*(750.0,1.0) +::STMT +MATRIX:ss,X2 +FLOAT:alpha +LITERAL_FLOAT:1.0 +*(-(1.0,alpha),-(/(nrow(X2),ss),1.0)) +::STMT +FLOAT:float634,parsertemp254709,parsertemp254694,2123_sq_root_d,pp_CG ++(float634,*(parsertemp254709,/(+(parsertemp254694,2123_sq_root_d),pp_CG))) +::STMT +FLOAT:b,int894,rad +/(int894,+(b,rad)) +::STMT +MATRIX:w,out +LITERAL_FLOAT:1.0,0.5 +*(1.0,+(*(0.5,cast.FLOAT(out)),*(0.5,cast.FLOAT(w)))) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0 +colSums(!=(X,0.0)) +::STMT +LITERAL_FLOAT:0.1092173494617922 +0.1092173494617922 +::STMT +MATRIX:r_CG,g_reg,z +%*%(t(z),+(r_CG,g_reg)) +::STMT +MATRIX:parsertemp498247,m_iter_err_sum +FLOAT:i_process_item +LITERAL_FLOAT:0.0,2.0 +*(2.0,/(-(0.0,-(parsertemp498247,m_iter_err_sum)),i_process_item)) +::STMT +MATRIX:D,ZERODIAG +LITERAL_FLOAT:1.0 +*(/(1.0,+(D,1.0)),ZERODIAG) +::STMT +MATRIX:intercept,X,beta +LITERAL_FLOAT:1.0 +INT:num_records,int303 ++(%*%(X,beta),%*%(rand(num_records,int303,1.0,1.0),intercept)) +::STMT +MATRIX:t,parsertemp171088,parsertemp171083,parsertemp171094 +FLOAT:float536 +LITERAL_FLOAT:-1.0,1.0,2.515517 ++(*(sqrt(*(float536,parsertemp171083)),-1.0),/(+(2.515517,*(t,parsertemp171088)),+(1.0,*(t,parsertemp171094)))) +::STMT +MATRIX:y_batch,parsertemp146892 +LITERAL_FLOAT:0.0 +sum(*(-(0.0,y_batch),parsertemp146892)) +::STMT +MATRIX:output_values,current_prediction +LITERAL_FLOAT:0.3 ++(current_prediction,*(0.3,cast.FLOAT(output_values))) +::STMT +MATRIX:w_X,z_LS +LITERAL_FLOAT:1000.0 +/(1000.0,cast.FLOAT(%*%(t(w_X),z_LS))) +::STMT +MATRIX:B,S,X +%*%(X,+(B,S)) +::STMT +MATRIX:s,d,alpha +-(s,*(cast.FLOAT(alpha),d)) +::STMT +MATRIX:M +FLOAT:parsertemp178174 +cast.MATRIX(+(max(M),parsertemp178174)) +::STMT +MATRIX:parsertemp394988,W3_rand +FLOAT:int204,int625 +LITERAL_FLOAT:0.21483446221182986 +%*%(*(0.21483446221182986,W3_rand),t(/(-(parsertemp394988,int625),+(parsertemp394988,int204)))) +::STMT +MATRIX:F,parsertemp42207,parsertemp42208,438_Ranks +FLOAT:parsertemp42222,int325,meanY,meanX,int938 +*(t(*(/(F,parsertemp42222),-(438_Ranks,meanX))),-(+(-(parsertemp42207,parsertemp42208),/(int325,int938)),meanY)) +::STMT +FLOAT:2344_s_err_mean +LITERAL_FLOAT:-1.0,0.001 +-(*(0.001,-1.0),2344_s_err_mean) +::STMT +MATRIX:history +FLOAT:float452 +-(max(history),float452) +::STMT +MATRIX:colSD,colMean +LITERAL_FLOAT:3.0 +-(colMean,*(3.0,colSD)) +::STMT +MATRIX:is_zero_y_corr,is_one_y_corr +LITERAL_FLOAT:-0.36651292058166435 +*(-0.36651292058166435,+(is_zero_y_corr,is_one_y_corr)) +::STMT +MATRIX:H,betamin,beta +FLOAT:logU +LITERAL_FLOAT:0.0 ++(*(<(-(H,logU),0.0),betamin),*(>=(-(H,logU),0.0),beta)) +::STMT +MATRIX:parsertemp171087,parsertemp171084,parsertemp171093 +FLOAT:float298,float780 +LITERAL_FLOAT:1.0,2.515517 +/(+(2.515517,*(sqrt(parsertemp171084),+(float780,parsertemp171087))),+(1.0,*(sqrt(parsertemp171084),+(float298,parsertemp171093)))) +::STMT +MATRIX:parsertemp235660,parsertemp235671 +FLOAT:parsertemp235661 +LITERAL_FLOAT:0.0 +sum(*(-(0.0,/(parsertemp235660,parsertemp235661)),parsertemp235671)) +::STMT +MATRIX:qLow,length,qUp +rowSums(|(<(length,qLow),>(length,qUp))) +::STMT +MATRIX:_sbcvar1750 +FLOAT:_sbcvar1751 +LITERAL_FLOAT:6.0 +/(_sbcvar1750,-(6.0,_sbcvar1751)) +::STMT +MATRIX:intercept +LITERAL_FLOAT:1.0 +INT:num_records,int615 +%*%(rand(num_records,int615,1.0,1.0),intercept) +::STMT +FLOAT:cmLabels,int624,float396 +LITERAL_FLOAT:10000.0 +sqrt(*(cmLabels,/(10000.0,-(int624,float396)))) +::STMT +MATRIX:parsertemp98,X,Y +LITERAL_FLOAT:2.0 +/(abs(-(X,Y)),/(parsertemp98,2.0)) +::STMT +MATRIX:V +FLOAT:std_dev,int434,mu +*(>(V,+(mu,*(int434,std_dev))),V) +::STMT +MATRIX:V +FLOAT:std_dev,int654,mu +*(<(V,-(mu,*(int654,std_dev))),V) +::STMT +MATRIX:X,y +LITERAL_FLOAT:-1.0 +%*%(*(t(X),-1.0),y) +::STMT +MATRIX:X_batch,186_dX,parsertemp146949,parsertemp146957,parsertemp146955 +FLOAT:beta1 +LITERAL_FLOAT:1.0 +*(-(1.0,beta1),%*%(t(X_batch),*(*(parsertemp146957,parsertemp146955),%*%(186_dX,parsertemp146949)))) +::STMT +MATRIX:neighbors,corePts,withinEps +*(*(neighbors,corePts),withinEps) +::STMT +MATRIX:r_LS +FLOAT:norm_r2_LS,p_LS,parsertemp170552,lambda_LS ++(cast.FLOAT(r_LS),*(/(norm_r2_LS,*(p_LS,p_LS)),+(*(parsertemp170552,p_LS),*(lambda_LS,p_LS)))) +::STMT +FLOAT:parsertemp496735,var,arch_coef,var_coef,a0 +sqrt(+(+(a0,*(arch_coef,parsertemp496735)),*(var_coef,var))) +::STMT +LITERAL_FLOAT:0.0,2.0 +*(2.0,0.0) +::STMT +MATRIX:parsertemp171084,parsertemp171083 +LITERAL_FLOAT:-2.0,0.001308,0.189269 +*(sqrt(*(-2.0,parsertemp171083)),+(0.189269,*(sqrt(parsertemp171084),0.001308))) +::STMT +FLOAT:277_sq_root_d,parsertemp170093,pp_CG,pq_CG +LITERAL_FLOAT:0.5 +*(*(0.5,/(-(parsertemp170093,277_sq_root_d),pp_CG)),pq_CG) +::STMT +FLOAT:int199,s,num_groups,int805 +LITERAL_FLOAT:1.0 ++(+(*(-(s,int199),-(num_groups,int805)),1.0),num_groups) +::STMT +MATRIX:grad +FLOAT:int842 +LITERAL_FLOAT:0.1 +*(0.1,sqrt(sum(^(grad,int842)))) +::STMT +MATRIX:X,y +FLOAT:int276,int931,int845,int559 +INT:int786,m,int690 +*(-(%*%(X,rand(m,int690,int845,int559)),y),-(%*%(X,rand(m,int786,int931,int276)),y)) +::STMT +MATRIX:W +FLOAT:int221,int797,wt +LITERAL_FLOAT:1.0,3.0,6.0 +/(*(*(6.0,sum(W)),-(sum(W),1.0)),*(*(-(wt,int221),+(wt,int797)),+(sum(W),3.0))) +::STMT +MATRIX:scale_X,z,beta +%*%(diag(scale_X),+(beta,z)) +::STMT +MATRIX:X +-(X,round(X)) +::STMT +MATRIX:u,minDist +sum(!=(u,minDist)) +::STMT +LITERAL_FLOAT:0.0,1.0 +INT:int374,int767 +sum(rand(int374,int767,0.0,1.0)) +::STMT +MATRIX:parsertemp539203 +FLOAT:int106 +LITERAL_FLOAT:1.0,2.0,1.5 +min(^(/(*(parsertemp539203,int106),2.0),/(1.0,1.5))) +::STMT +FLOAT:ID +LITERAL_FLOAT:1.0,2.0 ++(*(2.0,ID),1.0) +::STMT +MATRIX:parsertemp472412,fP +FLOAT:max_values +<=(parsertemp472412,/(^($1:ncol(fP),max_values),$1)) +::STMT +MATRIX:prediction,target +rowSums(abs(-(prediction,target))) +::STMT +MATRIX:parsertemp382905,S,V,W,row_nonzeros +FLOAT:reg +*(S,+(%*%(*(W,parsertemp382905),V),*(*(reg,S),row_nonzeros))) +::STMT +MATRIX:X +cast.MATRIX(sum(X)) +::STMT +MATRIX:parsertemp43618,o +LITERAL_FLOAT:1.0 +/(1.0,+(1.0,exp(*(parsertemp43618,o)))) +::STMT +MATRIX:lambda,parsertemp286535,beta +FLOAT:float296 +LITERAL_FLOAT:0.0 +cast.FLOAT(%*%(t(+(float296,parsertemp286535)),+(0.0,*(lambda,beta)))) +::STMT +FLOAT:Hin +LITERAL_FLOAT:2.0,64.0 +*(64.0,/(/(Hin,2.0),2.0)) +::STMT +MATRIX:scale_X,w,ssX_p_CG,X +*(cast.FLOAT(diag(scale_X)),%*%(t(X),*(w,%*%(X,ssX_p_CG)))) +::STMT +MATRIX:s,w +FLOAT:step_sz +LITERAL_FLOAT:2.0 +^(+(w,*(step_sz,s)),2.0) +::STMT +MATRIX:Y +LITERAL_FLOAT:0.0 +==(rowSums(Y),0.0) +::STMT +MATRIX:Y,the_exp +FLOAT:int549 +-(*(rowSums(Y),exp(*(the_exp,int549))),Y) +::STMT +MATRIX:images +LITERAL_FLOAT:1.0,2.0,255.0 +-(*(/(images,255.0),2.0),1.0) +::STMT +FLOAT:parsertemp459295 +LITERAL_FLOAT:1.0,128.0 +-(+(+(parsertemp459295,1.0),128.0),1.0) +::STMT +MATRIX:negSampleMeans +LITERAL_FLOAT:2.0,150.0 +*(150.0,^(negSampleMeans,2.0)) +::STMT +MATRIX:output,output1 +LITERAL_FLOAT:0.0,0.1 +==(<(abs(-(output,output1)),0.1),0.0) +::STMT +MATRIX:_sbcvar1734 +FLOAT:_sbcvar1735 +LITERAL_FLOAT:12.0 +/(_sbcvar1734,-(12.0,_sbcvar1735)) +::STMT +MATRIX:r +FLOAT:int12 +LITERAL_FLOAT:9.999999999999998E-15 +sqrt(*(sum(^(r,int12)),9.999999999999998E-15)) +::STMT +MATRIX:parsertemp2832 +==(round(parsertemp2832),max(round(parsertemp2832))) +::STMT +MATRIX:A +-(A,t(A)) +::STMT +MATRIX:X,Y,K +FLOAT:x +LITERAL_FLOAT:1.0 +*(-(*(K,-(X,X)),-(Y,Y)),-(1.0,/(-(x,X),-(X,X)))) +::STMT +MATRIX:samples_vs_runs_map,centroid_placer,X_samples +rowSums(*(X_samples,%*%(samples_vs_runs_map,%*%(centroid_placer,X_samples)))) +::STMT +MATRIX:R,parsertemp72406,parsertemp72323 +LITERAL_FLOAT:2.0 +sum(^(-(%*%(parsertemp72323,R),diag(parsertemp72406)),2.0)) +::STMT +MATRIX:parsertemp220845,D,ZERODIAG +rowSums(*(*(exp(parsertemp220845),ZERODIAG),D)) +::STMT +MATRIX:X,y +FLOAT:int621,int319 +INT:int505,m +t(-(%*%(X,rand(m,int505,int319,int621)),y)) +::STMT +FLOAT:s_err_mean +LITERAL_FLOAT:-0.001 +-(-0.001,s_err_mean) +::STMT +FLOAT:batch,i,int558 +LITERAL_FLOAT:1.0 ++(+(*(-(i,int558),batch),1.0),batch) +::STMT +MATRIX:parsertemp447181,strings +/(parsertemp447181,length(strings)) +::STMT +FLOAT:a,b +LITERAL_FLOAT:2.0 +/(*(2.0,*(a,b)),+(a,b)) +::STMT +MATRIX:parsertemp500607,X,y,parsertemp500610 +%*%(t(X),-(%*%(X,*(parsertemp500607,parsertemp500610)),y)) +::STMT +MATRIX:g_Y,lambda,scale_X,parsertemp286673,beta +LITERAL_FLOAT:0.0 ++(*(scale_X,-(0.0,%*%(parsertemp286673,g_Y))),*(lambda,beta)) +::STMT +MATRIX:parsertemp170263,finite_linear_terms,parsertemp170261,the_exp +FLOAT:int120,int98,int745 +LITERAL_FLOAT:1.0 ++(*(-(1.0,==(parsertemp170263,int120)),-(1.0,exp(parsertemp170261))),*(*(==(parsertemp170263,int98),exp(finite_linear_terms)),-(1.0,/(the_exp,int745)))) +::STMT +MATRIX:r,c,E,_sbcvar78 +LITERAL_FLOAT:2.0,10000.0 +sum(/(^(-(_sbcvar78,E),2.0),/(%*%(r,c),10000.0))) +::STMT +MATRIX:X +FLOAT:val +>(X,val) +::STMT +MATRIX:parsertemp498248 +FLOAT:int60,i_process_item +LITERAL_FLOAT:2.0 +*(^(/(-(int60,parsertemp498248),i_process_item),2.0),i_process_item) +::STMT +MATRIX:R,w,ones_ns ++(R,diag(*(ones_ns,cast.FLOAT(w)))) +::STMT +MATRIX:foffb +LITERAL_FLOAT:1.0 +*(ncol(foffb),1.0) +::STMT +MATRIX:selCols2,maxscub +FLOAT:parsertemp31797 +LITERAL_FLOAT:-Infinity +&(selCols2,|(>=(maxscub,parsertemp31797),==(maxscub,-Infinity))) +::STMT +MATRIX:A +FLOAT:a11,a12,int566,int260 +LITERAL_FLOAT:1.0 ++(+(+(/(int260,a11),/(int566,a12)),/(1.0,cast.FLOAT(A))),/(1.0,cast.FLOAT(A))) +::STMT +MATRIX:t_gp,parsertemp170239 +FLOAT:float801,float299 +LITERAL_FLOAT:1.0,1.421413741,-1.453152027 ++(1.421413741,*(/(1.0,+(float299,parsertemp170239)),+(-1.453152027,*(t_gp,float801)))) +::STMT +MATRIX:275_X,275_curr_X +FLOAT:275_value +&(==(275_X,275_curr_X),>=(275_X,275_value)) +::STMT +MATRIX:float999,is_zero_y_corr,is_one_y_corr,parsertemp317445,parsertemp317451,parsertemp317462 +FLOAT:float898 +LITERAL_FLOAT:1.0 +-(+(*(+(parsertemp317451,parsertemp317462),1-*(float999,parsertemp317445)),/(is_one_y_corr,-(float898,is_one_y_corr))),/(is_zero_y_corr,-(1.0,is_zero_y_corr))) +::STMT +MATRIX:vW1,parsertemp146976 +FLOAT:parsertemp146975,191_beta2,191_epsilon ++(sqrt(+(*(191_beta2,vW1),*(parsertemp146975,parsertemp146976))),191_epsilon) +::STMT +MATRIX:X +LITERAL_FLOAT:3.0 +*(ncol(X),3.0) +::STMT +MATRIX:H,betamax,Hneg,Hpos,beta +FLOAT:float761 +LITERAL_FLOAT:0.0,2.0,1.0E20 +*(*(2.0,>=(-(H,float761),0.0)),==(+(*(Hpos,betamax),*(Hneg,beta)),1.0E20)) +::STMT +MATRIX:B +LITERAL_FLOAT:2.0 +*(ncol(B),2.0) +::STMT +MATRIX:parsertemp11251 +LITERAL_FLOAT:2.0 +^(2.0,parsertemp11251) +::STMT +MATRIX:parsertemp220853,Ws,beta +LITERAL_FLOAT:0.0,3.4011973816621555 +>=(-(+(parsertemp220853,*(beta,Ws)),3.4011973816621555),0.0) +::STMT +MATRIX:X +LITERAL_FLOAT:2.0,1764.0 +/(colSums(^(X,2.0)),1764.0) +::STMT +MATRIX:r,scale_X,shift_X ++(*(scale_X,r),%*%(shift_X,r)) +::STMT +MATRIX:_sbcvar1846 +FLOAT:_sbcvar1847 +LITERAL_FLOAT:11.0 +/(_sbcvar1846,-(11.0,_sbcvar1847)) +::STMT +LITERAL_FLOAT:2.0,2000.0 +^(2000.0,2.0) +::STMT +MATRIX:B +LITERAL_FLOAT:4.0 +*(ncol(B),4.0) +::STMT +MATRIX:parsertemp410118,g0_1,parsertemp410117 +%*%(t(+(g0_1,t(parsertemp410118))),+(g0_1,t(colSums(parsertemp410117)))) +::STMT +MATRIX:p,q +FLOAT:norm_r2 +*(/(norm_r2,sum(*(p,q))),p) +::STMT +FLOAT:float778,int543,parsertemp171819,parsertemp171815,sim_score_parent,int9,parsertemp171824,float6 +-(+(/(^(parsertemp171815,int543),+(parsertemp171819,float6)),/(^(parsertemp171824,int9),+(parsertemp171819,float778))),sim_score_parent) +::STMT +MATRIX:lambda,g,beta +sum(*(+(g,*(lambda,beta)),+(g,*(lambda,beta)))) +::STMT +FLOAT:dd,step_sz,wd ++(wd,*(step_sz,dd)) +::STMT +MATRIX:B +LITERAL_FLOAT:8.0 +*(ncol(B),8.0) +::STMT +MATRIX:R,parsertemp40220 +FLOAT:numRows +LITERAL_FLOAT:1.0 +-(/(numRows,-(R,rowSums(parsertemp40220))),1.0) +::STMT +MATRIX:parsertemp31188,parsertemp31186 +FLOAT:int947 +LITERAL_FLOAT:1.0,7000.0 +/(/(-(colSums(parsertemp31186),*(int947,parsertemp31188)),-(7000.0,1.0)),7000.0) +::STMT +MATRIX:f,parsertemp472177,I,parsertemp472179 +*(I,-(%*%(f,parsertemp472177),t(parsertemp472179))) +::STMT +FLOAT:m2X,W +LITERAL_FLOAT:1.0 +*(m2X,/(W,-(W,1.0))) +::STMT +FLOAT:m2X,parsertemp4,m2Y,parsertemp8,int635,int492 +*(sqrt(*(m2X,/(int492,parsertemp4))),sqrt(*(m2Y,/(int635,parsertemp8)))) +::STMT +MATRIX:y_corr +FLOAT:float657 +LITERAL_FLOAT:1.0,0.5 ++(y_corr,*(-(1.0,*(float657,y_corr)),>(y_corr,0.5))) +::STMT +MATRIX:parsertemp400660,W3_rand +FLOAT:int364,int747 +LITERAL_FLOAT:0.2656844656620286 +%*%(*(0.2656844656620286,W3_rand),t(/(-(parsertemp400660,int747),+(parsertemp400660,int364)))) +::STMT +MATRIX:ones,classFeatureCounts +LITERAL_FLOAT:1.0,50.0 +%*%(+(rowSums(classFeatureCounts),*(50.0,1.0)),ones) +::STMT +MATRIX:CFreqs1,present_domain_vals_mat,parsertemp27485 +FLOAT:my +LITERAL_FLOAT:2.0 +*(%*%(present_domain_vals_mat,CFreqs1),^(-(%*%(present_domain_vals_mat,parsertemp27485),my),2.0)) +::STMT +FLOAT:parsertemp13703 +LITERAL_FLOAT:999.0,1000.0 +*(999.0,/(*(parsertemp13703,1000.0),999.0)) +::STMT +MATRIX:p_LS,X +*(cast.FLOAT(%*%(t(X),X)),cast.FLOAT(p_LS)) +::STMT +MATRIX:cm,FD +FLOAT:int406,n +LITERAL_FLOAT:0.0 +!=(+(+(FD,==(cm,int406)),==(t(cm),n)),0.0) +::STMT +MATRIX:lambda,beta +LITERAL_FLOAT:0.0 +t(+(0.0,*(lambda,beta))) +::STMT +MATRIX:r,d,alpha,parsertemp44052,Hd +FLOAT:norm_r2 ++(-(r,*(cast.FLOAT(alpha),Hd)),*(/(sum(parsertemp44052),norm_r2),d)) +::STMT +MATRIX:fdom,parsertemp1688 +-(t(parsertemp1688),fdom) +::STMT +MATRIX:d,parsertemp410054 +FLOAT:r2 +/(r2,cast.FLOAT(%*%(t(d),t(parsertemp410054)))) +::STMT +MATRIX:p,lambda,scale_X,shift_X +FLOAT:q +*(p,+(+(*(scale_X,q),*(q,shift_X)),*(lambda,p))) +::STMT +MATRIX:B2,ytest,Xtest +%*%(t(-(ytest,%*%(Xtest,B2))),-(ytest,%*%(Xtest,B2))) +::STMT +MATRIX:parsertemp43632,X,y +LITERAL_FLOAT:0.0,2.0 ++(0.0,*(2.0,%*%(t(X),*(parsertemp43632,y)))) +::STMT +FLOAT:df,parsertemp437302,n,norm +LITERAL_FLOAT:-2.0 ++(*(*(-2.0,norm),n),*(df,parsertemp437302)) +::STMT +FLOAT:cols,parsertemp451837 +LITERAL_FLOAT:1.0 +-(+(+(*(parsertemp451837,cols),1.0),cols),1.0) +::STMT +MATRIX:codebook +FLOAT:j +*(j,ncol(codebook)) +::STMT +MATRIX:is_LT_infinite +LITERAL_FLOAT:1.0 +-(1.0,rowSums(is_LT_infinite)) +::STMT +LITERAL_FLOAT:1.02 +1.02 +::STMT +MATRIX:linear_terms +FLOAT:link_power +LITERAL_FLOAT:1.0 +-(1.0,^(linear_terms,/(1.0,link_power))) +::STMT +MATRIX:_sbcvar1830 +FLOAT:_sbcvar1831 +LITERAL_FLOAT:10.0 +/(_sbcvar1830,-(10.0,_sbcvar1831)) +::STMT +MATRIX:tmp +LITERAL_FLOAT:50.0 +*(50.0,cast.FLOAT(%*%(t(tmp),tmp))) +::STMT +MATRIX:y_hat +FLOAT:k,parsertemp176418 ++(sqrt(parsertemp176418),*(k,y_hat)) +::STMT +MATRIX:m_iter_err_sum,parsertemp498242 +FLOAT:i_process_item +LITERAL_FLOAT:0.0 +/(-(0.0,-(t(m_iter_err_sum),+(parsertemp498242,m_iter_err_sum))),i_process_item) +::STMT +LITERAL_FLOAT:1.0E-16 +1.0E-16 +::STMT +FLOAT:D,o +LITERAL_FLOAT:-2.0,-1.0,2.0 ++(*(-2.0,*(o,-1.0)),*(2.0,D)) +::STMT +MATRIX:W,parsertemp411099,X,H +LITERAL_FLOAT:1.0E-8 +/(%*%(t(W),X),+(%*%(%*%(parsertemp411099,W),H),1.0E-8)) +::STMT +MATRIX:g_reg,g,parsertemp285556 +FLOAT:parsertemp285562 +*(cast.FLOAT(%*%(t(g_reg),+(g,parsertemp285556))),parsertemp285562) +::STMT +MATRIX:X2 +LITERAL_FLOAT:4.0 +>=(t(colSums(X2)),4.0) +::STMT +MATRIX:select,d_r_rev,X_exp_Xb_rev_agg,D_r_rev +*(/(%*%(select,X_exp_Xb_rev_agg),D_r_rev),d_r_rev) +::STMT +MATRIX:w,ones_ns +diag(*(ones_ns,cast.FLOAT(w))) +::STMT +MATRIX:g +FLOAT:lambda,beta +LITERAL_FLOAT:2.0 +sum(^(+(g,*(lambda,beta)),2.0)) +::STMT +LITERAL_FLOAT:1.0,500.0 +*(500.0,1.0) +::STMT +MATRIX:s,w +LITERAL_FLOAT:1.0 +*(1.0,sum(*(w,s))) +::STMT +MATRIX:m_active_flag_tmp,m_active_flag +LITERAL_FLOAT:1.0 +>=(+(m_active_flag,m_active_flag_tmp),1.0) +::STMT +MATRIX:vW1,190_dW +FLOAT:191_beta2,int129,int49 +sqrt(+(*(191_beta2,vW1),*(-(int129,191_beta2),^(190_dW,int49)))) +::STMT +MATRIX:A +FLOAT:parsertemp12882 +LITERAL_FLOAT:1.0 +*(-(nrow(A),1.0),/(*(parsertemp12882,nrow(A)),-(nrow(A),1.0))) +::STMT +LITERAL_FLOAT:0.08709382882250233 +0.08709382882250233 +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,50.0 +/(+(classFeatureCounts,1.0),+(rowSums(classFeatureCounts),*(50.0,1.0))) +::STMT +MATRIX:parsertemp170665,residual_matrix,curr_prediction +FLOAT:282_lambda +LITERAL_FLOAT:2.0 +/(^(sum(residual_matrix),2.0),+(sum(*(curr_prediction,parsertemp170665)),282_lambda)) +::STMT +MATRIX:dY,g,parsertemp221002,Y +FLOAT:float831,float422 +-(+(Y,-(*(float422,dY),*(float831,g))),parsertemp221002) +::STMT +MATRIX:grad +LITERAL_FLOAT:-1.0 +*(*(grad,-1.0),*(grad,-1.0)) +::STMT +MATRIX:parsertemp379560,m_iter_err_sum +FLOAT:i_process_item +LITERAL_FLOAT:-1.0 +/(*(-(t(m_iter_err_sum),+(parsertemp379560,m_iter_err_sum)),-1.0),i_process_item) +::STMT +MATRIX:grad +FLOAT:int204,int415 +sqrt(sum(*(*(grad,int204),*(grad,int415)))) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:1.0,2.0 +/(1.0,^(linear_terms,2.0)) +::STMT +MATRIX:log_l_part_saturated +LITERAL_FLOAT:2.0 +*(2.0,sum(log_l_part_saturated)) +::STMT +FLOAT:eta,s,parsertemp454319 +*(parsertemp454319,^(eta,s)) +::STMT +MATRIX:output,Mask +LITERAL_FLOAT:1.0 +*(output,-(1.0,Mask)) +::STMT +MATRIX:paramLens,parsertemp387457 +/(parsertemp387457,rev(paramLens)) +::STMT +MATRIX:WM +FLOAT:parsertemp31268 +*(parsertemp31268,sum(WM)) +::STMT +MATRIX:X +LITERAL_FLOAT:2.0 +*(2.0,ncol(X)) +::STMT +LITERAL_FLOAT:1.0,2000.0 ++(2000.0,1.0) +::STMT +MATRIX:p,parsertemp1934,parsertemp1935 +FLOAT:eps +cast.FLOAT(%*%(t(p),+(%*%(parsertemp1934,parsertemp1935),*(eps,p)))) +::STMT +MATRIX:parsertemp410246,parsertemp410249 +FLOAT:float218,int106,int527,float484 +-(max(^(/(parsertemp410246,parsertemp410249),/(int106,float218))),min(^(/(parsertemp410246,parsertemp410249),/(int527,float484)))) +::STMT +MATRIX:XY_pairs_local,XY_pairs +|(XY_pairs,t(XY_pairs_local)) +::STMT +MATRIX:ssX_V,X,parsertemp150463,P_1K,parsertemp149251 +%*%(t(X),-(*(P_1K,%*%(X,ssX_V)),*(P_1K,%*%(parsertemp149251,parsertemp150463)))) +::STMT +MATRIX:parsertemp235671,I,y2 +LITERAL_FLOAT:0.0 +*(-(0.0,/(%*%(I,y2),sum(I))),parsertemp235671) +::STMT +MATRIX:X +FLOAT:N +%*%(t(/(colSums(X),N)),/(colSums(X),N)) +::STMT +MATRIX:parsertemp27746,parsertemp27872 +FLOAT:featureCorrection +-(%*%(parsertemp27872,t(parsertemp27746)),featureCorrection) +::STMT +MATRIX:_sbcvar1798 +FLOAT:_sbcvar1799 +LITERAL_FLOAT:9.0 +/(_sbcvar1798,-(9.0,_sbcvar1799)) +::STMT +FLOAT:num_records +LITERAL_FLOAT:1.0,100.0 +-(1.0,/(100.0,num_records)) +::STMT +MATRIX:d,X,logisticD +FLOAT:C +*(C,%*%(t(X),*(logisticD,%*%(X,d)))) +::STMT +MATRIX:parsertemp171083 +LITERAL_FLOAT:0.0,-2.0 +-(0.0,sqrt(*(-2.0,parsertemp171083))) +::STMT +MATRIX:S,addedE,addedX2 +FLOAT:level +*(==(%*%(S,t(addedX2)),level),t(addedE)) +::STMT +LITERAL_FLOAT:409.0 +409.0 +::STMT +FLOAT:parsertemp40812,m2,int727 +LITERAL_FLOAT:4.0 +^(sqrt(*(/(int727,parsertemp40812),m2)),4.0) +::STMT +FLOAT:int960,parsertemp285740,p_CG,pp_CG,parsertemp285757 +*(parsertemp285757,/(+(*(p_CG,int960),sqrt(parsertemp285740)),pp_CG)) +::STMT +FLOAT:n +LITERAL_FLOAT:1.0,2.0 +/(1.0,*(2.0,n)) +::STMT +LITERAL_FLOAT:5.0,2000.0 ++(2000.0,5.0) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0E-12 +INT:int829,int420 ++(%*%(t(X),X),diag(rand(int829,int420,1.0E-12,1.0E-12))) +::STMT +MATRIX:is_row_in_samples,parsertemp79018 +LITERAL_FLOAT:3811.0 +-(3811.0,*(is_row_in_samples,parsertemp79018)) +::STMT +MATRIX:W,H +sum(%*%(W,H)) +::STMT +LITERAL_FLOAT:750.0 +750.0 +::STMT +LITERAL_FLOAT:0.08725945907447251 +0.08725945907447251 +::STMT +LITERAL_FLOAT:3.0,2000.0 ++(2000.0,3.0) +::STMT +MATRIX:scores,unnorm_probs,dprobs +rowSums(*(dprobs,/(exp(scores),rowSums(unnorm_probs)))) +::STMT +MATRIX:parsertemp472316,parsertemp472314,ig +FLOAT:min_leaf +max(*(&(>=(parsertemp472314,min_leaf),>=(parsertemp472316,min_leaf)),ig)) +::STMT +FLOAT:FN,FP,TP +*(+(TP,FP),+(TP,FN)) +::STMT +MATRIX:tmp,X,Y,out +t(-(%*%(t(X),*(out,Y)),tmp)) +::STMT +FLOAT:alpha +LITERAL_FLOAT:1.0,2.0 +-(1.0,/(alpha,2.0)) +::STMT +MATRIX:A,B +LITERAL_FLOAT:-1.0,2.0 +^(*(%*%(t(A),B),-1.0),2.0) +::STMT +LITERAL_FLOAT:1.432788 +1.432788 +::STMT +MATRIX:surv +LITERAL_FLOAT:0.5 +sum(<=(surv,0.5)) +::STMT +MATRIX:G,authorities,hubs +-(/(%*%(G,authorities),max(%*%(G,authorities))),hubs) +::STMT +MATRIX:X,parsertemp555606 +LITERAL_FLOAT:1.0 +/(%*%(t(-(X,parsertemp555606)),-(X,parsertemp555606)),-(nrow(X),1.0)) +::STMT +MATRIX:parsertemp42200,F +LITERAL_FLOAT:2.0 +-(parsertemp42200,/(rowSums(F),2.0)) +::STMT +MATRIX:R,parsertemp500307 +FLOAT:int715 +LITERAL_FLOAT:1.0 +INT:int807,int466,parsertemp500306,parsertemp500303 ++(%*%(rowSums(^(R,int715)),rand(int466,parsertemp500303,1.0,1.0)),%*%(rand(parsertemp500306,int807,1.0,1.0),t(rowSums(parsertemp500307)))) +::STMT +MATRIX:parsertemp171117,is_zero_y_corr,is_one_y_corr,parsertemp171113 +FLOAT:parsertemp171116,float156 +LITERAL_FLOAT:1.0 +-(+(-(parsertemp171113,*(parsertemp171116,parsertemp171117)),/(is_one_y_corr,-(float156,is_one_y_corr))),/(is_zero_y_corr,-(1.0,is_zero_y_corr))) +::STMT +MATRIX:parsertemp411207,parsertemp411209,W,parsertemp411198,H,parsertemp411200 +LITERAL_FLOAT:1.0E-8 ++(%*%(/(*(W,parsertemp411207),t(parsertemp411209)),/(*(H,parsertemp411198),t(parsertemp411200))),1.0E-8) +::STMT +MATRIX:subspace_idx,parsertemp72201 +FLOAT:subvector_size +-(subspace_idx,*(parsertemp72201,subvector_size)) +::STMT +MATRIX:p_CG,z +*(cast.FLOAT(z),sum(p_CG)) +::STMT +MATRIX:parsertemp459256 +LITERAL_FLOAT:5.0E-4 +*(5.0E-4,parsertemp459256) +::STMT +MATRIX:X +LITERAL_FLOAT:2.0 +^(/(t(colSums(X)),nrow(X)),2.0) +::STMT +FLOAT:o +LITERAL_FLOAT:-2.0,-1.0 +*(-2.0,*(o,-1.0)) +::STMT +MATRIX:dout1,vb1 +FLOAT:192_beta2 +LITERAL_FLOAT:1.0,2.0 ++(*(192_beta2,vb1),*(-(1.0,192_beta2),^(colSums(dout1),2.0))) +::STMT +MATRIX:X,Y +abs(-(X,Y)) +::STMT +MATRIX:parsertemp10744,W,H +FLOAT:Eps ++(%*%(W,%*%(*(H,parsertemp10744),t(H))),Eps) +::STMT +MATRIX:y_residual,parsertemp415351 +FLOAT:parsertemp415362,n,int152 +LITERAL_FLOAT:1.0 +-(1.0,/(sum(^(y_residual,int152)),-(sum(parsertemp415351),*(n,parsertemp415362)))) +::STMT +MATRIX:parsertemp10740,V,W,H +FLOAT:Eps +/(%*%(t(W),V),+(%*%(%*%(parsertemp10740,W),H),Eps)) +::STMT +MATRIX:in_m_data_target +LITERAL_FLOAT:100.0 +*(-(max(in_m_data_target),min(in_m_data_target)),100.0) +::STMT +MATRIX:parsertemp560919,parsertemp560920,elt,ones_ctg +LITERAL_FLOAT:1.0 +*(/(elt,%*%(rowSums(elt),t(ones_ctg))),%*%(/(elt,%*%(parsertemp560919,parsertemp560920)),-(1.0,diag(ones_ctg)))) +::STMT +MATRIX:termination_bitmap,parsertemp72096 +FLOAT:int497,worst_wcss +LITERAL_FLOAT:1.0,10.0 ++(*(parsertemp72096,termination_bitmap),*(+(*(int497,worst_wcss),10.0),-(1.0,termination_bitmap))) +::STMT +MATRIX:W1_rand,X,parsertemp401984,parsertemp401974 +FLOAT:float690 +LITERAL_FLOAT:0.06835859270246632 +%*%(*(0.06835859270246632,W1_rand),t(/(-(X,parsertemp401974),+(parsertemp401984,float690)))) +::STMT +MATRIX:I,y2 +LITERAL_FLOAT:0.0 +-(0.0,/(%*%(I,y2),sum(I))) +::STMT +MATRIX:M +exp(-(M,max(M))) +::STMT +MATRIX:entropy,parsertemp552397,resp,L +*(==(+(resp,t(parsertemp552397)),L),entropy) +::STMT +FLOAT:sd_X +sqrt(sd_X) +::STMT +FLOAT:j +LITERAL_FLOAT:1.0,4.0 ++(-(4.0,j),1.0) +::STMT +MATRIX:w_X,z_LS +LITERAL_FLOAT:10000.0 +*(/(10000.0,cast.FLOAT(%*%(w_X,z_LS))),z_LS) +::STMT +FLOAT:m2Y,sigmaX +LITERAL_FLOAT:1.0005 +*(sigmaX,sqrt(*(m2Y,1.0005))) +::STMT +FLOAT:deviance_nodisp +LITERAL_FLOAT:0.1,1.0E-12 +*(1.0E-12,+(deviance_nodisp,0.1)) +::STMT +MATRIX:parsertemp410979,W,X,parsertemp410981,parsertemp410983 +FLOAT:eps +*(W,%*%(/(X,+(parsertemp410983,eps)),t(/(parsertemp410979,parsertemp410981)))) +::STMT +FLOAT:n_components,n_features +LITERAL_FLOAT:1.0,2.0 +/(*(*(n_components,n_features),+(n_features,1.0)),2.0) +::STMT +MATRIX:mu +LITERAL_FLOAT:4.0 +*(4.0,*(cast.FLOAT(mu),cast.FLOAT(mu))) +::STMT +MATRIX:p,r,parsertemp1597,lambda,parsertemp1590,parsertemp1589 +FLOAT:norm_r2 ++(r,*(/(norm_r2,sum(parsertemp1597)),+(%*%(parsertemp1589,parsertemp1590),*(lambda,p)))) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:500.0 ++(rowSums(classFeatureCounts),500.0) +::STMT +MATRIX:parsertemp13658,parsertemp13659,_sbcvar12 +FLOAT:44_meanX +LITERAL_FLOAT:999.0,0.5 +*(/(_sbcvar12,999.0),-(+(-(parsertemp13658,parsertemp13659),0.5),44_meanX)) +::STMT +MATRIX:linear_terms +FLOAT:var_power,link_power,float674 +LITERAL_FLOAT:2.0 +/(^(linear_terms,/(-(float674,var_power),link_power)),-(2.0,var_power)) +::STMT +FLOAT:int435,int13 +INT:int92,int565 +rand(int565,int92,int435,int13) +::STMT +MATRIX:prec_chol,bc_matrix,parsertemp436690 +FLOAT:int898 +*(bc_matrix,t(*(rowSums(parsertemp436690),^(prec_chol,int898)))) +::STMT +MATRIX:X +FLOAT:q1,q2 +|(<(X,q1),>(X,q2)) +::STMT +FLOAT:ytest,int697,int876,parsertemp454072,parsertemp454076,int481,int619 +LITERAL_FLOAT:1.0 +-(1.0,/(-(^(ytest,int619),*(int481,parsertemp454072)),-(^(ytest,int697),*(int876,parsertemp454076)))) +::STMT +MATRIX:parsertemp477918,b +FLOAT:tolerance +LITERAL_FLOAT:2.0 +*(sum(^(%*%(parsertemp477918,b),2.0)),^(tolerance,2.0)) +::STMT +MATRIX:X +FLOAT:M +*(nrow(X),M) +::STMT +MATRIX:leaf_ids +FLOAT:boundary_left,step_size +<(leaf_ids,+(boundary_left,step_size)) +::STMT +FLOAT:m2 +LITERAL_FLOAT:2.0,3.0 +*(3.0,^(m2,2.0)) +::STMT +MATRIX:curr_prediction +FLOAT:int644 +LITERAL_FLOAT:0.0 ++(sum(*(curr_prediction,-(int644,curr_prediction))),0.0) +::STMT +MATRIX:A,scale_X,shift_X,parsertemp1656,parsertemp1655 ++(%*%(diag(scale_X),t(+(parsertemp1655,parsertemp1656))),%*%(shift_X,A)) +::STMT +FLOAT:run_index +LITERAL_FLOAT:24.0,1.0 ++(*(24.0,-(run_index,1.0)),1.0) +::STMT +FLOAT:acc +LITERAL_FLOAT:1.0,100.0 +-(1.0,/(acc,100.0)) +::STMT +FLOAT:log_ten,d_eee,x,float396 +*(x,exp(*(log_ten,-(float396,d_eee)))) +::STMT +FLOAT:int244,parsertemp459332,int646,parsertemp459334 +LITERAL_FLOAT:2.0 +sqrt(/(2.0,*(*(int244,parsertemp459332),/(parsertemp459334,int646)))) +::STMT +MATRIX:X +FLOAT:N +LITERAL_FLOAT:0.0 +-(0.0,/(t(colSums(X)),N)) +::STMT +MATRIX:parsertemp410245,parsertemp410248 +FLOAT:int56,float64 +LITERAL_FLOAT:1.0,1.5 +max(^(/(*(parsertemp410245,int56),*(float64,parsertemp410248)),/(1.0,1.5))) +::STMT +MATRIX:parsertemp429913,avg_X_cols +FLOAT:int179 +LITERAL_FLOAT:300.0,299.0 +/(-(t(colSums(parsertemp429913)),*(300.0,^(avg_X_cols,int179))),299.0) +::STMT +MATRIX:P_denom +LITERAL_FLOAT:0.0 +sum(<=(P_denom,0.0)) +::STMT +LITERAL_FLOAT:0.0,1.0 +INT:int842,int402 +cast.FLOAT(rand(int842,int402,0.0,1.0)) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:0.0,2.0 +*(2.0,>=(linear_terms,0.0)) +::STMT +MATRIX:p,q,r,parsertemp1947 +FLOAT:norm_r2,alpha +LITERAL_FLOAT:0.0 ++(-(0.0,+(r,*(alpha,q))),*(/(sum(parsertemp1947),norm_r2),p)) +::STMT +MATRIX:Y_prob,Y +*(*(rowSums(Y),Y_prob),Y_prob) +::STMT +MATRIX:g_new,g_old +/(sum(*(g_new,g_new)),sum(*(g_old,g_old))) +::STMT +MATRIX:r_LS +FLOAT:norm_r2_LS,p_LS,parsertemp170552,lambda_LS ++(r_LS,*(/(norm_r2_LS,*(p_LS,p_LS)),+(*(parsertemp170552,p_LS),*(lambda_LS,p_LS)))) +::STMT +MATRIX:samples_vs_runs_map,X_samples_sq_norms,centroids +FLOAT:int785 ++(X_samples_sq_norms,%*%(samples_vs_runs_map,rowSums(^(centroids,int785)))) +::STMT +FLOAT:e,epochs +LITERAL_FLOAT:1.0 +-(+(1.0,epochs),e) +::STMT +MATRIX:t,parsertemp171083 +FLOAT:float488,float22 +LITERAL_FLOAT:0.802853,2.515517 ++(2.515517,*(sqrt(*(float488,parsertemp171083)),+(0.802853,*(t,float22)))) +::STMT +MATRIX:W1_rand +LITERAL_FLOAT:0.07808688094430302 +*(0.07808688094430302,W1_rand) +::STMT +MATRIX:output1,dataset +LITERAL_FLOAT:0.16 +>=(abs(-(output1,dataset)),0.16) +::STMT +MATRIX:X,parsertemp438796 +t(*(ncol(X),parsertemp438796)) +::STMT +MATRIX:t,tmp +FLOAT:parsertemp477715,int875,x,X,Y,K +*(cast.FLOAT(t),+(*(-(K,Y),-(int875,parsertemp477715)),*(cast.FLOAT(tmp),/(x,X)))) +::STMT +MATRIX:parsertemp12846,F +FLOAT:W +LITERAL_FLOAT:2.0 +/(^(-(F,/(parsertemp12846,W)),2.0),/(%*%(rowSums(F),colSums(F)),sum(F))) +::STMT +MATRIX:std,rad ++(cast.FLOAT(std),cast.FLOAT(rad)) +::STMT +LITERAL_FLOAT:1.0,2.0,150.0 +*(^(150.0,2.0),-(150.0,1.0)) +::STMT +MATRIX:meanDiff,parsertemp570372,parsertemp570375 +LITERAL_FLOAT:0.5,-0.5 +-(*(-0.5,parsertemp570372),*(0.5,%*%(%*%(meanDiff,parsertemp570375),t(meanDiff)))) +::STMT +MATRIX:parsertemp570372 +LITERAL_FLOAT:-0.5 +*(-0.5,parsertemp570372) +::STMT +MATRIX:parsertemp31912,I +FLOAT:eAvg +/(/(t(%*%(parsertemp31912,I)),t(colSums(I))),eAvg) +::STMT +MATRIX:node +LITERAL_FLOAT:1.0,2.0 ++(*(node,2.0),1.0) +::STMT +MATRIX:lambda,g,beta +LITERAL_FLOAT:2.0 +^(+(g,*(cast.FLOAT(lambda),cast.FLOAT(beta))),2.0) +::STMT +MATRIX:p,Z +FLOAT:norm_r2 +/(norm_r2,cast.FLOAT(%*%(t(p),%*%(Z,p)))) +::STMT +MATRIX:posSamples,posSampleMeans +LITERAL_FLOAT:2.0,2000.0 +-(colSums(^(posSamples,2.0)),*(2000.0,^(posSampleMeans,2.0))) +::STMT +MATRIX:parsertemp170665,residual_matrix,curr_prediction +LITERAL_FLOAT:0.0,2.0 +/(^(sum(residual_matrix),2.0),+(sum(*(curr_prediction,parsertemp170665)),0.0)) +::STMT +MATRIX:m_err_vars,m_err_mean +LITERAL_FLOAT:-0.001 +/(-(-0.001,cast.FLOAT(m_err_mean)),cast.FLOAT(m_err_vars)) +::STMT +MATRIX:S,V +FLOAT:int586,delta2 +LITERAL_FLOAT:2.0 +*(sum(^(V,2.0)),-(delta2,sum(^(S,int586)))) +::STMT +MATRIX:parsertemp389212,parsertemp389215 +LITERAL_FLOAT:2.0,1058.0 +-(parsertemp389215,^(/(parsertemp389212,1058.0),2.0)) +::STMT +MATRIX:avg_res_Y,means,Y_counts,Y +LITERAL_FLOAT:2.0 +colSums(^(-(-(Y,means),%*%(Y_counts,avg_res_Y)),2.0)) +::STMT +FLOAT:w_i +LITERAL_FLOAT:5.0 +-(w_i,5.0) +::STMT +MATRIX:r,scale_X,shift_X,y,parsertemp116004 +LITERAL_FLOAT:2.0 +^(+(*(scale_X,%*%(parsertemp116004,y)),*(cast.FLOAT(r),shift_X)),2.0) +::STMT +MATRIX:S,X +LITERAL_FLOAT:1.0,2.0 +/(^(diag(S),2.0),-(nrow(X),1.0)) +::STMT +MATRIX:2699_dscores,parsertemp459193,parsertemp459183,parsertemp459190,2703_X,2703_W +LITERAL_FLOAT:5.0E-4 ++(%*%(t(2703_X),*(*(parsertemp459193,parsertemp459190),%*%(2699_dscores,parsertemp459183))),*(5.0E-4,2703_W)) +::STMT +MATRIX:parsertemp285809,p_CG,z +FLOAT:parsertemp285799,2235_sq_root_d,parsertemp285814 +LITERAL_FLOAT:0.5 ++(*(0.5,cast.FLOAT(%*%(z,parsertemp285809))),*(parsertemp285814,/(+(parsertemp285799,2235_sq_root_d),cast.FLOAT(p_CG)))) +::STMT +FLOAT:obj +LITERAL_FLOAT:1.0E-10 +*(1.0E-10,obj) +::STMT +MATRIX:n_event_stratum,n_risk_stratum,n_risk +*(*(n_risk,n_event_stratum),-(n_risk_stratum,n_event_stratum)) +::STMT +MATRIX:y_hat,X +*(-(X,y_hat),-(X,y_hat)) +::STMT +FLOAT:n +LITERAL_FLOAT:4.0 +-(n,4.0) +::STMT +MATRIX:X,X_nonzero_ind +LITERAL_FLOAT:0.0 +-(nrow(X),sum(!=(rowSums(X_nonzero_ind),0.0))) +::STMT +MATRIX:X,permut +FLOAT:n +*(/(colSums(%*%(permut,X)),n),/(colSums(%*%(permut,X)),n)) +::STMT +MATRIX:W1_rand,stds,parsertemp401986 +LITERAL_FLOAT:0.06835859270246632 +t(%*%(*(0.06835859270246632,W1_rand),t(/(parsertemp401986,stds)))) +::STMT +MATRIX:X +FLOAT:int416 +LITERAL_FLOAT:1.0 +sqrt(/(colSums(^(X,int416)),-(nrow(X),1.0))) +::STMT +MATRIX:U_OE +rowSums(rowSums(U_OE)) +::STMT +MATRIX:Y,Xd,Xw +FLOAT:step_sz +LITERAL_FLOAT:1.0 +-(1.0,*(Y,+(Xw,*(step_sz,Xd)))) +::STMT +FLOAT:s +LITERAL_FLOAT:-1.0,3.0 +^(3.0,*(s,-1.0)) +::STMT +LITERAL_FLOAT:1.000100010001 +1.000100010001 +::STMT +MATRIX:252_Y,252_K +FLOAT:252_X,float532 +LITERAL_FLOAT:1.0 +*(-(*(cast.FLOAT(252_K),-(252_X,252_X)),-(cast.FLOAT(252_Y),cast.FLOAT(252_Y))),-(1.0,/(-(float532,252_X),-(252_X,252_X)))) +::STMT +FLOAT:window_size,parsertemp181047,parsertemp181040 +LITERAL_FLOAT:1.0,2.0 +sqrt(*(*(2.0,window_size),-(1.0,/(parsertemp181040,parsertemp181047)))) +::STMT +MATRIX:b_cumulant,Y,natural_parameters +sum(-(*(Y,natural_parameters),b_cumulant)) +::STMT +LITERAL_FLOAT:0.07808688094430302 +0.07808688094430302 +::STMT +MATRIX:y_corr,is_zero_y_corr +FLOAT:float599,float550,float570,int718 +LITERAL_FLOAT:1.0,0.5 ++(*(*(y_corr,-(float599,is_zero_y_corr)),-(1.0,>=(y_corr,float550))),*(0.5,+(<=(y_corr,int718),>=(y_corr,float570)))) +::STMT +MATRIX:2212_oY +!(2212_oY) +::STMT +MATRIX:parsertemp129475 +LITERAL_FLOAT:1.0,2.0 +-(+(*(max(parsertemp129475),2.0),1.0),1.0) +::STMT +FLOAT:idx +LITERAL_FLOAT:1.0,1024.0 ++(-(1024.0,idx),1.0) +::STMT +MATRIX:resp,mean,X +t(*(-(X,mean),resp)) +::STMT +MATRIX:r_CG,q_CG +FLOAT:alpha_CG +LITERAL_FLOAT:2.0 +^(+(r_CG,*(alpha_CG,cast.FLOAT(q_CG))),2.0) +::STMT +MATRIX:221_CFreqs,221_present_domain_vals_mat,parsertemp27770 +FLOAT:int792 +LITERAL_FLOAT:1000.0 +/(sum(*(-(221_CFreqs,int792),%*%(221_present_domain_vals_mat,parsertemp27770))),-(1000.0,nrow(221_present_domain_vals_mat))) +::STMT +MATRIX:Xtest_dists +LITERAL_FLOAT:0.0 +<(0.0,Xtest_dists) +::STMT +MATRIX:selCols,ncCnts,maxsc +FLOAT:parsertemp31781 +LITERAL_FLOAT:0.0 +&(selCols,|(>(ncCnts,0.0),>(maxsc,parsertemp31781))) +::STMT +MATRIX:b,X,sb +*(X,exp(%*%(X,+(b,sb)))) +::STMT +MATRIX:R,addedE,parsertemp40215 +FLOAT:level ++(R,rowSums(*(==(parsertemp40215,level),t(addedE)))) +::STMT +FLOAT:step +LITERAL_FLOAT:0.9 +*(step,0.9) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:0.0,2.0 +/(-(0.0,^(finite_linear_terms,2.0)),2.0) +::STMT +MATRIX:Q1,IQR +LITERAL_FLOAT:1.5 +-(Q1,*(1.5,IQR)) +::STMT +LITERAL_FLOAT:1.0E-6 +1.0E-6 +::STMT +MATRIX:ytest +LITERAL_FLOAT:1.0,2.0 +^(/(cast.FLOAT(ytest),1.0),2.0) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:1.0,-1.0 +-(1.0,exp(*(exp(finite_linear_terms),-1.0))) +::STMT +MATRIX:ytest,yhat +LITERAL_FLOAT:2.0 +^(/(sum(-(ytest,yhat)),nrow(ytest)),2.0) +::STMT +MATRIX:CVars,CFreqs +LITERAL_FLOAT:1.0 +sum(*(-(CFreqs,1.0),CVars)) +::STMT +FLOAT:window_size,i,k ++(+(i,k),window_size) +::STMT +MATRIX:ss,X2 +/(nrow(X2),ss) +::STMT +MATRIX:X +LITERAL_FLOAT:3.0 +*(3.0,ncol(X)) +::STMT +MATRIX:grad +LITERAL_FLOAT:0.0,2.0 +sum(^(-(0.0,grad),2.0)) +::STMT +MATRIX:parsertemp129475,groupIndex +*(groupIndex,max(parsertemp129475)) +::STMT +FLOAT:run_index +LITERAL_FLOAT:1.0,48.0 +*(48.0,-(run_index,1.0)) +::STMT +MATRIX:scale_X,shift_X,X +FLOAT:int959 +LITERAL_FLOAT:2.0 ++(%*%(^(X,2.0),^(scale_X,2.0)),%*%(X,*(*(int959,scale_X),shift_X))) +::STMT +MATRIX:parsertemp171083 +FLOAT:float680 +LITERAL_FLOAT:0.010328,0.802853 ++(0.802853,*(sqrt(*(float680,parsertemp171083)),0.010328)) +::STMT +MATRIX:g +LITERAL_FLOAT:0.01 +*(0.01,cast.FLOAT(%*%(t(g),g))) +::STMT +LITERAL_FLOAT:1.0,2.0,2001.0 +-(^(2001.0,2.0),1.0) +::STMT +MATRIX:parsertemp539203,T,event +FLOAT:int620 +LITERAL_FLOAT:1.0,2.0,1.5 +/(^(/(*(parsertemp539203,int620),2.0),/(1.0,1.5)),/(-(max(T),min(T)),sum(event))) +::STMT +FLOAT:int263 +LITERAL_FLOAT:0.0,2.0 +INT:int475,parsertemp282730 +>(rand(parsertemp282730,int475,int263,2.0),0.0) +::STMT +MATRIX:p,q,g,z +FLOAT:pq,float62,tau_1 ++(+(*(*(float62,tau_1),pq),sum(*(z,q))),sum(*(g,p))) +::STMT +MATRIX:prob,pred +FLOAT:threshold +*(pred,>(prob,threshold)) +::STMT +MATRIX:out,parsertemp2798 +FLOAT:int94,int771,int10,int37 +sum(*(*(>(out,int771),-(int94,parsertemp2798)),*(>(out,int37),-(int10,parsertemp2798)))) +::STMT +MATRIX:Q,R +FLOAT:int517 +LITERAL_FLOAT:2.0 ++(rowSums(^(R,2.0)),t(rowSums(^(Q,int517)))) +::STMT +FLOAT:float812,parsertemp382948,parsertemp382957,loss_init,parsertemp382950 +/(-(loss_init,+(*(float812,parsertemp382948),*(parsertemp382950,parsertemp382957))),loss_init) +::STMT +MATRIX:n_corr,Y +FLOAT:int495 +LITERAL_FLOAT:0.0,0.5 ++(/(Y,+(rowSums(Y),==(n_corr,int495))),*(-(0.5,Y),==(rowSums(Y),0.0))) +::STMT +FLOAT:level +LITERAL_FLOAT:2.0 +-(level,2.0) +::STMT +MATRIX:linear_terms +FLOAT:var_power +LITERAL_FLOAT:1.0,2.0 +/(/(1.0,^(linear_terms,2.0)),-(1.0,var_power)) +::STMT +MATRIX:w_X,z_LS +LITERAL_FLOAT:2000.0 +*(/(2000.0,cast.FLOAT(%*%(w_X,z_LS))),z_LS) +::STMT +MATRIX:X +/(colSums(X),nrow(X)) +::STMT +MATRIX:parsertemp389300 +LITERAL_FLOAT:1.0,2.0 ++(exp(*(2.0,t(parsertemp389300))),1.0) +::STMT +LITERAL_FLOAT:5.0E-7 +5.0E-7 +::STMT +MATRIX:r +FLOAT:tolerance +LITERAL_FLOAT:2.0 +*(sum(^(r,2.0)),^(tolerance,2.0)) +::STMT +MATRIX:parsertemp42223,parsertemp42224,parsertemp42209 +FLOAT:parsertemp42210,meanY +sum(*(t(*(parsertemp42223,parsertemp42224)),-(+(parsertemp42209,parsertemp42210),meanY))) +::STMT +MATRIX:2134_left,2134_right +LITERAL_FLOAT:0.0,2.0 ++(/(^(sum(2134_left),2.0),+(nrow(2134_left),0.0)),/(^(sum(2134_right),2.0),+(nrow(2134_right),0.0))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0 +-(i,1.0) +::STMT +MATRIX:Y,parsertemp2798,Xw +FLOAT:int178 +LITERAL_FLOAT:0.0,1.0 +*(*(>(-(int178,parsertemp2798),0.0),-(1.0,*(Y,Xw))),Y) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,500.0 +/(+(classFeatureCounts,1.0),+(rowSums(classFeatureCounts),*(500.0,1.0))) +::STMT +MATRIX:termination_bitmap,parsertemp222665 +LITERAL_FLOAT:10.0 +*(10.0,max(*(parsertemp222665,termination_bitmap))) +::STMT +MATRIX:U,V +LITERAL_FLOAT:2.0 ++(sum(^(U,2.0)),sum(^(V,2.0))) +::STMT +MATRIX:instance,X,mask +*(-(X,instance),mask) +::STMT +MATRIX:X,parsertemp129018 +LITERAL_FLOAT:1.0 ++(*(max(parsertemp129018),-(ncol(X),1.0)),1.0) +::STMT +MATRIX:parsertemp220900,dY,parsertemp220899 +FLOAT:lr,momentum +LITERAL_FLOAT:2.0 +^(-(*(momentum,dY),*(lr,-(parsertemp220899,parsertemp220900))),2.0) +::STMT +MATRIX:parsertemp175066,scores,dprobs +*(dprobs,/(exp(-(scores,parsertemp175066)),rowSums(exp(scores)))) +::STMT +MATRIX:solution,X +sum(*(-(X,solution),-(X,solution))) +::STMT +MATRIX:Q,lambda,V,X,parsertemp149253 +*(V,+(%*%(t(X),-(Q,parsertemp149253)),*(lambda,V))) +::STMT +MATRIX:r,alpha,Hd +*(-(r,*(cast.FLOAT(alpha),Hd)),-(r,*(cast.FLOAT(alpha),Hd))) +::STMT +MATRIX:G,minDist +LITERAL_FLOAT:0.0 ++(G,*(!=(G,0.0),minDist)) +::STMT +MATRIX:parsertemp12846,F +FLOAT:W +LITERAL_FLOAT:2.0 +/(^(-(F,/(parsertemp12846,W)),2.0),/(%*%(rowSums(F),colSums(F)),sum(F))) +::STMT +MATRIX:parsertemp409532,ctab,parsertemp409528 +LITERAL_FLOAT:0.4 +*(parsertemp409532,>(/(parsertemp409528,rowSums(ctab)),0.4)) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:1.0,2.0 +/(1.0,-(1.0,^(linear_terms,2.0))) +::STMT +MATRIX:w,g +FLOAT:alpha +abs(-(w,/(g,alpha))) +::STMT +MATRIX:Xtrain,Xtest,X,Y +-(+(sum(X),sum(Y)),+(sum(Xtrain),sum(Xtest))) +::STMT +MATRIX:parsertemp42200,R +FLOAT:int137,meanX +LITERAL_FLOAT:0.5 +-(+(-(parsertemp42200,/(R,int137)),0.5),meanX) +::STMT +MATRIX:newbeta,lambda +LITERAL_FLOAT:2.0 +cast.FLOAT(%*%(t(lambda),^(newbeta,2.0))) +::STMT +MATRIX:W4_rand +LITERAL_FLOAT:0.08681986202598489 +*(0.08681986202598489,W4_rand) +::STMT +MATRIX:scale_X +cast.FLOAT(diag(scale_X)) +::STMT +MATRIX:q,r +FLOAT:alpha +LITERAL_FLOAT:2.0 +sum(^(+(r,*(alpha,q)),2.0)) +::STMT +MATRIX:_sbcvar12,parsertemp13660 +FLOAT:float545,44_meanX +LITERAL_FLOAT:999.0 +t(*(/(_sbcvar12,999.0),-(+(parsertemp13660,float545),44_meanX))) +::STMT +MATRIX:2701_mask,2700_W,2726_dpred,parsertemp459177,2699_probs,2702_X +LITERAL_FLOAT:0.0,0.5 +*(*(>(2702_X,0.0),/(2701_mask,0.5)),%*%(-(*(2726_dpred,2699_probs),*(2699_probs,parsertemp459177)),t(2700_W))) +::STMT +MATRIX:std,sts,rad +FLOAT:delta2 +/(-(delta2,sts),+(cast.FLOAT(std),cast.FLOAT(rad))) +::STMT +MATRIX:w,out +LITERAL_FLOAT:0.5,0.001 +*(0.001,+(*(0.5,cast.FLOAT(out)),*(0.5,cast.FLOAT(w)))) +::STMT +MATRIX:A +FLOAT:parsertemp12882 +LITERAL_FLOAT:1.0 +/(*(parsertemp12882,nrow(A)),-(nrow(A),1.0)) +::STMT +MATRIX:eVals,eVecs +FLOAT:int192 +%*%(%*%(eVecs,diag(^(eVals,int192))),t(eVecs)) +::STMT +MATRIX:log_det_chol +FLOAT:int840,int149 +INT:int669,parsertemp436708 +*(rand(int669,parsertemp436708,int149,int840),log_det_chol) +::STMT +MATRIX:b,H_inv +/(b,sqrt(diag(H_inv))) +::STMT +FLOAT:alpha +LITERAL_FLOAT:1.0 +-(1.0,alpha) +::STMT +MATRIX:X +LITERAL_FLOAT:2.0 +t(rowSums(^(X,2.0))) +::STMT +MATRIX:X +FLOAT:lambda +LITERAL_FLOAT:2.0,0.5 +*(*(0.5,lambda),sum(^(X,2.0))) +::STMT +MATRIX:X_Xd_exp_Xb_rev_agg,select,X_exp_Xb_rev_agg,D_r_rev,Xd_exp_Xb_rev_agg +LITERAL_FLOAT:2.0 +-(/(%*%(select,X_Xd_exp_Xb_rev_agg),D_r_rev),/(*(X_exp_Xb_rev_agg,%*%(select,Xd_exp_Xb_rev_agg)),^(D_r_rev,2.0))) +::STMT +MATRIX:linear_terms,Y +LITERAL_FLOAT:0.0,1.0 +*(linear_terms,-(1.0,==(Y,0.0))) +::STMT +MATRIX:C,I +FLOAT:ss ++(%*%(t(C),C),*(I,ss)) +::STMT +MATRIX:X +LITERAL_FLOAT:1.8378770664093453 +*(ncol(X),1.8378770664093453) +::STMT +MATRIX:parsertemp31104,parsertemp31106 +FLOAT:int180 +LITERAL_FLOAT:1999.0,2.0 +^(/(-(colSums(parsertemp31104),*(int180,parsertemp31106)),1999.0),2.0) +::STMT +LITERAL_FLOAT:-0.284496736 +-0.284496736 +::STMT +FLOAT:width +LITERAL_FLOAT:2.0 +*(2.0,^(width,2.0)) +::STMT +MATRIX:parsertemp560880,parsertemp560876,parsertemp560863,parsertemp560868 +FLOAT:float715,float721,int346,int38 +LITERAL_FLOAT:1.0,2.0 +*(*(*(/(float715,parsertemp560868),+(float721,parsertemp560876)),-(*(int346,parsertemp560863),1.0)),exp(/(*(parsertemp560880,int38),2.0))) +::STMT +LITERAL_FLOAT:0.45 +0.45 +::STMT +MATRIX:COMPONENTS,id +-(==(id,cast.FLOAT(id)),cast.FLOAT(diag(diag(COMPONENTS)))) +::STMT +MATRIX:parsertemp130875 +LITERAL_FLOAT:1.0,4.0 +-(+(*(max(parsertemp130875),4.0),1.0),1.0) +::STMT +FLOAT:int84,se_g1,int223,int467,int512,parsertemp113,wt +sqrt(/(*(*(int512,parsertemp113),^(se_g1,int84)),*(+(wt,int467),-(wt,int223)))) +::STMT +MATRIX:gs +LITERAL_FLOAT:-0.5 +*(-0.5,cast.FLOAT(gs)) +::STMT +MATRIX:s,parsertemp44016,d +cast.FLOAT(%*%(t(-(s,parsertemp44016)),d)) +::STMT +MATRIX:samples_vs_runs_map,X_samples_sq_norms,parsertemp222444,is_row_in_samples,parsertemp222440 +LITERAL_FLOAT:2.0 +*(is_row_in_samples,-(+(X_samples_sq_norms,%*%(samples_vs_runs_map,parsertemp222440)),*(2.0,rowSums(parsertemp222444)))) +::STMT +MATRIX:X,parsertemp16892 +FLOAT:int275 +%*%(sqrt(rowSums(^(X,int275))),t(sqrt(rowSums(parsertemp16892)))) +::STMT +MATRIX:eVals +LITERAL_FLOAT:-1.0 +diag(^(eVals,-1.0)) +::STMT +MATRIX:s,w,wnew,parsertemp44079 +FLOAT:int330,C +LITERAL_FLOAT:0.5 ++(*(0.5,%*%(t(wnew),+(w,s))),*(C,sum(*(parsertemp44079,int330)))) +::STMT +MATRIX:X +FLOAT:x +LITERAL_FLOAT:1.0 ++(-(nrow(X),sum(>=(X,x))),1.0) +::STMT +MATRIX:parsertemp503368,B +LITERAL_FLOAT:-1.0,2.0 +sum(^(*(%*%(parsertemp503368,B),-1.0),2.0)) +::STMT +LITERAL_FLOAT:0.9 +0.9 +::STMT +MATRIX:g0_2,g0_1 +FLOAT:tol +LITERAL_FLOAT:2.0 +*(sum(^(+(g0_1,g0_2),2.0)),^(tol,2.0)) +::STMT +FLOAT:wcss +LITERAL_FLOAT:1.0E-5 +*(1.0E-5,wcss) +::STMT +MATRIX:WM,CVars,parsertemp31290,CFreqs,parsertemp31285 +LITERAL_FLOAT:1.0 +/(/(sum(*(CFreqs,parsertemp31285)),-(nrow(CFreqs),1.0)),/(sum(*(parsertemp31290,CVars)),-(sum(WM),nrow(CFreqs)))) +::STMT +MATRIX:q,ssX_p,scale_X,shift_X,X ++(*(scale_X,%*%(t(X),%*%(X,ssX_p))),*(cast.FLOAT(q),shift_X)) +::STMT +MATRIX:mean,X,parsertemp437224,weight +/(-(%*%(t(X),X),%*%(*(parsertemp437224,weight),mean)),sum(weight)) +::STMT +MATRIX:parsertemp43635,w +sqrt(sum(*(+(w,parsertemp43635),+(w,parsertemp43635)))) +::STMT +MATRIX:p_CG +FLOAT:parsertemp170148,int378,z,int271 +LITERAL_FLOAT:0.5 +*(0.5,/(-(*(z,int378),sqrt(parsertemp170148)),sum(^(p_CG,int271)))) +::STMT +MATRIX:252_Y +FLOAT:252_X,int54,int127,parsertemp32925,int877,parsertemp32915,float189,parsertemp32934,float807 ++(+(*(-(int877,parsertemp32915),cast.FLOAT(252_Y)),*(/(float807,252_X),cast.FLOAT(252_Y))),*(*(/(float189,252_X),-(int54,parsertemp32915)),+(*(parsertemp32925,int127),*(parsertemp32934,parsertemp32915)))) +::STMT +LITERAL_FLOAT:2.29128784747792 +2.29128784747792 +::STMT +MATRIX:parsertemp146931,184_dtemp,parsertemp146929,184_unnorm_probs,parsertemp146936,outr2 +%*%(t(outr2),-(*(*(parsertemp146929,parsertemp146931),/(184_unnorm_probs,parsertemp146936)),*(/(184_unnorm_probs,parsertemp146936),rowSums(184_dtemp)))) +::STMT +MATRIX:252_Y,252_X +LITERAL_FLOAT:4.5 +*(/(-(4.5,cast.FLOAT(252_X)),-(cast.FLOAT(252_X),cast.FLOAT(252_X))),cast.FLOAT(252_Y)) +::STMT +MATRIX:parsertemp16755 +LITERAL_FLOAT:2.0 +^(2.0,cast.FLOAT(parsertemp16755)) +::STMT +MATRIX:WM,CFreqs +-(sum(WM),nrow(CFreqs)) +::STMT +MATRIX:IQR +LITERAL_FLOAT:1.5 +*(1.5,IQR) +::STMT +MATRIX:sv,out +LITERAL_FLOAT:0.5 +*(0.5,sum(*(*(sv,out),*(sv,out)))) +::STMT +MATRIX:W1_rand,X,parsertemp394884,parsertemp394894 +FLOAT:float244 +LITERAL_FLOAT:0.08146881698903526 +%*%(*(0.08146881698903526,W1_rand),t(/(-(X,parsertemp394884),+(parsertemp394894,float244)))) +::STMT +MATRIX:u,parsertemp500604 +FLOAT:alpha,tau +LITERAL_FLOAT:0.0 +*(*(parsertemp500604,-(abs(u),/(tau,alpha))),>(-(abs(u),/(tau,alpha)),0.0)) +::STMT +MATRIX:V,W,H,parsertemp10749 +FLOAT:Eps +*(W,/(%*%(V,t(H)),+(%*%(W,parsertemp10749),Eps))) +::STMT +FLOAT:e,int622,mu,epochs +LITERAL_FLOAT:0.999 ++(mu,/(-(0.999,mu),-(+(int622,epochs),e))) +::STMT +MATRIX:hubs +FLOAT:parsertemp30953 +LITERAL_FLOAT:2.0 +sum(^(-(/(hubs,parsertemp30953),hubs),2.0)) +::STMT +MATRIX:_funvar2124,parsertemp437267,parsertemp437272 +exp(-(+(_funvar2124,parsertemp437267),parsertemp437272)) +::STMT +MATRIX:q_CG +FLOAT:alpha_CG +*(alpha_CG,cast.FLOAT(q_CG)) +::STMT +FLOAT:n_features +LITERAL_FLOAT:1.0,2.0 +/(*(n_features,+(n_features,1.0)),2.0) +::STMT +MATRIX:posSampleVariances,negSampleVariances +FLOAT:int496,int127,int944,int391 +LITERAL_FLOAT:2.0,3352500.0,990000.0 +/(^(+(/(posSampleVariances,int496),/(negSampleVariances,int127)),2.0),+(/(^(posSampleVariances,int391),990000.0),/(^(negSampleVariances,int944),3352500.0))) +::STMT +MATRIX:avg_X_cols,parsertemp1513 +FLOAT:int956,n +LITERAL_FLOAT:1.0 +/(-(t(colSums(parsertemp1513)),*(n,^(avg_X_cols,int956))),-(n,1.0)) +::STMT +FLOAT:n_group_cols +LITERAL_FLOAT:1.0,3.0 +-(+(3.0,n_group_cols),1.0) +::STMT +FLOAT:float725,float58 +LITERAL_FLOAT:0.0,0.5 +INT:int612,int943,int668,int29 +*(0.5,%*%(t(rand(int943,int612,float725,float58)),rand(int29,int668,0.0,0.0))) +::STMT +FLOAT:deviance_nodisp +LITERAL_FLOAT:0.1,1.0E-6 +*(1.0E-6,+(deviance_nodisp,0.1)) +::STMT +MATRIX:tmp_Xw,Y,parsertemp2773,Xw +LITERAL_FLOAT:0.0,1.0 +*(-(1.0,*(Y,+(Xw,parsertemp2773))),>(-(1.0,*(Y,tmp_Xw)),0.0)) +::STMT +MATRIX:parsertemp410976,W,H,X +/(*(H,%*%(t(W),/(X,parsertemp410976))),t(colSums(W))) +::STMT +MATRIX:surv +LITERAL_FLOAT:1.0 +sqrt(-(1.0,surv)) +::STMT +MATRIX:parsertemp539203 +LITERAL_FLOAT:-1.0,2.0,0.6666666666666666 +^(/(*(parsertemp539203,-1.0),2.0),0.6666666666666666) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,253.0 +-(+(i,253.0),1.0) +::STMT +MATRIX:r,c,F +LITERAL_FLOAT:2.0 +^(-(F,/(%*%(r,c),sum(F))),2.0) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:50.0 ++(rowSums(classFeatureCounts),50.0) +::STMT +MATRIX:parsertemp389219,X,permut +FLOAT:parsertemp389220,n +LITERAL_FLOAT:1.0E-17 +/(-(%*%(permut,X),/(colSums(X),n)),+(sqrt(/(parsertemp389219,parsertemp389220)),1.0E-17)) +::STMT +MATRIX:r,c,E,F +LITERAL_FLOAT:2.0 +sum(/(^(-(F,E),2.0),/(%*%(r,c),sum(F)))) +::STMT +MATRIX:W +FLOAT:parsertemp112,int710,parsertemp91 +LITERAL_FLOAT:2.0,3.0,4.0,5.0 +/(*(*(4.0,-(parsertemp112,int710)),^(sqrt(parsertemp91),2.0)),*(+(sum(W),5.0),-(sum(W),3.0))) +::STMT +MATRIX:classCounts +LITERAL_FLOAT:1000.0 +/(classCounts,1000.0) +::STMT +MATRIX:P,I +LITERAL_FLOAT:1.0 +&(I,<=(rowSums(P),1.0)) +::STMT +FLOAT:beg +LITERAL_FLOAT:1.0,256.0 +-(+(beg,256.0),1.0) +::STMT +MATRIX:parsertemp410977,W,H,parsertemp410974 +t(/(*(H,%*%(parsertemp410974,parsertemp410977)),t(colSums(W)))) +::STMT +MATRIX:ss,se,e +LITERAL_FLOAT:40.0 +/(/(se,ss),/(sum(e),40.0)) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0 +-(0.0,/(t(colSums(X)),nrow(X))) +::STMT +MATRIX:X +FLOAT:x +LITERAL_FLOAT:1.0 +*(/(-(x,X),-(X,X)),-(1.0,/(-(x,X),-(X,X)))) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:105.0 ++(rowSums(classFeatureCounts),105.0) +::STMT +MATRIX:m_iter_err_sum,m_err ++(colSums(m_err),m_iter_err_sum) +::STMT +MATRIX:w,out +FLOAT:int362,int565 +LITERAL_FLOAT:0.5 ++(*(0.5,sum(^(out,int362))),*(0.5,sum(^(w,int565)))) +::STMT +MATRIX:M +LITERAL_FLOAT:2.0 +<(rowSums(M),2.0) +::STMT +LITERAL_FLOAT:1.0,100.0 ++(+(100.0,100.0),1.0) +::STMT +MATRIX:2663_X +LITERAL_FLOAT:1.0 +*(1.0,ncol(2663_X)) +::STMT +MATRIX:Ileft,_funvar2707 +FLOAT:numI +*(/(rowSums(Ileft),numI),_funvar2707) +::STMT +MATRIX:parsertemp31276,CVars +FLOAT:int850,parsertemp31269,W,parsertemp31270 +LITERAL_FLOAT:1.0 +-(1.0,/(sum(*(parsertemp31276,CVars)),*(-(W,int850),/(parsertemp31269,parsertemp31270)))) +::STMT +LITERAL_FLOAT:0.010328 +0.010328 +::STMT +MATRIX:parsertemp220863,parsertemp220864,Hdiff,beta +FLOAT:int935 +LITERAL_FLOAT:2.0,1.0E20 +*(*(*(2.0,>=(Hdiff,int935)),==(+(parsertemp220863,parsertemp220864),1.0E20)),beta) +::STMT +MATRIX:d,sb +LITERAL_FLOAT:2.0 +*(2.0,sum(*(sb,d))) +::STMT +MATRIX:parsertemp31190,parsertemp31197 +FLOAT:int867,int372 +LITERAL_FLOAT:3.42951E11,2.0,3.37275E9 ++(/(^(/(parsertemp31190,int372),2.0),3.42951E11),/(^(/(parsertemp31197,int867),2.0),3.37275E9)) +::STMT +LITERAL_FLOAT:2.0,100.0 ++(+(100.0,100.0),2.0) +::STMT +MATRIX:linear_terms +FLOAT:link_power +LITERAL_FLOAT:0.0,-1.0 +-(0.0,^(linear_terms,/(-1.0,link_power))) +::STMT +LITERAL_FLOAT:0.0 +/(0.0,0.0) +::STMT +MATRIX:LT,Y,parsertemp149320 +*(Y,-(LT,parsertemp149320)) +::STMT +MATRIX:posSamples,posSampleMeans +FLOAT:int113,int998 +LITERAL_FLOAT:99.0,100.0 +/(-(colSums(^(posSamples,int998)),*(100.0,^(posSampleMeans,int113))),99.0) +::STMT +MATRIX:cumLeftHist,parsertemp132494,leftHist,outBucket +%*%(==(outBucket,t(parsertemp132494)),-(cumLeftHist,leftHist)) +::STMT +MATRIX:U +LITERAL_FLOAT:1.0E-6 +*(1.0E-6,U) +::STMT +MATRIX:D,ZERODIAG +FLOAT:int802 +LITERAL_FLOAT:1.0 +sum(*(/(1.0,+(D,int802)),ZERODIAG)) +::STMT +MATRIX:_sbcvar0 +LITERAL_FLOAT:2000.0 +/(_sbcvar0,2000.0) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0E-14 +>(abs(-(X,round(X))),1.0E-14) +::STMT +MATRIX:X,Y +LITERAL_FLOAT:1.0 ++(+(ncol(X),ncol(Y)),1.0) +::STMT +MATRIX:lambda,beta +*(cast.FLOAT(lambda),cast.FLOAT(beta)) +::STMT +MATRIX:ss,se,e +LITERAL_FLOAT:20.0 +/(/(se,ss),/(sum(e),20.0)) +::STMT +MATRIX:parsertemp43632,X,y +LITERAL_FLOAT:0.0,2.0 +INT:int440,int584 ++(rand(int440,int584,0.0,0.0),*(2.0,%*%(t(X),*(parsertemp43632,y)))) +::STMT +MATRIX:parsertemp477718,parsertemp477715,parsertemp477724,X,Y,parsertemp477733,K,parsertemp477730 +FLOAT:x +LITERAL_FLOAT:1.0 ++(*(-(*(K,parsertemp477724),-(Y,Y)),-(1.0,/(parsertemp477715,parsertemp477718))),*(+(*(parsertemp477730,parsertemp477733),-(Y,Y)),/(-(x,X),-(X,X)))) +::STMT +MATRIX:R,dssm +FLOAT:2_n +/(2_n,-(R,dssm)) +::STMT +MATRIX:n_risk_stratum,n_risk_i2j +FLOAT:I_i1i2 +-(I_i1i2,/(n_risk_i2j,n_risk_stratum)) +::STMT +MATRIX:parsertemp410978,W,H +t(/(*(H,t(parsertemp410978)),t(colSums(W)))) +::STMT +FLOAT:parsertemp89,parsertemp88,parsertemp83,parsertemp84 +LITERAL_FLOAT:2.0 +^(sqrt(/(*(parsertemp83,parsertemp84),*(parsertemp88,parsertemp89))),2.0) +::STMT +MATRIX:Y +LITERAL_FLOAT:0.0 +<=(Y,0.0) +::STMT +MATRIX:parsertemp383173 +FLOAT:reg,parsertemp383181,loss_init +/(-(loss_init,+(sum(parsertemp383173),*(reg,parsertemp383181))),loss_init) +::STMT +MATRIX:parsertemp437549,pred,parsertemp437666 +t(colSums(==(*(parsertemp437666,parsertemp437549),pred))) +::STMT +MATRIX:R,parsertemp40219,parsertemp40216,parsertemp40225 +FLOAT:level +/(+(R,rowSums(*(parsertemp40216,parsertemp40225))),-(R,rowSums(==(parsertemp40219,level)))) +::STMT +MATRIX:termination_bitmap,parsertemp222665 +max(*(parsertemp222665,termination_bitmap)) +::STMT +MATRIX:features,beta_unscaled +FLOAT:intercept,parsertemp176418 +LITERAL_FLOAT:3.0 ++(sqrt(parsertemp176418),*(3.0,+(%*%(features,beta_unscaled),intercept))) +::STMT +MATRIX:id +diag(diag(==(id,cast.FLOAT(id)))) +::STMT +MATRIX:parsertemp145796,parsertemp145794,y +/(sum(rowSums(*(parsertemp145794,parsertemp145796))),nrow(y)) +::STMT +MATRIX:Xd,out +FLOAT:dd,step_sz,wd +/(-(+(wd,*(step_sz,dd)),sum(out)),+(dd,sum(Xd))) +::STMT +MATRIX:X +LITERAL_FLOAT:4.0 +<=(X,4.0) +::STMT +MATRIX:X,y,logisticnew +LITERAL_FLOAT:1.0 +%*%(t(X),*(-(logisticnew,1.0),y)) +::STMT +MATRIX:classCounts +LITERAL_FLOAT:8000.0 +/(classCounts,8000.0) +::STMT +MATRIX:parsertemp570381,parsertemp570372,parsertemp570376,parsertemp570377 +LITERAL_FLOAT:0.5,-0.5 ++(parsertemp570381,-(*(-0.5,parsertemp570372),*(0.5,%*%(parsertemp570376,parsertemp570377)))) +::STMT +MATRIX:parsertemp31762,X2 +FLOAT:minSup +LITERAL_FLOAT:0.0 +&(>=(t(colSums(X2)),minSup),>(t(%*%(parsertemp31762,X2)),0.0)) +::STMT +MATRIX:parsertemp220896,W,Y,Z +FLOAT:lr +*(lr,-(*(Y,rowSums(W)),%*%(*(parsertemp220896,Z),Y))) +::STMT +MATRIX:X +FLOAT:N +t(/(colSums(X),N)) +::STMT +MATRIX:classesUnBalanced,classesBalanced +cast.FLOAT(-(classesUnBalanced,classesBalanced)) +::STMT +MATRIX:posSampleMeans +LITERAL_FLOAT:2.0,7000.0 +*(7000.0,^(posSampleMeans,2.0)) +::STMT +MATRIX:r,c,E,F +LITERAL_FLOAT:2.0 +sum(/(^(-(F,E),2.0),/(%*%(r,c),sum(F)))) +::STMT +MATRIX:scale_X,X,z,beta +*(*(cast.FLOAT(diag(scale_X)),+(cast.FLOAT(beta),cast.FLOAT(z))),X) +::STMT +MATRIX:X,y +LITERAL_FLOAT:0.0 +-(0.0,%*%(-(0.0,t(X)),y)) +::STMT +MATRIX:linear_terms +FLOAT:link_power +LITERAL_FLOAT:1.0 +^(linear_terms,/(1.0,link_power)) +::STMT +MATRIX:linear_terms +FLOAT:link_power +LITERAL_FLOAT:-1.0 +^(linear_terms,/(-1.0,link_power)) +::STMT +MATRIX:V1,parsertemp539081 +FLOAT:range,I_i1i2 +LITERAL_FLOAT:2.0 +/(sum(*(V1,-(I_i1i2,parsertemp539081))),^(range,2.0)) +::STMT +MATRIX:surv +LITERAL_FLOAT:0.5 +<=(surv,0.5) +::STMT +MATRIX:parsertemp410070,r +FLOAT:r2 +/(cast.FLOAT(%*%(t(r),+(r,parsertemp410070))),r2) +::STMT +MATRIX:W +FLOAT:parsertemp65,parsertemp66 +LITERAL_FLOAT:1.0,2.0,3.0 +*(*(-(sum(W),1.0),-(sum(W),2.0)),^(sqrt(/(parsertemp65,parsertemp66)),3.0)) +::STMT +MATRIX:C,Xm,parsertemp265701,XtZ +FLOAT:ZtZ_sum +sum(%*%(%*%(Xm,%*%(C,parsertemp265701)),t(/(XtZ,ZtZ_sum)))) +::STMT +MATRIX:R,dsep,dssm +FLOAT:2_eAvg +LITERAL_FLOAT:1.0 +-(/(/(+(R,dsep),-(R,dssm)),2_eAvg),1.0) +::STMT +MATRIX:pred +LITERAL_FLOAT:1.0,1.0E-10 +/(1.0,+(pred,1.0E-10)) +::STMT +MATRIX:t_gp,parsertemp171332,pt_gp,parsertemp171331,Y,the_gauss_exp,parsertemp171327,parsertemp171316 +LITERAL_FLOAT:2.0,0.25,0.3989422804014327 +/(*(0.3989422804014327,+(-(Y,parsertemp171327),*(parsertemp171331,parsertemp171332))),*(*(0.25,*(t_gp,parsertemp171316)),-(2.0,*(the_gauss_exp,pt_gp)))) +::STMT +MATRIX:p,r,parsertemp503395,Z +FLOAT:norm_r2 ++(r,*(/(norm_r2,cast.FLOAT(parsertemp503395)),%*%(Z,p))) +::STMT +MATRIX:Xtest_dists +LITERAL_FLOAT:1.0 +<=(Xtest_dists,1.0) +::STMT +MATRIX:parsertemp43626 +LITERAL_FLOAT:-1.0 +*(-1.0,sum(parsertemp43626)) +::STMT +MATRIX:parsertemp415524,y +FLOAT:intercept +LITERAL_FLOAT:2.0 +sum(^(-(y,+(parsertemp415524,intercept)),2.0)) +::STMT +MATRIX:parsertemp279509 +FLOAT:int374 +LITERAL_FLOAT:1000.0,100.0 +*(/(sum(==(parsertemp279509,int374)),1000.0),100.0) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0 +rowSums(!=(X,0.0)) +::STMT +MATRIX:col_nonzeros,U,parsertemp382849,V,parsertemp382852 +LITERAL_FLOAT:1.0E-6 ++(t(%*%(t(U),*(parsertemp382849,parsertemp382852))),*(*(1.0E-6,V),col_nonzeros)) +::STMT +MATRIX:R,parsertemp72406 +-(%*%(t(R),R),diag(parsertemp72406)) +::STMT +FLOAT:log_ten,parsertemp169812 +LITERAL_FLOAT:0.5 +round(-(/(parsertemp169812,log_ten),0.5)) +::STMT +MATRIX:W,X,H +FLOAT:eps +%*%(t(W),/(X,+(%*%(W,H),eps))) +::STMT +MATRIX:is_LT_infinite,Y_prob,Y,parsertemp171293,flip_pos +rowSums(*(*(Y,%*%(Y_prob,flip_pos)),+(*(Y_prob,parsertemp171293),is_LT_infinite))) +::STMT +MATRIX:prevTK2,X2 +colSums(==(%*%(X2,t(prevTK2)),t(rowSums(prevTK2)))) +::STMT +MATRIX:lambda,p_CG,shift_X,parsertemp170070,temp_CG +*(p_CG,+(+(*(lambda,p_CG),%*%(parsertemp170070,temp_CG)),%*%(shift_X,temp_CG))) +::STMT +LITERAL_FLOAT:2001.0 +sqrt(2001.0) +::STMT +MATRIX:W4_rand +LITERAL_FLOAT:0.08709382882250233 +*(0.08709382882250233,W4_rand) +::STMT +MATRIX:parsertemp414371,scale_X +LITERAL_FLOAT:0.0,200.0 +*(-(0.0,/(t(parsertemp414371),200.0)),scale_X) +::STMT +MATRIX:r,c,_sbcvar78 +LITERAL_FLOAT:2.0,10000.0 +^(-(_sbcvar78,/(%*%(r,c),10000.0)),2.0) +::STMT +MATRIX:linear_terms +FLOAT:var_power,float434 +LITERAL_FLOAT:1.0 +/(exp(*(linear_terms,-(float434,var_power))),-(1.0,var_power)) +::STMT +MATRIX:samples_vs_runs_map,centroid_placer,X_samples +*(X_samples,%*%(samples_vs_runs_map,%*%(centroid_placer,X_samples))) +::STMT +FLOAT:int463,parsertemp40812,m2 +LITERAL_FLOAT:3.0 +^(sqrt(*(/(int463,parsertemp40812),m2)),3.0) +::STMT +MATRIX:X_nonzero_ind +LITERAL_FLOAT:0.0 +sum(!=(t(colSums(X_nonzero_ind)),0.0)) +::STMT +MATRIX:CMeans,CFreqs +FLOAT:parsertemp31266,W +LITERAL_FLOAT:2.0 +*(CFreqs,^(-(CMeans,/(parsertemp31266,W)),2.0)) +::STMT +MATRIX:parsertemp386449,corePts +FLOAT:int440 +LITERAL_FLOAT:0.0,1.0 +&(==(t(corePts),0.0),>(colSums(>(parsertemp386449,int440)),1.0)) +::STMT +MATRIX:output,output1 +LITERAL_FLOAT:0.1 +>=(abs(-(output,output1)),0.1) +::STMT +MATRIX:codes,codebook +*(ncol(codes),ncol(codebook)) +::STMT +MATRIX:p_LS,parsertemp170551,X +FLOAT:lambda_LS +*(p_LS,+(%*%(%*%(parsertemp170551,X),p_LS),*(lambda_LS,p_LS))) +::STMT +FLOAT:parsertemp13703 +LITERAL_FLOAT:1.0,1000.0 +*(-(1000.0,1.0),/(*(parsertemp13703,1000.0),-(1000.0,1.0))) +::STMT +MATRIX:xs +FLOAT:252_x +LITERAL_FLOAT:1.0,10.0 ++(-(10.0,sum(>=(xs,252_x))),1.0) +::STMT +MATRIX:parsertemp1510,scale_X +FLOAT:n +LITERAL_FLOAT:-1.0 +*(*(/(t(parsertemp1510),n),-1.0),scale_X) +::STMT +MATRIX:posSampleVariances,negSampleVariances +FLOAT:int842,int402,int569,int812 +LITERAL_FLOAT:2.0,3352500.0,990000.0 +/(^(+(/(posSampleVariances,int569),/(negSampleVariances,int842)),2.0),+(/(^(posSampleVariances,int402),990000.0),/(^(negSampleVariances,int812),3352500.0))) +::STMT +LITERAL_FLOAT:0.189269 +0.189269 +::STMT +FLOAT:k,kmax,start_stepsize +LITERAL_FLOAT:1.0 +*(-(1.0,/(k,kmax)),start_stepsize) +::STMT +MATRIX:p_CG,z +FLOAT:trust_delta_sq,pp_CG +-(*(cast.FLOAT(%*%(p_CG,z)),cast.FLOAT(%*%(p_CG,z))),*(pp_CG,-(cast.FLOAT(z),trust_delta_sq))) +::STMT +FLOAT:idx +LITERAL_FLOAT:1.0,1048.0 ++(-(1048.0,idx),1.0) +::STMT +LITERAL_FLOAT:0.0,0.025 +INT:parsertemp410939,rnk +rand(parsertemp410939,rnk,0.0,0.025) +::STMT +MATRIX:P,parsertemp220844,ZERODIAG,beta +LITERAL_FLOAT:1.0E-12 +/(*(exp(*(parsertemp220844,beta)),ZERODIAG),+(rowSums(*(P,ZERODIAG)),1.0E-12)) +::STMT +MATRIX:X,y +LITERAL_FLOAT:0.0 +t(-(0.0,%*%(t(X),y))) +::STMT +MATRIX:g0_1,parsertemp410117 +t(+(g0_1,t(colSums(parsertemp410117)))) +::STMT +MATRIX:A +LITERAL_FLOAT:1.0 ++(length(A),1.0) +::STMT +MATRIX:252_Y,252_X +FLOAT:252_X,252_K,int803 +LITERAL_FLOAT:4.5 +*(+(*(-(int803,252_K),-(252_X,252_X)),-(cast.FLOAT(252_Y),cast.FLOAT(252_Y))),/(-(4.5,cast.FLOAT(252_X)),-(cast.FLOAT(252_X),cast.FLOAT(252_X)))) +::STMT +FLOAT:approx_sample_size +LITERAL_FLOAT:10.0 +round(*(10.0,sqrt(approx_sample_size))) +::STMT +MATRIX:linear_terms,Y +LITERAL_FLOAT:-0.0 +*(^(linear_terms,-0.0),-(Y,linear_terms)) +::STMT +MATRIX:linear_terms,Y +LITERAL_FLOAT:-2.0 +*(^(linear_terms,-2.0),-(Y,linear_terms)) +::STMT +MATRIX:parsertemp220863,parsertemp220864,H,betamax,Hneg,beta,Hpos +FLOAT:INF,logU +LITERAL_FLOAT:0.0 +*(*(>=(-(H,logU),0.0),!=(+(parsertemp220863,parsertemp220864),INF)),+(beta,+(*(Hpos,betamax),*(Hneg,beta)))) +::STMT +MATRIX:w,ssX_p_CG,X +%*%(t(X),*(w,%*%(X,ssX_p_CG))) +::STMT +FLOAT:j +LITERAL_FLOAT:1.0,3.0 ++(-(3.0,j),1.0) +::STMT +MATRIX:parsertemp400674,W4_rand,parsertemp400677 +LITERAL_FLOAT:0.08720414403938946 +t(%*%(*(0.08720414403938946,W4_rand),t(/(parsertemp400674,parsertemp400677)))) +::STMT +MATRIX:parsertemp496901 +FLOAT:std +*(cast.FLOAT(parsertemp496901),std) +::STMT +FLOAT:cmLabels +LITERAL_FLOAT:1.000100010001 +sqrt(*(cmLabels,1.000100010001)) +::STMT +MATRIX:parsertemp31403,classFeatureCounts +LITERAL_FLOAT:105.0,1.0 +%*%(+(rowSums(classFeatureCounts),*(105.0,1.0)),parsertemp31403) +::STMT +MATRIX:X +FLOAT:x +LITERAL_FLOAT:1.0 +*(/(-(x,X),-(X,X)),-(1.0,/(-(x,X),-(X,X)))) +::STMT +MATRIX:X,Y +FLOAT:x +*(/(-(x,cast.FLOAT(X)),-(cast.FLOAT(X),cast.FLOAT(X))),cast.FLOAT(Y)) +::STMT +MATRIX:lambda,parsertemp170067,parsertemp170065,p_CG,shift_X,w,parsertemp170066,X,parsertemp170060 ++(+(*(lambda,p_CG),*(cast.FLOAT(parsertemp170060),%*%(parsertemp170065,parsertemp170067))),*(cast.FLOAT(shift_X),%*%(t(X),*(w,parsertemp170066)))) +::STMT +MATRIX:parsertemp175076,parsertemp175080,R1 +abs(-(R1,/(exp(parsertemp175076),rowSums(parsertemp175080)))) +::STMT +MATRIX:parsertemp437190,resp,X,weight +LITERAL_FLOAT:2.22E-16 +/(*(/(%*%(parsertemp437190,X),t(weight)),%*%(t(resp),X)),t(+(colSums(resp),2.22E-16))) +::STMT +MATRIX:upd_W1,X_batch,W1_grad +FLOAT:mu,step +-(*(mu,upd_W1),*(/(step,nrow(X_batch)),W1_grad)) +::STMT +FLOAT:i,n +LITERAL_FLOAT:1.0,64.0 +-(n,-(+(i,64.0),1.0)) +::STMT +MATRIX:simplex +/(-(rowSums(simplex),simplex),nrow(simplex)) +::STMT +LITERAL_FLOAT:-0.05,0.05 +INT:parsertemp411077,rnk +rand(parsertemp411077,rnk,-0.05,0.05) +::STMT +MATRIX:oldE +FLOAT:parsertemp32107 +/(sum(oldE),parsertemp32107) +::STMT +FLOAT:norm_Grad_initial +LITERAL_FLOAT:1.0E-8 +*(1.0E-8,norm_Grad_initial) +::STMT +MATRIX:parsertemp414375,parsertemp414377 +FLOAT:int880 +LITERAL_FLOAT:0.0,199.0 +<=(/(-(t(parsertemp414375),*(int880,parsertemp414377)),199.0),0.0) +::STMT +MATRIX:R,parsertemp40216,parsertemp40226 +FLOAT:eAvg +/(/(+(R,rowSums(parsertemp40226)),+(R,rowSums(parsertemp40216))),eAvg) +::STMT +FLOAT:high,low +LITERAL_FLOAT:2.0 +/(+(low,high),2.0) +::STMT +MATRIX:45_CFreqs +LITERAL_FLOAT:1000.0 +-(1000.0,nrow(45_CFreqs)) +::STMT +LITERAL_FLOAT:0.128920512778062 +0.128920512778062 +::STMT +MATRIX:X +FLOAT:int242,n +-(/(colSums(^(X,int242)),n),*(/(colSums(X),n),/(colSums(X),n))) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,50.0 +/(+(classFeatureCounts,1.0),+(rowSums(classFeatureCounts),50.0)) +::STMT +MATRIX:parsertemp43631,parsertemp43633,w +LITERAL_FLOAT:2.0 +*(+(w,*(2.0,%*%(parsertemp43631,parsertemp43633))),+(w,*(2.0,%*%(parsertemp43631,parsertemp43633)))) +::STMT +MATRIX:Y,Xd,out +FLOAT:dd,step_sz,wd +-(+(wd,*(step_sz,dd)),sum(*(*(out,Y),Xd))) +::STMT +MATRIX:Y_counts,means,parsertemp560512,parsertemp560516 +LITERAL_FLOAT:2.0 +*(Y_counts,-(rowSums(*(means,parsertemp560516)),^(rowSums(parsertemp560512),2.0))) +::STMT +MATRIX:parsertemp500608,parsertemp500604,parsertemp500605,w +FLOAT:lambda +LITERAL_FLOAT:0.0 +-(*(*(parsertemp500604,-(parsertemp500605,lambda)),>(-(parsertemp500608,lambda),0.0)),w) +::STMT +MATRIX:X,y,logisticnew +FLOAT:C,int545 +*(C,%*%(t(X),*(-(logisticnew,int545),y))) +::STMT +MATRIX:rowSums_X_sq +max(sqrt(rowSums_X_sq)) +::STMT +MATRIX:Y,parsertemp171319 +FLOAT:float554 +LITERAL_FLOAT:0.15915494309189535 +*(*(exp(/(parsertemp171319,float554)),0.15915494309189535),rowSums(Y)) +::STMT +MATRIX:mn,mx +-(mx,mn) +::STMT +MATRIX:y_corr,parsertemp171089,parsertemp171084,parsertemp171095 +FLOAT:float558,float534,float130 +LITERAL_FLOAT:0.0,1.0,2.0 +*(+(-(0.0,sqrt(parsertemp171084)),/(+(float558,parsertemp171089),+(float130,parsertemp171095))),-(1.0,*(2.0,>(y_corr,float534)))) +::STMT +MATRIX:Y +LITERAL_FLOAT:2.0 +^(rowSums(Y),2.0) +::STMT +MATRIX:parsertemp500607,w,parsertemp500610,wnew +cast.FLOAT(%*%(t(-(wnew,w)),-(*(parsertemp500607,parsertemp500610),w))) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:105.0,1.0 +/(+(classFeatureCounts,1.0),+(rowSums(classFeatureCounts),105.0)) +::STMT +FLOAT:x +LITERAL_FLOAT:1.0,-1.0 ++(1.0,exp(*(x,-1.0))) +::STMT +LITERAL_FLOAT:1.0,2000.0 +/(2000.0,-(2000.0,1.0)) +::STMT +MATRIX:curr_rows_vector +LITERAL_FLOAT:0.0 +sum(>(curr_rows_vector,0.0)) +::STMT +MATRIX:parsertemp31189,parsertemp31187 +LITERAL_FLOAT:3.42951E11,2.0,6999.0 +/(^(/(-(parsertemp31187,parsertemp31189),6999.0),2.0),3.42951E11) +::STMT +MATRIX:R,dssp,dssm +FLOAT:5_n +LITERAL_FLOAT:1.0 +-(/(5_n,-(+(R,dssp),dssm)),1.0) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,500.0 +/(+(classFeatureCounts,1.0),+(rowSums(classFeatureCounts),500.0)) +::STMT +MATRIX:parsertemp555744,target +/(sum(rowSums(abs(parsertemp555744))),nrow(target)) +::STMT +MATRIX:parsertemp129125,groupIndex +-(*(groupIndex,max(parsertemp129125)),max(parsertemp129125)) +::STMT +LITERAL_FLOAT:1.0,6.0,2003.0 +*(*(6.0,2003.0),-(2003.0,1.0)) +::STMT +MATRIX:M2,parsertemp553121 +%*%(rowSums(*(M2,M2)),parsertemp553121) +::STMT +MATRIX:s,d +FLOAT:parsertemp44015 +%*%(t(-(s,*(parsertemp44015,d))),d) +::STMT +FLOAT:link_power +LITERAL_FLOAT:2.0 +/(2.0,link_power) +::STMT +FLOAT:link_power +LITERAL_FLOAT:0.0 +/(0.0,link_power) +::STMT +MATRIX:Y +FLOAT:check_max,check_min +LITERAL_FLOAT:2.0 +-(*(/(2.0,-(check_max,check_min)),Y),/(+(check_min,check_max),-(check_max,check_min))) +::STMT +LITERAL_FLOAT:1.8 +1.8 +::STMT +FLOAT:parsertemp40813,m2,mu +LITERAL_FLOAT:5.0 ++(mu,*(5.0,sqrt(*(parsertemp40813,m2)))) +::STMT +MATRIX:Y,2212_fp +/(2212_fp,-(nrow(Y),sum(Y))) +::STMT +MATRIX:R,dssp,dsep,dssm,dsem +FLOAT:5_eAvg +/(/(-(+(R,dsep),dsem),-(+(R,dssp),dssm)),5_eAvg) +::STMT +MATRIX:parsertemp31188,parsertemp31186 +FLOAT:int452 +LITERAL_FLOAT:2.0,6999.0 +^(/(-(colSums(parsertemp31186),*(int452,parsertemp31188)),6999.0),2.0) +::STMT +MATRIX:scale_X,shift_X,parsertemp274137,parsertemp274138,Grad +LITERAL_FLOAT:2.0 +^(+(%*%(diag(scale_X),%*%(parsertemp274137,parsertemp274138)),%*%(shift_X,Grad)),2.0) +::STMT +MATRIX:csgaps,csmask +>(csgaps,csmask) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:1.0 +/(exp(linear_terms),+(1.0,exp(linear_terms))) +::STMT +MATRIX:ctab,parsertemp409528 +LITERAL_FLOAT:0.4 +>(/(parsertemp409528,rowSums(ctab)),0.4) +::STMT +MATRIX:y_hat,B,parsertemp503774 +LITERAL_FLOAT:2.0 +sum(^(-(-(B,parsertemp503774),y_hat),2.0)) +::STMT +MATRIX:P12,map +LITERAL_FLOAT:0.0 +!=(%*%(map,P12),0.0) +::STMT +FLOAT:run_index +LITERAL_FLOAT:24.0,1.0 +*(24.0,-(run_index,1.0)) +::STMT +MATRIX:parsertemp31029,parsertemp31031 +FLOAT:int341 +LITERAL_FLOAT:1.0,150.0 +/(/(-(colSums(parsertemp31029),*(int341,parsertemp31031)),-(150.0,1.0)),150.0) +::STMT +MATRIX:B,parsertemp410245,X_t +LITERAL_FLOAT:-1.0,2.0 +/(*(parsertemp410245,-1.0),*(2.0,exp(%*%(X_t,B)))) +::STMT +MATRIX:surv,parsertemp538706 +*(sqrt(parsertemp538706),surv) +::STMT +MATRIX:LHSthreshold +LITERAL_FLOAT:1.0 +sum(>(LHSthreshold,1.0)) +::STMT +MATRIX:parsertemp477718,parsertemp477728,t,parsertemp477715,parsertemp477737,parsertemp477725,X,parsertemp477734 +FLOAT:int376,x +LITERAL_FLOAT:1.0 +*(*(/(-(x,X),-(X,X)),-(1.0,/(parsertemp477715,parsertemp477718))),+(*(-(parsertemp477725,parsertemp477728),-(int376,t)),*(+(parsertemp477734,parsertemp477737),/(parsertemp477715,parsertemp477718)))) +::STMT +FLOAT:s,num_groups +LITERAL_FLOAT:1.0,7.0 +*(*(-(s,1.0),num_groups),7.0) +::STMT +MATRIX:X +FLOAT:val +!=(X,val) +::STMT +MATRIX:parsertemp171083 +LITERAL_FLOAT:-2.0,0.001308 +*(sqrt(*(-2.0,parsertemp171083)),0.001308) +::STMT +MATRIX:Y,linear_terms,is_y_0 +FLOAT:int410 +LITERAL_FLOAT:0.0 +/(+(Y,==(Y,0.0)),+(*(linear_terms,-(int410,is_y_0)),==(Y,0.0))) +::STMT +MATRIX:Y_counts,Y +-(Y,%*%(Y_counts,/(colSums(Y),sum(Y_counts)))) +::STMT +MATRIX:linear_terms,Y +FLOAT:int829,link_power,parsertemp286300 +/(*(^(linear_terms,-(parsertemp286300,int829)),-(Y,^(linear_terms,parsertemp286300))),link_power) +::STMT +MATRIX:out,parsertemp2798 +FLOAT:int733,int943 +LITERAL_FLOAT:2.0 +sum(^(*(>(out,int733),-(int943,parsertemp2798)),2.0)) +::STMT +MATRIX:R,dssm +FLOAT:2_n,2_alpha +LITERAL_FLOAT:1.0 +*(-(1.0,2_alpha),-(/(2_n,-(R,dssm)),1.0)) +::STMT +MATRIX:parsertemp149248,parsertemp150463,P_1K +*(P_1K,%*%(rowSums(*(P_1K,parsertemp149248)),parsertemp150463)) +::STMT +FLOAT:g +LITERAL_FLOAT:1.0,2.0 ++(*(-(g,1.0),2.0),2.0) +::STMT +MATRIX:p,V +LITERAL_FLOAT:1.0E-8 ++(%*%(t(V),%*%(V,p)),*(1.0E-8,p)) +::STMT +MATRIX:posSamples +LITERAL_FLOAT:2.0 +colSums(^(posSamples,2.0)) +::STMT +MATRIX:parsertemp175066,scores,parsertemp175069,unnorm_probs,dprobs +*(/(exp(-(scores,parsertemp175066)),rowSums(exp(scores))),rowSums(*(dprobs,/(unnorm_probs,parsertemp175069)))) +::STMT +MATRIX:F +%*%(rowSums(F),colSums(F)) +::STMT +FLOAT:42_m2X +LITERAL_FLOAT:1.0,1000.0 +*(42_m2X,/(1000.0,-(1000.0,1.0))) +::STMT +MATRIX:252_Y +FLOAT:252_X,float125,float67 +LITERAL_FLOAT:1.0 ++(*(-(1.0,/(float67,252_X)),cast.FLOAT(252_Y)),*(/(-(float125,252_X),-(252_X,252_X)),cast.FLOAT(252_Y))) +::STMT +MATRIX:CVars,CFreqs +FLOAT:float426,int601,int956,parsertemp31330,int591 +LITERAL_FLOAT:1.0,10000.0 +/(sum(*(-(CFreqs,int601),CVars)),*(-(10000.0,1.0),/(*(parsertemp31330,int956),-(int591,float426)))) +::STMT +MATRIX:parsertemp171315,t_gp,parsertemp171320,parsertemp171307,parsertemp171316 +FLOAT:float678,float19 +LITERAL_FLOAT:2.0,0.25 +*(*(0.25,*(/(float678,parsertemp171307),+(float19,parsertemp171315))),-(2.0,*(exp(parsertemp171320),*(t_gp,parsertemp171316)))) +::STMT +MATRIX:X,parsertemp115855 +FLOAT:int61,n +LITERAL_FLOAT:2.0 +-(t(colSums(^(X,int61))),*(nrow(X),^(/(parsertemp115855,n),2.0))) +::STMT +MATRIX:out2,parsertemp146942,184_dscores,maskd1,W2 +FLOAT:p,int969 +*(/(maskd1,p),%*%(*(>(out2,int969),%*%(184_dscores,parsertemp146942)),t(W2))) +::STMT +MATRIX:linear_terms +FLOAT:var_power +LITERAL_FLOAT:1.0 +exp(*(linear_terms,-(1.0,var_power))) +::STMT +FLOAT:g +LITERAL_FLOAT:1.0,2.0 ++(*(-(g,1.0),2.0),1.0) +::STMT +MATRIX:posSampleVariances,negSampleVariances +FLOAT:int593,int652,int575,int849 +LITERAL_FLOAT:7.996E9,2.0,3.37275E9 +/(^(+(/(posSampleVariances,int652),/(negSampleVariances,int849)),2.0),+(/(^(posSampleVariances,int593),7.996E9),/(^(negSampleVariances,int575),3.37275E9))) +::STMT +MATRIX:parsertemp410245,parsertemp410248 +FLOAT:int577,float40 +LITERAL_FLOAT:1.0,1.5 +min(^(/(*(parsertemp410245,int577),*(float40,parsertemp410248)),/(1.0,1.5))) +::STMT +MATRIX:parsertemp2832 +==(round(parsertemp2832),min(round(parsertemp2832))) +::STMT +MATRIX:C,Xm,parsertemp265702 +sum(%*%(%*%(%*%(Xm,parsertemp265702),t(C)),t(Xm))) +::STMT +MATRIX:parsertemp386437,neighbors +FLOAT:eps +LITERAL_FLOAT:0.0 +*(<=(-(neighbors,diag(parsertemp386437)),eps),<(0.0,-(neighbors,diag(parsertemp386437)))) +::STMT +MATRIX:neighbors +FLOAT:eps,int625 +LITERAL_FLOAT:1.0 ++(rowSums(*(<=(neighbors,eps),<(int625,neighbors))),1.0) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:0.0,1.0 +-(1.0,exp(-(0.0,exp(finite_linear_terms)))) +::STMT +MATRIX:W +FLOAT:parsertemp65,parsertemp66 +LITERAL_FLOAT:1.0,2.0,3.0 +*(*(-(sum(W),1.0),-(sum(W),2.0)),^(sqrt(/(parsertemp65,parsertemp66)),3.0)) +::STMT +MATRIX:parsertemp42200,R +FLOAT:int779,meanX +LITERAL_FLOAT:1.0,2.0 +-(+(-(parsertemp42200,/(R,int779)),/(1.0,2.0)),meanX) +::STMT +MATRIX:V,parsertemp10742,H,parsertemp10738 +FLOAT:Eps +t(*(H,/(%*%(parsertemp10738,V),+(parsertemp10742,Eps)))) +::STMT +MATRIX:r_LS,parsertemp285848 +LITERAL_FLOAT:0.0 +-(0.0,cast.FLOAT(%*%(t(r_LS),t(parsertemp285848)))) +::STMT +MATRIX:X,parsertemp115854 +LITERAL_FLOAT:2.0 +*($1:nrow(X),^(/(t(parsertemp115854),$1),2.0)) +::STMT +MATRIX:W,X,parsertemp411199,parsertemp411201 +LITERAL_FLOAT:1.0E-8 +/(X,+(%*%(W,/(parsertemp411199,parsertemp411201)),1.0E-8)) +::STMT +FLOAT:parsertemp42302,parsertemp42306 +LITERAL_FLOAT:1.000100010001 +*(sqrt(*(parsertemp42302,1.000100010001)),sqrt(*(parsertemp42306,1.000100010001))) +::STMT +MATRIX:p_CG +FLOAT:parsertemp285794,parsertemp285796 +LITERAL_FLOAT:-1.0 +/(+(*(cast.FLOAT(p_CG),-1.0),sqrt(-(parsertemp285794,parsertemp285796))),cast.FLOAT(%*%(t(p_CG),p_CG))) +::STMT +MATRIX:cdf_min_distances,threshold_matrix +LITERAL_FLOAT:1.0 ++(t(colSums(<(cdf_min_distances,threshold_matrix))),1.0) +::STMT +FLOAT:dimensions +LITERAL_FLOAT:1.0,2.0 ++(^(2.0,dimensions),1.0) +::STMT +FLOAT:m2Y,sigmaX,W,parsertemp26583 +*(sigmaX,sqrt(*(m2Y,/(W,parsertemp26583)))) +::STMT +MATRIX:CVars,CFreqs +FLOAT:int381 +LITERAL_FLOAT:10000.0 +/(sum(*(-(CFreqs,int381),CVars)),-(10000.0,nrow(CFreqs))) +::STMT +MATRIX:R,dssp,dssm +FLOAT:5_n +/(5_n,-(+(R,dssp),dssm)) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,50.0 ++(rowSums(classFeatureCounts),*(50.0,1.0)) +::STMT +MATRIX:parsertemp410978,W,H,parsertemp410980 +FLOAT:eps ++(%*%(W,/(*(H,parsertemp410978),t(parsertemp410980))),eps) +::STMT +MATRIX:linear_terms +FLOAT:var_power +LITERAL_FLOAT:1.0,2.0 +/(/(1.0,^(linear_terms,2.0)),-(2.0,var_power)) +::STMT +MATRIX:LT,Y,parsertemp149320,parsertemp150469 +*(Y,-(LT,%*%(parsertemp149320,parsertemp150469))) +::STMT +LITERAL_FLOAT:0.6 +0.6 +::STMT +MATRIX:C,Xm,parsertemp265701 +t(%*%(Xm,%*%(C,parsertemp265701))) +::STMT +MATRIX:parsertemp42190,X +LITERAL_FLOAT:1.0,2.0 ++(-(parsertemp42190,/(X,2.0)),/(1.0,2.0)) +::STMT +LITERAL_FLOAT:-1.0E30 +-1.0E30 +::STMT +LITERAL_FLOAT:1.0E30 +1.0E30 +::STMT +MATRIX:prec_chol,mu +LITERAL_FLOAT:2.0 +*(^(mu,2.0),^(prec_chol,2.0)) +::STMT +LITERAL_FLOAT:0.85 +0.85 +::STMT +MATRIX:X,Y +FLOAT:x +*(/(-(x,X),-(X,X)),Y) +::STMT +LITERAL_FLOAT:0.3 +0.3 +::STMT +MATRIX:p,V +%*%(V,p) +::STMT +MATRIX:dY,g +FLOAT:lr,momentum +LITERAL_FLOAT:2.0 +sum(^(-(*(momentum,dY),*(lr,g)),2.0)) +::STMT +MATRIX:ncCnts,maxsc +FLOAT:parsertemp31781 +LITERAL_FLOAT:0.0 +|(>(ncCnts,0.0),>(maxsc,parsertemp31781)) +::STMT +MATRIX:current_node +FLOAT:cur_node_index +LITERAL_FLOAT:1.0 ++(+(cur_node_index,cast.FLOAT(current_node)),1.0) +::STMT +MATRIX:_sbcvar1708 +LITERAL_FLOAT:45.0 ++(45.0,nrow(_sbcvar1708)) +::STMT +LITERAL_FLOAT:0.08146881698903526 +0.08146881698903526 +::STMT +MATRIX:cumLeftHist,parsertemp132495,parsertemp132506,leftHist,outBucket +LITERAL_FLOAT:1.0 ++(+(%*%(==(outBucket,parsertemp132495),-(cumLeftHist,leftHist)),parsertemp132506),1.0) +::STMT +LITERAL_FLOAT:0.30000000000000004 +0.30000000000000004 +::STMT +MATRIX:s,d,alpha +FLOAT:parsertemp44015 +%*%(t(-(s,*(parsertemp44015,d))),-(s,*(cast.FLOAT(alpha),d))) +::STMT +MATRIX:C,Xm,tmp,parsertemp265701 +/(%*%(t(Xm),%*%(Xm,%*%(C,parsertemp265701))),sum(tmp)) +::STMT +MATRIX:logistic,X,y +FLOAT:int215 +LITERAL_FLOAT:2.0 +*(2.0,%*%(t(X),*(-(logistic,int215),y))) +::STMT +MATRIX:q_LS,p_LS,parsertemp170551,X +FLOAT:norm_r2_LS,lambda_LS +*(/(norm_r2_LS,sum(*(p_LS,q_LS))),+(%*%(%*%(parsertemp170551,X),p_LS),*(lambda_LS,p_LS))) +::STMT +MATRIX:shift_X,parsertemp116007 +LITERAL_FLOAT:2.0,9.999999999999998E-15 +*(sum(^(+(parsertemp116007,shift_X),2.0)),9.999999999999998E-15) +::STMT +MATRIX:parsertemp10744,W,H +FLOAT:Eps ++(%*%(W,%*%(*(H,parsertemp10744),t(H))),Eps) +::STMT +MATRIX:parsertemp170277 +LITERAL_FLOAT:3.141592653589793,0.5 ++(0.5,/(parsertemp170277,3.141592653589793)) +::STMT +MATRIX:ts +FLOAT:q ++(-(q,%*%(ts,ts)),%*%(ts,ts)) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:1.0,2.0,1.0E7 +*(==(+(1.0E7,exp(linear_terms)),1.0E7),-(1.0,/(exp(linear_terms),2.0))) +::STMT +MATRIX:dY,W,Y,sumW +FLOAT:lr,momentum +-(*(momentum,dY),*(lr,-(*(Y,sumW),%*%(W,Y)))) +::STMT +MATRIX:m_err +sum(colSums(m_err)) +::STMT +MATRIX:parsertemp409058,parsertemp409054,ctab +FLOAT:threshold +*(parsertemp409058,>(/(parsertemp409054,rowSums(ctab)),threshold)) +::STMT +MATRIX:means,parsertemp560515 +LITERAL_FLOAT:2.0 +rowSums(*(means,^(parsertemp560515,2.0))) +::STMT +MATRIX:P,minD,D +t(colSums(/(<=(D,minD),rowSums(P)))) +::STMT +MATRIX:X +FLOAT:x +LITERAL_FLOAT:1.0 +-(1.0,/(cast.FLOAT(-(x,X)),-(cast.FLOAT(X),cast.FLOAT(X)))) +::STMT +MATRIX:tpr,fpr +*(-(fpr,fpr),+(tpr,tpr)) +::STMT +MATRIX:is_LT_infinite,Y_prob +LITERAL_FLOAT:1.0 +*(Y_prob,-(1.0,rowSums(is_LT_infinite))) +::STMT +MATRIX:parsertemp222327,is_row_in_samples +LITERAL_FLOAT:2001.0 +-(2001.0,*(is_row_in_samples,parsertemp222327)) +::STMT +MATRIX:surv +LITERAL_FLOAT:1.0 +*(surv,sqrt(-(1.0,surv))) +::STMT +MATRIX:ssX_p,scale_X,X +*(scale_X,%*%(t(X),%*%(X,ssX_p))) +::STMT +FLOAT:b,int247 +LITERAL_FLOAT:2.0 +sqrt(-(^(b,2.0),int247)) +::STMT +MATRIX:tab,catTotal +LITERAL_FLOAT:-1.0 +*(/(tab,catTotal),-1.0) +::STMT +MATRIX:col_nonzeros,parsertemp382954,parsertemp382951,row_nonzeros +LITERAL_FLOAT:5.0E-7 +*(5.0E-7,+(sum(*(parsertemp382951,row_nonzeros)),sum(*(parsertemp382954,col_nonzeros)))) +::STMT +FLOAT:int484,217_a22,parsertemp22450,parsertemp22451 +LITERAL_FLOAT:2.0 +*(2.0,sqrt(+(+(parsertemp22450,parsertemp22451),/(int484,217_a22)))) +::STMT +LITERAL_FLOAT:44.721359549995796 +44.721359549995796 +::STMT +MATRIX:X +FLOAT:int557,int228 +LITERAL_FLOAT:1.0 +/(-(exp(*(int557,X)),1.0),+(exp(*(int228,X)),1.0)) +::STMT +MATRIX:parsertemp31023,parsertemp31025 +FLOAT:int211,int718 +LITERAL_FLOAT:1.0,2.0,100.0 +/(^(/(-(parsertemp31023,parsertemp31025),-(int211,int718)),2.0),*(^(100.0,2.0),-(100.0,1.0))) +::STMT +MATRIX:B,X,y +LITERAL_FLOAT:2.0 +^(-(y,%*%(X,B)),2.0) +::STMT +MATRIX:prec_chol,mu +FLOAT:int510,int468 +t(rowSums(*(^(mu,int468),^(prec_chol,int510)))) +::STMT +LITERAL_FLOAT:1.0,2001.0 +-(2001.0,1.0) +::STMT +MATRIX:y_corr +LITERAL_FLOAT:1.0 +/(y_corr,-(1.0,y_corr)) +::STMT +MATRIX:ts +LITERAL_FLOAT:1.0,2.0,4.0 +-(+(-(length(ts),4.0),1.0),2.0) +::STMT +LITERAL_FLOAT:0.086386842558136 +0.086386842558136 +::STMT +MATRIX:P,scale_lambda,X,Y,parsertemp150455 +LITERAL_FLOAT:0.0,1.0E-5 ++(%*%(t(X),-(P,Y)),*(*(%*%(scale_lambda,parsertemp150455),1.0E-5),0.0)) +::STMT +MATRIX:parsertemp555613,parsertemp555615 +%*%(t(sqrt(parsertemp555613)),sqrt(parsertemp555615)) +::STMT +MATRIX:X,Y +/(abs(-(X,Y)),abs(X)) +::STMT +MATRIX:W1_rand,X,parsertemp393476,parsertemp393466 +FLOAT:float616 +LITERAL_FLOAT:0.07261134713572442 +%*%(*(0.07261134713572442,W1_rand),t(/(-(X,parsertemp393466),+(parsertemp393476,float616)))) +::STMT +MATRIX:colSD +LITERAL_FLOAT:3.0 +*(3.0,colSD) +::STMT +MATRIX:_funvar402 +LITERAL_FLOAT:1.0E-16 ++(_funvar402,1.0E-16) +::STMT +MATRIX:var_tot_Y +cast.FLOAT(sqrt(var_tot_Y)) +::STMT +MATRIX:select,d_r_rev,X_rev_agg +*(%*%(select,X_rev_agg),d_r_rev) +::STMT +FLOAT:n_features +LITERAL_FLOAT:1.0 +*(n_features,+(n_features,1.0)) +::STMT +MATRIX:r +LITERAL_FLOAT:9.999999999999998E-15 +*(cast.FLOAT(%*%(t(r),r)),9.999999999999998E-15) +::STMT +LITERAL_FLOAT:3.0,2001.0 +-(2001.0,3.0) +::STMT +MATRIX:X,Y,K +LITERAL_FLOAT:-1.0 ++(*(*(K,-1.0),-(X,X)),-(Y,Y)) +::STMT +MATRIX:sample_maps,X +LITERAL_FLOAT:2.0 +^(%*%(sample_maps,X),2.0) +::STMT +MATRIX:Y_prob,Y,parsertemp171380 +FLOAT:int58 +LITERAL_FLOAT:3.141592653589793,2.0 +*(*(*(rowSums(Y),Y_prob),Y_prob),^(*(+(int58,parsertemp171380),3.141592653589793),2.0)) +::STMT +MATRIX:y_corr +LITERAL_FLOAT:0.5 +-(y_corr,0.5) +::STMT +LITERAL_FLOAT:0.15000000000000002 +0.15000000000000002 +::STMT +FLOAT:parsertemp31330 +LITERAL_FLOAT:10000.0 +*(parsertemp31330,10000.0) +::STMT +MATRIX:A +/(*(cast.FLOAT(A),cast.FLOAT(A)),*(cast.FLOAT(A),cast.FLOAT(A))) +::STMT +MATRIX:W +FLOAT:int461,parsertemp65,parsertemp66,int339,wt +LITERAL_FLOAT:3.0,4.0 +*(*(*(-(wt,int339),-(wt,int461)),-(sum(W),3.0)),^(sqrt(/(parsertemp65,parsertemp66)),4.0)) +::STMT +FLOAT:int495,x +LITERAL_FLOAT:1.0 +/(1.0,+(1.0,exp(*(x,int495)))) +::STMT +MATRIX:distances,ksmall,parsertemp557211 +LITERAL_FLOAT:0.0 +*(<=(distances,ksmall),==(diag(parsertemp557211),0.0)) +::STMT +MATRIX:parsertemp410979,W,X,parsertemp410981 +FLOAT:eps +/(X,+(%*%(W,/(parsertemp410979,parsertemp410981)),eps)) +::STMT +MATRIX:Xtest_dists +FLOAT:eps +LITERAL_FLOAT:0.0 +rowSums(*(<=(Xtest_dists,eps),<(0.0,Xtest_dists))) +::STMT +MATRIX:_sbcvar11 +LITERAL_FLOAT:1000.0 +/(%*%(rowSums(_sbcvar11),colSums(_sbcvar11)),1000.0) +::STMT +MATRIX:resp,mean,X,weight +FLOAT:int164 +LITERAL_FLOAT:2.0 +-(/(%*%(t(resp),^(X,int164)),t(weight)),*(2.0,^(mean,2.0))) +::STMT +MATRIX:CFreqs +LITERAL_FLOAT:1.0 +-(CFreqs,1.0) +::STMT +MATRIX:parsertemp220853,parsertemp220854 +LITERAL_FLOAT:0.0,2.0,3.4011973816621555 +*(2.0,>=(-(+(parsertemp220853,parsertemp220854),3.4011973816621555),0.0)) +::STMT +MATRIX:WM +LITERAL_FLOAT:1.0 +/(sum(WM),-(sum(WM),1.0)) +::STMT +LITERAL_FLOAT:2.0,2001.0 +-(2001.0,2.0) +::STMT +MATRIX:X +FLOAT:index,int193,parsertemp129094 +LITERAL_FLOAT:2.0 ++(+(*(index,-(parsertemp129094,int193)),2.0),-(ncol(X),2.0)) +::STMT +MATRIX:t_gp,parsertemp560881,parsertemp560864,parsertemp560863,parsertemp560877 +FLOAT:int773,float843,int853 +LITERAL_FLOAT:1.0 +-(+(1.0,-(*(int773,parsertemp560863),1.0)),*(*(*(t_gp,parsertemp560877),-(parsertemp560864,int853)),exp(/(parsertemp560881,float843)))) +::STMT +FLOAT:parsertemp410218,parsertemp410219 +LITERAL_FLOAT:-1.0,50.0 +exp(/(*(-(parsertemp410218,parsertemp410219),-1.0),50.0)) +::STMT +FLOAT:rho +LITERAL_FLOAT:10000.0 +round(*(10000.0,rho)) +::STMT +FLOAT:eta,s +^(eta,s) +::STMT +MATRIX:ss_res_Y,var_tot_Y +FLOAT:df_ss_res_Y +LITERAL_FLOAT:1.0 +-(1.0,/(/(ss_res_Y,df_ss_res_Y),var_tot_Y)) +::STMT +MATRIX:tmp,parsertemp260786,parsertemp260787,parsertemp260785 +cast.FLOAT(%*%(t(-(parsertemp260787,tmp)),-(%*%(parsertemp260785,parsertemp260786),tmp))) +::STMT +MATRIX:A +LITERAL_FLOAT:1.0E-4 ++(1.0E-4,abs(t(A))) +::STMT +MATRIX:is_LT_infinite,Y_prob,Y,parsertemp171294,parsertemp171292,flip_pos,parsertemp171290 +FLOAT:float465 +*(*(Y,%*%(+(parsertemp171294,is_LT_infinite),flip_pos)),+(*(/(Y_prob,parsertemp171290),-(float465,parsertemp171292)),is_LT_infinite)) +::STMT +MATRIX:P,Y,dP +sum(&(>(P,dP),Y)) +::STMT +FLOAT:a,b,c,int863 +LITERAL_FLOAT:2.0 +sqrt(-(^(b,2.0),*(*(int863,a),c))) +::STMT +MATRIX:y_corr +FLOAT:link_power,int319 +LITERAL_FLOAT:0.0 +-(^(+(y_corr,==(y_corr,int319)),link_power),==(y_corr,0.0)) +::STMT +LITERAL_FLOAT:1.0,2.0,3.0,2003.0 +*(*(-(2003.0,2.0),+(2003.0,1.0)),+(2003.0,3.0)) +::STMT +MATRIX:g0_1,parsertemp410117 +LITERAL_FLOAT:2.0 +^(+(g0_1,t(colSums(parsertemp410117))),2.0) +::STMT +MATRIX:P,Y,dP +&(<=(P,dP),!(Y)) +::STMT +MATRIX:parsertemp274141,shift_X,Grad +LITERAL_FLOAT:2.0 +sum(^(+(%*%(parsertemp274141,Grad),%*%(shift_X,Grad)),2.0)) +::STMT +MATRIX:U,V,X +LITERAL_FLOAT:0.0 +*(!=(X,0.0),-(%*%(U,t(V)),X)) +::STMT +MATRIX:col +FLOAT:min_val,bin_width +/(-(col,min_val),bin_width) +::STMT +MATRIX:parsertemp260759,parsertemp260756,Xd +FLOAT:dd,parsertemp260753,wd +/(*(-(+(wd,parsertemp260753),sum(parsertemp260756)),-(+(wd,parsertemp260753),sum(parsertemp260756))),+(dd,sum(*(parsertemp260759,Xd)))) +::STMT +MATRIX:parsertemp254737 +FLOAT:2124_sq_root_d,parsertemp254772,parsertemp254751,float69 ++(float69,*(parsertemp254772,/(-(parsertemp254751,2124_sq_root_d),sum(parsertemp254737)))) +::STMT +MATRIX:X,Centering,ScaleFactor +FLOAT:N +LITERAL_FLOAT:1.0 +/(%*%(t(/(X,ScaleFactor)),/(-(X,Centering),ScaleFactor)),-(N,1.0)) +::STMT +FLOAT:link_power +LITERAL_FLOAT:-2.0,1.0 ++(-2.0,/(1.0,link_power)) +::STMT +LITERAL_FLOAT:0.0,1.0 +INT:int449,m +abs(rand(m,int449,0.0,1.0)) +::STMT +FLOAT:parsertemp22454,parsertemp22485 +LITERAL_FLOAT:2.0 +exp(+(parsertemp22485,*(2.0,sqrt(parsertemp22454)))) +::STMT +MATRIX:_sbcvar1007 +FLOAT:number_nans +/(number_nans,nrow(_sbcvar1007)) +::STMT +MATRIX:r,parsertemp44050 +FLOAT:norm_r2 +/(sum(*(-(r,parsertemp44050),-(r,parsertemp44050))),norm_r2) +::STMT +MATRIX:xs +LITERAL_FLOAT:1000.0,4.5 +-(1000.0,sum(>=(xs,4.5))) +::STMT +MATRIX:parsertemp397720,W1_rand,parsertemp397730,X +FLOAT:float798 +LITERAL_FLOAT:0.086386842558136 +%*%(*(0.086386842558136,W1_rand),t(/(-(X,parsertemp397720),+(parsertemp397730,float798)))) +::STMT +MATRIX:I +*(nrow(I),ncol(I)) +::STMT +MATRIX:linear_terms +FLOAT:link_power,parsertemp171228 +LITERAL_FLOAT:2.0 +/(^(linear_terms,-(/(parsertemp171228,link_power),2.0)),^(link_power,2.0)) +::STMT +MATRIX:_sbcvar96,_sbcvar95,_sbcvar98 +LITERAL_FLOAT:-1.0 +sum(*(+(%*%(_sbcvar95,_sbcvar96),-1.0),%*%(_sbcvar95,_sbcvar98))) +::STMT +MATRIX:parsertemp170136 +FLOAT:278_sq_root_d,parsertemp170150,pq_CG +LITERAL_FLOAT:0.5 +*(*(0.5,/(-(parsertemp170150,278_sq_root_d),sum(parsertemp170136))),pq_CG) +::STMT +MATRIX:V,W,parsertemp10741,H +LITERAL_FLOAT:1.0E-8 +*(H,/(%*%(t(W),V),+(%*%(parsertemp10741,H),1.0E-8))) +::STMT +FLOAT:252_Y,252_X,252_K,float711,float512,parsertemp32930,int666,parsertemp32915,float790 +LITERAL_FLOAT:1.0 +*(*(/(-(float512,252_X),-(252_X,252_X)),-(1.0,/(float790,252_X))),+(*(-(252_K,252_Y),-(int666,parsertemp32915)),*(+(parsertemp32930,252_Y),/(float711,252_X)))) +::STMT +FLOAT:int684,191_t,191_lr,int4,191_beta1,parsertemp146979 +LITERAL_FLOAT:1.0 +/(*(191_lr,sqrt(-(int684,parsertemp146979))),-(1.0,^(191_beta1,+(191_t,int4)))) +::STMT +FLOAT:rho +LITERAL_FLOAT:10000.0 +/(round(*(10000.0,rho)),10000.0) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:-1.0,2.0 +/(*(^(finite_linear_terms,2.0),-1.0),2.0) +::STMT +MATRIX:ssX_newbeta +LITERAL_FLOAT:0.0 +INT:int142,int272 ++(ssX_newbeta,cast.FLOAT(rand(int142,int272,0.0,0.0))) +::STMT +MATRIX:S +LITERAL_FLOAT:2.0,799.0 +/(^(diag(S),2.0),799.0) +::STMT +MATRIX:parsertemp171314,t_gp,parsertemp171306 +FLOAT:float653 +LITERAL_FLOAT:1.0,0.25,0.254829592 +*(0.25,*(/(1.0,+(float653,parsertemp171306)),+(0.254829592,*(t_gp,parsertemp171314)))) +::STMT +FLOAT:num_hidden1,m +sqrt(+(m,num_hidden1)) +::STMT +MATRIX:parsertemp410988,parsertemp410979,parsertemp410990,parsertemp410981 +FLOAT:parsertemp410999 +-(sum(%*%(/(parsertemp410988,parsertemp410990),/(parsertemp410979,parsertemp410981))),parsertemp410999) +::STMT +MATRIX:d,parsertemp410054 +FLOAT:r2 +/(r2,sum(*(d,t(parsertemp410054)))) +::STMT +MATRIX:E,parsertemp22269 +FLOAT:int373,q +LITERAL_FLOAT:10000.0 +sqrt(/(sum(/(parsertemp22269,E)),*(10000.0,-(q,int373)))) +::STMT +FLOAT:beg +LITERAL_FLOAT:1.0,512.0 +-(+(beg,512.0),1.0) +::STMT +MATRIX:parsertemp220863,parsertemp220864,H,betamax,Hneg,beta,Hpos +FLOAT:float727 +LITERAL_FLOAT:0.0,1.0E20 +*(*(>=(-(H,float727),0.0),!=(+(parsertemp220863,parsertemp220864),1.0E20)),+(beta,+(*(Hpos,betamax),*(Hneg,beta)))) +::STMT +MATRIX:w,out +LITERAL_FLOAT:1.0,0.5,0.001 +*(0.001,+(*(0.5,cast.FLOAT(out)),*(1.0,cast.FLOAT(w)))) +::STMT +MATRIX:F +-(F,/(%*%(rowSums(F),colSums(F)),sum(F))) +::STMT +MATRIX:rowSums_X_sq +LITERAL_FLOAT:5.0 +/(5.0,max(sqrt(rowSums_X_sq))) +::STMT +MATRIX:Ileft,Iright,ig +FLOAT:min_leaf +*(&(>=(rowSums(Ileft),min_leaf),>=(rowSums(Iright),min_leaf)),ig) +::STMT +FLOAT:c +LITERAL_FLOAT:-1.0,2.0 +*(*(2.0,c),-1.0) +::STMT +MATRIX:maxscub +FLOAT:parsertemp31797 +LITERAL_FLOAT:-Infinity +|(>=(maxscub,parsertemp31797),==(maxscub,-Infinity)) +::STMT +MATRIX:vars +FLOAT:dispersion +*(dispersion,colSums(vars)) +::STMT +MATRIX:parsertemp410245,parsertemp410247 +LITERAL_FLOAT:-1.0,1.0,2.0,1.5 +^(/(*(parsertemp410245,-1.0),*(2.0,exp(parsertemp410247))),/(1.0,1.5)) +::STMT +FLOAT:e,mu +LITERAL_FLOAT:0.999,4.0 +/(-(0.999,mu),-(4.0,e)) +::STMT +LITERAL_FLOAT:105.0,1.0 +*(105.0,1.0) +::STMT +LITERAL_FLOAT:1.0,10000.0 +-(10000.0,1.0) +::STMT +MATRIX:parsertemp2781,Xd,parsertemp2785 +FLOAT:dd,step_sz,wd +/(-(+(wd,*(step_sz,dd)),sum(*(parsertemp2781,Xd))),+(dd,sum(*(parsertemp2785,Xd)))) +::STMT +MATRIX:W2_rand +LITERAL_FLOAT:0.2656844656620286 +*(0.2656844656620286,W2_rand) +::STMT +MATRIX:linear_terms +FLOAT:link_power +LITERAL_FLOAT:2.0 +^(linear_terms,/(2.0,link_power)) +::STMT +MATRIX:252_X,252_K +LITERAL_FLOAT:0.0 +*(-(0.0,cast.FLOAT(252_K)),-(cast.FLOAT(252_X),cast.FLOAT(252_X))) +::STMT +MATRIX:ytest,yhat +FLOAT:parsertemp115806,n +LITERAL_FLOAT:2.0 +-(sum(^(-(ytest,yhat),2.0)),*(nrow(ytest),^(/(parsertemp115806,n),2.0))) +::STMT +MATRIX:parsertemp31265,WM,CMeans +LITERAL_FLOAT:2.0 +^(-(CMeans,/(sum(parsertemp31265),sum(WM))),2.0) +::STMT +FLOAT:log_ten,float83,parsertemp169813 +LITERAL_FLOAT:4.0 +*(log_ten,-(4.0,round(-(parsertemp169813,float83)))) +::STMT +MATRIX:X +FLOAT:i ++(i,ncol(X)) +::STMT +MATRIX:parsertemp410978,H,parsertemp410980 +t(rowSums(/(*(H,parsertemp410978),t(parsertemp410980)))) +::STMT +MATRIX:residual_matrix +LITERAL_FLOAT:0.0 ++(nrow(residual_matrix),0.0) +::STMT +MATRIX:X_plane,parsertemp11251 +FLOAT:int665 +LITERAL_FLOAT:0.0 +rowSums(*(>(X_plane,0.0),t(^(int665,parsertemp11251)))) +::STMT +MATRIX:parsertemp178161,M +colSums(exp(-(M,parsertemp178161))) +::STMT +MATRIX:W +LITERAL_FLOAT:2.0 +-(sum(round(W)),2.0) +::STMT +MATRIX:r,d,Hd +FLOAT:r2,c +LITERAL_FLOAT:0.0 ++(-(0.0,+(r,*(c,Hd))),*(/(cast.FLOAT(r),r2),d)) +::STMT +LITERAL_FLOAT:2.0,0.5,-0.5 +INT:int121,int493 +^(rand(int121,int493,-0.5,0.5),2.0) +::STMT +MATRIX:W +LITERAL_FLOAT:3.0 +-(sum(round(W)),3.0) +::STMT +MATRIX:trees_M_offset +LITERAL_FLOAT:1.0 +-(cast.FLOAT(trees_M_offset),1.0) +::STMT +MATRIX:dataFrame,constraintsFrame +*(nrow(dataFrame),nrow(constraintsFrame)) +::STMT +MATRIX:S,parsertemp382904,V,W,row_nonzeros +FLOAT:reg ++(%*%(*(W,%*%(S,parsertemp382904)),V),*(*(reg,S),row_nonzeros)) +::STMT +MATRIX:oldX +LITERAL_FLOAT:1.0 ++(nrow(oldX),1.0) +::STMT +MATRIX:parsertemp10964,C +LITERAL_FLOAT:100.0 +/(sum(==(parsertemp10964,C)),100.0) +::STMT +MATRIX:obj,gs,parsertemp44066 +FLOAT:float664,int191,parsertemp44077,int394 +LITERAL_FLOAT:-0.5 +/(-(cast.FLOAT(obj),+(*(float664,parsertemp44077),*(int191,int394))),*(-0.5,-(cast.FLOAT(gs),cast.FLOAT(parsertemp44066)))) +::STMT +MATRIX:W +LITERAL_FLOAT:1.0 +-(sum(round(W)),1.0) +::STMT +MATRIX:_sbcvar179,_sbcvar182,237_CFreqs +FLOAT:int842 +LITERAL_FLOAT:10000.0 +/(sum(*(+(237_CFreqs,int842),%*%(_sbcvar179,_sbcvar182))),-(10000.0,nrow(_sbcvar179))) +::STMT +MATRIX:p,z +FLOAT:pp,parsertemp169870,pz +LITERAL_FLOAT:-1.0 ++(*(sum(*(p,z)),-1.0),sqrt(-(*(pz,pz),*(pp,parsertemp169870)))) +::STMT +MATRIX:parsertemp31782,err,parsertemp31769,parsertemp31768,cCnts,parsertemp31780 +FLOAT:minSup,int606 +-(sum(&(>=(cCnts,minSup),>(err,int606))),sum(&(&(parsertemp31768,parsertemp31769),|(parsertemp31780,parsertemp31782)))) +::STMT +MATRIX:V,y +LITERAL_FLOAT:0.0 +-(0.0,%*%(t(V),y)) +::STMT +MATRIX:Y,predicted_Y +LITERAL_FLOAT:0.0 +sum(==(-(predicted_Y,Y),0.0)) +::STMT +FLOAT:n_stratum_cols,n_group_cols +LITERAL_FLOAT:2.0 ++(+(2.0,n_group_cols),n_stratum_cols) +::STMT +FLOAT:sigma,alpha +LITERAL_FLOAT:0.5 +*(*(0.5,sigma),alpha) +::STMT +MATRIX:rowSums_X_sq +LITERAL_FLOAT:8.674675786448736 +/(8.674675786448736,max(sqrt(rowSums_X_sq))) +::STMT +MATRIX:G,authorities +max(%*%(t(G),%*%(G,authorities))) +::STMT +MATRIX:indexWithInGroups,selectedMatrix +rowSums(*(indexWithInGroups,selectedMatrix)) +::STMT +MATRIX:in_m_neighbor_value +FLOAT:in_i_k_min +LITERAL_FLOAT:1.0 ++(-(ncol(in_m_neighbor_value),in_i_k_min),1.0) +::STMT +MATRIX:parsertemp386440,parsertemp386441 +FLOAT:minPts +LITERAL_FLOAT:1.0 +>=(+(rowSums(*(parsertemp386440,parsertemp386441)),1.0),minPts) +::STMT +MATRIX:solution,X +*(-(X,solution),-(X,solution)) +::STMT +MATRIX:qLow,length,qUp +LITERAL_FLOAT:2.0 +<(rowSums(|(<(length,qLow),>(length,qUp))),2.0) +::STMT +MATRIX:C,parsertemp11014 +LITERAL_FLOAT:1000.0 +/(sum(==(parsertemp11014,C)),1000.0) +::STMT +MATRIX:parsertemp2832 +==(round(parsertemp2832),max(round(parsertemp2832))) +::STMT +MATRIX:parsertemp410081,d_r_rev,parsertemp410090 +FLOAT:o +LITERAL_FLOAT:-1.0 +-(+(*(cast.FLOAT(parsertemp410081),-1.0),sum(*(d_r_rev,parsertemp410090))),o) +::STMT +MATRIX:CFreqs1,present_domain_vals_mat,CMeans +FLOAT:my +LITERAL_FLOAT:2.0 +sum(*(%*%(present_domain_vals_mat,CFreqs1),^(-(CMeans,my),2.0))) +::STMT +MATRIX:linear_terms +FLOAT:link_power +LITERAL_FLOAT:-1.0 +*(^(linear_terms,/(-1.0,link_power)),-1.0) +::STMT +MATRIX:parsertemp437190,X,weight +LITERAL_FLOAT:2.0 +*(2.0,^(/(%*%(parsertemp437190,X),t(weight)),2.0)) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:1.0,2.0 +/(/(1.0,linear_terms),2.0) +::STMT +FLOAT:int252,int543 +INT:int84,int477 +diag(rand(int84,int477,int252,int543)) +::STMT +MATRIX:A,B,C,X +%*%(<=(%*%(X,A),B),C) +::STMT +MATRIX:r,d,parsertemp43999 +cast.FLOAT(/(sum(*(r,r)),%*%(t(d),+(d,parsertemp43999)))) +::STMT +MATRIX:2814_K,2814_X +LITERAL_FLOAT:0.0 +*(cast.FLOAT(-(0.0,2814_K)),-(cast.FLOAT(2814_X),cast.FLOAT(2814_X))) +::STMT +MATRIX:posSamples,posSampleMeans +LITERAL_FLOAT:2.0,7000.0 +-(colSums(^(posSamples,2.0)),*(7000.0,^(posSampleMeans,2.0))) +::STMT +MATRIX:mu +FLOAT:q +LITERAL_FLOAT:4.0 +-(q,*(4.0,*(cast.FLOAT(mu),cast.FLOAT(mu)))) +::STMT +MATRIX:linear_terms,Y +LITERAL_FLOAT:-1.0 +*(^(linear_terms,-1.0),-(Y,linear_terms)) +::STMT +MATRIX:U,X,parsertemp382850 +LITERAL_FLOAT:0.0 +%*%(t(U),*(!=(X,0.0),-(%*%(U,parsertemp382850),X))) +::STMT +LITERAL_FLOAT:0.0,1.0 +INT:int379,parsertemp12177 +rand(parsertemp12177,int379,0.0,1.0) +::STMT +MATRIX:parsertemp553122,missing +t(%*%(rowSums(*(missing,missing)),parsertemp553122)) +::STMT +MATRIX:parsertemp171314,t_gp,parsertemp171318,parsertemp171306 +FLOAT:int866,float62 +LITERAL_FLOAT:1.0,2.0,0.254829592 +*(exp(/(*(parsertemp171318,int866),2.0)),*(/(1.0,+(float62,parsertemp171306)),+(0.254829592,*(t_gp,parsertemp171314)))) +::STMT +MATRIX:grad +FLOAT:psi +*(psi,sqrt(sum(*(grad,grad)))) +::STMT +MATRIX:dX,v,X +FLOAT:lr,mu ++(X,-(*(mu,v),*(lr,dX))) +::STMT +MATRIX:R,parsertemp40219 +FLOAT:numRows,level +/(numRows,-(R,rowSums(==(parsertemp40219,level)))) +::STMT +MATRIX:d_r,parsertemp409781 +cast.FLOAT(%*%(t(rev(d_r)),parsertemp409781)) +::STMT +MATRIX:287_x,287_y +LITERAL_FLOAT:2.0 +/(+(cast.FLOAT(287_x),cast.FLOAT(287_y)),2.0) +::STMT +MATRIX:aggr_best_index_vector +LITERAL_FLOAT:0.0,1.0 ++(sum(==(aggr_best_index_vector,0.0)),1.0) +::STMT +MATRIX:id +FLOAT:parsertemp22683 +cast.FLOAT(diag(diag(==(id,parsertemp22683)))) +::STMT +MATRIX:w,X,y +*(-(%*%(X,w),y),-(%*%(X,w),y)) +::STMT +LITERAL_FLOAT:2.0 +INT:int554,int716 +rand(int716,int554,2.0,2.0) +::STMT +LITERAL_FLOAT:0.0 +INT:int87,int416 +rand(int87,int416,0.0,0.0) +::STMT +FLOAT:window_size,parsertemp180776,n +LITERAL_FLOAT:1.0 +-(+(-(n,window_size),1.0),+(parsertemp180776,1.0)) +::STMT +MATRIX:outSize +LITERAL_FLOAT:0.0 +cast.FLOAT(>(outSize,0.0)) +::STMT +MATRIX:P,I,X2 +LITERAL_FLOAT:0.0 +==(*(t(%*%(X2,P)),I),0.0) +::STMT +MATRIX:P,I +LITERAL_FLOAT:0.0 +==(%*%(P,I),0.0) +::STMT +MATRIX:R,dssp +FLOAT:4_n +/(4_n,+(R,dssp)) +::STMT +MATRIX:X +LITERAL_FLOAT:4.0 +>(X,4.0) +::STMT +MATRIX:parsertemp146930,184_unnorm_probs,184_probs,parsertemp146928,183_dpred,184_scores +FLOAT:int466,parsertemp146927 +-(*(*(*(parsertemp146927,parsertemp146928),/(int466,parsertemp146930)),/(exp(184_scores),rowSums(184_unnorm_probs))),*(/(exp(184_scores),rowSums(184_unnorm_probs)),rowSums(*(183_dpred,184_probs)))) +::STMT +MATRIX:Y +cast.MATRIX(max(Y)) +::STMT +FLOAT:run_index +LITERAL_FLOAT:1.0,48.0 ++(*(48.0,-(run_index,1.0)),1.0) +::STMT +LITERAL_FLOAT:1.0E-7 +INT:int267,m +rand(m,int267,1.0E-7,1.0E-7) +::STMT +MATRIX:F +LITERAL_FLOAT:2.0 +/(t(colSums(F)),2.0) +::STMT +FLOAT:sum_y_test,n +LITERAL_FLOAT:2.0 +^(/(sum_y_test,n),2.0) +::STMT +MATRIX:x,y +LITERAL_FLOAT:2.0 +/(+(x,y),2.0) +::STMT +MATRIX:gXY +FLOAT:lambda,parsertemp171602,beta +LITERAL_FLOAT:2.0 +sum(^(+(*(parsertemp171602,gXY),*(lambda,beta)),2.0)) +::STMT +MATRIX:X_plane +LITERAL_FLOAT:0.0 +>(X_plane,0.0) +::STMT +MATRIX:cumLens +FLOAT:i +LITERAL_FLOAT:1.0 +/(-(i,1.0),cast.FLOAT(cumLens)) +::STMT +MATRIX:err,cCnts +FLOAT:minSup +LITERAL_FLOAT:0.0 +|(<(cCnts,minSup),==(err,0.0)) +::STMT +MATRIX:parsertemp220845,ZERODIAG +LITERAL_FLOAT:1.0E-12 ++(rowSums(*(exp(parsertemp220845),ZERODIAG)),1.0E-12) +::STMT +MATRIX:parsertemp11509 +LITERAL_FLOAT:2.0 +*(2.0,parsertemp11509) +::STMT +MATRIX:intercept +LITERAL_FLOAT:0.0 +*(0.0,intercept) +::STMT +MATRIX:parsertemp171083 +LITERAL_FLOAT:-2.0 +*(-2.0,parsertemp171083) +::STMT +MATRIX:shift_X +FLOAT:lambda,p_CG,parsertemp170060,temp_CG +*(+(+(*(lambda,p_CG),*(parsertemp170060,temp_CG)),*(cast.FLOAT(shift_X),cast.FLOAT(temp_CG))),sum(p_CG)) +::STMT +MATRIX:cumHistMul,offset +cast.FLOAT(<=(offset,cumHistMul)) +::STMT +MATRIX:P,Y,Z,ZERODIAG,parsertemp220891 +FLOAT:int631,parsertemp220894 +%*%(*(-(P,/(Z,parsertemp220894)),*(/(int631,parsertemp220891),ZERODIAG)),Y) +::STMT +MATRIX:X,MSE +LITERAL_FLOAT:2.0 +/(^(max(X),2.0),MSE) +::STMT +MATRIX:parsertemp10744,V,W,H,parsertemp10748 +FLOAT:Eps +/(%*%(V,t(*(H,parsertemp10744))),+(%*%(W,%*%(H,parsertemp10748)),Eps)) +::STMT +MATRIX:parsertemp460641 +LITERAL_FLOAT:0.282842712474619 +*(parsertemp460641,0.282842712474619) +::STMT +MATRIX:P,gradients,Phi_new,Theta +FLOAT:alpha ++(Phi_new,*(alpha,%*%(t(gradients),%*%(P,Theta)))) +::STMT +MATRIX:xs +FLOAT:252_x +LITERAL_FLOAT:10.0 +-(10.0,sum(>=(xs,252_x))) +::STMT +MATRIX:Yhat_prime,H3_prime,E,W4 +*(H3_prime,%*%(*(E,Yhat_prime),W4)) +::STMT +MATRIX:means,parsertemp560530 +LITERAL_FLOAT:5.0 +/(sum(<(*(means,parsertemp560530),5.0)),*(nrow(means),ncol(means))) +::STMT +MATRIX:79_77_X_row_norm,parsertemp17178,parsertemp17180,Y_block,parsertemp17170,79_77_Y_row_norm,X_block +LITERAL_FLOAT:0.9 +*(>(/(%*%(X_block,parsertemp17180),%*%(79_77_X_row_norm,parsertemp17178)),0.9),/(%*%(X_block,t(Y_block)),%*%(+(79_77_X_row_norm,parsertemp17170),t(79_77_Y_row_norm)))) +::STMT +MATRIX:tmp,w,out +LITERAL_FLOAT:1.0,0.5 ++(*(0.5,cast.FLOAT(%*%(out,out))),*(1.0,cast.FLOAT(%*%(w,tmp)))) +::STMT +MATRIX:confusionM +min(rowSums(confusionM)) +::STMT +MATRIX:parsertemp175056,316_scores,X +-(/(exp(-(X,parsertemp175056)),rowSums(exp(316_scores))),/(exp(-(X,parsertemp175056)),rowSums(exp(316_scores)))) +::STMT +FLOAT:m2,float885,wt +LITERAL_FLOAT:5.0 +*(5.0,sqrt(/(*(m2,wt),-(wt,float885)))) +::STMT +MATRIX:validKeyMask +cast.FLOAT(colSums(validKeyMask)) +::STMT +MATRIX:classes +LITERAL_FLOAT:1.0,0.8 +*(cast.FLOAT(classes),-(1.0,0.8)) +::STMT +MATRIX:U,V,X +-(%*%(U,t(V)),X) +::STMT +MATRIX:termination_bitmap,parsertemp222665 +LITERAL_FLOAT:10.0 ++(*(10.0,max(*(parsertemp222665,termination_bitmap))),10.0) +::STMT +MATRIX:sv,s,w,X,Y,out +FLOAT:lambda,step_sz +-(%*%(t(X),*(*(sv,out),Y)),*(lambda,+(w,*(step_sz,s)))) +::STMT +MATRIX:parsertemp195898 +FLOAT:parsertemp195895,factor_up +LITERAL_FLOAT:1.0 +-(1.0,abs(-(/(parsertemp195898,factor_up),/(parsertemp195895,factor_up)))) +::STMT +FLOAT:p_CG,parsertemp170088,z,pp_CG,parsertemp170090 +LITERAL_FLOAT:-1.0 +/(-(*(*(z,p_CG),-1.0),sqrt(-(parsertemp170088,parsertemp170090))),pp_CG) +::STMT +MATRIX:parsertemp31115,parsertemp31108 +FLOAT:parsertemp31116,parsertemp31109 +LITERAL_FLOAT:1500.0,2000.0 +sqrt(+(/(/(parsertemp31108,parsertemp31109),2000.0),/(/(parsertemp31115,parsertemp31116),1500.0))) +::STMT +MATRIX:t,parsertemp171083,parsertemp171092 +FLOAT:float141 +LITERAL_FLOAT:1.0,1.432788 ++(1.0,*(sqrt(*(float141,parsertemp171083)),+(1.432788,*(t,parsertemp171092)))) +::STMT +MATRIX:y +FLOAT:beta +LITERAL_FLOAT:2.0,100.0 +/(sum(^(-(beta,y),2.0)),100.0) +::STMT +MATRIX:p,q,V,parsertemp1939 +FLOAT:norm_r2,eps +*(/(norm_r2,cast.FLOAT(%*%(parsertemp1939,q))),+(%*%(t(V),%*%(V,p)),*(eps,p))) +::STMT +MATRIX:surv,se_surv +FLOAT:z_alpha_2 +/(*(z_alpha_2,se_surv),surv) +::STMT +MATRIX:rowSums_X_sq +LITERAL_FLOAT:3.5355339059327378 +/(3.5355339059327378,max(sqrt(rowSums_X_sq))) +::STMT +MATRIX:s,w,wnew +LITERAL_FLOAT:0.5 +*(0.5,cast.FLOAT(%*%(t(wnew),+(w,s)))) +::STMT +LITERAL_FLOAT:1.0E20 +INT:int563,n +rand(n,int563,1.0E20,1.0E20) +::STMT +FLOAT:prob_true,prob_false +LITERAL_FLOAT:2.0 ++(^(prob_true,2.0),^(prob_false,2.0)) +::STMT +MATRIX:R,dsep,dssm +FLOAT:2_eAvg +/(/(+(R,dsep),-(R,dssm)),2_eAvg) +::STMT +MATRIX:is_too_small,parsertemp171346,the_exp_exp,linear_terms,the_exp +FLOAT:int95,int146,int568,int902,int805 +LITERAL_FLOAT:1.0,1.0E7 ++(/(*(-(int805,is_too_small),-(int902,the_exp_exp)),+(exp(linear_terms),==(parsertemp171346,int568))),*(==(+(int146,the_exp),1.0E7),-(1.0,/(the_exp,int95)))) +::STMT +MATRIX:T_1,parsertemp410245,event,parsertemp410248 +FLOAT:int916,float628 +LITERAL_FLOAT:1.0,1.5 +/(^(/(*(parsertemp410245,int916),*(float628,parsertemp410248)),/(1.0,1.5)),/(-(max(T_1),min(T_1)),sum(event))) +::STMT +FLOAT:obj,objnew +/(abs(-(objnew,obj)),obj) +::STMT +FLOAT:padw,padh,Hin,Win +LITERAL_FLOAT:2.0 +*(+(Hin,*(2.0,padh)),+(Win,*(2.0,padw))) +::STMT +MATRIX:LHSthreshold +LITERAL_FLOAT:1.0 +>(LHSthreshold,1.0) +::STMT +MATRIX:2707_X,2706_dX +LITERAL_FLOAT:0.0 +colSums(*(>(2707_X,0.0),2706_dX)) +::STMT +MATRIX:parsertemp220853,parsertemp220854,Hneg,beta,betamin,Hpos +FLOAT:logU +LITERAL_FLOAT:0.0 +*(<(-(+(parsertemp220853,parsertemp220854),logU),0.0),+(beta,+(*(Hneg,betamin),*(Hpos,beta)))) +::STMT +MATRIX:linear_terms,Y +LITERAL_FLOAT:0.0 +*(^(exp(linear_terms),0.0),-(Y,exp(linear_terms))) +::STMT +FLOAT:R,eta,s +LITERAL_FLOAT:-1.0 +*(R,^(eta,*(s,-1.0))) +::STMT +FLOAT:sig,q,parsertemp181039,int284 +LITERAL_FLOAT:1.0,8.0 +*(8.0,-(1.0,/(-(q,parsertemp181039),*(int284,sig)))) +::STMT +MATRIX:Y,parsertemp283552 +-(sum(Y),parsertemp283552) +::STMT +MATRIX:newbeta,lambda +LITERAL_FLOAT:2.0 +%*%(t(lambda),^(newbeta,2.0)) +::STMT +LITERAL_FLOAT:10.0,1.5,-8.0 +*(1.5,^(10.0,-8.0)) +::STMT +MATRIX:Train,2342_m_colmax,2342_m_colmin +LITERAL_FLOAT:2.0 +/(*(2.0,-(Train,2342_m_colmin)),-(2342_m_colmax,2342_m_colmin)) +::STMT +MATRIX:parsertemp143446,parsertemp143445 +&(parsertemp143445,parsertemp143446) +::STMT +MATRIX:X_batch,dout1 +FLOAT:191_beta1 +LITERAL_FLOAT:1.0 +*(-(1.0,191_beta1),%*%(t(X_batch),dout1)) +::STMT +MATRIX:std,rad +-(rad,cast.FLOAT(std)) +::STMT +MATRIX:parsertemp171315,parsertemp171307,parsertemp171319 +FLOAT:float489,float311,float639 +LITERAL_FLOAT:2.0 +-(2.0,*(exp(/(parsertemp171319,float489)),*(/(float311,parsertemp171307),+(float639,parsertemp171315)))) +::STMT +MATRIX:y_corr +LITERAL_FLOAT:0.5 +>(y_corr,0.5) +::STMT +MATRIX:s,sts,d,parsertemp44023 +FLOAT:delta2 +LITERAL_FLOAT:2.0 ++(^(%*%(t(s),d),2.0),*(cast.FLOAT(%*%(parsertemp44023,d)),-(delta2,cast.FLOAT(sts)))) +::STMT +MATRIX:t_gp,parsertemp560881,parsertemp560864,parsertemp560863,parsertemp560877 +FLOAT:int551,int310,float761 +LITERAL_FLOAT:1.0 ++(-(1.0,-(*(int310,parsertemp560863),1.0)),*(*(*(t_gp,parsertemp560877),-(parsertemp560864,int551)),exp(/(parsertemp560881,float761)))) +::STMT +MATRIX:parsertemp43620,y +FLOAT:float213 +LITERAL_FLOAT:1.0 +*(-(/(1.0,+(float213,parsertemp43620)),1.0),y) +::STMT +MATRIX:X_plane,parsertemp11251 +LITERAL_FLOAT:0.0,2.0 +*(>(X_plane,0.0),t(^(2.0,parsertemp11251))) +::STMT +MATRIX:p,parsertemp285529,g +FLOAT:pp,pq,int41,pz,parsertemp285521,parsertemp285537 +*(+(+(*(parsertemp285537,pq),sum(parsertemp285529)),sum(*(g,p))),/(+(*(pz,int41),sqrt(parsertemp285521)),pp)) +::STMT +MATRIX:W1_rand +FLOAT:num_hidden1,m +LITERAL_FLOAT:6.0 +*(/(sqrt(6.0),sqrt(+(m,num_hidden1))),W1_rand) +::STMT +FLOAT:int584,m2,float284 +LITERAL_FLOAT:2003.0 +sqrt(*(/(2003.0,-(int584,float284)),m2)) +::STMT +LITERAL_FLOAT:1.0E-7 +1.0E-7 +::STMT +MATRIX:parsertemp27746,parsertemp27872 +FLOAT:featureCorrection +LITERAL_FLOAT:0.0 ++(%*%(parsertemp27872,t(parsertemp27746)),-(0.0,featureCorrection)) +::STMT +MATRIX:scale_X,parsertemp429910 +LITERAL_FLOAT:300.0,0.0 +*(-(0.0,/(t(parsertemp429910),300.0)),scale_X) +::STMT +MATRIX:parsertemp79022 +LITERAL_FLOAT:0.5,1270.0 +round(+(0.5,/(parsertemp79022,1270.0))) +::STMT +MATRIX:prec_chol,X +LITERAL_FLOAT:2.0 +%*%(^(X,2.0),t(^(prec_chol,2.0))) +::STMT +MATRIX:t_gp,pt_gp,parsertemp171320,Y,the_gauss_exp,parsertemp171316 +LITERAL_FLOAT:2.0,0.25,0.15915494309189535 +/(*(*(exp(parsertemp171320),0.15915494309189535),rowSums(Y)),*(*(0.25,*(t_gp,parsertemp171316)),-(2.0,*(the_gauss_exp,pt_gp)))) +::STMT +MATRIX:cumHistMul,offset,parsertemp132495,histMul,outBucket +LITERAL_FLOAT:1.0 +-(-(offset,%*%(==(outBucket,parsertemp132495),-(cumHistMul,histMul))),1.0) +::STMT +MATRIX:parsertemp1904,y +LITERAL_FLOAT:-1.0 +sum(*(*(%*%(parsertemp1904,y),-1.0),*(%*%(parsertemp1904,y),-1.0))) +::STMT +MATRIX:y +FLOAT:beta +LITERAL_FLOAT:2.0,10.0 +/(sum(^(-(beta,y),2.0)),10.0) +::STMT +FLOAT:i,k +LITERAL_FLOAT:2.0,4.0 +-(+(+(i,k),4.0),2.0) +::STMT +MATRIX:X +FLOAT:M +/(ncol(X),M) +::STMT +MATRIX:X +LITERAL_FLOAT:200.0 +/(t(colSums(X)),200.0) +::STMT +FLOAT:s,num_groups +LITERAL_FLOAT:1.0 +*(-(s,1.0),-(num_groups,1.0)) +::STMT +MATRIX:id +==(id,cast.FLOAT(id)) +::STMT +MATRIX:R,svLowBnd +>(R,cast.FLOAT(svLowBnd)) +::STMT +MATRIX:X +LITERAL_FLOAT:300.0 +/(t(colSums(X)),300.0) +::STMT +FLOAT:s +LITERAL_FLOAT:-1.0,50.0,3.0 +*(50.0,^(3.0,*(s,-1.0))) +::STMT +FLOAT:var,arch_coef,xt,var_coef,int838,a0 ++(+(a0,*(arch_coef,^(xt,int838))),*(var_coef,var)) +::STMT +MATRIX:parsertemp171318 +FLOAT:int267,one_over_sqrt_two_pi +LITERAL_FLOAT:2.0 +*(exp(/(*(parsertemp171318,int267),2.0)),^(one_over_sqrt_two_pi,2.0)) +::STMT +MATRIX:ssX_V,X,parsertemp150463,P_1K +%*%(rowSums(*(P_1K,%*%(X,ssX_V))),parsertemp150463) +::STMT +MATRIX:sv,out +LITERAL_FLOAT:2.0,0.5 +*(0.5,sum(^(*(sv,out),2.0))) +::STMT +MATRIX:probs,y_batch +LITERAL_FLOAT:0.0,1.0,1.0E-10 +*(*(/(1.0,nrow(y_batch)),-(0.0,y_batch)),/(1.0,+(probs,1.0E-10))) +::STMT +FLOAT:i,cols,n +LITERAL_FLOAT:1.0 +-(n,-(+(i,cols),1.0)) +::STMT +MATRIX:parsertemp222331 +LITERAL_FLOAT:200.0,0.5 ++(0.5,/(parsertemp222331,200.0)) +::STMT +LITERAL_FLOAT:1.0,2.0,2000.0 +-(^(2000.0,2.0),1.0) +::STMT +MATRIX:parsertemp175083 +LITERAL_FLOAT:1.0E-6 +cast.MATRIX(sum(<(abs(parsertemp175083),1.0E-6))) +::STMT +MATRIX:Y,linear_terms +LITERAL_FLOAT:0.0 +-(Y,*(rowSums(Y),>=(linear_terms,0.0))) +::STMT +MATRIX:parsertemp44079 +FLOAT:C +LITERAL_FLOAT:-1.0 +*(C,sum(*(parsertemp44079,-1.0))) +::STMT +MATRIX:y_corr +LITERAL_FLOAT:0.0,1.0 +-(1.0,<=(y_corr,0.0)) +::STMT +FLOAT:qmle,var_t,int653,xq_t,parsertemp496694,n +LITERAL_FLOAT:1.0 +-(qmle,*(/(1.0,*(int653,n)),+(parsertemp496694,/(xq_t,var_t)))) +::STMT +MATRIX:b4,parsertemp389338 +LITERAL_FLOAT:2.0 +exp(*(2.0,t(+(parsertemp389338,b4)))) +::STMT +MATRIX:parsertemp397828,parsertemp397825,W3_rand +LITERAL_FLOAT:0.5107539184552492 +t(%*%(*(0.5107539184552492,W3_rand),t(/(parsertemp397825,parsertemp397828)))) +::STMT +MATRIX:wnew,parsertemp44111 +LITERAL_FLOAT:2.0 +sqrt(sum(^(+(wnew,parsertemp44111),2.0))) +::STMT +MATRIX:_sbcvar2306 +LITERAL_FLOAT:1.0 ++(max(t(_sbcvar2306)),1.0) +::STMT +MATRIX:simplex +LITERAL_FLOAT:2.0 +*(2.0,/(-(rowSums(simplex),simplex),nrow(simplex))) +::STMT +MATRIX:W1_rand,stds,parsertemp394896 +LITERAL_FLOAT:0.08146881698903526 +t(%*%(*(0.08146881698903526,W1_rand),t(/(parsertemp394896,stds)))) +::STMT +MATRIX:V,y +%*%(t(V),y) +::STMT +MATRIX:is_natural_parameter_log_zero,Y +LITERAL_FLOAT:0.0,1.0 +-(1.0,*(>(Y,0.0),is_natural_parameter_log_zero)) +::STMT +FLOAT:int143,o_init,int524,o +LITERAL_FLOAT:-1.0,50.0 +/(*(-(*(int524,o_init),*(int143,o)),-1.0),50.0) +::STMT +MATRIX:U,V_sum +/(*(U,U),sum(V_sum)) +::STMT +FLOAT:parsertemp565893,h,y_offset +LITERAL_FLOAT:1.0 +-(+(+(parsertemp565893,y_offset),h),1.0) +::STMT +LITERAL_FLOAT:0.054717579189018505 +0.054717579189018505 +::STMT +MATRIX:X_batch,dout1,mW1 +FLOAT:191_beta1 +LITERAL_FLOAT:1.0 ++(*(191_beta1,mW1),*(-(1.0,191_beta1),%*%(t(X_batch),dout1))) +::STMT +MATRIX:X_batch,parsertemp389606,parsertemp389591,2364_2361_Y,parsertemp389588,W4 +FLOAT:int318 +LITERAL_FLOAT:1.0,2.0 +*(-(1.0,^(/(parsertemp389588,parsertemp389591),2.0)),%*%(*(-(2364_2361_Y,X_batch),-(int318,parsertemp389606)),W4)) +::STMT +MATRIX:d_r_rev,Hd_1,Hd_2 +t(colSums(*(-(Hd_1,Hd_2),d_r_rev))) +::STMT +MATRIX:I,parsertemp472360 +LITERAL_FLOAT:0.0 +*(I,==(!=(*(parsertemp472360,I),0.0),0.0)) +::STMT +LITERAL_FLOAT:1.0,0.8 +-(1.0,-(1.0,0.8)) +::STMT +MATRIX:parsertemp222700,parsertemp222697,parsertemp222694 +FLOAT:int857 +t(<=(+(*(int857,parsertemp222694),t(parsertemp222697)),parsertemp222700)) +::STMT +FLOAT:int227,429_C +LITERAL_FLOAT:1.0,2.0 +sqrt(/(2.0,*(*(429_C,int227),1.0))) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:-1.0,2.0 +*(^(finite_linear_terms,2.0),-1.0) +::STMT +MATRIX:X,Y,K +FLOAT:x +LITERAL_FLOAT:1.0 +*(-(*(K,-(X,X)),-(Y,Y)),-(1.0,/(-(x,X),-(X,X)))) +::STMT +MATRIX:V +-(max(V),min(V)) +::STMT +FLOAT:2690_Hin +LITERAL_FLOAT:0.0,2.0 ++(2690_Hin,*(2.0,0.0)) +::STMT +MATRIX:parsertemp386457,parsertemp386459,neighbors,parsertemp386455 +LITERAL_FLOAT:0.0 +==(-(*(*(neighbors,parsertemp386455),parsertemp386457),parsertemp386459),0.0) +::STMT +MATRIX:grad +FLOAT:int396,int927 +sqrt(sum(*(*(grad,int927),*(grad,int396)))) +::STMT +MATRIX:residuals_vector +FLOAT:lambda +/(sum(residuals_vector),+(nrow(residuals_vector),lambda)) +::STMT +MATRIX:g0_2,g0_1,g0 +LITERAL_FLOAT:1.0E-12 +*(cast.FLOAT(%*%(t(g0),+(g0_1,g0_2))),1.0E-12) +::STMT +MATRIX:Yhat_prime,E +colSums(*(E,Yhat_prime)) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0005002501250626 +*(1.0005002501250626,m2) +::STMT +FLOAT:sim_score_left,sim_score_right,sim_score_parent +-(+(sim_score_left,sim_score_right),sim_score_parent) +::STMT +MATRIX:samples_vs_runs_map,X_samples_sq_norms,parsertemp222439,parsertemp222443,X_samples +LITERAL_FLOAT:2.0 +-(+(X_samples_sq_norms,%*%(samples_vs_runs_map,rowSums(parsertemp222439))),*(2.0,rowSums(*(X_samples,parsertemp222443)))) +::STMT +MATRIX:parsertemp500609,parsertemp500606,parsertemp500604,X,y +FLOAT:int564 +-(%*%(X,*(*(parsertemp500604,parsertemp500606),>(parsertemp500609,int564))),y) +::STMT +FLOAT:window_size,i,k +LITERAL_FLOAT:2.0 +-(+(+(i,k),window_size),2.0) +::STMT +LITERAL_FLOAT:4.890349128221754 +4.890349128221754 +::STMT +MATRIX:negSampleMeans,negSamples +FLOAT:int877,int492 +LITERAL_FLOAT:1.0,150.0 +/(-(colSums(^(negSamples,int492)),*(150.0,^(negSampleMeans,int877))),-(150.0,1.0)) +::STMT +MATRIX:y_val,preds +%*%(t(-(y_val,preds)),-(y_val,preds)) +::STMT +MATRIX:A +abs(t(A)) +::STMT +FLOAT:check_max,check_min +LITERAL_FLOAT:2.0 +/(2.0,-(check_max,check_min)) +::STMT +FLOAT:idx +LITERAL_FLOAT:1.0,253.0 ++(-(253.0,idx),1.0) +::STMT +MATRIX:ZtZ,Xm,parsertemp265719,parsertemp265718,parsertemp265714 +LITERAL_FLOAT:2.0 +-(+(sum(*(Xm,Xm)),trace(*(ZtZ,parsertemp265714))),*(2.0,sum(%*%(parsertemp265718,parsertemp265719)))) +::STMT +MATRIX:W +FLOAT:int573,parsertemp97,int148,m4,int722,wt,int371 +LITERAL_FLOAT:1.0 +-(*(*(^(wt,int371),+(wt,int148)),m4),*(*(*(int722,parsertemp97),^(wt,int573)),-(sum(W),1.0))) +::STMT +MATRIX:r_CG,p_CG +FLOAT:rr_CG,old_rr_CG +LITERAL_FLOAT:-1.0 ++(*(r_CG,-1.0),*(/(rr_CG,old_rr_CG),p_CG)) +::STMT +FLOAT:int153,float879,float406,int53 +LITERAL_FLOAT:1.0,3.0,6.0,2003.0 +/(*(*(6.0,2003.0),-(2003.0,1.0)),*(*(-(int153,float879),+(int53,float406)),+(2003.0,3.0))) +::STMT +FLOAT:429_C +LITERAL_FLOAT:1.0,2.0 +/(2.0,*(*(429_C,1.0),1.0)) +::STMT +MATRIX:S,V,parsertemp149285 +FLOAT:int503,delta2 +LITERAL_FLOAT:2.0 ++(^(sum(*(S,V)),2.0),*(sum(^(V,int503)),-(delta2,sum(parsertemp149285)))) +::STMT +MATRIX:p,q,A +FLOAT:norm_r2 +*(/(norm_r2,sum(*(p,q))),%*%(t(A),%*%(A,p))) +::STMT +MATRIX:r,Hd +FLOAT:parsertemp44049 +sum(*(-(r,*(parsertemp44049,Hd)),-(r,*(parsertemp44049,Hd)))) +::STMT +LITERAL_FLOAT:0.0,1.0 +INT:int633,int424 ++(1.0,exp(rand(int633,int424,0.0,0.0))) +::STMT +MATRIX:s,d,tau ++(s,*(cast.FLOAT(tau),d)) +::STMT +MATRIX:leaf_ids +FLOAT:boundary_left,step_size +&(>=(leaf_ids,boundary_left),<(leaf_ids,+(boundary_left,step_size))) +::STMT +MATRIX:P,Q,Y,Z,ZERODIAG +*(Y,rowSums(*(-(P,Q),*(Z,ZERODIAG)))) +::STMT +MATRIX:B,X,y +-(y,%*%(X,B)) +::STMT +MATRIX:s,d +FLOAT:norm_r2,alpha_deno +%*%(t(+(s,*(norm_r2,d))),+(s,*(/(norm_r2,alpha_deno),d))) +::STMT +MATRIX:parsertemp437192,parsertemp437191,parsertemp437237,mean,weight,avgMean +FLOAT:int874 +LITERAL_FLOAT:1.0E-9 ++(+(-(/(parsertemp437237,parsertemp437192),*(int874,avgMean)),/(*(mean,parsertemp437191),t(weight))),1.0E-9) +::STMT +MATRIX:W,X,H,parsertemp411105,parsertemp411107 +LITERAL_FLOAT:1.0E-8 +/(%*%(X,t(*(H,parsertemp411105))),+(%*%(W,%*%(H,parsertemp411107)),1.0E-8)) +::STMT +FLOAT:m2 +LITERAL_FLOAT:5.0,1.0005 +*(5.0,sqrt(*(1.0005,m2))) +::STMT +MATRIX:parsertemp129186,parsertemp129185,key_unique,key +==(%*%(key_unique,parsertemp129185),%*%(parsertemp129186,t(key))) +::STMT +MATRIX:hubs +LITERAL_FLOAT:2.0 +abs(sum(^(-(hubs,hubs),2.0))) +::STMT +MATRIX:P,N_T,X,parsertemp230442 +<=(rowSums(*(X,parsertemp230442)),%*%(P,t(N_T))) +::STMT +MATRIX:R,parsertemp497774 +LITERAL_FLOAT:0.0 +-(ncol(R),sum(==(colSums(parsertemp497774),0.0))) +::STMT +MATRIX:A +FLOAT:parsertemp22359,a21,parsertemp22358,int923 +LITERAL_FLOAT:1.0 +sqrt(+(+(+(parsertemp22358,parsertemp22359),/(int923,a21)),/(1.0,cast.FLOAT(A)))) +::STMT +LITERAL_FLOAT:8.660254037844387 +8.660254037844387 +::STMT +MATRIX:y +FLOAT:beta +LITERAL_FLOAT:2.0 +^(-(beta,y),2.0) +::STMT +MATRIX:D,parsertemp570375,classMeans +%*%(-(D,classMeans),parsertemp570375) +::STMT +FLOAT:481_Hf,481_Hin +LITERAL_FLOAT:0.0,2.0 +-(+(481_Hin,*(2.0,0.0)),481_Hf) +::STMT +MATRIX:parsertemp10964,C +sum(==(parsertemp10964,C)) +::STMT +MATRIX:parsertemp146940,184_dtemp,mW3,outr2 +FLOAT:beta1 +LITERAL_FLOAT:1.0 ++(*(beta1,mW3),*(-(1.0,beta1),%*%(t(outr2),-(184_dtemp,parsertemp146940)))) +::STMT +MATRIX:G,authorities +max(%*%(G,authorities)) +::STMT +MATRIX:nI +LITERAL_FLOAT:0.25 +*(0.25,ncol(nI)) +::STMT +FLOAT:int455,int456,o_init,N,o +LITERAL_FLOAT:-1.0 +/(*(-(*(int456,o_init),*(int455,o)),-1.0),N) +::STMT +MATRIX:confusionM +min(colSums(confusionM)) +::STMT +MATRIX:parsertemp383011,X,X_nonzero_ind +LITERAL_FLOAT:2.0 +sum(*(X_nonzero_ind,^(-(X,parsertemp383011),2.0))) +::STMT +MATRIX:parsertemp498248,m_iter_err_sum,m_err +FLOAT:int526,i_process_item +LITERAL_FLOAT:2.0 +*(*(2.0,/(-(int526,parsertemp498248),i_process_item)),+(colSums(m_err),m_iter_err_sum)) +::STMT +MATRIX:std,sts,rad +FLOAT:delta2 +/(-(delta2,sts),+(std,rad)) +::STMT +MATRIX:_sbcvar1708 +LITERAL_FLOAT:105.0 ++(105.0,nrow(_sbcvar1708)) +::STMT +MATRIX:parsertemp414375,parsertemp414377,parsertemp414379 +FLOAT:int577,int293 +LITERAL_FLOAT:0.0,1.0,199.0 +*(/(-(t(parsertemp414375),*(int577,parsertemp414377)),199.0),-(1.0,<=(/(parsertemp414379,int293),0.0))) +::STMT +MATRIX:maskNAN +LITERAL_FLOAT:0.0 +!=(rowSums(maskNAN),0.0) +::STMT +MATRIX:parsertemp171083 +LITERAL_FLOAT:-2.0,-1.0 +*(sqrt(*(-2.0,parsertemp171083)),-1.0) +::STMT +MATRIX:parsertemp170248,parsertemp170253,parsertemp170240,lt_pos_neg +FLOAT:float811,float257,float69 ++(lt_pos_neg,*(*(-(float257,lt_pos_neg),exp(parsertemp170253)),*(/(float811,parsertemp170240),+(float69,parsertemp170248)))) +::STMT +MATRIX:prec_chol,X,mu +FLOAT:int69 +%*%(X,t(*(mu,^(prec_chol,int69)))) +::STMT +MATRIX:parsertemp13624,_sbcvar11 +FLOAT:int171 +LITERAL_FLOAT:2.0,1000.0 +/(^(-(_sbcvar11,/(parsertemp13624,int171)),2.0),/(%*%(rowSums(_sbcvar11),colSums(_sbcvar11)),1000.0)) +::STMT +MATRIX:r,Hd +FLOAT:parsertemp44049 +LITERAL_FLOAT:2.0 +sum(^(-(r,*(parsertemp44049,Hd)),2.0)) +::STMT +MATRIX:tmp_Xw,parsertemp260747,Y,Xw +LITERAL_FLOAT:0.0,1.0 +*(-(1.0,*(Y,+(Xw,parsertemp260747))),>(-(1.0,*(Y,tmp_Xw)),0.0)) +::STMT +MATRIX:out2,parsertemp146942,184_dscores,maskd1,out1,W2 +FLOAT:p,int336 +LITERAL_FLOAT:0.0 +*(*(>(out1,0.0),/(maskd1,p)),%*%(*(>(out2,int336),%*%(184_dscores,parsertemp146942)),t(W2))) +::STMT +MATRIX:is_LT_infinite,parsertemp171366,p_one_m_one +LITERAL_FLOAT:3.141592653589793,1.0,0.5 +*(+(0.5,/(%*%(parsertemp171366,p_one_m_one),3.141592653589793)),-(1.0,rowSums(is_LT_infinite))) +::STMT +MATRIX:parsertemp231012 +FLOAT:parsertemp231013 +LITERAL_FLOAT:1.0,2.0 +-(1.0,sum(^(/(parsertemp231012,parsertemp231013),2.0))) +::STMT +MATRIX:V,y +LITERAL_FLOAT:0.0,2.0 +^(-(0.0,%*%(t(V),y)),2.0) +::STMT +MATRIX:c,x_r +LITERAL_FLOAT:2.0 +-(*(2.0,x_r),c) +::STMT +MATRIX:X +FLOAT:int758 +LITERAL_FLOAT:1.0 +/(1.0,+(1.0,exp(*(X,int758)))) +::STMT +MATRIX:vW1,W1,dW1 +FLOAT:2727_mu,2727_lr +LITERAL_FLOAT:1.0 ++(-(W1,*(2727_mu,vW1)),*(+(1.0,2727_mu),-(*(2727_mu,vW1),*(2727_lr,dW1)))) +::STMT +MATRIX:W +FLOAT:m2,wt,float491 +/(sqrt(/(*(m2,wt),-(wt,float491))),sqrt(sum(round(W)))) +::STMT +MATRIX:P,Q +LITERAL_FLOAT:-2.0 ++(*(-2.0,%*%(P,t(Q))),P) +::STMT +MATRIX:X,y +FLOAT:float984,float563 +LITERAL_FLOAT:-1.0 +INT:int154,int667 +exp(*(*(y,-1.0),%*%(X,rand(int154,int667,float984,float563)))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,1024.0 +-(+(i,1024.0),1.0) +::STMT +MATRIX:y +LITERAL_FLOAT:1.0 +/(1.0,nrow(y)) +::STMT +MATRIX:X +*(nrow(X),ncol(X)) +::STMT +MATRIX:_sbcvar78 +LITERAL_FLOAT:10000.0 +-(_sbcvar78,/(%*%(rowSums(_sbcvar78),colSums(_sbcvar78)),10000.0)) +::STMT +MATRIX:parsertemp43619 +LITERAL_FLOAT:1.0 +-(1.0,/(1.0,+(1.0,exp(parsertemp43619)))) +::STMT +MATRIX:parsertemp383012,parsertemp383020,parsertemp383017,X_nonzero_ind +FLOAT:reg,int800 ++(sum(*(X_nonzero_ind,^(parsertemp383012,int800))),*(reg,+(sum(parsertemp383017),sum(parsertemp383020)))) +::STMT +MATRIX:parsertemp400673,W4_rand +FLOAT:int116,int619 +LITERAL_FLOAT:0.08720414403938946 +%*%(*(0.08720414403938946,W4_rand),t(/(-(parsertemp400673,int619),+(parsertemp400673,int116)))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,1048.0 +-(+(i,1048.0),1.0) +::STMT +MATRIX:parsertemp570381,parsertemp570372,parsertemp570376,parsertemp570377 +FLOAT:int431,int433,int633,int645 ++(parsertemp570381,-(*(/(int433,int431),parsertemp570372),*(/(int633,int645),%*%(parsertemp570376,parsertemp570377)))) +::STMT +MATRIX:parsertemp389580,parsertemp389562,parsertemp389565,2365_delta3,W2,W3 +FLOAT:int926 +LITERAL_FLOAT:1.0,2.0 +*(-(1.0,^(/(parsertemp389562,parsertemp389565),2.0)),%*%(*(-(int926,parsertemp389580),%*%(2365_delta3,W3)),W2)) +::STMT +MATRIX:Y,parsertemp2798,Xw +FLOAT:int92 +LITERAL_FLOAT:0.0,1.0,2.0 +^(*(>(-(int92,parsertemp2798),0.0),-(1.0,*(Y,Xw))),2.0) +::STMT +MATRIX:s,d,alpha +t(-(s,*(cast.FLOAT(alpha),d))) +::STMT +MATRIX:parsertemp31189,parsertemp31187 +FLOAT:int226,int613 +LITERAL_FLOAT:1.0,2.0,7000.0 +/(^(/(-(parsertemp31187,parsertemp31189),-(int226,int613)),2.0),*(^(7000.0,2.0),-(7000.0,1.0))) +::STMT +MATRIX:col,less_than_lb,parsertemp24102,parsertemp24103 +FLOAT:int760,num_bins,int226 +LITERAL_FLOAT:1.0 ++(*(-(-(int226,less_than_lb),>(col,num_bins)),+(round(parsertemp24102),1.0)),*(>(+(parsertemp24103,int760),num_bins),num_bins)) +::STMT +FLOAT:m2Y,sigmaX,covXY,parsertemp26584 +/(covXY,*(sigmaX,sqrt(*(m2Y,parsertemp26584)))) +::STMT +MATRIX:g,parsertemp169907 +sqrt(sum(*(+(g,parsertemp169907),+(g,parsertemp169907)))) +::STMT +MATRIX:2814_K,2814_X,2814_Y +FLOAT:int302 ++(*(cast.FLOAT(-(int302,2814_K)),-(cast.FLOAT(2814_X),cast.FLOAT(2814_X))),-(cast.FLOAT(2814_Y),cast.FLOAT(2814_Y))) +::STMT +MATRIX:Y +cast.MATRIX(min(Y)) +::STMT +MATRIX:tmp_Xw,parsertemp260749,Y +FLOAT:int438 +LITERAL_FLOAT:0.0,1.0 +*(*(-(1.0,*(Y,tmp_Xw)),>(-(int438,parsertemp260749),0.0)),Y) +::STMT +MATRIX:parsertemp31732,parsertemp31734,dssm,dsem +FLOAT:5_eAvg +LITERAL_FLOAT:1.0 +-(/(/(-(parsertemp31734,dsem),-(parsertemp31732,dssm)),5_eAvg),1.0) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,133.0 +*(133.0,-(i,1.0)) +::STMT +MATRIX:out2,parsertemp146942,184_dscores,outd1 +FLOAT:beta1,int17 +LITERAL_FLOAT:1.0 +*(-(1.0,beta1),%*%(t(outd1),*(>(out2,int17),%*%(184_dscores,parsertemp146942)))) +::STMT +MATRIX:_sbcvar92 +LITERAL_FLOAT:0.0 +==(/(%*%(rowSums(_sbcvar92),colSums(_sbcvar92)),sum(_sbcvar92)),0.0) +::STMT +MATRIX:parsertemp382672,parsertemp382681,parsertemp382668,parsertemp382678 +FLOAT:reg +LITERAL_FLOAT:0.5 ++(*(0.5,sum(*(parsertemp382668,parsertemp382672))),*(*(0.5,reg),+(sum(parsertemp382678),sum(parsertemp382681)))) +::STMT +MATRIX:intercept +FLOAT:int172,int470 +INT:num_records,int150 +%*%(rand(num_records,int150,int172,int470),intercept) +::STMT +MATRIX:neighbors +FLOAT:eps +<=(-(neighbors,diag(diag(neighbors))),eps) +::STMT +MATRIX:R,w +INT:parsertemp31673,int63 ++(R,diag(rand(parsertemp31673,int63,cast.FLOAT(w),cast.FLOAT(w)))) +::STMT +MATRIX:240_elt,240_ones_ctg +%*%(rowSums(240_elt),t(240_ones_ctg)) +::STMT +MATRIX:p +FLOAT:eps +*(eps,p) +::STMT +MATRIX:sample_rec_ids +FLOAT:num_records +LITERAL_FLOAT:1.0 +*(+(num_records,1.0),-(1.0,<=(sample_rec_ids,num_records))) +::STMT +MATRIX:s,parsertemp44005,d +FLOAT:parsertemp44004 +cast.FLOAT(%*%(t(+(s,parsertemp44005)),+(s,*(parsertemp44004,d)))) +::STMT +MATRIX:X_batch,2365_delta2,W2,parsertemp389567 +FLOAT:int376 +%*%(t(*(-(int376,parsertemp389567),%*%(2365_delta2,W2))),X_batch) +::STMT +MATRIX:A,b +LITERAL_FLOAT:-1.0 +*(%*%(*(t(A),-1.0),b),-1.0) +::STMT +MATRIX:X,mu,precisions +LITERAL_FLOAT:2.0 +*(2.0,%*%(X,t(*(mu,precisions)))) +::STMT +FLOAT:var_power,link_power +LITERAL_FLOAT:2.0 +-(/(-(2.0,var_power),link_power),2.0) +::STMT +MATRIX:Y,the_exp +FLOAT:int14 +-(*(rowSums(Y),exp(-(int14,the_exp))),Y) +::STMT +MATRIX:cumHistMul,offset +<=(offset,cumHistMul) +::STMT +FLOAT:current_hash_value +LITERAL_FLOAT:1.0,33.0 +-(33.0,+(current_hash_value,1.0)) +::STMT +MATRIX:F,parsertemp27458 +FLOAT:W +LITERAL_FLOAT:0.0,1.0E-4 ++(*(==(/(parsertemp27458,W),0.0),1.0E-4),/(%*%(rowSums(F),colSums(F)),sum(F))) +::STMT +MATRIX:D,parsertemp570375,classMeans +LITERAL_FLOAT:0.5 +*(0.5,%*%(%*%(-(D,classMeans),parsertemp570375),t(-(D,classMeans)))) +::STMT +MATRIX:parsertemp393571,W3_rand,parsertemp393574 +LITERAL_FLOAT:0.128920512778062 +t(%*%(*(0.128920512778062,W3_rand),t(/(parsertemp393571,parsertemp393574)))) +::STMT +MATRIX:_sbcvar1716 +LITERAL_FLOAT:120.0 ++(120.0,nrow(_sbcvar1716)) +::STMT +MATRIX:negSampleMeans,negSamples +LITERAL_FLOAT:2.0,150.0 +-(colSums(^(negSamples,2.0)),*(150.0,^(negSampleMeans,2.0))) +::STMT +MATRIX:Mask1 +LITERAL_FLOAT:0.0 +>(colSums(Mask1),0.0) +::STMT +MATRIX:X,y +LITERAL_FLOAT:0.0,1.0 +INT:int942,m +-(%*%(X,rand(m,int942,0.0,1.0)),y) +::STMT +MATRIX:MDx,MUx,MLx ++(+(MUx,MDx),MLx) +::STMT +FLOAT:ssPrev,parsertemp265727,parsertemp265726 +LITERAL_FLOAT:1.0 +abs(-(1.0,/(/(parsertemp265726,parsertemp265727),ssPrev))) +::STMT +MATRIX:ytest +LITERAL_FLOAT:2.0 +^(/(sum(ytest),nrow(ytest)),2.0) +::STMT +MATRIX:means,Y_counts,parsertemp560529 +LITERAL_FLOAT:1.0 +sum(<(*(means,%*%(Y_counts,parsertemp560529)),1.0)) +::STMT +MATRIX:t,parsertemp171088,parsertemp171083,parsertemp171094 +FLOAT:float707 +LITERAL_FLOAT:0.0,1.0,2.515517 ++(-(0.0,sqrt(*(float707,parsertemp171083))),/(+(2.515517,*(t,parsertemp171088)),+(1.0,*(t,parsertemp171094)))) +::STMT +LITERAL_FLOAT:2000.0 +sqrt(2000.0) +::STMT +MATRIX:parsertemp31024,parsertemp31022 +FLOAT:int777 +LITERAL_FLOAT:1.0,100.0 +/(/(-(colSums(parsertemp31022),*(int777,parsertemp31024)),-(100.0,1.0)),100.0) +::STMT +MATRIX:r,d,parsertemp43999 +cast.FLOAT(/(sum(*(r,r)),%*%(t(d),+(d,parsertemp43999)))) +::STMT +MATRIX:C,parsertemp265706,parsertemp265704,Z,XtZ +FLOAT:ss,ZtZ_sum +trace(*(+(%*%(parsertemp265704,Z),*(parsertemp265706,ss)),%*%(t(C),/(XtZ,ZtZ_sum)))) +::STMT +FLOAT:sample_frac +LITERAL_FLOAT:0.0,1.0 +INT:parsertemp553005,int999 +<=(rand(parsertemp553005,int999,0.0,1.0),sample_frac) +::STMT +MATRIX:classFeatureCounts +FLOAT:laplaceCorrection ++(classFeatureCounts,laplaceCorrection) +::STMT +MATRIX:U,row_nonzeros +LITERAL_FLOAT:2.0 +*(^(U,2.0),row_nonzeros) +::STMT +MATRIX:184_probs,183_dpred,parsertemp146939,outr2 +FLOAT:beta1 +LITERAL_FLOAT:1.0 +*(-(1.0,beta1),%*%(t(outr2),-(*(183_dpred,184_probs),*(184_probs,parsertemp146939)))) +::STMT +MATRIX:X,y +LITERAL_FLOAT:0.0,-1.0 +INT:int578,int705 +*(*(y,-1.0),%*%(X,rand(int578,int705,0.0,0.0))) +::STMT +FLOAT:int625 +LITERAL_FLOAT:-1.0 +INT:int426,int191 ++(diag(rand(int426,int191,-1.0,-1.0)),int625) +::STMT +MATRIX:Bxu,Bxd +LITERAL_FLOAT:2.0 +diag(*(2.0,+(Bxd,Bxu))) +::STMT +MATRIX:45_CVars,45_CFreqs +FLOAT:float192,int474,parsertemp13703,int43,int766 +LITERAL_FLOAT:1.0,1000.0 +/(sum(*(-(45_CFreqs,int43),45_CVars)),*(-(1000.0,1.0),/(*(parsertemp13703,int766),-(int474,float192)))) +::STMT +MATRIX:d,parsertemp43996,parsertemp43997 +FLOAT:C +%*%(t(d),+(d,*(C,%*%(parsertemp43996,parsertemp43997)))) +::STMT +MATRIX:p,parsertemp1936,parsertemp1937 +FLOAT:norm_r2 +/(norm_r2,cast.FLOAT(%*%(t(p),+(parsertemp1936,parsertemp1937)))) +::STMT +MATRIX:parsertemp11741 +LITERAL_FLOAT:1.0 ++(1.0,parsertemp11741) +::STMT +MATRIX:distance_matrix,parsertemp447763,upper_triangle ++(+(distance_matrix,t(upper_triangle)),diag(parsertemp447763)) +::STMT +MATRIX:s,parsertemp44016 +FLOAT:delta2 +-(delta2,cast.FLOAT(%*%(t(s),-(s,parsertemp44016)))) +::STMT +LITERAL_FLOAT:2.225E-307 +2.225E-307 +::STMT +MATRIX:col,less_than_lb,parsertemp24102,parsertemp24103 +FLOAT:int918,num_bins,int391 +LITERAL_FLOAT:1.0 ++(*(-(-(int391,less_than_lb),>(col,num_bins)),+(round(parsertemp24102),1.0)),*(>(+(parsertemp24103,int918),num_bins),num_bins)) +::STMT +MATRIX:A,B ++(ncol(A),ncol(B)) +::STMT +FLOAT:log_l,new_log_l +LITERAL_FLOAT:1.0E-14 +*(+(abs(log_l),abs(new_log_l)),1.0E-14) +::STMT +LITERAL_FLOAT:1.0,50.0 +*(50.0,1.0) +::STMT +MATRIX:X +abs(X) +::STMT +FLOAT:step +LITERAL_FLOAT:0.95 +*(step,0.95) +::STMT +MATRIX:parsertemp415351,ytest +FLOAT:parsertemp415362,n +LITERAL_FLOAT:1.0 +sqrt(/(-(sum(parsertemp415351),*(n,parsertemp415362)),-(nrow(ytest),1.0))) +::STMT +MATRIX:t_gp,parsertemp170245,parsertemp170239 +FLOAT:float726 +LITERAL_FLOAT:1.0,-0.284496736,0.254829592 ++(0.254829592,*(/(1.0,+(float726,parsertemp170239)),+(-0.284496736,*(t_gp,parsertemp170245)))) +::STMT +FLOAT:parsertemp557354,prob_true +LITERAL_FLOAT:0.6931471805599453 +/(*(prob_true,parsertemp557354),0.6931471805599453) +::STMT +FLOAT:num_records,i +LITERAL_FLOAT:1.0 +*(num_records,-(i,1.0)) +::STMT +FLOAT:num_min,num_max ++(num_min,num_max) +::STMT +FLOAT:i,n +LITERAL_FLOAT:1.0,253.0 +-(n,-(+(i,253.0),1.0)) +::STMT +MATRIX:tmp,Y +LITERAL_FLOAT:0.0 +>(1-*(Y,tmp),0.0) +::STMT +MATRIX:b,X,sb +exp(%*%(X,+(b,sb))) +::STMT +MATRIX:parsertemp436668,X,parsertemp436672 +LITERAL_FLOAT:1.0,2.0 +INT:int254,parsertemp436666 +-(*(rand(int254,parsertemp436666,1.0,1.0),t(rowSums(parsertemp436668))),*(2.0,%*%(X,t(parsertemp436672)))) +::STMT +FLOAT:i,n +LITERAL_FLOAT:1.0,256.0 +-(n,-(+(i,256.0),1.0)) +::STMT +LITERAL_FLOAT:4.0 +INT:int785,int18 +rand(int18,int785,4.0,4.0) +::STMT +MATRIX:parsertemp115858,X,parsertemp115862,parsertemp115860 +FLOAT:parsertemp115863,n +LITERAL_FLOAT:0.0,1.0 +*(/(-(t(parsertemp115858),*(n,parsertemp115860)),-(nrow(X),1.0)),-(1.0,<=(/(parsertemp115862,parsertemp115863),0.0))) +::STMT +MATRIX:obj,objnew,gs +cast.FLOAT(-(-(objnew,obj),gs)) +::STMT +MATRIX:determinants +FLOAT:nFeats +LITERAL_FLOAT:6.283185307179586 +*(^(6.283185307179586,nFeats),determinants) +::STMT +MATRIX:R +LITERAL_FLOAT:1.0,2.0 +INT:parsertemp500303,int480 +%*%(rowSums(^(R,2.0)),rand(int480,parsertemp500303,1.0,1.0)) +::STMT +MATRIX:lambda,g,beta +*(+(g,*(lambda,beta)),+(g,*(lambda,beta))) +::STMT +MATRIX:s,d,alpha +FLOAT:parsertemp44004 +%*%(t(+(s,*(parsertemp44004,d))),+(s,*(cast.FLOAT(alpha),d))) +::STMT +MATRIX:parsertemp31023,parsertemp31025,parsertemp31030,parsertemp31032 +FLOAT:int254,int53,int315,int955 +LITERAL_FLOAT:150.0,100.0 ++(/(/(-(parsertemp31023,parsertemp31025),-(int955,int254)),100.0),/(/(-(parsertemp31030,parsertemp31032),-(int53,int315)),150.0)) +::STMT +MATRIX:parsertemp146940,184_dtemp,outr2 +FLOAT:beta2 +LITERAL_FLOAT:1.0,2.0 +*(-(1.0,beta2),^(%*%(t(outr2),-(184_dtemp,parsertemp146940)),2.0)) +::STMT +MATRIX:parsertemp40086,addedE,addedX2 +/(t(%*%(t(addedE),addedX2)),t(parsertemp40086)) +::STMT +MATRIX:lambda,parsertemp170067,scale_X,parsertemp170065,p_CG ++(*(cast.FLOAT(lambda),cast.FLOAT(p_CG)),*(cast.FLOAT(diag(scale_X)),cast.FLOAT(%*%(parsertemp170065,parsertemp170067)))) +::STMT +MATRIX:e +LITERAL_FLOAT:4.0 +*(4.0,e) +::STMT +FLOAT:m2X +LITERAL_FLOAT:1.0002795638803466 +sqrt(*(m2X,1.0002795638803466)) +::STMT +FLOAT:parsertemp170147,parsertemp170145,p_CG,z +LITERAL_FLOAT:-1.0,2.0 +/(+(*(*(z,p_CG),-1.0),sqrt(-(parsertemp170145,parsertemp170147))),sum(^(p_CG,2.0))) +::STMT +MATRIX:P,minD,D,X +%*%(t(/(<=(D,minD),rowSums(P))),X) +::STMT +MATRIX:present_domain_vals_mat,parsertemp27485 +FLOAT:my +LITERAL_FLOAT:2.0 +^(-(%*%(present_domain_vals_mat,parsertemp27485),my),2.0) +::STMT +MATRIX:D +LITERAL_FLOAT:1.0 +/(1.0,+(D,1.0)) +::STMT +MATRIX:Y +FLOAT:bernoulli_No_label +LITERAL_FLOAT:1.0 +-(1.0,==(Y,bernoulli_No_label)) +::STMT +FLOAT:window_size,k,n +LITERAL_FLOAT:2.0 +-(+(-(n,window_size),2.0),k) +::STMT +MATRIX:tmp,w,out +LITERAL_FLOAT:0.5 ++(*(0.5,cast.FLOAT(%*%(out,out))),*(0.5,cast.FLOAT(%*%(w,tmp)))) +::STMT +MATRIX:flip_neg,is_LT_infinite,Y,parsertemp171294 +rowSums(*(Y,%*%(+(parsertemp171294,is_LT_infinite),flip_neg))) +::STMT +MATRIX:lambda,B,S +LITERAL_FLOAT:2.0 +*(lambda,^(+(B,S),2.0)) +::STMT +FLOAT:n_components,cov_param,n_features +LITERAL_FLOAT:1.0 +-(+(+(cov_param,*(n_features,n_components)),n_components),1.0) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0 +*(-(0.0,sum(X)),-(0.0,sum(X))) +::STMT +FLOAT:sample_block_size,num_samples +LITERAL_FLOAT:1.0 +-(*(sample_block_size,num_samples),1.0) +::STMT +MATRIX:R,parsertemp500359 +LITERAL_FLOAT:2.0 +%*%(rowSums(^(R,2.0)),parsertemp500359) +::STMT +MATRIX:intercept,X,beta +FLOAT:int198,int797 +INT:num_records,int979 ++(%*%(X,beta),%*%(rand(num_records,int979,int797,int198),intercept)) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:0.0 +exp(-(0.0,exp(finite_linear_terms))) +::STMT +MATRIX:output_values +FLOAT:log_odds,learning_rate +LITERAL_FLOAT:2.7182818284 +^(2.7182818284,+(log_odds,*(learning_rate,sum(output_values)))) +::STMT +MATRIX:log_prob,X +FLOAT:parsertemp436712 +LITERAL_FLOAT:-0.5 +*(-0.5,+(*(ncol(X),parsertemp436712),log_prob)) +::STMT +MATRIX:parsertemp174552 +LITERAL_FLOAT:0.0 +sum(abs(==(parsertemp174552,0.0))) +::STMT +MATRIX:dl_matrix +FLOAT:cost ++(cast.FLOAT(dl_matrix),cost) +::STMT +MATRIX:C,Xm,parsertemp265701 +%*%(t(Xm),%*%(Xm,%*%(C,parsertemp265701))) +::STMT +LITERAL_FLOAT:6.0 +sqrt(6.0) +::STMT +MATRIX:X +LITERAL_FLOAT:-1.0 +exp(*(X,-1.0)) +::STMT +MATRIX:parsertemp72202,subspace_idx +LITERAL_FLOAT:1.0 +diag(/(1.0,<(-(subspace_idx,parsertemp72202),1.0))) +::STMT +MATRIX:y_corr +FLOAT:link_power +LITERAL_FLOAT:0.0 +^(+(y_corr,==(y_corr,0.0)),link_power) +::STMT +MATRIX:prec_chol,X,parsertemp436696,bc_matrix,parsertemp436692 +FLOAT:int149 +LITERAL_FLOAT:2.0 ++(-(*(bc_matrix,t(parsertemp436692)),*(2.0,%*%(X,parsertemp436696))),%*%(rowSums(*(X,X)),t(^(prec_chol,int149)))) +::STMT +MATRIX:mean,X,weight,parsertemp437211,parsertemp437629 ++(/(%*%(t(parsertemp437211),-(X,mean)),cast.FLOAT(weight)),diag(parsertemp437629)) +::STMT +MATRIX:parsertemp170158,parsertemp170136 +FLOAT:r_CG,g_reg,278_sq_root_d,z,parsertemp170171,parsertemp170150 +LITERAL_FLOAT:0.5 ++(*(0.5,*(cast.FLOAT(z),+(r_CG,g_reg))),*(+(+(parsertemp170171,z),sum(parsertemp170158)),/(-(parsertemp170150,278_sq_root_d),sum(parsertemp170136)))) +::STMT +MATRIX:E,O +/(*(sum(-(O,E)),sum(-(O,E))),sum(E)) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:2.0 +exp(*(linear_terms,2.0)) +::STMT +MATRIX:S,addedX2 +FLOAT:level +rowSums(==(%*%(S,t(addedX2)),level)) +::STMT +MATRIX:parsertemp410080,d_r_rev,parsertemp410079,parsertemp410090 +LITERAL_FLOAT:-1.0 ++(*(cast.FLOAT(%*%(parsertemp410079,parsertemp410080)),-1.0),cast.FLOAT(%*%(t(d_r_rev),parsertemp410090))) +::STMT +FLOAT:i,subvector_size +LITERAL_FLOAT:1.0 ++(*(-(i,1.0),subvector_size),1.0) +::STMT +MATRIX:Y_prob +/(Y_prob,rowSums(Y_prob)) +::STMT +MATRIX:scale_X,p_CG +*(cast.FLOAT(diag(scale_X)),p_CG) +::STMT +MATRIX:prob +FLOAT:threshold +LITERAL_FLOAT:0.0 +==(>(prob,threshold),0.0) +::STMT +MATRIX:288_left,291_d,288_right +LITERAL_FLOAT:0.0,2.0 ++(/(^(sum(288_left),2.0),+(sum(291_d),0.0)),/(^(sum(288_right),2.0),+(sum(291_d),0.0))) +::STMT +LITERAL_FLOAT:1.0E9 +1.0E9 +::STMT +MATRIX:y_corr +FLOAT:link_power +^(y_corr,link_power) +::STMT +MATRIX:X_batch,186_dX,parsertemp146949,parsertemp146957,parsertemp146955 +LITERAL_FLOAT:2.0 +^(%*%(t(X_batch),*(*(parsertemp146957,parsertemp146955),%*%(186_dX,parsertemp146949))),2.0) +::STMT +LITERAL_FLOAT:2.0 +sqrt(2.0) +::STMT +MATRIX:W1_rand +LITERAL_FLOAT:0.086386842558136 +*(0.086386842558136,W1_rand) +::STMT +MATRIX:S +FLOAT:level +LITERAL_FLOAT:2.0 +==(%*%(S,t(S)),-(level,2.0)) +::STMT +MATRIX:D,parsertemp220844,ZERODIAG,beta +*(*(exp(*(parsertemp220844,beta)),ZERODIAG),D) +::STMT +MATRIX:resp,mean,X +%*%(t(*(-(X,mean),resp)),-(X,mean)) +::STMT +MATRIX:K_inv,scores,Ks +cast.FLOAT(%*%(%*%(t(Ks),K_inv),scores)) +::STMT +MATRIX:parsertemp43993,s,d,alpha_deno ++(s,*(/(sum(parsertemp43993),cast.FLOAT(alpha_deno)),d)) +::STMT +MATRIX:sb +FLOAT:delta +LITERAL_FLOAT:2.0 +-(cast.FLOAT(%*%(t(sb),sb)),^(delta,2.0)) +::STMT +MATRIX:surv,se_surv +FLOAT:z_alpha_2 +exp(/(*(z_alpha_2,se_surv),surv)) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0004995004995005 +*(1.0004995004995005,m2) +::STMT +MATRIX:X +FLOAT:val +<(X,val) +::STMT +MATRIX:parsertemp171366,p_one_m_one +LITERAL_FLOAT:3.141592653589793 +/(%*%(parsertemp171366,p_one_m_one),3.141592653589793) +::STMT +MATRIX:W +FLOAT:int98,parsertemp97,int798,m4,int820,int239,wt +LITERAL_FLOAT:1.0 +-(*(*(^(wt,int98),+(wt,int798)),m4),*(*(*(int820,parsertemp97),^(wt,int239)),-(sum(W),1.0))) +::STMT +MATRIX:184_probs,183_dpred,parsertemp146939 +LITERAL_FLOAT:2.0 +^(colSums(-(*(183_dpred,184_probs),*(184_probs,parsertemp146939))),2.0) +::STMT +MATRIX:p_CG,z +FLOAT:trust_delta_sq +LITERAL_FLOAT:2.0 +*(sum(^(p_CG,2.0)),-(^(cast.FLOAT(z),2.0),trust_delta_sq)) +::STMT +FLOAT:run_index +LITERAL_FLOAT:1.0,96.0 +*(96.0,-(run_index,1.0)) +::STMT +MATRIX:X,Y +LITERAL_FLOAT:2.0 +/(abs(-(X,Y)),/(+(abs(X),abs(Y)),2.0)) +::STMT +FLOAT:padh,Hin,Hf +LITERAL_FLOAT:2.0 +-(+(Hin,*(2.0,padh)),Hf) +::STMT +MATRIX:Grad +LITERAL_FLOAT:-1.0,2.0 +^(*(Grad,-1.0),2.0) +::STMT +MATRIX:parsertemp389604,X_batch,2364_2361_Y,W4,parsertemp389601 +FLOAT:int996 +LITERAL_FLOAT:1.0 +%*%(*(-(/(parsertemp389601,parsertemp389604),X_batch),-(1.0,^(2364_2361_Y,int996))),W4) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:0.0 +*(exp(-(0.0,exp(linear_terms))),exp(linear_terms)) +::STMT +MATRIX:parsertemp436667,precisions,bc_matrix +*(bc_matrix,t(rowSums(*(parsertemp436667,precisions)))) +::STMT +MATRIX:_sbcvar96,_sbcvar95 +LITERAL_FLOAT:-1.0 ++(%*%(_sbcvar95,_sbcvar96),-1.0) +::STMT +FLOAT:D +LITERAL_FLOAT:0.5 +*(0.5,sqrt(D)) +::STMT +MATRIX:W +LITERAL_FLOAT:1.0,2.0,3.0 +*(*(-(sum(W),1.0),-(sum(W),2.0)),-(sum(round(W)),3.0)) +::STMT +MATRIX:parsertemp496901 +FLOAT:std,arch_coef +LITERAL_FLOAT:2.0 +*(arch_coef,^(*(cast.FLOAT(parsertemp496901),std),2.0)) +::STMT +MATRIX:parsertemp183431,X,mu +FLOAT:int754,N +LITERAL_FLOAT:1.0 +-(/(%*%(t(X),X),-(N,1.0)),*(/(N,-(N,int754)),%*%(t(mu),/(parsertemp183431,N)))) +::STMT +MATRIX:ss,X2 +LITERAL_FLOAT:1.0 +-(/(nrow(X2),ss),1.0) +::STMT +MATRIX:r_CG,q_CG +FLOAT:alpha_CG +sum(*(+(r_CG,*(alpha_CG,q_CG)),+(r_CG,*(alpha_CG,q_CG)))) +::STMT +FLOAT:int528 +LITERAL_FLOAT:0.0 +INT:m,int848 +sum(abs(rand(m,int848,0.0,int528))) +::STMT +FLOAT:ytest,yhat +LITERAL_FLOAT:1.0,2.0 +*(1.0,^(/(-(ytest,yhat),1.0),2.0)) +::STMT +MATRIX:z +sqrt(cast.FLOAT(%*%(t(z),z))) +::STMT +MATRIX:X_batch,W_1 +LITERAL_FLOAT:0.0 ++(%*%(X_batch,W_1),0.0) +::STMT +FLOAT:parsertemp169812 +LITERAL_FLOAT:2.302585092994046,0.5 +-(/(parsertemp169812,2.302585092994046),0.5) +::STMT +MATRIX:finite_linear_terms +FLOAT:int375 +LITERAL_FLOAT:-1.0,2.0 +exp(/(*(^(finite_linear_terms,int375),-1.0),2.0)) +::STMT +MATRIX:ones_ctg +LITERAL_FLOAT:1.0 +-(1.0,diag(ones_ctg)) +::STMT +MATRIX:parsertemp11251 +LITERAL_FLOAT:2.0 +t(^(2.0,parsertemp11251)) +::STMT +MATRIX:means,Y_counts,parsertemp560529 +LITERAL_FLOAT:5.0 +sum(<(*(means,%*%(Y_counts,parsertemp560529)),5.0)) +::STMT +FLOAT:parsertemp83 +-(cast.MATRIX(parsertemp83),parsertemp83) +::STMT +MATRIX:ts +LITERAL_FLOAT:1.0,4.0 ++(-(length(ts),4.0),1.0) +::STMT +MATRIX:parsertemp389604,X_batch,2364_2361_Y,parsertemp389601 +FLOAT:int323 +LITERAL_FLOAT:1.0 +t(*(-(/(parsertemp389601,parsertemp389604),X_batch),-(1.0,^(2364_2361_Y,int323)))) +::STMT +MATRIX:parsertemp397824,W3_rand +FLOAT:int796,int27 +LITERAL_FLOAT:0.5107539184552492 +%*%(*(0.5107539184552492,W3_rand),t(/(-(parsertemp397824,int27),+(parsertemp397824,int796)))) +::STMT +MATRIX:X +FLOAT:x +/(-(x,cast.FLOAT(X)),-(cast.FLOAT(X),cast.FLOAT(X))) +::STMT +MATRIX:parsertemp27718,_sbcvar92 +FLOAT:220_W +LITERAL_FLOAT:0.0,1.0E-4 ++(*(==(/(parsertemp27718,220_W),0.0),1.0E-4),/(%*%(rowSums(_sbcvar92),colSums(_sbcvar92)),sum(_sbcvar92))) +::STMT +MATRIX:Y,predicted_Y +LITERAL_FLOAT:0.0,1000.0 +/(sum(==(-(predicted_Y,Y),0.0)),1000.0) +::STMT +FLOAT:x,parsertemp169817 +LITERAL_FLOAT:10000.0 +/(round(*(x,exp(parsertemp169817))),10000.0) +::STMT +FLOAT:e,decay +LITERAL_FLOAT:1.0 +/(1.0,+(1.0,*(decay,e))) +::STMT +MATRIX:z +FLOAT:trust_delta_sq +-(cast.FLOAT(%*%(t(z),z)),trust_delta_sq) +::STMT +MATRIX:W +FLOAT:m2 +LITERAL_FLOAT:1.0 +/(*(m2,sum(round(W))),-(sum(round(W)),1.0)) +::STMT +MATRIX:scale_X,beta +*(cast.FLOAT(diag(scale_X)),cast.FLOAT(beta)) +::STMT +MATRIX:X,Y,K +LITERAL_FLOAT:-1.0 ++(*(*(K,-1.0),-(X,X)),-(Y,Y)) +::STMT +MATRIX:yhat +FLOAT:mean_y_test +LITERAL_FLOAT:2.0 +sum(^(-(yhat,mean_y_test),2.0)) +::STMT +MATRIX:parsertemp436669,prec_chol,X,parsertemp436673,bc_matrix +FLOAT:int367 +LITERAL_FLOAT:2.0 ++(-(*(bc_matrix,t(parsertemp436669)),*(2.0,%*%(X,parsertemp436673))),%*%(^(X,2.0),t(^(prec_chol,int367)))) +::STMT +FLOAT:sd_Y,sd_X +abs(-(sqrt(sd_Y),sqrt(sd_X))) +::STMT +MATRIX:parsertemp231464 +FLOAT:feature_frac +t(<=(parsertemp231464,feature_frac)) +::STMT +MATRIX:m_correct +FLOAT:i,in_i_k_min +LITERAL_FLOAT:1.0 +/(rowSums(m_correct),-(+(in_i_k_min,i),1.0)) +::STMT +MATRIX:R,parsertemp503780 +FLOAT:int440,int936 +INT:int175,parsertemp503363 +%*%(t(+(R,diag(parsertemp503780))),+(R,diag(rand(parsertemp503363,int175,int440,int936)))) +::STMT +MATRIX:Q,R +LITERAL_FLOAT:2.0 +*(2.0,%*%(R,t(Q))) +::STMT +MATRIX:C,X +LITERAL_FLOAT:-2.0 +*(-2.0,%*%(X,t(C))) +::STMT +MATRIX:the_exp +FLOAT:int76,int968 +LITERAL_FLOAT:1.0,1.0E7 +*(-(1.0,==(+(int76,the_exp),1.0E7)),-(1.0,exp(*(the_exp,int968)))) +::STMT +MATRIX:X +LITERAL_FLOAT:2.0,1.0E-6 +*(1.0E-6,sum(^(X,2.0))) +::STMT +MATRIX:2701_mask,2702_X +LITERAL_FLOAT:0.0,0.5 +*(>(2702_X,0.0),/(2701_mask,0.5)) +::STMT +MATRIX:svUpBnd,R,svLowBnd +diag(*(<=(R,cast.FLOAT(svUpBnd)),>(R,cast.FLOAT(svLowBnd)))) +::STMT +MATRIX:X +rev(X) +::STMT +MATRIX:obj,parsertemp44077 +FLOAT:C,float763,parsertemp44081 +cast.FLOAT(-(obj,+(*(float763,parsertemp44077),*(C,parsertemp44081)))) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:1.0,1.0E7 +-(1.0,==(+(1.0E7,exp(finite_linear_terms)),1.0E7)) +::STMT +MATRIX:ot2 +FLOAT:int160 +LITERAL_FLOAT:25.0,100.0 +/(*(sum(>(ot2,int160)),100.0),25.0) +::STMT +FLOAT:link_power +LITERAL_FLOAT:0.5 +^(0.5,link_power) +::STMT +MATRIX:prec_chol,mu +LITERAL_FLOAT:2.0 +t(*(mu,^(prec_chol,2.0))) +::STMT +LITERAL_FLOAT:192.0 +INT:int522,int415 +rand(int522,int415,192.0,192.0) +::STMT +LITERAL_FLOAT:10000.0 +10000.0 +::STMT +MATRIX:U,V_sum +rowSums(rowSums(/(*(U,U),sum(V_sum)))) +::STMT +FLOAT:num_records +LITERAL_FLOAT:1920.0,1.0 +-(1.0,/(1920.0,num_records)) +::STMT +MATRIX:Q3,IQR +LITERAL_FLOAT:1.5 ++(Q3,*(1.5,IQR)) +::STMT +MATRIX:F +LITERAL_FLOAT:1.0 +*(-(nrow(F),1.0),-(ncol(F),1.0)) +::STMT +MATRIX:ytest,yhat +FLOAT:mean_y_test +LITERAL_FLOAT:2.0 +/(sum(^(-(yhat,mean_y_test),2.0)),sum(^(-(ytest,mean_y_test),2.0))) +::STMT +FLOAT:approx_sample_size,num_records +LITERAL_FLOAT:1.0 +-(1.0,/(approx_sample_size,num_records)) +::STMT +MATRIX:valueCount,parsertemp552531,resp,Y +rowSums(*(==(+(resp,parsertemp552531),Y),valueCount)) +::STMT +MATRIX:pearson_residual_sq +FLOAT:num_records +LITERAL_FLOAT:1.0 +/(sum(pearson_residual_sq),-(num_records,1.0)) +::STMT +MATRIX:z,parsertemp285752 +FLOAT:2234_sq_root_d,parsertemp285742,parsertemp285763,pp_CG +LITERAL_FLOAT:0.5 ++(*(0.5,cast.FLOAT(%*%(z,parsertemp285752))),*(parsertemp285763,/(-(parsertemp285742,2234_sq_root_d),pp_CG))) +::STMT +MATRIX:parsertemp539203,T,event +FLOAT:int631 +LITERAL_FLOAT:2.0,0.6666666666666666 +/(^(/(*(parsertemp539203,int631),2.0),0.6666666666666666),/(-(max(T),min(T)),sum(event))) +::STMT +MATRIX:prec_chol,parsertemp438810,X,bc_matrix,parsertemp438806 +FLOAT:int230,int476 +LITERAL_FLOAT:2.0 ++(-(*(bc_matrix,t(parsertemp438806)),*(2.0,%*%(X,parsertemp438810))),%*%(rowSums(^(X,int230)),t(^(prec_chol,int476)))) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:0.231641888 +*(abs(finite_linear_terms),0.231641888) +::STMT +MATRIX:r_LS +FLOAT:alpha_LS,norm_r2_LS,p_LS +LITERAL_FLOAT:2.0 +/(^(+(cast.FLOAT(r_LS),*(alpha_LS,p_LS)),2.0),norm_r2_LS) +::STMT +MATRIX:Y_counts,parsertemp560507,Y,parsertemp560512 +-(sum(rowSums(*(Y,parsertemp560507))),sum(*(Y_counts,rowSums(parsertemp560512)))) +::STMT +MATRIX:maskd1,out1,186_dX,parsertemp146949 +FLOAT:p +LITERAL_FLOAT:0.0 +colSums(*(>(out1,0.0),*(/(maskd1,p),%*%(186_dX,parsertemp146949)))) +::STMT +LITERAL_FLOAT:6.0,2003.0 +*(6.0,2003.0) +::STMT +MATRIX:col_nonzeros,U,parsertemp382849,V,parsertemp382852 +FLOAT:reg ++(t(%*%(t(U),*(parsertemp382849,parsertemp382852))),*(*(reg,V),col_nonzeros)) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,500.0 ++(rowSums(classFeatureCounts),*(500.0,1.0)) +::STMT +MATRIX:X,parsertemp471907 +LITERAL_FLOAT:1.0E-14 +sum(>(abs(-(X,parsertemp471907)),1.0E-14)) +::STMT +FLOAT:42_m2Y,42_m2X +LITERAL_FLOAT:1.001001001001001 +*(sqrt(*(42_m2X,1.001001001001001)),sqrt(*(42_m2Y,1.001001001001001))) +::STMT +MATRIX:posSampleVariances,negSampleMeans,posSampleMeans,negSampleVariances +FLOAT:int673,int18 +/(-(posSampleMeans,negSampleMeans),sqrt(+(/(posSampleVariances,int673),/(negSampleVariances,int18)))) +::STMT +MATRIX:parsertemp31908,e +FLOAT:l +/(t(%*%(t(e),==(parsertemp31908,l))),t(colSums(==(parsertemp31908,l)))) +::STMT +MATRIX:scale_X,p_CG,shift_X ++(*(cast.FLOAT(diag(scale_X)),p_CG),*(cast.FLOAT(shift_X),p_CG)) +::STMT +MATRIX:posSamples,posSampleMeans +FLOAT:int970,int726 +LITERAL_FLOAT:1.0,100.0 +/(-(colSums(^(posSamples,int726)),*(100.0,^(posSampleMeans,int970))),-(100.0,1.0)) +::STMT +MATRIX:X,parsertemp129018 +LITERAL_FLOAT:1.0 +*(max(parsertemp129018),-(ncol(X),1.0)) +::STMT +MATRIX:ss,se +FLOAT:130_eAvg +LITERAL_FLOAT:1.0 +-(/(/(se,ss),130_eAvg),1.0) +::STMT +MATRIX:X,parsertemp222929 +*(cast.FLOAT(parsertemp222929),-(X,X)) +::STMT +MATRIX:yhat +FLOAT:mean_y_test +LITERAL_FLOAT:2.0 +^(-(yhat,mean_y_test),2.0) +::STMT +LITERAL_FLOAT:1.0E20 +1.0E20 +::STMT +MATRIX:Yhat_prime,E,H3 +%*%(t(*(E,Yhat_prime)),H3) +::STMT +MATRIX:ssX_p,X +%*%(t(X),%*%(X,ssX_p)) +::STMT +MATRIX:lambda,beta +LITERAL_FLOAT:0.0 +%*%(t(+(0.0,*(lambda,beta))),+(0.0,*(lambda,beta))) +::STMT +MATRIX:Xm,parsertemp265718 +abs(/(-(sum(parsertemp265718),sum(Xm)),sum(Xm))) +::STMT +MATRIX:feature +FLOAT:n_bins +/(-(max(feature),min(feature)),n_bins) +::STMT +MATRIX:2700_X,2700_W,parsertemp459178,2699_dtemp +FLOAT:lr +LITERAL_FLOAT:5.0E-4 +*(lr,+(%*%(t(2700_X),-(2699_dtemp,parsertemp459178)),*(5.0E-4,2700_W))) +::STMT +MATRIX:foffb,foffe +LITERAL_FLOAT:1.0 +-(cast.FLOAT(foffe),+(cast.FLOAT(foffb),1.0)) +::STMT +MATRIX:p_CG +FLOAT:int351,trust_delta_sq,z +-(*(*(cast.FLOAT(z),sum(p_CG)),*(cast.FLOAT(z),sum(p_CG))),*(sum(^(p_CG,int351)),-(*(z,z),trust_delta_sq))) +::STMT +MATRIX:b,H +%*%(t(b),-(+(H,t(H)),diag(diag(H)))) +::STMT +MATRIX:p_CG,z +FLOAT:trust_delta_sq +*(cast.FLOAT(%*%(t(p_CG),p_CG)),-(cast.FLOAT(%*%(z,z)),trust_delta_sq)) +::STMT +FLOAT:eps +LITERAL_FLOAT:0.5 ++(0.5,eps) +::STMT +MATRIX:z +FLOAT:pp,trust_delta_sq +*(pp,-(sum(*(z,z)),trust_delta_sq)) +::STMT +MATRIX:n_risk_stratum,n_risk_i2j,V1 +FLOAT:I_i1i2 +*(V1,-(I_i1i2,/(n_risk_i2j,n_risk_stratum))) +::STMT +MATRIX:col,missing_indicator_mat +FLOAT:global_mean ++(col,*(missing_indicator_mat,global_mean)) +::STMT +FLOAT:parsertemp539092,parsertemp539091,num_groups +LITERAL_FLOAT:1.0,2.0 +-(+(+(*(parsertemp539091,parsertemp539092),1.0),num_groups),2.0) +::STMT +MATRIX:I,parsertemp472299 +LITERAL_FLOAT:0.0 +*(==(!=(*(parsertemp472299,I),0.0),0.0),I) +::STMT +MATRIX:Y,linear_terms +LITERAL_FLOAT:0.0 +*(rowSums(Y),exp(-(0.0,exp(linear_terms)))) +::STMT +MATRIX:W +FLOAT:m4 +LITERAL_FLOAT:1.0,2.0 +*(*(^(sum(W),2.0),+(sum(W),1.0)),m4) +::STMT +MATRIX:X_batch,dout1 +LITERAL_FLOAT:2.0 +^(%*%(t(X_batch),dout1),2.0) +::STMT +MATRIX:m_err +rowSums(colSums(m_err)) +::STMT +FLOAT:int453,se_g1,int711,int305,int506,parsertemp113,wt +sqrt(/(*(*(int711,parsertemp113),^(se_g1,int453)),*(+(wt,int506),-(wt,int305)))) +::STMT +MATRIX:X +FLOAT:N +LITERAL_FLOAT:1.0 +/(%*%(t(X),X),-(N,1.0)) +::STMT +FLOAT:277_sq_root_d,parsertemp170093,pp_CG,pq_CG +LITERAL_FLOAT:0.5 +*(*(0.5,/(+(parsertemp170093,277_sq_root_d),pp_CG)),pq_CG) +::STMT +FLOAT:parsertemp191170,Wf +LITERAL_FLOAT:0.0,1.0,2.0 +INT:parsertemp191169,F +*(rand(F,parsertemp191169,0.0,1.0),sqrt(/(2.0,*(parsertemp191170,Wf)))) +::STMT +MATRIX:b,W,X ++(%*%(X,W),b) +::STMT +MATRIX:y_corr +LITERAL_FLOAT:0.0,1.0 ++(<=(y_corr,0.0),>=(y_corr,1.0)) +::STMT +MATRIX:parsertemp231461 +LITERAL_FLOAT:0.1 +<=(parsertemp231461,0.1) +::STMT +MATRIX:C,Xm,parsertemp265701,XtZ +FLOAT:ZtZ_sum +%*%(%*%(%*%(Xm,%*%(C,parsertemp265701)),t(/(XtZ,ZtZ_sum))),t(Xm)) +::STMT +MATRIX:X_batch,parsertemp146957,187_dX +FLOAT:beta2 +LITERAL_FLOAT:1.0,2.0 +*(-(1.0,beta2),^(%*%(t(X_batch),*(parsertemp146957,187_dX)),2.0)) +::STMT +MATRIX:parsertemp437305,_funvar2125,parsertemp437277,parsertemp437272 +exp(-(+(_funvar2125,parsertemp437305),+(parsertemp437272,parsertemp437277))) +::STMT +MATRIX:W +FLOAT:m3 +LITERAL_FLOAT:2.0 +*(^(sum(round(W)),2.0),m3) +::STMT +MATRIX:pearson_residual_sq +FLOAT:num_features,num_records +/(sum(pearson_residual_sq),-(num_records,num_features)) +::STMT +MATRIX:parsertemp12846,F,parsertemp12848 +FLOAT:q,int265,W +LITERAL_FLOAT:1.0 +/(sum(/(^(parsertemp12848,int265),/(parsertemp12846,W))),*(sum(F),-(q,1.0))) +::STMT +FLOAT:o_init,N +LITERAL_FLOAT:-2.0 +/(*(-2.0,o_init),N) +::STMT +LITERAL_FLOAT:1.0 ++(+(+(1.0,1.0),1.0),1.0) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:1.0 +*(^(exp(linear_terms),1.0),exp(linear_terms)) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:-1.0 +*(^(exp(linear_terms),-1.0),exp(linear_terms)) +::STMT +MATRIX:X +FLOAT:2917_split +-($1:nrow(X),round(*($1,2917_split))) +::STMT +LITERAL_FLOAT:9999.0 +9999.0 +::STMT +MATRIX:prec_chol,mu +LITERAL_FLOAT:2.0 +*(rowSums(^(mu,2.0)),^(prec_chol,2.0)) +::STMT +LITERAL_FLOAT:0.01 +0.01 +::STMT +MATRIX:cm +FLOAT:n +==(t(cm),n) +::STMT +MATRIX:parsertemp389341,X,parsertemp389344 +LITERAL_FLOAT:1.0 +-(/(-(exp(parsertemp389341),1.0),+(exp(parsertemp389344),1.0)),X) +::STMT +LITERAL_FLOAT:100.0 ++(100.0,100.0) +::STMT +MATRIX:b_cumulant,is_natural_parameter_log_zero,parsertemp560392,Y,natural_parameters +FLOAT:int562 +LITERAL_FLOAT:1.0 +-(-(*(Y,natural_parameters),b_cumulant),/(*(>(Y,int562),is_natural_parameter_log_zero),-(1.0,*(parsertemp560392,is_natural_parameter_log_zero)))) +::STMT +MATRIX:X,Y +FLOAT:x +LITERAL_FLOAT:1.0 +*(-(1.0,/(-(x,X),-(X,X))),Y) +::STMT +MATRIX:g_new,s,g_old +*(/(sum(*(g_new,g_new)),sum(*(g_old,g_old))),s) +::STMT +MATRIX:parsertemp265720,parsertemp265715,parsertemp265722 +FLOAT:m,n +LITERAL_FLOAT:2.0 +/(-(+(sum(parsertemp265722),trace(parsertemp265715)),*(2.0,sum(parsertemp265720))),*(n,m)) +::STMT +MATRIX:I,y2 +/(%*%(I,y2),rowSums(I)) +::STMT +MATRIX:A,lambda ++(A,diag(lambda)) +::STMT +MATRIX:X +abs(-(X,round(X))) +::STMT +MATRIX:C,Xm,parsertemp265702 +sum(-(%*%(%*%(Xm,parsertemp265702),t(C)),Xm)) +::STMT +MATRIX:objvals +LITERAL_FLOAT:10.0,1.5,-8.0 +*(*(1.5,^(10.0,-8.0)),cast.FLOAT(objvals)) +::STMT +FLOAT:parsertemp496689,parsertemp496690,parsertemp496694,int69,parsertemp496686,n +LITERAL_FLOAT:1.0,2.0 +*(/(1.0,*(2.0,n)),+(parsertemp496694,/(^(parsertemp496686,int69),+(parsertemp496689,parsertemp496690)))) +::STMT +MATRIX:p,lambda,X ++(%*%(t(X),%*%(X,p)),*(lambda,p)) +::STMT +MATRIX:ot2 +FLOAT:int689 +LITERAL_FLOAT:200.0,100.0 +/(*(sum(>(ot2,int689)),100.0),200.0) +::STMT +LITERAL_FLOAT:1.0,8.0 +-(8.0,1.0) +::STMT +MATRIX:parsertemp171083 +FLOAT:float666 +LITERAL_FLOAT:0.001308,0.189269 ++(0.189269,*(sqrt(*(float666,parsertemp171083)),0.001308)) +::STMT +MATRIX:r_CG,q_CG +FLOAT:alpha_CG +*(+(r_CG,*(alpha_CG,q_CG)),+(r_CG,*(alpha_CG,q_CG))) +::STMT +MATRIX:parsertemp389329,parsertemp389332,W4 +FLOAT:int14,int822 +%*%(W4,t(/(-(parsertemp389329,int14),+(parsertemp389332,int822)))) +::STMT +MATRIX:linear_terms +FLOAT:int6 +LITERAL_FLOAT:1.0 +/(1.0,-(exp(*(linear_terms,int6)),1.0)) +::STMT +MATRIX:p_LS +FLOAT:alpha_LS,r_LS,norm_r2_LS +LITERAL_FLOAT:2.0 +*(/(^(+(r_LS,alpha_LS),2.0),norm_r2_LS),cast.FLOAT(p_LS)) +::STMT +MATRIX:simplex,parsertemp503570 +LITERAL_FLOAT:2.0 +-(*(2.0,/(-(parsertemp503570,simplex),nrow(simplex))),simplex) +::STMT +MATRIX:X_cluster,_funvar62 +|(X_cluster,_funvar62) +::STMT +MATRIX:Y_counts,parsertemp560517,ent1_vec +FLOAT:int324 +sum(*(Y_counts,-(rowSums(parsertemp560517),^(ent1_vec,int324)))) +::STMT +MATRIX:parsertemp410246,parsertemp410249 +LITERAL_FLOAT:0.6666666666666666 +-(max(^(/(parsertemp410246,parsertemp410249),0.6666666666666666)),min(^(/(parsertemp410246,parsertemp410249),0.6666666666666666))) +::STMT +MATRIX:means,Y_counts,Y +/(colSums(-(Y,means)),sum(Y_counts)) +::STMT +MATRIX:p,V +LITERAL_FLOAT:1.0E-8 +%*%(t(p),+(%*%(t(V),%*%(V,p)),*(1.0E-8,p))) +::STMT +MATRIX:parsertemp171245,Y +LITERAL_FLOAT:1.0 +*(Y,/(1.0,-(exp(parsertemp171245),1.0))) +::STMT +MATRIX:parsertemp410977,W,H,parsertemp410974 +%*%(W,/(*(H,%*%(parsertemp410974,parsertemp410977)),t(colSums(W)))) +::STMT +MATRIX:parsertemp31026,parsertemp31033 +FLOAT:int225,int394 +LITERAL_FLOAT:2.0,3352500.0,990000.0 ++(/(^(/(parsertemp31026,int394),2.0),990000.0),/(^(/(parsertemp31033,int225),2.0),3352500.0)) +::STMT +MATRIX:CFreqs1,parsertemp27492,present_domain_vals_mat +FLOAT:int634 +LITERAL_FLOAT:1.0 +/(sum(*(%*%(present_domain_vals_mat,CFreqs1),^(parsertemp27492,int634))),-(nrow(present_domain_vals_mat),1.0)) +::STMT +MATRIX:log_prob,log_det_chol +FLOAT:parsertemp443052,float150 +LITERAL_FLOAT:-0.5 ++(*(-0.5,+(*(parsertemp443052,float150),log_prob)),cast.FLOAT(log_det_chol)) +::STMT +FLOAT:s,i2,n +-(n,*(s,i2)) +::STMT +FLOAT:num_records +LITERAL_FLOAT:100.0 +/(100.0,num_records) +::STMT +MATRIX:parsertemp31115,parsertemp31108 +FLOAT:int207,int915 +LITERAL_FLOAT:7.996E9,2.0,3.37275E9 ++(/(^(/(parsertemp31108,int915),2.0),7.996E9),/(^(/(parsertemp31115,int207),2.0),3.37275E9)) +::STMT +MATRIX:uniqueValues,X +cast.FLOAT(==(X,uniqueValues)) +::STMT +MATRIX:resp,X +LITERAL_FLOAT:2.22E-16 +/(%*%(t(resp),X),t(+(colSums(resp),2.22E-16))) +::STMT +MATRIX:id +diag(diag(==(id,cast.FLOAT(id)))) +::STMT +MATRIX:p_LS,X +*(cast.FLOAT(%*%(t(X),X)),p_LS) +::STMT +FLOAT:iter +LITERAL_FLOAT:5.0 +/(iter,5.0) +::STMT +FLOAT:b,rad +LITERAL_FLOAT:-1.0 +*(-(b,rad),-1.0) +::STMT +LITERAL_FLOAT:1.0,4.0 +-(4.0,1.0) +::STMT +MATRIX:X +FLOAT:val +==(X,val) +::STMT +MATRIX:W,X,H +LITERAL_FLOAT:1.0E-8 +%*%(t(W),/(X,+(%*%(W,H),1.0E-8))) +::STMT +FLOAT:sum_y_test,sum_sq_y_test,n +LITERAL_FLOAT:2.0 +-(sum_sq_y_test,*(n,^(/(sum_y_test,n),2.0))) +::STMT +FLOAT:window_size +LITERAL_FLOAT:4.0 +/(window_size,4.0) +::STMT +MATRIX:2696_mask,outr3 +LITERAL_FLOAT:0.5 +/(*(outr3,2696_mask),0.5) +::STMT +MATRIX:p,q,lambda,X +FLOAT:norm_r2 +*(/(norm_r2,sum(*(p,q))),+(%*%(t(X),%*%(X,p)),*(lambda,p))) +::STMT +MATRIX:scale_X,X,beta +*(*(cast.FLOAT(diag(scale_X)),cast.FLOAT(beta)),X) +::STMT +MATRIX:Train,2342_m_colmax,2342_m_colmin +LITERAL_FLOAT:1.0,2.0 +-(/(*(2.0,-(Train,2342_m_colmin)),-(2342_m_colmax,2342_m_colmin)),1.0) +::STMT +MATRIX:p,A +sum(*(p,%*%(t(A),%*%(A,p)))) +::STMT +MATRIX:r,d,X,Hd,parsertemp44001 +FLOAT:int656 +*(/(sum(^(r,int656)),cast.FLOAT(%*%(parsertemp44001,Hd))),%*%(X,d)) +::STMT +MATRIX:parsertemp393591,W4 +LITERAL_FLOAT:2.0 +exp(*(2.0,t(%*%(W4,parsertemp393591)))) +::STMT +MATRIX:2701_mask,2700_W,2726_dpred,parsertemp459177,2699_probs +LITERAL_FLOAT:0.5 +*(/(2701_mask,0.5),%*%(-(*(2726_dpred,2699_probs),*(2699_probs,parsertemp459177)),t(2700_W))) +::STMT +MATRIX:X,parsertemp386474 +LITERAL_FLOAT:-2.0 ++(+(*(-2.0,%*%(X,parsertemp386474)),X),t(X)) +::STMT +FLOAT:strideh,Hin,Hf +LITERAL_FLOAT:1.0 ++(/(-(Hin,Hf),strideh),1.0) +::STMT +MATRIX:P,D,Z,ZERODIAG +FLOAT:int934 +LITERAL_FLOAT:1.0 +*(-(P,/(*(Z,ZERODIAG),sum(Z))),*(/(1.0,+(D,int934)),ZERODIAG)) +::STMT +MATRIX:W2_rand +LITERAL_FLOAT:0.16823164622761327 +*(0.16823164622761327,W2_rand) +::STMT +MATRIX:parsertemp265709,Xm,tmp,Z,parsertemp265702 +%*%(t(/(%*%(parsertemp265709,Z),sum(tmp))),/(%*%(t(Xm),%*%(Xm,parsertemp265702)),sum(tmp))) +::STMT +MATRIX:curr_prediction +FLOAT:int567,282_lambda ++(sum(*(curr_prediction,-(int567,curr_prediction))),282_lambda) +::STMT +MATRIX:X2 +max(colSums(X2)) +::STMT +MATRIX:parsertemp31115,parsertemp31108 +FLOAT:parsertemp31116,parsertemp31109 +LITERAL_FLOAT:2.0,1500.0,2000.0 +^(+(/(/(parsertemp31108,parsertemp31109),2000.0),/(/(parsertemp31115,parsertemp31116),1500.0)),2.0) +::STMT +MATRIX:parsertemp106 +LITERAL_FLOAT:10.0 +*(10.0,parsertemp106) +::STMT +MATRIX:W +FLOAT:m2 +LITERAL_FLOAT:1.0 +sqrt(/(*(m2,sum(W)),-(sum(W),1.0))) +::STMT +LITERAL_FLOAT:128.0 +INT:int502,int800 +rand(int800,int502,128.0,128.0) +::STMT +MATRIX:parsertemp386440,parsertemp386441 +LITERAL_FLOAT:1.0,5.0 +>=(+(rowSums(*(parsertemp386440,parsertemp386441)),1.0),5.0) +::STMT +MATRIX:simplex +LITERAL_FLOAT:2.0 +*(2.0,/(-(rowSums(simplex),simplex),nrow(simplex))) +::STMT +MATRIX:parsertemp169867 +FLOAT:pp,zz,trust_delta_sq +sqrt(-(*(sum(parsertemp169867),sum(parsertemp169867)),*(pp,-(zz,trust_delta_sq)))) +::STMT +MATRIX:X,permut +LITERAL_FLOAT:2.0 +colSums(^(%*%(permut,X),2.0)) +::STMT +MATRIX:2697_b,parsertemp459149,2697_W,outd3 +exp(-(+(%*%(outd3,2697_W),2697_b),parsertemp459149)) +::STMT +MATRIX:b,scale_X,shift_X,X,y ++(%*%(diag(scale_X),%*%(t(X),y)),*(cast.FLOAT(b),shift_X)) +::STMT +MATRIX:parsertemp395001,W4_rand +FLOAT:int764,int842 +LITERAL_FLOAT:0.08692913816996169 +%*%(*(0.08692913816996169,W4_rand),t(/(-(parsertemp395001,int842),+(parsertemp395001,int764)))) +::STMT +MATRIX:S,addedE,parsertemp31676 +FLOAT:level +rowSums(*(==(%*%(S,parsertemp31676),level),t(addedE))) +::STMT +MATRIX:parsertemp421322 +LITERAL_FLOAT:1.0,11.0 +*(11.0,-(max(round(parsertemp421322)),1.0)) +::STMT +MATRIX:dW,parsertemp459256 +LITERAL_FLOAT:5.0E-4 ++(dW,*(5.0E-4,parsertemp459256)) +::STMT +MATRIX:R,HS +FLOAT:alpha +LITERAL_FLOAT:2.0 +^(-(R,*(alpha,HS)),2.0) +::STMT +LITERAL_FLOAT:1.0,100.0 ++(100.0,1.0) +::STMT +MATRIX:R,parsertemp40219,parsertemp40216,parsertemp40226,parsertemp40220,parsertemp40231 +FLOAT:level +/(-(+(R,rowSums(parsertemp40226)),rowSums(*(parsertemp40220,parsertemp40231))),-(+(R,rowSums(parsertemp40216)),rowSums(==(parsertemp40219,level)))) +::STMT +MATRIX:V +-(max(V),min(V)) +::STMT +MATRIX:Y,linear_terms +LITERAL_FLOAT:0.0 +*(rowSums(Y),>=(linear_terms,0.0)) +::STMT +MATRIX:parsertemp285809,p_CG,z +FLOAT:parsertemp285799,parsertemp285820,2235_sq_root_d +LITERAL_FLOAT:0.5 ++(*(0.5,cast.FLOAT(%*%(z,parsertemp285809))),*(parsertemp285820,/(-(parsertemp285799,2235_sq_root_d),cast.FLOAT(p_CG)))) +::STMT +MATRIX:Y +==(Y,min(Y)) +::STMT +FLOAT:i,num_runs,num_centroids +LITERAL_FLOAT:1.0 ++(*(num_centroids,-(num_runs,1.0)),i) +::STMT +MATRIX:E,X +LITERAL_FLOAT:0.0 +-(0.0,t(colSums(*(X,E)))) +::STMT +MATRIX:p,e,u +LITERAL_FLOAT:0.15000000000000002 +*(0.15000000000000002,%*%(%*%(e,u),p)) +::STMT +FLOAT:iter +LITERAL_FLOAT:3.0 +/(iter,3.0) +::STMT +MATRIX:t,parsertemp32854,parsertemp32848,Y,parsertemp32857,parsertemp32858 +cast.FLOAT(+(+(*(parsertemp32848,Y),*(t,Y)),*(*(t,parsertemp32854),+(parsertemp32857,parsertemp32858)))) +::STMT +MATRIX:parsertemp149335,LT,Y +LITERAL_FLOAT:-1.0 ++(*(sum(*(Y,LT)),-1.0),sum(parsertemp149335)) +::STMT +MATRIX:col +FLOAT:min_val,bin_width +LITERAL_FLOAT:0.5 +round(-(/(-(col,min_val),bin_width),0.5)) +::STMT +FLOAT:lambda +LITERAL_FLOAT:2.0 +/(lambda,2.0) +::STMT +MATRIX:diff +LITERAL_FLOAT:2.0 +sqrt(rowSums(^(diff,2.0))) +::STMT +MATRIX:F,parsertemp12916,parsertemp12915 +FLOAT:int496,int64,meanX +LITERAL_FLOAT:1.0 +*(/(F,-(sum(F),1.0)),-(+(-(parsertemp12915,parsertemp12916),/(int496,int64)),meanX)) +::STMT +MATRIX:W,X,H +FLOAT:eps +/(X,+(%*%(W,H),eps)) +::STMT +MATRIX:lambda,parsertemp149338,parsertemp149335,parsertemp149331 +LITERAL_FLOAT:-1.0,0.5 ++(+(*(sum(parsertemp149331),-1.0),sum(parsertemp149335)),*(0.5,sum(*(lambda,parsertemp149338)))) +::STMT +MATRIX:R,3_ss,dsep +FLOAT:3_eAvg +LITERAL_FLOAT:1.0 +-(/(/(+(R,dsep),3_ss),3_eAvg),1.0) +::STMT +MATRIX:S,V,W +%*%(*(W,%*%(S,t(V))),V) +::STMT +MATRIX:parsertemp389339 +LITERAL_FLOAT:1.0,2.0 +-(exp(*(2.0,t(parsertemp389339))),1.0) +::STMT +LITERAL_FLOAT:1.0,1500.0 +-(1500.0,1.0) +::STMT +LITERAL_FLOAT:-1.0,1.0 +INT:int633,n ++(diag(rand(n,int633,-1.0,-1.0)),1.0) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:0.0,0.5 +-(>=(linear_terms,0.0),0.5) +::STMT +MATRIX:posSampleMeans +LITERAL_FLOAT:2.0,2000.0 +*(2000.0,^(posSampleMeans,2.0)) +::STMT +FLOAT:run_index +LITERAL_FLOAT:1.0,96.0 ++(*(96.0,-(run_index,1.0)),1.0) +::STMT +MATRIX:dout,mask +FLOAT:p +*(/(mask,p),dout) +::STMT +MATRIX:parsertemp13725,parsertemp13720,45_CVars,45_CFreqs +LITERAL_FLOAT:1.0,1000.0 +/(/(sum(*(45_CFreqs,parsertemp13720)),-(nrow(45_CFreqs),1.0)),/(sum(*(parsertemp13725,45_CVars)),-(1000.0,nrow(45_CFreqs)))) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:0.0 +>=(finite_linear_terms,0.0) +::STMT +MATRIX:parsertemp170246,parsertemp170240,parsertemp170238 +FLOAT:float336,float335,float235 +LITERAL_FLOAT:1.0,0.254829592 +*(/(1.0,+(1.0,*(parsertemp170238,float235))),+(0.254829592,*(/(float336,parsertemp170240),+(float335,parsertemp170246)))) +::STMT +MATRIX:m_iter_err_sum,parsertemp379567,m_err +FLOAT:i_process_item +LITERAL_FLOAT:2.0 +-(*(^(/(parsertemp379567,i_process_item),2.0),i_process_item),*(*(2.0,/(parsertemp379567,i_process_item)),+(colSums(m_err),m_iter_err_sum))) +::STMT +MATRIX:parsertemp539204 +FLOAT:float280,float688,int423,float881,float839,int969 +-(max(^(/(parsertemp539204,float688),/(int423,float839))),min(^(/(parsertemp539204,float881),/(int969,float280)))) +::STMT +MATRIX:parsertemp171367,is_LT_infinite +FLOAT:float643 +LITERAL_FLOAT:1.0,0.5 ++(*(+(0.5,/(parsertemp171367,float643)),-(1.0,rowSums(is_LT_infinite))),is_LT_infinite) +::STMT +MATRIX:w,yt,Xt +LITERAL_FLOAT:0.0 +>(*(yt,%*%(Xt,w)),0.0) +::STMT +MATRIX:A,parsertemp12899,CVars,CFreqs,parsertemp12904 +LITERAL_FLOAT:1.0 +/(/(sum(*(CFreqs,parsertemp12899)),-(nrow(CFreqs),1.0)),/(sum(*(parsertemp12904,CVars)),-(nrow(A),nrow(CFreqs)))) +::STMT +MATRIX:parsertemp456742,r,y +LITERAL_FLOAT:0.0 +-(0.0,cast.FLOAT(%*%(t(r),%*%(parsertemp456742,y)))) +::STMT +MATRIX:W,H +FLOAT:eps ++(%*%(%*%(t(W),W),H),eps) +::STMT +FLOAT:F1 +LITERAL_FLOAT:2.0 +*(*(*(F1,2.0),2.0),2.0) +::STMT +LITERAL_FLOAT:10000.0,0.8 +*(10000.0,0.8) +::STMT +MATRIX:grad +LITERAL_FLOAT:0.0 +-(0.0,grad) +::STMT +MATRIX:parsertemp31111,parsertemp31113 +FLOAT:int50 +LITERAL_FLOAT:1.0,2.0,1500.0 +^(/(-(colSums(parsertemp31111),*(int50,parsertemp31113)),-(1500.0,1.0)),2.0) +::STMT +MATRIX:X +LITERAL_FLOAT:4.0 +>=(X,4.0) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,11.0 +-(+(i,11.0),1.0) +::STMT +MATRIX:xs +LITERAL_FLOAT:4.5 +>=(xs,4.5) +::STMT +MATRIX:elt,ones_ctg +LITERAL_FLOAT:1.0 +%*%(/(elt,%*%(rowSums(elt),t(ones_ctg))),-(1.0,diag(ones_ctg))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,3.0 +*(3.0,-(i,1.0)) +::STMT +MATRIX:grad +LITERAL_FLOAT:-1.0 +sum(*(*(grad,-1.0),*(grad,-1.0))) +::STMT +MATRIX:m_err_for_order,m_active_flag +LITERAL_FLOAT:0.0 +*(m_err_for_order,t(==(m_active_flag,0.0))) +::STMT +LITERAL_FLOAT:3.141592653589793 +3.141592653589793 +::STMT +MATRIX:parsertemp31111,parsertemp31113 +FLOAT:int73 +LITERAL_FLOAT:1.0,1500.0 +/(/(-(colSums(parsertemp31111),*(int73,parsertemp31113)),-(1500.0,1.0)),1500.0) +::STMT +MATRIX:R,B,parsertemp503364 +LITERAL_FLOAT:0.0 +-(0.0,%*%(t(+(R,parsertemp503364)),B)) +::STMT +MATRIX:ss,map +LITERAL_FLOAT:1.0 +*(map,/(1.0,t(ss))) +::STMT +MATRIX:w_X,z_LS,X +*(/(nrow(X),*(cast.FLOAT(w_X),cast.FLOAT(z_LS))),z_LS) +::STMT +MATRIX:parsertemp220853,W,sum_Pi,beta +LITERAL_FLOAT:3.4011973816621555 +-(+(parsertemp220853,*(beta,/(W,sum_Pi))),3.4011973816621555) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0 +>=(X,1.0) +::STMT +LITERAL_FLOAT:6.283185307179586 +6.283185307179586 +::STMT +LITERAL_FLOAT:0.15915494309189535 +0.15915494309189535 +::STMT +FLOAT:arch_coef,var_coef +LITERAL_FLOAT:1.0 +-(-(1.0,arch_coef),var_coef) +::STMT +MATRIX:log_l_part_saturated,log_l_part +LITERAL_FLOAT:2.0 +-(*(2.0,sum(log_l_part_saturated)),*(2.0,sum(log_l_part))) +::STMT +MATRIX:X +FLOAT:x +-(x,cast.FLOAT(X)) +::STMT +MATRIX:parsertemp220863,parsertemp220864,Hdiff,betamax,beta +FLOAT:INF,int45 +LITERAL_FLOAT:2.0 +/(*(*(>=(Hdiff,int45),!=(betamax,INF)),+(beta,+(parsertemp220863,parsertemp220864))),2.0) +::STMT +MATRIX:parsertemp222331 +FLOAT:sample_block_size +LITERAL_FLOAT:0.5 +round(+(0.5,/(parsertemp222331,sample_block_size))) +::STMT +MATRIX:parsertemp75086 +LITERAL_FLOAT:1.0,32.0 ++(*(parsertemp75086,32.0),1.0) +::STMT +MATRIX:parsertemp496901 +FLOAT:std +LITERAL_FLOAT:2.0 +^(*(cast.FLOAT(parsertemp496901),std),2.0) +::STMT +LITERAL_FLOAT:512.0,0.8 +*(512.0,0.8) +::STMT +MATRIX:ss +LITERAL_FLOAT:1.0,20.0 +-(/(20.0,ss),1.0) +::STMT +MATRIX:R +LITERAL_FLOAT:32.0 +>=(R,32.0) +::STMT +LITERAL_FLOAT:6.144102863722254 +6.144102863722254 +::STMT +MATRIX:y_prob,parsertemp560892,linear_terms,elt +FLOAT:int566,int338,int507 +LITERAL_FLOAT:1.0 ++(*(-(1.0,==(parsertemp560892,int566)),-(1.0,y_prob)),*(*(==(parsertemp560892,int338),exp(linear_terms)),-(1.0,/(elt,int507)))) +::STMT +FLOAT:float982,parsertemp169812 +LITERAL_FLOAT:4.0,0.5 +-(4.0,round(-(/(parsertemp169812,float982),0.5))) +::STMT +MATRIX:A +LITERAL_FLOAT:1.0 ++(/(1.0,cast.FLOAT(A)),/(1.0,cast.FLOAT(A))) +::STMT +MATRIX:W,H,X,parsertemp411102 +FLOAT:eps +*(H,/(%*%(t(W),X),+(%*%(parsertemp411102,H),eps))) +::STMT +MATRIX:parsertemp170101 +FLOAT:r_CG,g_reg,z,277_sq_root_d,parsertemp170108,parsertemp170093,pp_CG +LITERAL_FLOAT:0.5 ++(*(0.5,*(cast.FLOAT(z),+(r_CG,g_reg))),*(+(+(parsertemp170108,z),sum(parsertemp170101)),/(+(parsertemp170093,277_sq_root_d),pp_CG))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,12.0 +-(+(i,12.0),1.0) +::STMT +FLOAT:max_iter +LITERAL_FLOAT:100.0 +/(max_iter,100.0) +::STMT +FLOAT:sample_block_size +LITERAL_FLOAT:1.0,3.0 +-(*(sample_block_size,3.0),1.0) +::STMT +MATRIX:p,A +sum(*(p,%*%(t(A),%*%(A,p)))) +::STMT +FLOAT:idx +LITERAL_FLOAT:1.0,256.0 ++(-(256.0,idx),1.0) +::STMT +MATRIX:Xm,parsertemp265717,Z +LITERAL_FLOAT:2.0 +*(2.0,sum(%*%(%*%(Z,parsertemp265717),t(Xm)))) +::STMT +MATRIX:X +FLOAT:x +-(nrow(X),sum(>=(X,x))) +::STMT +MATRIX:W +sqrt(sum(round(W))) +::STMT +MATRIX:linear_terms,Y +FLOAT:var_power +LITERAL_FLOAT:1.0 +*(^(exp(linear_terms),-(1.0,var_power)),-(Y,exp(linear_terms))) +::STMT +MATRIX:surv,se_surv +FLOAT:z_alpha_2,int420 +exp(/(*(*(z_alpha_2,int420),se_surv),surv)) +::STMT +MATRIX:WM +FLOAT:m2X +LITERAL_FLOAT:1.0 +*(m2X,/(sum(WM),-(sum(WM),1.0))) +::STMT +MATRIX:X2,85_s +LITERAL_FLOAT:1.0 +*(/(1.0,85_s),nrow(X2)) +::STMT +MATRIX:X +FLOAT:eps +*(eps,nrow(X)) +::STMT +MATRIX:W +FLOAT:int246,parsertemp65,int96,parsertemp66,wt +LITERAL_FLOAT:3.0,4.0 +*(*(*(-(wt,int96),-(wt,int246)),-(sum(W),3.0)),^(sqrt(/(parsertemp65,parsertemp66)),4.0)) +::STMT +LITERAL_FLOAT:2.0,100.0 +^(100.0,2.0) +::STMT +FLOAT:parsertemp65,parsertemp66,mu +LITERAL_FLOAT:5.0 +-(mu,*(5.0,sqrt(/(parsertemp65,parsertemp66)))) +::STMT +MATRIX:F +LITERAL_FLOAT:1.0 +/(F,-(sum(F),1.0)) +::STMT +MATRIX:mat_chol +/(nrow(mat_chol),ncol(mat_chol)) +::STMT +MATRIX:g_reg,p_CG +FLOAT:parsertemp170148,parsertemp170164,q_CG,z,int13,pq_CG,int470 +*(+(+(*(parsertemp170164,pq_CG),*(z,q_CG)),sum(*(g_reg,p_CG))),/(+(*(z,int470),sqrt(parsertemp170148)),sum(^(p_CG,int13)))) +::STMT +FLOAT:num_records,i +LITERAL_FLOAT:1.0 ++(*(num_records,-(i,1.0)),1.0) +::STMT +MATRIX:R +FLOAT:int595,int353 +INT:parsertemp503361,int790 ++(R,diag(rand(parsertemp503361,int790,int595,int353))) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0 +t(!=(X,0.0)) +::STMT +FLOAT:offset_x +round(offset_x) +::STMT +MATRIX:X,tS +FLOAT:l +colSums(==(%*%(X,tS),l)) +::STMT +FLOAT:C,Hf,Wf +*(*(C,Hf),Wf) +::STMT +MATRIX:f,parsertemp472177,parsertemp472179 +-(%*%(f,parsertemp472177),t(parsertemp472179)) +::STMT +MATRIX:obj,objnew +-(cast.FLOAT(objnew),cast.FLOAT(obj)) +::STMT +MATRIX:lambda,g,beta +t(+(g,*(lambda,beta))) +::STMT +MATRIX:WM,CVars,CFreqs +FLOAT:int548 +/(sum(*(-(CFreqs,int548),CVars)),-(sum(WM),nrow(CFreqs))) +::STMT +MATRIX:X,W1,b1 ++(%*%(W1,t(X)),b1) +::STMT +MATRIX:linear_terms +FLOAT:var_power +LITERAL_FLOAT:1.0 +^(exp(linear_terms),-(1.0,var_power)) +::STMT +FLOAT:current_hash_value +LITERAL_FLOAT:1.0,9.0 +-(9.0,+(current_hash_value,1.0)) +::STMT +MATRIX:z +FLOAT:trust_delta_sq,pp_CG +*(pp_CG,-(*(cast.FLOAT(z),cast.FLOAT(z)),trust_delta_sq)) +::STMT +MATRIX:W +LITERAL_FLOAT:1.0,6.0 +*(*(6.0,sum(round(W))),-(sum(round(W)),1.0)) +::STMT +MATRIX:S +FLOAT:delta2 +LITERAL_FLOAT:2.0 +-(delta2,sum(^(S,2.0))) +::STMT +MATRIX:X +FLOAT:parsertemp72162,M +LITERAL_FLOAT:1.0 +-(*(+(parsertemp72162,1.0),M),ncol(X)) +::STMT +FLOAT:s,g,int170,num_groups +LITERAL_FLOAT:1.0,7.0 ++(*(*(-(s,int170),num_groups),7.0),*(-(g,1.0),7.0)) +::STMT +MATRIX:r,Hd +FLOAT:c +%*%(t(+(r,*(c,Hd))),+(r,*(c,Hd))) +::STMT +LITERAL_FLOAT:0.0,1.0,0.282842712474619 +INT:int945,int604 +*(rand(int945,int604,0.0,1.0),0.282842712474619) +::STMT +LITERAL_FLOAT:0.08333333333333333 +0.08333333333333333 +::STMT +MATRIX:resp,mean,X,weight +/(*(mean,%*%(t(resp),X)),t(weight)) +::STMT +LITERAL_FLOAT:-1.0E30 +INT:int924,M +rand(M,int924,-1.0E30,-1.0E30) +::STMT +FLOAT:x1,x2 +LITERAL_FLOAT:2.0 +^(-(x1,x2),2.0) +::STMT +MATRIX:r,scale_X,shift_X,y,parsertemp116004 +LITERAL_FLOAT:0.0 +-(0.0,+(*(scale_X,%*%(parsertemp116004,y)),*(cast.FLOAT(r),shift_X))) +::STMT +MATRIX:R,dssp +FLOAT:4_n,4_alpha +LITERAL_FLOAT:1.0 +*(-(1.0,4_alpha),-(/(4_n,+(R,dssp)),1.0)) +::STMT +LITERAL_FLOAT:0.6666666666666666 +0.6666666666666666 +::STMT +MATRIX:xs +FLOAT:254_x +LITERAL_FLOAT:1.0,100.0 ++(-(100.0,sum(>=(xs,254_x))),1.0) +::STMT +MATRIX:parsertemp109934 +LITERAL_FLOAT:1.0,42.0 ++(*(parsertemp109934,42.0),1.0) +::STMT +MATRIX:r +FLOAT:int435,tolerance +LITERAL_FLOAT:2.0 +sqrt(*(sum(^(r,int435)),^(tolerance,2.0))) +::STMT +MATRIX:Y +-(length(Y),sum(Y)) +::STMT +MATRIX:R,parsertemp40226,parsertemp40220 +FLOAT:eAvg +/(/(+(R,rowSums(parsertemp40226)),-(R,rowSums(parsertemp40220))),eAvg) +::STMT +MATRIX:P,Y,parsertemp221025,Z,ZERODIAG +FLOAT:int525 +LITERAL_FLOAT:1.0,4.0 +*(-(*(P,4.0),/(*(Z,ZERODIAG),sum(Z))),*(/(1.0,+(Y,int525)),+(diag(parsertemp221025),1.0))) +::STMT +MATRIX:r,parsertemp1945 +FLOAT:norm_r2 +/(sum(*(+(r,parsertemp1945),+(r,parsertemp1945))),norm_r2) +::STMT +MATRIX:p,q,lambda ++(q,*(lambda,p)) +::STMT +MATRIX:r,g,z +*(z,+(r,g)) +::STMT +MATRIX:parsertemp72333 +FLOAT:int203,rows +/(colSums(rowSums(^(parsertemp72333,int203))),rows) +::STMT +FLOAT:parsertemp40813,m2,m3 +LITERAL_FLOAT:3.0 +/(m3,^(sqrt(*(parsertemp40813,m2)),3.0)) +::STMT +MATRIX:s,w +FLOAT:lambda +*(lambda,sum(*(w,s))) +::STMT +MATRIX:posSampleVariances,negSampleVariances +FLOAT:int771,parsertemp31048,int464,parsertemp31047,int798,int713,parsertemp31053,parsertemp31052 +LITERAL_FLOAT:2.0 +/(^(+(/(posSampleVariances,int464),/(negSampleVariances,int713)),2.0),+(/(^(posSampleVariances,int771),*(parsertemp31047,parsertemp31048)),/(^(negSampleVariances,int798),*(parsertemp31052,parsertemp31053)))) +::STMT +MATRIX:y_hat,b,parsertemp31748 +sum(*(-(-(b,parsertemp31748),y_hat),-(-(b,parsertemp31748),y_hat))) +::STMT +FLOAT:parsertemp40813,m2,m4 +LITERAL_FLOAT:4.0 +/(m4,^(sqrt(*(parsertemp40813,m2)),4.0)) +::STMT +MATRIX:I,y2 +LITERAL_FLOAT:2.0 +sum(^(/(%*%(I,y2),sum(I)),2.0)) +::STMT +MATRIX:parsertemp500607,X,y,parsertemp500610,wnew +%*%(t(-(%*%(X,wnew),y)),-(%*%(X,*(parsertemp500607,parsertemp500610)),y)) +::STMT +LITERAL_FLOAT:0.1093262138242341 +0.1093262138242341 +::STMT +MATRIX:linear_terms +FLOAT:int267 +LITERAL_FLOAT:1.0,2.0 +-(1.0,-(*(2.0,>=(linear_terms,int267)),1.0)) +::STMT +MATRIX:parsertemp27546 +FLOAT:labelCorrection +t(-(parsertemp27546,labelCorrection)) +::STMT +MATRIX:parsertemp16959,id +-(==(id,t(id)),diag(diag(==(id,parsertemp16959)))) +::STMT +MATRIX:A,scale_X,shift_X,X ++(%*%(diag(scale_X),%*%(t(X),X)),%*%(shift_X,A)) +::STMT +FLOAT:191_beta2,191_t +LITERAL_FLOAT:1.0 +-(1.0,^(191_beta2,+(191_t,1.0))) +::STMT +FLOAT:parsertemp557354,weight,parsertemp557358,prob_true,prob_false +LITERAL_FLOAT:-1.0,0.6931471805599453 +*(*(-1.0,weight),+(/(*(prob_true,parsertemp557354),0.6931471805599453),/(*(prob_false,parsertemp557358),0.6931471805599453))) +::STMT +MATRIX:parsertemp31188,parsertemp31186 +FLOAT:int441 +LITERAL_FLOAT:6999.0,7000.0 +/(/(-(colSums(parsertemp31186),*(int441,parsertemp31188)),6999.0),7000.0) +::STMT +FLOAT:parsertemp40936,parsertemp40941,int194 +LITERAL_FLOAT:2.0,3.0,4.0,5.0,2001.0 +/(*(*(4.0,-(parsertemp40941,int194)),^(sqrt(parsertemp40936),2.0)),*(+(2001.0,5.0),-(2001.0,3.0))) +::STMT +MATRIX:prevTK2,X2 +==(%*%(X2,t(prevTK2)),t(rowSums(prevTK2))) +::STMT +MATRIX:sv,Xd +FLOAT:dd ++(dd,sum(*(*(Xd,sv),Xd))) +::STMT +FLOAT:num_records +LITERAL_FLOAT:1.0,3840.0 +-(1.0,/(3840.0,num_records)) +::STMT +MATRIX:parsertemp389217,parsertemp389216 +FLOAT:n +LITERAL_FLOAT:1.0 +sqrt(/(*(-(parsertemp389216,parsertemp389217),n),-(n,1.0))) +::STMT +MATRIX:parsertemp171346,parsertemp171344,linear_terms,the_exp +FLOAT:int422,int41 +LITERAL_FLOAT:1.0,1.0E7 +/(*(-(1.0,==(parsertemp171346,int422)),-(1.0,exp(parsertemp171344))),+(exp(linear_terms),==(+(int41,the_exp),1.0E7))) +::STMT +FLOAT:parsertemp166531 +LITERAL_FLOAT:2.0,10.0 ++(2.0,*(10.0,parsertemp166531)) +::STMT +FLOAT:parsertemp40837,parsertemp40832,int270 +LITERAL_FLOAT:2.0,3.0,4.0,5.0,2000.0 +/(*(*(4.0,-(parsertemp40837,int270)),^(sqrt(parsertemp40832),2.0)),*(+(2000.0,5.0),-(2000.0,3.0))) +::STMT +FLOAT:num_strata,num_groups +LITERAL_FLOAT:7.0 +*(*(num_groups,num_strata),7.0) +::STMT +FLOAT:run_index +LITERAL_FLOAT:1.0,2.0 +*(2.0,-(run_index,1.0)) +::STMT +MATRIX:output,mask +LITERAL_FLOAT:0.0,1.0 +&(==(output,0.0),==(mask,1.0)) +::STMT +MATRIX:p,G +LITERAL_FLOAT:0.85 +*(0.85,%*%(G,p)) +::STMT +FLOAT:parsertemp31330 +LITERAL_FLOAT:1.0,10000.0 +*(-(10000.0,1.0),/(*(parsertemp31330,10000.0),-(10000.0,1.0))) +::STMT +MATRIX:parsertemp220867,parsertemp220866,Hdiff,parsertemp220871,parsertemp220872,beta,betamin +FLOAT:int591 +LITERAL_FLOAT:2.0 ++(+(*(*(parsertemp220866,parsertemp220867),beta),/(*(parsertemp220871,parsertemp220872),2.0)),/(*(<(Hdiff,int591),+(beta,betamin)),2.0)) +::STMT +MATRIX:parsertemp382671,X +FLOAT:int751,int347 +LITERAL_FLOAT:0.5 +*(0.5,sum(*(!=(X,int347),^(parsertemp382671,int751)))) +::STMT +MATRIX:is_LT_infinite,Y_prob +LITERAL_FLOAT:1.0 ++(*(/(Y_prob,rowSums(Y_prob)),-(1.0,rowSums(is_LT_infinite))),is_LT_infinite) +::STMT +MATRIX:M2 +LITERAL_FLOAT:0.0 +&(!(!=(M2,0.0)),!=(M2,0.0)) +::STMT +FLOAT:parsertemp41040,int116,parsertemp41045 +LITERAL_FLOAT:2.0,3.0,4.0,5.0,2003.0 +/(*(*(4.0,-(parsertemp41045,int116)),^(sqrt(parsertemp41040),2.0)),*(+(2003.0,5.0),-(2003.0,3.0))) +::STMT +MATRIX:S,col_nonzeros,parsertemp382922,parsertemp382920 +sum(*(S,+(t(parsertemp382920),*(parsertemp382922,col_nonzeros)))) +::STMT +MATRIX:r,s,grad +-(cast.FLOAT(%*%(t(s),grad)),cast.FLOAT(%*%(t(s),r))) +::STMT +FLOAT:s,num_groups +LITERAL_FLOAT:1.0 ++(*(-(s,1.0),-(num_groups,1.0)),1.0) +::STMT +MATRIX:parsertemp1904,y +LITERAL_FLOAT:0.0,2.0 +sum(^(-(0.0,%*%(parsertemp1904,y)),2.0)) +::STMT +MATRIX:A +*(cast.FLOAT(A),cast.FLOAT(A)) +::STMT +MATRIX:parsertemp42200,F +LITERAL_FLOAT:1.0,2.0 ++(-(parsertemp42200,/(rowSums(F),2.0)),/(1.0,2.0)) +::STMT +MATRIX:tmp +FLOAT:N +LITERAL_FLOAT:1.0 +/(tmp,-(N,1.0)) +::STMT +MATRIX:C,Xm,parsertemp265702 +-(sum(%*%(%*%(Xm,parsertemp265702),t(C))),sum(Xm)) +::STMT +MATRIX:sig,parsertemp181037 +FLOAT:window_size,q +/(-(q,*(window_size,cast.FLOAT(parsertemp181037))),*(window_size,cast.FLOAT(*(sig,sig)))) +::STMT +MATRIX:parsertemp163760 +FLOAT:bin_length +/(rowSums(parsertemp163760),bin_length) +::STMT +MATRIX:X +FLOAT:value +!(<(X,value)) +::STMT +MATRIX:cumHistMul,offset,parsertemp132494,histMul,outBucket +-(offset,%*%(==(outBucket,t(parsertemp132494)),-(cumHistMul,histMul))) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0 +/(nrow(X),-(nrow(X),1.0)) +::STMT +MATRIX:y_hat,A,B +-(-(B,%*%(A,y_hat)),y_hat) +::STMT +FLOAT:int395,Hin,Win +LITERAL_FLOAT:2.0,64.0 +*(*(64.0,/(/(Hin,int395),2.0)),/(/(Win,2.0),2.0)) +::STMT +MATRIX:R +FLOAT:s,i8 +-(nrow(R),*(s,i8)) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0,1.0E-6 +/(*(1.0E-6,cast.FLOAT(%*%(X,X))),1.0) +::STMT +MATRIX:_sbcvar12 +LITERAL_FLOAT:999.0 +/(_sbcvar12,999.0) +::STMT +MATRIX:std,rad,dtd +/(-(rad,cast.FLOAT(std)),cast.FLOAT(dtd)) +::STMT +MATRIX:parsertemp79022 +LITERAL_FLOAT:1270.0 +/(parsertemp79022,1270.0) +::STMT +MATRIX:p,V +FLOAT:eps +%*%(t(p),+(%*%(t(V),%*%(V,p)),*(eps,p))) +::STMT +MATRIX:prec,X,mu +rowSums(*(-(%*%(X,prec),%*%(mu,prec)),-(%*%(X,prec),%*%(mu,prec)))) +::STMT +MATRIX:w +FLOAT:tau +*(tau,sum(abs(w))) +::STMT +MATRIX:p_CG +FLOAT:parsertemp170148,int652,z,int229 +LITERAL_FLOAT:0.5 +*(0.5,/(+(*(z,int652),sqrt(parsertemp170148)),sum(^(p_CG,int229)))) +::STMT +FLOAT:window_size,k +LITERAL_FLOAT:1.0 +-(+(k,window_size),1.0) +::STMT +FLOAT:m2,mu +LITERAL_FLOAT:1.0004995004995005 +/(sqrt(*(1.0004995004995005,m2)),mu) +::STMT +MATRIX:tmp,X,Y,out +-(%*%(t(X),*(out,Y)),tmp) +::STMT +MATRIX:_sbcvar92,parsertemp27718,parsertemp27720 +FLOAT:220_W,float581 +LITERAL_FLOAT:2.0 +^(-(_sbcvar92,+(*(parsertemp27720,float581),/(parsertemp27718,220_W))),2.0) +::STMT +MATRIX:f +LITERAL_FLOAT:1.0,2.0 +-(1.0,rowSums(^(f,2.0))) +::STMT +MATRIX:Xm,Z,parsertemp265713 +/(-(sum(%*%(Z,parsertemp265713)),sum(Xm)),sum(Xm)) +::STMT +MATRIX:B,X,y +LITERAL_FLOAT:2.0 +sum(^(-(y,%*%(X,B)),2.0)) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,750.0 +INT:int354,int611 +%*%(+(rowSums(classFeatureCounts),*(750.0,1.0)),rand(int354,int611,1.0,1.0)) +::STMT +MATRIX:curr_prediction +LITERAL_FLOAT:1.0 +sum(*(curr_prediction,-(1.0,curr_prediction))) +::STMT +LITERAL_FLOAT:1.001001001001001 +1.001001001001001 +::STMT +MATRIX:scale_X,X,y +LITERAL_FLOAT:0.0 +*(scale_X,%*%(-(0.0,t(X)),y)) +::STMT +MATRIX:vectors,pq_result +LITERAL_FLOAT:2.0 +rowSums(^(-(vectors,pq_result),2.0)) +::STMT +MATRIX:R,parsertemp72406 +LITERAL_FLOAT:2.0 +^(-(%*%(t(R),R),diag(parsertemp72406)),2.0) +::STMT +MATRIX:A,scale_X,shift_X ++(%*%(diag(scale_X),A),%*%(shift_X,A)) +::STMT +FLOAT:i7 +LITERAL_FLOAT:1.0 ++(1.0,i7) +::STMT +MATRIX:out2,parsertemp146940,184_dtemp,outd1,W3 +LITERAL_FLOAT:0.0 +%*%(t(outd1),*(>(out2,0.0),%*%(-(184_dtemp,parsertemp146940),t(W3)))) +::STMT +MATRIX:parsertemp42200,_sbcvar330 +LITERAL_FLOAT:2.0,0.5 ++(-(parsertemp42200,/(rowSums(_sbcvar330),2.0)),0.5) +::STMT +LITERAL_FLOAT:0.07261134713572442 +0.07261134713572442 +::STMT +FLOAT:int671,int784,parsertemp86,parsertemp87,int369,wt +sqrt(/(*(*(int369,wt),-(wt,int784)),*(*(parsertemp86,parsertemp87),+(wt,int671)))) +::STMT +MATRIX:U,V,X,parsertemp382840 +LITERAL_FLOAT:0.0 +%*%(*(!=(X,0.0),-(%*%(U,parsertemp382840),X)),V) +::STMT +MATRIX:P,D,beta +LITERAL_FLOAT:1.0E-12 +*(beta,/(rowSums(*(P,D)),+(rowSums(P),1.0E-12))) +::STMT +MATRIX:B,_sbcvar887 ++(%*%(_sbcvar887,B),cast.FLOAT(B)) +::STMT +MATRIX:R2,R1 +LITERAL_FLOAT:1.0E-6 +sum(<(abs(-(R1,R2)),1.0E-6)) +::STMT +MATRIX:resp,X +LITERAL_FLOAT:2.0,2.22E-16 +/(%*%(t(resp),^(X,2.0)),t(+(colSums(resp),2.22E-16))) +::STMT +MATRIX:linear_terms +FLOAT:int323 +LITERAL_FLOAT:3.141592653589793,1.0,2.0 +^(*(+(1.0,^(linear_terms,int323)),3.141592653589793),2.0) +::STMT +MATRIX:parsertemp409054,ctab +LITERAL_FLOAT:0.6 +>(/(parsertemp409054,rowSums(ctab)),0.6) +::STMT +MATRIX:parsertemp1654,A,scale_X,shift_X +%*%(diag(scale_X),t(+(%*%(parsertemp1654,A),%*%(shift_X,A)))) +::STMT +MATRIX:131_s,parsertemp115723 +FLOAT:eAvg +LITERAL_FLOAT:1.0,0.95 +*(0.95,-(/(/(parsertemp115723,131_s),eAvg),1.0)) +::STMT +MATRIX:W +LITERAL_FLOAT:6.0 +*(6.0,sum(round(W))) +::STMT +MATRIX:minD,D +t(/(<=(D,minD),rowSums(<=(D,minD)))) +::STMT +MATRIX:termination_bitmap,parsertemp222665 +FLOAT:num_successful_runs +/(sum(*(parsertemp222665,termination_bitmap)),num_successful_runs) +::STMT +MATRIX:tpr,fpr +LITERAL_FLOAT:2.0 +sum(/(*(-(fpr,fpr),+(tpr,tpr)),2.0)) +::STMT +MATRIX:d,parsertemp43998 +FLOAT:int458 +cast.FLOAT(%*%(t(d),+(d,*(int458,parsertemp43998)))) +::STMT +FLOAT:i8 +LITERAL_FLOAT:1.0,24.0 ++(1.0,*(24.0,i8)) +::STMT +MATRIX:Y,parsertemp2798,Xw +FLOAT:int484,int328 +LITERAL_FLOAT:0.0,1.0 +*(*(>(-(int328,parsertemp2798),0.0),-(1.0,*(Y,Xw))),*(>(-(int484,parsertemp2798),0.0),-(1.0,*(Y,Xw)))) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0005002501250626,44.721359549995796 +/(sqrt(*(1.0005002501250626,m2)),44.721359549995796) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,133.0 ++(*(133.0,-(i,1.0)),1.0) +::STMT +MATRIX:X2,85_s +LITERAL_FLOAT:1.0 +-(*(/(1.0,85_s),nrow(X2)),1.0) +::STMT +FLOAT:df +LITERAL_FLOAT:4.890349128221754 +*(df,4.890349128221754) +::STMT +FLOAT:P,pIn,qIn,i8 ++(+(+(P,pIn),qIn),i8) +::STMT +FLOAT:i,k +LITERAL_FLOAT:2.0 +-(+(i,k),2.0) +::STMT +MATRIX:Y_row_norm,parsertemp16881 +FLOAT:epsilon +t(+(sqrt(rowSums(parsertemp16881)),*(<(Y_row_norm,epsilon),epsilon))) +::STMT +MATRIX:parsertemp387154,y +LITERAL_FLOAT:2.0 +cast.MATRIX(sum(^(-(y,parsertemp387154),2.0))) +::STMT +FLOAT:o_init,o +LITERAL_FLOAT:2.0 +-(*(2.0,o_init),*(2.0,o)) +::STMT +MATRIX:parsertemp149248,V,X,P_1K +-(*(P_1K,%*%(X,V)),*(P_1K,rowSums(*(P_1K,parsertemp149248)))) +::STMT +MATRIX:xs +FLOAT:254_x +LITERAL_FLOAT:100.0 +-(100.0,sum(>=(xs,254_x))) +::STMT +MATRIX:s,d,parsertemp44021 +FLOAT:delta2 +*(cast.FLOAT(%*%(t(d),d)),-(delta2,cast.FLOAT(%*%(parsertemp44021,s)))) +::STMT +MATRIX:w,ones_ns +*(ones_ns,cast.FLOAT(w)) +::STMT +MATRIX:parsertemp1511,X +FLOAT:int967,n +LITERAL_FLOAT:2.0 +-(t(colSums(^(X,int967))),*(n,^(/(parsertemp1511,n),2.0))) +::STMT +MATRIX:r,s,grad +-(%*%(t(s),grad),%*%(t(s),r)) +::STMT +MATRIX:A +LITERAL_FLOAT:1.0 ++(/(1.0,cast.FLOAT(A)),/(1.0,cast.FLOAT(A))) +::STMT +MATRIX:p,w,parsertemp1940 +FLOAT:norm_r2 ++(w,*(/(norm_r2,cast.FLOAT(parsertemp1940)),p)) +::STMT +LITERAL_FLOAT:1.0,2.0 +-(2.0,1.0) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:0.0 +sum(<(linear_terms,0.0)) +::STMT +LITERAL_FLOAT:0.0,1.0 +INT:parsertemp557199,int866 +==(diag(rand(parsertemp557199,int866,1.0,1.0)),0.0) +::STMT +MATRIX:X_row_norm,parsertemp16875,parsertemp16884,parsertemp16882 +FLOAT:epsilon +%*%(+(sqrt(rowSums(parsertemp16875)),*(<(X_row_norm,epsilon),epsilon)),t(+(sqrt(parsertemp16882),*(parsertemp16884,epsilon)))) +::STMT +MATRIX:parsertemp437548,pred,parsertemp437666 +colSums(==(*(parsertemp437666,t(parsertemp437548)),pred)) +::STMT +FLOAT:i,n +LITERAL_FLOAT:1.0,1048.0 +-(n,-(+(i,1048.0),1.0)) +::STMT +LITERAL_FLOAT:2003.0 +sqrt(2003.0) +::STMT +LITERAL_FLOAT:0.0,1.0,0.05 +INT:int670,int414 +*(rand(int670,int414,0.0,1.0),0.05) +::STMT +MATRIX:selCols2 +FLOAT:n +-(n,sum(selCols2)) +::STMT +MATRIX:ytest,yhat +FLOAT:int551,mean_y_test,int687 +LITERAL_FLOAT:2.0 +/(sum(^(-(ytest,yhat),2.0)),-(sum(^(ytest,int551)),*(nrow(ytest),^(mean_y_test,int687)))) +::STMT +MATRIX:B +FLOAT:M +*(ncol(B),M) +::STMT +MATRIX:s,w +cast.FLOAT(%*%(t(+(w,s)),+(w,s))) +::STMT +FLOAT:int99,arch_coef,var_coef,int481,a0 +INT:int876,int329 +rand(int876,int329,/(a0,-(-(int99,arch_coef),var_coef)),/(a0,-(-(int481,arch_coef),var_coef))) +::STMT +MATRIX:X +FLOAT:int675 +LITERAL_FLOAT:0.0 +sum(!=(rowSums(!=(X,int675)),0.0)) +::STMT +FLOAT:i,n +LITERAL_FLOAT:1.0,1024.0 +-(n,-(+(i,1024.0),1.0)) +::STMT +MATRIX:Y_counts +FLOAT:num_features +-(sum(Y_counts),num_features) +::STMT +MATRIX:ss,se +FLOAT:130_eAvg +/(/(se,ss),130_eAvg) +::STMT +MATRIX:adjacency +LITERAL_FLOAT:0.0 +>(rowSums(adjacency),0.0) +::STMT +MATRIX:parsertemp477718,parsertemp477715,X,Y +FLOAT:x +LITERAL_FLOAT:1.0 ++(*(-(1.0,/(parsertemp477715,parsertemp477718)),Y),*(/(-(x,X),-(X,X)),Y)) +::STMT +MATRIX:parsertemp11509 +LITERAL_FLOAT:1.0,2.0 ++(1.0,*(2.0,parsertemp11509)) +::STMT +MATRIX:finite_linear_terms,the_exp +FLOAT:int960 +LITERAL_FLOAT:1.0,2.0,1.0E7 +*(*(==(+(int960,the_exp),1.0E7),exp(finite_linear_terms)),-(1.0,/(exp(finite_linear_terms),2.0))) +::STMT +MATRIX:p,V +LITERAL_FLOAT:1.0E-8 ++(%*%(t(V),%*%(V,p)),*(1.0E-8,p)) +::STMT +MATRIX:linear_terms,Y +FLOAT:parsertemp171225,link_power,float353 +LITERAL_FLOAT:1.0 +*(^(linear_terms,-(/(parsertemp171225,link_power),1.0)),-(Y,^(linear_terms,/(float353,link_power)))) +::STMT +MATRIX:s,d +FLOAT:norm_r2,alpha_deno +t(+(s,*(/(norm_r2,alpha_deno),d))) +::STMT +MATRIX:X +FLOAT:x +-(nrow(X),sum(>=(X,x))) +::STMT +MATRIX:F,parsertemp42207 +LITERAL_FLOAT:2.0 +-(parsertemp42207,/(t(colSums(F)),2.0)) +::STMT +MATRIX:parsertemp389218 +FLOAT:int263,n +LITERAL_FLOAT:1.0E-17 ++(sqrt(/(*(parsertemp389218,n),-(n,int263))),1.0E-17) +::STMT +FLOAT:parsertemp170472,parsertemp170473,log_odds,learning_rate +LITERAL_FLOAT:1.0,2.7182818284 +/(^(2.7182818284,+(log_odds,*(learning_rate,parsertemp170472))),+(1.0,^(2.7182818284,+(log_odds,parsertemp170473)))) +::STMT +MATRIX:y_corr +LITERAL_FLOAT:0.0,1.0,0.5 +*(0.5,+(<=(y_corr,0.0),>=(y_corr,1.0))) +::STMT +FLOAT:parsertemp169814 +LITERAL_FLOAT:2.302585092994046,4.0 +exp(*(2.302585092994046,-(4.0,round(parsertemp169814)))) +::STMT +FLOAT:s +LITERAL_FLOAT:81.0,-1.0,3.0 +*(81.0,^(3.0,*(s,-1.0))) +::STMT +MATRIX:F +/(%*%(rowSums(F),colSums(F)),sum(F)) +::STMT +MATRIX:Yhat_prime,E,W4 +%*%(*(E,Yhat_prime),W4) +::STMT +MATRIX:d,parsertemp43996,parsertemp43997 +LITERAL_FLOAT:2.0 +%*%(t(d),+(d,*(2.0,%*%(parsertemp43996,parsertemp43997)))) +::STMT +LITERAL_FLOAT:1.0,10000.0 +/(10000.0,-(10000.0,1.0)) +::STMT +MATRIX:X,Y +LITERAL_FLOAT:2.0 +/(+(abs(X),abs(Y)),2.0) +::STMT +FLOAT:end_stepsize,k,kmax,start_stepsize +LITERAL_FLOAT:1.0 ++(*(-(1.0,/(k,kmax)),start_stepsize),*(/(k,kmax),end_stepsize)) +::STMT +FLOAT:int764,tau,int900 +INT:int902,m +*(tau,sum(abs(rand(m,int902,int764,int900)))) +::STMT +MATRIX:parsertemp13627,43_E +FLOAT:int232,43_q +LITERAL_FLOAT:1000.0 +sqrt(/(sum(/(parsertemp13627,43_E)),*(1000.0,-(43_q,int232)))) +::STMT +MATRIX:parsertemp31112,parsertemp31114,parsertemp31105,parsertemp31107 +LITERAL_FLOAT:1499.0,1999.0,1500.0,2000.0 ++(/(/(-(parsertemp31105,parsertemp31107),1999.0),2000.0),/(/(-(parsertemp31112,parsertemp31114),1499.0),1500.0)) +::STMT +MATRIX:l1,l2 +cast.FLOAT(<(l1,l2)) +::STMT +MATRIX:D,ZERODIAG,beta +FLOAT:int333 +*(exp(*(*(D,int333),beta)),ZERODIAG) +::STMT +MATRIX:y_hat,A,B +LITERAL_FLOAT:2.0 +^(-(-(B,%*%(A,y_hat)),y_hat),2.0) +::STMT +MATRIX:missing_indicator_mat +FLOAT:global_mean +*(missing_indicator_mat,global_mean) +::STMT +MATRIX:surv,se_surv +FLOAT:parsertemp538723 +*(surv,exp(/(*(parsertemp538723,se_surv),surv))) +::STMT +MATRIX:parsertemp31026,parsertemp31033 +FLOAT:parsertemp31034,parsertemp31027 +LITERAL_FLOAT:2.0,150.0,100.0 +^(+(/(/(parsertemp31026,parsertemp31027),100.0),/(/(parsertemp31033,parsertemp31034),150.0)),2.0) +::STMT +MATRIX:Y +LITERAL_FLOAT:1.0 ++(-(ncol(Y),1.0),1.0) +::STMT +MATRIX:parsertemp396410,parsertemp396407,W3_rand +LITERAL_FLOAT:0.16823164622761327 +t(%*%(*(0.16823164622761327,W3_rand),t(/(parsertemp396407,parsertemp396410)))) +::STMT +MATRIX:g_Y,scale_X,X +LITERAL_FLOAT:0.0 +*(cast.FLOAT(diag(scale_X)),%*%(-(0.0,t(X)),g_Y)) +::STMT +MATRIX:parsertemp429917,parsertemp429915 +LITERAL_FLOAT:0.0,1.0,299.0 +-(1.0,<=(/(-(parsertemp429915,parsertemp429917),299.0),0.0)) +::STMT +MATRIX:r,g,z +sum(*(z,+(r,g))) +::STMT +FLOAT:delta +LITERAL_FLOAT:0.25 +*(0.25,delta) +::STMT +FLOAT:arch_coef,int306,var_coef,a0 +sqrt(/(a0,-(-(int306,arch_coef),var_coef))) +::STMT +MATRIX:parsertemp149323,LT,Y +LITERAL_FLOAT:-1.0 +*(sum(*(Y,-(LT,parsertemp149323))),-1.0) +::STMT +MATRIX:p_CG,z +LITERAL_FLOAT:-1.0 +*(*(cast.FLOAT(z),sum(p_CG)),-1.0) +::STMT +MATRIX:parsertemp24101 +FLOAT:num_bins,float936 +LITERAL_FLOAT:1.0 +>(+(round(-(parsertemp24101,float936)),1.0),num_bins) +::STMT +MATRIX:parsertemp459193,2700_dX,2703_X,2703_W +FLOAT:lr +LITERAL_FLOAT:5.0E-4 +*(lr,+(%*%(t(2703_X),*(parsertemp459193,2700_dX)),*(5.0E-4,2703_W))) +::STMT +MATRIX:Y +FLOAT:parsertemp185166 +>(-(cast.MATRIX(max(Y)),parsertemp185166),-(parsertemp185166,min(Y))) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0,2001.0 +*(/(2001.0,-(2001.0,1.0)),m2) +::STMT +MATRIX:w,parsertemp43626 +FLOAT:int200,float862,float690 +LITERAL_FLOAT:2.0,0.5 +INT:int774,int235 ++(*(0.5,%*%(t(w),rand(int774,int235,float690,float862))),*(2.0,sum(*(parsertemp43626,int200)))) +::STMT +MATRIX:R,parsertemp40216,parsertemp40225 +/(+(R,rowSums(*(parsertemp40216,parsertemp40225))),R) +::STMT +FLOAT:max_depth +LITERAL_FLOAT:1.0,2.0 +-(^(2.0,max_depth),1.0) +::STMT +LITERAL_FLOAT:1.0,2001.0 +/(2001.0,-(2001.0,1.0)) +::STMT +FLOAT:int154,i +LITERAL_FLOAT:1.0,100.0 ++(*(*(-(i,int154),100.0),100.0),1.0) +::STMT +MATRIX:131_s +FLOAT:n2,int815 +LITERAL_FLOAT:0.050000000000000044,1.0 +*(0.050000000000000044,-(*(/(int815,131_s),n2),1.0)) +::STMT +MATRIX:b,E,X,sb +%*%(colSums(*(X,E)),+(b,sb)) +::STMT +MATRIX:p,r,Z +FLOAT:norm_r2,parsertemp503396 +LITERAL_FLOAT:-1.0 +*(+(r,*(/(norm_r2,parsertemp503396),%*%(Z,p))),-1.0) +::STMT +FLOAT:obj,obj_new,gs +LITERAL_FLOAT:-0.5 +/(*(-0.5,gs),-(-(obj_new,obj),gs)) +::STMT +FLOAT:step +LITERAL_FLOAT:0.85 +*(step,0.85) +::STMT +MATRIX:w_X,z_LS,X +/(nrow(X),sum(*(w_X,z_LS))) +::STMT +MATRIX:parsertemp285531,z,parsertemp285533 +FLOAT:pp,sq_root_d,zq,parsertemp285523,parsertemp285538 +LITERAL_FLOAT:0.5 ++(*(0.5,sum(*(z,parsertemp285533))),*(+(+(parsertemp285538,zq),sum(parsertemp285531)),/(+(parsertemp285523,sq_root_d),pp))) +::STMT +FLOAT:191_beta2,191_t +LITERAL_FLOAT:1.0 +^(191_beta2,+(191_t,1.0)) +::STMT +MATRIX:2814_K +LITERAL_FLOAT:0.0 +cast.FLOAT(-(0.0,2814_K)) +::STMT +MATRIX:dw,history +FLOAT:sigma,float741,alpha +-(max(history),*(*(*(float741,sigma),alpha),sum(*(dw,dw)))) +::STMT +MATRIX:X,parsertemp222929 ++(X,*(cast.FLOAT(parsertemp222929),-(X,X))) +::STMT +MATRIX:dout1 +FLOAT:192_beta1 +LITERAL_FLOAT:1.0 +*(-(1.0,192_beta1),colSums(dout1)) +::STMT +FLOAT:lambda,beta +LITERAL_FLOAT:0.0,2.0 +sqrt(^(+(0.0,*(lambda,beta)),2.0)) +::STMT +MATRIX:C,parsertemp11064 +LITERAL_FLOAT:10000.0,100.0 +*(/(sum(==(parsertemp11064,C)),10000.0),100.0) +::STMT +FLOAT:N +LITERAL_FLOAT:1.0 +/(N,-(N,1.0)) +::STMT +MATRIX:residual_matrix +LITERAL_FLOAT:2.0 +^(sum(residual_matrix),2.0) +::STMT +MATRIX:E,F +LITERAL_FLOAT:0.001 +sum(<(-(E,F),0.001)) +::STMT +MATRIX:parsertemp170505 +LITERAL_FLOAT:-1.0,2.0 +sum(^(*(t(parsertemp170505),-1.0),2.0)) +::STMT +MATRIX:parsertemp1518,parsertemp1516,parsertemp1514 +FLOAT:parsertemp1519,n +LITERAL_FLOAT:0.0,1.0 +*(/(-(t(parsertemp1514),*(n,parsertemp1516)),-(n,1.0)),-(1.0,<=(/(parsertemp1518,parsertemp1519),0.0))) +::STMT +MATRIX:resp,X,parsertemp437188 +FLOAT:float191 +*(/(%*%(t(resp),X),t(+(parsertemp437188,float191))),%*%(t(resp),X)) +::STMT +LITERAL_FLOAT:225.0 +INT:int873,int730 +rand(int873,int730,225.0,225.0) +::STMT +MATRIX:X_batch,parsertemp389606,2364_2361_Y,parsertemp389586 +FLOAT:int440 +LITERAL_FLOAT:1.0 +%*%(t(*(-(2364_2361_Y,X_batch),-(int440,parsertemp389606))),/(-(exp(parsertemp389586),1.0),+(exp(parsertemp389586),1.0))) +::STMT +LITERAL_FLOAT:1.0 ++(+(1.0,1.0),1.0) +::STMT +MATRIX:2846_Q,X +FLOAT:int123,int579 +LITERAL_FLOAT:2.0 +-(+(rowSums(^(X,int123)),sum(^(2846_Q,int579))),*(2.0,%*%(X,t(2846_Q)))) +::STMT +MATRIX:s,w +LITERAL_FLOAT:0.5 +*(0.5,%*%(t(+(w,s)),+(w,s))) +::STMT +FLOAT:FN,FP,TN,TP +*(*(+(TP,FP),+(TP,FN)),+(TN,FP)) +::STMT +MATRIX:r,w +FLOAT:tau +LITERAL_FLOAT:0.5 ++(*(0.5,sum(*(r,r))),*(tau,sum(abs(w)))) +::STMT +MATRIX:parsertemp31190,parsertemp31197 +FLOAT:parsertemp31191,parsertemp31198 +LITERAL_FLOAT:2.0,1500.0,7000.0 +^(+(/(/(parsertemp31190,parsertemp31191),7000.0),/(/(parsertemp31197,parsertemp31198),1500.0)),2.0) +::STMT +MATRIX:flip_neg,is_LT_infinite,Y_prob,parsertemp171292,parsertemp171290 +FLOAT:float877 +%*%(+(*(/(Y_prob,parsertemp171290),-(float877,parsertemp171292)),is_LT_infinite),flip_neg) +::STMT +MATRIX:parsertemp171090,is_one_y_corr,t,parsertemp171096,parsertemp171080 +FLOAT:int787,float950 +LITERAL_FLOAT:1.0 ++(*(+(-(int787,t),/(parsertemp171090,parsertemp171096)),-(1.0,*(float950,parsertemp171080))),/(is_one_y_corr,-(1.0,is_one_y_corr))) +::STMT +MATRIX:simplex +FLOAT:num_func_invoc +LITERAL_FLOAT:1.0 ++(num_func_invoc,-(ncol(simplex),1.0)) +::STMT +MATRIX:parsertemp220848,parsertemp220853,parsertemp220850,beta +FLOAT:float768 ++(parsertemp220853,*(beta,/(rowSums(parsertemp220850),+(parsertemp220848,float768)))) +::STMT +MATRIX:W +FLOAT:m2 +LITERAL_FLOAT:1.0 +sqrt(/(*(m2,sum(W)),-(sum(W),1.0))) +::STMT +MATRIX:neighbors +diag(diag(neighbors)) +::STMT +MATRIX:X,y +LITERAL_FLOAT:0.0,2.0 +^(%*%(-(0.0,t(X)),y),2.0) +::STMT +MATRIX:S,addedX2 +FLOAT:level +==(%*%(S,t(addedX2)),level) +::STMT +MATRIX:p,e,u,G +LITERAL_FLOAT:0.15000000000000002,0.85 ++(*(0.85,%*%(G,p)),*(0.15000000000000002,%*%(%*%(e,u),p))) +::STMT +MATRIX:C,tmp,parsertemp265713 +FLOAT:Xm ++(Xm,trace(*(tmp,%*%(parsertemp265713,C)))) +::STMT +MATRIX:parsertemp42190,X +LITERAL_FLOAT:2.0 +-(parsertemp42190,/(X,2.0)) +::STMT +MATRIX:s +LITERAL_FLOAT:2.0 +sum(^(s,2.0)) +::STMT +MATRIX:lambda,g,beta +%*%(t(+(g,*(lambda,beta))),+(g,*(lambda,beta))) +::STMT +MATRIX:dW2 +FLOAT:193_beta2 +LITERAL_FLOAT:1.0,2.0 +*(-(1.0,193_beta2),^(dW2,2.0)) +::STMT +MATRIX:parsertemp163717,p_gaps_vector +t(*(parsertemp163717,p_gaps_vector)) +::STMT +MATRIX:img_in1 +FLOAT:weight +LITERAL_FLOAT:1.0 +*(-(1.0,weight),img_in1) +::STMT +MATRIX:dout1,mb1 +FLOAT:parsertemp147007,192_t,192_lr,192_beta1,int736 +LITERAL_FLOAT:1.0 +*(/(*(192_lr,sqrt(parsertemp147007)),-(1.0,^(192_beta1,192_t))),+(*(192_beta1,mb1),*(-(int736,192_beta1),colSums(dout1)))) +::STMT +FLOAT:parsertemp169812 +LITERAL_FLOAT:2.302585092994046 +/(parsertemp169812,2.302585092994046) +::STMT +MATRIX:residuals_vector +LITERAL_FLOAT:0.0 +/(sum(residuals_vector),+(nrow(residuals_vector),0.0)) +::STMT +MATRIX:ZtZ,parsertemp265709,Xm,parsertemp265707,parsertemp265705,Z,parsertemp265702 +%*%(t(/(%*%(parsertemp265709,Z),sum(ZtZ))),/(%*%(t(Xm),%*%(Xm,parsertemp265702)),sum(+(parsertemp265705,parsertemp265707)))) +::STMT +FLOAT:dd,step_sz +*(step_sz,dd) +::STMT +MATRIX:WM,CVars,CFreqs +FLOAT:parsertemp31268,int795,W,float277 +LITERAL_FLOAT:1.0 +/(sum(*(-(CFreqs,int795),CVars)),*(-(sum(WM),1.0),/(*(parsertemp31268,W),-(W,float277)))) +::STMT +MATRIX:ss,se +/(se,ss) +::STMT +MATRIX:g_Y,scale_X,X +LITERAL_FLOAT:-1.0 +%*%(diag(scale_X),%*%(*(t(X),-1.0),g_Y)) +::STMT +MATRIX:maskd1,out1 +FLOAT:p +LITERAL_FLOAT:0.0 +*(>(out1,0.0),/(maskd1,p)) +::STMT +MATRIX:V,W,parsertemp10741,H +LITERAL_FLOAT:1.0E-8 +*(H,/(%*%(t(W),V),+(%*%(parsertemp10741,H),1.0E-8))) +::STMT +MATRIX:parsertemp386448,withinEps +LITERAL_FLOAT:0.0,1.0 +>(colSums(>(*(parsertemp386448,withinEps),0.0)),1.0) +::STMT +MATRIX:finite_linear_terms +LITERAL_FLOAT:-1.0 +*(exp(finite_linear_terms),-1.0) +::STMT +MATRIX:X +FLOAT:x +LITERAL_FLOAT:1.0 ++(-(nrow(X),sum(>=(X,x))),1.0) +::STMT +MATRIX:parsertemp122290,X2 +LITERAL_FLOAT:0.0,4.0 +|(<(t(colSums(X2)),4.0),==(t(%*%(parsertemp122290,X2)),0.0)) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,50.0 +INT:int258,int270 +%*%(+(rowSums(classFeatureCounts),*(50.0,1.0)),rand(int270,int258,1.0,1.0)) +::STMT +MATRIX:f,parsertemp472177,I,parsertemp472179 +LITERAL_FLOAT:2.0 +^(*(I,-(%*%(f,parsertemp472177),t(parsertemp472179))),2.0) +::STMT +MATRIX:parsertemp387552 +LITERAL_FLOAT:10.0 +^(10.0,parsertemp387552) +::STMT +MATRIX:parsertemp72182 +FLOAT:subvector_size +LITERAL_FLOAT:1.0 ++(*(parsertemp72182,subvector_size),1.0) +::STMT +MATRIX:Y,parsertemp282723 +==(Y,cast.FLOAT(parsertemp282723)) +::STMT +MATRIX:Xm,parsertemp265733 +abs(/(sum(-(parsertemp265733,Xm)),sum(Xm))) +::STMT +FLOAT:end_stepsize,k,kmax +*(/(k,kmax),end_stepsize) +::STMT +MATRIX:parsertemp271862,parsertemp271860 +FLOAT:obj,parsertemp271888 +LITERAL_FLOAT:-0.5 +/(-(obj,parsertemp271888),*(-0.5,-(sum(parsertemp271860),sum(parsertemp271862)))) +::STMT +MATRIX:parsertemp500606,parsertemp500604,w +FLOAT:int50 +t(-(*(*(parsertemp500604,parsertemp500606),>(parsertemp500606,int50)),w)) +::STMT +MATRIX:binary_array +LITERAL_FLOAT:1.0 ++(1.0,cast.FLOAT(binary_array)) +::STMT +MATRIX:R,dssp,dsep,parsertemp40232,parsertemp40220 +FLOAT:eAvg +/(/(-(+(R,dsep),rowSums(parsertemp40232)),-(+(R,dssp),rowSums(parsertemp40220))),eAvg) +::STMT +MATRIX:parsertemp386457,parsertemp386459,neighbors,corePts,withinEps,parsertemp386456 +LITERAL_FLOAT:0.0 +*(>(*(*(neighbors,corePts),withinEps),0.0),==(-(*(parsertemp386456,parsertemp386457),parsertemp386459),0.0)) +::STMT +MATRIX:parsertemp222331 +LITERAL_FLOAT:200.0,0.5 +round(+(0.5,/(parsertemp222331,200.0))) +::STMT +MATRIX:CFreqs1,present_domain_vals_mat,parsertemp27487 +LITERAL_FLOAT:1.0 +*(-(%*%(present_domain_vals_mat,CFreqs1),1.0),%*%(present_domain_vals_mat,parsertemp27487)) +::STMT +MATRIX:p,e,u +%*%(%*%(e,u),p) +::STMT +MATRIX:parsertemp220853,parsertemp220854,betamin +FLOAT:logU +LITERAL_FLOAT:0.0 +*(<(-(+(parsertemp220853,parsertemp220854),logU),0.0),betamin) +::STMT +MATRIX:b,E,X,sb +cast.FLOAT(%*%(colSums(*(X,E)),+(b,sb))) +::STMT +MATRIX:sb +FLOAT:delta +LITERAL_FLOAT:2.0 +-(sum(^(sb,2.0)),^(delta,2.0)) +::STMT +MATRIX:parsertemp171084,parsertemp171083 +LITERAL_FLOAT:0.010328,-2.0,0.802853 +*(sqrt(*(-2.0,parsertemp171083)),+(0.802853,*(sqrt(parsertemp171084),0.010328))) +::STMT +MATRIX:c,G +*(G,t(c)) +::STMT +MATRIX:parsertemp399242,W3_rand +FLOAT:int741,int312 +LITERAL_FLOAT:0.6546536707079771 +%*%(*(0.6546536707079771,W3_rand),t(/(-(parsertemp399242,int741),+(parsertemp399242,int312)))) +::STMT +FLOAT:parsertemp164939 +LITERAL_FLOAT:2.0,100.0 ++(2.0,*(100.0,parsertemp164939)) +::STMT +MATRIX:p,p2 +LITERAL_FLOAT:1.0E8 +>(abs(-(p2,p)),1.0E8) +::STMT +MATRIX:ytest,yhat +sum(-(ytest,yhat)) +::STMT +MATRIX:parsertemp221021 +LITERAL_FLOAT:1.0 ++(diag(parsertemp221021),1.0) +::STMT +MATRIX:CFreqs1,present_domain_vals_mat +LITERAL_FLOAT:1.0 +-(%*%(present_domain_vals_mat,CFreqs1),1.0) +::STMT +MATRIX:G,authorities +/(%*%(G,authorities),max(%*%(G,authorities))) +::STMT +LITERAL_FLOAT:1.0,2003.0 +-(2003.0,1.0) +::STMT +MATRIX:parsertemp137847,keyPos1 +*(t(parsertemp137847),keyPos1) +::STMT +MATRIX:s,w,wnew,parsertemp44079 +LITERAL_FLOAT:-1.0,2.0,0.5 ++(*(0.5,%*%(t(wnew),+(w,s))),*(2.0,*(-1.0,sum(parsertemp44079)))) +::STMT +MATRIX:m_iter_err_sum_squared,parsertemp379562,parsertemp379571,m_iter_err_sum,parsertemp379569 +FLOAT:i_process_item +LITERAL_FLOAT:1.0 +/(+(-(*(parsertemp379569,i_process_item),*(parsertemp379571,m_iter_err_sum)),+(colSums(parsertemp379562),m_iter_err_sum_squared)),-(i_process_item,1.0)) +::STMT +MATRIX:p,r,Z +FLOAT:norm_r2,parsertemp503396 +LITERAL_FLOAT:2.0 +^(+(r,*(/(norm_r2,parsertemp503396),%*%(Z,p))),2.0) +::STMT +MATRIX:dX,v +FLOAT:lr,mu +-(*(mu,v),*(lr,dX)) +::STMT +FLOAT:246_AIC_best,246_thr +abs(*(246_thr,246_AIC_best)) +::STMT +MATRIX:X,Centering,ScaleFactor +%*%(t(/(-(X,Centering),ScaleFactor)),/(-(X,Centering),ScaleFactor)) +::STMT +MATRIX:d,X,logisticD +LITERAL_FLOAT:2.0 +*(2.0,%*%(t(X),*(logisticD,%*%(X,d)))) +::STMT +MATRIX:U,row_nonzeros +LITERAL_FLOAT:2.0 +sum(*(^(U,2.0),row_nonzeros)) +::STMT +MATRIX:s,w +LITERAL_FLOAT:0.5 +*(0.5,%*%(t(+(w,s)),+(w,s))) +::STMT +MATRIX:parsertemp410979,W,X,H,parsertemp410981,parsertemp410984 +/(*(W,%*%(/(X,parsertemp410984),t(H))),t(rowSums(/(parsertemp410979,parsertemp410981)))) +::STMT +MATRIX:S,parsertemp382904,V,W,row_nonzeros +LITERAL_FLOAT:1.0E-6 ++(%*%(*(W,%*%(S,parsertemp382904)),V),*(*(1.0E-6,S),row_nonzeros)) +::STMT +FLOAT:parsertemp31330 +LITERAL_FLOAT:9999.0,10000.0 +*(9999.0,/(*(parsertemp31330,10000.0),9999.0)) +::STMT +LITERAL_FLOAT:3.0,2003.0 +-(2003.0,3.0) +::STMT +MATRIX:out +LITERAL_FLOAT:0.5 +*(0.5,cast.FLOAT(%*%(t(out),out))) +::STMT +MATRIX:upd_W1,W1_grad,W1 +FLOAT:parsertemp389637,mu,step ++(W1,-(*(mu,upd_W1),*(/(step,parsertemp389637),W1_grad))) +::STMT +MATRIX:ones,classFeatureCounts +FLOAT:float714,int456 +LITERAL_FLOAT:1.0 +/(+(classFeatureCounts,1.0),%*%(+(rowSums(classFeatureCounts),*(int456,float714)),ones)) +::STMT +LITERAL_FLOAT:2.0,2001.0 +^(2001.0,2.0) +::STMT +MATRIX:W1_rand,X,parsertemp400556,parsertemp400566 +FLOAT:float936 +LITERAL_FLOAT:0.08333333333333333 +%*%(*(0.08333333333333333,W1_rand),t(/(-(X,parsertemp400556),+(parsertemp400566,float936)))) +::STMT +FLOAT:avg_res,ytest,mean_y_test,int765,yhat,int958 +LITERAL_FLOAT:1.0,2.0 +/(-(^(-(ytest,yhat),2.0),*(1.0,^(avg_res,int958))),-(^(cast.FLOAT(ytest),2.0),*(1.0,^(mean_y_test,int765)))) +::STMT +MATRIX:X +FLOAT:x +/(-(x,X),-(cast.FLOAT(X),cast.FLOAT(X))) +::STMT +MATRIX:col_nonzeros,U,V,row_nonzeros +FLOAT:int917,int674 ++(sum(*(^(U,int917),row_nonzeros)),sum(*(^(V,int674),col_nonzeros))) +::STMT +MATRIX:parsertemp24102 +FLOAT:num_bins +LITERAL_FLOAT:1.0 +*(>(+(round(parsertemp24102),1.0),num_bins),num_bins) +::STMT +MATRIX:parsertemp539204 +FLOAT:float276,float683 +LITERAL_FLOAT:0.6666666666666666 +-(max(^(/(parsertemp539204,float276),0.6666666666666666)),min(^(/(parsertemp539204,float683),0.6666666666666666))) +::STMT +MATRIX:r,d,Hd,parsertemp44001 +FLOAT:int112 +*(/(sum(^(r,int112)),cast.FLOAT(%*%(parsertemp44001,Hd))),d) +::STMT +MATRIX:m_active_flag +LITERAL_FLOAT:0.0 +t(==(m_active_flag,0.0)) +::STMT +LITERAL_FLOAT:1.0005002501250626 +1.0005002501250626 +::STMT +MATRIX:parsertemp170242,parsertemp170240,parsertemp170238 +FLOAT:float516,float545,float457 +LITERAL_FLOAT:1.0,1.421413741 +*(/(1.0,+(1.0,*(parsertemp170238,float545))),+(1.421413741,*(/(float457,parsertemp170240),+(float516,parsertemp170242)))) +::STMT +LITERAL_FLOAT:2.0,2003.0 +-(2003.0,2.0) +::STMT +MATRIX:t_gp,parsertemp170243,parsertemp170239 +FLOAT:float433 +LITERAL_FLOAT:1.0,1.421413741,-0.284496736 ++(-0.284496736,*(/(1.0,+(float433,parsertemp170239)),+(1.421413741,*(t_gp,parsertemp170243)))) +::STMT +MATRIX:X +FLOAT:int432 +LITERAL_FLOAT:1.0E-6 +<(sqrt(rowSums(^(X,int432))),1.0E-6) +::STMT +MATRIX:A +LITERAL_FLOAT:1.0 +>=(rowSums(abs(A)),1.0) +::STMT +FLOAT:int709,b +LITERAL_FLOAT:2.0 +-(^(b,2.0),int709) +::STMT +MATRIX:B +FLOAT:M +/(nrow(B),M) +::STMT +MATRIX:simplex +LITERAL_FLOAT:0.0 ++(0.0,ncol(simplex)) +::STMT +MATRIX:minD +FLOAT:sumXsq ++(sumXsq,sum(minD)) +::STMT +MATRIX:H,parsertemp220860,parsertemp220861,beta +FLOAT:logU +LITERAL_FLOAT:0.0,2.0 +/(*(<(-(H,logU),0.0),+(beta,+(parsertemp220860,parsertemp220861))),2.0) +::STMT +MATRIX:output1,dataset +LITERAL_FLOAT:0.16 +<(abs(-(output1,dataset)),0.16) +::STMT +MATRIX:r,s,grad +LITERAL_FLOAT:-0.5 +*(-0.5,-(%*%(t(s),grad),%*%(t(s),r))) +::STMT +MATRIX:parsertemp131907,cumHistMul,offset,parsertemp132092,histMul,outBucket +-(offset,%*%(==(outBucket,%*%(parsertemp132092,parsertemp131907)),-(cumHistMul,histMul))) +::STMT +LITERAL_FLOAT:-1.0,0.001 +*(0.001,-1.0) +::STMT +MATRIX:centroid_placer,X_samples +%*%(centroid_placer,%*%(centroid_placer,X_samples)) +::STMT +LITERAL_FLOAT:0.0,1.0 +/(1.0,0.0) +::STMT +LITERAL_FLOAT:1.0,2.0 +/(1.0,2.0) +::STMT +LITERAL_FLOAT:-1.0,2.0 +/(-1.0,2.0) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0,2.0 +/(*($1:ncol(X),+($1,1.0)),2.0) +::STMT +MATRIX:parsertemp165076,X,y +LITERAL_FLOAT:2.0 +/(sum(^(-(y,parsertemp165076),2.0)),nrow(X)) +::STMT +MATRIX:parsertemp170277 +LITERAL_FLOAT:3.141592653589793 +/(parsertemp170277,3.141592653589793) +::STMT +MATRIX:parsertemp403497,parsertemp403500,W3_rand +LITERAL_FLOAT:0.1651445647689541 +t(%*%(*(0.1651445647689541,W3_rand),t(/(parsertemp403497,parsertemp403500)))) +::STMT +MATRIX:parsertemp286536,parsertemp286535 +FLOAT:float220 +sqrt(cast.FLOAT(%*%(t(parsertemp286536),+(float220,parsertemp286535)))) +::STMT +MATRIX:n_event_stratum,n_risk_stratum,n_risk +/(*(n_risk,n_event_stratum),n_risk_stratum) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,80656.0 +*(-(i,1.0),80656.0) +::STMT +MATRIX:y_corr +LITERAL_FLOAT:1.0 +-(1.0,>=(y_corr,1.0)) +::STMT +MATRIX:means,parsertemp389215 +FLOAT:n +LITERAL_FLOAT:1.0 +/(*(-(/(parsertemp389215,n),*(means,means)),n),-(n,1.0)) +::STMT +FLOAT:max_depth +LITERAL_FLOAT:1.0,2.0 +*(2.0,-(^(2.0,max_depth),1.0)) +::STMT +LITERAL_FLOAT:1.0,1.5 +/(1.0,1.5) +::STMT +FLOAT:e,mu +LITERAL_FLOAT:0.999,4.0 ++(mu,/(-(0.999,mu),-(4.0,e))) +::STMT +MATRIX:B2,ytest,Xtest,parsertemp387577 +cast.FLOAT(%*%(t(-(ytest,parsertemp387577)),-(ytest,%*%(Xtest,B2)))) +::STMT +MATRIX:r,obj,parsertemp44063,parsertemp44077,parsertemp44065,grad +FLOAT:float27,C,parsertemp44081 +LITERAL_FLOAT:-0.5 +/(-(obj,+(*(float27,parsertemp44077),*(C,parsertemp44081))),*(-0.5,-(%*%(parsertemp44063,grad),%*%(parsertemp44065,r)))) +::STMT +MATRIX:LT,parsertemp149320,parsertemp150469 +exp(-(LT,%*%(parsertemp149320,parsertemp150469))) +::STMT +FLOAT:i +LITERAL_FLOAT:80656.0 +*(i,80656.0) +::STMT +MATRIX:r,d,parsertemp43998 +FLOAT:C +/(sum(*(r,r)),%*%(t(d),+(d,*(C,parsertemp43998)))) +::STMT +MATRIX:X,parsertemp220785 +FLOAT:int457,int358 +LITERAL_FLOAT:-2.0 ++(+(*(-2.0,%*%(X,parsertemp220785)),rowSums(^(X,int457))),t(rowSums(^(X,int358)))) +::STMT +MATRIX:D,parsertemp10961,parsertemp10958 ++(%*%(D,t(parsertemp10958)),t(parsertemp10961)) +::STMT +LITERAL_FLOAT:1.0,10.0 +/(1.0,10.0) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:105.0,1.0 +INT:int623,int718 +%*%(+(rowSums(classFeatureCounts),*(105.0,1.0)),rand(int623,int718,1.0,1.0)) +::STMT +MATRIX:237_present_domain_vals_mat +LITERAL_FLOAT:10000.0 +-(10000.0,nrow(237_present_domain_vals_mat)) +::STMT +MATRIX:F +LITERAL_FLOAT:1.0 +/(F,-(sum(F),1.0)) +::STMT +MATRIX:Q +FLOAT:int677 +LITERAL_FLOAT:1.0 +INT:int897,parsertemp500306 +%*%(rand(parsertemp500306,int897,1.0,1.0),t(rowSums(^(Q,int677)))) +::STMT +LITERAL_FLOAT:1.0,2001.0 ++(2001.0,1.0) +::STMT +MATRIX:r,g,z +LITERAL_FLOAT:0.5 +*(0.5,sum(*(z,+(r,g)))) +::STMT +LITERAL_FLOAT:1.0E-14 +1.0E-14 +::STMT +LITERAL_FLOAT:9.999999999999998E-15 +9.999999999999998E-15 +::STMT +MATRIX:pearson_residual_sq +LITERAL_FLOAT:9950.0 +/(sum(pearson_residual_sq),9950.0) +::STMT +FLOAT:parsertemp72162,M +LITERAL_FLOAT:1.0 +*(+(parsertemp72162,1.0),M) +::STMT +MATRIX:g_Y,lambda,parsertemp171599,scale_X,beta +FLOAT:int223 ++(*(cast.FLOAT(diag(scale_X)),%*%(-(int223,parsertemp171599),g_Y)),*(cast.FLOAT(lambda),cast.FLOAT(beta))) +::STMT +MATRIX:S +LITERAL_FLOAT:2.0 +^(diag(S),2.0) +::STMT +MATRIX:R,ones +%*%(t(+(R,diag(ones))),+(R,diag(ones))) +::STMT +MATRIX:scale_X,shift_X,X +LITERAL_FLOAT:2.0 +%*%(X,*(*(2.0,scale_X),shift_X)) +::STMT +MATRIX:P +LITERAL_FLOAT:1.0 +<=(rowSums(P),1.0) +::STMT +MATRIX:ytest +LITERAL_FLOAT:1.0,2.0 +*(1.0,^(/(cast.FLOAT(ytest),1.0),2.0)) +::STMT +LITERAL_FLOAT:5.0,2001.0 ++(2001.0,5.0) +::STMT +MATRIX:out1,187_dX,parsertemp146955 +FLOAT:beta1,int533 +LITERAL_FLOAT:1.0 +*(-(1.0,beta1),colSums(*(>(out1,int533),*(parsertemp146955,187_dX)))) +::STMT +LITERAL_FLOAT:3.0,2001.0 ++(2001.0,3.0) +::STMT +MATRIX:d,od,X,logisticD +FLOAT:C ++(d,*(C,%*%(t(X),*(logisticD,od)))) +::STMT +MATRIX:M +LITERAL_FLOAT:2.0 +/(ncol(M),2.0) +::STMT +MATRIX:X_batch,maskd1,out1,185_dX,parsertemp146947,W2 +FLOAT:p,int850 +%*%(t(X_batch),*(*(>(out1,int850),/(maskd1,p)),%*%(*(parsertemp146947,185_dX),t(W2)))) +::STMT +MATRIX:M +sum(exp(-(M,max(M)))) +::STMT +FLOAT:int134,z,pp_CG,parsertemp170091 +LITERAL_FLOAT:0.5 +*(0.5,/(-(*(z,int134),sqrt(parsertemp170091)),pp_CG)) +::STMT +MATRIX:Grad +LITERAL_FLOAT:-1.0,2.0 +sum(^(*(Grad,-1.0),2.0)) +::STMT +MATRIX:sums +LITERAL_FLOAT:4.0 +/(sums,4.0) +::STMT +MATRIX:parsertemp221417 +FLOAT:float22 +LITERAL_FLOAT:0.1,2.0 +*(sum(^(-(float22,parsertemp221417),2.0)),0.1) +::STMT +MATRIX:t,parsertemp32854,parsertemp32848,Y,parsertemp32857,parsertemp32858 +cast.FLOAT(+(+(*(parsertemp32848,Y),*(t,Y)),*(*(t,parsertemp32854),+(parsertemp32857,parsertemp32858)))) +::STMT +MATRIX:lambda,parsertemp286549 +FLOAT:new_log_l +LITERAL_FLOAT:0.5 +-(new_log_l,*(0.5,cast.FLOAT(%*%(lambda,parsertemp286549)))) +::STMT +MATRIX:parsertemp220786,parsertemp220783 +FLOAT:int927 +sqrt(+(+(*(int927,parsertemp220786),rowSums(parsertemp220783)),t(rowSums(parsertemp220783)))) +::STMT +MATRIX:parsertemp500607,X,y,parsertemp500610 +t(-(%*%(X,*(parsertemp500607,parsertemp500610)),y)) +::STMT +MATRIX:s,parsertemp44016,d +LITERAL_FLOAT:2.0 +^(%*%(t(-(s,parsertemp44016)),d),2.0) +::STMT +MATRIX:output_values +FLOAT:log_odds,learning_rate ++(log_odds,*(learning_rate,sum(output_values))) +::STMT +MATRIX:X +LITERAL_FLOAT:2.0 +/(X,2.0) +::STMT +MATRIX:parsertemp561025 +LITERAL_FLOAT:0.0 +/(parsertemp561025,0.0) +::STMT +MATRIX:y_corr +LITERAL_FLOAT:1.0,2.0,0.5 +*(-(1.0,*(2.0,y_corr)),>(y_corr,0.5)) +::STMT +MATRIX:prec_chol +LITERAL_FLOAT:2.0 +t(^(prec_chol,2.0)) +::STMT +MATRIX:g_reg,p_CG +FLOAT:parsertemp170113,q_CG,int940,z,pq_CG,pp_CG,parsertemp170091 +*(+(+(*(parsertemp170113,pq_CG),*(z,q_CG)),sum(*(g_reg,p_CG))),/(-(*(z,int940),sqrt(parsertemp170091)),pp_CG)) +::STMT +MATRIX:parsertemp171083 +LITERAL_FLOAT:0.010328,-2.0 +*(sqrt(*(-2.0,parsertemp171083)),0.010328) +::STMT +FLOAT:parsertemp22454,parsertemp22485 +LITERAL_FLOAT:2.0 +exp(-(parsertemp22485,*(2.0,sqrt(parsertemp22454)))) +::STMT +MATRIX:sig_sq +sqrt(sig_sq) +::STMT +MATRIX:parsertemp171083 +LITERAL_FLOAT:-2.0 +sqrt(*(-2.0,parsertemp171083)) +::STMT +MATRIX:parsertemp31910,X +FLOAT:alpha +LITERAL_FLOAT:1.0 +*(-(1.0,alpha),-(/(nrow(X),t(parsertemp31910)),1.0)) +::STMT +MATRIX:252_Y,252_X,252_K +LITERAL_FLOAT:0.0 ++(*(-(0.0,cast.FLOAT(252_K)),-(cast.FLOAT(252_X),cast.FLOAT(252_X))),-(cast.FLOAT(252_Y),cast.FLOAT(252_Y))) +::STMT +LITERAL_FLOAT:1.0,2.0,3.0,2000.0 +*(*(-(2000.0,2.0),+(2000.0,1.0)),+(2000.0,3.0)) +::STMT +MATRIX:R,S,Grad +-(sum(*(S,Grad)),sum(*(S,R))) +::STMT +MATRIX:p,e,u,G +LITERAL_FLOAT:0.15000000000000002,0.85 ++(*(0.85,%*%(G,p)),*(0.15000000000000002,%*%(%*%(e,u),p))) +::STMT +MATRIX:f,parsertemp472172 +LITERAL_FLOAT:0.0 +rowSums(*(-(0.0,f),parsertemp472172)) +::STMT +FLOAT:int780,ss2,ssPrev,Xm,parsertemp265718 +LITERAL_FLOAT:4000.0 +/(/(-(+(Xm,ss2),*(int780,parsertemp265718)),4000.0),ssPrev) +::STMT +MATRIX:parsertemp107030 +LITERAL_FLOAT:1.0,7.0 ++(*(parsertemp107030,7.0),1.0) +::STMT +MATRIX:X,K +*(cast.FLOAT(K),-(cast.FLOAT(X),cast.FLOAT(X))) +::STMT +MATRIX:Xm +rowSums(t(Xm)) +::STMT +MATRIX:parsertemp436659 +t(rowSums(parsertemp436659)) +::STMT +LITERAL_FLOAT:1.0E-5 +1.0E-5 +::STMT +MATRIX:r,d,parsertemp43998 +FLOAT:int550 +LITERAL_FLOAT:2.0 +/(sum(^(r,2.0)),%*%(t(d),+(d,*(int550,parsertemp43998)))) +::STMT +LITERAL_FLOAT:32.0 +INT:int197,int136 +rand(int197,int136,32.0,32.0) +::STMT +MATRIX:X,tS +FLOAT:l +t(colSums(==(%*%(X,tS),l))) +::STMT +MATRIX:Y_prob,Y +LITERAL_FLOAT:0.0 +sum(*(<=(Y_prob,0.0),abs(Y))) +::STMT +MATRIX:jaccardDist,adjacency +FLOAT:threshold +&(adjacency,>=(jaccardDist,threshold)) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0005002501250626,4.0 +^(sqrt(*(1.0005002501250626,m2)),4.0) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,750.0 +/(+(classFeatureCounts,1.0),+(rowSums(classFeatureCounts),750.0)) +::STMT +MATRIX:_sbcvar11,43_r,43_c +LITERAL_FLOAT:2.0,1000.0 +^(-(_sbcvar11,/(%*%(43_r,43_c),1000.0)),2.0) +::STMT +MATRIX:G,authorities,hubs +LITERAL_FLOAT:2.0 +^(-(/(%*%(G,authorities),max(hubs)),hubs),2.0) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0005002501250626,3.0 +^(sqrt(*(1.0005002501250626,m2)),3.0) +::STMT +MATRIX:surv,n_risk +FLOAT:int594 +/(*(surv,sqrt(-(int594,surv))),sqrt(n_risk)) +::STMT +FLOAT:so_linear_approx +LITERAL_FLOAT:-0.5 +*(-0.5,so_linear_approx) +::STMT +FLOAT:delta +LITERAL_FLOAT:0.5 +*(0.5,delta) +::STMT +MATRIX:se_surv +FLOAT:z_alpha_2 +LITERAL_FLOAT:-1.0 +*(*(z_alpha_2,-1.0),se_surv) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0004995004995005 +sqrt(*(1.0004995004995005,m2)) +::STMT +MATRIX:classCounts +LITERAL_FLOAT:80.0 +/(classCounts,80.0) +::STMT +MATRIX:parsertemp379565,m_iter_err_sum +FLOAT:i_process_item +LITERAL_FLOAT:-1.0,2.0 +*(2.0,/(*(-(parsertemp379565,m_iter_err_sum),-1.0),i_process_item)) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0,2.0 +/(colSums(^(X,2.0)),-(nrow(X),1.0)) +::STMT +MATRIX:posSampleVariances,negSampleVariances +FLOAT:int681,int584,int92,int34 +LITERAL_FLOAT:7.996E9,2.0,3.37275E9 +/(^(+(/(posSampleVariances,int681),/(negSampleVariances,int34)),2.0),+(/(^(posSampleVariances,int584),7.996E9),/(^(negSampleVariances,int92),3.37275E9))) +::STMT +LITERAL_FLOAT:1.0,20.0 +-(20.0,1.0) +::STMT +MATRIX:scores,parsertemp145878 +/(exp(-(scores,parsertemp145878)),rowSums(exp(-(scores,parsertemp145878)))) +::STMT +MATRIX:t_gp,parsertemp171332,pt_gp,parsertemp171331,Y,the_gauss_exp,parsertemp171327,parsertemp171316 +FLOAT:one_over_sqrt_two_pi +LITERAL_FLOAT:2.0,0.25 +/(*(one_over_sqrt_two_pi,+(-(Y,parsertemp171327),*(parsertemp171331,parsertemp171332))),*(*(0.25,*(t_gp,parsertemp171316)),-(2.0,*(the_gauss_exp,pt_gp)))) +::STMT +MATRIX:ss +FLOAT:alpha +LITERAL_FLOAT:1.0,40.0 +*(-(1.0,alpha),-(/(40.0,ss),1.0)) +::STMT +LITERAL_FLOAT:0.3989422804014327 +0.3989422804014327 +::STMT +LITERAL_FLOAT:0.1 +0.1 +::STMT +LITERAL_FLOAT:-0.1 +-0.1 +::STMT +MATRIX:X +FLOAT:var_lag,parsertemp496688,parsertemp496694,var_coef,a0 +LITERAL_FLOAT:2.0 ++(parsertemp496694,/(^(cast.FLOAT(X),2.0),+(+(a0,parsertemp496688),*(var_coef,var_lag)))) +::STMT +MATRIX:parsertemp222331 +LITERAL_FLOAT:200.0 +/(parsertemp222331,200.0) +::STMT +MATRIX:parsertemp220903 +FLOAT:float857 +LITERAL_FLOAT:2.0,1.0E-5 +*(sum(^(-(float857,parsertemp220903),2.0)),1.0E-5) +::STMT +MATRIX:parsertemp399255,W4_rand +FLOAT:int818,int687 +LITERAL_FLOAT:0.08725945907447251 +%*%(*(0.08725945907447251,W4_rand),t(/(-(parsertemp399255,int687),+(parsertemp399255,int818)))) +::STMT +MATRIX:tmp +FLOAT:N,parsertemp274090 +LITERAL_FLOAT:0.0,1.0 +*(/(tmp,-(N,1.0)),-(1.0,<=(/(tmp,parsertemp274090),0.0))) +::STMT +MATRIX:W,H,parsertemp411105 +LITERAL_FLOAT:1.0E-8 ++(%*%(W,%*%(*(H,parsertemp411105),t(H))),1.0E-8) +::STMT +MATRIX:log_prob,X +LITERAL_FLOAT:1.8378770664093453,-0.5 +*(-0.5,+(*(ncol(X),1.8378770664093453),log_prob)) +::STMT +LITERAL_FLOAT:1.5000000000000002E-8 +1.5000000000000002E-8 +::STMT +MATRIX:parsertemp539203 +FLOAT:int993 +LITERAL_FLOAT:1.0,2.0,1.5 +max(^(/(*(parsertemp539203,int993),2.0),/(1.0,1.5))) +::STMT +FLOAT:width,x1,x2 +LITERAL_FLOAT:-1.0,2.0 +/(*(-1.0,^(-(x1,x2),2.0)),*(2.0,^(width,2.0))) +::STMT +MATRIX:images +LITERAL_FLOAT:255.0 +/(images,255.0) +::STMT +MATRIX:W,parsertemp411110,X,H +FLOAT:eps +*(W,/(%*%(X,t(H)),+(%*%(W,parsertemp411110),eps))) +::STMT +MATRIX:ytest +LITERAL_FLOAT:1.0 +/(cast.FLOAT(ytest),1.0) +::STMT +LITERAL_FLOAT:1.0,2.0,4.0,2003.0 +*(4.0,-(^(2003.0,2.0),1.0)) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:0.0,1.0 +-(exp(-(0.0,linear_terms)),1.0) +::STMT +FLOAT:m2X +LITERAL_FLOAT:1.0002795638803466 +*(m2X,1.0002795638803466) +::STMT +MATRIX:classCounts +LITERAL_FLOAT:100.0 +/(classCounts,100.0) +::STMT +MATRIX:parsertemp1904,y +LITERAL_FLOAT:-1.0 +sum(*(*(%*%(parsertemp1904,y),-1.0),*(%*%(parsertemp1904,y),-1.0))) +::STMT +MATRIX:means,Y,vars +LITERAL_FLOAT:2.0 +sum(/(^(-(Y,means),2.0),vars)) +::STMT +MATRIX:parsertemp409788,parsertemp409797 +LITERAL_FLOAT:0.0 +t(+(-(0.0,t(parsertemp409788)),t(colSums(parsertemp409797)))) +::STMT +MATRIX:parsertemp386438,neighbors +FLOAT:eps +LITERAL_FLOAT:0.0 +rowSums(*(<=(-(neighbors,parsertemp386438),eps),<(0.0,-(neighbors,parsertemp386438)))) +::STMT +MATRIX:obj,parsertemp44077 +FLOAT:int642,parsertemp44079 +LITERAL_FLOAT:2.0,0.5 +-(cast.FLOAT(obj),+(*(0.5,cast.FLOAT(parsertemp44077)),*(2.0,*(int642,parsertemp44079)))) +::STMT +MATRIX:weight +LITERAL_FLOAT:133.0 +/(weight,133.0) +::STMT +MATRIX:F +/(%*%(rowSums(F),colSums(F)),sum(F)) +::STMT +LITERAL_FLOAT:0.025 +0.025 +::STMT +FLOAT:42_m2X +LITERAL_FLOAT:1.001001001001001 +sqrt(*(42_m2X,1.001001001001001)) +::STMT +MATRIX:Y_Train,Y_Test +FLOAT:sumY,sumX,parsertemp251796,parsertemp251795 +abs(-(-(+(sumX,sumY),+(parsertemp251795,parsertemp251796)),+(sum(Y_Train),sum(Y_Test)))) +::STMT +MATRIX:V +FLOAT:var,mu +LITERAL_FLOAT:5.0 +>(V,+(mu,*(5.0,sqrt(var)))) +::STMT +MATRIX:V +FLOAT:var,mu +LITERAL_FLOAT:5.0 +<(V,-(mu,*(5.0,sqrt(var)))) +::STMT +MATRIX:p_CG +FLOAT:trust_delta_sq,z,pp_CG +sqrt(-(*(cast.FLOAT(p_CG),cast.FLOAT(p_CG)),*(pp_CG,-(z,trust_delta_sq)))) +::STMT +MATRIX:parsertemp171326,is_lt_pos,parsertemp171323,Y +FLOAT:float940 +LITERAL_FLOAT:0.3989422804014327 +*(0.3989422804014327,+(-(Y,*(parsertemp171326,is_lt_pos)),*(*(parsertemp171323,parsertemp171326),-(is_lt_pos,float940)))) +::STMT +FLOAT:vicinity,target_a0,a0 +LITERAL_FLOAT:1.0 ++(*(vicinity,target_a0),*(-(1.0,vicinity),a0)) +::STMT +MATRIX:_sbcvar92,220_r,220_c +LITERAL_FLOAT:0.0,1.0E-4 +*(==(/(%*%(220_r,220_c),sum(_sbcvar92)),0.0),1.0E-4) +::STMT +MATRIX:p,q,lambda +FLOAT:norm_r2 +*(/(norm_r2,cast.FLOAT(%*%(p,q))),+(q,*(lambda,p))) +::STMT +MATRIX:r +FLOAT:int383 +LITERAL_FLOAT:2.0,9.999999999999998E-15 +*(sum(^(-(int383,r),2.0)),9.999999999999998E-15) +::STMT +LITERAL_FLOAT:1.0,2.0,1500.0 +*(^(1500.0,2.0),-(1500.0,1.0)) +::STMT +MATRIX:B,parsertemp410245,X_t +LITERAL_FLOAT:0.0,2.0 +/(-(0.0,parsertemp410245),*(2.0,exp(%*%(X_t,B)))) +::STMT +MATRIX:r,Hd +FLOAT:c +LITERAL_FLOAT:0.0 +-(0.0,+(r,*(c,Hd))) +::STMT +MATRIX:Y +FLOAT:class +LITERAL_FLOAT:2.0 +*(2.0,==(Y,class)) +::STMT +MATRIX:qLow,length,qUp +LITERAL_FLOAT:0.0 +>(rowSums(|(<(length,qLow),>(length,qUp))),0.0) +::STMT +MATRIX:var_X_cols,parsertemp429917,parsertemp429915 +FLOAT:int636 +LITERAL_FLOAT:0.0,1.0,299.0 ++(*(/(-(parsertemp429915,parsertemp429917),299.0),-(1.0,<=(var_X_cols,int636))),<=(/(-(parsertemp429915,parsertemp429917),299.0),0.0)) +::STMT +MATRIX:parsertemp43635 +FLOAT:float100 +LITERAL_FLOAT:2.0 +sqrt(sum(^(+(float100,parsertemp43635),2.0))) +::STMT +FLOAT:window_size,n +LITERAL_FLOAT:2.0 ++(-(n,window_size),2.0) +::STMT +MATRIX:R,w +FLOAT:int794,int742 +INT:parsertemp31673,int163 ++(R,diag(*(rand(parsertemp31673,int163,int742,int794),cast.FLOAT(w)))) +::STMT +MATRIX:cumLeftHist,parsertemp132494,parsertemp132506,leftHist,outBucket ++(%*%(==(outBucket,t(parsertemp132494)),-(cumLeftHist,leftHist)),parsertemp132506) +::STMT +MATRIX:parsertemp72182 +LITERAL_FLOAT:1.0,8.0 ++(*(parsertemp72182,8.0),1.0) +::STMT +FLOAT:idx +LITERAL_FLOAT:1048.0 +-(1048.0,idx) +::STMT +MATRIX:parsertemp13626,parsertemp13624 +FLOAT:int992,43_q,int581 +LITERAL_FLOAT:1.0,1000.0 +/(sum(/(^(parsertemp13626,int581),/(parsertemp13624,int992))),*(1000.0,-(43_q,1.0))) +::STMT +MATRIX:subspace_idx,parsertemp72201 +FLOAT:subvector_size +LITERAL_FLOAT:1.0 +<(-(subspace_idx,*(parsertemp72201,subvector_size)),1.0) +::STMT +MATRIX:os,y,o +LITERAL_FLOAT:-1.0 +exp(*(*(y,-1.0),+(o,os))) +::STMT +MATRIX:atan_linear_terms +LITERAL_FLOAT:3.141592653589793,0.5 +-(0.5,/(atan_linear_terms,3.141592653589793)) +::STMT +MATRIX:linear_terms,Y +FLOAT:var_power +LITERAL_FLOAT:-1.0 +*(^(linear_terms,*(var_power,-1.0)),-(Y,linear_terms)) +::STMT +MATRIX:w,X,y +%*%(t(-(%*%(X,w),y)),-(%*%(X,w),y)) +::STMT +MATRIX:H,betamax,Hneg,Hpos,beta +FLOAT:INF,logU +LITERAL_FLOAT:0.0,2.0 +*(*(2.0,>=(-(H,logU),0.0)),==(+(*(Hpos,betamax),*(Hneg,beta)),INF)) +::STMT +LITERAL_FLOAT:1.0E-4 +1.0E-4 +::STMT +MATRIX:X,parsertemp16876 +FLOAT:epsilon,int288 ++(sqrt(rowSums(^(X,int288))),*(<(sqrt(parsertemp16876),epsilon),epsilon)) +::STMT +LITERAL_FLOAT:1400.0,20.0 +*(1400.0,20.0) +::STMT +MATRIX:lt_pos_neg +LITERAL_FLOAT:0.5 +-(0.5,lt_pos_neg) +::STMT +MATRIX:parsertemp389219,tmp,X,parsertemp389212 +FLOAT:int464 +LITERAL_FLOAT:1.0E-17 +/(-(%*%(tmp,X),parsertemp389212),+(sqrt(/(parsertemp389219,int464)),1.0E-17)) +::STMT +MATRIX:Y,linear_terms,vec1,is_y_0,parsertemp171270 +LITERAL_FLOAT:0.0 +-(-(/(+(Y,is_y_0),+(parsertemp171270,is_y_0)),==(Y,0.0)),*(*(Y,vec1),linear_terms)) +::STMT +MATRIX:Bx,Yd,Yu +/(-(Yu,Yd),*(Bx,Bx)) +::STMT +MATRIX:W +LITERAL_FLOAT:1.0,2.0,3.0 +*(*(-(sum(W),2.0),+(sum(W),1.0)),+(sum(round(W)),3.0)) +::STMT +MATRIX:cm,FD +FLOAT:n +LITERAL_FLOAT:1.0 ++(+(FD,==(cm,1.0)),==(t(cm),n)) +::STMT +MATRIX:r,alpha,Hd +*(-(r,*(cast.FLOAT(alpha),Hd)),-(r,*(cast.FLOAT(alpha),Hd))) +::STMT +MATRIX:X +LITERAL_FLOAT:2.0 +exp(*(2.0,X)) +::STMT +MATRIX:g,parsertemp169907 +FLOAT:parsertemp169913 +*(sum(*(+(g,parsertemp169907),+(g,parsertemp169907))),parsertemp169913) +::STMT +MATRIX:is_zero_y_corr,is_one_y_corr,parsertemp171113 +LITERAL_FLOAT:1.0,-0.36651292058166435 ++(-(parsertemp171113,*(-0.36651292058166435,+(is_zero_y_corr,is_one_y_corr))),/(is_one_y_corr,-(1.0,is_one_y_corr))) +::STMT +FLOAT:int112 +LITERAL_FLOAT:2.0 +INT:int809,parsertemp282730 +rand(parsertemp282730,int809,int112,2.0) +::STMT +MATRIX:vI +FLOAT:beg +LITERAL_FLOAT:1.0 +-(+(cast.FLOAT(vI),beg),1.0) +::STMT +MATRIX:parsertemp557211 +LITERAL_FLOAT:0.0 +==(diag(parsertemp557211),0.0) +::STMT +FLOAT:var,m4 +LITERAL_FLOAT:3.0,4.0 +-(/(m4,^(sqrt(var),4.0)),3.0) +::STMT +MATRIX:lambda,B_new +FLOAT:int37 +LITERAL_FLOAT:0.5 +*(0.5,sum(*(lambda,^(B_new,int37)))) +::STMT +MATRIX:parsertemp413082 +LITERAL_FLOAT:1.0 +-(max(round(parsertemp413082)),1.0) +::STMT +MATRIX:parsertemp410190,b,parsertemp410188 +cast.FLOAT(%*%(%*%(t(b),-(parsertemp410188,parsertemp410190)),b)) +::STMT +MATRIX:_sbcvar96,_sbcvar95,221_CMeans +FLOAT:int455 +LITERAL_FLOAT:2.0 +sum(*(%*%(_sbcvar95,_sbcvar96),^(+(221_CMeans,int455),2.0))) +::STMT +MATRIX:X_Xd_exp_Xb_rev_agg,parsertemp410050,d_r_rev,Hd_2_num,D_r_rev +colSums(*(-(/(X_Xd_exp_Xb_rev_agg,D_r_rev),/(Hd_2_num,parsertemp410050)),d_r_rev)) +::STMT +MATRIX:scale_X,w,ssX_p_CG,X +%*%(diag(scale_X),%*%(t(X),*(w,%*%(X,ssX_p_CG)))) +::STMT +MATRIX:lambda,parsertemp148883,parsertemp148882 +FLOAT:int12 +LITERAL_FLOAT:2.0 +sum(^(+(%*%(parsertemp148882,parsertemp148883),*(lambda,int12)),2.0)) +::STMT +MATRIX:img_in +FLOAT:h +LITERAL_FLOAT:2.0 +/(-(nrow(img_in),h),2.0) +::STMT +FLOAT:var_power,link_power +LITERAL_FLOAT:2.0 +/(-(2.0,var_power),link_power) +::STMT +FLOAT:dummy_coding_beg_col,dummy_coding_end_col +LITERAL_FLOAT:1.0 ++(-(dummy_coding_end_col,dummy_coding_beg_col),1.0) +::STMT +MATRIX:y_batch,parsertemp146892 +FLOAT:int243 +/(sum(*(-(int243,y_batch),parsertemp146892)),nrow(y_batch)) +::STMT +LITERAL_FLOAT:1.421413741 +1.421413741 +::STMT +MATRIX:P,parsertemp220889,Z,parsertemp220891 +FLOAT:int562,int464,int63,parsertemp220894 +rowSums(*(-(*(P,int562),/(Z,parsertemp220894)),*(/(int464,parsertemp220891),+(parsertemp220889,int63)))) +::STMT +MATRIX:316_unnorm_probs,316_scores +abs(-(/(exp(316_scores),rowSums(316_unnorm_probs)),/(exp(316_scores),rowSums(316_unnorm_probs)))) +::STMT +MATRIX:ss +LITERAL_FLOAT:1.0,40.0 +-(/(40.0,ss),1.0) +::STMT +FLOAT:idx +LITERAL_FLOAT:1024.0 +-(1024.0,idx) +::STMT +FLOAT:current_hash_value +LITERAL_FLOAT:1.0,3.0 +-(3.0,+(current_hash_value,1.0)) +::STMT +MATRIX:tmp +FLOAT:int239,N +LITERAL_FLOAT:0.0,1.0 +-(1.0,<=(/(tmp,-(N,int239)),0.0)) +::STMT +MATRIX:r,d,parsertemp43999 +LITERAL_FLOAT:2.0 +/(sum(^(r,2.0)),cast.FLOAT(%*%(t(d),+(d,parsertemp43999)))) +::STMT +MATRIX:parsertemp387501 +LITERAL_FLOAT:1.0 +cast.FLOAT(+(parsertemp387501,1.0)) +::STMT +FLOAT:cvk +LITERAL_FLOAT:300.0 +/(300.0,cvk) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:105.0,1.0 ++(rowSums(classFeatureCounts),*(105.0,1.0)) +::STMT +MATRIX:w_X,z_LS +LITERAL_FLOAT:1000.0 +*(/(1000.0,cast.FLOAT(%*%(w_X,z_LS))),z_LS) +::STMT +MATRIX:addedE,addedX +/(sum(addedE),nrow(addedX)) +::STMT +MATRIX:X_Train,X_Test,X,Y,Y_Train,Y_Test +-(-(+(sum(X),sum(Y)),+(sum(X_Train),sum(X_Test))),+(sum(Y_Train),sum(Y_Test))) +::STMT +MATRIX:resp +LITERAL_FLOAT:2.22E-16 ++(colSums(resp),2.22E-16) +::STMT +MATRIX:X_train +LITERAL_FLOAT:2.0 +sqrt(/(2.0,ncol(X_train))) +::STMT +FLOAT:b,c,rad +LITERAL_FLOAT:-1.0,2.0 +/(*(*(2.0,c),-1.0),+(b,rad)) +::STMT +LITERAL_FLOAT:0.802853 +0.802853 +::STMT +MATRIX:parsertemp394992,parsertemp394989,W3_rand +LITERAL_FLOAT:0.21483446221182986 +t(%*%(*(0.21483446221182986,W3_rand),t(/(parsertemp394989,parsertemp394992)))) +::STMT +MATRIX:Y +FLOAT:check_max,check_min +LITERAL_FLOAT:2.0 +-(*(/(2.0,-(check_max,check_min)),Y),/(+(check_min,check_max),-(check_max,check_min))) +::STMT +MATRIX:2434_2432_Y,W4_rand +FLOAT:float108 +LITERAL_FLOAT:2.0 +*(2.0,t(%*%(*(float108,W4_rand),t(2434_2432_Y)))) +::STMT +MATRIX:inactive_set,w +LITERAL_FLOAT:0.0 +abs(-(inactive_set,!=(w,0.0))) +::STMT +MATRIX:p,e,u,G +FLOAT:alpha +LITERAL_FLOAT:1.0 ++(*(alpha,%*%(G,p)),*(-(1.0,alpha),%*%(%*%(e,u),p))) +::STMT +LITERAL_FLOAT:80.0,1200.0 +*(1200.0,80.0) +::STMT +FLOAT:n +LITERAL_FLOAT:1.0,2.0,4.0 +-(+(-(n,4.0),1.0),2.0) +::STMT +MATRIX:parsertemp443564,parsertemp443530,parsertemp443567,mean,parsertemp443973,X +FLOAT:float834 ++(/(-(%*%(parsertemp443564,X),%*%(parsertemp443567,mean)),sum(+(parsertemp443530,float834))),diag(parsertemp443973)) +::STMT +MATRIX:2701_mask +LITERAL_FLOAT:0.5 +/(2701_mask,0.5) +::STMT +MATRIX:X,mask +FLOAT:p +/(*(X,mask),p) +::STMT +MATRIX:X,parsertemp382984 +LITERAL_FLOAT:0.0 +-(ncol(X),sum(!=(t(parsertemp382984),0.0))) +::STMT +MATRIX:parsertemp2782,parsertemp2786 +FLOAT:dd,parsertemp2779,step_sz,wd +-(step_sz,/(-(+(wd,parsertemp2779),sum(parsertemp2782)),+(dd,sum(parsertemp2786)))) +::STMT +MATRIX:parsertemp410245,parsertemp410247 +LITERAL_FLOAT:0.0,2.0,0.6666666666666666 +^(/(-(0.0,parsertemp410245),*(2.0,exp(parsertemp410247))),0.6666666666666666) +::STMT +FLOAT:factor_up,parsertemp195892,int529 +LITERAL_FLOAT:1.0,2.0 +/(-(-(*(int529,factor_up),parsertemp195892),1.0),2.0) +::STMT +MATRIX:p,q,V,parsertemp1939 +FLOAT:norm_r2 +LITERAL_FLOAT:1.0E-8 +*(/(norm_r2,cast.FLOAT(%*%(parsertemp1939,q))),+(%*%(t(V),%*%(V,p)),*(1.0E-8,p))) +::STMT +MATRIX:upd_W1 +LITERAL_FLOAT:0.8 +*(0.8,upd_W1) +::STMT +MATRIX:X +FLOAT:x +/(-(x,X),-(X,X)) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,50.0 +/(+(classFeatureCounts,1.0),+(rowSums(classFeatureCounts),*(50.0,1.0))) +::STMT +FLOAT:C,Hf,Wf +LITERAL_FLOAT:2.0 +sqrt(/(2.0,*(*(C,Hf),Wf))) +::STMT +MATRIX:R +FLOAT:int548 +LITERAL_FLOAT:0.0 +sum(==(colSums(!=(R,int548)),0.0)) +::STMT +MATRIX:parsertemp539203 +LITERAL_FLOAT:-1.0,1.0,2.0,1.5 +^(/(*(parsertemp539203,-1.0),2.0),/(1.0,1.5)) +::STMT +MATRIX:X,Y,K +FLOAT:x,int118 +*(+(*(*(K,int118),-(X,X)),-(Y,Y)),/(-(x,X),-(X,X))) +::STMT +MATRIX:cdf_min_distances +LITERAL_FLOAT:0.0,1.0 +INT:int159,num_runs +*(rand(int159,num_runs,0.0,1.0),cdf_min_distances) +::STMT +FLOAT:m2X,m2Y +LITERAL_FLOAT:1.000010000100001 +*(sqrt(*(m2X,1.000010000100001)),sqrt(*(m2Y,1.000010000100001))) +::STMT +MATRIX:Y,linear_terms +LITERAL_FLOAT:-1.0 +*(rowSums(Y),exp(*(exp(linear_terms),-1.0))) +::STMT +FLOAT:cmLabels +LITERAL_FLOAT:1.000100010001 +*(cmLabels,1.000100010001) +::STMT +MATRIX:sv,out +LITERAL_FLOAT:0.5 +*(0.5,sum(*(*(sv,out),*(sv,out)))) +::STMT +MATRIX:y +LITERAL_FLOAT:1.0,-1.0 +*(/(1.0,nrow(y)),*(y,-1.0)) +::STMT +MATRIX:current_node +FLOAT:cur_node_index ++(cur_node_index,cast.FLOAT(current_node)) +::STMT +MATRIX:Kss,parsertemp387410 +sqrt(abs(-(cast.FLOAT(Kss),cast.FLOAT(parsertemp387410)))) +::STMT +MATRIX:resp +LITERAL_FLOAT:2.22E-16 +sum(+(colSums(resp),2.22E-16)) +::STMT +MATRIX:xs +LITERAL_FLOAT:100.0,4.5 +-(100.0,sum(>=(xs,4.5))) +::STMT +MATRIX:parsertemp410978,W,H +rowSums(/(*(H,t(parsertemp410978)),t(colSums(W)))) +::STMT +MATRIX:z,beta ++(beta,cast.FLOAT(z)) +::STMT +MATRIX:X +FLOAT:int902 +t(sqrt(rowSums(^(X,int902)))) +::STMT +MATRIX:R +FLOAT:minSup +LITERAL_FLOAT:0.0 +&(>=(R,minSup),>(R,0.0)) +::STMT +MATRIX:parsertemp12898,CFreqs +FLOAT:int517 +LITERAL_FLOAT:1.0 +/(sum(*(CFreqs,^(parsertemp12898,int517))),-(nrow(CFreqs),1.0)) +::STMT +FLOAT:float605,int893,float444,int152 +LITERAL_FLOAT:1.0,3.0,6.0,2001.0 +/(*(*(6.0,2001.0),-(2001.0,1.0)),*(*(-(int893,float444),+(int152,float605)),+(2001.0,3.0))) +::STMT +MATRIX:parsertemp150380 +LITERAL_FLOAT:0.0,0.16 +sum(==(<(abs(parsertemp150380),0.16),0.0)) +::STMT +MATRIX:237_CVars,parsertemp29525,237_CFreqs,parsertemp29520 +LITERAL_FLOAT:1.0,10000.0 +/(/(sum(*(237_CFreqs,parsertemp29520)),-(nrow(237_CFreqs),1.0)),/(sum(*(parsertemp29525,237_CVars)),-(10000.0,nrow(237_CFreqs)))) +::STMT +LITERAL_FLOAT:96.0 +INT:int523,int607 +rand(int523,int607,96.0,96.0) +::STMT +MATRIX:colDuplicates,adjacency +LITERAL_FLOAT:0.0 +*(colDuplicates,>(rowSums(adjacency),0.0)) +::STMT +MATRIX:cdf_min_distances,random_row +t(colSums(<(cdf_min_distances,*(random_row,cdf_min_distances)))) +::STMT +MATRIX:s,d,alpha +t(+(s,*(cast.FLOAT(alpha),d))) +::STMT +MATRIX:parsertemp472298,I +LITERAL_FLOAT:0.0 +==(!=(*(t(parsertemp472298),I),0.0),0.0) +::STMT +MATRIX:parsertemp171318 +FLOAT:int591 +LITERAL_FLOAT:2.0,0.15915494309189535 +*(exp(/(-(int591,parsertemp171318),2.0)),0.15915494309189535) +::STMT +FLOAT:m2X +LITERAL_FLOAT:1.0005 +*(m2X,1.0005) +::STMT +MATRIX:H,betamax,beta +FLOAT:logU +LITERAL_FLOAT:0.0 ++(*(>=(-(H,logU),0.0),betamax),*(<(-(H,logU),0.0),beta)) +::STMT +MATRIX:key_unique,key +==(key_unique,t(key)) +::STMT +MATRIX:e_r_rev_agg,parsertemp409787,parsertemp409796 +LITERAL_FLOAT:0.0 ++(-(0.0,t(colSums(parsertemp409787))),t(colSums(/(parsertemp409796,e_r_rev_agg)))) +::STMT +MATRIX:parsertemp132498,offset,parsertemp132494,rightHist,mask,outBucket +LITERAL_FLOAT:1.0 +/(-(-(offset,%*%(mask,parsertemp132498)),1.0),%*%(==(outBucket,t(parsertemp132494)),rightHist)) +::STMT +MATRIX:r,parsertemp44050 +FLOAT:norm_r2 +LITERAL_FLOAT:2.0 +/(sum(^(-(r,parsertemp44050),2.0)),norm_r2) +::STMT +MATRIX:y_prob,ones_ctg +LITERAL_FLOAT:1.0 +*(y_prob,%*%(y_prob,-(1.0,diag(ones_ctg)))) +::STMT +MATRIX:tmp +LITERAL_FLOAT:1.0 +*(1.0,cast.FLOAT(%*%(t(tmp),tmp))) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0,2.0 +-(exp(*(2.0,X)),1.0) +::STMT +MATRIX:X +FLOAT:x +LITERAL_FLOAT:1.0 +-(1.0,/(-(x,X),-(X,X))) +::STMT +MATRIX:p,A,r,parsertemp51660 +FLOAT:norm_r2 ++(r,*(/(norm_r2,sum(parsertemp51660)),%*%(t(A),%*%(A,p)))) +::STMT +MATRIX:dout1,mb1 +FLOAT:192_beta1 +LITERAL_FLOAT:1.0 ++(*(192_beta1,mb1),*(-(1.0,192_beta1),colSums(dout1))) +::STMT +FLOAT:parsertemp31330 +LITERAL_FLOAT:9999.0,10000.0 +/(*(parsertemp31330,10000.0),9999.0) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0,2.0 +^(-(0.0,sum(X)),2.0) +::STMT +MATRIX:w +sum(abs(w)) +::STMT +MATRIX:ytest,yhat +LITERAL_FLOAT:1.0 +/(-(cast.FLOAT(ytest),cast.FLOAT(yhat)),1.0) +::STMT +LITERAL_FLOAT:1.00001 +1.00001 +::STMT +MATRIX:ss,se,e +LITERAL_FLOAT:1.0,20.0 +-(/(/(se,ss),/(sum(e),20.0)),1.0) +::STMT +FLOAT:parsertemp65,parsertemp66,mu +LITERAL_FLOAT:5.0 ++(mu,*(5.0,sqrt(/(parsertemp65,parsertemp66)))) +::STMT +MATRIX:BLOCKS +FLOAT:current_hash_value +LITERAL_FLOAT:1.0 +-(nrow(BLOCKS),+(current_hash_value,1.0)) +::STMT +MATRIX:parsertemp170158,parsertemp170136 +FLOAT:r_CG,g_reg,parsertemp170165,278_sq_root_d,z,parsertemp170150 +LITERAL_FLOAT:0.5 ++(*(0.5,*(cast.FLOAT(z),+(r_CG,g_reg))),*(+(+(parsertemp170165,z),sum(parsertemp170158)),/(+(parsertemp170150,278_sq_root_d),sum(parsertemp170136)))) +::STMT +MATRIX:W,parsertemp411198,X,H,parsertemp411200 +LITERAL_FLOAT:1.0E-8 +%*%(/(X,+(%*%(W,H),1.0E-8)),t(/(*(H,parsertemp411198),t(parsertemp411200)))) +::STMT +FLOAT:num_records +LITERAL_FLOAT:1.0,960.0 +-(1.0,/(960.0,num_records)) +::STMT +MATRIX:H2_prime,H3_prime,W2,W3,parsertemp389610 +%*%(*(H2_prime,%*%(*(H3_prime,parsertemp389610),W3)),W2) +::STMT +MATRIX:R,dssp +FLOAT:4_n +LITERAL_FLOAT:1.0 +-(/(4_n,+(R,dssp)),1.0) +::STMT +FLOAT:neg_log_l_change_predicted,log_l_change +LITERAL_FLOAT:-1.0 +/(*(log_l_change,-1.0),neg_log_l_change_predicted) +::STMT +MATRIX:tmp_c +FLOAT:i +LITERAL_FLOAT:1.0,12.0 ++(tmp_c,*(-(i,1.0),12.0)) +::STMT +LITERAL_FLOAT:300.0,1.0 ++(300.0,1.0) +::STMT +MATRIX:s,sts,d,parsertemp44023 +FLOAT:delta2 +LITERAL_FLOAT:2.0 ++(^(%*%(t(s),d),2.0),*(cast.FLOAT(%*%(parsertemp44023,d)),-(delta2,cast.FLOAT(sts)))) +::STMT +MATRIX:U,V,X,parsertemp382841,row_nonzeros +FLOAT:reg,int524 ++(%*%(*(!=(X,int524),-(parsertemp382841,X)),V),*(*(reg,U),row_nonzeros)) +::STMT +MATRIX:C,Xm,parsertemp265702 +%*%(colSums(%*%(%*%(Xm,parsertemp265702),t(C))),rowSums(t(Xm))) +::STMT +MATRIX:Y +FLOAT:parsertemp185166 +-(cast.MATRIX(max(Y)),parsertemp185166) +::STMT +MATRIX:V,y +LITERAL_FLOAT:-1.0 +*(*(%*%(t(V),y),-1.0),-1.0) +::STMT +MATRIX:n_event_stratum,n_risk_stratum,n_risk +LITERAL_FLOAT:2.0 +*(*(^(n_risk_stratum,2.0),*(n_risk,n_event_stratum)),-(n_risk_stratum,n_event_stratum)) +::STMT +MATRIX:A,scale_lambda,scale_X,shift_X,parsertemp115882 +LITERAL_FLOAT:0.001 ++(+(%*%(diag(scale_X),t(parsertemp115882)),%*%(shift_X,A)),diag(*(scale_lambda,0.001))) +::STMT +MATRIX:parsertemp286680,lambda,scale_X,gXY,beta +cast.FLOAT(%*%(t(+(scale_X,parsertemp286680)),+(*(scale_X,gXY),*(lambda,beta)))) +::STMT +MATRIX:parsertemp443534,resp,parsertemp443566,parsertemp443533,X,weight +LITERAL_FLOAT:2.22E-16 +/(-(%*%(t(X),X),%*%(*(parsertemp443566,weight),/(parsertemp443533,parsertemp443534))),sum(+(colSums(resp),2.22E-16))) +::STMT +MATRIX:_funvar2124,parsertemp437267,parsertemp437277,parsertemp437272 +-(+(_funvar2124,parsertemp437267),+(parsertemp437272,parsertemp437277)) +::STMT +LITERAL_FLOAT:0.19999999999999996 +0.19999999999999996 +::STMT +MATRIX:m_err +/(colSums(m_err),sum(colSums(m_err))) +::STMT +FLOAT:check_max,check_min +/(+(check_min,check_max),-(check_max,check_min)) +::STMT +MATRIX:p_CG,z +*(*(cast.FLOAT(z),sum(p_CG)),*(cast.FLOAT(z),sum(p_CG))) +::STMT +LITERAL_FLOAT:100.0,0.8 +*(100.0,0.8) +::STMT +MATRIX:s,w +cast.FLOAT(%*%(t(+(w,s)),+(w,s))) +::STMT +FLOAT:int443,int775,weight,prob_true,prob_false +LITERAL_FLOAT:1.0 +*(weight,-(1.0,+(^(prob_true,int443),^(prob_false,int775)))) +::STMT +MATRIX:prec_chol,X +LITERAL_FLOAT:2.0 +%*%(rowSums(*(X,X)),t(^(prec_chol,2.0))) +::STMT +MATRIX:tmp,leftIdx +%*%(tmp,%*%(t(tmp),leftIdx)) +::STMT +LITERAL_FLOAT:0.2 +0.2 +::STMT +MATRIX:w,X,y +LITERAL_FLOAT:-1.0 +*(*(y,-1.0),%*%(X,w)) +::STMT +MATRIX:parsertemp220844,ZERODIAG,beta +rowSums(*(exp(*(parsertemp220844,beta)),ZERODIAG)) +::STMT +MATRIX:scale_X,w,ssX_p_CG,X +*(scale_X,%*%(t(X),*(w,%*%(X,ssX_p_CG)))) +::STMT +MATRIX:newbeta,lambda +FLOAT:int214 +LITERAL_FLOAT:0.5 +*(0.5,cast.FLOAT(%*%(t(lambda),^(newbeta,int214)))) +::STMT +MATRIX:79_77_X_row_norm,Y_block,parsertemp17170,79_77_Y_row_norm,X_block +LITERAL_FLOAT:0.9 +>(/(%*%(X_block,t(Y_block)),%*%(+(79_77_X_row_norm,parsertemp17170),t(79_77_Y_row_norm))),0.9) +::STMT +LITERAL_FLOAT:0.0 +INT:int576,int409 +cast.FLOAT(rand(int409,int576,0.0,0.0)) +::STMT +MATRIX:var_X_cols,parsertemp1517,parsertemp1515 +FLOAT:int932,int191,int490,n +LITERAL_FLOAT:0.0,1.0 ++(*(/(-(parsertemp1515,parsertemp1517),-(n,int932)),-(1.0,<=(var_X_cols,int191))),<=(/(-(parsertemp1515,parsertemp1517),-(n,int490)),0.0)) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:0.0 +*(^(exp(linear_terms),0.0),exp(linear_terms)) +::STMT +MATRIX:parsertemp42200,parsertemp42201,_sbcvar330 +FLOAT:meanX +LITERAL_FLOAT:1.0,0.5 +*(/(_sbcvar330,-(sum(_sbcvar330),1.0)),-(+(-(parsertemp42200,parsertemp42201),0.5),meanX)) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:-1.0 +*(exp(*(linear_terms,-1.0)),-1.0) +::STMT +MATRIX:parsertemp552345,tab,catTotal +LITERAL_FLOAT:-1.0 +*(*(/(tab,catTotal),-1.0),parsertemp552345) +::STMT +MATRIX:m_active_flag_tmp,m_active_flag +LITERAL_FLOAT:1.0 +-(>=(+(m_active_flag,m_active_flag_tmp),1.0),1.0) +::STMT +FLOAT:n_false,n_true,n_vars +/(+(n_true,n_false),n_vars) +::STMT +MATRIX:G,minDist +LITERAL_FLOAT:0.0 +*(!=(G,0.0),minDist) +::STMT +LITERAL_FLOAT:0.05 +0.05 +::STMT +LITERAL_FLOAT:-0.05 +-0.05 +::STMT +MATRIX:id +diag(==(id,cast.FLOAT(id))) +::STMT +MATRIX:grad +LITERAL_FLOAT:2.0 +sqrt(sum(^(grad,2.0))) +::STMT +MATRIX:select,d_r_rev,X_exp_Xb_rev_agg,D_r_rev +colSums(*(/(%*%(select,X_exp_Xb_rev_agg),D_r_rev),d_r_rev)) +::STMT +MATRIX:parsertemp43993,d,X,Hd,parsertemp44001 +*(cast.FLOAT(/(sum(parsertemp43993),%*%(parsertemp44001,Hd))),%*%(X,d)) +::STMT +MATRIX:parsertemp10964,C +LITERAL_FLOAT:100.0 +*(/(sum(==(parsertemp10964,C)),100.0),100.0) +::STMT +MATRIX:parsertemp31024,parsertemp31022 +FLOAT:int113 +LITERAL_FLOAT:2.0,99.0 +^(/(-(colSums(parsertemp31022),*(int113,parsertemp31024)),99.0),2.0) +::STMT +MATRIX:r,parsertemp44050 +LITERAL_FLOAT:2.0 +sqrt(sum(^(-(r,parsertemp44050),2.0))) +::STMT +MATRIX:Y_counts,Y +%*%(Y_counts,/(colSums(Y),sum(Y_counts))) +::STMT +MATRIX:prec_chol,mu +LITERAL_FLOAT:2.0 +*(mu,^(prec_chol,2.0)) +::STMT +LITERAL_FLOAT:0.4 +0.4 +::STMT +MATRIX:classFeatureCounts +rowSums(classFeatureCounts) +::STMT +MATRIX:parsertemp116065,p,r,lambda,shift_X,parsertemp116069 +FLOAT:norm_r2 ++(r,*(/(norm_r2,sum(parsertemp116069)),+(+(parsertemp116065,shift_X),*(lambda,p)))) +::STMT +MATRIX:TKC +cast.FLOAT(/(TKC,TKC)) +::STMT +MATRIX:p_LS,X +%*%(%*%(t(X),X),p_LS) +::STMT +FLOAT:m2,wt,float608 +LITERAL_FLOAT:3.0 +^(sqrt(/(*(m2,wt),-(wt,float608))),3.0) +::STMT +MATRIX:p_LS,tmp +FLOAT:norm_r2_LS +/(norm_r2_LS,cast.FLOAT(%*%(t(p_LS),tmp))) +::STMT +LITERAL_FLOAT:0.6546536707079771 +0.6546536707079771 +::STMT +FLOAT:parsertemp149336,obj,parsertemp149333,float101,qk,parsertemp149340 +/(-(obj,+(+(parsertemp149333,parsertemp149336),*(float101,parsertemp149340))),qk) +::STMT +MATRIX:d_r_rev,X_exp_Xb_rev_agg,D_r_rev +t(colSums(*(/(X_exp_Xb_rev_agg,D_r_rev),d_r_rev))) +::STMT +FLOAT:log_l,new_log_l ++(abs(log_l),abs(new_log_l)) +::STMT +MATRIX:d,parsertemp410053 +cast.FLOAT(%*%(t(d),t(colSums(parsertemp410053)))) +::STMT +MATRIX:Y_counts,means,parsertemp560511 +sum(*(Y_counts,rowSums(*(means,parsertemp560511)))) +::STMT +MATRIX:Y,Xd,Xw +FLOAT:step_sz +*(Y,+(Xw,*(step_sz,Xd))) +::STMT +MATRIX:2697_b,parsertemp459149,2697_W,outd3 +-(+(%*%(outd3,2697_W),2697_b),parsertemp459149) +::STMT +MATRIX:B,X_t +LITERAL_FLOAT:2.0 +*(2.0,exp(%*%(X_t,B))) +::STMT +MATRIX:D,beta +LITERAL_FLOAT:0.0 +exp(*(-(0.0,D),beta)) +::STMT +MATRIX:r,s,grad +LITERAL_FLOAT:-0.5 +*(-0.5,-(%*%(t(s),grad),%*%(t(s),r))) +::STMT +MATRIX:p,lambda,parsertemp1590,parsertemp1589 +sum(*(p,+(%*%(parsertemp1589,parsertemp1590),*(lambda,p)))) +::STMT +LITERAL_FLOAT:0.050000000000000044 +0.050000000000000044 +::STMT +FLOAT:m2,float774,wt +LITERAL_FLOAT:4.0 +^(sqrt(/(*(m2,wt),-(wt,float774))),4.0) +::STMT +MATRIX:d,parsertemp43996,parsertemp43997 +FLOAT:C +%*%(t(d),+(d,*(C,%*%(parsertemp43996,parsertemp43997)))) +::STMT +LITERAL_FLOAT:750.0 +*(750.0,750.0) +::STMT +MATRIX:X_Xd_exp_Xb_rev_agg,select,d_r_rev,X_exp_Xb_rev_agg,D_r_rev,Xd_exp_Xb_rev_agg +FLOAT:int929 +*(-(/(%*%(select,X_Xd_exp_Xb_rev_agg),D_r_rev),/(*(X_exp_Xb_rev_agg,Xd_exp_Xb_rev_agg),^(D_r_rev,int929))),d_r_rev) +::STMT +MATRIX:dout1 +LITERAL_FLOAT:2.0 +^(colSums(dout1),2.0) +::STMT +MATRIX:X +FLOAT:int174 +max(sqrt(rowSums(^(X,int174)))) +::STMT +MATRIX:p_LS,parsertemp170552 +FLOAT:lambda_LS +*(cast.FLOAT(p_LS),+(*(cast.FLOAT(parsertemp170552),cast.FLOAT(p_LS)),*(lambda_LS,cast.FLOAT(p_LS)))) +::STMT +LITERAL_FLOAT:2.0,0.5,-0.5 +INT:int818,int737 +sum(^(rand(int737,int818,-0.5,0.5),2.0)) +::STMT +FLOAT:nFeats +LITERAL_FLOAT:3.141592653589793,2.0 +^(*(2.0,3.141592653589793),nFeats) +::STMT +MATRIX:2701_mask,2700_W,parsertemp459178,2699_dtemp,2703_X,2702_X +FLOAT:float56,int493 +%*%(t(2703_X),*(*(>(2702_X,int493),/(2701_mask,float56)),%*%(-(2699_dtemp,parsertemp459178),t(2700_W)))) +::STMT +FLOAT:i +LITERAL_FLOAT:1.0,64.0 +-(+(i,64.0),1.0) +::STMT +LITERAL_FLOAT:0.8 +0.8 +::STMT +MATRIX:rowSums_X_sq +LITERAL_FLOAT:6.144102863722254 +/(6.144102863722254,max(sqrt(rowSums_X_sq))) +::STMT +MATRIX:parsertemp44025,s,d +FLOAT:delta2 ++(*(%*%(t(s),d),%*%(t(s),d)),*(%*%(t(d),d),-(delta2,%*%(parsertemp44025,s)))) +::STMT +FLOAT:sample_block_size,num_samples +LITERAL_FLOAT:1.0 ++(*(sample_block_size,num_samples),1.0) +::STMT +MATRIX:b4,W4,parsertemp389337 +LITERAL_FLOAT:2.0 +*(2.0,t(+(%*%(W4,parsertemp389337),b4))) +::STMT +MATRIX:g_Y,scale_X,X +LITERAL_FLOAT:0.0 +*(scale_X,-(0.0,%*%(t(X),g_Y))) +::STMT +MATRIX:features,beta_unscaled +FLOAT:intercept +LITERAL_FLOAT:3.0 +*(3.0,+(%*%(features,beta_unscaled),intercept)) +::STMT +MATRIX:R +LITERAL_FLOAT:0.0,4.0 +&(>(R,0.0),>=(R,4.0)) +::STMT +MATRIX:tmp,X,parsertemp389212 +-(%*%(tmp,X),parsertemp389212) +::STMT +LITERAL_FLOAT:0.16 +0.16 +::STMT +MATRIX:vectors,pq_result +LITERAL_FLOAT:2.0 +colSums(rowSums(^(-(vectors,pq_result),2.0))) +::STMT +MATRIX:parsertemp285516 +FLOAT:pp,parsertemp285518,parsertemp285520 +LITERAL_FLOAT:-1.0 +/(-(*(sum(parsertemp285516),-1.0),sqrt(-(parsertemp285518,parsertemp285520))),pp) +::STMT +MATRIX:221_present_domain_vals_mat,parsertemp27770 +t(sqrt(%*%(221_present_domain_vals_mat,parsertemp27770))) +::STMT +MATRIX:WM,Y +/(sum(*(Y,WM)),sum(WM)) +::STMT +MATRIX:X_nonzero_ind +LITERAL_FLOAT:0.0,6.0 +-(6.0,sum(!=(rowSums(X_nonzero_ind),0.0))) +::STMT +MATRIX:m_active_flag_tmp +LITERAL_FLOAT:1.0 +sum(!=(m_active_flag_tmp,1.0)) +::STMT +MATRIX:d,parsertemp410052,d_r_rev +%*%(t(d),t(colSums(*(parsertemp410052,d_r_rev)))) +::STMT +MATRIX:p,q +FLOAT:norm_r2 +/(norm_r2,sum(*(p,+(q,q)))) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0 +-(0.0,t(colSums(X))) +::STMT +MATRIX:Xtest_dists +LITERAL_FLOAT:0.0,1.0 +*(<=(Xtest_dists,1.0),<(0.0,Xtest_dists)) +::STMT +MATRIX:parsertemp393595,tmp,X,parsertemp393475,parsertemp393466 +LITERAL_FLOAT:1.0,1.0E-17 +-(/(-(exp(parsertemp393595),1.0),+(exp(parsertemp393595),1.0)),/(-(%*%(tmp,X),parsertemp393466),+(sqrt(parsertemp393475),1.0E-17))) +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0005002501250626,5.0 +*(5.0,sqrt(*(1.0005002501250626,m2))) +::STMT +MATRIX:parsertemp31104,parsertemp31106 +FLOAT:int713 +LITERAL_FLOAT:1.0,2000.0 +/(/(-(colSums(parsertemp31104),*(int713,parsertemp31106)),-(2000.0,1.0)),2000.0) +::STMT +FLOAT:K +LITERAL_FLOAT:301.0 +*(301.0,K) +::STMT +MATRIX:lambda,g,parsertemp285556,beta +cast.FLOAT(%*%(t(+(g,parsertemp285556)),+(g,*(lambda,beta)))) +::STMT +MATRIX:distT +LITERAL_FLOAT:0.0 +sum(!=(distT,0.0)) +::STMT +MATRIX:parsertemp137844 +rev(t(parsertemp137844)) +::STMT +MATRIX:d_r +t(rev(d_r)) +::STMT +FLOAT:B,R,s +LITERAL_FLOAT:1.0 +/(/(B,R),+(s,1.0)) +::STMT +LITERAL_FLOAT:2.0,150.0 +^(150.0,2.0) +::STMT +MATRIX:n_risk_stratum,n_risk_i2j,V1 +FLOAT:I_i1i2 +sum(*(V1,-(I_i1i2,/(n_risk_i2j,n_risk_stratum)))) +::STMT +FLOAT:float246,d_eee,x +LITERAL_FLOAT:2.302585092994046 +*(x,exp(*(2.302585092994046,-(float246,d_eee)))) +::STMT +MATRIX:flip_neg,is_LT_infinite,Y_prob,Y,parsertemp171293 +*(Y,%*%(+(*(Y_prob,parsertemp171293),is_LT_infinite),flip_neg)) +::STMT +MATRIX:classFeatureCounts +LITERAL_FLOAT:1.0,500.0 +/(+(classFeatureCounts,1.0),+(rowSums(classFeatureCounts),*(500.0,1.0))) +::STMT +MATRIX:g_Y,parsertemp171599,scale_X,shift_X,gXY +FLOAT:int545 ++(%*%(diag(scale_X),%*%(*(parsertemp171599,int545),g_Y)),%*%(shift_X,gXY)) +::STMT +MATRIX:cdf_min_distances,random_row +<(cdf_min_distances,*(random_row,cdf_min_distances)) +::STMT +MATRIX:parsertemp1532,y +LITERAL_FLOAT:2.0,9.999999999999998E-15 +*(sum(^(%*%(parsertemp1532,y),2.0)),9.999999999999998E-15) +::STMT +MATRIX:clusterMembers,adjacency +LITERAL_FLOAT:0.0 +>(*(clusterMembers,>(rowSums(adjacency),0.0)),0.0) +::STMT +MATRIX:ts +FLOAT:q +-(q,*(cast.FLOAT(ts),cast.FLOAT(ts))) +::STMT +FLOAT:max_features,n +/(^(n,max_features),n) +::STMT +LITERAL_FLOAT:1.000010000100001 +1.000010000100001 +::STMT +LITERAL_FLOAT:0.02 +0.02 +::STMT +FLOAT:i +LITERAL_FLOAT:100.0 +*(*(i,100.0),100.0) +::STMT +MATRIX:parsertemp410118,g0_1 +LITERAL_FLOAT:2.0 +sum(^(+(g0_1,t(parsertemp410118)),2.0)) +::STMT +MATRIX:d,dtd,parsertemp44021 +FLOAT:sts,delta2 +LITERAL_FLOAT:2.0 +sqrt(+(^(%*%(parsertemp44021,d),2.0),*(cast.FLOAT(dtd),-(delta2,sts)))) +::STMT +LITERAL_FLOAT:64.0 +INT:int753,int690 +rand(int690,int753,64.0,64.0) +::STMT +MATRIX:287_x,287_y,one_featureX +LITERAL_FLOAT:2.0 +<(one_featureX,/(+(cast.FLOAT(287_x),cast.FLOAT(287_y)),2.0)) +::STMT +MATRIX:Ileft +FLOAT:min_leaf +>=(rowSums(Ileft),min_leaf) +::STMT +MATRIX:parsertemp472315,parsertemp472326 +FLOAT:beg ++(-(nrow(parsertemp472315),cast.FLOAT(parsertemp472326)),beg) +::STMT +MATRIX:parsertemp402078,W3_rand +FLOAT:int259,int106 +LITERAL_FLOAT:0.1092173494617922 +%*%(*(0.1092173494617922,W3_rand),t(/(-(parsertemp402078,int106),+(parsertemp402078,int259)))) +::STMT +MATRIX:X +LITERAL_FLOAT:-1.0 +*(t(X),-1.0) +::STMT +LITERAL_FLOAT:0.10940797384659613 +0.10940797384659613 +::STMT +FLOAT:m2 +LITERAL_FLOAT:1.0005 +*(1.0005,m2) +::STMT +MATRIX:_sbcvar11 +LITERAL_FLOAT:1000.0 +-(_sbcvar11,/(%*%(rowSums(_sbcvar11),colSums(_sbcvar11)),1000.0)) +::STMT +LITERAL_FLOAT:200.0,1.0 ++(200.0,1.0) +::STMT +FLOAT:i2,n +LITERAL_FLOAT:24.0 +-(n,*(24.0,i2)) +::STMT +MATRIX:mb1,parsertemp146957,188_dX +FLOAT:beta1 +LITERAL_FLOAT:1.0 ++(*(beta1,mb1),*(-(1.0,beta1),colSums(*(parsertemp146957,188_dX)))) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:1.0 +^(exp(linear_terms),1.0) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:-1.0 +^(exp(linear_terms),-1.0) +::STMT +MATRIX:tmp,Y +1-*(Y,tmp) +::STMT +MATRIX:p_CG +FLOAT:trust_delta_sq,int444,z,pp_CG +-(*(*(cast.FLOAT(z),sum(p_CG)),*(cast.FLOAT(z),sum(p_CG))),*(pp_CG,-(^(z,int444),trust_delta_sq))) +::STMT +MATRIX:linear_terms +FLOAT:var_power +LITERAL_FLOAT:-1.0 +^(linear_terms,*(var_power,-1.0)) +::STMT +MATRIX:y_hat,X_adapted +FLOAT:parsertemp176421,k,parsertemp176418 +|(<(X_adapted,-(sqrt(parsertemp176421),*(k,y_hat))),>(X_adapted,+(sqrt(parsertemp176418),*(k,y_hat)))) +::STMT +MATRIX:X_adapted,yhat +FLOAT:int587,int291,parsertemp176418 +|(<(X_adapted,-(sqrt(parsertemp176418),*(int587,yhat))),>(X_adapted,+(sqrt(parsertemp176418),*(int291,yhat)))) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:0.0 +^(exp(linear_terms),0.0) +::STMT +MATRIX:parsertemp413082 +LITERAL_FLOAT:1.0,21.0 +*(21.0,-(max(round(parsertemp413082)),1.0)) +::STMT +MATRIX:y_train,prediction +LITERAL_FLOAT:0.5 +sum(==(y_train,>(prediction,0.5))) +::STMT +MATRIX:r,d,parsertemp43998 +FLOAT:C +/(sum(*(r,r)),%*%(t(d),+(d,*(C,parsertemp43998)))) +::STMT +MATRIX:parsertemp31029,parsertemp31031 +FLOAT:int586 +LITERAL_FLOAT:149.0,2.0 +^(/(-(colSums(parsertemp31029),*(int586,parsertemp31031)),149.0),2.0) +::STMT +MATRIX:_sbcvar264,_sbcvar262 +FLOAT:int495,int563,parsertemp31330 +LITERAL_FLOAT:9999.0 +/(sum(*(-(_sbcvar262,int495),_sbcvar264)),*(9999.0,/(*(parsertemp31330,int563),9999.0))) +::STMT +MATRIX:p,r +FLOAT:norm_r2,int58 +*(/(sum(^(r,int58)),norm_r2),p) +::STMT +MATRIX:A,CVars,CFreqs +FLOAT:W,int623,parsertemp12882,float120 +LITERAL_FLOAT:1.0 +/(sum(*(-(CFreqs,int623),CVars)),*(-(nrow(A),1.0),/(*(parsertemp12882,W),-(W,float120)))) +::STMT +MATRIX:X +LITERAL_FLOAT:1.0 +-(nrow(X),1.0) +::STMT +MATRIX:is_zero_y_corr,is_one_y_corr,parsertemp171100,parsertemp171086,parsertemp171097 +FLOAT:float279,float397 +LITERAL_FLOAT:1.0 +-(+(*(+(parsertemp171086,parsertemp171097),-(float279,parsertemp171100)),/(is_one_y_corr,-(float397,is_one_y_corr))),/(is_zero_y_corr,-(1.0,is_zero_y_corr))) +::STMT +FLOAT:502_strideh,502_padh,parsertemp193094,int645,502_Hf +LITERAL_FLOAT:0.0 ++(+(-(*(502_strideh,parsertemp193094),*(int645,502_padh)),502_Hf),0.0) +::STMT +MATRIX:posSamples,posSampleMeans +FLOAT:int928,int149 +LITERAL_FLOAT:1.0,7000.0 +/(-(colSums(^(posSamples,int149)),*(7000.0,^(posSampleMeans,int928))),-(7000.0,1.0)) +::STMT +MATRIX:R +LITERAL_FLOAT:0.0,32.0 +&(>(R,0.0),>=(R,32.0)) +::STMT +MATRIX:posSampleVariances,negSampleVariances +FLOAT:int366,int542,int961,int998 +LITERAL_FLOAT:3.42951E11,2.0,3.37275E9 +/(^(+(/(posSampleVariances,int961),/(negSampleVariances,int366)),2.0),+(/(^(posSampleVariances,int542),3.42951E11),/(^(negSampleVariances,int998),3.37275E9))) +::STMT +MATRIX:d,od,X,logisticD +LITERAL_FLOAT:2.0 ++(d,*(2.0,%*%(t(X),*(logisticD,od)))) +::STMT +MATRIX:_sbcvar78 +LITERAL_FLOAT:10000.0 +-(_sbcvar78,/(%*%(rowSums(_sbcvar78),colSums(_sbcvar78)),10000.0)) +::STMT +MATRIX:is_zero_y_corr,is_one_y_corr,parsertemp171113 +LITERAL_FLOAT:-0.36651292058166435 +-(parsertemp171113,*(-0.36651292058166435,+(is_zero_y_corr,is_one_y_corr))) +::STMT +FLOAT:C,K +LITERAL_FLOAT:2.0 +^(*(C,K),2.0) +::STMT +FLOAT:link_power +LITERAL_FLOAT:-1.0 +*(-1.0,link_power) +::STMT +MATRIX:parsertemp44025,s,d +FLOAT:delta2 ++(*(%*%(t(s),d),%*%(t(s),d)),*(%*%(t(d),d),-(delta2,%*%(parsertemp44025,s)))) +::STMT +MATRIX:Y_prob,Y +LITERAL_FLOAT:0.0 +*(<=(Y_prob,0.0),abs(Y)) +::STMT +FLOAT:approx_sample_size +LITERAL_FLOAT:10.0 +*(10.0,sqrt(approx_sample_size)) +::STMT +MATRIX:is_row_in_samples,parsertemp77566 +LITERAL_FLOAT:7075.0 +-(7075.0,*(is_row_in_samples,parsertemp77566)) +::STMT +MATRIX:dout1 +FLOAT:192_beta2 +LITERAL_FLOAT:1.0,2.0 +*(-(1.0,192_beta2),^(colSums(dout1),2.0)) +::STMT +FLOAT:parsertemp170147,parsertemp170145,p_CG,z +LITERAL_FLOAT:-1.0,2.0 +/(-(*(*(z,p_CG),-1.0),sqrt(-(parsertemp170145,parsertemp170147))),sum(^(p_CG,2.0))) +::STMT +FLOAT:m2,float248,mu,wt +/(sqrt(/(*(m2,wt),-(wt,float248))),mu) +::STMT +FLOAT:x,parsertemp169816,float183 +round(*(x,exp(*(float183,parsertemp169816)))) +::STMT +MATRIX:scale_lambda,X +LITERAL_FLOAT:1.0E-7 ++(%*%(t(X),X),diag(*(scale_lambda,1.0E-7))) +::STMT +MATRIX:cdf_min_distances +LITERAL_FLOAT:0.0,1.0 +INT:int795,num_runs +<(cdf_min_distances,*(rand(int795,num_runs,0.0,1.0),cdf_min_distances)) +::STMT +FLOAT:trust_delta_sq,p_CG,z,pp_CG +sqrt(-(*(*(z,p_CG),*(z,p_CG)),*(pp_CG,-(z,trust_delta_sq)))) +::STMT +MATRIX:A +LITERAL_FLOAT:1.0E-4 +<=(abs(-(A,t(A))),+(1.0E-4,abs(t(A)))) +::STMT +MATRIX:linear_terms +LITERAL_FLOAT:1.0 ++(1.0,exp(linear_terms)) +::STMT +MATRIX:parsertemp220889,Y,parsertemp221025,parsertemp220891 +FLOAT:int275,int359,int130 +LITERAL_FLOAT:1.0 +/(*(/(1.0,+(Y,int130)),+(diag(parsertemp221025),1.0)),sum(*(/(int275,parsertemp220891),+(parsertemp220889,int359)))) +::STMT +FLOAT:int242,lratio_t +LITERAL_FLOAT:1.0,50.0 +-(1.0,exp(/(*(lratio_t,int242),50.0))) +::STMT +MATRIX:Y_prob +FLOAT:int909 +LITERAL_FLOAT:0.0,1.0 ++(*(Y_prob,-(1.0,<=(Y_prob,int909))),<=(Y_prob,0.0)) +::STMT +MATRIX:m_err +/(colSums(m_err),cast.FLOAT(rowSums(colSums(m_err)))) +::STMT +LITERAL_FLOAT:1.0,1000.0 +/(1000.0,-(1000.0,1.0)) +::STMT +MATRIX:parsertemp389186,parsertemp389189 +LITERAL_FLOAT:1.0,2.0 +^(/(-(exp(parsertemp389186),1.0),+(exp(parsertemp389189),1.0)),2.0) +::STMT +MATRIX:logisticnew +LITERAL_FLOAT:1.0 +-(1.0,logisticnew) +::STMT +MATRIX:W1_rand,stds,parsertemp397732 +LITERAL_FLOAT:0.086386842558136 +t(%*%(*(0.086386842558136,W1_rand),t(/(parsertemp397732,stds)))) +::STMT +MATRIX:parsertemp183431,X +FLOAT:N +LITERAL_FLOAT:1.0 +*(/(N,-(N,1.0)),%*%(t(/(parsertemp183431,N)),/(colSums(X),N))) +::STMT +MATRIX:s,w +t(+(w,s)) +::STMT +FLOAT:norm_grad +LITERAL_FLOAT:0.1 +*(0.1,norm_grad) +::STMT +MATRIX:I1 +LITERAL_FLOAT:2.0 +*(2.0,cast.FLOAT(I1)) +::STMT +MATRIX:Nc +==(Nc,max(Nc)) +::STMT +MATRIX:parsertemp175077,parsertemp175081,R1 +LITERAL_FLOAT:1.0E-6 +<(abs(-(R1,/(parsertemp175077,parsertemp175081))),1.0E-6) +::STMT +FLOAT:parsertemp386966 +sum(cast.MATRIX(parsertemp386966)) +::STMT +FLOAT:n_components,cov_param,n_features ++(+(cov_param,*(n_features,n_components)),n_components) +::STMT +LITERAL_FLOAT:0.282842712474619 +0.282842712474619 +::STMT +LITERAL_FLOAT:1.0,0.8 +-(1.0,0.8) +::STMT +MATRIX:X,K +LITERAL_FLOAT:-1.0 +*(*(K,-1.0),-(X,X)) +::STMT +MATRIX:parsertemp397837,W4_rand +FLOAT:int375,int658 +LITERAL_FLOAT:0.0873148795050037 +%*%(*(0.0873148795050037,W4_rand),t(/(-(parsertemp397837,int658),+(parsertemp397837,int375)))) +::STMT +MATRIX:parsertemp42200,parsertemp42201,F +FLOAT:int25,int329,meanX +LITERAL_FLOAT:1.0 +*(/(F,-(sum(F),1.0)),-(+(-(parsertemp42200,parsertemp42201),/(int329,int25)),meanX)) +::STMT +MATRIX:parsertemp170136 +FLOAT:trust_delta_sq,p_CG,z +sqrt(-(*(*(z,p_CG),*(z,p_CG)),*(sum(parsertemp170136),-(z,trust_delta_sq)))) +::STMT +MATRIX:parsertemp220853,parsertemp220854,betamax,Hneg,Hpos,beta +LITERAL_FLOAT:0.0,3.4011973816621555,1.0E20 +*(>=(-(+(parsertemp220853,parsertemp220854),3.4011973816621555),0.0),!=(+(*(Hpos,betamax),*(Hneg,beta)),1.0E20)) +::STMT +MATRIX:X,Y,K +-(*(K,-(X,X)),-(Y,Y)) +::STMT +MATRIX:R +FLOAT:minSup +LITERAL_FLOAT:0.0 +&(>(R,0.0),>=(R,minSup)) +::STMT +MATRIX:std,rad,dtd +/(-(rad,std),dtd) +::STMT +MATRIX:lambda,B,S +LITERAL_FLOAT:2.0 +sum(*(lambda,^(+(B,S),2.0))) +::STMT +MATRIX:R,S,parsertemp40218,parsertemp40215 +FLOAT:level +-(+(R,rowSums(==(parsertemp40215,level))),rowSums(==(%*%(S,parsertemp40218),level))) +::STMT +LITERAL_FLOAT:1.0,2.0,2000.0 +*(^(2000.0,2.0),-(2000.0,1.0)) +::STMT +FLOAT:a,x +LITERAL_FLOAT:2.0 +*(a,^(x,2.0)) +::STMT +MATRIX:hubs +LITERAL_FLOAT:2.0 +abs(sum(^(-(hubs,hubs),2.0))) +::STMT +MATRIX:is_unsafe,parsertemp1518 +FLOAT:parsertemp1519,int493 +LITERAL_FLOAT:0.0 +sqrt(+(*(/(parsertemp1518,parsertemp1519),-(int493,is_unsafe)),<=(/(parsertemp1518,parsertemp1519),0.0))) +::STMT +MATRIX:diff,mask +LITERAL_FLOAT:0.0 +*(diff,==(mask,0.0)) +::STMT +MATRIX:parsertemp73634 +LITERAL_FLOAT:16.0,1.0 ++(*(parsertemp73634,16.0),1.0) +::STMT +LITERAL_FLOAT:0.16823164622761327 +0.16823164622761327 +::STMT +MATRIX:Y +LITERAL_FLOAT:0.0 ++(rowSums(Y),==(rowSums(Y),0.0)) +::STMT +MATRIX:simplex +LITERAL_FLOAT:4.0 +/(-(rowSums(simplex),simplex),4.0) +::STMT +MATRIX:s,parsertemp44016,d +FLOAT:parsertemp44015 +cast.FLOAT(%*%(t(-(s,parsertemp44016)),-(s,*(parsertemp44015,d)))) +::STMT +FLOAT:eta,s +LITERAL_FLOAT:-1.0 +^(eta,*(s,-1.0)) +::STMT +MATRIX:2814_t +FLOAT:parsertemp477829,parsertemp477814,2814_K,int626,2814_X,2814_Y,inp_x +*(cast.FLOAT(2814_t),+(*(-(2814_K,2814_Y),-(int626,parsertemp477814)),*(+(parsertemp477829,2814_Y),/(inp_x,2814_X)))) +::STMT +MATRIX:X +LITERAL_FLOAT:0.0 +-(0.0,t(X)) +::STMT +MATRIX:parsertemp409789,parsertemp409798 +FLOAT:int986 +LITERAL_FLOAT:2.0 +sum(^(+(*(parsertemp409789,int986),t(parsertemp409798)),2.0)) +::STMT +MATRIX:scale_X,X,parsertemp115854 +LITERAL_FLOAT:0.0 +*(-(0.0,/(t(parsertemp115854),nrow(X))),scale_X) +::STMT +MATRIX:s,parsertemp44016,d +FLOAT:parsertemp44015,delta2 +-(delta2,%*%(t(-(s,parsertemp44016)),-(s,*(parsertemp44015,d)))) +::STMT +MATRIX:CFreqs1,present_domain_vals_mat,parsertemp27487 +LITERAL_FLOAT:1.0 +sum(*(-(%*%(present_domain_vals_mat,CFreqs1),1.0),%*%(present_domain_vals_mat,parsertemp27487))) +::STMT +MATRIX:W,X,H,parsertemp411101 +FLOAT:eps +/(%*%(t(W),X),+(%*%(%*%(parsertemp411101,W),H),eps)) +::STMT +MATRIX:classFeatureCounts +FLOAT:float640,int90 +LITERAL_FLOAT:1.0 +INT:int694,int227 +/(+(classFeatureCounts,1.0),%*%(+(rowSums(classFeatureCounts),*(int90,float640)),rand(int227,int694,1.0,1.0))) +::STMT +MATRIX:tmp +FLOAT:parsertemp477715,X,x,Y,K +LITERAL_FLOAT:1.0 ++(*(-(*(K,X),-(Y,Y)),-(1.0,/(parsertemp477715,X))),*(cast.FLOAT(tmp),/(-(x,X),-(X,X)))) +::STMT +MATRIX:t,parsertemp171083 +FLOAT:float208,float321 +LITERAL_FLOAT:0.189269,1.432788 ++(1.432788,*(sqrt(*(float321,parsertemp171083)),+(0.189269,*(t,float208)))) +::STMT +LITERAL_FLOAT:0.05469029540078189 +0.05469029540078189 +::STMT +MATRIX:parsertemp22268,parsertemp22266 +FLOAT:q,int631,int578 +LITERAL_FLOAT:1.0,10000.0 +/(sum(/(^(parsertemp22268,int578),/(parsertemp22266,int631))),*(10000.0,-(q,1.0))) +::STMT +MATRIX:b2,176_mask,W2,175_out +FLOAT:p ++(%*%(/(*(175_out,176_mask),p),W2),b2) +::STMT +FLOAT:window_size,q,parsertemp181039,parsertemp181046 +LITERAL_FLOAT:1.0,2.0 +*(*(2.0,window_size),-(1.0,/(-(q,parsertemp181039),*(window_size,parsertemp181046)))) +::STMT +FLOAT:std,arch_coef,noise,a0 +LITERAL_FLOAT:2.0 ++(a0,*(arch_coef,^(*(noise,std),2.0))) \ No newline at end of file From 4be1bd980b083c999a8062c4155be58586e8d3a6 Mon Sep 17 00:00:00 2001 From: Jaybit0 Date: Mon, 3 Feb 2025 09:37:32 +0100 Subject: [PATCH 2/9] Remove wildcard import --- .../org/apache/sysds/hops/rewriter/RewriterFramework.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterFramework.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterFramework.java index ce57eafa0f5..2cc784ba1c2 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/RewriterFramework.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterFramework.java @@ -33,7 +33,10 @@ import scala.Tuple2; import scala.Tuple4; -import java.io.*; +import java.io.BufferedReader; +import java.io.FileReader; +import java.io.FileWriter; +import java.io.IOException; import java.nio.file.Files; import java.nio.file.Paths; import java.util.ArrayList; From a0e1a173a4e1f577ac9919ab0465503c67a9a92a Mon Sep 17 00:00:00 2001 From: Jaybit0 Date: Mon, 3 Feb 2025 11:57:14 +0100 Subject: [PATCH 3/9] Bugfix in CodeGen To properly remove references to created operators if broadcasting checks fail --- .../rewriter/codegen/RewriterCodeGen.java | 26 +++- .../generated/GeneratedRewriteClass.java | 134 +++++++++++++----- 2 files changed, 118 insertions(+), 42 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/rewriter/codegen/RewriterCodeGen.java b/src/main/java/org/apache/sysds/hops/rewriter/codegen/RewriterCodeGen.java index 7af6984660c..1dfe573e890 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/codegen/RewriterCodeGen.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/codegen/RewriterCodeGen.java @@ -545,7 +545,7 @@ private static void buildCostFnRecursively(RewriterStatement costFn, Map buildRewrite(RewriterStatement newRoot, StringBuilder sb, RewriterAssertions assertions, Map vars, final RuleContext ctx, int indentation) { Set visited = new HashSet<>(); - recursivelyBuildNewHop(sb, newRoot, assertions, vars, ctx, indentation, 1, visited, newRoot.getResultingDataType(ctx).equals("FLOAT")); + recursivelyBuildNewHop(sb, newRoot, assertions, vars, ctx, indentation, 1, visited, newRoot.getResultingDataType(ctx).equals("FLOAT"), new ArrayList<>()); return visited; } @@ -561,13 +561,13 @@ private static void removeUnreferencedHops(RewriterStatement oldRoot, Set vars, final RuleContext ctx, int indentation, int varCtr, Set visited, boolean enforceRootDataType) { + private static int recursivelyBuildNewHop(StringBuilder sb, RewriterStatement cur, RewriterAssertions assertions, Map vars, final RuleContext ctx, int indentation, int varCtr, Set visited, boolean enforceRootDataType, List createdOps) { visited.add(cur); if (vars.containsKey(cur)) return varCtr; for (RewriterStatement child : cur.getOperands()) - varCtr = recursivelyBuildNewHop(sb, child, assertions, vars, ctx, indentation, varCtr, visited, false); + varCtr = recursivelyBuildNewHop(sb, child, assertions, vars, ctx, indentation, varCtr, visited, false, createdOps); if (cur instanceof RewriterDataType) { if (cur.isLiteral()) { @@ -610,6 +610,7 @@ private static int recursivelyBuildNewHop(StringBuilder sb, RewriterStatement cu sb.append("LiteralOp " + name + " = new LiteralOp( " + literalStr + " );\n"); } vars.put(cur, name); + createdOps.add(name); } return varCtr; @@ -620,17 +621,31 @@ private static int recursivelyBuildNewHop(StringBuilder sb, RewriterStatement cu if (CodeGenUtils.opRequiresBinaryBroadcastingMatch(cur, ctx)) { // Then we need to validate that broadcasting still works after rearranging indent(indentation, sb); - sb.append("if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(" + operandRefs[0] + ", " + operandRefs[1] + ") )\n"); + sb.append("if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(" + operandRefs[0] + ", " + operandRefs[1] + ") ) {\n"); + for (String createdOp : createdOps) { + // Properly remove the references to the newly constructed ops + indent(indentation+1, sb); + sb.append("HopRewriteUtils.removeAllChildReferences(" + createdOp + ");\n"); + } indent(indentation+1, sb); sb.append("return hi;\n"); + indent(indentation, sb); + sb.append("}\n"); } else { List matchingDims = CodeGenUtils.matchingDimRequirement(cur, ctx); if (!matchingDims.isEmpty()) { // Then we need to validate that broadcasting still works after rearranging - sb.append("if ( !RewriterRuntimeUtils.hasMatchingDims(" + matchingDims.stream().map(idx -> operandRefs[idx]).collect(Collectors.joining(", ")) + ") )\n"); + sb.append("if ( !RewriterRuntimeUtils.hasMatchingDims(" + matchingDims.stream().map(idx -> operandRefs[idx]).collect(Collectors.joining(", ")) + ") ) {\n"); + for (String createdOp : createdOps) { + // Properly remove the references to the newly constructed ops + indent(indentation+1, sb); + sb.append("HopRewriteUtils.removeAllChildReferences(" + createdOp + ");\n"); + } indent(indentation+1, sb); sb.append("return hi;\n"); + indent(indentation, sb); + sb.append("}\n"); } } @@ -640,6 +655,7 @@ private static int recursivelyBuildNewHop(StringBuilder sb, RewriterStatement cu sb.append(opClass + " " + name + " = " + constructor + ";\n"); vars.put(cur, name); + createdOps.add(name); } return varCtr; diff --git a/src/main/java/org/apache/sysds/hops/rewriter/generated/GeneratedRewriteClass.java b/src/main/java/org/apache/sysds/hops/rewriter/generated/GeneratedRewriteClass.java index c239e6c6ef4..09aea53048d 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/generated/GeneratedRewriteClass.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/generated/GeneratedRewriteClass.java @@ -629,8 +629,9 @@ private static Hop _applyRewrite7(Hop hi) { // Now, we start building the new HOP-DAG: -(B,A) - if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1, hi_0_1) ) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1, hi_0_1) ) { return hi; + } BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1, hi_0_1, Types.OpOp2.MINUS); Hop newRoot = v1; @@ -701,8 +702,9 @@ private static Hop _applyRewrite8(Hop hi) { // Now, we start building the new HOP-DAG: -(A,B) - if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1_1, hi_1_0) ) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1_1, hi_1_0) ) { return hi; + } BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1_1, hi_1_0, Types.OpOp2.MINUS); Hop newRoot = v1; @@ -773,8 +775,9 @@ private static Hop _applyRewrite9(Hop hi) { // Now, we start building the new HOP-DAG: /(A,B) - if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0, hi_1_1) ) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0, hi_1_1) ) { return hi; + } BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0, hi_1_1, Types.OpOp2.DIV); Hop newRoot = v1; @@ -1401,19 +1404,33 @@ private static Hop _applyRewrite18(Hop hi) { // Now, we start building the new HOP-DAG: +(*(*(tmp8608,y_corr),-(float599,is_zero_y_corr)),*(tmp20367,+(tmp55180,tmp23071))) - if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_1, hi_0_0_0) ) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_1, hi_0_0_0) ) { return hi; + } BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_1, hi_0_0_0, Types.OpOp2.MULT); BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0_1_0, hi_0_0_1_1, Types.OpOp2.MINUS); - if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(v1, v2) ) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(v1, v2) ) { + HopRewriteUtils.removeAllChildReferences(v1); + HopRewriteUtils.removeAllChildReferences(v2); return hi; + } BinaryOp v3 = HopRewriteUtils.createAutoGeneratedBinary(v1, v2, Types.OpOp2.MULT); - if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1_1_1, hi_1_1_0) ) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1_1_1, hi_1_1_0) ) { + HopRewriteUtils.removeAllChildReferences(v1); + HopRewriteUtils.removeAllChildReferences(v2); + HopRewriteUtils.removeAllChildReferences(v3); return hi; + } BinaryOp v4 = HopRewriteUtils.createAutoGeneratedBinary(hi_1_1_1, hi_1_1_0, Types.OpOp2.PLUS); BinaryOp v5 = HopRewriteUtils.createAutoGeneratedBinary(hi_1_0, v4, Types.OpOp2.MULT); - if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(v3, v5) ) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(v3, v5) ) { + HopRewriteUtils.removeAllChildReferences(v1); + HopRewriteUtils.removeAllChildReferences(v2); + HopRewriteUtils.removeAllChildReferences(v3); + HopRewriteUtils.removeAllChildReferences(v4); + HopRewriteUtils.removeAllChildReferences(v5); return hi; + } BinaryOp v6 = HopRewriteUtils.createAutoGeneratedBinary(v3, v5, Types.OpOp2.PLUS); Hop newRoot = v6; @@ -1549,8 +1566,9 @@ private static Hop _applyRewrite20(Hop hi) { // Now, we start building the new HOP-DAG: -(+(tmp63699,tmp80035),f12880) - if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1, hi_0_0) ) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1, hi_0_0) ) { return hi; + } BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1, hi_0_0, Types.OpOp2.PLUS); BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_0_1, Types.OpOp2.MINUS); @@ -1678,8 +1696,10 @@ private static Hop _applyRewrite22(Hop hi) { // Now, we start building the new HOP-DAG: ==(key,t(key_unique)) ReorgOp v1 = HopRewriteUtils.createTranspose(hi_0_0); - if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_1_0, v1) ) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_1_0, v1) ) { + HopRewriteUtils.removeAllChildReferences(v1); return hi; + } BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_1_0, v1, Types.OpOp2.EQUAL); Hop newRoot = v2; @@ -1886,12 +1906,16 @@ private static Hop _applyRewrite25(Hop hi) { // Now, we start building the new HOP-DAG: *(*(y_corr,tmp8608),-(float599,is_zero_y_corr)) - if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_0, hi_1) ) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_0, hi_1) ) { return hi; + } BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, hi_1, Types.OpOp2.MULT); BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_1_0, hi_0_1_1, Types.OpOp2.MINUS); - if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(v1, v2) ) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(v1, v2) ) { + HopRewriteUtils.removeAllChildReferences(v1); + HopRewriteUtils.removeAllChildReferences(v2); return hi; + } BinaryOp v3 = HopRewriteUtils.createAutoGeneratedBinary(v1, v2, Types.OpOp2.MULT); Hop newRoot = v3; @@ -2040,8 +2064,9 @@ private static Hop _applyRewrite27(Hop hi) { // Now, we start building the new HOP-DAG: *(tmp30390,sum(*(tmp97178,tmp8790))) - if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_1, hi_0_0_0) ) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_1, hi_0_0_0) ) { return hi; + } BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_1, hi_0_0_0, Types.OpOp2.MULT); AggUnaryOp v2 = HopRewriteUtils.createAggUnaryOp(v1, Types.AggOp.SUM, Types.Direction.RowCol); BinaryOp v3 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0_1, v2, Types.OpOp2.MULT); @@ -2106,8 +2131,9 @@ private static Hop _applyRewrite28(Hop hi) { // Now, we start building the new HOP-DAG: +(-(tmp82242,tmp98488),a) - if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_1, hi_1) ) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_1, hi_1) ) { return hi; + } BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_1, hi_1, Types.OpOp2.MINUS); BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_0_0, Types.OpOp2.PLUS); @@ -2170,8 +2196,9 @@ private static Hop _applyRewrite29(Hop hi) { // Now, we start building the new HOP-DAG: -(-(obj,tmp6500),tmp26035) - if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0, hi_1_0) ) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0, hi_1_0) ) { return hi; + } BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0, hi_1_0, Types.OpOp2.MINUS); BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_1_1, Types.OpOp2.MINUS); @@ -2234,8 +2261,9 @@ private static Hop _applyRewrite30(Hop hi) { // Now, we start building the new HOP-DAG: -(tmp68530,+(tmp73960,tmp29113)) - if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_1, hi_1) ) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_1, hi_1) ) { return hi; + } BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_1, hi_1, Types.OpOp2.PLUS); BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, v1, Types.OpOp2.MINUS); @@ -2298,8 +2326,9 @@ private static Hop _applyRewrite31(Hop hi) { // Now, we start building the new HOP-DAG: +(-(tmp82242,tmp98488),a) - if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1, hi_0_1) ) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1, hi_0_1) ) { return hi; + } BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1, hi_0_1, Types.OpOp2.MINUS); BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_0_0, Types.OpOp2.PLUS); @@ -2714,8 +2743,9 @@ private static Hop _applyRewrite37(Hop hi) { // Now, we start building the new HOP-DAG: -(+*(A,b,C),d) - if ( !RewriterRuntimeUtils.hasMatchingDims(hi_1, hi_0_0_0) ) + if ( !RewriterRuntimeUtils.hasMatchingDims(hi_1, hi_0_0_0) ) { return hi; + } TernaryOp v1 = HopRewriteUtils.createTernary(hi_1, hi_0_0_1, hi_0_0_0,Types.OpOp3.PLUS_MULT); BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_0_1, Types.OpOp2.MINUS); @@ -2796,11 +2826,14 @@ private static Hop _applyRewrite38(Hop hi) { // Now, we start building the new HOP-DAG: -(A,-*(B,c,D)) - if ( !RewriterRuntimeUtils.hasMatchingDims(hi_0_1, hi_0_0_0) ) + if ( !RewriterRuntimeUtils.hasMatchingDims(hi_0_1, hi_0_0_0) ) { return hi; + } TernaryOp v1 = HopRewriteUtils.createTernary(hi_0_1, hi_0_0_1, hi_0_0_0,Types.OpOp3.MINUS_MULT); - if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1, v1) ) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1, v1) ) { + HopRewriteUtils.removeAllChildReferences(v1); return hi; + } BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_1, v1, Types.OpOp2.MINUS); Hop newRoot = v2; @@ -2880,11 +2913,14 @@ private static Hop _applyRewrite39(Hop hi) { // Now, we start building the new HOP-DAG: +*(M9347,f32765,*(K,M40316)) - if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1_0, hi_1_1_0) ) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1_0, hi_1_1_0) ) { return hi; + } BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1_0, hi_1_1_0, Types.OpOp2.MULT); - if ( !RewriterRuntimeUtils.hasMatchingDims(hi_0, v1) ) + if ( !RewriterRuntimeUtils.hasMatchingDims(hi_0, v1) ) { + HopRewriteUtils.removeAllChildReferences(v1); return hi; + } TernaryOp v2 = HopRewriteUtils.createTernary(hi_0, hi_1_1_1, v1,Types.OpOp3.PLUS_MULT); Hop newRoot = v2; @@ -2960,8 +2996,10 @@ private static Hop _applyRewrite40(Hop hi) { // Now, we start building the new HOP-DAG: -(-(y,%*%(X,B)),intercept) AggBinaryOp v1 = HopRewriteUtils.createMatrixMultiply(hi_1_0_0, hi_1_0_1); - if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0, v1) ) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0, v1) ) { + HopRewriteUtils.removeAllChildReferences(v1); return hi; + } BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_0, v1, Types.OpOp2.MINUS); BinaryOp v3 = HopRewriteUtils.createAutoGeneratedBinary(v2, hi_1_1, Types.OpOp2.MINUS); @@ -3025,8 +3063,9 @@ private static Hop _applyRewrite41(Hop hi) { // Now, we start building the new HOP-DAG: +(f45081,-(B,A)) - if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1, hi_0_1) ) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1, hi_0_1) ) { return hi; + } BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1, hi_0_1, Types.OpOp2.MINUS); BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, v1, Types.OpOp2.PLUS); @@ -3106,11 +3145,14 @@ private static Hop _applyRewrite42(Hop hi) { // Now, we start building the new HOP-DAG: +*(M9347,f32765,*(K,M40316)) - if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1_0_1, hi_1_1) ) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1_0_1, hi_1_1) ) { return hi; + } BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1_0_1, hi_1_1, Types.OpOp2.MULT); - if ( !RewriterRuntimeUtils.hasMatchingDims(hi_0, v1) ) + if ( !RewriterRuntimeUtils.hasMatchingDims(hi_0, v1) ) { + HopRewriteUtils.removeAllChildReferences(v1); return hi; + } TernaryOp v2 = HopRewriteUtils.createTernary(hi_0, hi_1_0_0, v1,Types.OpOp3.PLUS_MULT); Hop newRoot = v2; @@ -3190,11 +3232,14 @@ private static Hop _applyRewrite43(Hop hi) { // Now, we start building the new HOP-DAG: +*(M9347,f32765,*(K,M40316)) - if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_0_0, hi_0_1) ) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_0_0, hi_0_1) ) { return hi; + } BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0_0, hi_0_1, Types.OpOp2.MULT); - if ( !RewriterRuntimeUtils.hasMatchingDims(hi_1, v1) ) + if ( !RewriterRuntimeUtils.hasMatchingDims(hi_1, v1) ) { + HopRewriteUtils.removeAllChildReferences(v1); return hi; + } TernaryOp v2 = HopRewriteUtils.createTernary(hi_1, hi_0_0_1, v1,Types.OpOp3.PLUS_MULT); Hop newRoot = v2; @@ -3265,8 +3310,9 @@ private static Hop _applyRewrite44(Hop hi) { // Now, we start building the new HOP-DAG: /(A,M13119) - if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1, hi_0_1) ) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1, hi_0_1) ) { return hi; + } BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1, hi_0_1, Types.OpOp2.DIV); Hop newRoot = v1; @@ -3395,8 +3441,10 @@ private static Hop _applyRewrite46(Hop hi) { // Now, we start building the new HOP-DAG: +(b,-(A,%*%(C,D))) AggBinaryOp v1 = HopRewriteUtils.createMatrixMultiply(hi_0_1_0, hi_0_1_1); - if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1, v1) ) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1, v1) ) { + HopRewriteUtils.removeAllChildReferences(v1); return hi; + } BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_1, v1, Types.OpOp2.MINUS); BinaryOp v3 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, v2, Types.OpOp2.PLUS); @@ -3583,8 +3631,9 @@ private static Hop _applyRewrite49(Hop hi) { // Now, we start building the new HOP-DAG: /(A,M13119) - if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0, hi_1_1) ) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0, hi_1_1) ) { return hi; + } BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0, hi_1_1, Types.OpOp2.DIV); Hop newRoot = v1; @@ -3734,11 +3783,14 @@ private static Hop _applyRewrite51(Hop hi) { // Now, we start building the new HOP-DAG: -*(M22650,f97734,*(M97683,M67673)) - if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1_1, hi_1_0_1) ) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1_1, hi_1_0_1) ) { return hi; + } BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1_1, hi_1_0_1, Types.OpOp2.MULT); - if ( !RewriterRuntimeUtils.hasMatchingDims(hi_0, v1) ) + if ( !RewriterRuntimeUtils.hasMatchingDims(hi_0, v1) ) { + HopRewriteUtils.removeAllChildReferences(v1); return hi; + } TernaryOp v2 = HopRewriteUtils.createTernary(hi_0, hi_1_0_0, v1,Types.OpOp3.MINUS_MULT); Hop newRoot = v2; @@ -3818,11 +3870,14 @@ private static Hop _applyRewrite52(Hop hi) { // Now, we start building the new HOP-DAG: -(f75306,+(*(A,M350),M67233)) - if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1_0, hi_1_1) ) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_1_0, hi_1_1) ) { return hi; + } BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_1_0, hi_1_1, Types.OpOp2.MULT); - if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(v1, hi_0_1) ) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(v1, hi_0_1) ) { + HopRewriteUtils.removeAllChildReferences(v1); return hi; + } BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_0_1, Types.OpOp2.PLUS); BinaryOp v3 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, v2, Types.OpOp2.MINUS); @@ -3903,11 +3958,14 @@ private static Hop _applyRewrite53(Hop hi) { // Now, we start building the new HOP-DAG: -(f75306,+(*(A,M350),M67233)) - if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_1_0, hi_0_1_1) ) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_1_0, hi_0_1_1) ) { return hi; + } BinaryOp v1 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_1_0, hi_0_1_1, Types.OpOp2.MULT); - if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(v1, hi_1) ) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(v1, hi_1) ) { + HopRewriteUtils.removeAllChildReferences(v1); return hi; + } BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(v1, hi_1, Types.OpOp2.PLUS); BinaryOp v3 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, v2, Types.OpOp2.MINUS); @@ -3984,8 +4042,10 @@ private static Hop _applyRewrite54(Hop hi) { // Now, we start building the new HOP-DAG: -(+(C,%*%(A,B)),d) AggBinaryOp v1 = HopRewriteUtils.createMatrixMultiply(hi_1_0, hi_1_1); - if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_0, v1) ) + if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(hi_0_0, v1) ) { + HopRewriteUtils.removeAllChildReferences(v1); return hi; + } BinaryOp v2 = HopRewriteUtils.createAutoGeneratedBinary(hi_0_0, v1, Types.OpOp2.PLUS); BinaryOp v3 = HopRewriteUtils.createAutoGeneratedBinary(v2, hi_0_1, Types.OpOp2.MINUS); From 1e9e9967f50f1682e8f9fc5c7970fa6d4f430211 Mon Sep 17 00:00:00 2001 From: Jaybit0 Date: Mon, 3 Feb 2025 17:42:19 +0100 Subject: [PATCH 4/9] Forgot to forward sample size --- .../java/org/apache/sysds/hops/rewriter/RewriterFramework.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterFramework.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterFramework.java index 2cc784ba1c2..074a6d1952f 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/RewriterFramework.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterFramework.java @@ -470,7 +470,7 @@ private List, Long, Boolean>> for (int i = 1; i < mEq.size(); i++) RewriterAssertionUtils.buildImplicitAssertions(mEq.get(i), assertions, ctx); - List, List>> costs = RewriterCostEstimator.compareCosts(mEq, assertions, ctx, true, 0); + List, List>> costs = RewriterCostEstimator.compareCosts(mEq, assertions, ctx, true, sample_size); Set> rewriteProposals = RewriterCostEstimator.findOptima(costs); long mId = idCtr.incrementAndGet(); From 8a86047c1cbeec97c5fbaf4d280c8879b001b9d7 Mon Sep 17 00:00:00 2001 From: Jaybit0 Date: Tue, 4 Feb 2025 12:19:11 +0100 Subject: [PATCH 5/9] Minor Bugfixes, Added Some Operators to CodeGen, Improved Default RuleSet Ordering --- .../sysds/hops/rewriter/RewriterStatement.java | 2 +- .../estimators/RewriterSparsityEstimator.java | 6 ++++++ .../hops/rewriter/rule/RewriterRuleSet.java | 7 ++++++- .../sysds/hops/rewriter/utils/CodeGenUtils.java | 16 +++++++++++++++- 4 files changed, 28 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java b/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java index faf2dbaea21..e6e633100a7 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java @@ -1041,7 +1041,7 @@ public int countInstructions() { MutableInt i = new MutableInt(); forEachPreOrder(cur -> { if (!cur.isDataOrigin() || cur.isLiteral()) { - i.increment(); + i.add(1 + cur.getOperands().size()); } return true; }, false); diff --git a/src/main/java/org/apache/sysds/hops/rewriter/estimators/RewriterSparsityEstimator.java b/src/main/java/org/apache/sysds/hops/rewriter/estimators/RewriterSparsityEstimator.java index 06bd446bd9f..c33d7737928 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/estimators/RewriterSparsityEstimator.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/estimators/RewriterSparsityEstimator.java @@ -129,6 +129,12 @@ public static RewriterStatement estimateNNZ(RewriterStatement stmt, final RuleCo if (stmt.getChild(1).isLiteral() && ConstantFoldingUtils.isNeutralElement(stmt.getChild(1).getLiteral(), "+")) return RewriterStatement.nnz(stmt.getChild(0), ctx); return StatementUtils.min(ctx, RewriterStatement.multiArgInstr(ctx, "+", RewriterStatement.nnz(stmt.getChild(0), ctx), RewriterStatement.nnz(stmt.getChild(2), ctx)), StatementUtils.length(ctx, stmt)); + case "const(MATRIX,FLOAT)": + if (stmt.getChild(1).isLiteral() && ConstantFoldingUtils.isNeutralElement(stmt.getChild(1).getLiteral(), "+")) + return RewriterStatement.literal(ctx, 0L); + case "rowSums(MATRIX)": + case "colSums(MATRIX)": + StatementUtils.min(ctx, RewriterStatement.nnz(stmt.getChild(0), ctx), StatementUtils.length(ctx, stmt)); } return StatementUtils.length(ctx, stmt); diff --git a/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleSet.java b/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleSet.java index d64de719c8d..468f7fd2ad2 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleSet.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleSet.java @@ -291,7 +291,12 @@ public static RewriterRuleSet deserialize(String[] data, final RuleContext ctx) for (int i = 0; i < data.length; i++) { if (data[i].equals("::RULE")) { if (!currentLines.isEmpty()) { - rules.add(RewriterUtils.parseRule(String.join("\n", currentLines), ctx)); + try { + rules.add(RewriterUtils.parseRule(String.join("\n", currentLines), ctx)); + } catch (Exception e) { + System.err.println("An error occurred while parsing the rule:\n" + String.join("\n", currentLines)); + e.printStackTrace(); + } currentLines.clear(); } } else { diff --git a/src/main/java/org/apache/sysds/hops/rewriter/utils/CodeGenUtils.java b/src/main/java/org/apache/sysds/hops/rewriter/utils/CodeGenUtils.java index 533f646a7ed..2c1cba45bb4 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/utils/CodeGenUtils.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/utils/CodeGenUtils.java @@ -20,6 +20,7 @@ package org.apache.sysds.hops.rewriter.utils; import org.apache.commons.lang3.NotImplementedException; +import org.apache.sysds.hops.rewrite.HopRewriteUtils; import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; import org.apache.sysds.hops.rewriter.RewriterStatement; import org.apache.sysds.hops.rewriter.RuleContext; @@ -60,7 +61,6 @@ public static String getAdditionalCheck(RewriterStatement stmt, final RuleContex public static String getOpCode(RewriterStatement stmt, final RuleContext ctx) { if (stmt.getOperands().size() == 1) { // Handle unary ops - // TODO: nrow, ncol, length switch (stmt.trueInstruction()) { case "t": return "Types.ReOrgOp.TRANS"; @@ -74,14 +74,20 @@ public static String getOpCode(RewriterStatement stmt, final RuleContext ctx) { return "Types.OpOp1.POW2"; case "log": return "Types.OpOp1.LOG"; + case "log_nz": + return "Types.OpOp1.LOG_NZ"; case "abs": return "Types.OpOp1.ABS"; case "round": return "Types.OpOp1.ROUND"; + case "exp": + return "Types.OpOp1.EXP"; case "rowSums": case "colSums": case "sum": return "Types.AggOp.SUM"; + case "sumSq": + return "Types.AggOp.SUM_SQ"; case "trace": return "Types.AggOp.TRACE"; case "*2": @@ -248,6 +254,7 @@ public static String getOpClass(RewriterStatement stmt, final RuleContext ctx) { case "!": case "sqrt": case "log": + case "log_nz": case "abs": case "round": case "*2": @@ -257,11 +264,13 @@ public static String getOpClass(RewriterStatement stmt, final RuleContext ctx) { case "ncol": case "length": case "sq": + case "exp": return "UnaryOp"; case "rowSums": case "colSums": case "sum": + case "sumSq": case "trace": return "AggUnaryOp"; @@ -386,6 +395,11 @@ public static String getHopConstructor(RewriterStatement cur, RewriterAssertions return "HopRewriteUtils.createAggUnaryOp(" + children[0] + ", Types.AggOp.SUM, Types.Direction.RowCol)"; + case "sumSq": + if (children.length != 1) + throw new IllegalArgumentException(); + + return "HopRewriteUtils.createAggUnaryOp(" + children[0] + ", Types.AggOp.SUM_SQ, Types.Direction.RowCol)"; case "trace": if (children.length != 1) throw new IllegalArgumentException(); From a315af9d9ef95aa13fc3588aae5211bc87fca959 Mon Sep 17 00:00:00 2001 From: Jaybit0 Date: Tue, 4 Feb 2025 12:23:36 +0100 Subject: [PATCH 6/9] Sparsity estimation for rev(MATRIX) --- .../hops/rewriter/estimators/RewriterSparsityEstimator.java | 1 + 1 file changed, 1 insertion(+) diff --git a/src/main/java/org/apache/sysds/hops/rewriter/estimators/RewriterSparsityEstimator.java b/src/main/java/org/apache/sysds/hops/rewriter/estimators/RewriterSparsityEstimator.java index c33d7737928..22de98abcb2 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/estimators/RewriterSparsityEstimator.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/estimators/RewriterSparsityEstimator.java @@ -121,6 +121,7 @@ public static RewriterStatement estimateNNZ(RewriterStatement stmt, final RuleCo case "*2(MATRIX)": case "sq(MATRIX)": case "t(MATRIX)": + case "rev(MATRIX)": return RewriterStatement.nnz(stmt.getChild(0), ctx); case "1-*(MATRIX,MATRIX)": return StatementUtils.length(ctx, stmt); From edff946f655dcf70190bb0f8184bb7282fd90a99 Mon Sep 17 00:00:00 2001 From: Jaybit0 Date: Tue, 4 Feb 2025 15:17:33 +0100 Subject: [PATCH 7/9] Extending support for constant matrices --- .../rewriter/codegen/CodeGenCondition.java | 13 ++++++--- .../rewriter/codegen/RewriterCodeGen.java | 25 ++++++++++++++--- .../hops/rewriter/utils/CodeGenUtils.java | 28 +++++++++++++++++-- 3 files changed, 56 insertions(+), 10 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/rewriter/codegen/CodeGenCondition.java b/src/main/java/org/apache/sysds/hops/rewriter/codegen/CodeGenCondition.java index e8e30f4105e..f97f6360235 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/codegen/CodeGenCondition.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/codegen/CodeGenCondition.java @@ -434,10 +434,15 @@ private boolean matchesDataTypeCondition(RewriterRule rule, RewriterStatement st } private boolean matchesOpClassCondition(RewriterRule rule, RewriterStatement stmt, final RuleContext ctx) { - String opClass = (String) conditionValue; - String actualClass = CodeGenUtils.getOpClass(stmt, ctx); - - return opClass.equals(actualClass); + try { + String opClass = (String) conditionValue; + String actualClass = CodeGenUtils.getOpClass(stmt, ctx); + + return opClass.equals(actualClass); + } catch (Exception e) { + System.err.println(rule.toParsableString(ctx)); + throw e; + } } private boolean matchesOpCodeCondition(RewriterRule rule, RewriterStatement stmt, final RuleContext ctx) { diff --git a/src/main/java/org/apache/sysds/hops/rewriter/codegen/RewriterCodeGen.java b/src/main/java/org/apache/sysds/hops/rewriter/codegen/RewriterCodeGen.java index 1dfe573e890..99da6a3b0a8 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/codegen/RewriterCodeGen.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/codegen/RewriterCodeGen.java @@ -146,11 +146,13 @@ public static String generateClass(String className, List implemented = new HashSet<>(); + int implementedRules = 0; for (Tuple2 appliedRewrites : rewrites) { String mRewriteFn; if (ignoreErrors) { try { mRewriteFn = generateRewriteFunction(appliedRewrites._2, appliedRewrites._1, 1, maintainRewriteStats, ctx); + implementedRules++; } catch (Exception e) { if (printErrors) e.printStackTrace(); @@ -159,6 +161,7 @@ public static String generateClass(String className, List { + if (el.isInstruction() && el.trueInstruction().equals("const") && vars.get(el.getChild(0)) == null) { + vars.put(el.getChild(0), vars.get(el)); + } + + return true; + }, false); + if (fromCost != null) { List msb = new ArrayList<>(); msb.add(new StringBuilder()); @@ -804,16 +816,21 @@ private static void recursivelyBuildMatchingSequence(RewriterStatement cur, Stri // Build the variable definition String name = resolveOperand(cur, i, sb, curVar, ctx, indentation); - map.put(stmt, name); - sb.append('\n'); - recursivelyBuildMatchingSequence(stmt, sb, name, ctx, indentation, map, allowedMultiRefs, allowCombinations); + if (name != null) { + map.put(stmt, name); + sb.append('\n'); + recursivelyBuildMatchingSequence(stmt, sb, name, ctx, indentation, map, allowedMultiRefs, allowCombinations); + } } } private static String resolveOperand(RewriterStatement stmt, int idx, StringBuilder sb, String curVar, final RuleContext ctx, int indentation) { + String accessor = CodeGenUtils.getChildAccessor(curVar, stmt, idx); + if (accessor == null) + return null; // Then we do not need to traverse the sub-dag further String name = curVar + "_" + idx; indent(indentation, sb); - sb.append("Hop " + name + " = " + curVar + ".getInput(" + idx + ");\n"); + sb.append("Hop " + name + " = " + accessor + ";\n"); return name; } diff --git a/src/main/java/org/apache/sysds/hops/rewriter/utils/CodeGenUtils.java b/src/main/java/org/apache/sysds/hops/rewriter/utils/CodeGenUtils.java index 2c1cba45bb4..59978923ef4 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/utils/CodeGenUtils.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/utils/CodeGenUtils.java @@ -20,6 +20,7 @@ package org.apache.sysds.hops.rewriter.utils; import org.apache.commons.lang3.NotImplementedException; +import org.apache.sysds.hops.DataGenOp; import org.apache.sysds.hops.rewrite.HopRewriteUtils; import org.apache.sysds.hops.rewriter.assertions.RewriterAssertions; import org.apache.sysds.hops.rewriter.RewriterStatement; @@ -31,12 +32,34 @@ import java.util.Optional; public class CodeGenUtils { + // Function to access child statement (which are not neccessarily through .getInput(n)) + public static String getChildAccessor(String parentVar, RewriterStatement stmt, int childIdx) { + switch (stmt.trueInstruction()) { + case "const": + if (childIdx != 1) + return null; + + if (stmt.getChild(1).isLiteral() && Math.abs(stmt.getChild(1).floatLiteral()) == 0.0) + return "new LiteralOp(0.0D)"; // as this might be nnz = 0 and not DataGenOp + return "((DataGenOp)" + parentVar + ").getConstantValue()"; + } + + return parentVar + ".getInput(" + childIdx + ")"; + } + public static String getSpecialOpCheck(RewriterStatement stmt, final RuleContext ctx, String hopVar) { if (!stmt.isInstruction()) return null; switch (stmt.trueInstruction()) { case "%*%": return "HopRewriteUtils.isMatrixMultiply(" + hopVar + ")"; + case "const": + if (stmt.getChild(1).isLiteral()) { + if (Math.abs(stmt.getChild(1).floatLiteral()) == 0.0) // Then this also holds for nnz=0 + return "HopRewriteUtils.isDataGenOpWithConstantValue(" + hopVar + ", " + stmt.getChild(1).floatLiteral() + ") || " + hopVar + ".getNnz() == 0"; + return "HopRewriteUtils.isDataGenOpWithConstantValue(" + hopVar + ", " + stmt.getChild(1).floatLiteral() + ")"; + } else + return "HopRewriteUtils.isDataGenOpWithConstantValue(" + hopVar + ")"; } return null; @@ -448,12 +471,13 @@ public static String getHopConstructor(RewriterStatement cur, RewriterAssertions if (mappedName != null) { nrowContent = getHopConstructor(stmt, assertions, varNameMapping, ctx, mappedName); - break; + if (nrowContent != null) + break; } } if (nrowContent == null) - throw new IllegalArgumentException(); + throw new IllegalArgumentException(nrowAssertion.toString()); } if (ncolLiteral.isPresent()) { From c4979c25a086b5979c276f5b27f097ed0a380e21 Mon Sep 17 00:00:00 2001 From: Jaybit0 Date: Tue, 4 Feb 2025 15:58:46 +0100 Subject: [PATCH 8/9] Disabling ^2 rewrites due to unexpected behavior --- .../org/apache/sysds/hops/rewriter/utils/CodeGenUtils.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/rewriter/utils/CodeGenUtils.java b/src/main/java/org/apache/sysds/hops/rewriter/utils/CodeGenUtils.java index 59978923ef4..3a61cc5c7f8 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/utils/CodeGenUtils.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/utils/CodeGenUtils.java @@ -93,8 +93,8 @@ public static String getOpCode(RewriterStatement stmt, final RuleContext ctx) { return "Types.OpOp1.NOT"; case "sqrt": return "Types.OpOp1.SQRT"; - case "sq": - return "Types.OpOp1.POW2"; + //case "sq": + // return "Types.OpOp1.POW2"; // POW2 does not seem to work in all cases when applying the rewrite (e.g., LinearLogRegTest) case "log": return "Types.OpOp1.LOG"; case "log_nz": @@ -286,7 +286,7 @@ public static String getOpClass(RewriterStatement stmt, final RuleContext ctx) { case "nrow": case "ncol": case "length": - case "sq": + //case "sq": // SQ does not appear to work in some cases case "exp": return "UnaryOp"; From 49114d6e5d9e285304ffba474d26a77e7ded55c0 Mon Sep 17 00:00:00 2001 From: Jaybit0 Date: Tue, 4 Feb 2025 16:32:27 +0100 Subject: [PATCH 9/9] Fix: Invalidate null arguments --- .../apache/sysds/hops/rewriter/codegen/RewriterCodeGen.java | 1 + .../org/apache/sysds/hops/rewriter/utils/CodeGenUtils.java | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/src/main/java/org/apache/sysds/hops/rewriter/codegen/RewriterCodeGen.java b/src/main/java/org/apache/sysds/hops/rewriter/codegen/RewriterCodeGen.java index 99da6a3b0a8..2274542d1e6 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/codegen/RewriterCodeGen.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/codegen/RewriterCodeGen.java @@ -648,6 +648,7 @@ private static int recursivelyBuildNewHop(StringBuilder sb, RewriterStatement cu if (!matchingDims.isEmpty()) { // Then we need to validate that broadcasting still works after rearranging + indent(indentation, sb); sb.append("if ( !RewriterRuntimeUtils.hasMatchingDims(" + matchingDims.stream().map(idx -> operandRefs[idx]).collect(Collectors.joining(", ")) + ") ) {\n"); for (String createdOp : createdOps) { // Properly remove the references to the newly constructed ops diff --git a/src/main/java/org/apache/sysds/hops/rewriter/utils/CodeGenUtils.java b/src/main/java/org/apache/sysds/hops/rewriter/utils/CodeGenUtils.java index 3a61cc5c7f8..5351cecdd68 100644 --- a/src/main/java/org/apache/sysds/hops/rewriter/utils/CodeGenUtils.java +++ b/src/main/java/org/apache/sysds/hops/rewriter/utils/CodeGenUtils.java @@ -382,6 +382,10 @@ public static String getHopConstructor(RewriterStatement cur, RewriterAssertions String opClass = getOpClass(cur, ctx); String opCode = null; + for (int i = 0; i < children.length; i++) + if (children[i] == null) + throw new IllegalArgumentException("The argument " + i + " is null: " + cur.toParsableString(ctx)); + // Special instructions switch (cur.trueInstruction()) { case "%*%":