From 9cf3bf2adf1f5ee4b78c74641567d5f68071df88 Mon Sep 17 00:00:00 2001 From: Nagakawa Yuno Date: Fri, 12 Jul 2024 12:26:14 +0800 Subject: [PATCH 01/10] fix: fix random insertion for ATSP --- utils/insertion/src/randomInsertion.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/utils/insertion/src/randomInsertion.cpp b/utils/insertion/src/randomInsertion.cpp index c529a97..a39fd7a 100644 --- a/utils/insertion/src/randomInsertion.cpp +++ b/utils/insertion/src/randomInsertion.cpp @@ -68,7 +68,7 @@ unsigned *Insertion::randomInsertion(unsigned *order) // get target list and distances // and get insert position with minimum cost Node *thisnode = route, *nextnode = thisnode->next; - float thisdist = tspi->getdist(thisnode->value, city), nextdist = 0; + float thisdist = 0, nextdist = 0; Node *minnode = thisnode; float mindelta = INFINITY; float td = 0.0, nd = 0.0; @@ -76,14 +76,15 @@ unsigned *Insertion::randomInsertion(unsigned *order) for (unsigned j = 0; j < i; j++) { nextnode = thisnode->next; - nextdist = tspi->getdist(nextnode->value, city); + thisdist = tspi->getdist(thisnode->value, city); + nextdist = tspi->getdist(city, nextnode->value); float delta = thisdist + nextdist - nextnode->length; if (delta < mindelta) { mindelta = delta, minnode = thisnode; td = thisdist, nd = nextdist; } - thisnode = nextnode, thisdist = nextdist; + thisnode = nextnode; } // insert the selected node @@ -135,4 +136,4 @@ Node::~Node() { if (next != nullptr) delete next; -} \ No newline at end of file +} From e3e57fc84a14b54e795e2e0a150cd8873822cabd Mon Sep 17 00:00:00 2001 From: Furffico Date: Sat, 13 Jul 2024 11:13:41 +0800 Subject: [PATCH 02/10] fix: compile fixed random insertion --- eval_atsp/ATSProblemDef.py | 6 ++++-- eval_atsp/test_glop.py | 23 ++++++++++++++++++----- utils/insertion/insertion.so | Bin 24848 -> 29248 bytes utils/insertion/makefile | 9 +++++---- utils/insertion/src/randomInsertion.cpp | 3 ++- 5 files changed, 29 insertions(+), 12 deletions(-) diff --git a/eval_atsp/ATSProblemDef.py b/eval_atsp/ATSProblemDef.py index a9ced86..9e3a75a 100644 --- a/eval_atsp/ATSProblemDef.py +++ b/eval_atsp/ATSProblemDef.py @@ -110,6 +110,8 @@ def load_single_problem_from_file(filename, node_cnt, scaler): torch.manual_seed(1234) dataset_size = 30 - for scale in [150, 200, 1000]: + for scale in [150, 250, 1000]: problems = get_random_problems(dataset_size, scale, problem_gen_params) - torch.save(problems, "../data/atsp/ATSP{}.pt".format(scale)) \ No newline at end of file + print("instances generated") + torch.save(problems, "../data/atsp/ATSP{}.pt".format(scale)) + print(f"created ../data/atsp/ATSP{scale}.pt") \ No newline at end of file diff --git a/eval_atsp/test_glop.py b/eval_atsp/test_glop.py index 38f8d68..0a297a2 100644 --- a/eval_atsp/test_glop.py +++ b/eval_atsp/test_glop.py @@ -132,7 +132,10 @@ def revision(tour, inst, tester): improved[improved < 0] = 0 return improved.sum().item() - + +def calc_len(tour, dist): + cost = dist[tour, torch.roll(tour, -1, -1)].sum() + return cost def main(n): dataset = torch.load('../data/atsp/ATSP{}.pt'.format(n), map_location='cuda:0') @@ -143,21 +146,31 @@ def main(n): model_params=model_params, tester_params=tester_params) - order = torch.randperm(n) + + torch.random.manual_seed(1) + order = torch.randperm(n, device='cpu').numpy() original_costs = [] revised_costs = [] + true_cost = [] start = time.time() for inst in dataset: tour, cost = random_insertion_non_euclidean(inst, order) original_costs.append(cost) - improved_cost = revision(torch.tensor(tour.astype(np.int64)), inst, tester) + + tour = torch.tensor(tour.astype(np.int64)) + cost = calc_len(tour, inst).item() + true_cost.append(cost) + + improved_cost = revision(tour, inst, tester) revised_costs.append(cost - improved_cost) + total_duration = time.time() - start - print("initial costs: ", sum(original_costs) / len(original_costs)) - print("revised costs: ", sum(revised_costs) / len(revised_costs)) + print("insertion costs: ", sum(original_costs) / len(original_costs)) + print("initial true costs:", sum(true_cost) / len(true_cost)) + print("revised true costs:", sum(revised_costs) / len(revised_costs)) print("total duration: ", total_duration) diff --git a/utils/insertion/insertion.so b/utils/insertion/insertion.so index eaa1b70aeff2757604b4b9b01aa618316d80007c..5ad219c88d132a395107404cec8ff9af56cc9e60 100755 GIT binary patch literal 29248 zcmeHw3wRsVweHBW6OkB62?1O}iNaBWdAU|z*g%S8TgiwhiHIFAG^r%pa%`|I*B&Jf z4Pa25G8G0>DCO(pYf2wC?SYc^(8J?EaL9|(bIT2s-cs5^fC6&EBXxkLymbGyALG$j zf_^>s-22`8W#=2MS^r*p?Y-Atd(WO3&&mzK@Ujw{O_9m2d|n~WKaV4Rk>S5uoPhY1 z8f7|uJ<1#=TRM?7>1^OM4oRxGIZt(z(W!@ikUZ?ypz~&PUQ%;8R-a$UTj$?o)$S!|x62q>n5%m=tha;lj zk_z)OTapeMp5?37Di9l`w%A1j=zNyoDZlZi_jm2TrTf_P58vZ#?|AI#FW!FNA4zlu zj`MJkJlXt=xzm&erQv7qoG>`w;n~|%q6~VJhMP;264h^4{E*=Cp98JN;lr^2#}XWr zF2GTP<3b!uaZn;t*0VE3SqA#T#*JPj_ux-|dqUGsUi$U;%l{s)efok0^S_~fXWEx9 z{dnfKy8YkUziG+*Z(Q+>ukW)vzISWaE1B6#FT3c(n#X?m^vrwyZEpWic=osddf8V8 z&-vCxTdw`rk@8E<+VO?GH-C8LSALL~xwrPApa1RoR{Q(|AKZWHmroMaP9ohf*vv^N z{2c}^gnMBiIcDZ!&lkb}QUpJ#2u^l6)2vvWCo#Z<>K(#B7i#B!6v1gpD@wW)}nxkcn(E7ITZ7r_@5Y5!?Os zW-F+tp@=-KWgZ;Xlqpj0RP^gi#i<-xz)M_36(X&G{pTvDC^2z;tKhKmYau^M3mg)y zv6Nb1H`6EhZE&G=wa;<`e{9-bUtg3TN$b2yx2x4 zKLtyAJZEwIbWW82UPS+GLS7N_vR^xyeVlx8hm5@bexHrIQ;Csbk*-HODE$ft)$12} z=sG~D7YE^1{acwG+)A@(KO|UMi(l&Rk#jjel@sMDA>S+HPZjtI*-l|FxzD%yPPJ;7GWstxt@C;GI` zWKaEqV8rKZYK(ODp#Gj#R??L1`^0SD!VL-S@_1{!M_ZL_O9VHP@@5*ZLCi zZkgW-)YsY(?~7`E@lLHD{X+rGx3DGN-`OgqLru}9>S(ZWNi?#eabdKJSF^A;rEwcH zCKtDNrTRN~qs_+hEks+GwNPWHMglZAt=B|bJFbbg$2&iMK^th*8!n_x>b|vCHsa`HTf!9*Kjapnw1S4dro=qqi-7qi^?M?Lc zCuxke&C!k3N_1_{rp}(W=$b@-s@pPlzQujE zV(!G5E{MSag29H~)o`lrM0a;`BiV`ugxJ>J&1^syR7azHK)Vt>8?=sSqOUL67g@ok zjHNe)jN7jCf2mJvU$}qu^kLm)WN8eNm>(K@>mK^hZ+`ud`c z2}}~9lK2R%E<&6)#4Npx{6c0^b2?k#Vq5|yZL}_fO`%E1<*R(~5%^iWrE{Z?oUOGh z-rpakF@ky8lR+As#-g@v3Kh@F*G<0I-&sU`G1sEXL&ZI#tFxtFYim9C+^YU$)e;44 zw{f?+>COIs(e1B zzPUM27mE6-7FN}om^`ssGe;L)+;S3O37ZgbrP%xu#9(j?H~H5uu?ZcwcOW6=5L1D>B=|7eu(rAnDW1J`(5Blm3xeQyD}p3{b?+RW!`)k;elqhzw;gZjPbXzIJ4C+3w~GfA6&CzfA@8x^ zI|N>7!TpPPy)_oRx=!GtzqB2sy{_E%`%U;oT$KIFeI?0%)+8_QKdgAfq-Uu~PqPWH zHsQTR@D3Bc&?G--!s|`=%_dyjGO^sdP5AjH`RyjW%7ot?<;)ZiII5Ahn{atAK$$yC zxExEucbf2v2@vTa6E5$GC~MS&&ym8QPnvLgwk6X+6TVD>pwF3b;rcB1kO>#J1Psrb z@Jf@OBPLwl15)-e6E1FzSg8)-|Ktx|hVY+?7^h@2%9F_x<7b4IND#C##u>uJtqudK zO*pN(GS$TR8Oh@o#z_7cKP8;TQYJOV&j|m#1VJM)&JZr93Im!=IIZR~#bW%7%$H_59eTs(SU zxe*f{Fv%Z3J#K;H7C3H!;}-b8Yk{M#Gu~IzZ#dLU+3zk<6m@7+E6E*H(|a9zSi^FQ zUj{s$oA(NcYmP$l9h5UZ_EIjF8)3YS@lP8(ZM??Eb{o8%@jo+o+E9&;eb3-&Lp(lq zyTQ}OYJBWQgQtt-_}G^Wo;JkeV{wD04b=G9Wd=_hr17zU!B1zr&){i8G(I-h;AtZ? zJ~qSP=>}$e%x3VkAsQe1%g3_6v=Ku84W2eY<6}=7JZ*f?e}ktD&-mES44yVR=)b|! z1_%8&c-q*Y{{~MR8uZ`bX(NOF8$4}b(0_xcjSKp3@U&qWAM+VJ-5{X<2LD;c&oFq} z$e{lQKZo&u`AGKP!}ymBelFvmHh9`#p#KI>8w>Q`;Ata|{u?~qP@w+?Pa6yL-{8+; z{FfvTZ+sQ^KCU?p3iC(sE0^Z;K?}ds!Y{V)=Ue!B7XGsq{$vY(qJ=NB@E=*)`>ut5 z)52T(_eV?qZ!G+;Ec`wT|CojUg@wP*!vE01f5*aq-NFxB_#q2_orO1QX(U(3sr#mC;l6e?OiPUY|A<^Nb~)PEr_Ur*(K zI;p%my!123QHRTB5~)3=os3J#^;8^}k`r==TyqAgr+Wm2{1{jGGLf@CR`nz5?zb*c zcOSE>w*BhUA8WIqV69N#$Q^2D^IVpv>&@WOAEK<1I(My_Uiv)U6{-5GS~Q^|_1Jol@$?fkO9ePHa&)cc*PLJAD{o<$8%q6u?Y;p$;zXL9q z-e=3+gKj@USK_e#*8T?XRFA43$gcm8++lX0_Kl5Y>d@~~2Sb@DYQ~?vG;?qD5vpbA zP3^kD>uMBjiECG?p1k>y_sLTaWLuz5&7cu&>O$V6)GD~7`wQy>>jLWoUqtt@{*q5S zSS3Ma7hQLuy;{AftO+mjwVlY#9dzwFa`;Emp6EvaSz{W8Ap2b!sps)CVG^?US~c@I zYS?U2{(EYM{`M{!ne}7q!LfSR@8R}l@{KSH-^3&_P;bHoQIPDxN)a=Uv}CTqaERQk z=G>`Z+Ffstx@OcnXx4p!TxMeYX?#NZM{J1nqo~M}H%b-0(d?_x5XuA<^+2f?4OKFJ zcnNGr=kT1(>8vfJzZM>DOP7ZABN+F`sXwZ|FZ*SfJ9Cnn-cu^v>VTg;uhAn8RSzoR z+E-G)hEI4ydXPq1R*>C8!-V0fm>8bZ&~}e8JlsCMQPg|*iR>bXGhf#Cjj3pCUjMEL z={b`V)?_zAI+U3j)^iP+UPptjc~8%t`L|pSfU57wo{G~@<`Vdxs#j#Fq5D+bu4bl& z^r$=MrmRLsCgnKlni0gboSJht&@4Yx0-(N}Ji9WG{MP# z1G$-i1DUlB{RPzA;4S6gQJ1^PTZ-RxUbn0Amq%UAPaX}p)~zgeeeQP~PYvsBUTQgN z#&45%7JkECcObpr)=>Ly>Mk|2&Wr7G*z3WpgOWCqaCdnt)Qnx#r$ai9E+S_}M|o|R z_jK1#C44<}U^(vyt`C)=Ow^sN`)e*Y9&ovrSNNWp@c+yT2R!lLkGif}0n5!EcQxhGvVLX z^=_^H5pz#kkQ~|P(N4zd9T@06!aY@8$&Mli4xk?J^W~GzoC51xo`2o=#tsZ9%!BDc zlg>1kwtDYImTwd;lc~x0J+8C2dIy212ex`gKq={4yb(A<4Nb3J0_bFbPKI)>tuNwM zVt6@Pbrf@LTd6t>A@ZFL*@q6c;EkaD@1c5J z`KU)%!R^LfGc<3rZIry7DvrtP9=t|X_2*Q5o!5_nZOFvrbp#WQP+VLBwa>cJ=aiyP zn058T~Tw@NN2^A^dwg>n_c=K>8h8@G#Lye4V%2Rrw?=bI?`!1kLqjnCs8s>I}PA zP`^CPMAhNiceTZ?%BMpaS6Dx~GP5L88^~PlsD0mc(>0jHXjBC?=O(%pRWp}WuzON| z9TffMRy6wm6@?_zTVNsA)*oY~qZ|)LhI|9A1!HCw;9^mG#Ff5`Wnb!m^9+~Y7^r>6 zm8QS`A|Z6Uj{5OQL;kg+)TRcnAB{(E{8&CnyHCwjU^Qmy=1iyyIBLgT>5tjygqWpg z-Hr-qM97s2S=1*gpk)!OIQBwR8!d-G+zsApnBZdUee6IMsXEK=sDX;5c53RA(&kDSFGuxB^Qa7 z3~j=o2lcAQX`*Dm36XZy)w4&R|{ z8~4(J8P>C=72jCD)r?xfR*0Kkz*QZauZppk*ltFw?`q~!xa>Qwn|^_kac_J2w5sIu(EhlEt)U>njaYw(Kudn#Q@y+a!E<3l07Qd=Rjb(p17IC(H43b-8-c z18m9Gs#JZ7n%?hlReq%ETpB)ssX1edw|eYh)D#kXCiVb?CV7WC%q^Au{1K{_U5VXW zP;th6_z87*i?=ts7Y!PF1f{cA;Y{vqu7?@r`4+M}zrT7I<3Y(4xWcr}sR(B_V4Hwl z&Y|$xSg>&eQv0N9YZ+`F(toS!R4&QdZ{=26SO&;DaCpY1a2I%?Z*ng^6WRl!&^=QPOTA+TR(^1OI?bcl<)V4ZzMjfzl{3=?5=pxeQ#4EQ4Te0+i-&-uaALUbS)n{2Ytk? z+6#Dcmi-K5)!L_2*RrQV`ct9Zui3|@Bb(13vUXstijC>;;)9T8o8%kcpi7DA!j3C6 zuIlTnaT3-q`V7iM>aek<9pH{u%&4~M<}Fd%Ps35|PtuLq4=g}cGW&!=Uum#rxH z8E~kzxJ}GP-pb`5aULYT!z2cn1jg0x2;ok2nW}$GR;grisq@GVw7FR49l#tro!w@+ zhPLrJMlI&XX7_c_88YegNS!XB^Vv69-8;!U-RXgflGG1rz$?<%R@hSCeG8WROs|LSUBa2G-B@&7rv%V7 zOeMN-4{KgOYNdG=_F{o`hrOk^g!$92(yR)5HAuy1d!=S~q33LEhHGN*U(Y@UK!tEM zjYJI_34KpS3HnAKrX??&S>fIVd$9dBW_MW-@5*6t_tp4Nq)bt+R&pa*Ttpz}kYz&p zUI?BJL0%HWM#`(1@&ZNmmmh}))UY380pV=t8<@)NSc0;*!KiuLM}B`T{c8ZvdBuHl zN1k&vLlc~b)mwpeYZi5ZJZV<;4L(j0G;#B}6JCVPy7Z!Hs2*Na1&J$|1ot9i#FvT@ zfB7)UrK-5!b3bY}`4RmkJ^IsOIME;)p};3*;pxD1wu&)7*+-`yoZ3>4ke{F_%oC=k z{#p3Sl+^ctp#gF1&^j>iXAC{#8hQrWYQN}64e`rLJ({i4RuKBirc|Q8dID2x6nh%n zHNuIiRsF5>E12)EXHSI@K;Z+~^U#ZNx>~EWU{}&eRnka>^xbkm!6IGhn=o|$csg!@;}$q>f&Y&zz&|~?@WM%-n(#he zvbw!}b!c_FvU+u|(v#}$-CWh0>`PQdyHj0SXS}a3zB!u02L>HVG_Yn(;L2zq5;C27 zdXk!_J(=ohGlQM|o>gnZVNb94?16s)MW0v%2NJFLJi)`i;qe$5RZ8o|zTPN@S<5D> z^FSR{V|-xYNwV)Rcn7-KCmWvDWOr|8SE9`m-w=mZ&4a(s_iV&xMX1$NJuq*;gU=#7 zS_gDVG}-Q99VY!G@hSQu|CqjIfu=zCLeHdT_Ipx2*YqSe^>`9JZJqHRUOT=gFDj3( zqqMGs5I>Lgr!Ar3LlpG7t+Txy-`4eLBD*Ki-|vZQo<9D036z`0c4BBf^w}4RYe`Qh z`_cuQhaw|{QSS55$57&rp7XyVk*&wK^98@VG36Ki%me$Y+4hT`gDCxZgZ9TJZN;Lfm`H6`F7 zL2m|q7WCVoe|j~SI|MrGHEaf;PXkp!n_kc5o&{~dHYAMvCeSA6Zv*WBeE{?ipo?$^ zI|6zq?$LLk-goc_$%}eR@o4@Q&{=qpa3}CZpaaAU6WFd@qu2)Aw%O%V9V5{11%5h? zfj>YW>Q$WXWzLFAUDG!?29--ryYQTa-ZQBvm2bdt=!IO4c4{&az7EH8zz369h3;$iK7&@lz-dm2fr6bH8A!~=rX7K){;7B z#jw53>A9se;H=zM7I0Rlr!+YICHG8oRs#t*J&>q#y6c>dWz$(d9q99i(19BeIGfYG zt>nsSPRFIw15O9CMJ4jfF-FZse!y9gwl8;jN>pb>!0CpdY^UwoNSIl3>bGtLHe1@1 z>_m3;LJ#><=?xgDiS`Eb+q>KuDf!mq_AZ~UU_Z{1sh-bF)Z>h#r)(>`rF7VSE9_3= zP>ti@i+p@aD^Mp=49AD)(*WASzU@UjZYv2oD{dvb4VTtAD{mpYZJQEs)}*I4ID73s zv`ur?0D=J_Vi+(mJyr;Yej3Xd+Hfc8UVw8&b-J&BJt@&xZb9BL_&Z0qZOGT&0aiys zY#d5*nuRAfB_xo_j{i+x8Pp|+zz#o8Vz`9Z;^W9%cO3wAh&I_aD1Qkrz1 zY>Va2LpHmw3>q6I8RJqsiZcGr-!wj(v33YQK|c=G{%$$a2y2oC)kA^>Qg_4OTx~ITkP=oZ8R6sWy_sA zZR?y5*_Js+aXg9RplvO%lH@ctG2qsGVg#p`j6tpgazV(|K`sHgkSHs$D}bRYvSYXu z-*@0a&xytAZg7s+O1^2SN1a}>!+E!@)wvz@>_9y`ZB3AaKJt<6(D?wKsN&r{OL*L) zJOHd0>m2*`ob_>)(Z`bg(_{w%G!|r=KS6#3>$cT~xt#5U`B(#EW%!C8OzCtmpeNpK zD|@Y6`M%@3rQ_u+V-7Nor{fklZh_+#_^+{m{0>xp_bI>gl;3sA?>PT!YA>#pUx^g< zO)N=Cp1#2~^YqQGnWv|Xl*&0#6c0Z+QScS?mvl(~q+@_6XUlC1CJ9B+SxkjDc?)*oE5Si)tF@i)`odL!X4pcqeo zD*pd++OfWH>#IE+_X`>kG$v@Tpo4;r2)bR+oq~=EdQi|qf*uhR!Gc(F3+fTHT2Q~B z5kX^u_6j;E=!l@(1>Gs=sGtW0JtXK6K@p6H*Dt6?&}u>bf<^?53EC^@pr9jyZWokm z#hxZE&k9ppVR#a{)t?W+KUtFBf#>J`{oH**ul!D2w&zxX?+{de|1I<7ci*2@6QU_x z=&4-Wf(H#LPc6bDRG;rlG0u1Wg6gVj#7j7vU^C+rrOs?g0Vj3<5&$nW`v*H5pTg$ z5HDADTKtSk*p(BN(SkVQR?Xu&&A)cVrO5R^AIFa0Z$|jJ9TD1Od`NSy9TA~!t~Vo6 z{v>59-6UF)9d`rrI%CeV<4a$8eaXis;)&(srz(ZkFFURx@_YMyd0gdUW`v*Hadj&+ zPwmQVCDu|%US9ns@~UxKs{AE4f$%e$FC~bGAni~0Qk3T6AYAS<*gX-j1{^lj8yEFb zc}ln8D5Sp;IO&Oq@%550(jX4P<^IL0zr~{e6jW45|J_25od1`V06=;i2eq?9=#kgg zgTSd?xt_3lP+;Rs{xoGUjuWH~PL!#rm*nOC%G%CeOFO-gx6y7y;4z`+!XoVt3jO~1 zLQs(90`Ilp^c;iQW8LTV0;m2;e*nwU4Gf>`w?7hkw3sLgO_RHy$&1SMq*qfftew11Ec0<5FHJaBFwA4d92m6Y%c)dPOl&=D(dPgktZU>VWLMHN95&Q+A-@1N&B=lS3ySi{g z?J?<9`MEwb2<;4h}Fy)?3ATYO7G zvdb$IT1T?2Uuo^r`n44PoU>J#i0c~Fx)I(FkzM-{+7sbfH*_Ui5Vo=nAuIc%@zj7q zFY3GS+PbZ(rh195un+}&?Tqq2KEONezRgN|U%WdJZA*1`BdBH6#IabWd7{zivNeH~ z!Dw(*1BIQ9B3OJhs0vVRSVL9zC!-yRB8t$n4Odk4CEBYBhashAD69r0+HmEnz{*fP zN+WcwR4!a&^|ED+K?G(E)P;jIBt7wNzn_H{qT(mFZg|3rAuJSQ&{9oHHt z(h{0;RpHnA=5U&y5F+yLjhJZ(4@xm9CxZV9!%@IS1Z2$XP@(As2~!wW9&EEfc+BE) zC23L@jh#Aa6eptGPL3Wd;-{KAZ~uPkyM*LSaDt{)xbsC|)2UD-V;*;sj3$uWAl{%Hz{cSOnoV3kO?8`^`a&^CCW* z!uL*S;?6-qaPxc1!VOy?7QEQpE3%j+D)NMCYGZzYRv9q)e@iImyvFgNGk1Mb{N{Xl zL?wl-Sd?mOJZcIyZ3*d|R~7|D&}u7e@rfegTg?W9iDVdDQ{3AEv2MxX3NI=~KSYFS zQ)KD9%QMMYCrY7c%vPa7t3^AMeNvEFnnndeOIz$!H108v9BjlDmIGxCC;mwRkVSxJ zQKMW13q;S(8x7Pnu`(eT_N1ujlfnUiQkeF94y;>RP`m>;#yp(N_2=_Pe&WHNe7iyoC2MlLyAySxW~vwntEQ#DU%|e#yBpgs(;oRhaZ2wQaSMfKO+4)oCz5&u zE(vZw@)G^Ip}D+#4UCQ$Ey@I4W zaNkP_x8r!;MG=d6lJ|r3{1%U5jPmmRgQW8Ps3^vC$aA7!z=0P*M*HP^3Q2Dj!m|Cc zosxDSpWb~+UcS$e^pGe|?I#_w{Q_1t30#(!?>!_PrOpW%qO!c~|BWKAjxv$t`&&uP z<>?)@wR{E`*;SZ%;wIDs5J|qr=G^4*--Zlxd5?!nN;)DtCKSkfR!MKSl$Y;iB$aZa zxN*-b@o!qnkDkdTB^~^f`oC`}FW)mr`df4GRLIJDc`b_aV@rAYK1fpecLGvg`jO=C z7v+s+`bCbQKQS^mZms`8z+`_ZCZAhN`UhUj__fx*&r)8#hm!O|O9AOWQsMuwl$YNQFX@}ebKwxv#CZQzy^fej9hkX(`QGjB)lA%Y z&%?Qi@=}irh0W!an2?Nc+Dw_IB+E%!VJR=)P%P!ScK_h0+$JvnnS zqrLCF-@U*4`*p*aeb#@kz4qE`uf5McA6d5qLQ9KmHc2GA^mU1-;WG`Sh8GO44R=81K1TLz#lh@nH z>xq1ncSYn5CRLJ(k<;-MFSq6%&L?u>r4)WcZKd*=3XrfFyF>9oDyL!0kAeEng*Fr=8(zUOwT)pa1V?K0EKE%uvPtK`+&tg5x|K zq$k^U`0i3-E!!>7VzT2es&DFx&t3+VY^0XYv87?;}$$hn{Z|B(Xy zeV_ntE1>_J0{nLrz()$;cHsH!d8B}y7Z=dyy#jW-rGOpMuv@Xpu9e zYAoV2rD;+Vzs{8LvfW(HFy|MxrF^P%l9c+2NeE8EurtZ&x5)VxNKLhJPK7P# zFY#A#ej)!2=sOkq56>{zXYpG1ayk7LIoF|G>hBJ&C#~O<&S3p=N@0FoJKcb#N0F0! z`D_Ca-FT%?yPO}J0hZ_*MCn$hPuaMBJ;ufb_UDR2J(li-J*nNAa}4@E)Ti`Y91Q33 z=P^{0HgVkAZi3m(DV6blIXQn6IoaV5Zx_w6^cd%_=Y67MfvKPkP;$6c}b$@?fHIRo4d6L`CS z;_X`PIfMJ#utkm(iCoha?~e4U(S#a_NRf{24poY@A(JA(#`;KWED^h=qgRb38tdyi zS#-}TU{P+jRn?|@C9_bNR5Y} zejprfh+LaRt9lY0-D5$#ZW(LYpBz4M!+y&WxrJJ=X$tc(O27Dd9BHO!B68qLh_ z=~E3GG{hISb@uhP8%mpv<(&^*n6=oj5CaWV%XN{K_Uj^T(GHSjXr{I&;u|8((bh;y zv{xPLzJH#{7Je@yT06t*yGV;M^Iq7}8H;xJ^%%@!cs030gBn$1fiPLBdjl%2>fX@N z-5OaL>+S2pFh#EE?~nAv61{O6a&=>5ePyn2@4`g1yEWdG9cwNt%9LnOt2*MnP`x*{ zAc)rL3?-LaaYK3LTu`Td#(bnf zU9>*dqQ(=!@WPdZHbjFB-bf(ey$~Y;7wGJaw+Jt|bme8NEUlLa<7Kp^1nSplc&igk6x*T*Pt>Dh?h(75F3R6g^yNxj!6)VBHah=1OE z)3jM0?+)~}bYhA{yD!9~XeH;r7QG5@^d+u|grkYxSYuyLXDrfy>DefSH#TBsMuNQ! zea$pS31G!x_!*=(t}ZAZtV=uNQ8ltO5%2QH``Efb6)H{oWR+o><}ER+8Qw@fz)g!~ z0d3UBTl=7H9n7xA<^c$H^~AAQ8JSTEbVc*E7_!K+nCes2M1zV6ehuKdXiqOyr(S{# z^tVJ-_I0ZrU9mtS5l;}5NJJW9@N!lIc8d6o0p?|4S}J%|qfJp$$VQvtIGAEW+)^5O zQHRv&XoeFR)nL*F>q6HU9EYr4?yZcFCqy%8ECn5QispuuSj1>LI}!j|BwGj6v^mednrkhL5JsTL=%&x9ejU8?8$eD(q8`)|~L?%TN&wIpkn$fciyo~Q#1TLOO$rgO% zOoPANf{${%(t_{e@|%1zt0?3PdZz_f_=a=Df=jhVeUS?}*K>XcFBf>v1*|!#%!04v z{IUh_nmK%O8cUHB6J9Go z_n2@jh}qO{!b`KLBn_BwagR#Hx0-O5iGSFHi#bZg+fBInxzY|3KGh^=rwKPd7~N&U z#XUS#d&GoKH^~_>;q>f8q^C@{xR<7~114NNLm~VX6E2=L5Pryni)Ryrr%X6s$_bW6 zO}Kc*LHKbK?k7MbpUlq4Z-ea2_;L7Tc0%~Y?9BLa`($=PIKP!=beT_PIpM|`BPH2q zP=wDkAog4ClUYvqEOuu6RQhCgLU=VhGk&UkGCLuBHajzZYJ4&~A-tTO89$0oW+#LP z*qQO8B;Ro;ddY{gWl0$vQHwGMl;mE=9x1c1XpSspW*6hvHKPV+M7L9=^q~yW>`zdo z-A42QMjs)%i0BbUzfCkv+4L?(zeaR1(K{IZD$%smriU5*9MLpo(*ulthG0OL&CYq*DdIzJg zB$}p9dYI8GiKZ!&9$@sPMAKAB_b|GaXqqDFCPrUGG);|kn9=iyrYVuGVf48~)1@X| z$>>=`pGmaL=+lX&sgiaxdMeRp6D=`%64B(+>Cw-q|3yU0L?2@GCmzt`!s!EyK0-9P zZhC~#Zxc(zeco&=pBrHmFU?-4>S5XqUjQq9$@t8=g3mqFld&Pp(W2i0dm(A zjlZHVxf~42(DVmUPZ?TxKP%nqNH;_Kwqfw!EBI;r!9UaB-)!*TB>3kE{*8j4#vlA2 z%wqCa8vIuY{vyF&4P=!vRN|$Q>f?&`hWZsurSDQTOr?{r&Kzp1aLw4l^uKx!&7v_H z`h9E&CcCXt=6;U)%I?qYiZ-h3PFp%EA^*p`T5z^jDK50w+;nPg!QCh4*+A&4jzt8WPD5tic$Ygxg?<)Ek z8x92ZVvnNNq%PIJCy$bB)Za94W0j;XbUof%2QjIWAOm$%KcrlTCaubRL$AJiESm1C zuJ*0-UG4h@>xrTrOev7B==R5HQG)DBWpT+WS(elXQJ6X4dVKWoPAF^iAARtlEUjV$ zQU}pO>U7i|vv-4{&mf&nBf&FBuy*iq8kwuoq|fkG7W=_++NZ|i9>+kvt4>$6_fv0z zmDM7mgZ1M&ceeB?bLnMNy1$M)L7m|Q$-R!N3{R%<32L9QA<~YbAvtH1%Df|~xzIAG z2P9>Gu?LDudJVh;W}|a>&P~abEvUU28fr}zhqMp5KmL*$N_1b&> z>d&A+-~UkbMvoiQs=Cv2zM_XbGUT*+!0Ybxlq-6XqD?}*XsM#vkhhn*2G4?9>#!I+ zitX|g9e%g_t)kSAK0-mrbN&&RyRJ;po=B!`%8+_%v0m?JbG`qxl6<0AyY$3S*Om39 zO7%;M-X$Gzt=Zjo=!k1dUD@Gtv;DOS4{8}E-w*kL!A#%1@PR7f1HQ-}ZFJAr2t=++0QPJ$FX(&;&GG%{_F;k2g(q{IbZ@%*gM#wZh zws>{|@Qz?$VIIr~^mIgb>ou}#?iSB5%9Qi zr~dq5Cc})njuZCl# z{CcNn0d$_e4J%kkYxPt?-A2!}ptfJp{&4ta!!8>gwvj;flgiM=+gzJ(XSIFQTlk23 zg}Jww(vF-ykN$=Fj5!T>%>kUPgXZPW(6kPD*t~Y5J0Z_Ba)z=HRxTOqX-InmlUZF9 z(qmFkclw~c!>`-?)g!LW-I%VVi@veUm844&UGK^jyU&rn78Lv`q@x?cZb5iG?B?2h z4BHGP3Hk8JXRx$UH5mh-Xu_djMtu^l1l3QulDC4xHc}n(EKu|WhPimFYx8DSYbA9Y zqwA*8T~~zT(%fPG?5Lqr4KzpJYOv8E$=+%>k)oIB4GvfGB(4C{yZVT0b1?%h|1#M> z(o+Y}8>>sEB5TUZV&Vfva<8`gXfVRGmirWb02sxenw zgGJcYVm_^5%hFGeWHRKDB)I#8-8Zz!W5XC>Dq(%x0>+^Br1y8!7I7=p z#=NT6-dArLTf#zm&+%Y#f3cEG;j&|}26YVX9G@0B5vqREwdHCsoIgxU+;>XzZ z(p4yqB{HP7V`GugDpeQFeMjnd=m^;l%|67|KU|*|uw9*NkUpwXzPa}i2s16wLB4OG zz1Ib3f~)gZY8OPMBZeQ+I-%$a+>|y=Wc+r;PFLOo>E{4suMgEn6x|1|6Bq?W4?1km zDMQolOn(59*~9R{T5#OVI8<_XIHh(ZeLa9-HgcwwICa$rBuJ&jj2vo{XPt0zjQh>- z$A^sJTKo`O$rUWS8eMB$Q*J0#=fZznXoX}p0A04?a+p41-z z9)3Bs9+}|}wBHG8$7s_X)c!)d$Io!(!|lfk4A=}sdn!2hkILLvl-gT>0s{WHYC8X^SX-7iZo2j*E zBB+hj<4y@1V;J&aYCWKk)&(v4AnGPJF8XCAkHWh3I9v^P3KxHfOiC|@>PG*SM{r7M z2Bw2++d`ZvgU9+_tk*`dFAI3zO)mvBHs7%ERkV*9aO3$fW(?XHFkRqrt;5w_`;#(+ zeG=P*5BWXru!7%m_Uj3IM?ym(PfZArC-nm|7WSxQ*ym|3V4h%c>4)h{es`)2y2q_l zf25YBo_sHpA+Db?u9C+^Au3?3P=Rs%fOBm#xVD?PWWjZEY9r^0iWZaoWks&Rm5glk zlqGK{v$+QA7>PYzawCxo7}-;my!p(@u0i@r3Trn}6I_F3jJlpvlMGHdq&K^jA#4oc zwz$W_Ki>Fv6 zl60+<8BWm!g62<2j0!O#B&L#!apa20QLlN3&+(`=kN}Ti8kLdTOv6rO8f+}3 zn)pE8$}BECK`+@cje00t=Qijv^o$CW@o+ckI+;OcKUT!V+8v~lHl39Vtwv^C`E zYp2MVxG_wa5nOBon5C_-z4ji4o-P6R!Ac?RJv#l(;YN&+92|jFRq`btbLsu5 zO>YBEc1hT|>`|MsV7j)@{U^+Rp9}V@u5+lkqLe&I#L&X^R7uQ_ag9Od%;{9&)+a3cY^=$-NyRr#bcL)snuMbqLtXE;z>x9 zT}zHb{fv7Fp5@x0!rSO%UsGmyJ7#S1xC_WDK-R07;b9=;!T$*U)^-DX(aHMzPcX%!M*Ae=H)+V23K@nL zjPoB2&KHQ&P4;H3{g^mYwCmMNt~WSdH40Pbk|O_R6m=U#Z!wEg3td!imr;1RQTQ*s z-VrMLHj2hBHW+``0~;_$mbq<5Xc-;G8in&?w2LX-g`M*-&@Z?xO8;9blVJ}};e~j* z&(b9;+S*nGSF}khR`f{SeO)~pD_Y`-SVg3(uT$-aCKAz&kv@FB(k?}OD_8o~MttF* z=~V8Ht8!bs4lG-Lc+YIjYJDS3kXg}f2@0pt%MFN2)B zv3xZlS8>fe6ZjV7dkIJW3i20_ho~JqN5JPU(nH5kA9*nz%=yu-2R9Y>Q;x0nW55?6 z-wv8y>`=1Zuu`)1yKQHbPIL^D9>C}c-L5}pG6Rq=Io(U0WtX}pZ*UAqmz=)%-1(ka zKv{i1a>rYl411D9AQCM#>u`i|^+R(k5!Qoa9c&@mC2R|hc3|RtDPh}jsKDkSr&NQ3 zZUlD#W8VtbI^A~^)jG?D?6pq$_F|v2d~1o%S(%*Rch(gB&q>ZoAU>xI4sg~w9f8TL zKckS*37a$JeNK0>Xze7YW9ej{)4}@eM&CD~oOG8^zO~5bEKAy#Ipw0TvkU+j$u5*^ zH-tmXY7_7MV3I7IA#6wXAp6l+A4FdK8yIj*-$1Uu%bek&AB@*`*9eE7y8 z&qKwv&hp!-&s!(>oK?w*%bY#-y*6hR5cC}!*}kJH+4~*H=z*WIDRG(8 z{SEY)Y)#J>9)X_~~NV=?iOZX`qhAM6YkL1-yMwUhdy%s{3iXv zMSCXY)6YLy3ZogC2Lou6zLhVgJB2ZM`TIL+vwyWDuS6fcKrO>@2>RU*A2#>~kPRsH;BdoPe~(;nZ0npmZA+cIa6E!z#I_2Uz0n4= z1!w?hEzm}wL10&yuuFk;PjbqT2YbRJnU2)&Eok!*JPSLEw>d;(d^@$dmHawca+!13 zX0Nh2x7$`i$fXdx41x(Q>YT*f973Tigu;}pZ?8b+-MGUQ-wvFQcHRNjgLQ#@KIIkhmHW-HlVCcjc6()0aO;w1!fmYY-(Yy+*2$Si;gN$|>;{VJ3?&RZb?ZByg zz~}IM5zm+Md^OM8c)p(JTX}vD&;Nnv5A%E<&tKyCn>;_p^9kJXPUZO=o-gA0Ql78o zc^l8y^L#7M@8S7B@cdz(@8kJPJb#nt$9O)0JNBtOpTqM-JYUN5)jV(G`FfsP?bzg# ztxygv=5Y>~hY;VPiz0exK}mdnF7SVEF1}Y6-=_KT2E||*gv>I~%&*h&i4dn4FJ63aX zeaht@|M$y=e5Z8CYvwf0?UE}S&%#`4$Bwvip&1eUQ>0e&$Elpfj=xNGFOkPpuuG?8 zuNUS*J9f+B?-X+J>4@u*PtK{?@Ah-~@g#v?A98X0-LWs9obi8WP4(<~j8FGC#nJ~E zDmN$c|00w9`Qs9H4`{>RaYt)q#xCYbEvS5Q=sttwi2a3CPP0YM(*@*Q#^s3l$?n5| zC2){F?Ocwy{v|oTnD6W!3fPYsf0;DklW~UU3+xxqK#`N4V*g{+v&W+6%iy=sRh8pS zT+X}5^V#7zaFQ>d#Tj|_3CEi(IK{Ofe(OH$9N^U7CW}7v7(U()*K#>xzb9IcG5+!X zTuA>djNdIa<(U`x`g<4n@g>-}dG})`XT0A&$nkccOck^Hm|rt|{Cs&?$l-QimGJXA z!^h8;kAd?ol3-G4|5H6_8bGJLyRB0O!~4IN$k()nchrk{NLk04qQA( zFlrdT|H|<3aRZ(MPVHLb7<|O#TmOE4sy)B_xdrg50(cPkEVF|AJO-Tp9$$Rie+YJ@ zv`ygW7&Lnxgwwl$=j-oJfRlZ!aTxY(nxydcQ|zOsV*^Wi zTH^y;#Bj3&(q-3A0OA^HGzN45HI2_94#Q9 z9?IqO=c%|M$)|rgaPq4|e2`&Amg<1#lS6TM^5NGPz$uPLKK}0(z#nG#_C!T|5c7#6K9?`{V?wtFfXm~>OaEFx z{_DU+U30Y$fz$l5#>qJgH{ki?pAS5roGRdCA8QA|16H>jbiNTdVrHYw~xZVh^i$=O2lyzHH3_8lh|7rroJyu)m$ zq*x*$ua|-`(0f+)Y8AkuFRijL%VE4*w5pLU7P@#B1w}9FV`_W6wO49MsJ&`mTU$ko zG#1t+qIThL{}I@w7k_<^=q%TC#+wnfq7@M_-cH+%rYsH0&s;lxdv4AlV z2L>I0F$az90&i6CBGB%p9m@vo#N zd*hLI1Q9{x8UGa(iC9}j{;(tdwab0=!8+7Hv>G9hkJpN&OB(`+#p0_C1!x$$qg^#M zEUthND*zC}zZh{@#uxlYLjKN< z4HlH+I*2&Iyh=p)$tmRVMR>3f=#v|lBVWvqnwqg_7N|Zx{E88{iN`)-!GFd@R53?5 zF=Ee**#(g_^2b*}FU*mNa>6*VkvC%Uj7c)aokH5=_L;>HvO>g$$vDo^rjZtHbf9kzxS5kL+LDzi-righnfXsn#!$+08?UyX}^HZCyG7lmZY ziOU5|c<3BsIW3)(bjXB}TPJd)uXo){Wh$Y2<3cDB)jzQR1s_E+nDxN2oV&dBc4&Lza`eAMxZfv zC`_2K!!cxtbCN@?LnPBQ+SSp5`f-&S20{^#Di8ok!UuU>UD!wE*{}XPA$pC0TMs-V zGm>~-E^-gYMTXnuoJ7y}@X145)ECdyMgAAgE$)-Cz0XNipy*wms4w34iQGZTAyK>Z z$VAi^&ns$xk!(?4Jf|1AcwYz>mPm#yH zMSh6aC;jOV`g2%n<+!Lnz;D7uKEfLo?Th-N|J}TPEf*->XNp|B|D^i#9i6p)KQOWh zZns$yaY#i@bOLSNCQphdpTQeV91 z5P98nQPappy8<1u)EDo!MULCToMi3)-|_mQo8rBR$Zrz`T#B{*doA^?al`lKv~Kv2 z8TyH(zIYEK@&y)!h24aVf40;Y?`uR}E*N=koj*?kcj6HKDcWTb>rGAr7W)+)!EHjkjrnext = node1; node1->next = node2; route = node1; - node1->length = node2->length = tspi->getdist(node1->value, node2->value); + node2->length = tspi->getdist(node1->value, node2->value); + node1->length = tspi->getdist(node2->value, node1->value); } for (unsigned i = 2; i < cc; i++) From fb4c99801183e857b98e3781b11ad1ef386f21a5 Mon Sep 17 00:00:00 2001 From: henry-yeh Date: Tue, 16 Jul 2024 17:03:37 +0800 Subject: [PATCH 03/10] init ASHPPEnv --- eval_atsp/ASHPPEnv.py | 162 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 162 insertions(+) create mode 100644 eval_atsp/ASHPPEnv.py diff --git a/eval_atsp/ASHPPEnv.py b/eval_atsp/ASHPPEnv.py new file mode 100644 index 0000000..4195679 --- /dev/null +++ b/eval_atsp/ASHPPEnv.py @@ -0,0 +1,162 @@ + +""" +The MIT License + +Copyright (c) 2021 MatNet + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + + + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from dataclasses import dataclass +import torch + +from ATSProblemDef import get_random_problems + + +@dataclass +class Reset_State: + problems: torch.Tensor + # shape: (batch, node, node) + + +@dataclass +class Step_State: + BATCH_IDX: torch.Tensor + POMO_IDX: torch.Tensor + # shape: (batch, pomo) + current_node: torch.Tensor = None + # shape: (batch, pomo) + ninf_mask: torch.Tensor = None + # shape: (batch, pomo, node) + + +class ASHPPEnv: + def __init__(self, **env_params): + + # Const @INIT + #################################### + self.env_params = env_params + self.node_cnt = env_params['node_cnt'] + self.pomo_size = env_params['pomo_size'] + + # Const @Load_Problem + #################################### + self.batch_size = None + self.BATCH_IDX = None + self.POMO_IDX = None + # IDX.shape: (batch, pomo) + self.problems = None + # shape: (batch, node, node) + + # Dynamic + #################################### + self.selected_count = None + self.current_node = None + # shape: (batch, pomo) + self.selected_node_list = None + # shape: (batch, pomo, 0~) + + # STEP-State + #################################### + self.step_state = None + + def load_problems(self, batch_size): + self.batch_size = batch_size + self.BATCH_IDX = torch.arange(self.batch_size)[:, None].expand(self.batch_size, self.pomo_size) + self.POMO_IDX = torch.arange(self.pomo_size)[None, :].expand(self.batch_size, self.pomo_size) + + problem_gen_params = self.env_params['problem_gen_params'] + self.problems = get_random_problems(batch_size, self.node_cnt, problem_gen_params) + # shape: (batch, node, node) + + def load_problems_manual(self, problems): + # problems.shape: (batch, node, node) + + self.batch_size = problems.size(0) + self.BATCH_IDX = torch.arange(self.batch_size)[:, None].expand(self.batch_size, self.pomo_size) + self.POMO_IDX = torch.arange(self.pomo_size)[None, :].expand(self.batch_size, self.pomo_size) + self.problems = problems + # shape: (batch, node, node) + + def reset(self): + self.selected_count = 0 + self.current_node = None + # shape: (batch, pomo) + self.selected_node_list = torch.empty((self.batch_size, self.pomo_size, 0), dtype=torch.long) + # shape: (batch, pomo, 0~) + + self._create_step_state() + + reward = None + done = False + return Reset_State(problems=self.problems), reward, done + + def _create_step_state(self): + self.step_state = Step_State(BATCH_IDX=self.BATCH_IDX, POMO_IDX=self.POMO_IDX) + self.step_state.ninf_mask = torch.zeros((self.batch_size, self.pomo_size, self.node_cnt)) + # shape: (batch, pomo, node) + + def pre_step(self): + reward = None + done = False + return self.step_state, reward, done + + def step(self, node_idx): + # node_idx.shape: (batch, pomo) + + self.selected_count += 1 + self.current_node = node_idx + # shape: (batch, pomo) + self.selected_node_list = torch.cat((self.selected_node_list, self.current_node[:, :, None]), dim=2) + # shape: (batch, pomo, 0~node) + + self._update_step_state() + + # returning values + done = (self.selected_count == self.node_cnt) + if done: + reward = -self._get_total_distance() # Note the MINUS Sign ==> We MAXIMIZE reward + # shape: (batch, pomo) + else: + reward = None + return self.step_state, reward, done + + def _update_step_state(self): + self.step_state.current_node = self.current_node + # shape: (batch, pomo) + self.step_state.ninf_mask[self.BATCH_IDX, self.POMO_IDX, self.current_node] = float('-inf') + # shape: (batch, pomo, node) + + def _get_total_distance(self): + + node_from = self.selected_node_list + # shape: (batch, pomo, node) + node_to = self.selected_node_list.roll(dims=2, shifts=-1) + # shape: (batch, pomo, node) + batch_index = self.BATCH_IDX[:, :, None].expand(self.batch_size, self.pomo_size, self.node_cnt) + # shape: (batch, pomo, node) + + selected_cost = self.problems[batch_index, node_from, node_to] + # shape: (batch, pomo, node) + total_distance = selected_cost.sum(2) + # shape: (batch, pomo) + + return total_distance From 4fd5528c2a5b72eb16e2a6cabea9fa890758c55b Mon Sep 17 00:00:00 2001 From: henry-yeh Date: Tue, 16 Jul 2024 20:00:09 +0800 Subject: [PATCH 04/10] add ASHPP model --- eval_atsp/ASHPPModel.py | 355 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 355 insertions(+) create mode 100644 eval_atsp/ASHPPModel.py diff --git a/eval_atsp/ASHPPModel.py b/eval_atsp/ASHPPModel.py new file mode 100644 index 0000000..09114a5 --- /dev/null +++ b/eval_atsp/ASHPPModel.py @@ -0,0 +1,355 @@ + +""" +The MIT License + +Copyright (c) 2021 MatNet + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + + + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +from ATSPModel_LIB import AddAndInstanceNormalization, FeedForward, MixedScore_MultiHeadAttention + + +class ATSPModel(nn.Module): + + def __init__(self, **model_params): + super().__init__() + self.model_params = model_params + + self.encoder = ATSP_Encoder(**model_params) + self.decoder = ATSP_Decoder(**model_params) + + self.encoded_row = None + self.encoded_col = None + # shape: (batch, node, embedding) + + def pre_forward(self, reset_state): + + problems = reset_state.problems + # problems.shape: (batch, node, node) + + batch_size = problems.size(0) + node_cnt = problems.size(1) + embedding_dim = self.model_params['embedding_dim'] + + row_emb = torch.zeros(size=(batch_size, node_cnt, embedding_dim)) + # emb.shape: (batch, node, embedding) + col_emb = torch.zeros(size=(batch_size, node_cnt, embedding_dim)) + # shape: (batch, node, embedding) + + seed_cnt = self.model_params['one_hot_seed_cnt'] + rand = torch.rand(batch_size, seed_cnt) + batch_rand_perm = rand.argsort(dim=1) + rand_idx = batch_rand_perm[:, :node_cnt] + + b_idx = torch.arange(batch_size)[:, None].expand(batch_size, node_cnt) + n_idx = torch.arange(node_cnt)[None, :].expand(batch_size, node_cnt) + col_emb[b_idx, n_idx, rand_idx] = 1 + # shape: (batch, node, embedding) + + self.encoded_row, self.encoded_col = self.encoder(row_emb, col_emb, problems) + # encoded_nodes.shape: (batch, node, embedding) + + self.decoder.set_kv(self.encoded_col) + + def forward(self, state): + + batch_size = state.BATCH_IDX.size(0) + pomo_size = state.BATCH_IDX.size(1) + + encoded_current_row = _get_encoding(self.encoded_row, state.current_node) + + if (state.current_node == 0).all(): + # selected = torch.arange(pomo_size)[None, :].expand(batch_size, pomo_size) + # prob = torch.ones(size=(batch_size, pomo_size)) + + # encoded_rows_mean = self.encoded_row.mean(dim=1, keepdim=True) + # encoded_cols_mean = self.encoded_col.mean(dim=1, keepdim=True) + # # shape: (batch, 1, embedding) + # encoded_first_row = _get_encoding(self.encoded_row, state.current_node) + # shape: (batch, pomo, embedding) + self.decoder.set_q1(encoded_current_row) + + + # shape: (batch, pomo, embedding) + all_job_probs = self.decoder(encoded_current_row, ninf_mask=state.ninf_mask) + # shape: (batch, pomo, job) + + if self.training or self.model_params['eval_type'] == 'softmax': + while True: # to fix pytorch.multinomial bug on selecting 0 probability elements + with torch.no_grad(): + selected = all_job_probs.reshape(batch_size * pomo_size, -1).multinomial(1) \ + .squeeze(dim=1).reshape(batch_size, pomo_size) + # shape: (batch, pomo) + + prob = all_job_probs[state.BATCH_IDX, state.POMO_IDX, selected] \ + .reshape(batch_size, pomo_size) + # shape: (batch, pomo) + + if (prob != 0).all(): + break + else: + selected = all_job_probs.argmax(dim=2) + # shape: (batch, pomo) + prob = None + + return selected, prob + + +def _get_encoding(encoded_nodes, node_index_to_pick): + # encoded_nodes.shape: (batch, problem, embedding) + # node_index_to_pick.shape: (batch, pomo) + + batch_size = node_index_to_pick.size(0) + pomo_size = node_index_to_pick.size(1) + embedding_dim = encoded_nodes.size(2) + + gathering_index = node_index_to_pick[:, :, None].expand(batch_size, pomo_size, embedding_dim) + # shape: (batch, pomo, embedding) + + picked_nodes = encoded_nodes.gather(dim=1, index=gathering_index) + # shape: (batch, pomo, embedding) + + return picked_nodes + + +######################################## +# ENCODER +######################################## +class ATSP_Encoder(nn.Module): + def __init__(self, **model_params): + super().__init__() + encoder_layer_num = model_params['encoder_layer_num'] + self.layers = nn.ModuleList([EncoderLayer(**model_params) for _ in range(encoder_layer_num)]) + + def forward(self, row_emb, col_emb, cost_mat): + # col_emb.shape: (batch, col_cnt, embedding) + # row_emb.shape: (batch, row_cnt, embedding) + # cost_mat.shape: (batch, row_cnt, col_cnt) + + for layer in self.layers: + row_emb, col_emb = layer(row_emb, col_emb, cost_mat) + + return row_emb, col_emb + + +class EncoderLayer(nn.Module): + def __init__(self, **model_params): + super().__init__() + self.row_encoding_block = EncodingBlock(**model_params) + self.col_encoding_block = EncodingBlock(**model_params) + + def forward(self, row_emb, col_emb, cost_mat): + # row_emb.shape: (batch, row_cnt, embedding) + # col_emb.shape: (batch, col_cnt, embedding) + # cost_mat.shape: (batch, row_cnt, col_cnt) + row_emb_out = self.row_encoding_block(row_emb, col_emb, cost_mat) + col_emb_out = self.col_encoding_block(col_emb, row_emb, cost_mat.transpose(1, 2)) + + return row_emb_out, col_emb_out + + +class EncodingBlock(nn.Module): + def __init__(self, **model_params): + super().__init__() + self.model_params = model_params + embedding_dim = self.model_params['embedding_dim'] + head_num = self.model_params['head_num'] + qkv_dim = self.model_params['qkv_dim'] + + self.Wq = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False) + self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False) + self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False) + self.mixed_score_MHA = MixedScore_MultiHeadAttention(**model_params) + self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim) + + self.add_n_normalization_1 = AddAndInstanceNormalization(**model_params) + self.feed_forward = FeedForward(**model_params) + self.add_n_normalization_2 = AddAndInstanceNormalization(**model_params) + + def forward(self, row_emb, col_emb, cost_mat): + # NOTE: row and col can be exchanged, if cost_mat.transpose(1,2) is used + # input1.shape: (batch, row_cnt, embedding) + # input2.shape: (batch, col_cnt, embedding) + # cost_mat.shape: (batch, row_cnt, col_cnt) + head_num = self.model_params['head_num'] + + q = reshape_by_heads(self.Wq(row_emb), head_num=head_num) + # q shape: (batch, head_num, row_cnt, qkv_dim) + k = reshape_by_heads(self.Wk(col_emb), head_num=head_num) + v = reshape_by_heads(self.Wv(col_emb), head_num=head_num) + # kv shape: (batch, head_num, col_cnt, qkv_dim) + + out_concat = self.mixed_score_MHA(q, k, v, cost_mat) + # shape: (batch, row_cnt, head_num*qkv_dim) + + multi_head_out = self.multi_head_combine(out_concat) + # shape: (batch, row_cnt, embedding) + + out1 = self.add_n_normalization_1(row_emb, multi_head_out) + out2 = self.feed_forward(out1) + out3 = self.add_n_normalization_2(out1, out2) + + return out3 + # shape: (batch, row_cnt, embedding) + + +######################################## +# Decoder +######################################## + +class ATSP_Decoder(nn.Module): + def __init__(self, **model_params): + super().__init__() + self.model_params = model_params + embedding_dim = self.model_params['embedding_dim'] + head_num = self.model_params['head_num'] + qkv_dim = self.model_params['qkv_dim'] + + self.Wq_0 = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False) + self.Wq_1 = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False) + self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False) + self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False) + + self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim) + + self.k = None # saved key, for multi-head attention + self.v = None # saved value, for multi-head_attention + self.single_head_key = None # saved key, for single-head attention + self.q1 = None # saved q1, for multi-head attention + + def set_kv(self, encoded_jobs): + # encoded_jobs.shape: (batch, job, embedding) + head_num = self.model_params['head_num'] + + self.k = reshape_by_heads(self.Wk(encoded_jobs), head_num=head_num) + self.v = reshape_by_heads(self.Wv(encoded_jobs), head_num=head_num) + # shape: (batch, head_num, job, qkv_dim) + self.single_head_key = encoded_jobs.transpose(1, 2) + # shape: (batch, embedding, job) + + def set_q1(self, encoded_q1): + # encoded_q.shape: (batch, n, embedding) # n can be 1 or pomo + head_num = self.model_params['head_num'] + + self.q1 = reshape_by_heads(self.Wq_1(encoded_q1), head_num=head_num) + # shape: (batch, head_num, n, qkv_dim) + + def forward(self, encoded_q0, ninf_mask): + # encoded_q4.shape: (batch, pomo, embedding) + # ninf_mask.shape: (batch, pomo, job) + + head_num = self.model_params['head_num'] + + # Multi-Head Attention + ####################################################### + q0 = reshape_by_heads(self.Wq_0(encoded_q0), head_num=head_num) + # shape: (batch, head_num, pomo, qkv_dim) + + q = self.q1 + q0 + # shape: (batch, head_num, pomo, qkv_dim) + + out_concat = self._multi_head_attention(q, self.k, self.v, rank3_ninf_mask=ninf_mask) + # shape: (batch, pomo, head_num*qkv_dim) + + mh_atten_out = self.multi_head_combine(out_concat) + # shape: (batch, pomo, embedding) + + # Single-Head Attention, for probability calculation + ####################################################### + score = torch.matmul(mh_atten_out, self.single_head_key) + # shape: (batch, pomo, job) + + sqrt_embedding_dim = self.model_params['sqrt_embedding_dim'] + logit_clipping = self.model_params['logit_clipping'] + + score_scaled = score / sqrt_embedding_dim + # shape: (batch, pomo, job) + + score_clipped = logit_clipping * torch.tanh(score_scaled) + + score_masked = score_clipped + ninf_mask + + probs = F.softmax(score_masked, dim=2) + # shape: (batch, pomo, job) + + return probs + + def _multi_head_attention(self, q, k, v, rank2_ninf_mask=None, rank3_ninf_mask=None): + # q shape: (batch, head_num, n, key_dim) : n can be either 1 or pomo + # k,v shape: (batch, head_num, node, key_dim) + # rank2_ninf_mask.shape: (batch, node) + # rank3_ninf_mask.shape: (batch, group, node) + + batch_s = q.size(0) + n = q.size(2) + node_cnt = k.size(2) + + head_num = self.model_params['head_num'] + qkv_dim = self.model_params['qkv_dim'] + sqrt_qkv_dim = self.model_params['sqrt_qkv_dim'] + + score = torch.matmul(q, k.transpose(2, 3)) + # shape: (batch, head_num, n, node) + + score_scaled = score / sqrt_qkv_dim + if rank2_ninf_mask is not None: + score_scaled = score_scaled + rank2_ninf_mask[:, None, None, :].expand(batch_s, head_num, n, node_cnt) + if rank3_ninf_mask is not None: + score_scaled = score_scaled + rank3_ninf_mask[:, None, :, :].expand(batch_s, head_num, n, node_cnt) + + weights = nn.Softmax(dim=3)(score_scaled) + # shape: (batch, head_num, n, node) + + out = torch.matmul(weights, v) + # shape: (batch, head_num, n, key_dim) + + out_transposed = out.transpose(1, 2) + # shape: (batch, n, head_num, key_dim) + + out_concat = out_transposed.reshape(batch_s, n, head_num * qkv_dim) + # shape: (batch, n, head_num*key_dim) + + return out_concat + + +######################################## +# NN SUB FUNCTIONS +######################################## + +def reshape_by_heads(qkv, head_num): + # q.shape: (batch, n, head_num*key_dim) : n can be either 1 or PROBLEM_SIZE + + batch_s = qkv.size(0) + n = qkv.size(1) + + q_reshaped = qkv.reshape(batch_s, n, head_num, -1) + # shape: (batch, n, head_num, key_dim) + + q_transposed = q_reshaped.transpose(1, 2) + # shape: (batch, head_num, n, key_dim) + + return q_transposed From 722263549d4b914b90d39e9b60416d018a82dc74 Mon Sep 17 00:00:00 2001 From: henry-yeh Date: Tue, 16 Jul 2024 20:00:32 +0800 Subject: [PATCH 05/10] update ASHPP env --- eval_atsp/ASHPPEnv.py | 40 +++++++++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/eval_atsp/ASHPPEnv.py b/eval_atsp/ASHPPEnv.py index 4195679..857b7be 100644 --- a/eval_atsp/ASHPPEnv.py +++ b/eval_atsp/ASHPPEnv.py @@ -27,6 +27,7 @@ from dataclasses import dataclass import torch +import warnings from ATSProblemDef import get_random_problems @@ -55,7 +56,7 @@ def __init__(self, **env_params): #################################### self.env_params = env_params self.node_cnt = env_params['node_cnt'] - self.pomo_size = env_params['pomo_size'] + self.pomo_size = env_params['pomo_size'] # pomo size if sample size here # Const @Load_Problem #################################### @@ -97,14 +98,16 @@ def load_problems_manual(self, problems): # shape: (batch, node, node) def reset(self): - self.selected_count = 0 - self.current_node = None + self.selected_count = 2 # Add starting and terminating ndoes + # Set current nodes as 0 + self.current_node = torch.zeros((self.batch_size, self.pomo_size), dtype=torch.long) + # shape: (batch, pomo) - self.selected_node_list = torch.empty((self.batch_size, self.pomo_size, 0), dtype=torch.long) + self.selected_node_list = self.current_node[:, :, None] # shape: (batch, pomo, 0~) - + self._create_step_state() - + reward = None done = False return Reset_State(problems=self.problems), reward, done @@ -117,6 +120,14 @@ def _create_step_state(self): def pre_step(self): reward = None done = False + + # Set the starting and terminating nodes to -inf + self.step_state.ninf_mask[self.BATCH_IDX, self.POMO_IDX, 0] = float('-inf') + self.step_state.ninf_mask[self.BATCH_IDX, self.POMO_IDX, -1] = float('-inf') + + # Set current node to 0 + self.step_state.current_node = self.current_node + return self.step_state, reward, done def step(self, node_idx): @@ -133,6 +144,9 @@ def step(self, node_idx): # returning values done = (self.selected_count == self.node_cnt) if done: + # Concat the terminating node (the last node) to the selected node list + self.current_node = torch.ones((self.batch_size, self.pomo_size), dtype=torch.long) * (self.node_cnt - 1) + self.selected_node_list = torch.cat((self.selected_node_list, self.current_node[:, :, None]), dim=2) reward = -self._get_total_distance() # Note the MINUS Sign ==> We MAXIMIZE reward # shape: (batch, pomo) else: @@ -147,15 +161,15 @@ def _update_step_state(self): def _get_total_distance(self): - node_from = self.selected_node_list - # shape: (batch, pomo, node) - node_to = self.selected_node_list.roll(dims=2, shifts=-1) - # shape: (batch, pomo, node) - batch_index = self.BATCH_IDX[:, :, None].expand(self.batch_size, self.pomo_size, self.node_cnt) - # shape: (batch, pomo, node) + node_from = self.selected_node_list[:, :, :-1] + # shape: (batch, pomo, node - 1) + node_to = self.selected_node_list.roll(dims=2, shifts=-1)[:, :, :-1] + # shape: (batch, pomo, node - 1) + batch_index = self.BATCH_IDX[:, :, None].expand(self.batch_size, self.pomo_size, self.node_cnt - 1) + # shape: (batch, pomo, node - 1) selected_cost = self.problems[batch_index, node_from, node_to] - # shape: (batch, pomo, node) + # shape: (batch, pomo, node - 1) total_distance = selected_cost.sum(2) # shape: (batch, pomo) From af52dbde4bc3e808f590ab747308aaa608e85539 Mon Sep 17 00:00:00 2001 From: henry-yeh Date: Wed, 17 Jul 2024 01:07:42 +0800 Subject: [PATCH 06/10] =?UTF-8?q?[=F0=9F=90=9B]=20Bugfix;=20[feat]=20retur?= =?UTF-8?q?n=20tour?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- eval_atsp/ASHPPModel.py | 3 +- eval_atsp/ATSPTester_glop.py | 44 +++++++++++------ eval_atsp/test_glop.py | 94 ++++++++++++++++++++++++------------ 3 files changed, 93 insertions(+), 48 deletions(-) diff --git a/eval_atsp/ASHPPModel.py b/eval_atsp/ASHPPModel.py index 09114a5..6eec377 100644 --- a/eval_atsp/ASHPPModel.py +++ b/eval_atsp/ASHPPModel.py @@ -33,7 +33,7 @@ from ATSPModel_LIB import AddAndInstanceNormalization, FeedForward, MixedScore_MultiHeadAttention -class ATSPModel(nn.Module): +class ASHPPModel(nn.Module): def __init__(self, **model_params): super().__init__() @@ -112,6 +112,7 @@ def forward(self, state): if (prob != 0).all(): break else: + assert self.model_params['eval_type'] == 'greedy' selected = all_job_probs.argmax(dim=2) # shape: (batch, pomo) prob = None diff --git a/eval_atsp/ATSPTester_glop.py b/eval_atsp/ATSPTester_glop.py index c3e318e..550e3f1 100644 --- a/eval_atsp/ATSPTester_glop.py +++ b/eval_atsp/ATSPTester_glop.py @@ -30,8 +30,8 @@ import os from logging import getLogger -from ATSPEnv import ATSPEnv as Env -from ATSPModel import ATSPModel as Model +from ASHPPEnv import ASHPPEnv as Env +from ASHPPModel import ASHPPModel as Model from utils_atsp.utils import get_result_folder, AverageMeter, TimeEstimator @@ -91,16 +91,20 @@ def run(self, insts): test_num_episode = insts.size(0) episode = 0 - ret = [] + scores = [] + + solutions = [] while episode < test_num_episode: remaining = test_num_episode - episode batch_size = min(self.tester_params['test_batch_size'], remaining) - aug_score = self._test_one_batch(episode, episode+batch_size, insts) + aug_score, batch_solutions = self._test_one_batch(episode, episode+batch_size, insts) + + scores.append(aug_score) - ret.append(aug_score) + solutions.append(batch_solutions) episode += batch_size @@ -117,7 +121,11 @@ def run(self, insts): # self.logger.info(" *** Test Done *** ") # self.logger.info(" NO-AUG SCORE: {:.4f} ".format(score_AM.avg)) # self.logger.info(" AUGMENTATION SCORE: {:.4f} ".format(aug_score_AM.avg)) - return ret + + scores = torch.cat(scores, dim=0) + solutions = torch.cat(solutions, dim=0) + + return scores, solutions def _test_one_batch(self, idx_start, idx_end, insts): @@ -127,6 +135,7 @@ def _test_one_batch(self, idx_start, idx_end, insts): # Augmentation ############################################### if self.tester_params['augmentation_enable']: + assert False, "Augmentation is not supported" aug_factor = self.tester_params['aug_factor'] batch_size = aug_factor*batch_size @@ -152,16 +161,19 @@ def _test_one_batch(self, idx_start, idx_end, insts): # Return ############################################### - batch_size = batch_size//aug_factor - aug_reward = reward.reshape(aug_factor, batch_size, self.env.pomo_size) - # shape: (augmentation, batch, pomo) - - max_pomo_reward, _ = aug_reward.max(dim=2) # get best results from pomo - # shape: (augmentation, batch) - no_aug_score = -max_pomo_reward[0, :].float().mean() # negative sign to make positive value + aug_reward = reward.reshape(batch_size, self.env.pomo_size) + # shape: (batch, pomo) + + # Get solutions + solutions = self.env.selected_node_list + # shape: (batch, pomo, node_cnt) + + max_pomo_reward, max_pomo_reward_idx = aug_reward.max(dim=1) # get best results from pomo + # shape: (batch) - max_aug_pomo_reward, _ = max_pomo_reward.max(dim=0) # get best results from augmentation - # shape: (batch,) - return -max_aug_pomo_reward.float() # negative sign to make positive value + optimal_solution = solutions[torch.arange(batch_size), max_pomo_reward_idx] + # shape: (batch, node_cnt) + + return max_pomo_reward.float(), optimal_solution # negative sign to make positive value diff --git a/eval_atsp/test_glop.py b/eval_atsp/test_glop.py index 0a297a2..3975a78 100644 --- a/eval_atsp/test_glop.py +++ b/eval_atsp/test_glop.py @@ -55,14 +55,21 @@ ########################################################################################## # parameters +##### GLOP parameters ##### +N_REVISER = 50 # We only test on Reviser-50; using more revisers requires code modifications +N_REVISIONS = 3 +N_SAMPLES = 100 # for sampling decoding during revision + + + env_params = { - 'node_cnt': 50, + 'node_cnt': N_REVISER, 'problem_gen_params': { 'int_min': 0, 'int_max': 1000*1000, 'scaler': 1000*1000 }, - 'pomo_size': 50 # same as node_cnt + 'pomo_size': N_SAMPLES } model_params = { @@ -77,7 +84,7 @@ 'ms_hidden_dim': 16, 'ms_layer1_init': (1/2)**(1/2), 'ms_layer2_init': (1/16)**(1/2), - 'eval_type': 'softmax', + 'eval_type': 'softmax', # note here, can be greedy 'one_hot_seed_cnt': 20, # must be >= node_cnt } @@ -90,8 +97,8 @@ }, 'saved_problem_folder': "../data/n20", 'saved_problem_filename': 'problem_20_0_1000000_{}.atsp', - 'test_batch_size': 1, - 'augmentation_enable': False, + 'test_batch_size': 999999, # Note this batch size is for revision + 'augmentation_enable': False, # No augementation for GLOP; requiring code modifications to enable 'aug_factor': 1, 'aug_batch_size': 1, } @@ -107,35 +114,55 @@ } -########################################################################################## -# main -L = 1.5 +########################################################################################## +# main def revision(tour, inst, tester): - revision_len = env_params['node_cnt'] - assert revision_len == 50 - sub_tours = tour.reshape(-1, revision_len) + sub_tours = tour.reshape(-1, N_REVISER) # shape: (batch, revision_len) sub_insts = [inst[sub_tour][:, sub_tour] for sub_tour in sub_tours] - original_scores = torch.stack([inst[sub_tour[:-1], torch.roll(sub_tour, shifts=-1)[:-1]].sum() for sub_tour in sub_tours]) - for sub_inst in sub_insts: # equivalent ATSP of each ASHPP - sub_inst[:, 0] += L - sub_inst[:, -1] += L - sub_inst[0, :] += L - sub_inst[-1, :] += L - sub_inst[0, 0] = sub_inst[0, -1] = sub_inst[-1, 0] = sub_inst[-1, -1] = 0 + original_scores = torch.stack([inst[sub_tour[:-1], torch.roll(sub_tour, shifts=-1)[:-1]].sum() for sub_tour in sub_tours]) # note that original_scores are positive values + # Scale the sub_insts to make the largest value 1 + scale_coef = [sub_inst.max() for sub_inst in sub_insts] sub_insts = torch.stack(sub_insts) + sub_insts_scaled = sub_insts / torch.tensor(scale_coef)[:, None, None] + + # Main part of the revision + revised_scores, solutions = tester.run(sub_insts_scaled) # solutions shape: (batch, revision_len) + + # Scale back the revised scores + revised_scores = - revised_scores * torch.tensor(scale_coef) # shape: (batch,); add negative sign to make positive value - revised_scores = torch.stack(tester.run(sub_insts)) - 2 * L + # TODO: unmcomment to validate the subtours + for i in range(len(sub_insts)): + validate_subtour(solutions[i], sub_insts[i], revised_scores[i]) - improved = original_scores - revised_scores - improved[improved < 0] = 0 + # Gather the subtours according to the solutions + revised_tours = sub_tours.gather(1, solutions) - return improved.sum().item() + # Compare the original scores and the revised scores + improved_scores = original_scores - revised_scores + # subtours should be aranged in the same order as the original tours, if the improved_scores <= 0 + kept_subtour_idx = improved_scores <= 0 + revised_tours[kept_subtour_idx] = sub_tours[kept_subtour_idx] + + return revised_tours + +def validate_subtour(subtour, dist, cost): + truth_cost = cal_len_shpp(subtour, dist) + assert truth_cost - cost < 1e-5 + # Assert subtour is a valid tour: (1) the starting node is 0 and the terminal node is len(subtour) - 1; (2) all nodes are visited exactly once. + assert subtour[0] == 0 and subtour[-1] == len(subtour) - 1 + for i in range(1, len(subtour) - 1): + assert i in subtour def calc_len(tour, dist): cost = dist[tour, torch.roll(tour, -1, -1)].sum() - return cost + return cost.item() + +def cal_len_shpp(tour, dist): + cost = dist[tour[:-1], tour[1:]].sum() + return cost.item() def main(n): dataset = torch.load('../data/atsp/ATSP{}.pt'.format(n), map_location='cuda:0') @@ -152,7 +179,9 @@ def main(n): original_costs = [] revised_costs = [] - true_cost = [] + # true_cost = [] + + N_SHIFTS = N_REVISER // N_REVISIONS start = time.time() for inst in dataset: @@ -160,17 +189,20 @@ def main(n): original_costs.append(cost) tour = torch.tensor(tour.astype(np.int64)) - cost = calc_len(tour, inst).item() - true_cost.append(cost) - - improved_cost = revision(tour, inst, tester) - revised_costs.append(cost - improved_cost) + for revision_iter in range(N_REVISIONS): + tour = revision(tour, inst, tester) + # Shift the tour to the right by N_SHIFTS + tour = torch.roll(tour, shifts=N_SHIFTS, dims=-1) + # cost = calc_len(tour, inst) + # print(f"cost after revision {revision_iter}: {cost}") + cost = calc_len(tour, inst) + revised_costs.append(cost) + total_duration = time.time() - start print("insertion costs: ", sum(original_costs) / len(original_costs)) - print("initial true costs:", sum(true_cost) / len(true_cost)) - print("revised true costs:", sum(revised_costs) / len(revised_costs)) + print("revised costs:", sum(revised_costs) / len(revised_costs)) print("total duration: ", total_duration) From 55c5f8a0f0b4fed7cd7bf4a2fecc179ef7f175c6 Mon Sep 17 00:00:00 2001 From: henry-yeh Date: Sun, 21 Jul 2024 10:26:18 +0800 Subject: [PATCH 07/10] update context query for SHPP model --- eval_atsp/ASHPPEnv.py | 5 +++++ eval_atsp/ASHPPModel.py | 14 +++----------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/eval_atsp/ASHPPEnv.py b/eval_atsp/ASHPPEnv.py index 857b7be..7bc9863 100644 --- a/eval_atsp/ASHPPEnv.py +++ b/eval_atsp/ASHPPEnv.py @@ -101,6 +101,8 @@ def reset(self): self.selected_count = 2 # Add starting and terminating ndoes # Set current nodes as 0 self.current_node = torch.zeros((self.batch_size, self.pomo_size), dtype=torch.long) + # Set the last node as node - 1 + self.last_node = torch.ones((self.batch_size, self.pomo_size), dtype=torch.long) * (self.node_cnt - 1) # shape: (batch, pomo) self.selected_node_list = self.current_node[:, :, None] @@ -127,6 +129,9 @@ def pre_step(self): # Set current node to 0 self.step_state.current_node = self.current_node + # Set last node to node - 1 + self.step_state.last_node = self.last_node + return self.step_state, reward, done diff --git a/eval_atsp/ASHPPModel.py b/eval_atsp/ASHPPModel.py index 6eec377..a07903c 100644 --- a/eval_atsp/ASHPPModel.py +++ b/eval_atsp/ASHPPModel.py @@ -80,19 +80,11 @@ def forward(self, state): batch_size = state.BATCH_IDX.size(0) pomo_size = state.BATCH_IDX.size(1) - encoded_current_row = _get_encoding(self.encoded_row, state.current_node) - if (state.current_node == 0).all(): - # selected = torch.arange(pomo_size)[None, :].expand(batch_size, pomo_size) - # prob = torch.ones(size=(batch_size, pomo_size)) - - # encoded_rows_mean = self.encoded_row.mean(dim=1, keepdim=True) - # encoded_cols_mean = self.encoded_col.mean(dim=1, keepdim=True) - # # shape: (batch, 1, embedding) - # encoded_first_row = _get_encoding(self.encoded_row, state.current_node) - # shape: (batch, pomo, embedding) - self.decoder.set_q1(encoded_current_row) + encoded_last_row = _get_encoding(self.encoded_row, state.last_node) + self.decoder.set_q1(encoded_last_row) + encoded_current_row = _get_encoding(self.encoded_row, state.current_node) # shape: (batch, pomo, embedding) all_job_probs = self.decoder(encoded_current_row, ninf_mask=state.ninf_mask) From 3163f439f6b3860d999e14302b2129218dd1805a Mon Sep 17 00:00:00 2001 From: henry-yeh Date: Sun, 21 Jul 2024 10:26:33 +0800 Subject: [PATCH 08/10] add training --- eval_atsp/ASHPPTrainer.py | 222 ++++++++++++++++++++++++++++++++++++++ eval_atsp/train_glop.py | 169 +++++++++++++++++++++++++++++ 2 files changed, 391 insertions(+) create mode 100644 eval_atsp/ASHPPTrainer.py create mode 100644 eval_atsp/train_glop.py diff --git a/eval_atsp/ASHPPTrainer.py b/eval_atsp/ASHPPTrainer.py new file mode 100644 index 0000000..95b61dd --- /dev/null +++ b/eval_atsp/ASHPPTrainer.py @@ -0,0 +1,222 @@ + +""" +The MIT License + +Copyright (c) 2021 MatNet + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + + + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import torch +from logging import getLogger + +from ASHPPEnv import ASHPPEnv as Env +from ASHPPModel import ASHPPModel as Model + +from torch.optim import Adam as Optimizer +from torch.optim.lr_scheduler import MultiStepLR as Scheduler + +from utils_atsp.utils import * + + +class ASHPPTrainer: + def __init__(self, + env_params, + model_params, + optimizer_params, + trainer_params): + + # save arguments + self.env_params = env_params + self.model_params = model_params + self.optimizer_params = optimizer_params + self.trainer_params = trainer_params + + # result folder, logger + self.logger = getLogger(name='trainer') + self.result_folder = get_result_folder() + self.result_log = LogData() + + # cuda + USE_CUDA = self.trainer_params['use_cuda'] + if USE_CUDA: + cuda_device_num = self.trainer_params['cuda_device_num'] + torch.cuda.set_device(cuda_device_num) + device = torch.device('cuda', cuda_device_num) + torch.set_default_tensor_type('torch.cuda.FloatTensor') + else: + device = torch.device('cpu') + torch.set_default_tensor_type('torch.FloatTensor') + + # Main Components + self.model = Model(**self.model_params) + self.env = Env(**self.env_params) + self.optimizer = Optimizer(self.model.parameters(), **self.optimizer_params['optimizer']) + self.scheduler = Scheduler(self.optimizer, **self.optimizer_params['scheduler']) + + # Restore + self.start_epoch = 1 + model_load = trainer_params['model_load'] + if model_load['enable']: + checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load) + checkpoint = torch.load(checkpoint_fullname, map_location=device) + self.model.load_state_dict(checkpoint['model_state_dict']) + self.start_epoch = 1 + model_load['epoch'] + self.result_log.set_raw_data(checkpoint['result_log']) + self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + self.scheduler.last_epoch = model_load['epoch']-1 + self.logger.info('Saved Model Loaded !!') + + # utility + self.time_estimator = TimeEstimator() + + def run(self): + self.time_estimator.reset(self.start_epoch) + for epoch in range(self.start_epoch, self.trainer_params['epochs']+1): + self.logger.info('=================================================================') + + # LR Decay + self.scheduler.step() + + # Train + train_score, train_loss = self._train_one_epoch(epoch) + self.result_log.append('train_score', epoch, train_score) + self.result_log.append('train_loss', epoch, train_loss) + + ############################ + # Logs & Checkpoint + ############################ + elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(epoch, self.trainer_params['epochs']) + self.logger.info("Epoch {:3d}/{:3d}: Time Est.: Elapsed[{}], Remain[{}]".format( + epoch, self.trainer_params['epochs'], elapsed_time_str, remain_time_str)) + + all_done = (epoch == self.trainer_params['epochs']) + model_save_interval = self.trainer_params['logging']['model_save_interval'] + img_save_interval = self.trainer_params['logging']['img_save_interval'] + + if epoch > 1: # save latest images, every epoch + self.logger.info("Saving log_image") + image_prefix = '{}/latest'.format(self.result_folder) + util_save_log_image_with_label(image_prefix, self.trainer_params['logging']['log_image_params_1'], + self.result_log, labels=['train_score']) + util_save_log_image_with_label(image_prefix, self.trainer_params['logging']['log_image_params_2'], + self.result_log, labels=['train_loss']) + + if all_done or (epoch % model_save_interval) == 0: + self.logger.info("Saving trained_model") + checkpoint_dict = { + 'epoch': epoch, + 'model_state_dict': self.model.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'scheduler_state_dict': self.scheduler.state_dict(), + 'result_log': self.result_log.get_raw_data() + } + torch.save(checkpoint_dict, '{}/checkpoint-{}.pt'.format(self.result_folder, epoch)) + + if all_done or (epoch % img_save_interval) == 0: + image_prefix = '{}/img/checkpoint-{}'.format(self.result_folder, epoch) + util_save_log_image_with_label(image_prefix, self.trainer_params['logging']['log_image_params_1'], + self.result_log, labels=['train_score']) + util_save_log_image_with_label(image_prefix, self.trainer_params['logging']['log_image_params_2'], + self.result_log, labels=['train_loss']) + + if all_done: + self.logger.info(" *** Training Done *** ") + self.logger.info("Now, printing log array...") + util_print_log_array(self.logger, self.result_log) + + def _train_one_epoch(self, epoch): + + score_AM = AverageMeter() + loss_AM = AverageMeter() + + train_num_episode = self.trainer_params['train_episodes'] + episode = 0 + loop_cnt = 0 + while episode < train_num_episode: + + remaining = train_num_episode - episode + batch_size = min(self.trainer_params['train_batch_size'], remaining) + + avg_score, avg_loss = self._train_one_batch(batch_size) + score_AM.update(avg_score, batch_size) + loss_AM.update(avg_loss, batch_size) + + episode += batch_size + + # Log First 10 Batch, only at the first epoch + if epoch == self.start_epoch: + loop_cnt += 1 + if loop_cnt <= 10: + self.logger.info('Epoch {:3d}: Train {:3d}/{:3d}({:1.1f}%) Score: {:.4f}, Loss: {:.4f}' + .format(epoch, episode, train_num_episode, 100. * episode / train_num_episode, + score_AM.avg, loss_AM.avg)) + + # Log Once, for each epoch + self.logger.info('Epoch {:3d}: Train ({:3.0f}%) Score: {:.4f}, Loss: {:.4f}' + .format(epoch, 100. * episode / train_num_episode, + score_AM.avg, loss_AM.avg)) + + return score_AM.avg, loss_AM.avg + + def _train_one_batch(self, batch_size): + + # Prep + ############################################### + self.model.train() + self.env.load_problems(batch_size) + reset_state, _, _ = self.env.reset() + self.model.pre_forward(reset_state) + + prob_list = torch.zeros(size=(batch_size, self.env.pomo_size, 0)) + # shape: (batch, pomo, 0~) + + # POMO Rollout + ############################################### + state, reward, done = self.env.pre_step() + while not done: + selected, prob = self.model(state) + # shape: (batch, pomo) + state, reward, done = self.env.step(selected) + + prob_list = torch.cat((prob_list, prob[:, :, None]), dim=2) + + # Loss + ############################################### + advantage = reward - reward.float().mean(dim=1, keepdims=True) + # shape: (batch, pomo) + log_prob = prob_list.log().sum(dim=2) + # size = (batch, pomo) + loss = -advantage * log_prob # Minus Sign: To Increase REWARD + # shape: (batch, pomo) + loss_mean = loss.mean() + + # Score + ############################################### + max_pomo_reward, _ = reward.max(dim=1) # get best results from pomo + score_mean = -max_pomo_reward.float().mean() # negative sign to make positive value + + # Step & Return + ############################################### + self.model.zero_grad() + loss_mean.backward() + self.optimizer.step() + return score_mean.item(), loss_mean.item() \ No newline at end of file diff --git a/eval_atsp/train_glop.py b/eval_atsp/train_glop.py new file mode 100644 index 0000000..df423f4 --- /dev/null +++ b/eval_atsp/train_glop.py @@ -0,0 +1,169 @@ + +""" +The MIT License + +Copyright (c) 2021 MatNet + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + + + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +########################################################################################## +# Machine Environment Config + +DEBUG_MODE = False +USE_CUDA = not DEBUG_MODE +CUDA_DEVICE_NUM = 0 + +########################################################################################## +# Path Config + +import os +import sys +os.environ["CUDA_VISIBLE_DEVICES"] = "3" + +os.chdir(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, "..") # for problem_def +sys.path.insert(0, "../..") # for utils + + +########################################################################################## +# import + +import logging + +from utils_atsp.utils import create_logger, copy_all_src +from ASHPPTrainer import ASHPPTrainer as Trainer + + +########################################################################################## +# parameters + +env_params = { + 'node_cnt': 50, + 'problem_gen_params': { + 'int_min': 0, + 'int_max': 1000*1000, + 'scaler': 1000*1000 + }, + 'pomo_size': 50 # same as node_cnt +} + +model_params = { + 'embedding_dim': 256, + 'sqrt_embedding_dim': 256**(1/2), + 'encoder_layer_num': 5, + 'qkv_dim': 16, + 'sqrt_qkv_dim': 16**(1/2), + 'head_num': 16, + 'logit_clipping': 10, + 'ff_hidden_dim': 512, + 'ms_hidden_dim': 16, + 'ms_layer1_init': (1/2)**(1/2), + 'ms_layer2_init': (1/16)**(1/2), + 'eval_type': 'argmax', + 'one_hot_seed_cnt': 50, # must be >= node_cnt +} + +optimizer_params = { + 'optimizer': { + 'lr': 3*1e-4, + 'weight_decay': 1e-6 + }, + 'scheduler': { + 'milestones': [2001, 2101], # if further training is needed + 'gamma': 0.1 + } +} + +trainer_params = { + 'use_cuda': USE_CUDA, + 'cuda_device_num': CUDA_DEVICE_NUM, + 'epochs': 1000, + 'train_episodes': 10*1000, + 'train_batch_size': 200, + 'logging': { + 'model_save_interval': 100, + 'img_save_interval': 200, + 'log_image_params_1': { + 'json_foldername': 'log_image_style', + 'filename': 'style.json' + }, + 'log_image_params_2': { + 'json_foldername': 'log_image_style', + 'filename': 'style_loss.json' + }, + }, + 'model_load': { + 'enable': True, # enable loading pre-trained model + 'path': './result/20240720_153647_matnet_train', # directory path of pre-trained model and log files saved. + 'epoch': 600, # epoch version of pre-trained model to laod. + } +} + +logger_params = { + 'log_file': { + 'desc': 'matnet_train', + 'filename': 'log.txt' + } +} + + +########################################################################################## +# main + +def main(): + if DEBUG_MODE: + _set_debug_mode() + + create_logger(**logger_params) + _print_config() + + trainer = Trainer(env_params=env_params, + model_params=model_params, + optimizer_params=optimizer_params, + trainer_params=trainer_params) + + copy_all_src(trainer.result_folder) + + trainer.run() + + +def _set_debug_mode(): + + global trainer_params + trainer_params['epochs'] = 2 + trainer_params['train_episodes'] = 4 + trainer_params['train_batch_size'] = 2 + trainer_params['validate_episodes'] = 4 + trainer_params['validate_batch_size'] = 2 + + +def _print_config(): + logger = logging.getLogger('root') + logger.info('DEBUG_MODE: {}'.format(DEBUG_MODE)) + logger.info('USE_CUDA: {}, CUDA_DEVICE_NUM: {}'.format(USE_CUDA, CUDA_DEVICE_NUM)) + [logger.info(g_key + "{}".format(globals()[g_key])) for g_key in globals().keys() if g_key.endswith('params')] + + +########################################################################################## + +if __name__ == "__main__": + main() \ No newline at end of file From 08a6e67434ac2bb9eecac22fbb42e44abf44e37a Mon Sep 17 00:00:00 2001 From: henry-yeh Date: Sun, 21 Jul 2024 12:54:32 +0800 Subject: [PATCH 09/10] data generation for poor man --- eval_atsp/ATSProblemDef.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/eval_atsp/ATSProblemDef.py b/eval_atsp/ATSProblemDef.py index 9e3a75a..e4f7075 100644 --- a/eval_atsp/ATSProblemDef.py +++ b/eval_atsp/ATSProblemDef.py @@ -26,6 +26,7 @@ """ import torch +from tqdm import tqdm def get_random_problems(batch_size, node_cnt, problem_gen_params): @@ -107,11 +108,14 @@ def load_single_problem_from_file(filename, node_cnt, scaler): if not os.path.exists("../data/atsp"): os.mkdir("../data/atsp") - torch.manual_seed(1234) - + torch.manual_seed(1234) + dataset_size = 30 for scale in [150, 250, 1000]: - problems = get_random_problems(dataset_size, scale, problem_gen_params) - print("instances generated") + problems = [] + for inst_id in tqdm(range(dataset_size)): + problem = get_random_problems(1, scale, problem_gen_params) + problems.append(problem) + problems = torch.cat(problems, dim=0) torch.save(problems, "../data/atsp/ATSP{}.pt".format(scale)) print(f"created ../data/atsp/ATSP{scale}.pt") \ No newline at end of file From 38ddcb97c36a77d9d8191f88f0f324f8cca76e00 Mon Sep 17 00:00:00 2001 From: henry-yeh Date: Sun, 21 Jul 2024 12:55:25 +0800 Subject: [PATCH 10/10] bugfix --- eval_atsp/test_glop.py | 47 ++++++++++++++++++++++++++---------------- 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/eval_atsp/test_glop.py b/eval_atsp/test_glop.py index 3975a78..9fcc8ac 100644 --- a/eval_atsp/test_glop.py +++ b/eval_atsp/test_glop.py @@ -57,8 +57,12 @@ ##### GLOP parameters ##### N_REVISER = 50 # We only test on Reviser-50; using more revisers requires code modifications -N_REVISIONS = 3 -N_SAMPLES = 100 # for sampling decoding during revision +N_REVISIONS = 3 # number of revision iterations +N_SAMPLES = { + 150: 2000, + 250: 1000, + 1000: 500 + } # for sampling decoding during revision @@ -69,7 +73,7 @@ 'int_max': 1000*1000, 'scaler': 1000*1000 }, - 'pomo_size': N_SAMPLES + 'pomo_size': 500, } model_params = { @@ -85,7 +89,7 @@ 'ms_layer1_init': (1/2)**(1/2), 'ms_layer2_init': (1/16)**(1/2), 'eval_type': 'softmax', # note here, can be greedy - 'one_hot_seed_cnt': 20, # must be >= node_cnt + 'one_hot_seed_cnt': N_REVISER, # must be >= node_cnt } tester_params = { @@ -121,7 +125,7 @@ def revision(tour, inst, tester): sub_tours = tour.reshape(-1, N_REVISER) # shape: (batch, revision_len) sub_insts = [inst[sub_tour][:, sub_tour] for sub_tour in sub_tours] - original_scores = torch.stack([inst[sub_tour[:-1], torch.roll(sub_tour, shifts=-1)[:-1]].sum() for sub_tour in sub_tours]) # note that original_scores are positive values + original_scores = torch.tensor([cal_len_shpp(sub_tour, inst) for sub_tour in sub_tours]) # note that original_scores are positive values # Scale the sub_insts to make the largest value 1 scale_coef = [sub_inst.max() for sub_inst in sub_insts] sub_insts = torch.stack(sub_insts) @@ -136,16 +140,15 @@ def revision(tour, inst, tester): # TODO: unmcomment to validate the subtours for i in range(len(sub_insts)): validate_subtour(solutions[i], sub_insts[i], revised_scores[i]) - - # Gather the subtours according to the solutions - revised_tours = sub_tours.gather(1, solutions) # Compare the original scores and the revised scores improved_scores = original_scores - revised_scores # subtours should be aranged in the same order as the original tours, if the improved_scores <= 0 - kept_subtour_idx = improved_scores <= 0 - revised_tours[kept_subtour_idx] = sub_tours[kept_subtour_idx] - + solutions[improved_scores <= 0] = torch.arange(sub_tours.shape[1]) + # Gather the subtours according to the solutions + revised_tours = sub_tours.gather(1, solutions) + # Flatten the revised_tours + revised_tours = revised_tours.reshape(-1) # shape: (batch * revision_len) i.e. (node_cnt,) return revised_tours def validate_subtour(subtour, dist, cost): @@ -156,7 +159,11 @@ def validate_subtour(subtour, dist, cost): for i in range(1, len(subtour) - 1): assert i in subtour -def calc_len(tour, dist): +def validate_tour(tour): + for i in range(1, len(tour) - 1): + assert i in tour + +def cal_len(tour, dist): cost = dist[tour, torch.roll(tour, -1, -1)].sum() return cost.item() @@ -187,16 +194,16 @@ def main(n): for inst in dataset: tour, cost = random_insertion_non_euclidean(inst, order) original_costs.append(cost) - tour = torch.tensor(tour.astype(np.int64)) - + for revision_iter in range(N_REVISIONS): tour = revision(tour, inst, tester) # Shift the tour to the right by N_SHIFTS tour = torch.roll(tour, shifts=N_SHIFTS, dims=-1) - # cost = calc_len(tour, inst) - # print(f"cost after revision {revision_iter}: {cost}") - cost = calc_len(tour, inst) + + # TODO: unmcomment to validate the solution + # validate_tour(tour) + cost = cal_len(tour, inst) revised_costs.append(cost) total_duration = time.time() - start @@ -207,4 +214,8 @@ def main(n): if __name__ == "__main__": - main(int(sys.argv[1])) + N = int(sys.argv[1]) + env_params['pomo_size'] = N_SAMPLES.get(N, 500) + + main(N) + \ No newline at end of file