Skip to content

Commit 4c88b60

Browse files
integrate w&b and other changes
1 parent 07a01d3 commit 4c88b60

File tree

1 file changed

+161
-153
lines changed

1 file changed

+161
-153
lines changed

scripts/train-reconstruction.py

Lines changed: 161 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-.
3-
# TODO: add the W&B integration
2+
# -*- coding: utf-8 -*-
43
import argparse, os, sys
5-
# add the root folder of the project to the path
4+
# Add the root folder of the project to the path
65
ROOT_FOLDER = os.path.abspath(os.path.dirname(__file__) + '/../')
76
sys.path.append(ROOT_FOLDER)
87

@@ -16,171 +15,180 @@
1615
from Core.CInpaintingTrainer import CInpaintingTrainer
1716
import tqdm
1817
import json
18+
import wandb
1919

2020
def _eval(dataset, model):
21-
T = time.time()
22-
# evaluate the model on the val dataset
23-
loss = []
24-
for batchId in range(len(dataset)):
25-
batch = dataset[batchId]
26-
loss_value = model.eval(batch)
27-
loss.append(loss_value)
28-
continue
29-
30-
loss = np.mean(loss)
31-
T = time.time() - T
32-
return loss, T
21+
T = time.time()
22+
# Evaluate the model on the validation dataset
23+
loss = []
24+
for batchId in range(len(dataset)):
25+
batch = dataset[batchId]
26+
loss_value = model.eval(batch)
27+
loss.append(loss_value)
28+
loss = np.mean(loss)
29+
T = time.time() - T
30+
return loss, T
3331

3432
def evaluator(datasets, model, folder, args):
35-
losses = [np.inf] * len(datasets) # initialize with infinity
36-
def evaluate(onlyImproved=False):
37-
totalLoss = []
38-
for i, dataset in enumerate(datasets):
39-
loss, T = _eval(dataset, model)
40-
isImproved = loss < losses[i]
41-
if (not onlyImproved) or isImproved:
42-
dataset_id = ', '.join([str(x) for x in dataset.parametersIDs()])
43-
print('Test %d / %d (%s) | %.2f sec | Loss: %.5f (%.5f).' % (
44-
i + 1, len(datasets), dataset_id, T, loss, losses[i],
45-
))
46-
if isImproved:
47-
print('Test %d / %d | Improved %.5f => %.5f,' % (
48-
i + 1, len(datasets), losses[i], loss,
49-
))
50-
model.save(folder, postfix='best-%d' % i) # save the model separately
51-
losses[i] = loss
52-
pass
33+
losses = [np.inf] * len(datasets) # Initialize with infinity
34+
def evaluate(onlyImproved=False, step=None):
35+
totalLoss = []
36+
eval_metrics = {}
37+
for i, dataset in enumerate(datasets):
38+
loss, T = _eval(dataset, model)
39+
dataset_id = ', '.join([str(x) for x in dataset.parametersIDs()])
40+
isImproved = loss < losses[i]
41+
if (not onlyImproved) or isImproved:
42+
print('Test %d / %d (%s) | %.2f sec | Loss: %.5f (%.5f).' % (
43+
i + 1, len(datasets), dataset_id, T, loss, losses[i],
44+
))
45+
if isImproved:
46+
print('Test %d / %d | Improved %.5f => %.5f,' % (
47+
i + 1, len(datasets), losses[i], loss,
48+
))
49+
modelFolder = os.path.join(folder, f"model-{dataset_id}")
50+
os.makedirs(modelFolder, exist_ok=True)
51+
# keep only the best model across all runs
52+
# name format: {model id}-{loss:.5f}-*.*
53+
all_files = os.listdir(modelFolder)
54+
all_losses = [f.split('-')[1] for f in all_files]
55+
all_losses = list(set(all_losses))
56+
print(f"Found losses: {all_losses}")
57+
for loss_file in all_losses:
58+
if float(loss_file) > loss:
59+
# remove all files with this loss in folder
60+
to_remove = [os.path.join(modelFolder, f) for f in all_files if loss_file in f]
61+
for f in to_remove:
62+
os.remove(f)
5363

54-
totalLoss.append(loss)
55-
continue
56-
if not onlyImproved:
57-
print('Mean loss: %.5f' % (np.mean(totalLoss), ))
58-
return np.mean(totalLoss)
59-
return evaluate
64+
model.save(modelFolder, postfix='%.5f' % loss)
65+
losses[i] = loss
66+
totalLoss.append(loss)
67+
eval_metrics['eval_loss_(%s)' % dataset_id] = loss
68+
mean_loss = np.mean(totalLoss)
69+
if not onlyImproved:
70+
print('Mean loss: %.5f' % mean_loss)
71+
# Log evaluation metrics to wandb
72+
if step is not None:
73+
wandb.log(eval_metrics, step=step)
74+
return mean_loss
75+
return evaluate
6076

6177
def _modelTrainingLoop(model, dataset):
62-
def F(desc):
63-
history = defaultdict(list)
64-
# use the tqdm progress bar
65-
with tqdm.tqdm(total=len(dataset), desc=desc) as pbar:
66-
dataset.on_epoch_start()
67-
for _ in range(len(dataset)):
68-
sampled = dataset.sample()
69-
stats = model.fit(sampled)
70-
history['time'].append(stats['time'])
71-
for k in stats['losses'].keys():
72-
history[k].append(stats['losses'][k])
73-
# add stats to the progress bar (mean of each history)
74-
pbar.set_postfix({k: '%.5f' % np.mean(v) for k, v in history.items()})
75-
pbar.update(1)
76-
continue
77-
dataset.on_epoch_end()
78-
return
79-
return F
78+
def F(desc):
79+
history = defaultdict(list)
80+
# Use the tqdm progress bar
81+
with tqdm.tqdm(total=len(dataset), desc=desc) as pbar:
82+
dataset.on_epoch_start()
83+
for step in range(len(dataset)):
84+
sampled = dataset.sample()
85+
stats = model.fit(sampled)
86+
history['time'].append(stats['time'])
87+
for k in stats['losses'].keys():
88+
history[k].append(stats['losses'][k])
89+
# Add stats to the progress bar (mean of each history)
90+
pbar.set_postfix({k: '%.5f' % np.mean(v) for k, v in history.items()})
91+
pbar.update(1)
92+
dataset.on_epoch_end()
93+
return {k: np.mean(v) for k, v in history.items()}
94+
return F
8095

8196
def _trainer_from(args):
82-
if args.trainer == 'default': return CInpaintingTrainer
83-
raise Exception('Unknown trainer: %s' % (args.trainer, ))
97+
if args.trainer == 'default': return CInpaintingTrainer
98+
raise Exception('Unknown trainer: %s' % (args.trainer, ))
8499

85100
def main(args):
86-
timesteps = args.steps
87-
folder = os.path.join(args.folder, 'Data')
88-
89-
stats = None
90-
with open(os.path.join(folder, 'remote', 'stats.json'), 'r') as f:
91-
stats = json.load(f)
101+
wandb.init(project=args.wandb_project, config=vars(args)) # Initialize wandb
102+
timesteps = args.steps
103+
folder = os.path.join(args.folder, 'Data')
92104

93-
trainer = _trainer_from(args)
94-
trainDataset = CDatasetLoader(
95-
os.path.join(folder, 'remote'),
96-
stats=stats,
97-
sampling=args.sampling,
98-
samplerArgs=dict(
99-
batch_size=args.batch_size,
100-
minFrames=timesteps,
101-
maxT=1.0,
102-
defaults=dict(
103-
timesteps=timesteps,
104-
stepsSampling={'max frames': 10},
105-
# no augmentations by default
106-
pointsNoise=0.01, pointsDropout=0.01,
107-
eyesDropout=0.1, eyesAdditiveNoise=0.01, brightnessFactor=1.5, lightBlobFactor=1.5,
108-
targets=dict(keypoints=3, total=10),
109-
),
110-
keys=['clean'],
111-
),
112-
sampler_class=CDataSamplerInpainting,
113-
test_folders=['train.npz'],
114-
)
115-
model = dict(timesteps=timesteps, stats=stats)
116-
if args.model is not None:
117-
model['weights'] = dict(folder=folder, postfix=args.model, embeddings=args.embeddings)
118-
if args.modelId is not None:
119-
model['model'] = args.modelId
105+
stats = None
106+
with open(os.path.join(folder, 'remote', 'stats.json'), 'r') as f:
107+
stats = json.load(f)
120108

121-
model = trainer(**model)
122-
# model._model.summary()
123-
124-
evalDatasets = [
125-
CTestInpaintingLoader(os.path.join(folderName, 'test-inpainting'))
126-
for folderName, _ in Utils.dataset_from_stats(stats, os.path.join(folder, 'remote'))
127-
if os.path.exists(os.path.join(folderName, 'test-inpainting'))
128-
]
129-
eval = evaluator(evalDatasets, model, folder, args)
130-
bestLoss = eval() # evaluate loaded model
131-
bestEpoch = 0
132-
# wrapper for the evaluation function. It saves the model if it is better
133-
def evalWrapper(eval):
134-
def f(epoch, onlyImproved=False):
135-
nonlocal bestLoss, bestEpoch
136-
newLoss = eval(onlyImproved=onlyImproved)
137-
if newLoss < bestLoss:
138-
print('Improved %.5f => %.5f' % (bestLoss, newLoss))
139-
if onlyImproved: #details
140-
for i, (loss, bestLoss_, dist, bestDist) in enumerate(losses):
141-
print('Test %d | Loss: %.5f (%.5f). Distance: %.5f (%.5f)' % (i + 1, loss, bestLoss_, dist, bestDist))
142-
continue
143-
print('-' * 80)
144-
bestLoss = newLoss
145-
bestEpoch = epoch
146-
model.save(folder, postfix='best')
147-
return
148-
return f
149-
150-
eval = evalWrapper(eval)
151-
trainStep = _modelTrainingLoop(model, trainDataset)
152-
for epoch in range(args.epochs):
153-
trainStep(
154-
desc='Epoch %.*d / %d' % (len(str(args.epochs)), epoch, args.epochs),
109+
trainer = _trainer_from(args)
110+
trainDataset = CDatasetLoader(
111+
os.path.join(folder, 'remote'),
112+
stats=stats,
113+
sampling=args.sampling,
114+
samplerArgs=dict(
115+
batch_size=args.batch_size,
116+
minFrames=timesteps,
117+
maxT=1.0,
118+
defaults=dict(
119+
timesteps=timesteps,
120+
stepsSampling={'max frames': 10},
121+
# No augmentations by default
122+
pointsNoise=0.01, pointsDropout=0.01,
123+
eyesDropout=0.1, eyesAdditiveNoise=0.01, brightnessFactor=1.5, lightBlobFactor=1.5,
124+
targets=dict(keypoints=3, total=10),
125+
),
126+
keys=['clean'],
127+
),
128+
sampler_class=CDataSamplerInpainting,
129+
test_folders=['train.npz'],
155130
)
156-
model.save(folder, postfix='latest')
157-
eval(epoch)
131+
model = dict(timesteps=timesteps, stats=stats)
132+
if args.model is not None:
133+
model['weights'] = dict(folder=folder, postfix=args.model, embeddings=args.embeddings)
134+
if args.modelId is not None:
135+
model['model'] = args.modelId
136+
137+
model = trainer(**model)
138+
139+
evalDatasets = [
140+
CTestInpaintingLoader(os.path.join(folderName, 'test-inpainting'))
141+
for folderName, _ in Utils.dataset_from_stats(stats, os.path.join(folder, 'remote'))
142+
if os.path.exists(os.path.join(folderName, 'test-inpainting'))
143+
]
144+
eval_fn = evaluator(evalDatasets, model, folder, args)
145+
bestLoss = eval_fn() # Evaluate loaded model
146+
bestEpoch = 0
158147

159-
print('Passed %d epochs since the last improvement (best: %.5f)' % (epoch - bestEpoch, bestLoss))
160-
if args.patience <= (epoch - bestEpoch):
161-
print('Early stopping')
162-
break
163-
continue
164-
return
148+
def evalWrapper(eval_fn):
149+
def f(epoch, onlyImproved=False, step=None):
150+
nonlocal bestLoss, bestEpoch
151+
newLoss = eval_fn(onlyImproved=onlyImproved, step=step)
152+
if newLoss < bestLoss:
153+
print('Improved %.5f => %.5f' % (bestLoss, newLoss))
154+
bestLoss = newLoss
155+
bestEpoch = epoch
156+
model.save(folder, postfix='%.5f' % newLoss)
157+
return
158+
return f
159+
160+
eval_fn = evalWrapper(eval_fn)
161+
trainStep = _modelTrainingLoop(model, trainDataset)
162+
for epoch in range(args.epochs):
163+
metrics = trainStep(
164+
desc='Epoch %.*d / %d' % (len(str(args.epochs)), epoch, args.epochs),
165+
)
166+
wandb.log(metrics, step=epoch + 1)
167+
model.save(folder, postfix='latest')
168+
eval_fn(epoch, step=epoch + 1)
169+
print('Passed %d epochs since the last improvement (best: %.5f)' % (epoch - bestEpoch, bestLoss))
170+
if args.patience <= (epoch - bestEpoch):
171+
print('Early stopping')
172+
break
165173

166174
if __name__ == '__main__':
167-
parser = argparse.ArgumentParser()
168-
parser.add_argument('--epochs', type=int, default=1000)
169-
parser.add_argument('--batch-size', type=int, default=64)
170-
parser.add_argument('--patience', type=int, default=5)
171-
parser.add_argument('--steps', type=int, default=5)
172-
parser.add_argument('--model', type=str)
173-
parser.add_argument('--embeddings', default=False, action='store_true')
174-
parser.add_argument('--folder', type=str, default=ROOT_FOLDER)
175-
parser.add_argument('--modelId', type=str)
176-
parser.add_argument(
177-
'--trainer', type=str, default='default',
178-
choices=['default']
179-
)
180-
parser.add_argument(
181-
'--sampling', type=str, default='uniform',
182-
choices=['uniform', 'as_is'],
183-
)
175+
parser = argparse.ArgumentParser()
176+
parser.add_argument('--epochs', type=int, default=1000)
177+
parser.add_argument('--batch-size', type=int, default=64)
178+
parser.add_argument('--patience', type=int, default=5)
179+
parser.add_argument('--steps', type=int, default=5)
180+
parser.add_argument('--model', type=str)
181+
parser.add_argument('--embeddings', default=False, action='store_true')
182+
parser.add_argument('--folder', type=str, default=ROOT_FOLDER)
183+
parser.add_argument('--modelId', type=str)
184+
parser.add_argument(
185+
'--trainer', type=str, default='default',
186+
choices=['default']
187+
)
188+
parser.add_argument(
189+
'--sampling', type=str, default='uniform',
190+
choices=['uniform', 'as_is'],
191+
)
192+
parser.add_argument('--wandb-project', type=str, default='alternative-input-reconstruction')
184193

185-
main(parser.parse_args())
186-
pass
194+
main(parser.parse_args())

0 commit comments

Comments
 (0)