@@ -71,18 +71,29 @@ def _eval(dataset, model, plotFilename, args):
7171 T = time .time () - T
7272 return loss , dist , T
7373
74- def evaluate (datasets , model , folder , args ):
75- totalLoss = totalDist = 0.0
76- for i , dataset in enumerate (datasets ):
77- loss , dist , T = _eval (dataset , model , os .path .join (folder , 'pred-%d.png' % i ), args )
78- print ('Test %d / %d | %.2f sec | Loss: %.5f. Distance: %.5f' % (i + 1 , len (datasets ), T , loss , dist ))
79- totalLoss += loss
80- totalDist += dist
81- continue
82- print ('Mean loss: %.5f | Mean distance: %.5f' % (
83- totalLoss / len (datasets ), totalDist / len (datasets )
84- ))
85- return totalLoss / len (datasets )
74+ def evaluator (datasets , model , folder , args ):
75+ losses = [np .inf ] * len (datasets ) # initialize with infinity
76+ def evaluate ():
77+ totalLoss = totalDist = 0.0
78+ for i , dataset in enumerate (datasets ):
79+ loss , dist , T = _eval (dataset , model , os .path .join (folder , 'pred-%d.png' % i ), args )
80+ print ('Test %d / %d | %.2f sec | Loss: %.5f (%.5f). Distance: %.5f' % (
81+ i + 1 , len (datasets ), T , loss , losses [i ], dist
82+ ))
83+ if loss < losses [i ]:
84+ print ('Improved %.5f => %.5f' % (losses [i ], loss ))
85+ model .save (folder , postfix = 'best-%d' % i ) # save the model separately
86+ losses [i ] = loss
87+ pass
88+
89+ totalLoss += loss
90+ totalDist += dist
91+ continue
92+ print ('Mean loss: %.5f | Mean distance: %.5f' % (
93+ totalLoss / len (datasets ), totalDist / len (datasets )
94+ ))
95+ return totalLoss / len (datasets )
96+ return evaluate
8697
8798def _modelTrainingLoop (model , dataset ):
8899 def F (desc , sampleParams ):
@@ -167,6 +178,25 @@ def _trainer_from(args):
167178 if args .trainer == 'default' : return CModelTrainer
168179 raise Exception ('Unknown trainer: %s' % (args .trainer , ))
169180
181+ def averageModels (folder , model , noiseStd = 0.0 ):
182+ TV = [np .zeros_like (x ) for x in model ._model .get_weights ()]
183+ N = 0
184+ for nm in glob .glob (os .path .join (folder , '*.h5' )):
185+ if not ('best' in nm ): continue # only the best models
186+ model .load (nm , embeddings = True )
187+ # add the weights to the total
188+ weights = model ._model .get_weights ()
189+ for i in range (len (TV )):
190+ TV [i ] += weights [i ]
191+ continue
192+ N += 1
193+ continue
194+
195+ # average the weights
196+ TV = [(x / N ) + np .random .normal (0.0 , noiseStd , x .shape ) for x in TV ]
197+ model ._model .set_weights (TV )
198+ return
199+
170200def main (args ):
171201 timesteps = args .steps
172202 folder = os .path .join (args .folder , 'Data' )
@@ -205,12 +235,15 @@ def main(args):
205235 model = trainer (** model )
206236 model ._model .summary ()
207237
238+ if args .average :
239+ averageModels (folder , model )
208240 # find folders with the name "/test-*/"
209241 evalDatasets = [
210242 CTestLoader (nm )
211243 for nm in glob .glob (os .path .join (folder , 'test-main' , 'test-*/' ))
212244 ]
213- bestLoss = evaluate (evalDatasets , model , folder , args )
245+ eval = evaluator (evalDatasets , model , folder , args )
246+ bestLoss = eval ()
214247 bestEpoch = 0
215248 trainStep = _modelTrainingLoop (model , trainDataset )
216249 for epoch in range (args .epochs ):
@@ -220,7 +253,7 @@ def main(args):
220253 )
221254 model .save (folder , postfix = 'latest' )
222255
223- testLoss = evaluate ( evalDatasets , model , folder , args )
256+ testLoss = eval ( )
224257 if testLoss < bestLoss :
225258 print ('Improved %.5f => %.5f' % (bestLoss , testLoss ))
226259 bestLoss = testLoss
@@ -230,19 +263,31 @@ def main(args):
230263
231264 print ('Passed %d epochs since the last improvement (best: %.5f)' % (epoch - bestEpoch , bestLoss ))
232265 if args .patience <= (epoch - bestEpoch ):
233- print ('Early stopping' )
234- break
266+ if 'stop' == args .on_patience :
267+ print ('Early stopping' )
268+ break
269+ if 'reset' == args .on_patience :
270+ print ('Resetting the model to the average of the best models' )
271+ # and add some noise
272+ averageModels (folder , model , noiseStd = 0.01 )
273+ bestEpoch = epoch
274+ continue
235275 continue
236276 return
237277
238278if __name__ == '__main__' :
239279 parser = argparse .ArgumentParser ()
240280 parser .add_argument ('--epochs' , type = int , default = 1000 )
241281 parser .add_argument ('--batch-size' , type = int , default = 64 )
242- parser .add_argument ('--patience' , type = int , default = 15 )
282+ parser .add_argument ('--patience' , type = int , default = 5 )
283+ parser .add_argument ('--on-patience' , type = str , default = 'stop' , choices = ['stop' , 'reset' ])
243284 parser .add_argument ('--steps' , type = int , default = 5 )
244285 parser .add_argument ('--model' , type = str )
245286 parser .add_argument ('--embeddings' , default = False , action = 'store_true' )
287+ parser .add_argument (
288+ '--average' , default = False , action = 'store_true' ,
289+ help = 'Load each model from the folder and average them weights'
290+ )
246291 parser .add_argument ('--folder' , type = str , default = ROOT_FOLDER )
247292 parser .add_argument ('--modelId' , type = str )
248293 parser .add_argument (
0 commit comments