@@ -239,31 +239,52 @@ def main(args):
239239 model = trainer (** model )
240240 model ._model .summary ()
241241
242- if args .average :
243- averageModels (folder , model )
244242 # find folders with the name "/test-*/"
245243 evalDatasets = [
246244 CTestLoader (nm )
247245 for nm in glob .glob (os .path .join (folder , 'test-main' , 'test-*/' ))
248246 ]
249247 eval = evaluator (evalDatasets , model , folder , args )
250- bestLoss = eval ()
248+ bestLoss = eval () # evaluate loaded model
251249 bestEpoch = 0
250+ # wrapper for the evaluation function. It saves the model if it is better
251+ def evalWrapper (eval ):
252+ def f (epoch , onlyImproved = False ):
253+ nonlocal bestLoss , bestEpoch
254+ newLoss = eval (onlyImproved = onlyImproved )
255+ if newLoss < bestLoss :
256+ print ('Improved %.5f => %.5f' % (bestLoss , newLoss ))
257+ bestLoss = newLoss
258+ bestEpoch = epoch
259+ model .save (folder , postfix = 'best' )
260+ return
261+ return f
262+
263+ eval = evalWrapper (eval )
264+
265+ def performRandomSearch (epoch = 0 ):
266+ nonlocal bestLoss , bestEpoch
267+ averageModels (folder , model , noiseStd = 0.0 )
268+ eval (epoch = epoch , onlyImproved = True ) # evaluate the averaged model
269+ for _ in range (args .restarts ):
270+ # and add some noise
271+ averageModels (folder , model , noiseStd = args .noise )
272+ # re-evaluate the model with the new weights
273+ eval (epoch = epoch , onlyImproved = True )
274+ continue
275+ return
276+
277+ if args .average :
278+ performRandomSearch ()
279+
252280 trainStep = _modelTrainingLoop (model , trainDataset )
253281 for epoch in range (args .epochs ):
254282 trainStep (
255283 desc = 'Epoch %.*d / %d' % (len (str (args .epochs )), epoch , args .epochs ),
256284 sampleParams = getSampleParams (epoch )
257285 )
258286 model .save (folder , postfix = 'latest' )
259-
260- testLoss = eval ()
261- if testLoss < bestLoss :
262- print ('Improved %.5f => %.5f' % (bestLoss , testLoss ))
263- bestLoss = testLoss
264- bestEpoch = epoch
265- model .save (folder , postfix = 'best' )
266- continue
287+ eval (epoch )
267288
268289 print ('Passed %d epochs since the last improvement (best: %.5f)' % (epoch - bestEpoch , bestLoss ))
269290 if args .patience <= (epoch - bestEpoch ):
@@ -272,19 +293,7 @@ def main(args):
272293 break
273294 if 'reset' == args .on_patience :
274295 print ('Resetting the model to the average of the best models' )
275- bestEpoch = epoch # reset the best epoch
276- for _ in range (args .restarts ):
277- # and add some noise
278- averageModels (folder , model , noiseStd = args .noise )
279- # re-evaluate the model with the new weights
280- testLoss = eval (onlyImproved = True )
281- if testLoss < bestLoss :
282- print ('Improved %.5f => %.5f' % (bestLoss , testLoss ))
283- bestLoss = testLoss
284- bestEpoch = epoch
285- model .save (folder , postfix = 'best' )
286- continue
287- continue
296+ performRandomSearch (epoch = epoch )
288297 continue
289298 return
290299
0 commit comments