Skip to content

Commit d2478b5

Browse files
averaging model weights and other small changes
1 parent d5a42aa commit d2478b5

File tree

2 files changed

+63
-18
lines changed

2 files changed

+63
-18
lines changed

Core/CModelWrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def save(self, folder=None, postfix=''):
8383
return
8484

8585
def load(self, folder=None, postfix='', embeddings=False):
86-
path = self._modelFilename(folder, postfix)
86+
path = self._modelFilename(folder, postfix) if not os.path.isfile(folder) else folder
8787
self._model.load_weights(path)
8888
if embeddings:
8989
embeddings = np.load(path.replace('.h5', '-embeddings.npz'))

scripts/train.py

Lines changed: 62 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -71,18 +71,29 @@ def _eval(dataset, model, plotFilename, args):
7171
T = time.time() - T
7272
return loss, dist, T
7373

74-
def evaluate(datasets, model, folder, args):
75-
totalLoss = totalDist = 0.0
76-
for i, dataset in enumerate(datasets):
77-
loss, dist, T = _eval(dataset, model, os.path.join(folder, 'pred-%d.png' % i), args)
78-
print('Test %d / %d | %.2f sec | Loss: %.5f. Distance: %.5f' % (i + 1, len(datasets), T, loss, dist))
79-
totalLoss += loss
80-
totalDist += dist
81-
continue
82-
print('Mean loss: %.5f | Mean distance: %.5f' % (
83-
totalLoss / len(datasets), totalDist / len(datasets)
84-
))
85-
return totalLoss / len(datasets)
74+
def evaluator(datasets, model, folder, args):
75+
losses = [np.inf] * len(datasets) # initialize with infinity
76+
def evaluate():
77+
totalLoss = totalDist = 0.0
78+
for i, dataset in enumerate(datasets):
79+
loss, dist, T = _eval(dataset, model, os.path.join(folder, 'pred-%d.png' % i), args)
80+
print('Test %d / %d | %.2f sec | Loss: %.5f (%.5f). Distance: %.5f' % (
81+
i + 1, len(datasets), T, loss, losses[i], dist
82+
))
83+
if loss < losses[i]:
84+
print('Improved %.5f => %.5f' % (losses[i], loss))
85+
model.save(folder, postfix='best-%d' % i) # save the model separately
86+
losses[i] = loss
87+
pass
88+
89+
totalLoss += loss
90+
totalDist += dist
91+
continue
92+
print('Mean loss: %.5f | Mean distance: %.5f' % (
93+
totalLoss / len(datasets), totalDist / len(datasets)
94+
))
95+
return totalLoss / len(datasets)
96+
return evaluate
8697

8798
def _modelTrainingLoop(model, dataset):
8899
def F(desc, sampleParams):
@@ -167,6 +178,25 @@ def _trainer_from(args):
167178
if args.trainer == 'default': return CModelTrainer
168179
raise Exception('Unknown trainer: %s' % (args.trainer, ))
169180

181+
def averageModels(folder, model, noiseStd=0.0):
182+
TV = [np.zeros_like(x) for x in model._model.get_weights()]
183+
N = 0
184+
for nm in glob.glob(os.path.join(folder, '*.h5')):
185+
if not('best' in nm): continue # only the best models
186+
model.load(nm, embeddings=True)
187+
# add the weights to the total
188+
weights = model._model.get_weights()
189+
for i in range(len(TV)):
190+
TV[i] += weights[i]
191+
continue
192+
N += 1
193+
continue
194+
195+
# average the weights
196+
TV = [(x / N) + np.random.normal(0.0, noiseStd, x.shape) for x in TV]
197+
model._model.set_weights(TV)
198+
return
199+
170200
def main(args):
171201
timesteps = args.steps
172202
folder = os.path.join(args.folder, 'Data')
@@ -205,12 +235,15 @@ def main(args):
205235
model = trainer(**model)
206236
model._model.summary()
207237

238+
if args.average:
239+
averageModels(folder, model)
208240
# find folders with the name "/test-*/"
209241
evalDatasets = [
210242
CTestLoader(nm)
211243
for nm in glob.glob(os.path.join(folder, 'test-main', 'test-*/'))
212244
]
213-
bestLoss = evaluate(evalDatasets, model, folder, args)
245+
eval = evaluator(evalDatasets, model, folder, args)
246+
bestLoss = eval()
214247
bestEpoch = 0
215248
trainStep = _modelTrainingLoop(model, trainDataset)
216249
for epoch in range(args.epochs):
@@ -220,7 +253,7 @@ def main(args):
220253
)
221254
model.save(folder, postfix='latest')
222255

223-
testLoss = evaluate(evalDatasets, model, folder, args)
256+
testLoss = eval()
224257
if testLoss < bestLoss:
225258
print('Improved %.5f => %.5f' % (bestLoss, testLoss))
226259
bestLoss = testLoss
@@ -230,19 +263,31 @@ def main(args):
230263

231264
print('Passed %d epochs since the last improvement (best: %.5f)' % (epoch - bestEpoch, bestLoss))
232265
if args.patience <= (epoch - bestEpoch):
233-
print('Early stopping')
234-
break
266+
if 'stop' == args.on_patience:
267+
print('Early stopping')
268+
break
269+
if 'reset' == args.on_patience:
270+
print('Resetting the model to the average of the best models')
271+
# and add some noise
272+
averageModels(folder, model, noiseStd=0.01)
273+
bestEpoch = epoch
274+
continue
235275
continue
236276
return
237277

238278
if __name__ == '__main__':
239279
parser = argparse.ArgumentParser()
240280
parser.add_argument('--epochs', type=int, default=1000)
241281
parser.add_argument('--batch-size', type=int, default=64)
242-
parser.add_argument('--patience', type=int, default=15)
282+
parser.add_argument('--patience', type=int, default=5)
283+
parser.add_argument('--on-patience', type=str, default='stop', choices=['stop', 'reset'])
243284
parser.add_argument('--steps', type=int, default=5)
244285
parser.add_argument('--model', type=str)
245286
parser.add_argument('--embeddings', default=False, action='store_true')
287+
parser.add_argument(
288+
'--average', default=False, action='store_true',
289+
help='Load each model from the folder and average them weights'
290+
)
246291
parser.add_argument('--folder', type=str, default=ROOT_FOLDER)
247292
parser.add_argument('--modelId', type=str)
248293
parser.add_argument(

0 commit comments

Comments
 (0)