From 1555cd80ab6c8c64a2781af4798c0016adf08562 Mon Sep 17 00:00:00 2001 From: 1hb6s7t <120760709+1hb6s7t@users.noreply.github.com> Date: Sat, 8 Feb 2025 21:13:25 +0800 Subject: [PATCH] Add files via upload --- .../transformers/models/jukebox/__init__.py | 10 - .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 174 bytes ...deling_jukebox.cpython-39-pytest-7.2.0.pyc | Bin 0 -> 14269 bytes ...zation_jukebox.cpython-39-pytest-7.2.0.pyc | Bin 0 -> 10226 bytes .../models/jukebox/test_modeling_jukebox.py | 395 ++++++++++++++++++ .../jukebox/test_tokenization_jukebox.py | 210 ++++++++++ 6 files changed, 605 insertions(+), 10 deletions(-) create mode 100644 mindnlp/transformers/models/jukebox/__pycache__/__init__.cpython-39.pyc create mode 100644 mindnlp/transformers/models/jukebox/__pycache__/test_modeling_jukebox.cpython-39-pytest-7.2.0.pyc create mode 100644 mindnlp/transformers/models/jukebox/__pycache__/test_tokenization_jukebox.cpython-39-pytest-7.2.0.pyc create mode 100644 mindnlp/transformers/models/jukebox/test_modeling_jukebox.py create mode 100644 mindnlp/transformers/models/jukebox/test_tokenization_jukebox.py diff --git a/mindnlp/transformers/models/jukebox/__init__.py b/mindnlp/transformers/models/jukebox/__init__.py index 786980182..e69de29bb 100644 --- a/mindnlp/transformers/models/jukebox/__init__.py +++ b/mindnlp/transformers/models/jukebox/__init__.py @@ -1,10 +0,0 @@ -"""Jukebox model""" -from . import configuration_jukebox, modeling_jukebox, tokenization_jukebox -from .configuration_jukebox import * -from .modeling_jukebox import * -from .tokenization_jukebox import * - -__all__ = [] -__all__.extend(configuration_jukebox.__all__) -__all__.extend(modeling_jukebox.__all__) -__all__.extend(tokenization_jukebox.__all__) diff --git a/mindnlp/transformers/models/jukebox/__pycache__/__init__.cpython-39.pyc b/mindnlp/transformers/models/jukebox/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..825224ba96f877394ee5f56cd9d3a76e2d04b4d5 GIT binary patch literal 174 zcmYe~<>g`k0)_1nX(0MBh(HF6K#l_t7qb9~6oz01O-8?!3`HPe1o11_*(xTqIJKxa z#>mvtz$C^cwK%&ZzaS@7o0MQ0W?KZ(@H{>Pfy0{ zQB)=4OZZi{B#;Q8#FGwW>JoJl=S>GQp+qPXPJ}b{iF%3ir6ZY!L_?`gV`7RP+NFry zl4nz*2~U5zDbt*2mK1?>OJ-_fYNj>OnwgfErYd^)WJRmfg1Z$hs0Vtd@AB?a6Eh?> zq=kVESFkfRb(0dSFQVs4M@)5m);2RA^E%T_E15B~nw8H9J(=2(GSjK8Y29fS`u08( zs7fKKq;oq_HE8K}($aMeX>TrXflK6+zHR!}+}Ue$nx3xY)(bNyoTgG)XKtIGH3xLj z#Xi*j`H(pQU+ZZAUCK1!(Y=sFtM2PovmQ-N_%#nC>eYOjUkiNRo2WA|@*C7d5YHf< zAv{BPhVcyJS&wHuo)J7FcsAhKfM+9~jd)JMa|)hKcsAkLT*9~D*^K;DJX`Q=#d9j2 z)9`G?bGkN7o4#8~%+O|NGx40M&C+J$IZJEP4#0D^_8x5xo^9G($UEw^mPWsHtCUTr zZrQpRRR(6}@_II99@gs$?$D)M4^8(Av&(w3F_G$uen8}QW}_cSWi_X6MXoPvi~gT| z)uSk@qZ!?frlPh;?aF zvQROrp=w`tTUxiQXj(U{Xs$bI*}1GPt8B7KVP>Otj~>-@OBZ%~)cthmDM55*aWvJvN69r|JnHkHYvNp}#al{UL|DGs~G%o>*4yH1S1hq90jW^!HIP{kd+i$yU6 zV!P$tx^|pwkAo^n2W?=kQ|ZO5hGur=gl0uk`8=u}D~Z?kN3*Gn9yP7#x`BQU)J#EA zA3`rE0X{94?V4O}8~JAXqo?J>HW2me(Y2=0WB0paLsR*bDGro;J92U)(XF~=YC43G zNts!zJ-QKkBnN?+?Mn4a+I8LT2q5m9Wu?;5(}do&ZEIcP=hO)@H zO6yQTXc(Laq$+e9qV9&0q;{5=k9rdU$jG$pM8MFqLbnnDSIJ`wiyy30l%f()6!QVt zzuCd;T6kE|0w-YBLD$W^7hn-^<}{{`1zmAE3;qq+;$9YPVBKDVCrAP>tQ4bc3l1wMb2n#H-dYq{+Ie!@aGlxj`7F)Xwsp8W} zVI0yZF9slioC|Hf08lNUJ&r%U&A|BxSgObca%gZ; zs<nt=cX8ENMO!1eH zL($?qzD4Q8NhHDXOufGhipPVbnw>PYoWum?Nt}>Q%Im~3PS!Ag1zEee3TKy%NpVoW({`j zz2w5I3s?;!xr!V&u$xJgn|lc^Lv6E*I-X&{)10gUf@9EtnP>Btu;45*<9wULB#wYY ze2BQU0LANA;04ZH%9hs=mm=rG$>sNK=uW1NB*?HEi%A?qN_4R9nViavOnsjP_mXZK zaK*=f6LVP5PJSOlUt$Z5JqurNcyP&R!FcFb*DYWC!p#Naz}rR76|aA#V7xW@+vnGA zy}e*4xeI)ced3OS@spp{Upaaq5@zl<7d_dD#3we{^FRAGXwG`&*?BMi>DGdAO2o71 zlIdW$=6xSM@vK*d3$IT5_$j~i9dSjO`22!#_neDf+i>@}1>@3-65qT3WxHUU(|YU4 zKO4~EMzZcZvo<)F#f_n7?>yu8_goV<_I8{wHTB4iapTB;KYG>4Vl-~d@}1gv=-0j) zH$MBxDFa8{aw~{$y=&`vKL+L#KmU)p&*yH78<(YD{jbsAel2cnT>XXf7kulkxN*Q> zYT4!`_r{H5KlRjkA1pi&H@1J}vKM=lLfp8}I_awa3?Ol7*GK;OB=19UPv2eDz1k z9p8-`z=v!yWhX7V1(MPJ2UpwCkw`tJ084#_F-R$8<|rnquUC`8<%s+koW~FHGaHcNH(9o zs$l#IOg+f^D%;{dg+PJ-S?T{vJ}rLlxl4(V?G==PXF?e z&m7lrN5S~xCU*DsxDkQcZdkpiV4QLB(LZ=Bb5nWBua6rKK0IT=CjXsr(FxNq?kLa< zZYmgxqJK+)vi+`E2%61*7A+E4KV)#k~dNPfLGs>%l+2vtYciNT2o6>0g2zq?;GV z&^V7JIF5j>zeq<{Tug8naV#w!Ot64pE&&%vv4wy&M1tVM0L25DsuUEt;uWLBNjh0@ zDzj%2$aZBldJsv{1s75IgC9}gc$5|RTJr&HWmp|ldOaiRkUFS#V>Lt0dy{&(s;RS- zA~{h(X^99pNbu%aD~{c=8HY@3%d#VnS#rddmAbXf&gHl8&LNx5ADY3LRTo=q z9Hgu*Xew)U=R}6|vN>x@>GUiqk`jkQQ?i6@&-Xix=wlwIm!uqmbtQ)mAFC*8L~T_E z4w~$A-afX@7#UK2@VKrF=Rmr)(ko@CshVexXV|NGr8Jdc-v}gV`!#i!ipv{ZTkKZs zfEL)Lu2goZL*7BPw{Floh^q}l)#{K7Qs%v(Rt1+F!99LZD8u1FFR}a%*N-T}krB>N z1HO%hX9Vk9nSX^_sKPBYDK}upahn2eP5a;$u5b%a$_?0Y+?v6y#pR|cym0D}`=n7^ z#9&gBTAxVSIHI34MyOcUcn}0VX6FC^1H=4C>yS#}YKraYg&6jDuJCO4Q6XcVBGp(FeTkzctLi6uSRuSTjRPuf8uG_; z(S8G8W#Fhu^@4hsv~YcphL*K2Pv^Q)Y3q0-$_$D$F=>Zru0c}j$+Vs|>>ej#;l?J3 zGjLzpwnQ3)MUH!PstB^uR)Q`7C&cSDX;%W z6G)d%Bwe}>(j^l~mq=2~k1huyt<(wNmQhbzj%T38AGj0JFh7_1v^9FeAk8{g1H9#Bt_rrty&2ao2#z+l~o zXNz|zI20NRqfU!GwYRl*TJQ8h={VQ|>V{?A86%24b9fd^arTI+)#DS{<14jQtrwDY zCa>4FpL$`qI@wY;`&n)4fU$L?`8T-cFM1*GcVY8l8|sS;L7t$pf=G=K7I3!{&x63g zvHRACt{m>z=zsQMzoHKeVjjnji#uUEUOlUc_W@Ppq9>R1SKaO3?D(iB;Y-1`5E^jTyw_8 z)hi@2RxfQ|e2i>)J`^7ZNMLJ_#~uQHOkB=2<)u)*`i7wAzVp_R$ zEIK0Z7)#O}rz=jg^e(lEa0}YD76QP-z^}l=c*3d=-exAA^E@rsDa?O6=qkAQX0e5!uLzEmHw~d)`V;@4r?+0(Z(xtRMS%4i z-XR}00+B%esGK)T*WY%I8~@PXSDyQ#Oi*FW4ch}PUC92Q+i-(gQ3ZiU~N zHlmJTTgZOE!$>OhXoiNUlAA;8^|O?j%7~}iBmIkKcqTSH^nPlu+sDwvK91=jw<_fx zaR0h?{9#k%)yzz+ZBnPi{e)}qK5Qy9$H)DFbXAUj$Mzj5Y%cU{7q;_ogFUIdE{0fv zzDt}(z!h0sLvSs@=LzU5#1{Z+HV^AJu3n4n#uCTdm9j;ag+6#QVh~%$BbyKz`<;mg`4Yd52uKB8LN1G^8qwQ+#($Tb`muc_T$SdPo11uJ+EpOCv-ka`=g zUBShT7>3kUaT8PW#CZi%p8|-D-^#j+U`5}=7fJbLf>DB-3HGl|(<&>>sT)o`aaD3v z$BK0;S8;PI6?A;@%2h1JA(dNAx!gGF*-~A`D@Rx=hgA2Uf!1mSD5->Jrn&&%fw}vh zc3H0@c7PL1CK#|!S4PDFjUbvE^bG)3g76fE@2wP&^kpP<%f&KE+r3kj8zK z##Dj_0E*HrtU~JVP+IRoF(9or$}mbGQQ~R)zm{{$1#q}*q#>8_WZ4T`TH>(^J=VD^6ewR`$jPXp!gz{o$Ffh z8LnpECU}*lj}rG5Q+#E}UFyGul(`Wk=H@!!PGSLBnI}Mx%qt15L#beEg2bYk;BkUS z2*}EO4t2|3elk}dCFNwktV%3N{(eAgx%nFBN!sy45KnO}WbJ0D!RAk~2NzGU;9mi1 zES5O=f|7WcDXF*n#8tG~+zW-4{M;H;^<-41lWYBD#hT!wK2FcyY6=!gX3y6 z`N0|oUR-v`=;lA-aHTlh|G`E4nV1mwBEB3%8#}RMqC1lX(c~q3aXHmx2SJ*kpJ0IC zqXgsBif7$3s8;m0JjO@%7bx0Ke+CZ1%d=5Eo9i=rq9?E4jJK^bInmE|Z&$32;>GDK zg3tTD<S<4U-x!EI!)Oj!yPhsuFs zd9%|n+*orjG-bF6>fc;*FH~xyCTJYc5W+=-3E^ki1_Qos8`XgCWNnlJzqZ>3?v31= zPT4!jaZDSyDXwXwzS0Kp`xysHDU(tgQ_F40Td7Fxt<;$LpmT$_cVo#sjLTSZkS{p6 z@#CuQ#5`g<2prX{;;I3qY(?WN<=Y2?c&$V8;>8XRqcOu?zSJSD7)zvvcnNN0J;Q!j9d2ti4dB7O<@#*AEHojiKSR!HcYf(nN&bweVl=xr zweIaJuHHS3Y7-)GP3pkneNW>=CLOUE;x6|8n*?_g&e-Be~Tf%+HU%6+c0K{_y zPZK;tAgx7Sr_#BI?+`psP+f&&O73v(MZxR%${@p->(ptuNDbC}Lr8~kk=hba!!rcy z#b!Cdt$k+N#?LjJ$O-xWWF$IoYM-3e7t4q_%QE{w`j;3?q* z`?4l}k%JfyyPs$)MSk1+Y%^`Ovw-hR$($=S?g<69s!F1%Kyy8t;?O6Pr4wB4Ua23?n^kKz cz&xz%t^NpJ1PMS*@CujLqqd%>to5k>4V)s=#PQ^!l`W%NN9nI<_7@+Btc-5%8Ilad&c(Q*`4Li zY+Of)R0*E@7xW>iDph&q59lAVuT>wW5B&pLRqZ*SvFA90<7RoR6JOtZ&+qR!_s;Cm zhR2T0dGfRN@t-%B7d`KJObhbaYI_G=udeKxg{iPR8$DPlf_@cc!ANikn3!PH5de;{G@l)@Qw&S(z zw=3CpJVW0XcP6~Wxf>_MeO1zxRvDLrZ|_x`C_oMmF2NIzB4dSFNnb%m#ABj@C_Gn zG#|tgqxiR;rI8H%{zJxA1OssEq$##XYv z73>O>`N-W$ZpDiNK-eFO=_)9ZOp}cmh>ERbu$h~uD}u?E;)CISw!ST56(ZHHkOico zqHn~}o3&jHat}4JCD{(sMHLluJDH2}ARG=w^o?5J>TWOy`*D!u!S&m_8hjE8f**)S zii-ZtP5p;iwypa6yTOgjZ42XW99&H{w~E~gPBA$Q6Zb-`|5`?L!CIUrQ7j0GNuWwXQNjkDT57Xd=9FOhAU`_c!OCeYboz;@Ul}d4u&oq{@m5P*; zC5XGYkTh&aL4`X9j4ZUr9lW+V}v*|^hKl+>g zeSdE?bJ%LLeRukfcW>rS>t6Scob!|3ooOBKx9)35qO`qp2k!AW7Z}sgYnODYk-a7!NL7shO9Tm+QGxRIWYn$L(yG-+op(!{p`9 zs;h13%-T`~-gq_{3`a#TiSn2JE_VrHHYEP{JTLi|=qhJLU-_X9KhxodIxNfJj_K_A z{h518mvVo5?rEK!*4dLf`>D=O>@VH7bcSr8GhLPoIyu^o?(EPuLQVo?as>FBq#mXg}q4`;z)j^l<%Cc0upJ<0@(L! zmdAjEV9-*dYZy2nK{Kl0(<(WT8A8P_4lyHMWiEeR`=S+AIey5jq(G3xjk=I3{LCbL zQonnuEb>SC7btThHWW+*Q@7SadKy`%Q|o?QyM{?i6d|w&hzy1mGbCH zD|8~VQX!Z+GbQdRM^ih6)B z3_W`>3;B*2fe3d=Vv zBy$ZLEC~?x7M8^U)j$LYf(CJpUjq`8g@A+sd$ctq%r%xskTP?_Y&85cU^STXq6}`L zy5hEIpmB7Pako($I8LJQ6er2F0bwB+vxUzY#4eGO8T4TdeV8=6HB9A&NrNs?x};*v z2p@eU&%B}%Yd1Rxxt63WmPaqYRh?* zKs|#!Qd#0N9uNl+vE-(Kg;XGh2X!!2_y*HyHH6O`SWX$~MyZK4u)-RdMin;5FMQ}z z56Vlbev-%?@5O{*RoB2&jhDD)#$hps@DB`Nv2>d%e6HWaN|>>0Z!+jZmHA@OnKkYg zB43&~-ZLs;HDrd+#}W^)fqxuMW`w{1GX$8a0}J2oKTGIqYWJ{s(PF@&9;Mi@c z@w7C~LTMxv@gUFKcux94($!jVqj=m6({x?>8S#V8cq)(64fm$%QDx)VV!t&gqB_l2 zpS=8=k1k(Yx%_tT`ll-&eY*0%bMFZ2Kch=;-27CO<~LtC_xkCxH{Xu)?IIiA-1Bg$ zFWtA;-PDit`AuE@){lJt|5&ZaaIwSy8|xDu?5iXh5@K~_k^@0N69_S-U=IN=L(*^<_rcDLkD7=^<^ zdx068@4X`()FR$==OV4d<|xa03=(;WJ#QSWJQtRQQl>98Q=ipUH7KnIN*;Y=(8{M#@)oMap^v1h zb%T?Ql(9z~R_KNw)W-Q}yziMwBT!8Y4QrDF!9{^XM8tqsR=}2h2!I2Xuq;6QqomLR zF7Oec9Sv|k+T>x`M<2dmmjPd323>OqYuG>n5GOh~cI9G&*eFcm<)shx1ZWr#KE~}V z*ki{E9Z!+PKXOTcD7BnhY}i1;E;6(r0o7`flLCP~W-uE=e7PtN{Cb#+Q58fYAy zWZZ4k29A?UJjF>eZ9rHE#%$qp2C+-z1Y@;c^~Q6(o-|}esm|wnz0XHsT2b~oOLa%6x1LHD$-Sj+ zw{_t2P7Nx3uo^jtnoh=A|3;Ur07^ AM*si- literal 0 HcmV?d00001 diff --git a/mindnlp/transformers/models/jukebox/test_modeling_jukebox.py b/mindnlp/transformers/models/jukebox/test_modeling_jukebox.py new file mode 100644 index 000000000..d37ab8905 --- /dev/null +++ b/mindnlp/transformers/models/jukebox/test_modeling_jukebox.py @@ -0,0 +1,395 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +from unittest import skip + +from mindnlp.utils.testing_utils import ( + is_mindspore_available, + require_mindspore, + slow, +) +from mindnlp.engine import set_seed +import mindnlp.core + +if is_mindspore_available(): + import mindspore + from mindspore import ops + from mindnlp.transformers import JukeboxModel, JukeboxPrior, JukeboxTokenizer + + +@require_mindspore +class Jukebox1bModelTester(unittest.TestCase): + all_model_classes = (JukeboxModel,) if is_mindspore_available() else () + model_id = "openai/jukebox-1b-lyrics" + metas = { + "artist": "Zac Brown Band", + "genres": "Country", + "lyrics": """I met a traveller from an antique land, + Who said "Two vast and trunkless legs of stone + Stand in the desert. . . . Near them, on the sand, + Half sunk a shattered visage lies, whose frown, + And wrinkled lip, and sneer of cold command, + Tell that its sculptor well those passions read + Which yet survive, stamped on these lifeless things, + The hand that mocked them, and the heart that fed; + And on the pedestal, these words appear: + My name is Ozymandias, King of Kings; + Look on my Works, ye Mighty, and despair! + Nothing beside remains. Round the decay + Of that colossal Wreck, boundless and bare + The lone and level sands stretch far away + """, + } + # fmt: off + EXPECTED_OUTPUT_2 = [ + 1864, 1536, 1213, 1870, 1357, 1536, 519, 880, 1323, 789, 1082, 534, + 1000, 1445, 1105, 1130, 967, 515, 1434, 1620, 534, 1495, 283, 1445, + 333, 1307, 539, 1631, 1528, 375, 1434, 673, 627, 710, 778, 1883, + 1405, 1276, 1455, 1228 + ] + + EXPECTED_OUTPUT_2_PT_2 = [ + 1489, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, + 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, + 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, + 653, 653, 653, 653 + ] + + EXPECTED_OUTPUT_1 = [ + 1125, 1751, 697, 1776, 1141, 1476, 391, 697, 1125, 684, 867, 416, + 844, 1372, 1274, 717, 1274, 844, 1299, 1419, 697, 1370, 317, 1125, + 191, 1440, 1370, 1440, 1370, 282, 1621, 1370, 368, 349, 867, 1872, + 1262, 869, 1728, 747 + ] + EXPECTED_OUTPUT_1_PT_2 = [ + 416, 416, 1125, 1125, 416, 416, 416, 416, 416, 416, 416, 416, + 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, + 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, + 416, 416, 416, 416 + ] + + EXPECTED_OUTPUT_0 = [ + 1755, 842, 307, 1843, 1022, 1395, 234, 1554, 806, 739, 1022, 442, + 616, 556, 268, 1499, 933, 457, 1440, 1837, 755, 985, 308, 902, + 293, 1443, 1671, 1141, 1533, 555, 1562, 1061, 287, 417, 1022, 2008, + 1186, 1015, 1777, 268 + ] + EXPECTED_OUTPUT_0_PT_2 = [ + 854, 842, 1353, 114, 1353, 842, 185, 842, 185, 114, 591, 842, + 185, 417, 185, 842, 307, 842, 591, 842, 185, 842, 307, 842, + 591, 842, 1353, 842, 185, 842, 591, 842, 591, 114, 591, 842, + 185, 842, 591, 89 + ] + + EXPECTED_Y_COND = [1058304, 0, 786432, 7169, 507, 76, 27, 40, 30, 76] + + EXPECTED_PRIMED_0 = [ + 390, 1160, 1002, 1907, 1788, 1788, 1788, 1907, 1002, 1002, 1854, 1002, + 1002, 1002, 1002, 1002, 1002, 1160, 1160, 1606, 596, 596, 1160, 1002, + 1516, 596, 1002, 1002, 1002, 1907, 1788, 1788, 1788, 1854, 1788, 1907, + 1907, 1788, 596, 1626 + ] + EXPECTED_PRIMED_1 = [ + 1236, 1668, 1484, 1920, 1848, 1409, 139, 864, 1828, 1272, 1599, 824, + 1672, 139, 555, 1484, 824, 1920, 555, 596, 1579, 1599, 1231, 1599, + 1637, 1407, 212, 824, 1599, 116, 1433, 824, 258, 1599, 1433, 1895, + 1063, 1433, 1433, 1599 + ] + EXPECTED_PRIMED_2 = [ + 1684, 1873, 1119, 1189, 395, 611, 1901, 972, 890, 1337, 1392, 1927, + 96, 972, 672, 780, 1119, 890, 158, 771, 1073, 1927, 353, 1331, + 1269, 1459, 1333, 1645, 812, 1577, 1337, 606, 353, 981, 1466, 619, + 197, 391, 302, 1930 + ] + EXPECTED_VQVAE_ENCODE = [ + 390, 1160, 1002, 1907, 1788, 1788, 1788, 1907, 1002, 1002, 1854, 1002, + 1002, 1002, 1002, 1002, 1002, 1160, 1160, 1606, 596, 596, 1160, 1002, + 1516, 596, 1002, 1002, 1002, 1907, 1788, 1788, 1788, 1854, 1788, 1907, + 1907, 1788, 596, 1626 + ] + EXPECTED_VQVAE_DECODE = [ + -0.0492, -0.0524, -0.0565, -0.0640, -0.0686, -0.0684, -0.0677, -0.0664, + -0.0605, -0.0490, -0.0330, -0.0168, -0.0083, -0.0075, -0.0051, 0.0025, + 0.0136, 0.0261, 0.0386, 0.0497, 0.0580, 0.0599, 0.0583, 0.0614, + 0.0740, 0.0889, 0.1023, 0.1162, 0.1211, 0.1212, 0.1251, 0.1336, + 0.1502, 0.1686, 0.1883, 0.2148, 0.2363, 0.2458, 0.2507, 0.2531 + ] + EXPECTED_AUDIO_COND = [ + 0.0256, -0.0544, 0.1600, -0.0032, 0.1066, 0.0825, -0.0013, 0.3440, + 0.0210, 0.0412, -0.1777, -0.0892, -0.0164, 0.0285, -0.0613, -0.0617, + -0.0137, -0.0201, -0.0175, 0.0215, -0.0627, 0.0520, -0.0730, 0.0970, + -0.0100, 0.0442, -0.0586, 0.0207, -0.0015, -0.0082 + ] + EXPECTED_META_COND = [ + 0.0415, 0.0877, 0.0022, -0.0055, 0.0751, 0.0334, 0.0324, -0.0068, + 0.0011, 0.0017, -0.0676, 0.0655, -0.0143, 0.0399, 0.0303, 0.0743, + -0.0168, -0.0394, -0.1113, 0.0124, 0.0442, 0.0267, -0.0003, -0.1536, + -0.0116, -0.1837, -0.0180, -0.1026, -0.0777, -0.0456 + ] + EXPECTED_LYRIC_COND = [ + 76, 27, 40, 30, 76, 46, 44, 47, 40, 37, 38, 31, 45, 45, 76, 38, 31, 33, + 45, 76, 41, 32, 76, 45, 46, 41, 40, 31, 78, 76 + ] + # fmt: on + + def prepare_inputs(self): + tokenizer = JukeboxTokenizer.from_pretrained(self.model_id) + tokens = tokenizer(**self.metas)["input_ids"] + return tokens + + #@slow + def test_sampling(self): + model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).set_train(False) + labels = self.prepare_inputs() + + set_seed(0) + zs = [ops.zeros((1, 0), dtype=mindspore.int64) for _ in range(3)] + zs = model._sample(zs, labels, [0], sample_length=40 * model.priors[0].raw_to_tokens, save_results=False) + self.assertIn(zs[0][0].detach().tolist(), [self.EXPECTED_OUTPUT_2, self.EXPECTED_OUTPUT_2_PT_2]) + + set_seed(0) + zs = model._sample(zs, labels, [1], sample_length=40 * model.priors[1].raw_to_tokens, save_results=False) + self.assertIn(zs[1][0].detach().tolist(), [self.EXPECTED_OUTPUT_1, self.EXPECTED_OUTPUT_1_PT_2]) + + set_seed(0) + zs = model._sample(zs, labels, [2], sample_length=40 * model.priors[2].raw_to_tokens, save_results=False) + self.assertIn(zs[2][0].detach().tolist(), [self.EXPECTED_OUTPUT_0, self.EXPECTED_OUTPUT_0_PT_2]) + + #@slow + def test_conditioning(self): + model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).set_train(False) + + labels = self.prepare_inputs() + set_seed(0) + zs = [ops.zeros((1, 0), dtype=mindspore.int64) for _ in range(3)] + + top_prior = model.priors[0] + start = 0 + music_token_conds = top_prior.get_music_tokens_conds(zs, start=start, end=start + top_prior.n_ctx) + metadata = top_prior.get_metadata(labels[0].clone(), start, 1058304, 0) + + self.assertIsNone(music_token_conds) + self.assertListEqual(metadata.numpy()[0][:10].tolist(), self.EXPECTED_Y_COND) + + audio_conditioning, metadata_conditioning, lyric_tokens = top_prior.get_cond(music_token_conds, metadata) + self.assertTrue(mindnlp.core.ops.allclose( + audio_conditioning[0][0][:30].detach(), mindspore.tensor(self.EXPECTED_AUDIO_COND), atol=1e-4, rtol=1e-4 + )) + self.assertTrue(mindnlp.core.ops.allclose( + metadata_conditioning[0][0][:30].detach(), mindspore.tensor(self.EXPECTED_META_COND), atol=1e-4, rtol=1e-4 + )) + self.assertTrue(mindnlp.core.ops.allclose( + lyric_tokens[0, :30].detach(), mindspore.tensor(self.EXPECTED_LYRIC_COND), atol=1e-4, rtol=1e-4 + )) + + #@slow + def test_primed_sampling(self): + model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).set_train(False) + set_seed(0) + waveform = ops.rand((1, 5120, 1)) + tokens = list(self.prepare_inputs()) + + zs = [model.vqvae.encode(waveform, start_level=2, bs_chunks=waveform.shape[0])[0], None, None] + zs = model._sample( + zs, tokens, sample_levels=[0], save_results=False, sample_length=40 * model.priors[0].raw_to_tokens + ) + self.assertTrue(mindnlp.core.ops.allclose(zs[0][0][:40], mindspore.tensor(self.EXPECTED_PRIMED_0))) + + upper_2 = ops.cat((zs[0], ops.zeros(1, 2048 - zs[0].shape[-1])), dim=-1).long() + zs = [upper_2, model.vqvae.encode(waveform, start_level=1, bs_chunks=waveform.shape[0])[0], None] + zs = model._sample( + zs, tokens, sample_levels=[1], save_results=False, sample_length=40 * model.priors[1].raw_to_tokens + ) + self.assertTrue(mindnlp.core.ops.allclose(zs[1][0][:40], mindspore.tensor(self.EXPECTED_PRIMED_1))) + + upper_1 = ops.cat((zs[1], ops.zeros(1, 2048 - zs[1].shape[-1])), dim=-1).long() + zs = [upper_2, upper_1, model.vqvae.encode(waveform, start_level=0, bs_chunks=waveform.shape[0])[0]] + zs = model._sample( + zs, tokens, sample_levels=[2], save_results=False, sample_length=40 * model.priors[2].raw_to_tokens + ) + self.assertTrue(mindnlp.core.ops.allclose(zs[2][0][:40], mindspore.tensor(self.EXPECTED_PRIMED_2))) + + #@slow + def test_vqvae(self): + model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).set_train(False) + set_seed(0) + x = ops.rand((1, 5120, 1)) + + zs = model.vqvae.encode(x, start_level=2, bs_chunks=x.shape[0]) + self.assertTrue(mindnlp.core.ops.allclose(zs[0][0], mindspore.tensor(self.EXPECTED_VQVAE_ENCODE))) + + x = model.vqvae.decode(zs, start_level=2, bs_chunks=x.shape[0]) + self.assertTrue(mindnlp.core.ops.allclose(x[0, :40, 0], mindspore.tensor(self.EXPECTED_VQVAE_DECODE), atol=1e-4, rtol=1e-4)) + + +@require_mindspore +class Jukebox5bModelTester(unittest.TestCase): + all_model_classes = (JukeboxModel,) if is_mindspore_available() else () + model_id = "openai/jukebox-5b-lyrics" + metas = { + "artist": "Zac Brown Band", + "genres": "Country", + "lyrics": """I met a traveller from an antique land, + Who said "Two vast and trunkless legs of stone + Stand in the desert. . . . Near them, on the sand, + Half sunk a shattered visage lies, whose frown, + And wrinkled lip, and sneer of cold command, + Tell that its sculptor well those passions read + Which yet survive, stamped on these lifeless things, + The hand that mocked them, and the heart that fed; + And on the pedestal, these words appear: + My name is Ozymandias, King of Kings; + Look on my Works, ye Mighty, and despair! + Nothing beside remains. Round the decay + Of that colossal Wreck, boundless and bare + The lone and level sands stretch far away + """, + } + + # fmt: off + EXPECTED_OUTPUT_2 = [ + 1489, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, + 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, + 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, + 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, + 1489, 1489, 1489, 1489, 1150, 1853, 1509, 1150, 1357, 1509, 6, 1272 + ] + EXPECTED_OUTPUT_2_PT_2 = [ + 1489, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, + 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, + 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, + 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, + 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653 + ] + + EXPECTED_OUTPUT_1 = [ + 1125, 416, 1125, 1125, 1125, 1125, 1125, 416, 416, 416, 416, 416, + 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, + 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, + 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, + 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416 + ] + EXPECTED_OUTPUT_1_PT_2 = [ + 416, 416, 1125, 1125, 416, 416, 416, 416, 416, 416, 416, 416, 416, + 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, + 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, + 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, + 416, 416, 416, 416, 416, 416, 416, 416 + ] + + EXPECTED_OUTPUT_0 = [ + 1755, 1061, 234, 1755, 1061, 1755, 185, 290, 307, 307, 616, 616, + 616, 616, 616, 616, 307, 290, 417, 1755, 234, 1755, 185, 290, + 290, 290, 307, 616, 616, 616, 616, 616, 290, 234, 234, 1755, + 234, 234, 1755, 234, 185, 185, 307, 616, 616, 616, 616, 290, + 1755, 1755, 1755, 234, 234, 1755, 1572, 290, 307, 616, 34, 616 + ] + EXPECTED_OUTPUT_0_PT_2 = [ + 854, 842, 1353, 114, 1353, 842, 185, 842, 185, 114, 591, 842, 185, + 417, 185, 842, 307, 842, 591, 842, 185, 842, 185, 842, 591, 842, + 1353, 842, 185, 842, 591, 842, 591, 114, 591, 842, 185, 842, 591, + 89, 591, 842, 591, 842, 591, 417, 1372, 842, 1372, 842, 34, 842, + 185, 89, 591, 842, 185, 842, 591, 632 + ] + + EXPECTED_GPU_OUTPUTS_2 = [ + 1489, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, + 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, + 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, + 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, + 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653 + ] + EXPECTED_GPU_OUTPUTS_2_PT_2 = [ + 1489, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, + 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, + 653, 653, 653, 653, 653, 653, 653, 1853, 1177, 1536, 1228, + 710, 475, 1489, 1229, 1224, 231, 1224, 252, 1434, 653, 475, + 1106, 1877, 1599, 1228, 1600, 1683, 1182, 1853, 475, 1864, + 252, 1229, 1434, 2001 + ] + + EXPECTED_GPU_OUTPUTS_1 = [ + 1125, 1125, 416, 1125, 1125, 416, 1125, 1125, 416, 416, 1125, 416, + 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, + 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, + 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, + 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416 + ] + EXPECTED_GPU_OUTPUTS_0 = [ + 491, 1755, 34, 1613, 1755, 417, 992, 1613, 222, 842, 1353, 1613, + 844, 632, 185, 1613, 844, 632, 185, 1613, 185, 842, 677, 1613, + 185, 114, 1353, 1613, 307, 89, 844, 1613, 307, 1332, 234, 1979, + 307, 89, 1353, 616, 34, 842, 185, 842, 34, 842, 185, 842, + 307, 114, 185, 89, 34, 1268, 185, 89, 34, 842, 185, 89 + ] + # fmt: on + + def prepare_inputs(self, model_id): + tokenizer = JukeboxTokenizer.from_pretrained(model_id) + tokens = tokenizer(**self.metas)["input_ids"] + return tokens + + #@slow + def test_sampling(self): + model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).set_train(False) + labels = self.prepare_inputs(self.model_id) + + set_seed(0) + zs = [ops.zeros((1, 0), dtype=mindspore.int64) for _ in range(3)] + zs = model._sample(zs, labels, [0], sample_length=60 * model.priors[0].raw_to_tokens, save_results=False) + self.assertIn(zs[0][0].detach().tolist(), [self.EXPECTED_OUTPUT_2, self.EXPECTED_OUTPUT_2_PT_2]) + + set_seed(0) + zs = model._sample(zs, labels, [1], sample_length=60 * model.priors[1].raw_to_tokens, save_results=False) + self.assertIn(zs[1][0].detach().tolist(), [self.EXPECTED_OUTPUT_1, self.EXPECTED_OUTPUT_1_PT_2]) + + set_seed(0) + zs = model._sample(zs, labels, [2], sample_length=60 * model.priors[2].raw_to_tokens, save_results=False) + self.assertIn(zs[2][0].detach().tolist(), [self.EXPECTED_OUTPUT_0, self.EXPECTED_OUTPUT_0_PT_2]) + + #@slow + @skip("Not enough GPU memory on CI runners") + def test_slow_sampling(self): + model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).set_train(False) + labels = [i for i in self.prepare_inputs(self.model_id)] + + set_seed(0) + model.priors[0] + zs = [ops.zeros((1, 0), dtype=mindspore.int64) for _ in range(3)] + zs = model._sample(zs, labels, [0], sample_length=60 * model.priors[0].raw_to_tokens, save_results=False) + self.assertTrue(mindnlp.core.ops.allclose(zs[0][0], mindspore.tensor(self.EXPECTED_GPU_OUTPUTS_2))) + model.priors[0] + + set_seed(0) + model.priors[1] + zs = model._sample(zs, labels, [1], sample_length=60 * model.priors[1].raw_to_tokens, save_results=False) + self.assertTrue(mindnlp.core.ops.allclose(zs[1][0], mindspore.tensor(self.EXPECTED_GPU_OUTPUTS_1))) + model.priors[1] + + set_seed(0) + model.priors[2] + zs = model._sample(zs, labels, [2], sample_length=60 * model.priors[2].raw_to_tokens, save_results=False) + self.assertTrue(mindnlp.core.ops.allclose(zs[2][0], mindspore.tensor(self.EXPECTED_GPU_OUTPUTS_0))) + + #@slow + def test_fp16_slow_sampling(self): + prior_id = "ArthurZ/jukebox_prior_0" + model = JukeboxPrior.from_pretrained(prior_id, min_duration=0).set_train(False).half() + + labels = self.prepare_inputs(prior_id)[0] + metadata = model.get_metadata(labels, 0, 7680, 0) + set_seed(0) + outputs = model.sample(1, metadata=metadata, sample_tokens=60) + self.assertIn(outputs[0].tolist(), [self.EXPECTED_GPU_OUTPUTS_2, self.EXPECTED_GPU_OUTPUTS_2_PT_2]) diff --git a/mindnlp/transformers/models/jukebox/test_tokenization_jukebox.py b/mindnlp/transformers/models/jukebox/test_tokenization_jukebox.py new file mode 100644 index 000000000..e971dde44 --- /dev/null +++ b/mindnlp/transformers/models/jukebox/test_tokenization_jukebox.py @@ -0,0 +1,210 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from mindnlp.transformers import JukeboxTokenizer +from mindnlp.utils.testing_utils import require_mindspore + +class JukeboxTokenizationTest(unittest.TestCase): + tokenizer_class = JukeboxTokenizer + metas = { + "artist": "Zac Brown Band", + "genres": "Country", + "lyrics": """I met a traveller from an antique land, + Who said "Two vast and trunkless legs of stone + Stand in the desert. . . . Near them, on the sand, + Half sunk a shattered visage lies, whose frown, + And wrinkled lip, and sneer of cold command, + Tell that its sculptor well those passions read + Which yet survive, stamped on these lifeless things, + The hand that mocked them, and the heart that fed; + And on the pedestal, these words appear: + My name is Ozymandias, King of Kings; + Look on my Works, ye Mighty, and despair! + Nothing beside remains. Round the decay + Of that colossal Wreck, boundless and bare + The lone and level sands stretch far away + """, + } + + @require_mindspore + def test_1b_lyrics_tokenizer(self): + """ + how to run the same test with openAI + ... + """ + import mindspore + from mindnlp.core import ops + + tokenizer = JukeboxTokenizer.from_pretrained("openai/jukebox-1b-lyrics") + tokens = tokenizer(**self.metas)["input_ids"] + # fmt: off + EXPECTED_OUTPUT = [ + mindspore.tensor([[ + 0, 0, 0, 7169, 507, 9, 76, 39, 31, 46, 76, 27, + 76, 46, 44, 27, 48, 31, 38, 38, 31, 44, 76, 32, + 44, 41, 39, 76, 27, 40, 76, 27, 40, 46, 35, 43, + 47, 31, 76, 38, 27, 40, 30, 64, 78, 76, 76, 76, + 76, 76, 76, 76, 76, 23, 34, 41, 76, 45, 27, 35, + 30, 76, 71, 20, 49, 41, 76, 48, 27, 45, 46, 76, + 27, 40, 30, 76, 46, 44, 47, 40, 37, 38, 31, 45, + 45, 76, 38, 31, 33, 45, 76, 41, 32, 76, 45, 46, + 41, 40, 31, 78, 76, 76, 76, 76, 76, 76, 76, 76, + 19, 46, 27, 40, 30, 76, 35, 40, 76, 46, 34, 31, + 76, 30, 31, 45, 31, 44, 46, 63, 76, 63, 76, 63, + 76, 63, 76, 14, 31, 27, 44, 76, 46, 34, 31, 39, + 64, 76, 41, 40, 76, 46, 34, 31, 76, 45, 27, 40, + 30, 64, 78, 76, 76, 76, 76, 76, 76, 76, 76, 8, + 27, 38, 32, 76, 45, 47, 40, 37, 76, 27, 76, 45, + 34, 27, 46, 46, 31, 44, 31, 30, 76, 48, 35, 45, + 27, 33, 31, 76, 38, 35, 31, 45, 64, 76, 49, 34, + 41, 45, 31, 76, 32, 44, 41, 49, 40, 64, 78, 76, + 76, 76, 76, 76, 76, 76, 76, 1, 40, 30, 76, 49, + 44, 35, 40, 37, 38, 31, 30, 76, 38, 35, 42, 64, + 76, 27, 40, 30, 76, 45, 40, 31, 31, 44, 76, 41, + 32, 76, 29, 41, 38, 30, 76, 29, 41, 39, 39, 27, + 40, 30, 64, 78, 76, 76, 76, 76, 76, 76, 76, 76, + 20, 31, 38, 38, 76, 46, 34, 27, 46, 76, 35, 46, + 45, 76, 45, 29, 47, 38, 42, 46, 41, 44, 76, 49, + 31, 38, 38, 76, 46, 34, 41, 45, 31, 76, 42, 27, + 45, 45, 35, 41, 40, 45, 76, 44, 31, 27, 30, 78, + 76, 76, 76, 76, 76, 76, 76, 76, 23, 34, 35, 29, + 34, 76, 51, 31, 46, 76, 45, 47, 44, 48, 35, 48, + 31, 64, 76, 45, 46, 27, 39, 42, 31, 30, 76, 41, + 40, 76, 46, 34, 31, 45, 31, 76, 38, 35, 32, 31, + 38, 31, 45, 45, 76, 46, 34, 35, 40, 33, 45, 64, + 78, 76, 76, 76, 76, 76, 76, 76, 76, 20, 34, 31, + 76, 34, 27, 40, 30, 76, 46, 34, 27, 46, 76, 39, + 41, 29, 37, 31, 30, 76, 46, 34, 31, 39, 64, 76, + 27, 40, 30, 76, 46, 34, 31, 76, 34, 31, 27, 44, + 46, 76, 46, 34, 27, 46, 76, 32, 31, 30, 66, 78, + 76, 76, 76, 76, 76, 76, 76, 76, 1, 40, 30, 76, + 41, 40, 76, 46, 34, 31, 76, 42, 31, 30, 31, 45, + 46, 27, 38, 64, 76, 46, 34, 31, 45, 31, 76, 49, + 41, 44, 30, 45, 76, 27, 42, 42, 31, 27, 44, 65, + 78, 76, 76, 76, 76, 76, 76, 76, 76, 13, 51, 76, + 40, 27, 39, 31, 76, 35, 45, 76, 15, 52, 51, 39, + 27, 40, 30, 35, 27, 45, 64, 76, 11, 35, 40, 33, + 76, 41, 32, 76, 11, 35, 40, 33, 45, 66, 78, 76, + 76, 76, 76, 76, 76, 76, 76, 12, 41, 41, 37, 76, + 41, 40, 76, 39, 51, 76, 23, 41, 44, 37, 45, 64, + 76, 51, 31, 76, 13, 35, 33, 34, 46, 51, 64, 76, + 27, 40, 30, 76, 30, 31, 45, 42, 27, 35, 44, 67, + 78, 76, 76, 76, 76, 76, 76, 76, 76, 14, 41, 46, + 34, 35, 40, 33, 76, 28, 31, 45, 35, 30, 31, 76, + 44, 31, 39, 27, 35, 40, 45, 63, 76, 18, 41, 47, + 40, 30, 76, 46, 34, 31, 76, 30, 31, 29, 27, 51, + 78, 76, 76, 76, 76, 76, 76, 76, 76, 15, 32, 76, + 46, 34, 27, 46, 76, 29, 41, 38, 41, 45, 45, 27, + 38, 76, 23, 44, 31, 29, 37, 64, 76, 28, 41, 47, + 40, 30, 38, 31, 45, 45, 76, 27, 40, 30, 76, 28, + 27, 44, 31, 78, 76, 76, 76, 76, 76, 76, 76, 76, + 20, 34, 31, 76, 38, 41, 40, 31, 76, 27, 40, 30, + 76, 38, 31, 48, 31, 38, 76, 45, 27, 40, 30, 45, + 76, 45, 46, 44, 31, 46, 29, 34, 76, 32, 27, 44, + 76, 27, 49, 27, 51, 78, 76, 76, 76, 76, 76, 76, + 76, 76]]), + mindspore.tensor([[0, 0, 0, 1069, 11]]), + mindspore.tensor([[0, 0, 0, 1069, 11]]), + ] + # fmt: on + self.assertTrue(ops.allclose(tokens[0], EXPECTED_OUTPUT[0])) + self.assertTrue(ops.allclose(tokens[1], EXPECTED_OUTPUT[1])) + self.assertTrue(ops.allclose(tokens[2], EXPECTED_OUTPUT[2])) + + @require_mindspore + def test_5b_lyrics_tokenizer(self): + """ + The outputs are similar that open AI but do not have the same format as this one is adapted to the HF integration. + """ + import mindspore + from mindnlp.core import ops + + tokenizer = JukeboxTokenizer.from_pretrained("openai/jukebox-5b-lyrics") + tokens = tokenizer(**self.metas)["input_ids"] + # fmt: off + EXPECTED_OUTPUT = [ + mindspore.tensor([[ + 0, 0, 0, 1069, 11, -1, -1, -1, -1, 9, 77, 39, + 31, 46, 77, 27, 77, 46, 44, 27, 48, 31, 38, 38, + 31, 44, 77, 32, 44, 41, 39, 77, 27, 40, 77, 27, + 40, 46, 35, 43, 47, 31, 77, 38, 27, 40, 30, 64, + 79, 77, 77, 77, 77, 77, 77, 77, 77, 23, 34, 41, + 77, 45, 27, 35, 30, 77, 72, 20, 49, 41, 77, 48, + 27, 45, 46, 77, 27, 40, 30, 77, 46, 44, 47, 40, + 37, 38, 31, 45, 45, 77, 38, 31, 33, 45, 77, 41, + 32, 77, 45, 46, 41, 40, 31, 79, 77, 77, 77, 77, + 77, 77, 77, 77, 19, 46, 27, 40, 30, 77, 35, 40, + 77, 46, 34, 31, 77, 30, 31, 45, 31, 44, 46, 63, + 77, 63, 77, 63, 77, 63, 77, 14, 31, 27, 44, 77, + 46, 34, 31, 39, 64, 77, 41, 40, 77, 46, 34, 31, + 77, 45, 27, 40, 30, 64, 79, 77, 77, 77, 77, 77, + 77, 77, 77, 8, 27, 38, 32, 77, 45, 47, 40, 37, + 77, 27, 77, 45, 34, 27, 46, 46, 31, 44, 31, 30, + 77, 48, 35, 45, 27, 33, 31, 77, 38, 35, 31, 45, + 64, 77, 49, 34, 41, 45, 31, 77, 32, 44, 41, 49, + 40, 64, 79, 77, 77, 77, 77, 77, 77, 77, 77, 1, + 40, 30, 77, 49, 44, 35, 40, 37, 38, 31, 30, 77, + 38, 35, 42, 64, 77, 27, 40, 30, 77, 45, 40, 31, + 31, 44, 77, 41, 32, 77, 29, 41, 38, 30, 77, 29, + 41, 39, 39, 27, 40, 30, 64, 79, 77, 77, 77, 77, + 77, 77, 77, 77, 20, 31, 38, 38, 77, 46, 34, 27, + 46, 77, 35, 46, 45, 77, 45, 29, 47, 38, 42, 46, + 41, 44, 77, 49, 31, 38, 38, 77, 46, 34, 41, 45, + 31, 77, 42, 27, 45, 45, 35, 41, 40, 45, 77, 44, + 31, 27, 30, 79, 77, 77, 77, 77, 77, 77, 77, 77, + 23, 34, 35, 29, 34, 77, 51, 31, 46, 77, 45, 47, + 44, 48, 35, 48, 31, 64, 77, 45, 46, 27, 39, 42, + 31, 30, 77, 41, 40, 77, 46, 34, 31, 45, 31, 77, + 38, 35, 32, 31, 38, 31, 45, 45, 77, 46, 34, 35, + 40, 33, 45, 64, 79, 77, 77, 77, 77, 77, 77, 77, + 77, 20, 34, 31, 77, 34, 27, 40, 30, 77, 46, 34, + 27, 46, 77, 39, 41, 29, 37, 31, 30, 77, 46, 34, + 31, 39, 64, 77, 27, 40, 30, 77, 46, 34, 31, 77, + 34, 31, 27, 44, 46, 77, 46, 34, 27, 46, 77, 32, + 31, 30, 66, 79, 77, 77, 77, 77, 77, 77, 77, 77, + 1, 40, 30, 77, 41, 40, 77, 46, 34, 31, 77, 42, + 31, 30, 31, 45, 46, 27, 38, 64, 77, 46, 34, 31, + 45, 31, 77, 49, 41, 44, 30, 45, 77, 27, 42, 42, + 31, 27, 44, 65, 79, 77, 77, 77, 77, 77, 77, 77, + 77, 13, 51, 77, 40, 27, 39, 31, 77, 35, 45, 77, + 15, 52, 51, 39, 27, 40, 30, 35, 27, 45, 64, 77, + 11, 35, 40, 33, 77, 41, 32, 77, 11, 35, 40, 33, + 45, 66, 79, 77, 77, 77, 77, 77, 77, 77, 77, 12, + 41, 41, 37, 77, 41, 40, 77, 39, 51, 77, 23, 41, + 44, 37, 45, 64, 77, 51, 31, 77, 13, 35, 33, 34, + 46, 51, 64, 77, 27, 40, 30, 77, 30, 31, 45, 42, + 27, 35, 44, 67, 79, 77, 77, 77, 77, 77, 77, 77, + 77, 14, 41, 46, 34, 35, 40, 33, 77, 28, 31, 45, + 35, 30, 31, 77, 44, 31, 39, 27, 35, 40, 45, 63, + 77, 18, 41, 47, 40, 30, 77, 46, 34, 31, 77, 30, + 31, 29, 27, 51, 79, 77, 77, 77, 77, 77, 77, 77, + 77, 15, 32, 77, 46, 34, 27, 46, 77, 29, 41, 38, + 41, 45, 45, 27, 38, 77, 23, 44, 31, 29, 37, 64, + 77, 28, 41, 47, 40, 30, 38, 31, 45, 45, 77, 27, + 40, 30, 77, 28, 27, 44, 31, 79, 77, 77, 77, 77, + 77, 77, 77, 77, 20, 34, 31, 77, 38, 41, 40, 31, + 77, 27, 40, 30, 77, 38, 31, 48, 31, 38, 77, 45, + 27, 40, 30, 45, 77, 45, 46, 44, 31, 46, 29, 34, + 77, 32, 27, 44, 77, 27, 49, 27, 51, 79, 77, 77, + 77, 77, 77, 77, 77, 77]]), + mindspore.tensor([[0, 0, 0, 1069, 11, -1, -1, -1, -1]]), + mindspore.tensor([[0, 0, 0, 1069, 11, -1, -1, -1, -1]]), + ] + # fmt: on + self.assertTrue(ops.allclose(tokens[0], EXPECTED_OUTPUT[0])) + self.assertTrue(ops.allclose(tokens[1], EXPECTED_OUTPUT[1])) + self.assertTrue(ops.allclose(tokens[2], EXPECTED_OUTPUT[2]))