From 222fd7d5f13e3369a7e9b2a629000ffffb47a2f0 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 7 Nov 2023 00:10:55 -0800 Subject: [PATCH] Adding gpu quantization workflows and apis Summary: Apis and workflows used for quantization and pruning in the segment-anything-fast and gpt-fast repos. Test Plan: python /home/cdhernandez/local/ao/ao/quantization/test.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- ao/quantization/__init__.py | 39 + .../__pycache__/dynamic_quant.cpython-310.pyc | Bin 0 -> 2871 bytes .../__pycache__/quant_api.cpython-310.pyc | Bin 0 -> 2443 bytes .../quant_primitives.cpython-310.pyc | Bin 0 -> 6132 bytes .../__pycache__/smoothquant.cpython-310.pyc | Bin 0 -> 6921 bytes .../__pycache__/subclass.cpython-310.pyc | Bin 0 -> 3940 bytes .../__pycache__/utils.cpython-310.pyc | Bin 0 -> 2791 bytes .../__pycache__/weight_only.cpython-310.pyc | Bin 0 -> 1592 bytes ao/quantization/dynamic_quant.py | 96 ++ ao/quantization/quant_api.py | 74 ++ ao/quantization/quant_primitives.py | 384 +++++++ ao/quantization/smoothquant.py | 237 ++++ ao/quantization/subclass.py | 117 ++ ao/quantization/test.py | 1024 +++++++++++++++++ ao/quantization/utils.py | 98 ++ ao/quantization/weight_only.py | 49 + setup.py | 14 + 17 files changed, 2132 insertions(+) create mode 100644 ao/quantization/__init__.py create mode 100644 ao/quantization/__pycache__/dynamic_quant.cpython-310.pyc create mode 100644 ao/quantization/__pycache__/quant_api.cpython-310.pyc create mode 100644 ao/quantization/__pycache__/quant_primitives.cpython-310.pyc create mode 100644 ao/quantization/__pycache__/smoothquant.cpython-310.pyc create mode 100644 ao/quantization/__pycache__/subclass.cpython-310.pyc create mode 100644 ao/quantization/__pycache__/utils.cpython-310.pyc create mode 100644 ao/quantization/__pycache__/weight_only.cpython-310.pyc create mode 100644 ao/quantization/dynamic_quant.py create mode 100644 ao/quantization/quant_api.py create mode 100644 ao/quantization/quant_primitives.py create mode 100644 ao/quantization/smoothquant.py create mode 100644 ao/quantization/subclass.py create mode 100644 ao/quantization/test.py create mode 100644 ao/quantization/utils.py create mode 100644 ao/quantization/weight_only.py create mode 100644 setup.py diff --git a/ao/quantization/__init__.py b/ao/quantization/__init__.py new file mode 100644 index 0000000000..a18cbaa43a --- /dev/null +++ b/ao/quantization/__init__.py @@ -0,0 +1,39 @@ +from smoothquant import * # noqa: F403 +from quant_api import * # noqa: F403 +from subclass import * # noqa: F403 +from quant_primitives import * # noqa: F403 +from utils import * # noqa: F403 +from weight_only import * # noqa: F403 + +__all__ = [ + "DynamicallyPerAxisQuantizedLinear", + "replace_with_custom_fn_if_matches_filter", + "apply_weight_only_int8_quant", + "apply_dynamic_quant", + "change_linear_weights_to_dqtensors", + "insert_subclass", + "safe_int_mm", + "dynamically_quantize_per_tensor", + "quantize_activation_per_token_absmax", + "dynamically_quantize_per_channel", + "dequantize_per_tensor", + "dequantize_per_channel", + "quant_int8_dynamic_linear", + "quant_int8_matmul", + "quant_int8_dynamic_per_token_linear", + "quant_int8_per_token_matmul", + "get_scale", + "SmoothFakeDynQuantMixin", + "SmoothFakeDynamicallyQuantizedLinear", + "swap_linear_with_smooth_fq_linear", + "smooth_fq_linear_to_inference", + "set_smooth_fq_attribute", + "DynamicallyQuantizedLinearWeight", + "log_with_rank", + "clear_logs", + "compute_error", + "forward_hook", + "apply_logging_hook", + "get_model_size_in_bytes", + "WeightOnlyInt8QuantLinear", +] diff --git a/ao/quantization/__pycache__/dynamic_quant.cpython-310.pyc b/ao/quantization/__pycache__/dynamic_quant.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9618fb7de93d96a635e78e4ea0dd44afc1cdbefe GIT binary patch literal 2871 zcmai0Pj4GV6rY*>C@w9y3^wK0#pU`@; z8CUszp3yYPcqqzuzdliOi;Jtnr6i*<57ks~S+?`3I%(NkNhnKKa2U*PM$Up$bzK(=io5_qp*?hf3tW)G zI;h535O@LK>{+EFr*PM$N8x@GMQPccN={=COy4*ik9$sOOPlLmR4K0>H zJjG>`QsX>IpQ7iVZVdA=--y_di!96--`z-afYwHs>;IyQHs(+=o&Cv<&FcY4+M`e~f1(kS$EL477-xP(2e{82F8Zxw4F+Umu7)}D z!^tE~00OlGUIvcJjlv2vif0lK_wXD8{#0sE>b^~5)&KK$(~qZFq#?Ut>{T(4 zSB}+d2i)~Ps912@Sk$`$teGRbHnKMSOiz4oQ2UDhMdg)l0KBHn+yk-V;uTMiLp7cP zn=%HJl>jonaD-mN;InMU;eAf)aiKv`w(l5V1I-0hrFVy#wbFpPbeP(i@SY{wFjZ(Y zmiR}pq8mnZbx)&)z&p^8dZN41!o*p__K|Jv_Ra=vy}q=@Ig7MW^G8$Q7`4AagW-E* zX3YShg$2~bBM9?TATzU`*fWh;TsRcyJtFWX%7U5=A>UCBpxM4>&D>dI<`vFCgSmx! z;2ola<{>(SgrQmoOGV>QuWxI;@J0oZH(+IlH3nLzJz6$$g?Udb81EL&rma>Be6V(i z_V8$B)&$Mxism7Pz~TUJnqjucn`?RYkPD@w>Lwg@Q`4o2w5NiqU*o4Hh0aEEVgRc^ zi!@<1_7k;m%{Oz!uK+HNBRI=}SXr(B5)Y<{;6C6Y*%BrwFRVL_a4>I<@#9Ee!#o$s zAj!h?bSCg&)imYW_NY4Q!XB>$`kC0M&|hnVlb_;KwdKM-J@U*0q=>gcKtewW&dJY5 zM@P?o(qrd944}aC*B3D74XBhZ;}4UF_tvW=C*fVk(b9dylfh7xj^So$>-1l)2tG(8 zaFK3J;~4Tr3oKR1V46>5iB(CmtZd#2MF^pe3$bn#RU9^Zii^-A)^xX>@oiPMCiBU% zBr5N?w{Z&UpxmewN;N?9I3w*_}l+FN;^np`P#{3+ck zB#XCnW67|q$}{3UGhBVNh!1t6Q|Ic4rzxciUVu+dBy?;uA4b~o@)8IS*_bp5w*MxC z{7WqI53$J$;x0dbvrR4#VjsDf91*hqZ?9QdmqKH!ioUM1mJSL4}m?mhZv`Kz`+kXC9D{Ak290(&SQs5UaJ<_Ik!MzxmC$Y_|gf&(9B5 zc;XZC3rio7<<|3+y7Bhm{ty?vu__lMVSg7h<$boTy=(g;`l_ zgo=wi+@3WQ^-C_k<|6!P?Q^*p{0hjC)ehA!-yIcko+fdYjn=rh{Y@$-ZgTcRql&Z|5ei5bpD34V# z;4w2DB3YHy#WyaizsEkw@v2DOg zmO^Tx8-lB$C=w*k^l8GN0qsL*GJ{T$iX6~GvPXBF$~mwO$zAeg=Yh5B?lvl$(XTA= zbmi`Pg;&`R39~BqA^E}H^(w1sRNkH?7AjgbK-yP!MJl)Q_kfDqfpXwKgLr6TW4$mf!d?!v$a)Jv2m>6DMuiAo*qf~6yifq)4I|U(8BJ+F z?A17)`h#@qeI|yy-XqXyO&3=?u69_OC_M)R;4CtV1TVCkDxOR2Nv>k0gm6Ki)6}jp zNza`>n!skJQd=|QvH&1=qpmSCL-#KboFgtdB zH@gxFJEjkg(K!=-v;f1e=%$7BWS?N@7=b|9tBEz0fiI42X!oK0HV55QY}4L@QgK9A zj~pGOGA*Qviv&DZrAP+aDhl1GQ{u?dHo&QE0K)iQch7(_?cd7ce4WM1Z=rrXY>tc< z|0yjNZyA{6G7M_bj%~aZNOm3c1!h}T#*n9C9+qZ~6qo~VK$BQ^EoxB*@bmvP>*riY z#L#_t6*_eNThhnMYwi2YK7*UHdJb+h=n@LIjPL=i$amxrP!m^Rp<@a3SGc&ZlWO1ro@ZBA&z1OV}YZW{!Lx zmtKMHh}P+0)wsTPvh_NSFrLDL?j2|+;086gxPAs7bm4@K2yRWoU5X%AaJ4b#6e-=u zgNvk;FG8otCNNpq2Tnyc-9rC#E}dvY#H`a0c&vB4eRrTS5{ zX0-k9psk=A7_{-5|HX`HPKnn+6c}m}XaRN25Cc8E4b8ahQ>y{>=UCt>a2cyE2xEf5 zwJWgnblTdBB333*WN_E*GxwTF1?^A4G*QqOPF;+3c7qDIyFmp~$7HNwoH_TDg~rxN leSyj|%=j|eAd%({@x2yp*)7|5tfqBrS;1fa<-l+Ge*+NkVGjTR literal 0 HcmV?d00001 diff --git a/ao/quantization/__pycache__/quant_primitives.cpython-310.pyc b/ao/quantization/__pycache__/quant_primitives.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f956bc39f377662108e0b12c394840b71a0e968 GIT binary patch literal 6132 zcmaJ_Taz0{6`t<7YINE4dhM*^U?z~8Ah8oWO@C)WOc`b| z@0`;U*Xyo^=YyAj7(Vo@rhP(%#m7S7WxUDXkO)n%zSg2&-e&{e(i!FTzA-RcCet=F zVSL0|mM}xDXA4WEHl-a)IHK~t(Q;6#beM2O^?lZIMNQP-*IHH45KH*2L0X4&S)72h zQI=N3NtBk#(yBOx(sEfkEgnGW1eJu=;bKjkL3w3feo&l6`J`YswB|#%u?FpWljXH2 z@!RoW6!oLw2FjIqocbc&9feIduO>k!^rKurX?KI+Fzn~2L^vO5VffR>s%Op5 zsK%t_^L{zDkHv(6%ug(c20=O)_w#QqRLxqeNtBLdzq1RCj)dh1W;7g))7(JuVs6PW z9m`>x(|(qZjb8%Gckm|Hn5L!LftFMR%d`obYFnB-mGV^Y89lRS^=#yvOv~7TaRX_u zk^vT*JbCO#v~;mZnrAQR1bsYDD{G&@Pe&45{Z|DTcHg4-gqd&B$B~;KlIkaZm<=_GV$V$mv+OW z!8)<@+M!H?Xy~aEPQ2}Igf2-M_4{h};;g|FuN_Mnwo^}r$+({;YS#6TB!xBwqe6j^ zyubs5F!}(D@wS8A^R61~l_VJt!o*vT)2>%ADDj|N=!s|mLeT=Db(^`>Gj@Gs`g&I`pjio{!r~fDx(y$ zZH5x`Q_5{)$;C`FZRJLCMTzJb)c$x++*G{BRG?n-O0k`hOww{W&PMlN z_b2Myw|f${V;o?y9bfFP26#g;?noIJtb4|~bg|T>IGou*7wq-6BW$+_x1u&Q3(|I1 z9a=X?XYGo!S2N}AYr$q%ta9(tFvNM@b6UmvuyX{YBpBf6ihuMs{?u~IgPI4FfUFMs%Y@n4~LIQWBR^mhP z*4HA^E&)o~DNO%{450(8KFqcM%;naDJrWP&;a(_XzdY^Gk30VJtp_*1Ty|zX5dIwN zi4P?E?eTg)NIur&qfp|`wF}*N5MF4DE|3XghkFZuCa!(!i>XJ-ryedn87*YhnyRqJJuuIg2XW$=CBoJ>HKg_ zSN+>2`pzEOnz?0vOhn;nXPInnbULj-3-5LT3X3q|JSqwc2?W3Vq`a~W~N}b_@pW@<9;>6+_1?`Ry~Wk zCUvxFOwnV%hI%)vWh~=Uy(p=cTG2AAJplVmvPnJDF;ac_g3^5ne&=2I9X&JPZ8&)w z^dp^@X$@%E$Qnm_3FEBxWxA0D_sn_Ed;&UVb*!h>(GiZA*FR>tK8S{L6+67QIvytD z(I}QF824@hTL)_0CKua0A=mJeo1JLbiF3Qtj|0NK9)cPtj)Db0=%7U?pP>(6y=7V}c7^7`Hfr5Z+T^4G3^5;}a8Zp9?c% z!de3jDMe_)dzP?;bBh7=1U>YaQ24_Yge_4KE&_=P0t^>YjU*ioh*Yo3Q<(YoA)cE} zp6jq2xxE>NBV=R~B88oee|JCm-CzFu={5N{s-oYZ_3d9F*R0CR)P|Ojn<~C7kW@hE znalEN5Y5QBL#nvw0U6i_(&MVxoB#Cjf@Lz%w#=7?s*NNK7##K?S}R-!(~tHMbpe3AX)_&@?~^Z zsuzI<+dZx)MJa#9n^1;yqCZK6FK9Hk5JWW6h!6|ZL5B$|wS!i$0;N_;vd{`F*RQJ+ zFMz@28yKlnTK76?|73Mk{T~`uQ7v3lC&wAJ#}Pg zNoqgom{YPla9EjsV7#sGSN2^vU+!y8V80QJ6n(Xc0gs~fOuFm>TXt$?@FNGxw)Kn+(!j)b81uLMjP?-|<{#J_ z+V3^8d=nd5;nKQa!8!8GSk-_;7pH8lYMHx=5G$(|p&8|XMeF(MZ_{DHzLU|OA1G&M+tGJuux>znl<@0 zRlGw9otk`;k`^UDr-VqK>nY-u>aH*1lwfk{Vyj*h%3!E<5=V^)ied)gmzDRW2%+;| zc#}tv7-tL|vU9hl8=p5!5asY3U$I?wj@7K1U4tP41gJ-wT-EF$8WAKE1ZZ7EodX6p zW#MobsA4N38JT%t;TUv9mL3h`<%lpetRyO8-;@zni&yJ@i$rSi%i|>UJ4i%% z;a5ABJ0;c=`3q`#%oGq#A4fow(EX<7au`;``nm!&$-c0nxiRRX*9_DE<rlb4ig1d^j+mz$wDLB}-bkXTQfgsS)b>0ELGHXbh{d=c(uYrLzgc9b}XiDczK%B-`^=2Q)* z$)4^}bqy&FhFBw1ki)`0<`M)TgIv7WKVVOL%Rf<<1W15Dw*U!pSgg(cUiD0qBZ+`U z?CR>O>ep5E-uvG7O8Meq!@}=RAHFC4<4czHM|wE>^YHL3T*ZH(;8w@tjN2*eFmtyD zj%_h(;C4J#_V^tiJx*F11f5_|@6?&KWpVd6tkdAnaG%#6u+GAFEgtat1B=%M+jh2_ zhfZe^Z4JJFwgs_hT9*#p&V{ljE+nk8T=)3py3McfC4S+7*SRSC?Ul|7zlxTNd<89+ z_$9vj!0N1W`?l43HonFztMP#-RFZ9l`ACIvw>y$?cmI_G3SYyg6hYLS{j2< zs*+PVoBLtWjZ;CN=#2`&!z5F=d3Pv9S8!Tj8t)5vBP@oZoAfY)6ie*oGL#|}yR@QV zELB0vR5H#AbPa?IH`BbkQ?wfopnr8P-imj`yZf2h-}jUINv2<(>zIAk48T|U z4~#7RVzCzwr-CEewWu(PihB2^t@;DOO4*nA+su()`$xHgFIU%T6C-KZY!v|2DD9>(dg9~&9(^}@TP zd1#I&u-AgDbz}2-c712V4823WXS#N-4>oR?4@#j%G85d$sZ3~I(9TBNbeEQVH^Qti zDG~`ADkqQGj#QCwVXTL2y*J9b#$-1|H-<%qd%_6G1eONV+!10J#-SW#8QJEpkY?8S za-u>$lGF(C$)+^U&<^iRziXsC<+sVYyE*owP|)&z=}@x;l#AMe_VD}5+Ye9hYdP9e zx$O3}GtBpNO^Tu)4@JxG1hdyV!E7elGb-qKcOwk$cz5Wg9iHrdYRQWb`oI6Q*3SoG zt;_pDX3&`!ufdz*bS=*5Z|Xc|forAlnWpw|U%Ghzd0a(+qFG;N0c$YR+C{_M&&JcY za24M`AuJ%T)v*bEfpiXcxQn~XJ?`V~@fr_s_j#Q+aIf(N-o!oNi?GWjNmkt@fwLdg zX)cwi!?3LgK#vkNGHDJoXV!$;@i~;PDLrmMq`UrsZcZJ`*w-z1{z9eMY)fCfA5E=* zG#tUS^aUxlNM|zI9QEL$@&aa{SGd?5ZAHaSGBmIO6VWRL)Fq-l0-H#Rp^OwD$}L3N zGcOrEl*J0kcrMA3nWLM%j&@_2#G9!o{t^8}6NS%w*2Lx8jpq2Z+XfUKL#Fomg68&* z_O9B$f5C&Y6jSyDVeJqgan(9xhwQ*s?zX4gBY(n<0UX@^7(g$T(Bbqq}-_G~|G==rK*c4Uq-$hs*XV z?6<*I?eVMU?dW+zy0p02kD#}P%h=-(73>A9am0%H_dzfY+0S5y+w6c!yLeGKM=map zMqy{!kq>)yjseuO8b-418Z;107>AoEXS;sre&`%H2QEO{IcgkR&;>wz;`JO@w7))n zo$RZ9)~X2&*Yix=0L=1mKOcpJtAycIjCotqIwU!D@)5A+;lHEST9B7;(|$<+A&RDBGxx-|eb0qb?UTT^Klf`ksrMvs z8f@G?Pr*+XJf~jTBDzSue(Kfc`;@K2x-TIDy*J}zf`3QKTz<;Tf*)N(^K5kA zgq(u*%V*XcxXB}Ix5JEioE?Js<9<2_rksZ0Eg&ldv zB`#VUK6!4{|42I!QoahQr`FM`$=9giCKXqyxJ(7%lO(bsuc2r;<+|xr?MF}41;!b4TXuZsDhr(DlYot1I!{$+-kjTvv-*)=rFpRN|Co?WWahR2f7~$6@fZ|r zP}GmZ-8dbA4D^$3AK^Gn!&JnE5R~3j(%4EM=m5Y7ZY02-d*Ox&xZixW{W{T?JQNur zcL;AFvKoWPg_{WU{Xr~uh(Hh(1Tm?^RbeO3%IHs%mak(O^6OMk&@KxTELMmVmOmr);6){KD#iGjQiY z0JZ=iB(@-YzGxT&fE8S!VRHoQBHfvI9~lkJVx=3)twQ@d zdvrIi0i^0t)J4E7qUh9;jG+As4t14`Q6OyZ&pVieX&^fC~|lq4W6LG1>~DG7)CGU}TB zL^n@ERHV+By79ys{2RN>e~!TgL86avv%=h_|M!NoYTGBavy7nRI^!EpwX0KclBmgC zW=B9@?Q91RoP7u?hAhr96o7$j*dwB7&c{R#*fyJZAi@krb5 z-&{h~4na=n7((KJAwg0zms(YmXxR@b+mP?0K&CG#6|8HS*dw^q zi=`;a+mTN`Y9WoD%m+~~&13aTOuL|v8Gw~6F(0f1u!w(4c7wTWy!Lb^Qpuac77BoM zT*mc2Km|?$Bbo9~oA{^6eu(3v2W)c2KmDT`!78{Xqh#!3dpp30IL%REj&7$Pojqh#Z9lg+u|THKZt{36wD#<&baN z+a4tbP>dmOMAAs}hP|6@07)a7{-g!V52=GPEBFv4Rh?jON*^AvS`*g)8zTzZ!4Q}> z{`ym`@o|O$wGfygpP*9!`AsU`LZNF>#Pcqq5yc?+18Q0@$N~@T!xIVQ%)Xn+=*4OH z>Ix)*(-3CZ#=A7mP!7V!F)pK|z;apB3q1Q`lLeQSmn8LSP}1^bsY?r+{b+&}xRRGD zioT5C5diBR*~iwq)=%F8;CcW{9AWI(@|lTuP|Io)`&R(oiT5k(*Umu=nHuuq;^hfy zyG(vtITK&G$~*E&)}f6SPXTHV*|XL`FbUvk{sdvO*F4lwY9tt`_-7 zb_IAqL1?8w4?rXd``jI|kJ2P?wDGzIP10ebB*Sbqv1bYbv^|llY=Ia6+wSjpg%eW1&L5L<=<1CfOyC&$fpqw5da<16?WXE z9WRMRO`DbRDLF3W7Q~tzm3^uiw=zVSU;_Lh#Z6y6b1HNO05yx`kLj&PR!jlz_qYsS zUM3DCP?k1?AzOfK2cOM+{m51oso&C@o3tK1`jPg*TSDmu6?Ge4}h0nK8qpGoSG zKShDlBV+9)eAtqb<@FZ3)p|z$8;v29>s+ds`|MmdPPLuE%ub)>EsWH46PcoTnCRuQ zI+V!(|9i133Ug**Xds~yoiat{L{t0kn{>z|EZQ%V0|O+b;>k+&wP`j{p6_=@IQT=n z&!VJ&~H_MjwZ|aaJGhKxZP%q=<5@RaA;FiNhk!wyV}BXQ%vf zEQb7rS^8@gZ|{^5$$9>n+TzdPDyze!)#1vPcO{qh5j$iZ2W>~XN37#YPd1LMjwjdU zj10hO$XVG0$Cq<*9`6~sAQ$ltlDo{JvFrQoW2LRk^a<@2l6Rcq%Um_wzydq&=e>5VEqjRav=ZuA`mGAqQ)p*lw{C z0wV<0%KJiPMXsw=k%qfU?CG#i@F@n*qa-YfwpZCvQdB{l^#`SpVHsBb3sC@13fxkM zS<%l6B`a@tFVx#bXF!e`ZYDX_ZiYGi#&`_NIL|gG zM{e}@tA-Gn+7n_9n-q&6fwRnf7O*?)4);!1?csx~5VtYj;({N^haCTpY2Q_lkv*~m zJO;`}9F$Azsr8B>r1T2LD<5VvYU>;5wjJoFlCDO0TZUt18D(p7+bj|Yr??d2E_wxl z$73F(AGWU{b-glTST%X5?c#65rV}It?pc3F>KIwa?1WnOmoC8TYz^PL>Wl}*{T~h+ z4_oBV55HpVHrMmmwhGR&$SbDT!Fr2oTJ^OMEesZ9%iGset6sGg(UQqyqe0d&dVPRK zX+bY`%uncTLrV%JOD~}3Wau3_2hh0ZBDnZdYs6_LwB}oyOw{%??XJn{6)S815aNxK z;8CNH2>}~na!ED)4a|HO|AJ<4cAq(?%l7d7tIWA^X0tb(b`fiEnMN$v3POCeo?7}2 zP8*}mNUpyLhHep|7&;Ft!XnwR4ZgAc>UC^CS@ni=Q%DigYtus}GB8a47aMWFw*Uhg zedU;}SPMvzP~*NutxHligj8i1?L5So!9HhA=AB+AV(W$fUrvhHi*hy*3M8riHc9dx zkrt6lGR=y!O!WxkW4Y!@u7_7T_}t<*WSe@J;z!~p*e3f~UBI(#Z=JK(cSx)jkxQ~I z3sIUJEV@ZpqA3K0g7NP>8-D+4@!nW}A;>7_O+8Hlj3DP4QNCtze!vfG*pwp^GHH`7 zkwh-Z-Dbx7ZKpFk7L^i54FL+aI?^ubwKCOfkWDrF5oVkCTIL)n5$6BF_4PZqPed9` zo*6WzX{1CZ7_kzu)x(|KbH**V=XCi?OaE;IY_D0z-pRDxYk)ghfqA=J{|d8Uot`go zSEKEiHrDO(;-k+{i%rInc)iIu+Kn&WL-t2jkB<2><9C2JcIZ0KoCEj3J7^sE*`t~_ zJ*PgiKBpX?gI>=cpJT=^oP$y9L3}!g-hsSuMmaljN8YG0@<(p3d5BA5s_R_OL2~^^ zIgk6#9yzrv#|tOcp>uB5qM7xtOS6{hS>^Jf#jNrMP~+E0_%&r>O)hy&u`yiAl#;D9 zEFZPZz4PPNu+{3uS(vnBoVLoG#-bgf-yxuXZt9EJU0LJegCYF^!hdTniyYn zR@oxrYSGMfu9-V7oge=JmE`N+Lx2l0Mt9;gi+mRyeOSDBjy74 zC>xr5VsCw+DgEJ=qZ_{pq8uM8+k<>i7<;#UW6sWOU1S;H2~h z&UXXb_k7!TeV5w-^F7XizK6G2JE)~wtiKZbI!@y88z&??H5C?V~LJ(pBMgW;fKpTMT6E=3qF|nsz-cDLsoTnbxluoef8L~+if#E zAAI=-^=HifM~ll(gvDKu`489^Oft!HHlR`Dd?E&7;tf3B$o+v2xt9l%a1e5~!=(Q& zK8R!>!zXMI%Sgr;Te2nF7~7Jk3(8ZSHKw|9VUx*@>_WaM7v(jK*D^NfDe+w6#c^pQ z*4P>0{oW?~l*uKz{DjG66&`tmWqDn$;GFA{Z?pa@&tN~>>I>`N%Z8PW|2VC(vPknj zw@aUwdN{t9nQ2-L$Dfr_^#j|@%bnzLR*eJIfm%vGur2#xIy%gJ=|5pLyT^|B375etBh9}62Oob&K%|IoGxN@4 zNa^oOP}P#*p3uLo+BN$QYst19^`kAz(;xWFA%_iT4;kl!wn1hEYzLqFDxC9*)%=M6 zL+p!f_LRYWo~bnaD%NjVZgdB#j{WfKjd3|q8$&r(x=0JD4ma|0nC2U4N&ow^w5YPf zlwjDHRatJ7K#_3guK4wnZDGBAqB3Jl0vItiM}kg>AbJG|&fAofQ>DPO5=E zA6?v0_vp3(7)=bZPT1xhev`j(8jBA1Pu=bn*rCKc&?g}CQ?Luzb@u}HGkBizt$v`d zLr)V7Hq6W@%ay+27FtcpN9vMA&i~G=5DIj>p-G1?k})(Ll26Ybf|%@g#=L_G4B$1Q zX)c;s&2P}klK}o$)S~uY$9e^E(HC3K$fCYygNic9`bDWHHh@oW-|G98e{K1}K>W#o zic2Os>!`524RY4CVcjy(~YrYr|pbN{W#Uq<@GH{Tu8>^ z@RteYIPQ9$^%r4z`5kvwU+DMnu701GgXmqFxi2iebkGKUhLAx3`Fb6bE50BfoMT2l ztl2R;K`r=d;enU~xmq0aN>G|oRoF*W-e%3UJ=>*VmsFc@Yh%~6+Chpac>MW2r&_tY zy6+OUu8aX&N)iu?7m+0Aeh6|%_IQutjAZt8crvNVq?{5r<8-QMGAw6BHH5;Ihc0y& zWL^bRY``TO2(%6=bbR1RRGWTim+m+I{#+HN)SsciI%~Fup{)Vt}X#ceu^?1+IA7Mnl`zyNAJgVZV^@!`c8ES47Z?Zb1ag9NCzmcs$26@!T z>(K2zkio^ejyB!fqzQ@ww#H=mArb)Z1#kvc?RCoz#d|lO{djtM>Z;!}R|4C%ojgZ4 zNsv$pVTV~eGg)D(v=}P=7W`tv#*ey73!iGK4IEA`JM;$-_gmICDj!*YG%JS4^nTZJ%WWfB_x4I2cPxW^6$+`nb! zWje{Tv?%8K$5~o7`I*&I>vT3s%hDEp?^~N6jhZjZV*apfKKhEze36$nbr~tS-sIJ! zzK4Nj2{$YePuP$fZp1~LNF$B9V2QGU2~2pw&Uq3(VTvUYoQfRX&-gWs0mk&!*U;+^ zAX>4+Ygc?ByOF?1HT;~7May&S{=uJe&RW?rEn5Yt41;QU=))78!+jstc&n`kftS;2 z>VtZQ$NE@fH*d6dWX_RYoKUqbhF+!asP^IG6Z&rUpK;&QIxEt;*7_Oy>*Iq_HL(Yo z8Ch4RrLkuRMFm?2X+{4lY||#M%7c>*Qe6Sl{psArFh>EeN%VQ2@A44x?44IazyHGD zOyJ6kM!thC>@)P*K(w6TYWWxu*kdcP5RF&@x)yX;23QYQd_}Y94F*W%I6lH1oW`wR z#2)({_kQDFv!}Gn#y!Jk+}$I1$Gs1`8vs`vUOq<~^u0RH?Fk?qrBmyJrgGWHi^-(J z%!{V-fobM|o8a}4v+INEhdiwZ%JpcKpZf4rcY65T-N4dX5CeQPtQ?|XZedbyL+q+M z{4S8{bH&fzel=QG_g(^vu#V6>uv!f8a0Y;QD~PU*7|V0@NG{|;EdnDSiAA`GTGobb zgp=UajJI+btk^PKu~r%NYp{s16Sc9X6EVK%-;-X2v1)nat7f&| z%1b@8X)^=be!H4Ao0%%@DGpClmriVB9XM|6ah_TCCK>K~nH9B{lgcDWp5gTU`x`G3 zsn4eV<}iJ1kvJQjx4M;=%tWO<(NN-d*4lMb=~-?68{7WSRTDidsMwi|W$D9paqU!Hv(Z`WvLOCs?fa)xDQjX184A^H`40z}m^A None: + super().__init__(in_features, out_features, bias) + self.use_fused_int_mm = use_fused_int_mm + # note: enabling use_fused_int_mm = True has best perf when additionally setting + # torch._inductor.config.force_fuse_int_mm_with_mul = True + + def forward(self, X: torch.Tensor) -> torch.Tensor: + """ + Performs the forward pass of the quantized linear layer. + + This method applies dynamic quantization to the input tensor across all axes except + the last axis using the `quant_int8_dynamic_per_token_linear` function. + + Args: + X (torch.Tensor): The input tensor to the quantized linear layer. + + Returns: + torch.Tensor: The output tensor after the quantized matmul and rescale. + + """ + # The following line mimics the behavior of SmoothFakeDynamicallyQuantizedLinear + if not self.use_fused_int_mm: + X = X / self.fake_rescale + # somehow the inductor fusion that occurs for most transformer models + # when this module has an additional div op is faster than when it doesn't + # have it although the memory usage is slightly higher. fake_rescale is scalar 1 + # so it doesn't affect accuracy + Y = quant_int8_dynamic_per_token_linear( + X, self.W_int_repr_t, self.W_scales, self.bias, X.dtype + ) + return Y + + @classmethod + def from_float( + cls, mod: torch.nn.Linear, use_fused_int_mm=False + ) -> "DynamicallyPerAxisQuantizedLinear": + """ + Converts a `mod` of class `torch.nn.Linear` to the dynamically quantized version of it. + + Note: this class does not require calibration. + + Args: + mod (torch.nn.Linear): The original `torch.nn.Linear` module to convert. + + Returns: + DynamicallyPerAxisQuantizedLinear: The converted quantized linear module. + + """ + + # create the new module with a toy size to ensure initialization is fast + fake_in_features, fake_out_features = 8, 8 + new_mod = cls( + fake_in_features, + fake_out_features, + bias=mod.bias is not None, + use_fused_int_mm=use_fused_int_mm, + ) + new_mod.in_features = mod.in_features + new_mod.out_features = mod.out_features + W_int_repr, W_scales, _W_zps = dynamically_quantize_per_channel( + mod.weight, -128, 127, torch.int8 + ) + new_mod.register_buffer("W_int_repr_t", W_int_repr.contiguous().t()) + new_mod.W_scales = nn.Parameter(W_scales) + new_mod.bias = mod.bias + if not use_fused_int_mm: + new_mod.fake_rescale = torch.tensor( + [1.0], dtype=mod.weight.dtype, device=mod.weight.device + ) + del new_mod.weight + + device_to_use = next(mod.parameters()).device + new_mod.to(device_to_use) + return new_mod diff --git a/ao/quantization/quant_api.py b/ao/quantization/quant_api.py new file mode 100644 index 0000000000..20b09b5cb3 --- /dev/null +++ b/ao/quantization/quant_api.py @@ -0,0 +1,74 @@ +""" +Quantization API stuff which is not specific to SmoothQuant + +Note: this is throwaway code for fast results on Blueberry, this is not +intended to be the actual long term quantization API for server GPUs. +""" + +import torch +from dynamic_quant import ( + DynamicallyPerAxisQuantizedLinear, +) +from subclass import ( + DynamicallyQuantizedLinearWeight, +) +from weight_only import ( + WeightOnlyInt8QuantLinear, +) + +__all__ = [ + "replace_with_custom_fn_if_matches_filter", + "apply_weight_only_int8_quant", + "apply_dynamic_quant", + "change_linear_weights_to_dqtensors", +] + + +def replace_with_custom_fn_if_matches_filter( + model, replacement_fn, filter_fn, cur_fqn="" +) -> None: + """ + For each `child` in `model`, replaces it with `replacement_fn(child)` + if `filter_fn(child)` is `True` + """ + name_to_child = dict(model.named_children()) + for name, child in name_to_child.items(): + if cur_fqn == "": + new_fqn = name + else: + new_fqn = f"{cur_fqn}.{name}" + if filter_fn(child, new_fqn): + new_child = replacement_fn(child) + setattr(model, name, new_child) + else: + replace_with_custom_fn_if_matches_filter( + child, replacement_fn, filter_fn, new_fqn + ) + + +def apply_weight_only_int8_quant(model): + replace_with_custom_fn_if_matches_filter( + model, + WeightOnlyInt8QuantLinear.from_float, + lambda mod, fqn: isinstance(mod, torch.nn.Linear), + ) + + +def apply_dynamic_quant(model, use_fused_int_mm=0): + replace_with_custom_fn_if_matches_filter( + model, + lambda mod: DynamicallyPerAxisQuantizedLinear.from_float(mod, use_fused_int_mm), + lambda mod, fqn: isinstance(mod, torch.nn.Linear), + ) + + +def change_linear_weights_to_dqtensors(model): + def insert_subclass(lin): + lin.weight = torch.nn.Parameter( + DynamicallyQuantizedLinearWeight.from_float(lin.weight), requires_grad=False + ) + return lin + + replace_with_custom_fn_if_matches_filter( + model, insert_subclass, lambda mod, fqn: isinstance(mod, torch.nn.Linear) + ) diff --git a/ao/quantization/quant_primitives.py b/ao/quantization/quant_primitives.py new file mode 100644 index 0000000000..e8a53e3135 --- /dev/null +++ b/ao/quantization/quant_primitives.py @@ -0,0 +1,384 @@ +import torch +from torch._dynamo import is_compiling as dynamo_is_compiling +from torch._higher_order_ops.out_dtype import out_dtype + +__all__ = [ + "safe_int_mm", + "dynamically_quantize_per_tensor", + "quantize_activation_per_token_absmax", + "dynamically_quantize_per_channel", + "dequantize_per_tensor", + "dequantize_per_channel", + "quant_int8_dynamic_linear", + "quant_int8_matmul", + "quant_int8_dynamic_per_token_linear", + "quant_int8_per_token_matmul", +] + + +def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: + r""" + This function wraps torch._int_mm and avoids several undesirable behaviors of the function for certain inputs while still + returning correct results and being torch.compiled in a performant way. + + Assumes both tensors have dimension of 2. + + Note: no error checking for torch.compiled path, if input.shape = [i, j] and j<=16 then the triton kernel + will error. + + Args: + input (Tensor, int8): the first tensor to be multiplied + mat2 (Tensor, int8): the second tensor to be multiplied + + Return: + out (Tensor, int32): the result of the matmul with device matching that of the inputs + """ + + # torch.compile path + if dynamo_is_compiling() or "FakeTensor" in input.__repr__(): + return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) + + # error checking for cublas path + assert ( + mat2.device == input.device + ), f"need both tensors to be on the same device but got {mat2.device} and {input.device}" + device_cpu = "cpu" in [mat2.device.type, input.device.type] + # with input.shape = [i,j] and mat2.shape = [j,k] + i_is_strictly_greater_than_16 = input.shape[0] > 16 + j_is_nonzero_multiple_of_8 = (input.shape[1] % 8 == 0) and (input.shape[1] > 0) + k_is_nonzero_multiple_of_8 = (mat2.shape[1] % 8 == 0) and (mat2.shape[1] > 0) + bad_dimensions_for_cublas = not ( + i_is_strictly_greater_than_16 + and j_is_nonzero_multiple_of_8 + and k_is_nonzero_multiple_of_8 + ) + + if device_cpu or bad_dimensions_for_cublas: + # fallback path + return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to( + input.device.type + ) + + # cublas paths + if not mat2.is_contiguous(): # silently gives incorrect result without this + mat2 = mat2.contiguous() + if (not input.is_contiguous()) and ( + input.shape[0] % 8 != 0 + ): # gives cryptic error without this + input = ( + input.contiguous() + ) # (it seems the transpose makes cublas check the above j constraint on i) + return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) + + +# copy-pasta of https://www.internalfb.com/intern/anp/view/?id=3350736 +def dynamically_quantize_per_tensor( + x, + quant_min, + quant_max, + target_dtype, + qscheme=torch.per_tensor_affine, # for now, reuse existing qscheme enum +): + # assumes affine quantization + + # default setup for affine quantization of activations + eps = torch.finfo(torch.float32).eps + + if qscheme == torch.per_tensor_affine: + # get min and max + # TODO(future): make torch.aminmax work on cpu-half + # min_val, max_val = torch.aminmax(x) + min_val = torch.min(x) + max_val = torch.max(x) + + # calculate scale and zero point based on min and max + # reference: https://fburl.com/code/srbiybme + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + device = min_val_neg.device + + scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) + # TODO(future): make torch.clamp with scalar work on cpu-half + scale = torch.clamp(scale, min=eps).reshape(1) + zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int) + zero_point = torch.clamp(zero_point, quant_min, quant_max) + + # quantize based on qmin/qmax/scale/zp + # reference: torch/ao/quantization/fx/_decomposed.py?lines=63 + quant = torch.clamp( + torch.round(x / scale) + zero_point, quant_min, quant_max + ).to(target_dtype) + + else: + assert qscheme == torch.per_tensor_symmetric, f"unsupported qscheme {qscheme}" + # assert quant_min == -1 * quant_max, "unsupported quant_min/quant_max" + amax = torch.max(torch.abs(x)) + scale = amax / (float(quant_max - quant_min) / 2) + scale = torch.clamp(scale, min=eps).reshape(1) + quant = torch.clamp(torch.round(x / scale), quant_min, quant_max).to( + target_dtype + ) + # do not create a tensor for zero_point as this is expensive + zero_point = None + + return quant, scale, zero_point + + +# taken from +# https://github.com/mit-han-lab/smoothquant/blob/2f87951dacfb9238d8d657f52ae83a82a3c9ba0c/smoothquant/fake_quant.py#L26 +# and slightly modified +def quantize_activation_per_token_absmax(t): + n_bits = 8 + # if the shape of t is [B, N, K], the shape of scales will be [B, N, 1] + + scales = t.abs().amax(dim=-1, keepdim=True) + if scales.dtype == torch.float16: + scales = ( + scales.float() + ) # want float scales to avoid overflows for fp16, (bf16 has wide enough range) + q_max = 2 ** (n_bits - 1) - 1 + scales = scales.clamp(min=1e-5).div(q_max) + # Note: the original smoothquant does not clamp to qmin/qmax here, + # but some of the tests with bfloat16 ended up with a flipped sign + # if we don't clamp. TODO(future) look into this further. + t = torch.round(t / scales).clamp(-127, 127).to(torch.int8) + return t, scales + + +def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): + # assumes symmetric quantization + # assumes axis == 0 + # assumes dense memory format + # TODO(future): relax ^ as needed + + # default setup for affine quantization of activations + eps = torch.finfo(torch.float32).eps + + # get min and max + min_val, max_val = torch.aminmax(x, dim=1) + + # calculate scale and zero point based on min and max + # reference: https://fburl.com/code/srbiybme + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + device = min_val_neg.device + + # reference: https://fburl.com/code/4wll53rk + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + # ensure scale is the same dtype as the original tensor + scale = torch.clamp(scale, min=eps).to(x.dtype) + zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) + + # quantize based on qmin/qmax/scale/zp + # reference: torch/ao/quantization/fx/_decomposed.py?lines=63 + x_div = x.transpose(0, 1) / scale + x_round = torch.round(x_div) + x_zp = x_round + zero_point + x_zp = x_zp.transpose(0, 1) + quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype) + + return quant, scale, zero_point + + +# reference: https://fburl.com/code/vfsygwd0 +def dequantize_per_tensor(int_repr, scale, zero_point, out_dtype=torch.float32): + y = int_repr.to(out_dtype) + if zero_point is not None: + y -= zero_point + return y * scale + + +# reference: https://fburl.com/code/org0fmi3 +def dequantize_per_channel(int_repr, scales, zero_points, out_dtype=torch.float32): + # assumes axis is 0 + y = int_repr.transpose(0, 1) + y = y.to(out_dtype) + y = y - zero_points + y = y * scales + y = y.transpose(0, 1) + return y + + +def quant_int8_dynamic_linear( + x, + x_quant_min, + x_quant_max, + x_q_dtype, + w_vals_int8_t, + w_scales, + w_vals_int8_t_sums_int64, + bias, + out_dtype=torch.float32, +): + # like F.linear, but with int8 dynamic quantization of activation, + # and a quantized weight + x_vals_int8, x_scale, x_zp = dynamically_quantize_per_tensor( + x, x_quant_min, x_quant_max, x_q_dtype + ) + # w_vals_int8_t_sums_int64 = w_vals_int8_t.sum(dim=0) + mm_out = quant_int8_matmul( + x_vals_int8, + x_scale, + x_zp, + w_vals_int8_t, + w_vals_int8_t_sums_int64, + w_scales, + out_dtype, + ) + if bias is not None: + mm_out += bias + return mm_out + + +def quant_int8_matmul( + x_vals_int8, + x_scale, + x_zp, + w_vals_int8_t, + w_vals_int8_t_sums_int64, + w_scales, + out_dtype=torch.float32, +): + # Quantized matmul of int8 operands that accumulates to int32 and returns + # out_dtype. For now, this is written for approximate numerical + # correctness, and things like aligning accumulation behaviors and + # performance optimizations are left for a future PR. + # Assumes that weight quantization is symmetric, i.e. w_zp is 0. + # Assumes that weight quantization is per-channel. + + # see + # https://github.com/google/gemmlowp/blob/master/doc/quantization.md + # for an overview of quantized matmul compute + + # in scalar form, assuming out_dtype is fp32 and zw == 0: + # + # Y_i_j_fp32 = sx * sw (dot(X_i, W_j) - zx * sum(W_j)) + # + + assert x_vals_int8.dtype in ( + torch.uint8, + torch.int8, + ), f"x dtype {x_vals_int8.dtype} not yet supported" + assert ( + w_vals_int8_t.dtype == torch.int8 + ), f"w dtype {w_vals_int8_t.dtype} not yet supported" + assert w_scales.dtype == out_dtype, f"{w_scales.dtype} does not match {out_dtype}" + + # + # 1. do the matrix form of dot(X_i, W_j) + # + + # TODO(before land): add test case for input with bsz + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + y_dot_int32 = safe_int_mm(tmp, w_vals_int8_t) + y_dot_int32 = y_dot_int32.reshape(*x_vals_int8.shape[:-1], -1) + + # TODO(future): consider using integer arithmetic throughout, although + # TBD if that is actually faster on GPUs + # need to use 32 bits here to prevent overflow for large shapes, + # 16 bits is not enough + y_dot_float32 = y_dot_int32.to(torch.float32) + + # + # 2. connect it all together + # + + # mm_unscaled has to stay in float32 for the next two lines to prevent overflow + mm_unscaled_float32 = y_dot_float32 - (x_zp * w_vals_int8_t_sums_int64) + y = x_scale * w_scales * mm_unscaled_float32 + # can downcast only at the very end + y = y.to(out_dtype) + return y + + +def quant_int8_dynamic_per_token_linear( + x, + w_vals_int8_t, + w_scales, + bias, + out_dtype=torch.float32, + use_fused_int_mm=0, +): + # like F.linear, but with int8 dynamic quantization of activation, + # and a quantized weight + x_vals_int8, x_scales = quantize_activation_per_token_absmax(x) + mm_out = quant_int8_per_token_matmul( + x_vals_int8, x_scales, w_vals_int8_t, w_scales, out_dtype, use_fused_int_mm + ) + if bias is not None: + mm_out += bias + return mm_out + + +def quant_int8_per_token_matmul( + x_vals_int8, + x_scales, + w_vals_int8_t, + w_scales, + output_dtype=torch.float32, + use_fused_int_mm=0, +): + # Quantized matmul of int8 operands that accumulates to int32 and returns + # output_dtype. For now, this is written for approximate numerical + # Assumes that activation and weight quantization are symmetric, + # i.e. act_zp and w_zp is 0. + # Assumes that weight quantization is per-channel. + + # see + # https://github.com/google/gemmlowp/blob/master/doc/quantization.md + # for an overview of quantized matmul compute + + # in scalar form, assuming output_dtype is fp32 and zw == 0: + # + # Y_i_j_fp32 = sx * sw dot(X_i, W_j) + # + + assert ( + x_vals_int8.dtype == torch.int8 + ), f"x dtype {x_vals_int8.dtype} not yet supported" + assert ( + w_vals_int8_t.dtype == torch.int8 + ), f"w dtype {w_vals_int8_t.dtype} not yet supported" + assert ( + w_scales.dtype == output_dtype + ), f"{w_scales.dtype} does not match {output_dtype}" + + # + # 1. do the matrix form of dot(X_i, W_j) + # + + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + # these branches use external triton fused_int_mm kernel's which fuse either 1 or 2 mul operations + if use_fused_int_mm == 2: + y = torch.ops.custom_int_mm.int_mm_dequant( + tmp, w_vals_int8_t, x_scales.view(-1, 1), w_scales, output_dtype + ).reshape(*x_vals_int8.shape[:-1], -1) + return y + elif use_fused_int_mm == 1: + y = torch.ops.custom_int_mm.int_mm_one_mul( + tmp, w_vals_int8_t, x_scales.view(-1, 1), output_dtype + ).reshape(*x_vals_int8.shape[:-1], -1) + y = y * w_scales + return y.to(output_dtype) + y_dot_int32 = safe_int_mm(tmp, w_vals_int8_t) + + # + # 2. rescale the output + # + # in cases with large matrices, y_dot_int32 can grow sufficiently + # large that y_dot_int32 * a float16 scale is greater than the maximum + # value of a float 16, (which results in a value of inf even if multiplying + # by the other scale would bring it within the expected range) + + assert x_scales.dtype in [ + torch.float, + torch.bfloat16, + ], f"x_scales needs to be a torch.float32 or torch.bfloat16 but got {x_scales.dtype}" + y = (y_dot_int32 * x_scales.view(-1, 1) * w_scales).reshape( + *x_vals_int8.shape[:-1], -1 + ) + + # can downcast only at the very end + y = y.to(output_dtype) + return y diff --git a/ao/quantization/smoothquant.py b/ao/quantization/smoothquant.py new file mode 100644 index 0000000000..80ce7893fd --- /dev/null +++ b/ao/quantization/smoothquant.py @@ -0,0 +1,237 @@ +""" +Testing out accuracy-only implementation of SmoothQuant +(https://arxiv.org/pdf/2211.10438.pdf) +Note: this is an application of input-weight equalization, with the addition that the +multiplication by scale is fused into the preceding layer, specifically for relevant +parts of transformer blocks. +""" + +import torch +import torch.nn.functional as F +import quant_api + +from quant_primitives import ( + dynamically_quantize_per_channel, + quant_int8_dynamic_per_token_linear, +) + +__all__ = [ + "get_scale", + "SmoothFakeDynQuantMixin", + "SmoothFakeDynamicallyQuantizedLinear", + "swap_linear_with_smooth_fq_linear", + "smooth_fq_linear_to_inference", + "set_smooth_fq_attribute", +] + +def get_scale(X_absmax, W_absmax, alpha=0.5): + """ + Calculate the scale based on abs(max(X)), abs(max(W)) and alpha + If X is of dimension `b*n*k` and W is dimension `k*m`, the returned + scale is of dimension `k`. + Note: X_absmax is calculated outside of this function because we + need to keep a running version of it during calibration. W_absmax + is calculated outside of this function for consistency with X_absmax. + """ + X_pow = torch.pow(X_absmax, alpha) + W_pow = torch.pow(W_absmax, 1.0 - alpha) + div = X_pow / W_pow + return div.reshape(-1) + + +class SmoothFakeDynQuantMixin(torch.nn.Module): + def init_smoothquant_variables(self, alpha): + self.calibrating = True + self.x_running_abs_max = None + self.register_buffer("smooth_scale", None) + self.alpha = alpha + # debug only + self.debug_skip_scaling = False + # self.debug_skip_scaling = True + + # Currently torch._int_mm cuBLAS underlying kernel does not work with + # non-contiguous weight. However, torch.compil'ing through + # torch._int_mm leads to triton code which is ~2x faster if the weight + # is transposed. So, for now we have a debug flag to toggle whether + # we store the quantized weight transposed, so that we can get correct + # numerics both in eager mode and after torch.compile. + # The default is True for cuBLAS / eager mode, set to False for + # torch.compile. + # self.store_w_int_repr_t = True + self.store_w_int_repr_t = False + + def update_x_running_abs_max(self, X): + # update the running max of incoming activations + all_dims_except_last = tuple(range(len(X.shape) - 1)) + cur_abs_max = torch.amax(torch.abs(X), dim=all_dims_except_last) + if self.x_running_abs_max is None: + self.x_running_abs_max = cur_abs_max + else: + self.x_running_abs_max = torch.max(cur_abs_max, self.x_running_abs_max) + + def get_scaled_quantized_w(self): + # inference + assert ( + self.smooth_scale is not None + ), "self.smooth_scale is None, did you turn on inference?" + W = self.weight + + # scale weight + # in the future, this can be done ahead of time instead of + # during inference + if not self.debug_skip_scaling: + # TODO(future): do below in `to_inference` instead of here + W = torch.matmul( + torch.diag(self.smooth_scale), W.transpose(0, 1) + ).transpose(0, 1) + + # fake quantize input and weight, and then do matmul in fp32/fp16 + # in the future, this should be replaced with quantized kernels which + # work on NVIDIA GPUs (such as protoquant's implementation) + W_dq_dtype = W.dtype + W_int_repr, W_scales, W_zps = dynamically_quantize_per_channel( + W, -128, 127, torch.int8 + ) + W_int_repr = W_int_repr.contiguous() + return W_int_repr, W_scales, W_zps + + def to_inference(self): + raise NotImplementedError() + + def fold_weight(self): + # note: _W_zps are zeroes and they are ignored + # TODO(future PR): set up serialization for this + W_int_repr, self.W_scales, _W_zps = self.get_scaled_quantized_w() + # need to store transposed weights to make eager mode matmul + # op work in cuBlas, or non-transposed to make it fast in torch.compile + if self.store_w_int_repr_t: + self.register_buffer("W_int_repr", W_int_repr.transpose(0, 1).contiguous()) + else: + self.register_buffer("W_int_repr", W_int_repr.contiguous()) + del self.weight + + def set_debug_x_absmax(self): + """ + Sets `self.x_running_abs_max` to a value which will lead to smooth scale + of all ones if `alpha=0.5`, to enable performance benchmarking without + calibration. + """ + raise NotImplementedError() + + +class SmoothFakeDynamicallyQuantizedLinear(SmoothFakeDynQuantMixin, torch.nn.Linear): + """ + This is a replacement for `torch.nn.Linear` which implements fake quantization + based on Smoothquant scaling. + """ + + def __init__(self, *args, **kwargs): + alpha = kwargs.pop("alpha") + super().__init__(*args, **kwargs) + self.init_smoothquant_variables(alpha) + + def forward(self, X): + if self.calibrating: + self.update_x_running_abs_max(X) + Y = F.linear(X, self.weight, self.bias) + else: + if not self.debug_skip_scaling: + # TODO(future): fuse this into previous layer (LayerNorm, + # RMSNorm, etc) where appropriate + X = X / self.smooth_scale + W_int_repr_t = ( + self.W_int_repr if self.store_w_int_repr_t else self.W_int_repr.t() + ) + Y = quant_int8_dynamic_per_token_linear( + X, W_int_repr_t, self.W_scales, self.bias, X.dtype + ) + return Y + + @classmethod + def from_float(cls, mod, alpha=0.5): + """ + Converts a `mod` of class `torch.nn.Linear` to the smooth fake quantized + version of it. Note: requires calibration. + """ + # create the new module with a toy size to ensure initialization is fast + fake_in_features, fake_out_features = 8, 8 + new_mod = cls( + fake_in_features, fake_out_features, bias=mod.bias is not None, alpha=alpha + ) + new_mod.in_features = mod.in_features + new_mod.out_features = mod.out_features + new_mod.weight = mod.weight + new_mod.bias = mod.bias + # TODO: test when creation is on cuda + device_to_use = next(mod.parameters()).device + new_mod.to(device_to_use) + return new_mod + + def to_inference(self): + """ + Calculates the smoothquant scale based on calibration + in preparation for inference + """ + assert self.x_running_abs_max is not None, "no calibration data found" + self.calibrating = False + self.smooth_scale = get_scale( + self.x_running_abs_max, + torch.max(torch.abs(self.weight.transpose(0, 1)), dim=1).values, + alpha=self.alpha, + ) + self.fold_weight() + + def set_debug_x_absmax(self): + w_absmax = torch.max(torch.abs(self.weight.transpose(0, 1)), dim=1).values + self.x_running_abs_max = w_absmax + +# +# utils to use the smooth linear on real models +# + +source_cls_to_target_cls = { + torch.nn.Linear: SmoothFakeDynamicallyQuantizedLinear, +} + + +def swap_linear_with_smooth_fq_linear( + model, skip_fqn_list=None, cur_fqn="", alpha=0.5 +) -> None: + name_to_child = dict(model.named_children()) + for name, child in name_to_child.items(): + if cur_fqn == "": + new_fqn = name + else: + new_fqn = f"{cur_fqn}.{name}" + if ((skip_fqn_list is None) or (new_fqn not in skip_fqn_list)) and isinstance( + child, tuple(source_cls_to_target_cls.keys()) + ): + target_cls = source_cls_to_target_cls[type(child)] + new_child = target_cls.from_float(child, alpha=alpha) + setattr(model, name, new_child) + else: + swap_linear_with_smooth_fq_linear(child, skip_fqn_list, new_fqn, alpha) + + +# code moved, avoid breaking callsites +# TODO clean this up +replace_with_custom_fn_if_matches_filter = ( + quant_api.replace_with_custom_fn_if_matches_filter +) + + +def smooth_fq_linear_to_inference(model, debug_skip_calibration=False) -> None: + for _, mod in model.named_modules(): + if isinstance(mod, tuple(source_cls_to_target_cls.values())): + if debug_skip_calibration: + mod.set_debug_x_absmax() + mod.to_inference() + + +# useful for quickly toggling smoothquant debug settings on all smoothquant +# modules in a model +def set_smooth_fq_attribute(model, attribute_name, new_attribute_val): + for _, mod in model.named_modules(): + if isinstance(mod, tuple(source_cls_to_target_cls.values())): + if hasattr(mod, attribute_name): + setattr(mod, attribute_name, new_attribute_val) diff --git a/ao/quantization/subclass.py b/ao/quantization/subclass.py new file mode 100644 index 0000000000..48959df922 --- /dev/null +++ b/ao/quantization/subclass.py @@ -0,0 +1,117 @@ +import torch +from quant_primitives import ( + dequantize_per_channel, + dynamically_quantize_per_channel, + quant_int8_dynamic_per_token_linear, +) +from torch.utils._python_dispatch import return_and_correct_aliasing + +__all__ = ["DynamicallyQuantizedLinearWeight"] + + +class DynamicallyQuantizedLinearWeight(torch.Tensor): + @staticmethod + def __new__(cls, input_data, q_scales, transposed=False, **kwargs): + # input data is assumed to be input so that q_axis is the 1th axis + # also assumes input is non contiguous + kwargs["device"] = input_data.device + kwargs["dtype"] = kwargs.get("dtype", torch.int8) + if input_data is not None: + kwargs["dtype"] = input_data.dtype + size = input_data.shape[::-1] if transposed else input_data.shape + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else input_data.layout + ) + return torch.Tensor._make_wrapper_subclass(cls, size, **kwargs) # type: ignore[attr-defined] + + def __init__(self, input_data, q_scales, transposed=False): + self.transposed = transposed + self.int_data = input_data + self.q_scales = q_scales + + def __repr__(self): + return f"DynamicallyQuantizedLinearWeight(shape={self.shape}, data={self.dequantize()})" + + def dequantize(self, dtype=None): + out = dequantize_per_channel( + self.int_data.t(), self.q_scales, 0, self.dtype if dtype is None else dtype + ) + return out if self.transposed else out.t() # already transposedd for dequantize + + def int_repr(self): + return self.int_data.t() if self.transposed else self.int_data + + def _detach(self): + return DynamicallyQuantizedLinearWeight( + self.int_data, self.q_scales, transposed=self.transposed + ) + + def _transposed(self): + return DynamicallyQuantizedLinearWeight( + self.int_data, self.q_scales, transposed=(not self.transposed) + ) + + def __tensor_flatten__(self): + return ["int_data", "q_scales"], self.transposed + + @staticmethod + def __tensor_unflatten__(tensor_data, transposed): + int_data, q_scales = tensor_data["int_data"], tensor_data["q_scales"] + return DynamicallyQuantizedLinearWeight( + int_data, q_scales, transposed=transposed + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + # two scenarios where we currently fall back to vanilla mm: + # 1 - when tensor is on CPU: we are missing qmm for CPU, but we should have a CPU implementation + # for consistency and to allow people to test + # 2 - we need to define what happens when we're given non-floats - quantizing long to int8 is probs craxy + if ( + func in [torch.ops.aten.mm.default, torch.ops.aten.addmm.default] + and args[0].is_floating_point() + and args[0].is_cuda + ): + if func == torch.ops.aten.addmm.default: + assert ( + args[1].shape[-1] == args[2].shape[0] + ), f"need mat1 shape: {args[1].shape} final dim to match mat2 shape: {args[2].shape} first dim " + mat1, mat2, scales, bias = ( + args[1], + args[2].int_data, + args[2].q_scales, + args[0], + ) + else: + assert ( + args[0].shape[-1] == args[1].shape[0] + ), f"need mat1 shape: {args[0].shape} final dim to match mat2 shape: {args[1].shape} first dim " + mat1, mat2, scales, bias = ( + args[0], + args[1].int_data, + args[1].q_scales, + None, + ) + return quant_int8_dynamic_per_token_linear( + mat1, mat2, scales, bias, mat1.dtype + ) + + if func is torch.ops.aten.detach.default: + return return_and_correct_aliasing(func, args, kwargs, args[0]._detach()) + + if func is torch.ops.aten.t.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._transposed() + ) + breakpoint() + return NotImplemented + + @classmethod + def from_float(cls, input_float, qmin=-128, qmax=127, dtype=torch.int8): + w_int_repr, w_scales, _ = dynamically_quantize_per_channel( + input_float, qmin, qmax, dtype + ) + # always store with quantized axis in dim=1 for fast matmul + return cls(w_int_repr.contiguous().t(), w_scales, transposed=True) diff --git a/ao/quantization/test.py b/ao/quantization/test.py new file mode 100644 index 0000000000..53fdffda3f --- /dev/null +++ b/ao/quantization/test.py @@ -0,0 +1,1024 @@ +# mypy: ignore-errors +import copy +import unittest + +import torch +import torch.nn as nn +from torch._inductor.utils import run_and_get_code + +from torch.ao.quantization import MinMaxObserver, QConfigMapping + +from dynamic_quant import ( + DynamicallyPerAxisQuantizedLinear, +) +from quant_api import ( + apply_dynamic_quant, + apply_weight_only_int8_quant, + change_linear_weights_to_dqtensors, +) +from quant_primitives import ( + dequantize_per_channel, + dequantize_per_tensor, + dynamically_quantize_per_channel, + dynamically_quantize_per_tensor, + quant_int8_dynamic_linear, + quant_int8_dynamic_per_token_linear, + quantize_activation_per_token_absmax, + safe_int_mm, +) + +from smoothquant import ( + get_scale, + replace_with_custom_fn_if_matches_filter, + smooth_fq_linear_to_inference, + SmoothFakeDynamicallyQuantizedLinear, + swap_linear_with_smooth_fq_linear, +) +from subclass import ( + DynamicallyQuantizedLinearWeight, +) +from utils import ( + apply_logging_hook, + compute_error, + compute_error as SQNR, + fqn_to_op_to_shape_to_count, + LoggingTensorMode, +) +from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx + +torch.manual_seed(0) + + +class SmoothquantUnitTest(unittest.TestCase): + # first, let's reproduce the graphic from the paper, Figure 4, to ensure + # we are calculating the scales correctly + def test_figure_4(self): + X = torch.FloatTensor([1, -16, 2, 6, -2, 8, -1, -9]).reshape(1, 2, 4) + W = torch.FloatTensor([2, 1, -2, 1, -1, -1, 2, -1, -2, -1, -1, 1]).reshape(4, 3) + X_mul_W = torch.matmul(X, W) + + smoothquant_scale = get_scale( + torch.amax(torch.abs(X), dim=(0, 1)), + torch.amax(torch.abs(W), dim=1), + alpha=0.5, + ) + + # reproduce scaled calculation + X_scaled = X / smoothquant_scale.reshape(1, 1, -1) + W_scaled = torch.matmul(torch.diag(smoothquant_scale), W) + X_scaled_mul_scaled_W = torch.matmul(X_scaled, W_scaled) + assert torch.allclose(X_mul_W, X_scaled_mul_scaled_W), "not close!" + assert X_mul_W.shape == X_scaled_mul_scaled_W.shape + + # next, run the above test on a sample of representative inputs + def test_tensors(self): + x_shape = (1, 5, 7) + w_shape = (7, 9) + for i in range(3): + X = torch.randn(x_shape) * 10 + W = torch.randn(w_shape) + s = get_scale( + torch.amax(torch.abs(X), dim=(0, 1)), + torch.amax(torch.abs(W), dim=1), + alpha=0.5, + ) + + Y = torch.matmul(X, W) + Y_ref = torch.matmul( + X / s.reshape(1, 1, -1), + torch.matmul(torch.diag(s), W), + ) + assert torch.allclose(Y, Y_ref, atol=1e-3, rtol=1e-3), "not close!" + + def _test_smooth_linear_impl(self, x_shape, lin_shape, device): + # so we can use the full range + torch.backends.quantized.engine = "qnnpack" + + x = torch.randn(*x_shape, device=device) * 9 + 10 + + lin_fp32 = nn.Linear(*lin_shape, device=device) # misc: ignore + lin_smooth = SmoothFakeDynamicallyQuantizedLinear.from_float( + copy.deepcopy(lin_fp32), alpha=0.25 + ) + lin_smooth_skip_scaling = SmoothFakeDynamicallyQuantizedLinear.from_float( + copy.deepcopy(lin_fp32), alpha=0.25 + ) + + lin_fp32_copy = copy.deepcopy(lin_fp32) # assignment: ignore + lin_fp32_copy.qconfig = torch.ao.quantization.QConfig( # assignment: ignore + activation=None, + weight=torch.ao.quantization.default_per_channel_weight_observer, + ) + lin_dynamic_q = torch.ao.nn.quantized.dynamic.Linear.from_float( + lin_fp32_copy.cpu() + ) + + y_ref = lin_fp32(x) + + # calibrate the smoothquant versions + y_smooth_nocalib = lin_smooth(x) + _ = lin_smooth_skip_scaling(x) + lin_smooth.to_inference() + lin_smooth_skip_scaling.debug_skip_scaling = True + lin_smooth_skip_scaling.to_inference() + + # verify that with scaling turned off, numerics match quantized version + y_smooth_fq_only = lin_smooth_skip_scaling(x) + y_smooth_fq = lin_smooth(x) + y_dynamic_q = lin_dynamic_q(x.cpu()).to(device) + + # print('y_ref', y_ref) + # print('y_smooth_nocalib', y_smooth_nocalib) + # print('y_smooth_fq', y_smooth_fq) + # print('y_smooth_fq_only', y_smooth_fq_only) + # print('y_dynamic_q', y_dynamic_q) + + sqnr_smooth_fq = compute_error(y_ref, y_smooth_fq) + sqnr_dynamic_q = compute_error(y_ref, y_dynamic_q) + sqnr_fq = compute_error(y_smooth_fq_only, y_dynamic_q) + # print('sqnr_smooth', sqnr_smooth_fq, 'sqnr_dynamic', sqnr_dynamic_q, 'sqnr_fq', sqnr_fq) + + assert torch.allclose( + y_ref, y_smooth_nocalib + ), "y_ref not close to y_smooth_nocalib" + # after https://github.com/pytorch-labs/ao_benchmarks/pull/32, + # numerics do not match exactly between production c++ code + # and this Python code + # assert torch.allclose( + # y_smooth_fq_only, y_dynamic_q, + # atol=torch.max(y_smooth_fq_only).item()*0.01, + # rtol=0.00001), \ + # 'y_smooth_fq_only not close to y_dynamic_q' + + self.assertTrue(sqnr_smooth_fq.item() >= 40.0) + self.assertTrue(sqnr_dynamic_q.item() >= 40.0) + self.assertTrue(sqnr_fq.item() >= 40.0) + + def test_smooth_linear_cpu(self): + self._test_smooth_linear_impl((1, 5, 3), (3, 4), "cpu") + + def test_smooth_linear_cuda(self): + if not torch.cuda.is_available(): + print("no cuda, skip") + return + self._test_smooth_linear_impl((1, 32, 32), (32, 16), "cuda") + + def test_smooth_linear_edge_cases(self): + # so we can use the full range + torch.backends.quantized.engine = "qnnpack" + lin_fp32 = nn.Linear(3, 4) + lin_smooth = SmoothFakeDynamicallyQuantizedLinear.from_float( + lin_fp32, alpha=0.25 + ) + + # test different ranks + x0 = torch.randn(4, 5, 3) + x1 = torch.randn(1, 8, 5, 3) + x2 = torch.randn(2, 3, 7, 5, 3) + + # calibrate + _ = lin_smooth(x0) + _ = lin_smooth(x1) + _ = lin_smooth(x2) + + # inference + lin_smooth.to_inference() + _ = lin_smooth(x0) + _ = lin_smooth(x1) + _ = lin_smooth(x2) + + def test_swap(self): + m = nn.Sequential( + nn.Sequential(nn.Linear(4, 4), nn.ReLU(), nn.Linear(4, 4)), + nn.Linear(4, 4), + ) + m_copy = copy.deepcopy(m) + swap_linear_with_smooth_fq_linear(m_copy, skip_fqn_list=["0.2"]) + + # verify all linears are swapped + assert isinstance(m_copy[0][0], SmoothFakeDynamicallyQuantizedLinear) + assert isinstance(m_copy[0][1], nn.ReLU) + # this one was skipped + assert isinstance(m_copy[0][2], nn.Linear) + assert isinstance(m_copy[1], SmoothFakeDynamicallyQuantizedLinear) + + # verify results do not change without smoothing + x = torch.randn(4, 4) + y_ref = m(x) + y = m_copy(x) + assert torch.allclose(y_ref, y) + + def test_weight_t_and_non_t_numerics_match(self): + # verify that numerics match whether weight is stored + # in transposed format (for cuBLAS) vs non-transposed format + # (for torch.compile) + if not torch.cuda.is_available(): + print("no cuda, skip") + return + dtype = torch.half + device = "cuda" + lin_ref = nn.Linear(32, 16, dtype=dtype, device=device) + lin_eager_t = copy.deepcopy(lin_ref) + lin_opt_t = copy.deepcopy(lin_eager_t) + lin_opt = copy.deepcopy(lin_eager_t) + lin_eager_t = SmoothFakeDynamicallyQuantizedLinear.from_float(lin_eager_t) + lin_opt_t = SmoothFakeDynamicallyQuantizedLinear.from_float(lin_opt_t) + lin_opt = SmoothFakeDynamicallyQuantizedLinear.from_float(lin_opt) + lin_opt.store_w_int_repr_t = False + + x = torch.randn(32, 32, dtype=dtype, device=device) + + y_calib_eager_t = lin_eager_t(x) + y_calib_opt_t = lin_opt_t(x) + y_calib_opt = lin_opt(x) + torch.testing.assert_close(y_calib_eager_t, y_calib_opt_t) + torch.testing.assert_close(y_calib_eager_t, y_calib_opt) + + lin_eager_t.to_inference() + lin_opt_t.to_inference() + lin_opt.to_inference() + + torch.testing.assert_close(lin_eager_t.W_int_repr, lin_opt_t.W_int_repr) + torch.testing.assert_close(lin_eager_t.W_int_repr, lin_opt.W_int_repr) + + lin_opt_t = torch.compile(lin_opt_t, mode="max-autotune") + lin_opt = torch.compile(lin_opt, mode="max-autotune") + + y_ref = lin_ref(x) + y_eager = lin_eager_t(x) + y_opt_t = lin_opt_t(x) + y_opt = lin_opt(x) + + if not torch.any(torch.isinf(y_ref)) and torch.any(torch.isinf(y_eager)): + # eager mode torch._int_mm is sometimes buggy, when this happens + # we can't really compare the compiled version against it properly + print("eager mode torch._int_mm known bad, test is inconclusive") + return + + sqnr_ref_eager = compute_error(y_ref, y_eager) + sqnr_eager_opt_t = compute_error(y_eager, y_opt_t) + sqnr_eager_opt = compute_error(y_eager, y_opt) + # since torch.compile for a torch.half model can + # change numerics significantly, we can only test for a high SQNR here + # and not for closeness + self.assertTrue(sqnr_eager_opt_t >= 45.0) + self.assertTrue(sqnr_eager_opt >= 45.0) + # y_opt_t and y_opt should be equivalent + torch.testing.assert_close(y_opt_t, y_opt) + + def test_selective_torch_compile(self): + m = nn.Sequential( + nn.Linear(4, 4), + nn.Sequential( + nn.Linear(4, 4), + nn.Linear(4, 4), + ), + nn.Linear(4, 4), + ) + x = torch.randn(4, 4) + y_ref = m(x) + + replace_with_custom_fn_if_matches_filter( + m, + lambda mod: torch.compile(mod), + lambda mod, fqn: isinstance(mod, nn.Linear) and fqn != "1.0", + ) + + self.assertTrue(isinstance(m[0], torch._dynamo.eval_frame.OptimizedModule)) + self.assertTrue(isinstance(m[1][0], nn.Linear)) + self.assertTrue(isinstance(m[1][1], torch._dynamo.eval_frame.OptimizedModule)) + self.assertTrue(isinstance(m[2], torch._dynamo.eval_frame.OptimizedModule)) + + y = m(x) + torch.testing.assert_close(y, y_ref) + + def test_debug_x_absmax(self): + m = nn.Sequential(nn.Linear(3, 4)) + x0 = torch.randn(4, 5, 3) + y0 = m(x0) + swap_linear_with_smooth_fq_linear(m) + # no calibration, straight to inference, should not crash + smooth_fq_linear_to_inference(m, debug_skip_calibration=True) + y1 = m(x0) + + +class PythonQuantPrimitivesUnitTest(unittest.TestCase): + def _test_dynamic_quant_per_tensor_numerics_impl( + self, qmin, qmax, int_dtype, qint_dtype, float_dtype, device, qscheme + ): + x = torch.randn(256, dtype=float_dtype, device=device) + y_vals, y_scale, y_zero_point = dynamically_quantize_per_tensor( + x, qmin, qmax, int_dtype, qscheme + ) + + # reference + # quantize_per_tensor_dynamic doesn't work for half, so we cast there and back + x_for_ref = x.half().float() if float_dtype == torch.float16 else x + + # quantize_per_tensor_dynamic doesn't support qscheme, so we just do dynamic + # quant manually with observers + static quant + obs = MinMaxObserver( + dtype=qint_dtype, qscheme=qscheme, quant_min=qmin, quant_max=qmax + ).to(device) + obs(x_for_ref) + ref_scale, ref_zero_point = obs.calculate_qparams() + y_ref = torch.quantize_per_tensor( + x_for_ref, ref_scale, ref_zero_point, qint_dtype + ) + + # y_ref = torch.quantize_per_tensor_dynamic(x_for_ref, qint_dtype, False) + # print(y_ref) + if float_dtype == torch.float: + assert torch.equal(y_vals, y_ref.int_repr()) + else: + # numerics are not exactly aligned yet, off-by-one probably due + # to rounding + assert torch.max(torch.abs(y_vals - y_ref.int_repr())).item() <= 1 + torch.testing.assert_close( + y_scale, torch.tensor([y_ref.q_scale()], device=device, dtype=float_dtype) + ) + if y_zero_point is not None: + assert torch.equal( + y_zero_point, torch.tensor([y_ref.q_zero_point()], device=device) + ) + else: + self.assertTrue(y_ref.q_zero_point() == 0) + + # dequantize and check again + x_dq = dequantize_per_tensor(y_vals, y_scale, y_zero_point, float_dtype) + y_ref_dq = y_ref.dequantize().to(float_dtype) + if float_dtype == torch.float: + torch.testing.assert_close(x_dq, y_ref_dq) + else: + sqnr = compute_error(x_dq, y_ref_dq) + self.assertTrue(sqnr.item() > 45.0) + + def test_dynamic_quant_per_tensor_numerics_cpu(self): + # verifies that dynamic quant per tensor in plain pytorch matches + # numerics of production AO code + # TODO(future): test this on cpu-half, need to first make + # torch.aminmax support half on cpu + test_cases = ( + ( + 0, + 255, + torch.uint8, + torch.quint8, + torch.float32, + "cpu", + torch.per_tensor_affine, + ), + ( + -128, + 127, + torch.int8, + torch.qint8, + torch.float32, + "cpu", + torch.per_tensor_affine, + ), + ( + -128, + 127, + torch.int8, + torch.qint8, + torch.float32, + "cpu", + torch.per_tensor_symmetric, + ), + ( + -127, + 127, + torch.int8, + torch.qint8, + torch.float32, + "cpu", + torch.per_tensor_symmetric, + ), + ) + for row in test_cases: + self._test_dynamic_quant_per_tensor_numerics_impl(*row) + + def test_dynamic_quant_per_tensor_numerics_cuda(self): + # verifies that dynamic quant per tensor in plain pytorch matches + # numerics of production AO code + if not torch.cuda.is_available(): + print("no cuda, skip") + return + test_cases = ( + ( + -128, + 127, + torch.int8, + torch.qint8, + torch.float32, + "cuda", + torch.per_tensor_affine, + ), + ( + -128, + 127, + torch.int8, + torch.qint8, + torch.float16, + "cuda", + torch.per_tensor_affine, + ), + ( + -128, + 127, + torch.int8, + torch.qint8, + torch.float32, + "cuda", + torch.per_tensor_symmetric, + ), + ( + -128, + 127, + torch.int8, + torch.qint8, + torch.float16, + "cuda", + torch.per_tensor_symmetric, + ), + ( + -127, + 127, + torch.int8, + torch.qint8, + torch.float32, + "cuda", + torch.per_tensor_symmetric, + ), + ( + -127, + 127, + torch.int8, + torch.qint8, + torch.float16, + "cuda", + torch.per_tensor_symmetric, + ), + ) + for row in test_cases: + self._test_dynamic_quant_per_tensor_numerics_impl(*row) + + def _test_dynamic_quant_per_channel_numerics_impl( + self, qmin, qmax, int_dtype, qint_dtype, float_dtype, device + ): + # verifies that dynamic quant per channel in plain pytorch matches + # numerics of production AO code + # TODO(future): test this on cpu-half, need to first make + # torch.aminmax support half on cpu + + x = torch.randn(16, 32, device=device, dtype=float_dtype) + y_vals, y_scale, y_zero_point = dynamically_quantize_per_channel( + x, qmin, qmax, int_dtype + ) + + min_val, max_val = torch.aminmax(x, dim=1) + + # reference + weight_obs = torch.ao.quantization.MovingAveragePerChannelMinMaxObserver( + dtype=qint_dtype, + quant_min=qmin, + quant_max=qmax, + qscheme=torch.per_channel_symmetric, + averaging_constant=1.0, # make it ignore previous iterations + ) + weight_obs(x) + y_ref_scale, y_ref_zp = weight_obs.calculate_qparams() + y_ref_scale = y_ref_scale.to(device) + y_ref_zp = y_ref_zp.to(device) + # quantize_per_channel doesn't work for half, so we cast there and back + x_for_ref = x.half().float() if float_dtype == torch.float16 else x + y_ref = torch.quantize_per_channel( + x_for_ref, y_ref_scale, y_ref_zp, 0, qint_dtype + ) + + torch.testing.assert_close( + y_scale, y_ref.q_per_channel_scales().to(float_dtype) + ) + assert torch.equal(y_zero_point, y_ref.q_per_channel_zero_points()) + # this test case has one element where the rounding is off by one + # from Python-only code vs the c++ code, it's easy to repro with + # various shapes. + # Discussion here is relevant: https://github.com/pytorch/pytorch/issues/16498 + # TODO(future): figure out what to do about this + # assert torch.equal(int_vals, q_reference.int_repr()) + assert torch.max(torch.abs(y_vals - y_ref.int_repr())) <= 1 + + # dequantize + x_dq = dequantize_per_channel(y_vals, y_scale, y_zero_point) + x_ref_dq = y_ref.dequantize() + # off-by-one for scale is okay + torch.testing.assert_close( + x_dq, x_ref_dq, atol=torch.max(y_scale).item() * 1.01, rtol=0.0001 + ) + + def test_dynamic_quant_per_channel_numerics_cpu(self): + test_cases = ((-128, 127, torch.int8, torch.qint8, torch.float32, "cpu"),) + for row in test_cases: + self._test_dynamic_quant_per_channel_numerics_impl(*row) + + def test_dynamic_quant_per_channel_numerics_cuda(self): + if not torch.cuda.is_available(): + print("no cuda, skip") + return + test_cases = ( + (-128, 127, torch.int8, torch.qint8, torch.float32, "cuda"), + (-128, 127, torch.int8, torch.qint8, torch.float16, "cuda"), + ) + for row in test_cases: + self._test_dynamic_quant_per_channel_numerics_impl(*row) + + def _test_quantize_per_token_impl(self, device, dtype): + x = torch.randn(3, 3, 3, device=device, dtype=dtype) + xq, scales = quantize_activation_per_token_absmax(x) + x_dq = dequantize_per_tensor(xq, scales, None).to(x.dtype) + sqnr = compute_error(x, x_dq) + self.assertTrue(sqnr >= 45.0) + + def test_quantize_per_token_cpu(self): + for dtype in (torch.float32, torch.float16, torch.bfloat16): + self._test_quantize_per_token_impl("cpu", dtype) + + def test_quantize_per_token_cuda(self): + if not torch.cuda.is_available(): + print("no cuda, skip") + return + for dtype in (torch.float32, torch.float16, torch.bfloat16): + self._test_quantize_per_token_impl("cuda", dtype) + + def _test_per_token_linear_impl(self, device, dtype): + x = torch.randn(2, 16, 8, device=device, dtype=dtype) + w = torch.randn(16, 8, device=device, dtype=dtype) + wq, w_scales, _w_zp = dynamically_quantize_per_channel(w, -127, 127, torch.int8) + # Note: need to make the weight contiguous because we are + # testing in eager mode and cuBlas will not give correct results + # for a transposed weight + y = quant_int8_dynamic_per_token_linear( + x, wq.t().contiguous(), w_scales, None, dtype + ) + y_ref = torch.matmul(x, w.t()) + sqnr = compute_error(y_ref, y) + self.assertTrue(sqnr >= 42.0) + + def test_per_token_linear_cpu(self): + for dtype in (torch.float32,): + self._test_per_token_linear_impl("cpu", dtype) + + def test_per_token_linear_cuda(self): + if not torch.cuda.is_available(): + print("no cuda, skip") + return + for dtype in (torch.float32, torch.float16, torch.bfloat16): + self._test_per_token_linear_impl("cuda", dtype) + + def test__int_mm(self): + # TODO(future): figure out what here needs to move to PT core, + # if it's not already tested there + if not torch.cuda.is_available(): + print("no cuda, skip") + return + + m, k, n = 32, 32, 16 + x = torch.randint(-128, 127, (m, k), dtype=torch.int8, device="cuda") + w = torch.randint(-128, 127, (k, n), dtype=torch.int8, device="cuda") + + y_ref = torch.matmul(x.float(), w.float()).to(torch.int32) + y_raw = safe_int_mm(x, w) + + wrap_in_mm_opt = torch.compile(safe_int_mm, mode="max-autotune") + # note: triton chokes on the line below on k == 8 and n == 8 with + # https://www.internalfb.com/phabricator/paste/view/P683467944 + # TODO(future): file an issue + y_opt = wrap_in_mm_opt(x, w) + + torch.testing.assert_close(y_ref, y_raw, atol=0, rtol=0) + torch.testing.assert_close(y_ref, y_opt, atol=0, rtol=0) + + def test__int_mm_eager_and_torch_compile_numerics(self): + if not torch.cuda.is_available(): + print("no cuda, skip") + return + + def __int_mm_ref(x, w): + x = x.cpu().to(torch.int32) + w = w.cpu().to(torch.int32) + y = torch.matmul(x, w) + return y.cuda() + + shapes = ( + # minimal test shape + ((1, 32, 32), (32, 16)), + # paste of real linear shapes from LLaMa 1.5b + ((17, 1, 1536), (1536, 1536)), + ((17, 8, 4096), (4096, 1536)), + ((17, 1, 1536), (1536, 4096)), + ((17, 8, 1536), (1536, 1536)), + ((17, 1, 4096), (4096, 1536)), + ((17, 8, 1536), (1536, 4096)), + ) + + for x_shape, w_shape in shapes: + + def wrap_torch_int_mm(x, w): + b, n, k = x.shape + k, m = w.shape + x = x.reshape(b * n, k) + res = safe_int_mm(x, w) + res = res.reshape(b, n, m) + return res + + wrap_torch_int_mm_opt = torch.compile( + wrap_torch_int_mm, mode="max-autotune" + ) + + x = torch.randint(-128, 127, x_shape, dtype=torch.int8, device="cuda") + w = torch.randint(-128, 127, w_shape, dtype=torch.int8, device="cuda") + + z_ref = __int_mm_ref(x, w) + z_eager = wrap_torch_int_mm(x, w) + z_torch_compile = wrap_torch_int_mm_opt(x, w) + # print(z_ref) + # print(z_eager) + # print(z_torch_compile) + + torch.testing.assert_close(z_ref, z_eager, atol=0, rtol=0) + torch.testing.assert_close(z_ref, z_torch_compile, atol=0, rtol=0) + + def _test_qlinear_per_channel_numerics( + self, x_shape, lin_shape, qmin, qmax, int_dtype, qint_dtype, float_dtype, device + ): + qconfig = torch.ao.quantization.per_channel_dynamic_qconfig + + x = torch.randn(*x_shape, device=device, dtype=float_dtype) + + # TODO: test bias true and false + # Note: reference path only works on float because lack of aten quant primitives + # support of half, so we cast back and forth to emulate + lin_ref = ( + nn.Sequential(nn.Linear(*lin_shape)) + .eval() + .to(float_dtype) + .float() + .to(device) + ) + y_ref = lin_ref(x.float()) + weight = lin_ref[0].weight + bias = lin_ref[0].bias + + qconfig_mapping = QConfigMapping().set_global(qconfig) + lin_ref_p = prepare_fx(lin_ref, qconfig_mapping, (torch.randn(1, 1),)) + lin_ref_q = convert_to_reference_fx(lin_ref_p) + y_q_ref = lin_ref_q(x.float()) + + # scale, zp of weight (get from reference model) + w_obs = qconfig.weight() + w_obs(weight) + lin_ref_w_scale, lin_ref_w_zp = w_obs.calculate_qparams() + lin_ref_w_scale = lin_ref_w_scale.to(device).to(float_dtype) + # print('lin_ref_w', 'scale', lin_ref_w_scale, 'zp', lin_ref_w_zp) + + w_vals, _s, _z = dynamically_quantize_per_channel( + getattr(lin_ref_q, "0").weight.to(float_dtype), -128, 127, torch.int8 + ) + w_vals = w_vals.t().contiguous() + w_vals_sums = w_vals.sum(dim=0) + + # do our version of the quantized linear operator + y = quant_int8_dynamic_linear( + x, + qmin, + qmax, + int_dtype, + w_vals, + lin_ref_w_scale, + w_vals_sums, + bias, + float_dtype, + ) + + # print('y', y) + # print('y_q_ref', y_q_ref) + # print('y_ref', y_ref) + + sqnr_ref = compute_error(y_ref, y_q_ref) + sqnr_our = compute_error(y_ref, y) + # print('sqnr_ref', sqnr_ref, 'sqnr_our', sqnr_our) + # for large shapes, sqnr can be in the high 30s for float32 and float16 + self.assertTrue(sqnr_our.item() >= 37.5) + + def test_qlinear_per_channel_numerics_cpu(self): + # Note: the AO codebase doesn't easily support qint8 activations, + # so the test cases below are for the quant primitives defined in + # this file only. The AO reference is using quint8 here. + test_cases = ( + ((2, 3), (3, 4), 0, 255, torch.uint8, torch.quint8, torch.float32, "cpu"), + ((2, 3), (3, 4), -128, 127, torch.int8, torch.qint8, torch.float32, "cpu"), + ) + for test_case in test_cases: + self._test_qlinear_per_channel_numerics(*test_case) + + def test_qlinear_per_channel_numerics_cuda(self): + if not torch.cuda.is_available(): + print("no cuda, skip") + return + test_cases = ( + # Note: torch._int_mm needs int8 activations, so we don't test uint8 + # activations on CUDA at all + ( + (32, 32), + (32, 16), + -128, + 127, + torch.int8, + torch.qint8, + torch.float32, + "cuda", + ), + ( + (32, 32), + (32, 16), + -128, + 127, + torch.int8, + torch.qint8, + torch.float16, + "cuda", + ), + # a large shape from LLaMa 1.5B - currently fails for float16 + ( + (17, 4096), + (4096, 1536), + -128, + 127, + torch.int8, + torch.qint8, + torch.float32, + "cuda", + ), + ( + (17, 4096), + (4096, 1536), + -128, + 127, + torch.int8, + torch.qint8, + torch.float16, + "cuda", + ), + ) + for test_case in test_cases: + self._test_qlinear_per_channel_numerics(*test_case) + + +class TestSubclass(unittest.TestCase): + def test_dq_lin_weight_subclass_aot(self): + m, k, n = 32, 64, 32 + x = torch.randn(m, k, device="cuda", dtype=torch.float32) + lin = torch.nn.Linear(k, n, device="cuda") + + import copy + + linq = DynamicallyPerAxisQuantizedLinear.from_float(copy.deepcopy(lin)) + + ref_f = lin(x) + ref_q = linq(x) + + print(SQNR(ref_f, ref_q), "float to dq") + + lin.weight = torch.nn.Parameter( + DynamicallyQuantizedLinearWeight.from_float(lin.weight), requires_grad=False + ) + test = lin(x) + print(SQNR(ref_f, test), "float to dq class") + print(SQNR(ref_q, test), "dq to dq class") + assert SQNR(ref_f, test) > 35 + assert SQNR(ref_q, test) > 35 + + lin_comp = torch.compile(lin, backend="aot_eager") + linq_comp = torch.compile(linq, backend="aot_eager") + test_comp = lin_comp(x) + ref_q_comp = linq_comp(x) + print(SQNR(ref_f, test_comp), "float to dq class compiled") + print(SQNR(ref_q_comp, test_comp), "dq compiled to dq class compiled") + assert SQNR(ref_f, test_comp) > 35 + assert SQNR(ref_q_comp, test_comp) > 35 + + def test_dq_lin_weight_subclass_max_autotune(self): + m, k, n = 32, 64, 32 + x = torch.randn(m, k, device="cuda", dtype=torch.float32) + lin = torch.nn.Linear(k, n, device="cuda") + + import copy + + linq = DynamicallyPerAxisQuantizedLinear.from_float(copy.deepcopy(lin)) + + ref_f = lin(x) + ref_q = linq(x) + + print(SQNR(ref_f, ref_q), "float to dq") + + lin.weight = torch.nn.Parameter( + DynamicallyQuantizedLinearWeight.from_float(lin.weight), requires_grad=False + ) + test = lin(x) + print(SQNR(ref_f, test), "float to dq class") + print(SQNR(ref_q, test), "dq to dq class") + assert SQNR(ref_f, test) > 35 + assert SQNR(ref_q, test) > 35 + + lin_comp = torch.compile(lin, mode="max-autotune") + linq_comp = torch.compile(linq, mode="max-autotune") + + test_comp = lin_comp(x) + ref_q_comp = linq_comp(x) + print(SQNR(ref_f, test_comp), "float to dq class compiled") + print(SQNR(ref_q_comp, test_comp), "dq compiled to dq class compiled") + assert SQNR(ref_f, test_comp) > 35 + assert SQNR(ref_q_comp, test_comp) > 35 + + @torch.no_grad() + def test_dq_lin_weight_subclass_max_autotune_api(self): + m, k, n = 32, 64, 32 + x = torch.randn(m, k, device="cuda", dtype=torch.float32) + + mod = nn.Sequential( + nn.Linear(k, n, device="cuda"), nn.ReLU(), nn.Linear(n, n, device="cuda") + ) + change_linear_weights_to_dqtensors(mod) + mod_qc = torch.compile(mod, mode="max-autotune") + mod_qc(x) + mod_qc(x) + + +class TestDynamicQuant(unittest.TestCase): + def test_dynamic_quant(self): + M, K, N = 8, 16, 8 + x = torch.randn(M, K) + m = nn.Sequential(nn.Linear(K, N)) + + y_ref = m(x) + apply_dynamic_quant(m) + y_test = m(x) + + sqnr = compute_error(y_ref, y_test) + self.assertGreater(sqnr, 40.0) + self.assertTrue(isinstance(m[0], DynamicallyPerAxisQuantizedLinear)) + + +class TestWeightOnlyInt8Quant(unittest.TestCase): + def test_weight_only_quant(self): + for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: + x = torch.randn(*x_shape) + m = nn.Sequential(nn.Linear(4, 5)) + y_ref = m(x) + apply_weight_only_int8_quant(m) + y_wo = m(x) + sqnr = compute_error(y_ref, y_wo) + self.assertGreater(sqnr, 44.0) + + @torch.no_grad() + def test_weight_only_quant_force_mixed_mm(self): + torch._inductor.config.epilogue_fusion = True + torch._inductor.config.force_mixed_mm = True + for x_dtype in [torch.float16, torch.bfloat16, torch.float32]: + for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: + torch._dynamo.reset() + x = torch.randn(*x_shape).to("cuda").to(x_dtype) + m = nn.Sequential(nn.Linear(4, 5)).to("cuda").to(x_dtype) + y_ref = m(x) + apply_weight_only_int8_quant(m) + m(x) + m_c = torch.compile(m, mode="max-autotune") + y_wo, (code,) = run_and_get_code(m_c, x) + sqnr = compute_error(y_ref, y_wo) + self.assertGreater(sqnr, 43.0) + self.assertTrue("mixed_mm" in code) + + def test_weight_only_quant_use_mixed_mm(self): + torch._inductor.config.epilogue_fusion = False + torch._inductor.config.use_mixed_mm = True + for x_dtype in [torch.float32, torch.float16, torch.bfloat16]: + for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: + torch._dynamo.reset() + x = torch.randn(*x_shape).to("cuda").to(x_dtype) + m = nn.Sequential(nn.Linear(4, 5)).to("cuda").to(x_dtype) + y_ref = m(x) + apply_weight_only_int8_quant(m) + m_c = torch.compile(m, mode="max-autotune") + y_wo, (code,) = run_and_get_code(m_c, x) + sqnr = compute_error(y_ref, y_wo) + self.assertGreater(sqnr, 43.0) + + +class TorchCompileUnitTest(unittest.TestCase): + def test_fullgraph(self): + if not torch.cuda.is_available(): + print("no cuda, skip") + return + lin_fp16 = nn.Linear(32, 16, device="cuda", dtype=torch.float16) + lin_smooth = SmoothFakeDynamicallyQuantizedLinear.from_float( + lin_fp16, alpha=0.25 + ) + + x0 = torch.randn(17, 1, 32, device="cuda", dtype=torch.float16) + + # calibrate + _ = lin_smooth(x0) + + # inference + lin_smooth.to_inference() + + # torch.compile + lin_smooth_opt = torch.compile(lin_smooth, fullgraph=True) + # print(lin_smooth_opt) + + y = lin_smooth_opt(x0) + # print(y) + + +class UtilsUnitTest(unittest.TestCase): + def test_shape_logger(self): + x = torch.randn(4, 4) + + m = nn.Sequential( + nn.Linear(4, 4), + nn.Sequential( + nn.Linear(4, 4), + ), + ) + + apply_logging_hook(m) + with LoggingTensorMode(): + m(x) + m(x) + + for fqn, d1 in fqn_to_op_to_shape_to_count.items(): # noqa: PERF102 + for op, d2 in d1.items(): # noqa: PERF102 + for shape, count in d2.items(): # noqa: PERF102 + # print(fqn, op, shape, count) + pass + + +class SmoothquantIntegrationTest(unittest.TestCase): + @torch.inference_mode() + def test_on_dummy_distilbert(self): + # https://huggingface.co/distilbert-base-uncased#how-to-use + from transformers import ( # type: ignore[import-untyped] + DistilBertModel, + DistilBertTokenizer, + ) + + tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") + model = DistilBertModel.from_pretrained("distilbert-base-uncased") + # print(model) + text = "Replace me by any text you'd like." + encoded_input = tokenizer(text, return_tensors="pt") + output_ref = model(**encoded_input) + # print(output_ref) + + # + # smooth_quant + # + model_copy = copy.deepcopy(model) + swap_linear_with_smooth_fq_linear(model_copy, alpha=0.75) + # calibrate + output_1_1 = model_copy(**encoded_input) + # inference + smooth_fq_linear_to_inference(model_copy) + output_1_2 = model_copy(**encoded_input) + # print(output_1_1) + # print(output_1_2) + sqnr_sq = compute_error( + output_ref.last_hidden_state, output_1_2.last_hidden_state + ) + print("sqnr_sq", sqnr_sq) + self.assertTrue(sqnr_sq >= 20.0) + + # + # reference - dynamic linear quant + # + model_copy2 = copy.deepcopy(model) + qconfig = torch.ao.quantization.QConfig( + activation=None, + weight=torch.ao.quantization.default_per_channel_weight_observer, + ) + model_copy2 = torch.ao.quantization.quantize_dynamic( + model_copy2, + {torch.nn.Linear: qconfig}, + ) + output_2_2 = model_copy2(**encoded_input) + # print(output_2_2) + sqnr_pt_quant = compute_error( + output_ref.last_hidden_state, output_2_2.last_hidden_state + ) + print("sqnr_pt_quant", sqnr_pt_quant) + self.assertTrue(sqnr_sq >= 8.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/ao/quantization/utils.py b/ao/quantization/utils.py new file mode 100644 index 0000000000..c53395579a --- /dev/null +++ b/ao/quantization/utils.py @@ -0,0 +1,98 @@ +import os +from typing import Dict, Optional + +import torch +from torch.utils._python_dispatch import TorchDispatchMode + +__all__ = [ + "log_with_rank", + "clear_logs", + "compute_error", + "apply_logging_hook", + "get_model_size_in_bytes", +] + + +def log_with_rank(*args): + # append + # + # {thing_to_log} + # + # to {file}_{rank}.txt, for printing stuff from multiple GPUs + if not os.path.exists(log_dir): + os.makedirs(log_dir) + with open(log_fname, "a") as f: + f.write(" ".join([str(s) for s in args]) + "\n") + if local_rank == 0: + print(*args) + + +def clear_logs(): + if os.path.isfile(log_fname): + os.remove(log_fname) + + +# basic SQNR +def compute_error(x, y): + Ps = torch.norm(x) + Pn = torch.norm(x - y) + return 20 * torch.log10(Ps / Pn) + + +# logger for fqn + op + shape +# note: not safe for any kind of multithreading +_cur_fqn: Optional[str] = None + + +def _get_logging_hook(fqn): + def forward_hook(module, input): + global _cur_fqn + _cur_fqn = fqn + + return forward_hook + + +def apply_logging_hook(model): + for name, mod in model.named_modules(): + mod.register_forward_pre_hook(_get_logging_hook(name)) + + +# collections.defaultdict printing is weird with lambdas, so hand writing for now +fqn_to_op_to_shape_to_count: Dict[ + Optional[str], Dict[Optional[str], Dict[Optional[str], int]] +] = {} + + +class LoggingTensorMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + rs = func(*args, **kwargs) + global _cur_fqn + op_name: str = f"{func.__module__}.{func.__name__}" + shape_str = "" + for arg in args: + if isinstance(arg, torch.Tensor): + shape_str += str(list(arg.shape)) + ", " + if shape_str != "": + shape_str = shape_str[:-2] + + if _cur_fqn not in fqn_to_op_to_shape_to_count: + fqn_to_op_to_shape_to_count[_cur_fqn] = {} + if op_name not in fqn_to_op_to_shape_to_count[_cur_fqn]: + fqn_to_op_to_shape_to_count[_cur_fqn][op_name] = {} + if shape_str not in fqn_to_op_to_shape_to_count[_cur_fqn][op_name]: + fqn_to_op_to_shape_to_count[_cur_fqn][op_name][shape_str] = 0 + fqn_to_op_to_shape_to_count[_cur_fqn][op_name][shape_str] += 1 + + return rs + + +# https://discuss.pytorch.org/t/finding-model-size/130275 +def get_model_size_in_bytes(model): + s = 0 + for p in model.parameters(): + s += p.nelement() * p.element_size() + for b in model.buffers(): + s += b.nelement() * b.element_size() + return s diff --git a/ao/quantization/weight_only.py b/ao/quantization/weight_only.py new file mode 100644 index 0000000000..0be9c8867b --- /dev/null +++ b/ao/quantization/weight_only.py @@ -0,0 +1,49 @@ +import torch +from quant_primitives import ( + dynamically_quantize_per_channel, +) + +__all__ = ["WeightOnlyInt8QuantLinear"] + + +class WeightOnlyInt8QuantLinear(torch.nn.Linear): + def __init__(self, *args, **kwargs): + w_int8 = kwargs.pop("w_int8") + scales = kwargs.pop("scales") + super().__init__(*args, **kwargs) + self.w_int8 = w_int8 + self.scales = scales + + def forward(self, x): + # if len(x.shape)<=2: + # y = torch.mm(x, self.w_int8.to(x.dtype)) * self.scales + # else: # turn x into 2d tensor, then undo it for y + x_view = x.view(-1, x.shape[-1]) + y = torch.mm(x_view, self.w_int8.to(x.dtype)) * self.scales + y = y.reshape(*x.shape[:-1], -1) + if self.bias is not None: + y += self.bias + return y + + @classmethod + def from_float(cls, mod): + w_fp32 = mod.weight + w_int8, scales, _zp = dynamically_quantize_per_channel( + w_fp32, -128, 127, torch.int8 + ) + # create the new module with a toy size to ensure initialization is fast + fake_in_features, fake_out_features = 8, 8 + new_mod = cls( + fake_in_features, + fake_out_features, + bias=mod.bias is not None, + w_int8=w_int8.t().contiguous(), + scales=scales, + ) + new_mod.in_features = mod.in_features + new_mod.out_features = mod.out_features + del new_mod.weight + new_mod.bias = mod.bias + device_to_use = next(mod.parameters()).device + new_mod.to(device_to_use) + return new_mod diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000..102c90b1da --- /dev/null +++ b/setup.py @@ -0,0 +1,14 @@ +from setuptools import setup, find_packages + +setup( + name='ao', + version='0.1', + packages=find_packages(), + install_requires=[ + 'torch', + ], + description='Package for applying ao techniques to GPU models', + long_description=open('README.md').read(), + long_description_content_type='text/markdown', + url='https://github.com/pytorch-labs/ao', +)