From bda114151b9e407575a375fca7e6484e3222a5f5 Mon Sep 17 00:00:00 2001 From: hemingkx Date: Wed, 8 Nov 2023 17:06:15 +0800 Subject: [PATCH] SpecDec updated --- GAD.gif => SpecDec.gif | Bin .../__pycache__/__init__.cpython-37.pyc | Bin 236 -> 0 bytes .../__pycache__/__init__.cpython-37.pyc | Bin 173 -> 0 bytes .../__pycache__/glat_loss.cpython-37.pyc | Bin 6595 -> 0 bytes .../__pycache__/BlockNAT.cpython-37.pyc | Bin 5600 -> 0 bytes .../models/__pycache__/GAD.cpython-37.pyc | Bin 5555 -> 0 bytes .../__pycache__/__init__.cpython-37.pyc | Bin 168 -> 0 bytes .../tasks/__pycache__/__init__.cpython-37.pyc | Bin 183 -> 0 bytes .../translation_lev_modified.cpython-37.pyc | Bin 8699 -> 0 bytes data/.DS_Store | Bin 0 -> 6148 bytes encoder_initial.py | 8 +-- inference.py | 32 ++++++------ inference.sh | 16 +++--- inference_drafter.py | 14 +++--- inference_paper.py | 34 ++++++------- readme.md | 47 ++++++++---------- specdec_plugins/.DS_Store | Bin 0 -> 6148 bytes .../__init__.py | 2 +- .../criterions/__init__.py | 0 .../criterions/glat_loss.py | 0 .../models/BlockNAT.py | 0 .../models/__init__.py | 0 .../tasks/__init__.py | 0 .../tasks/translation_lev_modified.py | 0 train.sh | 2 +- 25 files changed, 75 insertions(+), 80 deletions(-) rename GAD.gif => SpecDec.gif (100%) delete mode 100644 block_plugins/__pycache__/__init__.cpython-37.pyc delete mode 100644 block_plugins/criterions/__pycache__/__init__.cpython-37.pyc delete mode 100644 block_plugins/criterions/__pycache__/glat_loss.cpython-37.pyc delete mode 100644 block_plugins/models/__pycache__/BlockNAT.cpython-37.pyc delete mode 100644 block_plugins/models/__pycache__/GAD.cpython-37.pyc delete mode 100644 block_plugins/models/__pycache__/__init__.cpython-37.pyc delete mode 100644 block_plugins/tasks/__pycache__/__init__.cpython-37.pyc delete mode 100644 block_plugins/tasks/__pycache__/translation_lev_modified.cpython-37.pyc create mode 100644 data/.DS_Store create mode 100644 specdec_plugins/.DS_Store rename {block_plugins => specdec_plugins}/__init__.py (67%) rename {block_plugins => specdec_plugins}/criterions/__init__.py (100%) rename {block_plugins => specdec_plugins}/criterions/glat_loss.py (100%) rename {block_plugins => specdec_plugins}/models/BlockNAT.py (100%) rename {block_plugins => specdec_plugins}/models/__init__.py (100%) rename {block_plugins => specdec_plugins}/tasks/__init__.py (100%) rename {block_plugins => specdec_plugins}/tasks/translation_lev_modified.py (100%) diff --git a/GAD.gif b/SpecDec.gif similarity index 100% rename from GAD.gif rename to SpecDec.gif diff --git a/block_plugins/__pycache__/__init__.cpython-37.pyc b/block_plugins/__pycache__/__init__.cpython-37.pyc deleted file mode 100644 index 8b9126dd1833305144c3fdfe126fb6bb3eedddba..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 236 zcmXwzyAHxI3`LVr9;(F1&cITJOsM~h61I+t-`HSH&!*(MyAV3XCzDa+{5~5nbE|d~zZAlb7D4s^t)A4z* zo-c=$X);~7p(#^Pz&Vejz=wOQYh`s=IN!-jP{#Roow_4NWpzg`kf`B`J6C{B2V-N=hn1BoiATH(s5-AKRj5!P;3@J>(44TX@8G%BYjJFuI z{4^P(IMZ_yOX73#i;Gt>6fpy3z{D>b{fzwFRQ<%N(xTMT;?yF2Ki?32cSjffq@4WZ z?D&G5()7%{V*TWz%#zfi%zPj-K0Y%qvm`!Vub}c4hfQvNN@-529muNBK+FIDO*Je1 diff --git a/block_plugins/criterions/__pycache__/glat_loss.cpython-37.pyc b/block_plugins/criterions/__pycache__/glat_loss.cpython-37.pyc deleted file mode 100644 index b9b7d8425e514886d2426d9cabb188ec5abf6188..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6595 zcmbtZO>88`b?)l!nd#}7A%{Pb$X)HKBL{{@-X-OKM4?xbwUPg1Mz9tZFo@FIoGK1` z+0#8+)x+fudXOJ7v3yv-xo-dx@x_M#zSy^1W56(cuz?&Jxdn)^fMFQ$O@?8AueztX zq}Rd9X`*YYUsb(&RrTKYUOj%H({T+vZ~ymiZ#SMdjQ^pJ`BOvXHN5JdQE-E^$mr9r z85twiXRM%^eN(qtBil5XF{<_Jy2XkbBd70-n*FB!wj+1c>bFMiep`RnB6HN~cS_;Y-LFQVq4XN`9s zGQPx@k8SjC3R}3Evcgw2rS(n2Hu`O>a4rAHTfvTq-WerHx+nO}Nf3Q`6sAFnrBrJg zSfzOMJAc))_(+W-W6+Lvwc!%j7Qe-H&q`Tlo}4?e%SmyE<_Fr7#-Q9^ERf9=-h7hk`*xf3OW zeSaKHhGDEWtDRPx)&6db@5>fguHaRAe_5MNuRXH?o}dQS7R{5g92@tXItHxfatyg0eo*0b{l z7rVhEO0nhl$3p9Ka1hAc-V@Py+P;2$qy*+3?eO3c*=o;{w10Ayif2%4L+9H8=l)#y zytW%9K}srKh%dX;R~JjD?C{1I=W~k|ww=(O5of_(` z)I70x?Uc=&%$YUg@YsD|3 z#eGss);ekNI(WCTM%Fr_IiN}Vw)}DCaMag~RRb;V5zFjVW6i+)Hc62+v-&E&vKnem zR_CpUHuS7B{h8~PPclizlT^Lp?F8vy&pY%Y5u@b(DB&Wix}`WoaiLq)ID$=rgZDTO zw`v_HQ8>5{ARCTlvh#`;?ve&e;f2amFj=qYQ1vIlxEPC=D-W7W+8%}bA`17CgnMd` zU}&H`DeeW5tBqr$i>84paHCq48_K=ck?wSfYMV(dHeQf8bOr+_59fTL5O z_?$u@HK%ESlIs|k+hi?ygUrDXfavlCqFOv0T~uU%e-QgzZPJ_?P^?RA`h}+gltD7W zfd~!!8(YAxAW}CrN`{i)<Y0e*2)&vVPxFp z%~O*%VR*lQ_cJSc>y%aX_G9%9^{-kGu;6Qfcb_vV8iE#ZTzOEREygc_r&YDM;Pd3Z z)@(6rXNyDjImle~U**r7nljBg*&-xerMYyE&fWA~BefAaJ$Gzm?VAXxfW6(cmUWI8 zU)#68&Ez54*L8cRXpfz&ds63XKH5fT4rZU{PM%vY?(iwwf5^6{k^Qk zpU1v~vSsiqCG|f+y`rgCwv59s(ai6?He1OojQC~hoHVl)?yei)xx$y$jmrS9c$sKJ zoc9?uBihwN8$`! zzphbLd4ZQs(BRz*(>;upLU#1-mu6g(NLiil`0R(KE zK%M62B2GqO8~~EeWRd|75&u-&KGpL88HnhHUWqsq8Szk!qu@UAg@~yP5KhL!YEbUR z3QR;i5GrrPX|fNv%bT%C5qAP~b63fMUy@r?hcg_ed%&dv&GOb*iUYs=uqlAtE&$Ae z6uFfS5iSwl8SgF35IS>X z-S`F~%)2mbf*?KhwgyZ1K6B&coJAL)^MmHrP{iVJEN}b+u_NBLF}?9LFoT$pT zs#s|^bWUF4tL0rr1@gKiZz*dL=9gm1FQJ3{G8N}BQGNxqf5cnLC_jpNuT}u5_7tC| zL0c&DR-r{cX;FDk3I>BozMw<;L4q?rO#QLm;hu|Q#fhbgR5rH|#>iKQ(_d2YX<}#L zn4ufuo?rSun+F&{;Gf&Ltc`lBg?FpzpeU~w(^IGccPq3^QcL$L@qCTOxzL5VDlKcR zDfG0YNkVD)87lsQ3R+-pB|CTWhV}{Yf!eQpmFNqls`~|tC+%L|C{Ig=Ip4r2qp@te ztZOba2QJYjCuzEHlqOqcHe6+!E#R-to<)e=#_y_y-t1ZQf2n4hHu_ND?O;qBquQox zU1cv?(~o}$%|VRfL6AyODAuN!!@{esB63OL8B!=SF0N3T0zZM9tfX5P=w?Z`FVL;j zx`TL_{PSfatpR-zD_^ES4_9`&UFr5_X8rN@lLkB#U674a6a5|D$&90-woud#qs-uo z>yYxj+ut(YH+YvX{S!PZUq(#-gS5DUuRtV+I4N3mr#Le zE|Y1(L*rI@YL+HyVm3>XvT4&s9z*^$J^(S=T#3z-%=HxXIi&p6(=Qh ziIQtup-}i~CI7cG$^XPDH6(qQDZg_zrMB%yw5yNE77)4DWHO0)2?OJCR0%~h5D;m* zCd4EgfF_csH|f%XNK~`K`Dwsf0Koakroc%bhUtC8_Xk6|a%0bh{Da8a*($E^H~t;0 ze~s6e-u!1LDZk9SC0Z zDb%)CLwc$(#uAWKVL4g}3m`uT3P%h(qS>WRZ!DbG{(nV-`WOmIa;%zbYz?UUeEGJ4 z#xAZjKx81X%dP^8r=R@rCo>@a^M`iu4ZQkFT7(eOpH|V)yL*D}(TkfuwNqULyCQaNZcHPB2Z-JBlb?yM2BeZf|(+h%}{^ zbT3pD8;rOmClqK#kr$4TJ2()EB0Qv3cW^=G72J79lL>wns=oOcRGp78@+%h(&c_n` z9pM*C_z&6}I(wrref??n`v2tCv$rKu*ToH2UPY~p*or#4)$35klWytAc)C7iTqY6K z+rEDn+0~*&XD0JDLYOohKm+#@E-w??vmoYg=rcx<+dD}T$%Lw%GC4XBxcmEOTLP6_ zr2g7cny67lQP}2~ZM)rZuep}{6_a&eFqx$GbD0z(kkKge3n^S0E7eVz^eS0@VcC!QqpGW|54lGL0kB4w+>FQa&+vsK*8mVv3&}E##o_$M7Hz5!#{>@N`rCq{uN!q9~ z%SuvIAZgkn3sj3IBV@r3i9lzp#W+)%qDC_3&Q_tPH)ycFo;ZjOk%j?QbrCB%R+l+{ HXMFbeNZe=I diff --git a/block_plugins/models/__pycache__/BlockNAT.cpython-37.pyc b/block_plugins/models/__pycache__/BlockNAT.cpython-37.pyc deleted file mode 100644 index 8678637e508a9d489b4734b27e358bd3fc299503..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 5600 zcmb_gOOGR06|R@<^4slx_e_#zn1mQI(=b4U0KsHPCJRU-AWEWAqEfZ*wcFk0D$lL* zWTq)Ypa-%F2~s2k%EG;5!4E*<2k;A6+$}<4@)Ov=cTSZ(-4hbApsYGRulqXZ-gC~a zTg_(8z;pNbrGvjcWf*@aV)2xK+{Byw6F?Zk3=NO|to@EklhxXX<9LAlnG%kDP zamA}JT?*ZC)vGdI4r}AOS7&_DPbF-On_hF=@>=7z*Dhoouft=lcq{mK!|r(1TQv>Y znXk$9Gt=9UtD-7u4~$dO+mvNdKQ&)AL_;(m7@{ewr#ARovLfoDb!v&W=&n6nDbnpkJfC1L)==xsd2YK{G#sY}yz6ed7wFUo!_r264B?OD2|7gSJ! zTE}|+^^Lv5cr5pV6IgtjNVT{B%By=XzwqMTK^XUs{Yf|-j-q6b-IeU!A|@<&eR5yd zc(ObX>M7{}7*@?}nDkm^&1#rx6FgnwzQ~5lhYU+k@*03NJX08+CG8`e2f&h7de<13 z!W7m6*DGVU*}~y{=E|xli8A9g;1%IAUdJx3iTVT6YrJbXhS!8io7!E>x9{+121JHI zXOqTH)&NrD)R>!@nOT{AX45&!onh%xGMz}JUEk*|=l2{kW;oDJpoWRA$YP&WM83nZ zIjg?k4}&D}eKxMzzyRtJz!zvZ3&1c>Ru+@GzUX7m@EC60m+|C4FD8xNSdSeT{DK|dYEk#-JcIO*A{jJaty5CSI_wrCXCbS$G3b|~X$ zVlU-Jpv{q6;wrE&;!QYiS~bhTHnGlIWYOc&Ut|sP8C&ETAgPi2>eQY&Df~t@IknZ5 z%$ZxMO$RD-;6Zy|w}Z(fyq|{(THfuRqg?c$D>zm`nkrpl4OIcNSJHMrOw<#k;HJaW z7o&cv8*C{*jgMuNDDuc-Ytqew=_nL_VcxGnBcVYIY<@ZjCmTy|Ur_u%EKt#=qD4`F zQqL0T5y+RL$k4nT^(DfdBY-URX$vTP{ViyG_5upU{XXphdT4)>L%)wVxr$AWtqLd6 z@y)D}nU#NM{l=VC;2><}BgDL!o7r<4xN~M^?oDG>&8qW~AW01p(jrNYfeVd#RxdOf zX?f1!;LMseg)wVor4LN?*US;dah0&tombC{56t(!iD+EQny2Rb8`u=aQC*ZWJFAHD zd-kkNG5JTvj}d*d_Pik~7`u~oGUwRJsyB^dtQGXt%v{L2n0@Qac+;A#&era!w4JS< znlBn}4QA_EEnD9f@|ST1Q=BqvzGU zw(%0>z~zTB8m5QILkfM0!05OCc-p*; z%vFt$&Vz7tB6WA74)He#;{d~laWvF!(nq{P(WG7K7g=en+SD;>f#g^Ue=tJocThNu^8nTkg9C{Y zVI0TlA(DbFPbUH;MDLQiOzmo9`68(8fuU|Nm%?7o)ef0Qmyh43_^!SRvDb;>2+Q>+ zae`bCM+r(#41bQUG7XO=UNgL*&PiM%Sx}`FaM9r9-WWxLST`2dCvSRnxLXu6@!TXB zOFx>9y&5McpFLt9BqzFp(vW)ftZ?MjQcM-a4>a=tPI}%ZyZMvb$7{3T+=^`_)5n8* z1r;SMAZ}s($(sO%vw^_gL15qFD#rm%VLne}T~t3^=C{nNrfXhdtrqKdknXzX1~dv? zhf*J^B(rSR>>B=0I?Gg13`<2}iwcm4OyKL9Pi6BoB)sAKcTfiA5=HS>w~J7PNpyp%!BKxK)5BOOO3u2@={5>qx)Wl>?uVagS_%b;C}Db$5bm6>W3RD~&|5SF=2wJ2CJU1hor zIv=yfbf*}l&eTdVN`t9x(QlKf)q-j5q9&+XcDA{dgUSyk#}UE4$$ih{$LRKxjxHKAnDdZ76$jFN*;1Wxw&9cxQi&r z_=hqO$wRK&wieQ5VnA_+Juh$T)rBaOfr=135TtasI0wbh`BX3RC`qg@SQlogTMKqR zLA|-4ADf%rS_mGWmR?;*&d+M)QG21{3F6%l=Yk*Ir84_24kZ>|2zB!uZ-5G1$#AMG zi>810Jw1Q^uHFnpFhOJKC-DIH*n7C;rSV{(w>}{#+H$tlyUAgIo5EtId5FcNDATgb z^jfZg+hxw6TSlUvL{8&ev5fM#lX8Px-dYaJU2?Jd$Rou7Wd41o-QPxqJJ z6M2OzVJI~qKvq+S1da%V1jYm+0x^LJfja~gfrLOxU`pUFfwu|VCvZaGrvx$rGXg&& z@N)uGYpZt%{DQzQ0TB1-w(xk)SUCt1dErCGJJ3pQ0vJxqy!h)N1GKD*yqvPk2KYH_ zS)|pm)hYV85HgpS#SbCj1IyQ;Wyr*3kwZovGVxskm&o#39`f?QSz(^Llk@WMS*0u0 zx$v3yk%!MZ^vlaW@*vt^-l7jTt!DVZR9vN-`?-@wm>@fjLj!y7Aq9N)-ou-Ov8Nt6 zj(pauuNhYwQaAo9fr#dbPX_dGvzW z5!~_{DV29*_T_SsvCYq(#pAo<#$V7!{%ycf?*i{#E~>^s{_#uEX!@Cgx{fMOAwuAw zD>)cNlCIqTP)5X!m9j~>wNTDK+bA|P?G>+kQF1*BQohW4&z%#X!bq;CDAbBCYADz6 zbu8-3#Wk1jnBJGqHIZRLeO*V57$EN8v!0SkLbE*A4`09Y>nay4l`tOS>(Vf8UNU`$5=`EVJ_*-j-$l2b96)VgLXD diff --git a/block_plugins/models/__pycache__/GAD.cpython-37.pyc b/block_plugins/models/__pycache__/GAD.cpython-37.pyc deleted file mode 100644 index 905281538b3fe38bdecf0a2f09149a9048f7b7ab..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 5555 zcmb_gOOGR06|R@<^4slxPR}H1CJ!)VX2=M#00fheOcszvz$l4|M5SuqYqzJ%WzW6k z$xKt`)d^XJ1St{%W#L}3;0GYFVaE?(b+-tK$xmPd-#JzGbWaFkL0NTtUiWp*z2}@; zFE*Ps3(sF~y?pTJCoJo4L@b{YkQ;dQKLCU!?AQwE--+$96F4^EZtRY|z+>EtOXG4- z9#?`2)1}xSSA#0!<+wJk2X)4m{Z!({xEVCZt)Mk-2kk=E2|7I1TCj$HKkkm#gLT`I zoyCUSJhg)@xh|@r_KtO82ivkN>L>OqmS~9PJCWK^fj;x5fXq`BsEjp}ufjM<> z)lG}egh!2Zrcmz)SBDC z+yZw}YvE7-881~~x z>o8<2a*Cz> zdVqPskq{4v}H1X7cz3arTOd?!7%xVsE9b_aeQQ9?Csh-6)Fj zKb@s}lL0fAYjR9E>ZhYgV!T5cPkXK^V{XQeguwBH85#vP8_Ohx3CcK!*f{z6X+z{j zcoNtb@oEl|PR;SKHJrj2%g$C_B14d;*ci_MNv+&NC+^%!;Vm-9iK{MW-oi;;IyjjJ zzq$LS9Zje4{X8Vl%5L^N<)a5v!SRXGRGAWMs0x_9l5zX7R+mV@PlstJM*Y+@*hpbI zIg*K1k3smza!m7JHi|`9SoSN>&@_mJy-xf7c#V8~PUruyHbs`8^-v5_w58M+2;?hK zJ;G@viVUs3M1VHZXUx6u@YkX7nR736_WR^?dM@H64}J%)ehM2JdlRmqQ<^y=J1hUj z`ISAdz%AIjw-M3ie&#M*;NGd7`8TY2HLEU4f+RIaNQ)#n1}-$}S-sF`q~!&l`cr4# z6xO_zl|Ho9UouZvM^(a7e^EWPKD0mh8e(lNYo6F2Y+)-{x9g&mxmiV&-*@M2in!mm zet_7TwHFOh!PuRwlX*u@R=r^rW38dDX68fI#q3+B)|<|JeZFx|rR{9}#D2+oYcSu; zYT4!mD_bMDD4kk6)}nK2WhLQXur{p48ZyCFwk6ynPkoYgN$Yl3xafI3Z+VIcmuQI1b~9*1cUm{>Dk9ai@_mj=W>s zS#o4tJsYbg)?g}>pGQfm8cZdVaHt}|DUZ(|BB&aUqkE=0Nu&Yy24S7#(Xku0dv3b--|a&L^1!NfF{)+cWUb+}s;GYNbhjb)h3#zBp< zQOF)~5A?CApzxz!11BB@wG>lDX#>qXY?EHF&2Ikm_6gc7IJ06`$@I~%SV6rA3y2r7 ze*GqZRSEO}l6o7SY-46d4>_>ioIk z(!(^U%p3K`GCiCKMTywdxk@QG&w`__5cm@r^6 z0`Q%(U3MC7qt>W2d_}ULKE&>MUOq3r`rP$%HO?xne}n*lw2sS=uuhO6Le$WpP~|X% zDjJk0N=&&-l|_Xqk11c^ih_0}rcn1WRc5MDP!*<-J6Ps3)uJfLbd~8g=zPo?)16|J zI#X-KC=I5%MZZm^)(fh|6fQs9Yn!R9g6c4}Ep}M8#?%F|%T$-Ci^c5LnYtt%W7!5% zj~63sGWA43Z85c1P}@v>uAp|9x?E5fn0gW`%x$&H^iwPPBGXS7+LyTaeP-Wum+>|) zgr-q?Xx7g25OPNr3IMaYTWpQ+0D*b?K$L%GP;#@+}B|Mao&=0sw z+gVCi83DB%cD%f4)|aAKMk+znKy)(Q;uI7^=Tkk;qwKJ`WSyI(X)W3L1kLu6eq?TD zXDN7eT4sGIIXkPhhwX)$OcCgYI2HVuE)~>waU`+uLTH+2cmvehO2#u&SvLLMZ|V8t zx6F1VqA427P)`Q9p5DXdEu9PoX6I9aqAjOcv!@Ru+zFO5&0{MjMJbkDW;SvSTpDu* z-6ynp9LbE6#45bws>uyCAZkL6XZfT5Hm zFd#stQzHUL1Y!bX0ttZ$fhmDI1QY>HASEy(aF@V60{00V6ZkQKjKG}0PY6(bt$s@2 zX9Ru@fT+hMHQ!`b*BP^R5NUbt`vkSnHvlZJWuO0b5c*rrd0tLAb_4tzwj9#xxat?^ z<6OvGT?;>ggbys=epVq9*Fp{%dC0`)23!}*8+pjf180SK{%+38!)KLlP-ntt-bWrj z>(DQ+`pAQ5gL%t7{Ir_kyHRmpZtmw!I>ZFoaU2@hdk-k!v-cj{AdEfr&~fCmUM+V! z`}v{U+2j8~oelIEe(6x!Is4h@59nWa0dlu&*cg`k0-tQxByk}97{q}ACLqHBh>LlEL<&O+V-7e+@ykL#BR@A)Ke4K`D7CaWwMgI3H$>mv(M3NAq&mJJ sr!+k?uUJ1fKP5G%SU)~KGcU6wK3=b&@)m~;P_Q&7)edCKXCP((0BtWQ6aWAK diff --git a/block_plugins/tasks/__pycache__/__init__.cpython-37.pyc b/block_plugins/tasks/__pycache__/__init__.cpython-37.pyc deleted file mode 100644 index 31e9360f8e4cb64a7117479ff1ed3afbcfac0acc..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 183 zcmZ?b<>g`kf`B`J6V!q9V-N=hn1BoiATH(s5-AKRj5!P;3@J>(44TX@8G%BYjJFuI z{4^P(Bua`B^NMp4OEUBG;&W2V;&by;GSf0sQ&ut*F#{EXiC^aW8Tq-X`iWJgMX9C5 zsYUvJz9IVVjxPF1Ir+)i@dY`h>6v-O`X!0Q*~R+t@tJvh^Fed;2 diff --git a/block_plugins/tasks/__pycache__/translation_lev_modified.cpython-37.pyc b/block_plugins/tasks/__pycache__/translation_lev_modified.cpython-37.pyc deleted file mode 100644 index 1bde05f557043d895336f279e7699213cfe7c1ed..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8699 zcmbta+mjsES?~Mw^xXC`x~!}y6U7)0IJ+`MAQmb|mb8+Eq$V#ltDqVYDZ?19un#Pvc@+TTw z=9PUDyi>e-@De|Ls5d{yOKgRoVylk~^s2IB=yi-Qsa|K8d0VR=e~hQp>W0j-R?=#R ztvHr;C*UEg>(UIPZUeveFbTpKh1DCaes|dF@~^c5aRmcno`CIL-j0Gc-|P>2sFeio z264iLkM`ZVDVOhvRzD6~Nf7lfNBvIFm6Zk79rTjN!l=dk5YrDZwNIPL=I-;4{T3$Z za?EX9zmEA^eHQh6#zUTPS;(qAn#S!%em@FgF7usX82XCoZG};LSC!lJwA0H%;Uez% zcR{$;)S1>a-qX0rbY?I!(3$nfXjiw-W3EwpsI@p4ahV-3l^gYvbbY_y>T%zfp6~Y} zHVmm;^!pIQso*yYY=y0& zPmvV>p%e1#vmx|u_JAaCqaS%STnim*uiobOhP>bAwQIc3g?a(?8l6?G6N%aa7Smy1 z_P9{f+~&OY(b?zpI9t20og{fPB-(*K)^>~?^H87Ii8Y2IWcscZpGoY8`q&;D z%s8o~{Nvi1HgQz%l;#X=lAE}RJ9ftUIER+}*ah{*rs`Etw5VuF(XyfyMVAy^X67T} zu8eiOF$-sIkM?drfz?{IZUB+3W%{cY55gc>_tdk|dA}vP@%dSeII2hO)i@FL^EDC< zlIUccPaF7pcm1^*P!{xGtKDyf0Yi<)NqshppyLvCkh&{#?WngE^tmi%TkwP4Ae2kP zLEJ)t`vVd5T4GODLf(P!L+QrRP_+3118Skm6It!G9%POFAc_MekSRCqvB0e)V%Tp( zPh_iXo&l@1_uuKxefTWd>v+TZC+aD9}hZ z7HyveZ78^;o2ZibWi*J_hNyPQMv|=81q^&k3Z9a(F+v1m~ zle7CVo*K4K?KHK9YS+@ls?x-Qb3G$|frb>7$;u}G3OMmAAevLwm-MQ+LP|sZS)g<~ z3xA`H1<6~Gn?(^_U!1>9%7iX*Pe_l@9sd!??`a*NpBbOH6GN3vW}$4VGU*a`VyQCe z6L(^Ns_keyI?L}G;>y?<>-+j`E#rCMeE>OFb0x+@ z7*fbYxzQ$M(nYT?hnv)Dd=^|<@IXJ*Cb}vaC>g3`q69%au$T_vwot;-B<3hNhuT37 zw~Ic`*g43nR^t%L0KK3LF2OrHgqNn;Y|N0BJ(}x7W8zG5V>_K0znE#KfP=$vJ%YNJ zxXd2oK1rjCxGm-!;$ROATCb`dln%8M+9Zz_H6PR0v~hmd21h$6GpG{u%#ZWft&!GM z%|dO~|InCtV-J3v_HBb{WAC6s&vt-afG!<^+cBW7&(vz)3>y zNF)IIuqX56ZjHF0IWVv*oi0R|3u!>7n9vCIf-Gf{nMAwrPQ+E(%QYfz(;_)i8zkLR z0aiX(F&h)NBLSZ+tHA>A&ZfhdZn8bAwued7>ByCVfNegqB24RV&~nP|D^V=_M|{G& zXpKoJXu3y#2GJ!5;Yp*YlXw;3m%2{#pX1<^#kDBhzU zN8u*EMFXbjqH(bx?@Ba@ulwQ~RQW9;AA&p$o#OQ6(bx&P#AiS<;OL%t##jPm2sExy z)=vXOE5>N^{|yydZ3Yz5>E!Jw^>_z`{}WJ-C_n`isFO<>P{sfTy)wKmy;09S>jkMJ z6u*gOg%479o_o%M1ncP|^!!<@^#gQ>=|wIc;OO$te_XSfXNPsfmV7I-B|EyZB*p>R z53(f20Z9^B61aoPl2`|h(z7HNHVm3gdi@|zx?Ncl?+_9M*>I8q^gr-V@}YiloMW_f zP=_pNm@W4LecT600;e&ZXJ5afeRB)8!XD=#TUZP3CrYq8kT5EdJ@_`{4iY=7UC>~R zRekKiS@ll9LV&0G!)Z*Hcww?nWo zQ#&-ZHsr&pz-s|YB!j7hD`_TMlghdh^Dm+*?IaTIZE3+4#o{_~j3kma)OME(f_%M@ z!tOK;Rx++;ZBdnExmijL0+|+=2_?L=A*)vFNvs(uEo%2pVoD?dM?=L@S~UR1EKSRt}Zgd&}~sa)kW zK3;7^$#rtwd%T};raWcHsIq`cA%2sFwTUnyJ4DDdPMN`g_V5t$*kaSiL9*>{@m5dSK6&=3HvRwuF}W|W zpZ+cvS7k*l9Q3N#0h*flae)82X7 zim%+he7Yc(5ixDsMJ{MS2O`Val+)QP`AP@rwwT{rr&ZLcVY|~~^SfyxAQGZS6D?}I?dHW&0$}dz5A_5J9(gt^#zlOtYGlKT0ztL5y@}BVb&$&%z4^jN`phZYU94Og|Ts<9csFkAbd)2qA+^e zj=5vKr0pB~#w5om(_-$VTxXoC!VCs@1I!}hc928KUW1wa5YZ`fz(f6lmbg3l9S?46 z?oj_eQsOX;6ulZBniIIN6Ynle+Sj#P+6OSL;=3?$`Wp!9W}{q;S|a5G^eA?IJn(nCst zl#%yGqO|Kr!(=dgD)Hf^!;#`2{2Ml-a-^E+!djlzPr*?>1>tq^Th&KjdI5~hXPXHu zA$10+e@0~vsRQH+zW-Nvk=g+g6%>VMK!N{5b zULk-pc6X2nRKNqgP$Hmt;1(zz+BM?ePKpop2N#mkIH!ELbMWENO1>9*J4u-$rxO}h zb$2QmH#g_r`}FQ2(2%MeI|^uSLwpC%CB6&tnDmn@!6o5CKi-Q;+NV`0}HHxPQwLg7p0-15|_=9ksEQt(p~hE?SXOn}P$&h;?@uWT#1 zY!l*FX`fEI2k~Xv{x)$;C`rf$eJzOVUh0|rkXnC4ZlRaQlP;`m@9DFj_WVK806-2 zJ_R{)VdDP)fy<(W2*z1?mQ~BE>>1e#vkL&hiMAhE6Q`?nwKS{z;QfRg8)qL@1rGmJ zfF*aJUb=GD0aoU$Y5Q>30brL}5*vWTMUqp;-AU@pr)LcpPG_b%NfK6Ib> z8wxv^2NxeRodEEmV~JBBZ8znpz9sQ5V{4~OPJCCV?*@;c-Jtt24#k5iN0+%xS7+%( zdVKQUR494@1x$x@A>;;fN6{7PkfgCL&M>{7noCnfT#on(+8n5Finok0#ST_OF2C1! zr*ZRcGan$$I-6Un3)lECJrKfZD0MX->;YtKy5;20jc7hG_!1Is-FY zm0on8`&$Ei?^EVd={+ZP@9^c(0yCB#M+#xUeXJmLe8x->s&GaS6MqeYOraK+QP7a3 z^>oKsLV^f>$BQI|o+VN94h)=mZa@b{ztv(r2Uw;3bH^#qX%fAE-2aZF_GJ zA@3Fk@%h?lb;`vAsXaemlU5(G6b=sG6F?yF+XX1Hs~jQY9wOmkb!`Ov9LlUS5>#^HsI&b;_y_&U>K`P zrW9m;?oZD3!$07;)hi%hvQH)l-#7OLa3S%u@mwcN-j55cxca%=tcx+W1CE(jvY< zgq)yM@jj%|r$l~2E%nkck##tJU|J=4`b3cL9~plex^5fM3$%-4;BmjwsQzHoTI-9@5GKQKK2? z2nis`4F#?AsDDN+^z6?%H0fS|_k)kiZ<05y{CZ?|RUBKyZxQ}S7SS3i&wI=>tNQ-{ De?LhS diff --git a/data/.DS_Store b/data/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..ecf949bb6385597ddf74afe0c215687fddd668b2 GIT binary patch literal 6148 zcmeHKO;5r=5S>Mo5@NzZqsJy*5dn?Sc&Vrduh!^64N}`^W9=F!_CQE_)j#9Wf8yWK zncYQ2vG~ELpM1=6rX`Ok)H61q3IW8h z3jSC&H~d8f=--tfg*!-~4qyLYa2UxeFOnW+(Dc|vuf8e6Xd=r?yZ46m;@wAKH>uTL zMQ$cPJ6Es@R>`_>Z)EBw?WEUjwmWB3JC-u?hH=|F34^p%Uf7dy(hlRGqX?lNV9NPv z82d79%3kbu71z}rR?#Z9%1izJR-RRwCk!S?^5AyjNXFNgLF~`0&5w!S0r^e1XUzO699G#_0akz&n63h50klff zeGML+6<`JapaQf$*yx0g!AzrCIY^vqvP{!l@7wu$R#Vl3XCc+qpM}w{|_eL|3^t&V+B}& zf2DxP@4Nd={3Nrtj(r^MwKn=CIvM3<8b2r)sI3@tX)9hvSBC921&EHpOe40S@j*bz Lzy&Mts|vgWpb}>} literal 0 HcmV?d00001 diff --git a/encoder_initial.py b/encoder_initial.py index 2256194..2e3f500 100644 --- a/encoder_initial.py +++ b/encoder_initial.py @@ -2,7 +2,7 @@ def model_preparing(model_path, save_path): - """only save the AT encoder params to initialize the NAT drafter's encoder""" + """only save the AR encoder params to initialize the NAR drafter's encoder""" key_l = [] raw_model = torch.load(model_path) for key in raw_model['model']: @@ -17,7 +17,7 @@ def model_preparing(model_path, save_path): def param_checking(model1, model2): - """check the parameters of the AT verifier and the NAT drafter""" + """check the parameters of the AR verifier and the NAR drafter""" key_l1 = [] key_l2 = [] raw_model1 = torch.load(model1) @@ -37,6 +37,6 @@ def param_checking(model1, model2): if __name__ == "__main__": - AR_path = './checkpoints/wmt14-en-de-base-at-verifier.pt' # the dir that contains AT verifier checkpoint - save_path = './checkpoints/initial_checkpoint.pt' # the save dir of your fairseq NAT drafter checkpoints + AR_path = './checkpoints/wmt14-en-de-base-at-verifier.pt' # the dir that contains AR verifier checkpoint + save_path = './checkpoints/initial_checkpoint.pt' # the save dir of your fairseq NAR drafter checkpoints model_preparing(AR_path, save_path) diff --git a/inference.py b/inference.py index d149191..6a8b753 100644 --- a/inference.py +++ b/inference.py @@ -142,13 +142,13 @@ def forward_decoder(model, input_tokens, encoder_out, incremental_state=None, @torch.no_grad() -def gad_generate(data_lines, model, AR_model, task, block_size, batch_size, device, beta=1, tau=0, max_len=200): - # Generalized Aggressive Decoding +def specdec_generate(data_lines, model, AR_model, task, block_size, batch_size, device, beta=1, tau=0, max_len=200): + # Speculative Decoding src_dict = task.source_dictionary tgt_dict = task.target_dictionary data_size = len(data_lines) all_results = [] - logger.info(f'GAD generate') + logger.info(f'SpecDec generate') start = time.perf_counter() for start_idx in tqdm(range(0, data_size, batch_size)): batch_size = min(data_size - start_idx, batch_size) @@ -167,9 +167,9 @@ def gad_generate(data_lines, model, AR_model, task, block_size, batch_size, devi start_pos_list = [0] * batch_size finish_list = [] for step in range(0, max_len): - prev_output_tokens, start_pos_list = gad_forward(start_pos_list, block_size, batch_size, - tgt_dict, prev_output_tokens, - encoder_out, AR_encoder_out, model, AR_model, beta, tau) + prev_output_tokens, start_pos_list = specdec_forward(start_pos_list, block_size, batch_size, + tgt_dict, prev_output_tokens, encoder_out, + AR_encoder_out, model, AR_model, beta, tau) for i, start_pos in enumerate(start_pos_list): if i not in finish_list: if start_pos == -1: @@ -187,8 +187,8 @@ def gad_generate(data_lines, model, AR_model, task, block_size, batch_size, devi @torch.no_grad() -def gad_forward(start_pos_list, block_size, batch_size, tgt_dict, prev_output_tokens, - encoder_out, AR_encoder_out, model, AR_model, beta, tau, max_len=200): +def specdec_forward(start_pos_list, block_size, batch_size, tgt_dict, prev_output_tokens, + encoder_out, AR_encoder_out, model, AR_model, beta, tau, max_len=200): pad_tokens = [[tgt_dict.pad()] * (max_len + block_size) for _ in range(batch_size)] for i in range(batch_size): pad_tokens[i][:len(prev_output_tokens[i])] = prev_output_tokens[i] @@ -258,9 +258,9 @@ def gad_forward(start_pos_list, block_size, batch_size, tgt_dict, prev_output_to parser.add_argument('--output-path', type=str, default=None, help='path to output file') parser.add_argument('--AR-path', type=str, default=None, - help='path to AR model') + help='path to autoregressive model (to be accelerated)') parser.add_argument('--strategy', type=str, default='fairseq', - help='decoding strategy, choose from: fairseq, AR, gad') + help='decoding strategy, choose from: fairseq, AR, specdec') parser.add_argument('--batch', type=int, default=None, help='batch size') parser.add_argument('--block-size', type=int, default=5, @@ -283,7 +283,7 @@ def gad_forward(start_pos_list, block_size, batch_size, tgt_dict, prev_output_to device = torch.device('cuda') # NAR drafter - if cmd_args.strategy == 'gad': + if cmd_args.strategy == 'specdec': logger.info("loading model(s) from {}".format(cfg.common_eval.path)) models, _model_args, _model_task = load_model_ensemble_and_task(filenames=[cfg.common_eval.path], task=task) model = models[0].to(device).eval() @@ -308,11 +308,11 @@ def gad_forward(start_pos_list, block_size, batch_size, tgt_dict, prev_output_to logger.info("Decoding Strategy: Simplified AR") remove_bpe_results, delta = baseline_generate(bpe_sents, AR_model, _AR_model_task, cmd_args.batch, device) logger.info(f'Simplified AR generate: {delta}') - elif cmd_args.strategy == 'gad': - logger.info("Decoding Strategy: GAD") - remove_bpe_results, delta = gad_generate(bpe_sents, model, AR_model, task, cmd_args.block_size, cmd_args.batch, - device, beta=cmd_args.beta, tau=cmd_args.tau) - logger.info(f'GAD generate: {delta}') + elif cmd_args.strategy == 'specdec': + logger.info("Decoding Strategy: SpecDec") + remove_bpe_results, delta = specdec_generate(bpe_sents, model, AR_model, task, cmd_args.block_size, cmd_args.batch, + device, beta=cmd_args.beta, tau=cmd_args.tau) + logger.info(f'SpecDec generate: {delta}') else: logger.info("Decoding Strategy: fairseq") remove_bpe_results, delta = fairseq_generate(bpe_sents, cfg, AR_models, _AR_model_task, cmd_args.batch, device) diff --git a/inference.sh b/inference.sh index b227758..36ffb41 100644 --- a/inference.sh +++ b/inference.sh @@ -1,11 +1,11 @@ -data_dir=./data # the dir that contains dict files -checkpoint_path=./checkpoints/wmt14-en-de-base-nat-drafter-checkpoint.avg10.pt # the dir that contains NAT drafter checkpoint -AR_checkpoint_path=./checkpoints/wmt14-en-de-base-at-verifier.pt # the dir that contains AT verifier checkpoint -input_path=./test.en # the dir that contains bpe test files +data_dir=./data/wmt14.en-de # the dir that contains dict files +checkpoint_path=/home/xiaheming/data/SpecDec/wmt14-en-de-base-nat-drafter-checkpoint.avg10.pt # the dir that contains NAT drafter checkpoint +AR_checkpoint_path=/home/xiaheming/data/SpecDec/wmt14-en-de-base-at-verifier.pt # the dir that contains AT verifier checkpoint +input_path=./data/wmt14.en-de/test.en # the dir that contains bpe test files output_path=./output/block.out # the dir for outputs -strategy='gad' # fairseq, AR, gad -batch=32 +strategy='specdec' # fairseq, AR, specdec +batch=1 beam=5 beta=5 @@ -16,8 +16,8 @@ src=en tgt=de -python inference.py ${data_dir} --path ${checkpoint_path} \ - --user-dir block_plugins --task translation_lev_modified --remove-bpe --max-sentences 20 --source-lang ${src} \ +python inference_paper.py ${data_dir} --path ${checkpoint_path} \ + --user-dir specdec_plugins --task translation_lev_modified --remove-bpe --max-sentences 20 --source-lang ${src} \ --target-lang ${tgt} --iter-decode-max-iter 0 --iter-decode-eos-penalty 0 --iter-decode-with-beam 1 \ --gen-subset test --AR-path ${AR_checkpoint_path} --input-path ${input_path} --output-path ${output_path} \ --block-size ${block_size} --beta ${beta} --tau ${tau} --batch ${batch} --beam ${beam} --strategy ${strategy} diff --git a/inference_drafter.py b/inference_drafter.py index 9b5f738..f9f3cc0 100644 --- a/inference_drafter.py +++ b/inference_drafter.py @@ -25,12 +25,12 @@ def write_result(results, output_file): @torch.no_grad() -def gad_generate(data_lines, model, task, block_size, device, max_len=200): +def drafter_generate(data_lines, model, task, block_size, device, max_len=200): src_dict = task.source_dictionary tgt_dict = task.target_dictionary data_size = len(data_lines) all_results = [] - logger.info(f'GAD generate') + logger.info(f'Spec-Drafter generate') pass_tokens = [0] * max_len sent_nums = [0] * max_len start = time.perf_counter() @@ -44,8 +44,8 @@ def gad_generate(data_lines, model, task, block_size, device, max_len=200): prev_output_tokens = [tgt_dict.unk()] * block_size start_pos = 0 for step in range(0, max_len): - start_pos, prev_output_tokens, pass_token = gad_forward(start_pos, block_size, tgt_dict, - prev_output_tokens, encoder_out, model) + start_pos, prev_output_tokens, pass_token = drafter_forward(start_pos, block_size, tgt_dict, + prev_output_tokens, encoder_out, model) pass_tokens[step] += pass_token sent_nums[step] += 1 if start_pos == -1: @@ -74,7 +74,7 @@ def gad_generate(data_lines, model, task, block_size, device, max_len=200): @torch.no_grad() -def gad_forward(start_pos, block_size, tgt_dict, prev_output_tokens, encoder_out, model, max_len=200): +def drafter_forward(start_pos, block_size, tgt_dict, prev_output_tokens, encoder_out, model, max_len=200): output_tokens = torch.tensor([prev_output_tokens]).to(device) block_mask = torch.zeros_like(output_tokens).to(output_tokens) block_mask[0][start_pos:start_pos + block_size] = 1 @@ -136,8 +136,8 @@ def gad_forward(start_pos, block_size, tgt_dict, prev_output_tokens, encoder_out with open(cmd_args.input_path, 'r') as f: bpe_sents = [l.strip() for l in f.readlines()] - logger.info("Decoding Strategy: GAD") - remove_bpe_results, delta = gad_generate(bpe_sents, model, task, cmd_args.block_size, device) + logger.info("Decoding Strategy: Spec-Drafter") + remove_bpe_results, delta = drafter_generate(bpe_sents, model, task, cmd_args.block_size, device) logger.info(f'GAD generate: {delta}') if cmd_args.output_path is not None: diff --git a/inference_paper.py b/inference_paper.py index 31449f3..a003fac 100644 --- a/inference_paper.py +++ b/inference_paper.py @@ -146,8 +146,8 @@ def forward_decoder(model, @torch.no_grad() -def gad_generate(data_lines, model, AR_model, task, block_size, device, beta=1, tau=0, max_len=200): - # Generalized Aggressive Decoding +def specdec_generate(data_lines, model, AR_model, task, block_size, device, beta=1, tau=0, max_len=200): + # Speculative Decoding src_dict = task.source_dictionary tgt_dict = task.target_dictionary encoder_state_ids = [] @@ -155,7 +155,7 @@ def gad_generate(data_lines, model, AR_model, task, block_size, device, beta=1, encoder_state_ids.append(AR_model.decoder.layers[i].encoder_attn._incremental_state_id) data_size = len(data_lines) all_results = [] - logger.info(f'GAD generate') + logger.info(f'SpecDec generate') pass_tokens = [0] * max_len sent_nums = [0] * max_len start = time.perf_counter() @@ -171,11 +171,11 @@ def gad_generate(data_lines, model, AR_model, task, block_size, device, beta=1, prev_output_tokens = [tgt_dict.unk()] * block_size start_pos = 0 for step in range(0, max_len): - start_pos, prev_output_tokens, pass_token = gad_forward(incremental_state, encoder_state_ids, - start_pos, block_size, tgt_dict, - prev_output_tokens, - encoder_out, AR_encoder_out, model, - AR_model, beta, tau) + start_pos, prev_output_tokens, pass_token = specdec_forward(incremental_state, encoder_state_ids, + start_pos, block_size, tgt_dict, + prev_output_tokens, + encoder_out, AR_encoder_out, model, + AR_model, beta, tau) pass_tokens[step] += pass_token sent_nums[step] += 1 if start_pos == -1: @@ -204,8 +204,8 @@ def gad_generate(data_lines, model, AR_model, task, block_size, device, beta=1, @torch.no_grad() -def gad_forward(incremental_state, encoder_state_ids, start_pos, block_size, tgt_dict, prev_output_tokens, - encoder_out, AR_encoder_out, model, AR_model, beta, tau, max_len=200): +def specdec_forward(incremental_state, encoder_state_ids, start_pos, block_size, tgt_dict, prev_output_tokens, + encoder_out, AR_encoder_out, model, AR_model, beta, tau, max_len=200): output_tokens = torch.tensor([prev_output_tokens]).to(device) block_mask = torch.zeros_like(output_tokens).to(output_tokens) block_mask[0][start_pos:start_pos + block_size] = 1 @@ -263,7 +263,7 @@ def gad_forward(incremental_state, encoder_state_ids, start_pos, block_size, tgt parser.add_argument('--AR-path', type=str, default=None, help='path to AR model') parser.add_argument('--strategy', type=str, default='fairseq', - help='decoding strategy, choose from: fairseq, AR, gad') + help='decoding strategy, choose from: fairseq, AR, specdec') parser.add_argument('--batch', type=int, default=None, help='batch size') parser.add_argument('--block-size', type=int, default=5, @@ -286,7 +286,7 @@ def gad_forward(incremental_state, encoder_state_ids, start_pos, block_size, tgt device = torch.device('cuda') # NAR drafter - if cmd_args.strategy == 'gad': + if cmd_args.strategy == 'specdec': logger.info("loading model(s) from {}".format(cfg.common_eval.path)) models, _model_args, _model_task = load_model_ensemble_and_task(filenames=[cfg.common_eval.path], task=task) model = models[0].to(device).eval() @@ -304,11 +304,11 @@ def gad_forward(incremental_state, encoder_state_ids, start_pos, block_size, tgt logger.info("Decoding Strategy: Simplified AR") remove_bpe_results, delta = baseline_generate(bpe_sents, AR_model, _AR_model_task, device) logger.info(f'Simplified AR generate: {delta}') - elif cmd_args.strategy == 'gad': - logger.info("Decoding Strategy: GAD") - remove_bpe_results, delta = gad_generate(bpe_sents, model, AR_model, task, cmd_args.block_size, device, - beta=cmd_args.beta, tau=cmd_args.tau) - logger.info(f'GAD generate: {delta}') + elif cmd_args.strategy == 'specdec': + logger.info("Decoding Strategy: SpecDec") + remove_bpe_results, delta = specdec_generate(bpe_sents, model, AR_model, task, cmd_args.block_size, device, + beta=cmd_args.beta, tau=cmd_args.tau) + logger.info(f'SpecDec generate: {delta}') else: logger.info("Decoding Strategy: fairseq") remove_bpe_results, delta = fairseq_generate(bpe_sents, cfg, AR_models, _AR_model_task, cmd_args.batch, device) diff --git a/readme.md b/readme.md index a311070..075e7f0 100644 --- a/readme.md +++ b/readme.md @@ -1,24 +1,19 @@ -# Generalized Aggressive Decoding +# Speculative Decoding ## Introduction -This repository contains all the code and checkpoints used to reimplement our paper: [Lossless Speedup of Autoregressive Translation with Generalized Aggressive Decoding](https://arxiv.org/pdf/2203.16487.pdf). +This repository contains the code used to reimplement our paper: [Speculative Decoding: Exploiting Speculative Execution for Accelerating Seq2seq Generation](https://arxiv.org/pdf/2203.16487.pdf). -![GAD](./GAD.gif) - -## News - -- 2022.09.20 Update💥: the memory cost of GAD is optimized. Now you can obtain **3x~5x speedup** using GAD with only **~300MiB of extra memory cost** (~240 MiB for model states), compared to Transformer's greedy decoding. -- 2022.09.21 Update💥: the inference codes for the summarization task are released. +![SpecDec](./SpecDec.gif) ## Download model | Description | Model | | ----------- | ------------------------------------------------------------ | -| wmt14.en-de | [at-verifier-base](https://drive.google.com/file/d/1L9z0Y5rked_tYn7Fllh-0VsRdgBHN1Mp/view?usp=sharing), [nat-drafter-base (k=25)](https://drive.google.com/file/d/1fPYt1QGgIrNfk78XvGnrx_TeDRYePr2e/view?usp=sharing) | -| wmt14.de-en | [at-verifier-base](https://drive.google.com/file/d/1h5EdTEt2PMqvAqCq2G5bRhCeWk8LzwoG/view?usp=sharing), [nat-drafter-base (k=25)](https://drive.google.com/file/d/1IEX2K65rgv5SUHWxiowXYaS--Zqr3GvT/view?usp=sharing) | -| wmt16.en-ro | [at-verifier-base](https://drive.google.com/file/d/1WocmZ9iw_OokYZY_BtzNAjGsgRXB-Aft/view?usp=sharing), [nat-drafter-base (k=25)](https://drive.google.com/file/d/1V_WbPRbgmIy-4oZDkws9mdFSw8n8KOGm/view?usp=sharing) | -| wmt16.ro-en | [at-verifier-base](https://drive.google.com/file/d/1LWHC56HvTtvs58EMwoYMT6jKByuMW1dB/view?usp=sharing), [nat-drafter-base (k=25)](https://drive.google.com/file/d/1P21nU3u4WdJueEl4nqAY-cwUKAvzPu8A/view?usp=sharing) | +| wmt14.en-de | [ar-verifier-base](https://drive.google.com/file/d/1L9z0Y5rked_tYn7Fllh-0VsRdgBHN1Mp/view?usp=sharing), [nar-drafter-base (k=25)](https://drive.google.com/file/d/1fPYt1QGgIrNfk78XvGnrx_TeDRYePr2e/view?usp=sharing) | +| wmt14.de-en | [ar-verifier-base](https://drive.google.com/file/d/1h5EdTEt2PMqvAqCq2G5bRhCeWk8LzwoG/view?usp=sharing), [nar-drafter-base (k=25)](https://drive.google.com/file/d/1IEX2K65rgv5SUHWxiowXYaS--Zqr3GvT/view?usp=sharing) | +| wmt16.en-ro | [ar-verifier-base](https://drive.google.com/file/d/1WocmZ9iw_OokYZY_BtzNAjGsgRXB-Aft/view?usp=sharing), [nar-drafter-base (k=25)](https://drive.google.com/file/d/1V_WbPRbgmIy-4oZDkws9mdFSw8n8KOGm/view?usp=sharing) | +| wmt16.ro-en | [ar-verifier-base](https://drive.google.com/file/d/1LWHC56HvTtvs58EMwoYMT6jKByuMW1dB/view?usp=sharing), [nar-drafter-base (k=25)](https://drive.google.com/file/d/1P21nU3u4WdJueEl4nqAY-cwUKAvzPu8A/view?usp=sharing) | ## Requirements @@ -28,8 +23,8 @@ This repository contains all the code and checkpoints used to reimplement our pa ## Installation ``` -conda create -n gad python=3.7 -cd GAD +conda create -n specdec python=3.7 +cd SpecDec pip install --editable . ``` @@ -54,11 +49,11 @@ fairseq-preprocess --source-lang ${src} --target-lang ${tgt} \ ## Encoder Initialization -We recommend using the AT verifier's encoder to initialize the weights of the NAT drafter. For preparing the initialization checkpoints, check `encoder_initial.py`. +We recommend using the AR verifier's encoder to initialize the weights of the NAR drafter. For preparing the initialization checkpoints, check `encoder_initial.py`. ## Train -**The AT verifier** of GAD is a standard Transformer that can be trained with [fairseq](https://github.com/facebookresearch/fairseq/tree/main/examples/translation): +**The AR verifier** of SpecDec is a standard Transformer that can be trained with [fairseq](https://github.com/facebookresearch/fairseq/tree/main/examples/translation): ``` fairseq-train ${bin_path} --arch transformer --share-all-embeddings \ @@ -78,7 +73,7 @@ fairseq-train ${bin_path} --arch transformer --share-all-embeddings \ --best-checkpoint-metric bleu --maximize-best-checkpoint-metric ``` -For training **the NAT drafter** of GAD (check `train.sh`): +For training **the NAR drafter** of SpecDec (check `train.sh`): ``` python train.py ${bin_path} --arch block --noise block_mask --share-all-embeddings \ @@ -90,7 +85,7 @@ python train.py ${bin_path} --arch block --noise block_mask --share-all-embeddin --decoder-embed-dim 512 --fp16 --max-source-positions 1000 \ --max-target-positions 1000 --max-update ${update} --seed ${seed} --clip-norm 5 \ --save-dir ./checkpoints --src-embedding-copy --log-interval 1000 \ - --user-dir block_plugins --block-size ${size} --total-up ${update} \ + --user-dir specdec_plugins --block-size ${size} --total-up ${update} \ --update-freq ${update_freq} --decoder-learned-pos --encoder-learned-pos \ --apply-bert-init --activation-fn gelu \ --restore-file ./checkpoints/initial_checkpoint.pt \ @@ -99,7 +94,7 @@ python train.py ${bin_path} --arch block --noise block_mask --share-all-embeddin ## Hyperparameters -The hyperparameters of the NAT drafter are shown as follows: +The hyperparameters of the NAR drafter are shown as follows: | Hyperparameters \ Datasets | WMT14 EN-DE | WMT16 EN-RO | | -------------------------- | :---------: | :---------: | @@ -113,10 +108,10 @@ The hyperparameters of the NAT drafter are shown as follows: ## Inference -For GAD++ (check `inference.sh`, set `beta=1` for vanilla GAD): +For SpecDec (check `inference.sh`, set `beta=1` for identical results to AR greedy decoding): ``` -python inference.py ${data_dir} --path ${checkpoint_path} --user-dir block_plugins \ +python inference.py ${data_dir} --path ${checkpoint_path} --user-dir specdec_plugins \ --task translation_lev_modified --remove-bpe --max-sentences 20 \ --source-lang ${src} --target-lang ${tgt} --iter-decode-max-iter 0 \ --iter-decode-eos-penalty 0 --iter-decode-with-beam 1 --gen-subset test \ @@ -125,9 +120,9 @@ python inference.py ${data_dir} --path ${checkpoint_path} --user-dir block_plugi --batch ${batch} --beam ${beam} --strategy ${strategy} ``` -> We test the inference latency of GAD with batch 1 implementation, check `inference_paper.py` for details. +> We test the inference latency of SpecDec with batch 1 implementation, check `inference_paper.py` for details. > -> check `inference_drafter.py` for inference with our NAT drafter only. +> check `inference_drafter.py` for inference with our NAR drafter only. Calculating compound split bleu: @@ -150,13 +145,13 @@ You can find the translation results in `./output`. ## Extra Memory Cost -Since there is no need to save intermediate variables during inference, GAD can achieve **3x~5x decoding speedup** (by alternating NAT and AT decoding) with only **~300MiB of extra memory cost**. Below is the `nvidia-smi` memory cost comparison of AT and GAD, tested on WMT14 EN-DE: +Since there is no need to save intermediate variables during inference, SpecDec can achieve **3x~5x decoding speedup** (by alternating NAR and AR decoding) with only **~300MiB of extra memory cost**. Below is the `nvidia-smi` memory cost comparison of AR and SpecDec, tested on WMT14 EN-DE: | Model \ Batch Size | Model States (Params) | 1 | 4 | 8 | 16 | 32 | | ------------------ | :-------------------: | :--: | :--: | :--: | :--: | :--: | | Fairseq (beam1) | 232.38 | 1670 | 1712 | 1758 | 1844 | 2028 | -| GAD++ | 469.75 (AT + NAT) | 1902 | 1938 | 2012 | 2108 | 2298 | -| Extra Memory | 237.38 (NAT) | 232 | 226 | 254 | 264 | 270 | +| SpecDec | 469.75 (AR + NAR) | 1902 | 1938 | 2012 | 2108 | 2298 | +| Extra Memory | 237.38 (NAR) | 232 | 226 | 254 | 264 | 270 | ## Note diff --git a/specdec_plugins/.DS_Store b/specdec_plugins/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..eaea8407f1aa9f0ac9209a93f79761306d91e4ca GIT binary patch literal 6148 zcmeHKO>fgc5S>la)@g*414uot#5IDH@+mG>C!2wXPYYVKp-ZXZIP!!2$IP+gP z@H_Yg{114uJC&NGh)ab~v=hy|_3pe~?Nj!Ph}3YJ9uW12D1bAz_Yf{IZf9S!meg>g zvVG)K(+Iy?G+xX06l&_|-B;eMR2Szj2aTHI{L{CF3Hm>cQZLYHhW*cXor_pc}l6KA2fl$Mv)t#*^3l z^}?7UIWNb_%RHNnde@(rvX1jIn^;1gW=MJSDlbzr8=7gER+byNhM*I4M!o&{{9y2) zuMZ#ITlDq((VanG-#xs)SagCLH*Y_FJ~}B*OY<4?fGO;9)~*Sj!xt1T)#GC{DGO7w zcFf(OQ!Fo*>kbiT~y@I}c z7hVCcz-B4H`-6rvhJuYlyLBMgCjhXGYHKL-Wq~dGF3q{${(Z9Coq(X;2dIh`!%L?#~6V~JN|LLFK|2D}7uYgxzGZj$n zqv&XexyjkOHcLKhHT)f%jq@6ZmVzM1u`TdXybrgAHq8ZKDA+he4a{EztPDPQ1^%f5 FzX7h;%2EIT literal 0 HcmV?d00001 diff --git a/block_plugins/__init__.py b/specdec_plugins/__init__.py similarity index 67% rename from block_plugins/__init__.py rename to specdec_plugins/__init__.py index 142aa57..ed35e2f 100644 --- a/block_plugins/__init__.py +++ b/specdec_plugins/__init__.py @@ -2,4 +2,4 @@ from .models import * from .tasks import * -print("GAD plugins loaded...") \ No newline at end of file +print("SpecDec plugins loaded...") \ No newline at end of file diff --git a/block_plugins/criterions/__init__.py b/specdec_plugins/criterions/__init__.py similarity index 100% rename from block_plugins/criterions/__init__.py rename to specdec_plugins/criterions/__init__.py diff --git a/block_plugins/criterions/glat_loss.py b/specdec_plugins/criterions/glat_loss.py similarity index 100% rename from block_plugins/criterions/glat_loss.py rename to specdec_plugins/criterions/glat_loss.py diff --git a/block_plugins/models/BlockNAT.py b/specdec_plugins/models/BlockNAT.py similarity index 100% rename from block_plugins/models/BlockNAT.py rename to specdec_plugins/models/BlockNAT.py diff --git a/block_plugins/models/__init__.py b/specdec_plugins/models/__init__.py similarity index 100% rename from block_plugins/models/__init__.py rename to specdec_plugins/models/__init__.py diff --git a/block_plugins/tasks/__init__.py b/specdec_plugins/tasks/__init__.py similarity index 100% rename from block_plugins/tasks/__init__.py rename to specdec_plugins/tasks/__init__.py diff --git a/block_plugins/tasks/translation_lev_modified.py b/specdec_plugins/tasks/translation_lev_modified.py similarity index 100% rename from block_plugins/tasks/translation_lev_modified.py rename to specdec_plugins/tasks/translation_lev_modified.py diff --git a/train.sh b/train.sh index 3acc93b..3a15b04 100644 --- a/train.sh +++ b/train.sh @@ -17,7 +17,7 @@ python train.py ${bin_path} --arch block \ --weight-decay 0.01 --dropout ${dropout} --encoder-layers 6 --encoder-embed-dim 512 --decoder-layers 6 \ --decoder-embed-dim 512 --fp16 --max-source-positions 1000 --max-target-positions 1000 --max-update ${update}\ --seed ${seed} --clip-norm 5 --save-dir ./checkpoints \ - --src-embedding-copy --log-interval 1000 --user-dir block_plugins --block-size ${size} --total-up ${update} \ + --src-embedding-copy --log-interval 1000 --user-dir specdec_plugins --block-size ${size} --total-up ${update} \ --update-freq ${update_freq} --decoder-learned-pos --encoder-learned-pos --apply-bert-init --activation-fn gelu \ --restore-file ./checkpoints/initial_checkpoint.pt \ --reset-optimizer --reset-meters --reset-lr-scheduler --reset-dataloader