@@ -73,29 +73,35 @@ def _eval(dataset, model, plotFilename, args):
7373
7474def evaluator (datasets , model , folder , args ):
7575 losses = [np .inf ] * len (datasets ) # initialize with infinity
76+ dists = [np .inf ] * len (datasets ) # initialize with infinity
7677 def evaluate (onlyImproved = False ):
7778 totalLoss = totalDist = 0.0
79+ losses_dist = []
7880 for i , dataset in enumerate (datasets ):
7981 loss , dist , T = _eval (dataset , model , os .path .join (folder , 'pred-%d.png' % i ), args )
82+ losses_dist .append ((loss , losses [i ], dist , dists [i ]))
8083 isImproved = loss < losses [i ]
8184 if (not onlyImproved ) or isImproved :
82- print ('Test %d / %d | %.2f sec | Loss: %.5f (%.5f). Distance: %.5f' % (
83- i + 1 , len (datasets ), T , loss , losses [i ], dist
85+ print ('Test %d / %d | %.2f sec | Loss: %.5f (%.5f). Distance: %.5f (%.5f) ' % (
86+ i + 1 , len (datasets ), T , loss , losses [i ], dist , dists [ i ]
8487 ))
8588 if isImproved :
86- print ('Test %d / %d | Improved %.5f => %.5f' % (i + 1 , len (datasets ), losses [i ], loss ))
89+ print ('Test %d / %d | Improved %.5f => %.5f, Distance: %.5f => %.5f' % (
90+ i + 1 , len (datasets ), losses [i ], loss , dists [i ], dist
91+ ))
8792 model .save (folder , postfix = 'best-%d' % i ) # save the model separately
8893 losses [i ] = loss
8994 pass
9095
96+ dists [i ] = min (dist , dists [i ]) # track the best distance
9197 totalLoss += loss
9298 totalDist += dist
9399 continue
94100 if not onlyImproved :
95101 print ('Mean loss: %.5f | Mean distance: %.5f' % (
96102 totalLoss / len (datasets ), totalDist / len (datasets )
97103 ))
98- return totalLoss / len (datasets )
104+ return totalLoss / len (datasets ), losses_dist
99105 return evaluate
100106
101107def _modelTrainingLoop (model , dataset ):
@@ -247,15 +253,20 @@ def main(args):
247253 for nm in glob .glob (os .path .join (folder , 'test-main' , 'test-*/' ))
248254 ]
249255 eval = evaluator (evalDatasets , model , folder , args )
250- bestLoss = eval () # evaluate loaded model
256+ bestLoss , _ = eval () # evaluate loaded model
251257 bestEpoch = 0
252258 # wrapper for the evaluation function. It saves the model if it is better
253259 def evalWrapper (eval ):
254260 def f (epoch , onlyImproved = False ):
255261 nonlocal bestLoss , bestEpoch
256- newLoss = eval (onlyImproved = onlyImproved )
262+ newLoss , losses = eval (onlyImproved = onlyImproved )
257263 if newLoss < bestLoss :
258264 print ('Improved %.5f => %.5f' % (bestLoss , newLoss ))
265+ if onlyImproved : #details
266+ for i , (loss , bestLoss_ , dist , bestDist ) in enumerate (losses ):
267+ print ('Test %d | Loss: %.5f (%.5f). Distance: %.5f (%.5f)' % (i + 1 , loss , bestLoss_ , dist , bestDist ))
268+ continue
269+ print ('-' * 80 )
259270 bestLoss = newLoss
260271 bestEpoch = epoch
261272 model .save (folder , postfix = 'best' )
0 commit comments