11#!/usr/bin/env python
2- # -*- coding: utf-8 -*-.
3- # TODO: add the W&B integration
2+ # -*- coding: utf-8 -*-
43import argparse , os , sys
5- # add the root folder of the project to the path
4+ # Add the root folder of the project to the path
65ROOT_FOLDER = os .path .abspath (os .path .dirname (__file__ ) + '/../' )
76sys .path .append (ROOT_FOLDER )
87
1615from Core .CInpaintingTrainer import CInpaintingTrainer
1716import tqdm
1817import json
18+ import wandb
1919
2020def _eval (dataset , model ):
21- T = time .time ()
22- # evaluate the model on the val dataset
23- loss = []
24- for batchId in range (len (dataset )):
25- batch = dataset [batchId ]
26- loss_value = model .eval (batch )
27- loss .append (loss_value )
28- continue
29-
30- loss = np .mean (loss )
31- T = time .time () - T
32- return loss , T
21+ T = time .time ()
22+ # Evaluate the model on the validation dataset
23+ loss = []
24+ for batchId in range (len (dataset )):
25+ batch = dataset [batchId ]
26+ loss_value = model .eval (batch )
27+ loss .append (loss_value )
28+ loss = np .mean (loss )
29+ T = time .time () - T
30+ return loss , T
3331
3432def evaluator (datasets , model , folder , args ):
35- losses = [np .inf ] * len (datasets ) # initialize with infinity
36- def evaluate (onlyImproved = False ):
37- totalLoss = []
38- for i , dataset in enumerate (datasets ):
39- loss , T = _eval (dataset , model )
40- isImproved = loss < losses [i ]
41- if (not onlyImproved ) or isImproved :
42- dataset_id = ', ' .join ([str (x ) for x in dataset .parametersIDs ()])
43- print ('Test %d / %d (%s) | %.2f sec | Loss: %.5f (%.5f).' % (
44- i + 1 , len (datasets ), dataset_id , T , loss , losses [i ],
45- ))
46- if isImproved :
47- print ('Test %d / %d | Improved %.5f => %.5f,' % (
48- i + 1 , len (datasets ), losses [i ], loss ,
49- ))
50- model .save (folder , postfix = 'best-%d' % i ) # save the model separately
51- losses [i ] = loss
52- pass
33+ losses = [np .inf ] * len (datasets ) # Initialize with infinity
34+ def evaluate (onlyImproved = False , step = None ):
35+ totalLoss = []
36+ eval_metrics = {}
37+ for i , dataset in enumerate (datasets ):
38+ loss , T = _eval (dataset , model )
39+ dataset_id = ', ' .join ([str (x ) for x in dataset .parametersIDs ()])
40+ isImproved = loss < losses [i ]
41+ if (not onlyImproved ) or isImproved :
42+ print ('Test %d / %d (%s) | %.2f sec | Loss: %.5f (%.5f).' % (
43+ i + 1 , len (datasets ), dataset_id , T , loss , losses [i ],
44+ ))
45+ if isImproved :
46+ print ('Test %d / %d | Improved %.5f => %.5f,' % (
47+ i + 1 , len (datasets ), losses [i ], loss ,
48+ ))
49+ modelFolder = os .path .join (folder , f"model-{ dataset_id } " )
50+ os .makedirs (modelFolder , exist_ok = True )
51+ # keep only the best model across all runs
52+ # name format: {model id}-{loss:.5f}-*.*
53+ all_files = os .listdir (modelFolder )
54+ all_losses = [f .split ('-' )[1 ] for f in all_files ]
55+ all_losses = list (set (all_losses ))
56+ print (f"Found losses: { all_losses } " )
57+ for loss_file in all_losses :
58+ if float (loss_file ) > loss :
59+ # remove all files with this loss in folder
60+ to_remove = [os .path .join (modelFolder , f ) for f in all_files if loss_file in f ]
61+ for f in to_remove :
62+ os .remove (f )
5363
54- totalLoss .append (loss )
55- continue
56- if not onlyImproved :
57- print ('Mean loss: %.5f' % (np .mean (totalLoss ), ))
58- return np .mean (totalLoss )
59- return evaluate
64+ model .save (modelFolder , postfix = '%.5f' % loss )
65+ losses [i ] = loss
66+ totalLoss .append (loss )
67+ eval_metrics ['eval_loss_(%s)' % dataset_id ] = loss
68+ mean_loss = np .mean (totalLoss )
69+ if not onlyImproved :
70+ print ('Mean loss: %.5f' % mean_loss )
71+ # Log evaluation metrics to wandb
72+ if step is not None :
73+ wandb .log (eval_metrics , step = step )
74+ return mean_loss
75+ return evaluate
6076
6177def _modelTrainingLoop (model , dataset ):
62- def F (desc ):
63- history = defaultdict (list )
64- # use the tqdm progress bar
65- with tqdm .tqdm (total = len (dataset ), desc = desc ) as pbar :
66- dataset .on_epoch_start ()
67- for _ in range (len (dataset )):
68- sampled = dataset .sample ()
69- stats = model .fit (sampled )
70- history ['time' ].append (stats ['time' ])
71- for k in stats ['losses' ].keys ():
72- history [k ].append (stats ['losses' ][k ])
73- # add stats to the progress bar (mean of each history)
74- pbar .set_postfix ({k : '%.5f' % np .mean (v ) for k , v in history .items ()})
75- pbar .update (1 )
76- continue
77- dataset .on_epoch_end ()
78- return
79- return F
78+ def F (desc ):
79+ history = defaultdict (list )
80+ # Use the tqdm progress bar
81+ with tqdm .tqdm (total = len (dataset ), desc = desc ) as pbar :
82+ dataset .on_epoch_start ()
83+ for step in range (len (dataset )):
84+ sampled = dataset .sample ()
85+ stats = model .fit (sampled )
86+ history ['time' ].append (stats ['time' ])
87+ for k in stats ['losses' ].keys ():
88+ history [k ].append (stats ['losses' ][k ])
89+ # Add stats to the progress bar (mean of each history)
90+ pbar .set_postfix ({k : '%.5f' % np .mean (v ) for k , v in history .items ()})
91+ pbar .update (1 )
92+ dataset .on_epoch_end ()
93+ return {k : np .mean (v ) for k , v in history .items ()}
94+ return F
8095
8196def _trainer_from (args ):
82- if args .trainer == 'default' : return CInpaintingTrainer
83- raise Exception ('Unknown trainer: %s' % (args .trainer , ))
97+ if args .trainer == 'default' : return CInpaintingTrainer
98+ raise Exception ('Unknown trainer: %s' % (args .trainer , ))
8499
85100def main (args ):
86- timesteps = args .steps
87- folder = os .path .join (args .folder , 'Data' )
88-
89- stats = None
90- with open (os .path .join (folder , 'remote' , 'stats.json' ), 'r' ) as f :
91- stats = json .load (f )
101+ wandb .init (project = args .wandb_project , config = vars (args )) # Initialize wandb
102+ timesteps = args .steps
103+ folder = os .path .join (args .folder , 'Data' )
92104
93- trainer = _trainer_from (args )
94- trainDataset = CDatasetLoader (
95- os .path .join (folder , 'remote' ),
96- stats = stats ,
97- sampling = args .sampling ,
98- samplerArgs = dict (
99- batch_size = args .batch_size ,
100- minFrames = timesteps ,
101- maxT = 1.0 ,
102- defaults = dict (
103- timesteps = timesteps ,
104- stepsSampling = {'max frames' : 10 },
105- # no augmentations by default
106- pointsNoise = 0.01 , pointsDropout = 0.01 ,
107- eyesDropout = 0.1 , eyesAdditiveNoise = 0.01 , brightnessFactor = 1.5 , lightBlobFactor = 1.5 ,
108- targets = dict (keypoints = 3 , total = 10 ),
109- ),
110- keys = ['clean' ],
111- ),
112- sampler_class = CDataSamplerInpainting ,
113- test_folders = ['train.npz' ],
114- )
115- model = dict (timesteps = timesteps , stats = stats )
116- if args .model is not None :
117- model ['weights' ] = dict (folder = folder , postfix = args .model , embeddings = args .embeddings )
118- if args .modelId is not None :
119- model ['model' ] = args .modelId
105+ stats = None
106+ with open (os .path .join (folder , 'remote' , 'stats.json' ), 'r' ) as f :
107+ stats = json .load (f )
120108
121- model = trainer (** model )
122- # model._model.summary()
123-
124- evalDatasets = [
125- CTestInpaintingLoader (os .path .join (folderName , 'test-inpainting' ))
126- for folderName , _ in Utils .dataset_from_stats (stats , os .path .join (folder , 'remote' ))
127- if os .path .exists (os .path .join (folderName , 'test-inpainting' ))
128- ]
129- eval = evaluator (evalDatasets , model , folder , args )
130- bestLoss = eval () # evaluate loaded model
131- bestEpoch = 0
132- # wrapper for the evaluation function. It saves the model if it is better
133- def evalWrapper (eval ):
134- def f (epoch , onlyImproved = False ):
135- nonlocal bestLoss , bestEpoch
136- newLoss = eval (onlyImproved = onlyImproved )
137- if newLoss < bestLoss :
138- print ('Improved %.5f => %.5f' % (bestLoss , newLoss ))
139- if onlyImproved : #details
140- for i , (loss , bestLoss_ , dist , bestDist ) in enumerate (losses ):
141- print ('Test %d | Loss: %.5f (%.5f). Distance: %.5f (%.5f)' % (i + 1 , loss , bestLoss_ , dist , bestDist ))
142- continue
143- print ('-' * 80 )
144- bestLoss = newLoss
145- bestEpoch = epoch
146- model .save (folder , postfix = 'best' )
147- return
148- return f
149-
150- eval = evalWrapper (eval )
151- trainStep = _modelTrainingLoop (model , trainDataset )
152- for epoch in range (args .epochs ):
153- trainStep (
154- desc = 'Epoch %.*d / %d' % (len (str (args .epochs )), epoch , args .epochs ),
109+ trainer = _trainer_from (args )
110+ trainDataset = CDatasetLoader (
111+ os .path .join (folder , 'remote' ),
112+ stats = stats ,
113+ sampling = args .sampling ,
114+ samplerArgs = dict (
115+ batch_size = args .batch_size ,
116+ minFrames = timesteps ,
117+ maxT = 1.0 ,
118+ defaults = dict (
119+ timesteps = timesteps ,
120+ stepsSampling = {'max frames' : 10 },
121+ # No augmentations by default
122+ pointsNoise = 0.01 , pointsDropout = 0.01 ,
123+ eyesDropout = 0.1 , eyesAdditiveNoise = 0.01 , brightnessFactor = 1.5 , lightBlobFactor = 1.5 ,
124+ targets = dict (keypoints = 3 , total = 10 ),
125+ ),
126+ keys = ['clean' ],
127+ ),
128+ sampler_class = CDataSamplerInpainting ,
129+ test_folders = ['train.npz' ],
155130 )
156- model .save (folder , postfix = 'latest' )
157- eval (epoch )
131+ model = dict (timesteps = timesteps , stats = stats )
132+ if args .model is not None :
133+ model ['weights' ] = dict (folder = folder , postfix = args .model , embeddings = args .embeddings )
134+ if args .modelId is not None :
135+ model ['model' ] = args .modelId
136+
137+ model = trainer (** model )
138+
139+ evalDatasets = [
140+ CTestInpaintingLoader (os .path .join (folderName , 'test-inpainting' ))
141+ for folderName , _ in Utils .dataset_from_stats (stats , os .path .join (folder , 'remote' ))
142+ if os .path .exists (os .path .join (folderName , 'test-inpainting' ))
143+ ]
144+ eval_fn = evaluator (evalDatasets , model , folder , args )
145+ bestLoss = eval_fn () # Evaluate loaded model
146+ bestEpoch = 0
158147
159- print ('Passed %d epochs since the last improvement (best: %.5f)' % (epoch - bestEpoch , bestLoss ))
160- if args .patience <= (epoch - bestEpoch ):
161- print ('Early stopping' )
162- break
163- continue
164- return
148+ def evalWrapper (eval_fn ):
149+ def f (epoch , onlyImproved = False , step = None ):
150+ nonlocal bestLoss , bestEpoch
151+ newLoss = eval_fn (onlyImproved = onlyImproved , step = step )
152+ if newLoss < bestLoss :
153+ print ('Improved %.5f => %.5f' % (bestLoss , newLoss ))
154+ bestLoss = newLoss
155+ bestEpoch = epoch
156+ model .save (folder , postfix = '%.5f' % newLoss )
157+ return
158+ return f
159+
160+ eval_fn = evalWrapper (eval_fn )
161+ trainStep = _modelTrainingLoop (model , trainDataset )
162+ for epoch in range (args .epochs ):
163+ metrics = trainStep (
164+ desc = 'Epoch %.*d / %d' % (len (str (args .epochs )), epoch , args .epochs ),
165+ )
166+ wandb .log (metrics , step = epoch + 1 )
167+ model .save (folder , postfix = 'latest' )
168+ eval_fn (epoch , step = epoch + 1 )
169+ print ('Passed %d epochs since the last improvement (best: %.5f)' % (epoch - bestEpoch , bestLoss ))
170+ if args .patience <= (epoch - bestEpoch ):
171+ print ('Early stopping' )
172+ break
165173
166174if __name__ == '__main__' :
167- parser = argparse .ArgumentParser ()
168- parser .add_argument ('--epochs' , type = int , default = 1000 )
169- parser .add_argument ('--batch-size' , type = int , default = 64 )
170- parser .add_argument ('--patience' , type = int , default = 5 )
171- parser .add_argument ('--steps' , type = int , default = 5 )
172- parser .add_argument ('--model' , type = str )
173- parser .add_argument ('--embeddings' , default = False , action = 'store_true' )
174- parser .add_argument ('--folder' , type = str , default = ROOT_FOLDER )
175- parser .add_argument ('--modelId' , type = str )
176- parser .add_argument (
177- '--trainer' , type = str , default = 'default' ,
178- choices = ['default' ]
179- )
180- parser .add_argument (
181- '--sampling' , type = str , default = 'uniform' ,
182- choices = ['uniform' , 'as_is' ],
183- )
175+ parser = argparse .ArgumentParser ()
176+ parser .add_argument ('--epochs' , type = int , default = 1000 )
177+ parser .add_argument ('--batch-size' , type = int , default = 64 )
178+ parser .add_argument ('--patience' , type = int , default = 5 )
179+ parser .add_argument ('--steps' , type = int , default = 5 )
180+ parser .add_argument ('--model' , type = str )
181+ parser .add_argument ('--embeddings' , default = False , action = 'store_true' )
182+ parser .add_argument ('--folder' , type = str , default = ROOT_FOLDER )
183+ parser .add_argument ('--modelId' , type = str )
184+ parser .add_argument (
185+ '--trainer' , type = str , default = 'default' ,
186+ choices = ['default' ]
187+ )
188+ parser .add_argument (
189+ '--sampling' , type = str , default = 'uniform' ,
190+ choices = ['uniform' , 'as_is' ],
191+ )
192+ parser .add_argument ('--wandb-project' , type = str , default = 'alternative-input-reconstruction' )
184193
185- main (parser .parse_args ())
186- pass
194+ main (parser .parse_args ())
0 commit comments