Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enzyme autodiff helpers #954

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft

feat: enzyme autodiff helpers #954

wants to merge 6 commits into from

Conversation

avik-pal
Copy link
Member

@avik-pal avik-pal commented Sep 23, 2024

  • Features
    • Batched Jacobian
      • Forward Mode
      • Reverse Mode
    • vector_jacobian_product
    • jacobian_vector_product
    • Reactant support
    • Remove enzyme HO testing and use reactant instead

@avik-pal avik-pal changed the base branch from main to ap/up_enzyme September 23, 2024 00:39
Copy link
Contributor

github-actions bot commented Sep 23, 2024

Benchmark Results (ASV)

main 3ed672d... main/3ed672d8cfa19e...
basics/overhead 0.128 ± 0.0016 μs 0.124 ± 0.0011 μs 1.03
time_to_load 0.891 ± 0.0063 s 0.898 ± 0.0015 s 0.991

Benchmark Plots

A plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR.
Go to "Actions"->"Benchmark a pull request"->[the most recent run]->"Artifacts" (at the bottom).

@avik-pal avik-pal force-pushed the ap/up_enzyme branch 4 times, most recently from 04ef36a to 10f7272 Compare September 23, 2024 01:50
Base automatically changed from ap/up_enzyme to main September 23, 2024 03:09
An error occurred while trying to automatically change base from ap/up_enzyme to main September 23, 2024 03:09
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lux Benchmarks

Benchmark suite Current: 746d416 Previous: 5cb86b3 Ratio
Dense(512 => 512, identity)(512 x 128)/forward/CPU/2 thread(s) 414875 ns 411270.5 ns 1.01
Dense(512 => 512, identity)(512 x 128)/forward/CPU/4 thread(s) 322583 ns 321459 ns 1.00
Dense(512 => 512, identity)(512 x 128)/forward/CPU/8 thread(s) 322854 ns 243229 ns 1.33
Dense(512 => 512, identity)(512 x 128)/forward/CPU/1 thread(s) 739833 ns 739583 ns 1.00
Dense(512 => 512, identity)(512 x 128)/forward/GPU/CUDA 41063 ns 41187 ns 1.00
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/2 thread(s) 1280812.5 ns 1293854.5 ns 0.99
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/4 thread(s) 2407854.5 ns 2409166.5 ns 1.00
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/8 thread(s) 14059417 ns 16158416 ns 0.87
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/1 thread(s) 926625 ns 2244124.5 ns 0.41
Dense(512 => 512, identity)(512 x 128)/zygote/GPU/CUDA 183518 ns 186717.5 ns 0.98
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/2 thread(s) 1361375 ns 1386416 ns 0.98
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/4 thread(s) 2592542 ns 2592167 ns 1.00
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/8 thread(s) 13891333 ns 16442917 ns 0.84
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/1 thread(s) 1002791.5 ns 2224250 ns 0.45
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s) 1690459 ns 1760437.5 ns 0.96
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s) 1094291.5 ns 1084209 ns 1.01
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s) 1507042 ns 1521520.5 ns 0.99
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s) 2336166.5 ns 2926125 ns 0.80
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/GPU/CUDA 204580.5 ns 205511.5 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s) 12153584 ns 12138333 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s) 8807770.5 ns 8825083.5 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s) 9160167 ns 9173333 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s) 18007292 ns 18597250 ns 0.97
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/GPU/CUDA 1493413.5 ns 1482675 ns 1.01
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s) 17314292 ns 17291625 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s) 13955979.5 ns 13944583.5 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s) 14505167 ns 14521250 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s) 21116187.5 ns 21811145.5 ns 0.97
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s) 250669312.5 ns 249976666.5 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s) 148517042 ns 148133250 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s) 116505167 ns 116316853.5 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s) 106687521 ns 449124167 ns 0.24
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/GPU/CUDA 5476055 ns 5482800 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s) 1221329917 ns 1230523917 ns 0.99
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s) 931447042 ns 928555292 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s) 832356062.5 ns 832972542 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s) 630480021 ns 1627591875 ns 0.39
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/GPU/CUDA 35434181.5 ns 35593150.5 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s) 1036811166 ns 1137967916 ns 0.91
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s) 1002735417 ns 992382854.5 ns 1.01
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s) 1320899271 ns 1336473333 ns 0.99
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s) 746540333.5 ns 1743275104 ns 0.43
lenet(28, 28, 1, 32)/forward/CPU/2 thread(s) 1117542 ns 1092042 ns 1.02
lenet(28, 28, 1, 32)/forward/CPU/4 thread(s) 1594208.5 ns 1598250 ns 1.00
lenet(28, 28, 1, 32)/forward/CPU/8 thread(s) 3782499.5 ns 3369875 ns 1.12
lenet(28, 28, 1, 32)/forward/CPU/1 thread(s) 955292 ns 781500 ns 1.22
lenet(28, 28, 1, 32)/forward/GPU/CUDA 258209 ns 252534 ns 1.02
lenet(28, 28, 1, 32)/zygote/CPU/2 thread(s) 2999833 ns 2977458 ns 1.01
lenet(28, 28, 1, 32)/zygote/CPU/4 thread(s) 4135687.5 ns 4116250 ns 1.00
lenet(28, 28, 1, 32)/zygote/CPU/8 thread(s) 10521916.5 ns 9642292 ns 1.09
lenet(28, 28, 1, 32)/zygote/CPU/1 thread(s) 3277124.5 ns 3142958 ns 1.04
lenet(28, 28, 1, 32)/zygote/GPU/CUDA 1045672 ns 1029320 ns 1.02
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s) 2268916 ns 2326083 ns 0.98
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s) 1423708 ns 1424667 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s) 1658833 ns 1560084 ns 1.06
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s) 3547125 ns 4056750 ns 0.87
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/GPU/CUDA 207087 ns 208748 ns 0.99
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s) 19446041 ns 19404250 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s) 16092583 ns 16063062.5 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s) 17282208.5 ns 17250542 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s) 25278271 ns 25830521 ns 0.98
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/GPU/CUDA 1572150 ns 1568775 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s) 34384124.5 ns 34513208 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s) 31011854 ns 30762479.5 ns 1.01
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s) 31238187.5 ns 31225958 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s) 36016000 ns 36930167 ns 0.98
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s) 4527333.5 ns 4524979.5 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s) 2756667 ns 2763937.5 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s) 2904250 ns 2673854.5 ns 1.09
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s) 7232542 ns 8377291.5 ns 0.86
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/GPU/CUDA 417797 ns 421945 ns 0.99
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s) 38954000 ns 39044209 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s) 32165354.5 ns 32104791.5 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s) 32261958 ns 32250208 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s) 49333875 ns 51814146 ns 0.95
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/GPU/CUDA 2609765 ns 2619508 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s) 89031416 ns 88649667 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s) 113154000 ns 113984812.5 ns 0.99
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s) 225760250 ns 225519896 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s) 72186959 ns 74452916.5 ns 0.97
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s) 269236500 ns 268719375 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s) 159569292 ns 159201750 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s) 126879500 ns 123545124.5 ns 1.03
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s) 146515666.5 ns 491504917 ns 0.30
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/GPU/CUDA 6869883 ns 6875777.5 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s) 1472680583.5 ns 1477656896 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s) 1142068417 ns 1178240000 ns 0.97
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s) 1076495521 ns 1062550813 ns 1.01
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s) 1003986146 ns 2026628729 ns 0.50
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/GPU/CUDA 33100972 ns 33057912 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s) 1695661542 ns 1716757416 ns 0.99
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s) 1475377937.5 ns 1531392041.5 ns 0.96
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s) 1873531416 ns 1872548250 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s) 1234411479 ns 2230422791 ns 0.55
lenet(28, 28, 1, 128)/forward/CPU/2 thread(s) 2051291.5 ns 2018375 ns 1.02
lenet(28, 28, 1, 128)/forward/CPU/4 thread(s) 2584729 ns 3022624.5 ns 0.86
lenet(28, 28, 1, 128)/forward/CPU/8 thread(s) 8034437.5 ns 7958917 ns 1.01
lenet(28, 28, 1, 128)/forward/CPU/1 thread(s) 2158812.5 ns 2489958 ns 0.87
lenet(28, 28, 1, 128)/forward/GPU/CUDA 259524 ns 253828.5 ns 1.02
lenet(28, 28, 1, 128)/zygote/CPU/2 thread(s) 9511666 ns 9615917 ns 0.99
lenet(28, 28, 1, 128)/zygote/CPU/4 thread(s) 11479666 ns 11905479 ns 0.96
lenet(28, 28, 1, 128)/zygote/CPU/8 thread(s) 25743562 ns 24859333 ns 1.04
lenet(28, 28, 1, 128)/zygote/CPU/1 thread(s) 10432292 ns 11285437.5 ns 0.92
lenet(28, 28, 1, 128)/zygote/GPU/CUDA 1105869 ns 1089672 ns 1.01
vgg16(32, 32, 3, 32)/forward/CPU/2 thread(s) 380071250 ns 382104291.5 ns 0.99
vgg16(32, 32, 3, 32)/forward/CPU/4 thread(s) 288235396 ns 288697958.5 ns 1.00
vgg16(32, 32, 3, 32)/forward/CPU/8 thread(s) 234569375 ns 263937500 ns 0.89
vgg16(32, 32, 3, 32)/forward/CPU/1 thread(s) 180534000 ns 453025562.5 ns 0.40
vgg16(32, 32, 3, 32)/forward/GPU/CUDA 4960129 ns 4955655.5 ns 1.00
vgg16(32, 32, 3, 32)/zygote/CPU/2 thread(s) 1151615584 ns 1159602875 ns 0.99
vgg16(32, 32, 3, 32)/zygote/CPU/4 thread(s) 943298834 ns 937068625 ns 1.01
vgg16(32, 32, 3, 32)/zygote/CPU/8 thread(s) 1035530459 ns 1116269958 ns 0.93
vgg16(32, 32, 3, 32)/zygote/CPU/1 thread(s) 836320375 ns 1586148458 ns 0.53
vgg16(32, 32, 3, 32)/zygote/GPU/CUDA 18005559 ns 18262229 ns 0.99
lenet(28, 28, 1, 64)/forward/CPU/2 thread(s) 1062271 ns 1053583 ns 1.01
lenet(28, 28, 1, 64)/forward/CPU/4 thread(s) 1957208 ns 2074479.5 ns 0.94
lenet(28, 28, 1, 64)/forward/CPU/8 thread(s) 5208209 ns 6685541 ns 0.78
lenet(28, 28, 1, 64)/forward/CPU/1 thread(s) 1335083 ns 1294959 ns 1.03
lenet(28, 28, 1, 64)/forward/GPU/CUDA 261443.5 ns 256727 ns 1.02
lenet(28, 28, 1, 64)/zygote/CPU/2 thread(s) 6520854 ns 6501125 ns 1.00
lenet(28, 28, 1, 64)/zygote/CPU/4 thread(s) 13810395.5 ns 12392708 ns 1.11
lenet(28, 28, 1, 64)/zygote/CPU/8 thread(s) 19287187.5 ns 19126833.5 ns 1.01
lenet(28, 28, 1, 64)/zygote/CPU/1 thread(s) 5707229 ns 6062209 ns 0.94
lenet(28, 28, 1, 64)/zygote/GPU/CUDA 1152646 ns 1151411.5 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s) 70489208 ns 70475667 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s) 43453750 ns 43577354.5 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s) 39549104.5 ns 39785333 ns 0.99
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s) 35097750 ns 132525000 ns 0.26
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/GPU/CUDA 1869963 ns 1859935 ns 1.01
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s) 355046958 ns 356597312.5 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s) 270615583 ns 270033292 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s) 254393917 ns 254165791.5 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s) 272932271 ns 541858104.5 ns 0.50
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/GPU/CUDA 12159123.5 ns 12225780 ns 0.99
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s) 395987166 ns 395590417 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s) 404320146.5 ns 407040083 ns 0.99
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s) 728272291.5 ns 686687708 ns 1.06
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s) 389487770.5 ns 711343459 ns 0.55
vgg16(32, 32, 3, 128)/forward/CPU/2 thread(s) 1190137625 ns 1189721417 ns 1.00
vgg16(32, 32, 3, 128)/forward/CPU/4 thread(s) 695477458.5 ns 694763792 ns 1.00
vgg16(32, 32, 3, 128)/forward/CPU/8 thread(s) 634960083 ns 639415416.5 ns 0.99
vgg16(32, 32, 3, 128)/forward/CPU/1 thread(s) 769841416 ns 1863138250 ns 0.41
vgg16(32, 32, 3, 128)/forward/GPU/CUDA 12538007.5 ns 12305225 ns 1.02
vgg16(32, 32, 3, 128)/zygote/CPU/2 thread(s) 3607518604 ns 3693716416.5 ns 0.98
vgg16(32, 32, 3, 128)/zygote/CPU/4 thread(s) 2913645208 ns 2822411042 ns 1.03
vgg16(32, 32, 3, 128)/zygote/CPU/8 thread(s) 2692374000 ns 2715785292 ns 0.99
vgg16(32, 32, 3, 128)/zygote/CPU/1 thread(s) 2202975896 ns 5075752792 ns 0.43
vgg16(32, 32, 3, 128)/zygote/GPU/CUDA 50257639 ns 50096815 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s) 3424417 ns 3427000 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s) 2074625 ns 2064167 ns 1.01
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s) 2543708 ns 2526875 ns 1.01
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s) 4468229 ns 6023271 ns 0.74
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/GPU/CUDA 341646 ns 343351 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s) 25873333 ns 26112416 ns 0.99
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s) 19047917 ns 19044125 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s) 19661709 ns 19200250 ns 1.02
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s) 36420541 ns 39317542 ns 0.93
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/GPU/CUDA 2449376.5 ns 2472756 ns 0.99
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s) 54505083 ns 53212375 ns 1.02
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s) 78853166.5 ns 86527562.5 ns 0.91
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s) 171442208.5 ns 174565833 ns 0.98
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s) 42923625 ns 45612312 ns 0.94
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s) 1781417 ns 1781334 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s) 1098125.5 ns 1100708 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s) 1592875 ns 1559041 ns 1.02
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s) 2400250 ns 3031000 ns 0.79
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/GPU/CUDA 209375 ns 209992.5 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s) 12557583.5 ns 12526708.5 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s) 9199833 ns 9202062.5 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s) 9662666 ns 9594042 ns 1.01
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s) 18408229.5 ns 18995083.5 ns 0.97
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/GPU/CUDA 1503069 ns 1520295 ns 0.99
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s) 17674791.5 ns 17654187.5 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s) 14289625 ns 14319770.5 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s) 14610292 ns 14542209 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s) 21384854.5 ns 22178916 ns 0.96
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s) 70541666 ns 70529958 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s) 43476792 ns 43518812.5 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s) 39661083 ns 39678750 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s) 35217125 ns 132567854.5 ns 0.27
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/GPU/CUDA 1870713 ns 1864543.5 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s) 360157416 ns 362025167 ns 0.99
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s) 347148521 ns 346626458 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s) 304138292 ns 304297792 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s) 460278000 ns 726309042 ns 0.63
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/GPU/CUDA 13219820 ns 13257700 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s) 421393604.5 ns 417934646 ns 1.01
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s) 439355042 ns 420406333 ns 1.05
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s) 698336792 ns 710093000 ns 0.98
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s) 392625708.5 ns 717186500 ns 0.55
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/2 thread(s) 1589459 ns 1662000 ns 0.96
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/4 thread(s) 1343812.5 ns 1348708 ns 1.00
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/8 thread(s) 1320270.5 ns 1039209 ns 1.27
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/1 thread(s) 2438209 ns 2446333 ns 1.00
mlp7layer_bn(gelu)(32 x 256)/forward/GPU/CUDA 547982 ns 549327.5 ns 1.00
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/2 thread(s) 8822916.5 ns 8831979 ns 1.00
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/4 thread(s) 12890000 ns 12827437 ns 1.00
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/8 thread(s) 31203500 ns 32684875 ns 0.95
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/1 thread(s) 5678583.5 ns 9840916 ns 0.58
mlp7layer_bn(gelu)(32 x 256)/zygote/GPU/CUDA 1245270 ns 1223627 ns 1.02
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/2 thread(s) 16553395.5 ns 16482375 ns 1.00
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/4 thread(s) 22651458 ns 22352396 ns 1.01
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/8 thread(s) 42591000 ns 48470938 ns 0.88
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/1 thread(s) 9863145.5 ns 13131854 ns 0.75
Dense(512 => 512, relu)(512 x 128)/forward/CPU/2 thread(s) 827125 ns 786625 ns 1.05
Dense(512 => 512, relu)(512 x 128)/forward/CPU/4 thread(s) 550208 ns 549645.5 ns 1.00
Dense(512 => 512, relu)(512 x 128)/forward/CPU/8 thread(s) 1035958.5 ns 1064854.5 ns 0.97
Dense(512 => 512, relu)(512 x 128)/forward/CPU/1 thread(s) 69083 ns 725104.5 ns 0.09527316407497126
Dense(512 => 512, relu)(512 x 128)/forward/GPU/CUDA 45920.5 ns 45187 ns 1.02
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/2 thread(s) 1547042 ns 1494083 ns 1.04
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/4 thread(s) 1016083 ns 1045666.5 ns 0.97
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/8 thread(s) 1442542 ns 1424458 ns 1.01
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/1 thread(s) 271791.5 ns 2291291 ns 0.12
Dense(512 => 512, relu)(512 x 128)/zygote/GPU/CUDA 210286.5 ns 207948.5 ns 1.01
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/2 thread(s) 1521770.5 ns 1496875 ns 1.02
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/4 thread(s) 1078937.5 ns 1011416 ns 1.07
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/8 thread(s) 1413312.5 ns 1769209 ns 0.80
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/1 thread(s) 329625.5 ns 2257000 ns 0.15
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s) 3416458 ns 3409312.5 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s) 2061937.5 ns 2052875 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s) 2494000 ns 2516667 ns 0.99
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s) 4566062.5 ns 5998083 ns 0.76
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/GPU/CUDA 280470 ns 281525 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s) 24061979 ns 24109729 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s) 17256209 ns 17182959 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s) 17237167 ns 17113146 ns 1.01
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s) 34902562.5 ns 37468750.5 ns 0.93
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/GPU/CUDA 2387533 ns 2398593 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s) 52964687 ns 52554812 ns 1.01
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s) 86734042 ns 84392000 ns 1.03
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s) 167475520.5 ns 170326000 ns 0.98
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s) 41736979 ns 44580125 ns 0.94
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s) 249717041.5 ns 250367042 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s) 147920542 ns 147848500 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s) 115821625 ns 116105042 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s) 106813458 ns 454852833 ns 0.23
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/GPU/CUDA 5478800 ns 5326699.5 ns 1.03
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s) 1102272500 ns 1103578584 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s) 857453958.5 ns 855911416.5 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s) 824612250 ns 831258666.5 ns 0.99
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s) 740289062.5 ns 1772502666 ns 0.42
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/GPU/CUDA 33158767 ns 33341175 ns 0.99
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s) 1030066416.5 ns 1010431750 ns 1.02
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s) 926322833 ns 965660000 ns 0.96
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s) 1298874583 ns 1276218416 ns 1.02
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s) 731812000 ns 1718974833.5 ns 0.43
mlp7layer_bn(relu)(32 x 256)/forward/CPU/2 thread(s) 1233917 ns 1245354 ns 0.99
mlp7layer_bn(relu)(32 x 256)/forward/CPU/4 thread(s) 902125 ns 938208 ns 0.96
mlp7layer_bn(relu)(32 x 256)/forward/CPU/8 thread(s) 950771 ns 685500 ns 1.39
mlp7layer_bn(relu)(32 x 256)/forward/CPU/1 thread(s) 2008792 ns 2004042 ns 1.00
mlp7layer_bn(relu)(32 x 256)/forward/GPU/CUDA 543147.5 ns 548493.5 ns 0.99
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/2 thread(s) 5799729 ns 5771604 ns 1.00
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/4 thread(s) 6771145.5 ns 6597750 ns 1.03
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/8 thread(s) 24594250 ns 25936583 ns 0.95
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/1 thread(s) 3704333 ns 7098375 ns 0.52
mlp7layer_bn(relu)(32 x 256)/zygote/GPU/CUDA 1200925 ns 1220210 ns 0.98
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/2 thread(s) 9688312.5 ns 9431791 ns 1.03
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/4 thread(s) 13008854.5 ns 13114104.5 ns 0.99
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/8 thread(s) 31739166.5 ns 33204521 ns 0.96
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/1 thread(s) 4429458.5 ns 7606042 ns 0.58
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/2 thread(s) 516437 ns 430541 ns 1.20
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/4 thread(s) 406667 ns 381021 ns 1.07
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/8 thread(s) 1920083 ns 3043792 ns 0.63
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/1 thread(s) 52584 ns 89542 ns 0.59
Dense(128 => 128, gelu)(128 x 128)/forward/GPU/CUDA 25170 ns 25679 ns 0.98
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/2 thread(s) 382104 ns 354583 ns 1.08
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/4 thread(s) 396375 ns 443833 ns 0.89
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/8 thread(s) 4741312.5 ns 4158750 ns 1.14
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/1 thread(s) 176000 ns 258750 ns 0.68
Dense(128 => 128, gelu)(128 x 128)/zygote/GPU/CUDA 183031.5 ns 188553 ns 0.97
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/2 thread(s) 412270.5 ns 385584 ns 1.07
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/4 thread(s) 427209 ns 474625 ns 0.90
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/8 thread(s) 4373812.5 ns 4412750 ns 0.99
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/1 thread(s) 189937.5 ns 271208 ns 0.70
Dense(128 => 128, relu)(128 x 128)/forward/CPU/2 thread(s) 463562.5 ns 376687.5 ns 1.23
Dense(128 => 128, relu)(128 x 128)/forward/CPU/4 thread(s) 346104 ns 325000 ns 1.06
Dense(128 => 128, relu)(128 x 128)/forward/CPU/8 thread(s) 737667 ns 771479 ns 0.96
Dense(128 => 128, relu)(128 x 128)/forward/CPU/1 thread(s) 13958 ns 54583 ns 0.26
Dense(128 => 128, relu)(128 x 128)/forward/GPU/CUDA 25472 ns 26029 ns 0.98
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/2 thread(s) 339750 ns 303687.5 ns 1.12
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/4 thread(s) 269625.5 ns 341166.5 ns 0.79
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/8 thread(s) 863583 ns 893375 ns 0.97
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/1 thread(s) 26167 ns 151583.5 ns 0.17
Dense(128 => 128, relu)(128 x 128)/zygote/GPU/CUDA 177009 ns 180458.5 ns 0.98
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/2 thread(s) 354500 ns 316479 ns 1.12
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/4 thread(s) 283979.5 ns 355958.5 ns 0.80
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/8 thread(s) 395375 ns 833687.5 ns 0.47
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/1 thread(s) 29792 ns 150917 ns 0.20
vgg16(32, 32, 3, 64)/forward/CPU/2 thread(s) 601672834 ns 603139000 ns 1.00
vgg16(32, 32, 3, 64)/forward/CPU/4 thread(s) 430355271 ns 430379625 ns 1.00
vgg16(32, 32, 3, 64)/forward/CPU/8 thread(s) 372252417 ns 380417500 ns 0.98
vgg16(32, 32, 3, 64)/forward/CPU/1 thread(s) 321473083 ns 876424750 ns 0.37
vgg16(32, 32, 3, 64)/forward/GPU/CUDA 7021643 ns 7025391.5 ns 1.00
vgg16(32, 32, 3, 64)/zygote/CPU/2 thread(s) 1998881708.5 ns 2008872896 ns 1.00
vgg16(32, 32, 3, 64)/zygote/CPU/4 thread(s) 1624143313 ns 1619900021 ns 1.00
vgg16(32, 32, 3, 64)/zygote/CPU/8 thread(s) 1593156750 ns 1577697458 ns 1.01
vgg16(32, 32, 3, 64)/zygote/CPU/1 thread(s) 1330130709 ns 2622523542 ns 0.51
vgg16(32, 32, 3, 64)/zygote/GPU/CUDA 26731028 ns 26902392.5 ns 0.99
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/2 thread(s) 521125 ns 535729 ns 0.97
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/4 thread(s) 430792 ns 431416.5 ns 1.00
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/8 thread(s) 1769625 ns 2478833.5 ns 0.71
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/1 thread(s) 218542 ns 866124.5 ns 0.25
Dense(512 => 512, gelu)(512 x 128)/forward/GPU/CUDA 44935 ns 44614 ns 1.01
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/2 thread(s) 1870834 ns 1911229 ns 0.98
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/4 thread(s) 2784708 ns 2468667 ns 1.13
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/8 thread(s) 14884459 ns 16401666 ns 0.91
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/1 thread(s) 1451104 ns 2768395.5 ns 0.52
Dense(512 => 512, gelu)(512 x 128)/zygote/GPU/CUDA 211690 ns 210772.5 ns 1.00
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/2 thread(s) 1907520.5 ns 1986854 ns 0.96
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/4 thread(s) 5045750 ns 5052500 ns 1.00
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/8 thread(s) 14628417 ns 16457750 ns 0.89
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/1 thread(s) 1531167 ns 2773062.5 ns 0.55
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/2 thread(s) 1352937.5 ns 1594542 ns 0.85
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/4 thread(s) 1191146 ns 1175979 ns 1.01
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/8 thread(s) 1186729.5 ns 932458.5 ns 1.27
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/1 thread(s) 2208291.5 ns 2307417 ns 0.96
mlp7layer_bn(tanh)(32 x 256)/forward/GPU/CUDA 546860 ns 543745 ns 1.01
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/2 thread(s) 5898041 ns 5990542 ns 0.98
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/4 thread(s) 6574479 ns 5767687 ns 1.14
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/8 thread(s) 23986125 ns 25963208 ns 0.92
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/1 thread(s) 3911334 ns 7322042 ns 0.53
mlp7layer_bn(tanh)(32 x 256)/zygote/GPU/CUDA 1157393 ns 1137388.5 ns 1.02
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/2 thread(s) 11645291.5 ns 11669000 ns 1.00
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/4 thread(s) 14054375 ns 16638833.5 ns 0.84
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/8 thread(s) 36069500 ns 38505541 ns 0.94
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/1 thread(s) 6334792 ns 9523103.5 ns 0.67
Dense(16 => 16, relu)(16 x 128)/forward/CPU/2 thread(s) 3187.5 ns 2770.5 ns 1.15
Dense(16 => 16, relu)(16 x 128)/forward/CPU/4 thread(s) 2875 ns 2541 ns 1.13
Dense(16 => 16, relu)(16 x 128)/forward/CPU/8 thread(s) 3000 ns 3542 ns 0.85
Dense(16 => 16, relu)(16 x 128)/forward/CPU/1 thread(s) 2396 ns 2167 ns 1.11
Dense(16 => 16, relu)(16 x 128)/forward/GPU/CUDA 22081 ns 21552 ns 1.02
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/2 thread(s) 7125 ns 7167 ns 0.99
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/4 thread(s) 7292 ns 7083 ns 1.03
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/8 thread(s) 7333 ns 7250 ns 1.01
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/1 thread(s) 7292 ns 7250 ns 1.01
Dense(16 => 16, relu)(16 x 128)/zygote/GPU/CUDA 175518 ns 173171.5 ns 1.01
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/2 thread(s) 8292 ns 8167 ns 1.02
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/4 thread(s) 8375 ns 8208 ns 1.02
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/8 thread(s) 8458 ns 8584 ns 0.99
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/1 thread(s) 5834 ns 5979.5 ns 0.98
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/2 thread(s) 10291 ns 10959 ns 0.94
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/4 thread(s) 16375 ns 13437.5 ns 1.22
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/8 thread(s) 10500 ns 10250 ns 1.02
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/1 thread(s) 7208 ns 7208 ns 1
Dense(16 => 16, gelu)(16 x 128)/forward/GPU/CUDA 22018 ns 21706 ns 1.01
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/2 thread(s) 20000 ns 19916 ns 1.00
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/4 thread(s) 20167 ns 20000 ns 1.01
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/8 thread(s) 20000 ns 20209 ns 0.99
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/1 thread(s) 19750 ns 19854.5 ns 0.99
Dense(16 => 16, gelu)(16 x 128)/zygote/GPU/CUDA 188547 ns 188017 ns 1.00
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/2 thread(s) 23666.5 ns 23666 ns 1.00
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/4 thread(s) 23667 ns 23625 ns 1.00
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/8 thread(s) 23792 ns 23708 ns 1.00
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/1 thread(s) 24125 ns 21334 ns 1.13
Dense(128 => 128, identity)(128 x 128)/forward/CPU/2 thread(s) 28708 ns 28625 ns 1.00
Dense(128 => 128, identity)(128 x 128)/forward/CPU/4 thread(s) 28417 ns 28875 ns 0.98
Dense(128 => 128, identity)(128 x 128)/forward/CPU/8 thread(s) 28709 ns 28208 ns 1.02
Dense(128 => 128, identity)(128 x 128)/forward/CPU/1 thread(s) 46333 ns 45875 ns 1.01
Dense(128 => 128, identity)(128 x 128)/forward/GPU/CUDA 23191.5 ns 23317 ns 0.99
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/2 thread(s) 224083.5 ns 233812.5 ns 0.96
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/4 thread(s) 271625 ns 277666 ns 0.98
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/8 thread(s) 3938083 ns 3990583 ns 0.99
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/1 thread(s) 62750 ns 145083 ns 0.43
Dense(128 => 128, identity)(128 x 128)/zygote/GPU/CUDA 177225.5 ns 191945 ns 0.92
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/2 thread(s) 241354 ns 250666.5 ns 0.96
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/4 thread(s) 288791.5 ns 295459 ns 0.98
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/8 thread(s) 4188708 ns 4148750 ns 1.01
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/1 thread(s) 68020.5 ns 145562.5 ns 0.47
Dense(16 => 16, identity)(16 x 128)/forward/CPU/2 thread(s) 3938 ns 2041 ns 1.93
Dense(16 => 16, identity)(16 x 128)/forward/CPU/4 thread(s) 2500 ns 1916 ns 1.30
Dense(16 => 16, identity)(16 x 128)/forward/CPU/8 thread(s) 2375 ns 2584 ns 0.92
Dense(16 => 16, identity)(16 x 128)/forward/CPU/1 thread(s) 1917 ns 1625 ns 1.18
Dense(16 => 16, identity)(16 x 128)/forward/GPU/CUDA 20311 ns 20024 ns 1.01
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/2 thread(s) 5125 ns 5375 ns 0.95
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/4 thread(s) 5375 ns 5125 ns 1.05
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/8 thread(s) 5459 ns 5250 ns 1.04
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/1 thread(s) 5312.5 ns 5084 ns 1.04
Dense(16 => 16, identity)(16 x 128)/zygote/GPU/CUDA 204801 ns 238397 ns 0.86
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/2 thread(s) 7500 ns 7541 ns 0.99
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/4 thread(s) 7500 ns 7416 ns 1.01
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/8 thread(s) 7666 ns 7750 ns 0.99
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/1 thread(s) 5833 ns 5250 ns 1.11
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s) 80011959 ns 79842084 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s) 49059208 ns 49100250 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s) 44947166 ns 43191750 ns 1.04
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s) 54049125 ns 151456000 ns 0.36
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/GPU/CUDA 2711019.5 ns 2712652 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s) 496782875 ns 472190041 ns 1.05
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s) 413969042 ns 413693042 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s) 403019709 ns 397758813 ns 1.01
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s) 472422687.5 ns 737522187.5 ns 0.64
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/GPU/CUDA 17006439 ns 16943151 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s) 717618208 ns 710270771 ns 1.01
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s) 668630292 ns 668321833 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s) 1013911916 ns 1002011792 ns 1.01
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s) 735911666.5 ns 997156208 ns 0.74

This comment was automatically generated by workflow using github-action-benchmark.

@avik-pal
Copy link
Member Author

Testcase from Discourse: https://discourse.julialang.org/t/speed-of-nested-ad-in-enzyme/123018/2

using Enzyme, Lux, Random, LinearAlgebra

n = 16
x_batch = randn(Float32, 2, n)
y_batch = randn(Float32, 2, n)

model = Chain(Parallel(vcat, Dense(2, 1, tanh), Dense(2, 1, tanh)), Dense(2, 1, tanh))

rng = Random.default_rng()
Random.seed!(rng, 0)

ps, st = Lux.setup(Xoshiro(0), model);

function compute_batched_error(model, x, y, ps, st)
    smodel = StatefulLuxLayer{true}(model, ps, st)
    fn = x_in -> smodel((x_in, y))
    J = batched_jacobian(
        fn, AutoEnzyme(;
            mode=set_runtime_activity(Enzyme.Forward), function_annotation=Enzyme.Const
        ),
        x
    )
    return sum(sqrt.(sum(abs2, J; dims=2)))
end

compute_batched_error(model, x_batch, y_batch, ps, st)

@avik-pal
Copy link
Member Author

Simpler test case:

using Lux, Random, Enzyme

model = Dense(2 => 3)
x = randn(Float32, 2, 4)
ps, st = Lux.setup(Random.default_rng(), model)

smodel = StatefulLuxLayer{true}(model, ps, st)

batched_jacobian(smodel, AutoEnzyme(; mode=Enzyme.Reverse), x)

function enzyme_test(model, x, ps, st)
    smodel = StatefulLuxLayer{true}(model, ps, st)
    J = batched_jacobian(
        smodel, AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const), x)
    return sum(abs2, J)
end

enzyme_test(model, x, ps, st)

Enzyme.gradient(Reverse, enzyme_test, Const(model), x, ps, Const(st))

But this leads to

ERROR: Enzyme compilation failed.
Current scope: 
; Function Attrs: mustprogress willreturn
define "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Pointer, [-1,8,0]:Integer, [-1,8,1]:Integer, [-1,8,2]:Integer, [-1,8,3]:Integer, [-1,8,4]:Integer, [-1,8,5]:Integer, [-1,8,6]:Integer, [-1,8,7]:Integer, [-1,8,8]:Pointer, [-1,8,8,-1]:Float@float, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer}" "enzymejl_parmtype"="137156409846672" "enzymejl_parmtype_ref"="1" [1 x [2 x {} addrspace(10)*]] @preprocess_julia_autodiff_146865_inner.1([1 x { {} addrspace(10)*, [2 x {} addrspace(10)*] }] "enzyme_type"="{[-1]:Pointer, [0,0]:Integer, [0,1]:Integer, [0,2]:Integer, [0,3]:Integer, [0,4]:Integer, [0,5]:Integer, [0,6]:Integer, [0,7]:Integer, [0,8]:Integer, [0,9]:Integer, [0,10]:Integer, [0,11]:Integer, [0,12]:Integer, [0,13]:Integer, [0,14]:Integer, [0,15]:Integer, [0,16]:Pointer, [0,16,0]:Pointer, [0,16,0,-1]:Float@float, [0,16,8]:Pointer, [0,16,8,0]:Integer, [0,16,8,1]:Integer, [0,16,8,2]:Integer, [0,16,8,3]:Integer, [0,16,8,4]:Integer, [0,16,8,5]:Integer, [0,16,8,6]:Integer, [0,16,8,7]:Integer, [0,16,8,8]:Pointer, [0,16,8,8,-1]:Float@float, [0,16,16]:Integer, [0,16,17]:Integer, [0,16,18]:Integer, [0,16,19]:Integer, [0,16,20]:Integer, [0,16,21]:Integer, [0,16,22]:Integer, [0,16,23]:Integer, [0,16,24]:Integer, [0,16,25]:Integer, [0,16,26]:Integer, [0,16,27]:Integer, [0,16,28]:Integer, [0,16,29]:Integer, [0,16,30]:Integer, [0,16,31]:Integer, [0,24]:Pointer, [0,24,0]:Pointer, [0,24,0,-1]:Float@float, [0,24,8]:Pointer, [0,24,8,0]:Integer, [0,24,8,1]:Integer, [0,24,8,2]:Integer, [0,24,8,3]:Integer, [0,24,8,4]:Integer, [0,24,8,5]:Integer, [0,24,8,6]:Integer, [0,24,8,7]:Integer, [0,24,8,8]:Pointer, [0,24,8,8,-1]:Float@float, [0,24,16]:Integer, [0,24,17]:Integer, [0,24,18]:Integer, [0,24,19]:Integer, [0,24,20]:Integer, [0,24,21]:Integer, [0,24,22]:Integer, [0,24,23]:Integer, [0,32]:Pointer, [8,0]:Pointer, [8,0,-1]:Float@float, [8,8]:Pointer, [8,8,0]:Integer, [8,8,1]:Integer, [8,8,2]:Integer, [8,8,3]:Integer, [8,8,4]:Integer, [8,8,5]:Integer, [8,8,6]:Integer, [8,8,7]:Integer, [8,8,8]:Pointer, [8,8,8,-1]:Float@float, [8,16]:Integer, [8,17]:Integer, [8,18]:Integer, [8,19]:Integer, [8,20]:Integer, [8,21]:Integer, [8,22]:Integer, [8,23]:Integer, [8,24]:Integer, [8,25]:Integer, [8,26]:Integer, [8,27]:Integer, [8,28]:Integer, [8,29]:Integer, [8,30]:Integer, [8,31]:Integer, [16,0]:Pointer, [16,0,-1]:Float@float, [16,8]:Pointer, [16,8,0]:Integer, [16,8,1]:Integer, [16,8,2]:Integer, [16,8,3]:Integer, [16,8,4]:Integer, [16,8,5]:Integer, [16,8,6]:Integer, [16,8,7]:Integer, [16,8,8]:Pointer, [16,8,8,-1]:Float@float, [16,16]:Integer, [16,17]:Integer, [16,18]:Integer, [16,19]:Integer, [16,20]:Integer, [16,21]:Integer, [16,22]:Integer, [16,23]:Integer}" "enzymejl_parmtype"="137163075826640" "enzymejl_parmtype_ref"="0" %0, { {} addrspace(10)*, [2 x {} addrspace(10)*] } "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Pointer, [-1,8,0]:Integer, [-1,8,1]:Integer, [-1,8,2]:Integer, [-1,8,3]:Integer, [-1,8,4]:Integer, [-1,8,5]:Integer, [-1,8,6]:Integer, [-1,8,7]:Integer, [-1,8,8]:Pointer, [-1,8,8,-1]:Float@float, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer}" "enzymejl_parmtype"="137156110361296" "enzymejl_parmtype_ref"="0" %1) local_unnamed_addr #3 !dbg !28 {
entry:
  %.fca.0.0.extract = extractvalue [1 x { {} addrspace(10)*, [2 x {} addrspace(10)*] }] %0, 0, 0, !dbg !29
  %.fca.0.1.0.extract = extractvalue [1 x { {} addrspace(10)*, [2 x {} addrspace(10)*] }] %0, 0, 1, 0, !dbg !29
  %.fca.0.1.1.extract = extractvalue [1 x { {} addrspace(10)*, [2 x {} addrspace(10)*] }] %0, 0, 1, 1, !dbg !29
  %pgcstack.i = call {}*** @julia.get_pgcstack() #6, !noalias !30
  %.fca.0.0.extract11 = extractvalue { {} addrspace(10)*, [2 x {} addrspace(10)*] } %1, 0
  %.fca.0.1.0.extract13 = extractvalue { {} addrspace(10)*, [2 x {} addrspace(10)*] } %1, 1, 0
  %.fca.0.1.1.extract15 = extractvalue { {} addrspace(10)*, [2 x {} addrspace(10)*] } %1, 1, 1
  %ptls_field.i22 = getelementptr inbounds {}**, {}*** %pgcstack.i, i64 2
  %2 = bitcast {}*** %ptls_field.i22 to i64***
  %ptls_load.i2324 = load i64**, i64*** %2, align 8, !tbaa !12, !noalias !30
  %3 = getelementptr inbounds i64*, i64** %ptls_load.i2324, i64 2
  %safepoint.i = load i64*, i64** %3, align 8, !tbaa !16, !noalias !30
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint.i) #6, !dbg !33, !noalias !30
  fence syncscope("singlethread") seq_cst
  %.fca.0.insert5 = insertvalue { {} addrspace(10)*, [2 x {} addrspace(10)*] } poison, {} addrspace(10)* %.fca.0.0.extract, 0, !dbg !35
  %.fca.1.0.insert6 = insertvalue { {} addrspace(10)*, [2 x {} addrspace(10)*] } %.fca.0.insert5, {} addrspace(10)* %.fca.0.1.0.extract, 1, 0, !dbg !35
  %.fca.1.1.insert7 = insertvalue { {} addrspace(10)*, [2 x {} addrspace(10)*] } %.fca.1.0.insert6, {} addrspace(10)* %.fca.0.1.1.extract, 1, 1, !dbg !35
  %.fca.0.insert17 = insertvalue [2 x {} addrspace(10)*] poison, {} addrspace(10)* %.fca.0.1.0.extract13, 0, !dbg !35
  %.fca.1.insert = insertvalue [2 x {} addrspace(10)*] %.fca.0.insert17, {} addrspace(10)* %.fca.0.1.1.extract15, 1, !dbg !35
  %4 = call [1 x [2 x {} addrspace(10)*]] inttoptr (i64 137160996173840 to [1 x [2 x {} addrspace(10)*]] ({ {} addrspace(10)*, [2 x {} addrspace(10)*] }, {} addrspace(10)*, [2 x {} addrspace(10)*])*)({ {} addrspace(10)*, [2 x {} addrspace(10)*] } %.fca.1.1.insert7, {} addrspace(10)* %.fca.0.0.extract11, [2 x {} addrspace(10)*] %.fca.1.insert) #6, !dbg !35, !noalias !30
  ret [1 x [2 x {} addrspace(10)*]] %4, !dbg !29
}

Did not have return index set when differentiating function
 call  %4 = call [1 x [2 x {} addrspace(10)*]] inttoptr (i64 137160996173840 to [1 x [2 x {} addrspace(10)*]] ({ {} addrspace(10)*, [2 x {} addrspace(10)*] }, {} addrspace(10)*, [2 x {} addrspace(10)*])*)({ {} addrspace(10)*, [2 x {} addrspace(10)*] } %.fca.1.1.insert7, {} addrspace(10)* %.fca.0.0.extract11, [2 x {} addrspace(10)*] %.fca.1.insert) #6, !dbg !20, !noalias !9
 augmentcall  %_augmented = call { i8*, [1 x [2 x {} addrspace(10)*]] } %11({ {} addrspace(10)*, [2 x {} addrspace(10)*] } %.fca.1.1.insert7, { {} addrspace(10)*, [2 x {} addrspace(10)*] } %".fca.1.1.insert7'ipiv", {} addrspace(10)* %.fca.0.0.extract11, {} addrspace(10)* %".fca.0.0.extract11'ipev", [2 x {} addrspace(10)*] %.fca.1.insert, [2 x {} addrspace(10)*] %".fca.1.insert'ipiv"), !dbg !20


Stacktrace:
 [1] macro expansion
   @ /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8398
 [2] enzyme_call
   @ /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:7950
 [3] ForwardModeThunk
   @ /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:7739
 [4] autodiff
   @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:647
 [5] autodiff
   @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:0

Stacktrace:
  [1] julia_error(cstr::Cstring, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
    @ Enzyme.Compiler /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:1612
  [2] EnzymeCreateAugmentedPrimal(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnUsed::Bool, shadowReturnUsed::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, forceAnonymousTape::Bool, runtimeActivity::Bool, width::Int64, atomicAdd::Bool)
    @ Enzyme.API /mnt/.julia/packages/Enzyme/RTS5U/src/api.jl:389
  [3] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…}, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:4095
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:7338
  [5] codegen
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:6146 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8468
  [7] _thunk
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8468 [inlined]
  [8] cached_compilation
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8509 [inlined]
  [9] thunkbase(ctx::LLVM.Context, mi::Core.MethodInstance, ::Val{…}, ::Type{…}, ::Type{…}, tt::Type{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Type{…}, ::Val{…}, ::Val{…})
    @ Enzyme.Compiler /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8641
 [10] #s2105#19135
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8778 [inlined]
 [11] 
    @ Enzyme.Compiler ./none:0
 [12] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
    @ Core ./boot.jl:707
 [13] runtime_generic_augfwd(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(autodiff), df::Nothing, primal_1::ForwardMode{…}, shadow_1_1::Nothing, primal_2::Const{…}, shadow_2_1::Const{…}, primal_3::Type{…}, shadow_3_1::Nothing, primal_4::BatchDuplicated{…}, shadow_4_1::BatchDuplicated{…})
    @ Enzyme.Compiler /mnt/.julia/packages/Enzyme/RTS5U/src/rules/jitrules.jl:469
 [14] autodiff
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:537 [inlined]
 [15] batched_enzyme_jacobian_impl
    @ /mnt/software/lux/Lux.jl/ext/LuxEnzymeExt/batched_autodiff.jl:29
 [16] batched_jacobian_impl
    @ /mnt/software/lux/Lux.jl/ext/LuxEnzymeExt/batched_autodiff.jl:4 [inlined]
 [17] batched_jacobian_internal
    @ /mnt/software/lux/Lux.jl/src/autodiff/batched_autodiff.jl:74 [inlined]
 [18] batched_jacobian_internal
    @ /mnt/software/lux/Lux.jl/src/autodiff/batched_autodiff.jl:14 [inlined]
 [19] batched_jacobian
    @ /mnt/software/lux/Lux.jl/src/autodiff/batched_autodiff.jl:8 [inlined]
 [20] batched_jacobian
    @ /mnt/software/lux/Lux.jl/src/autodiff/api.jl:123 [inlined]
 [21] enzyme_test
    @ /mnt/software/lux/Lux.jl/envs/enzyme/test1.jl:13 [inlined]
 [22] enzyme_test
    @ /mnt/software/lux/Lux.jl/envs/enzyme/test1.jl:0 [inlined]
 [23] diffejulia_enzyme_test_142033_inner_18wrap
    @ /mnt/software/lux/Lux.jl/envs/enzyme/test1.jl:0
 [24] macro expansion
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8398 [inlined]
 [25] enzyme_call
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:7950 [inlined]
 [26] CombinedAdjointThunk
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/compiler.jl:7723 [inlined]
 [27] autodiff
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:491 [inlined]
 [28] autodiff
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:512 [inlined]
 [29] macro expansion
    @ /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:1677 [inlined]
 [30] gradient(::ReverseMode{…}, ::typeof(enzyme_test), ::Const{…}, ::Matrix{…}, ::@NamedTuple{}, ::Const{…})
    @ Enzyme /mnt/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:1660
 [31] top-level scope
    @ /mnt/software/lux/Lux.jl/envs/enzyme/test1.jl:20
Some type information was truncated. Use `show(err)` to see complete types.

@wsmoses what does the above mean 😅?

Copy link

codecov bot commented Nov 25, 2024

Codecov Report

Attention: Patch coverage is 16.66667% with 65 lines in your changes missing coverage. Please review.

Project coverage is 73.37%. Comparing base (6f9f8d6) to head (52e4241).

Files with missing lines Patch % Lines
ext/LuxEnzymeExt/batched_autodiff.jl 0.00% 58 Missing ⚠️
ext/LuxEnzymeExt/LuxEnzymeExt.jl 0.00% 6 Missing ⚠️
src/autodiff/api.jl 92.30% 1 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (6f9f8d6) and HEAD (52e4241). Click for more details.

HEAD has 26 uploads less than BASE
Flag BASE (6f9f8d6) HEAD (52e4241)
47 21
Additional details and impacted files
@@             Coverage Diff             @@
##             main     #954       +/-   ##
===========================================
- Coverage   83.47%   73.37%   -10.10%     
===========================================
  Files         147      146        -1     
  Lines        6062     6104       +42     
===========================================
- Hits         5060     4479      -581     
- Misses       1002     1625      +623     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.


🚨 Try these New Features:

@@ -1,5 +1,8 @@
Utils.vec(x::AnyTracedRArray) = Reactant.TracedUtils.materialize_traced_array(vec(x))

# XXX: remove once EnzymeJAX supports batched AD
Utils.max_enzyme_batched_chunk_size(x::AnyTracedRArray) = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in theory this should be in now btw [thanks ofc to @jumerckx ]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are probably missing this on Julia end then? I got the width must be 1 error

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants