Skip to content

Commit e99e991

Browse files
Can evaluate the inpainting model
1 parent 7900487 commit e99e991

File tree

8 files changed

+241
-209
lines changed

8 files changed

+241
-209
lines changed

Core/CDataSamplerInpainting.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def sample(self, **kwargs):
5050
timesteps = kwargs.get('timesteps', None)
5151
N = kwargs.get('N', self._batchSize) // len(self._keys)
5252
indexes = []
53+
added = False
5354
for _ in range(N):
5455
added = False
5556
while not added:
@@ -62,7 +63,7 @@ def sample(self, **kwargs):
6263
indexes.extend(sampledSteps)
6364
added = True
6465
continue
65-
66+
if not added: return None, 0
6667
return self._indexes2XY(indexes, kwargs)
6768

6869
def sampleById(self, idx, **kwargs):
@@ -97,7 +98,7 @@ def sampleByIds(self, ids, **kwargs):
9798

9899
res = None
99100
if 0 < len(sampledSteps):
100-
res = self._indexes2XY(sampledSteps, kwargs)
101+
res, _ = self._indexes2XY(sampledSteps, kwargs)
101102
return res, rejected, accepted
102103

103104
def _indexes2XY(self, indexesAndTime, kwargs):
@@ -216,7 +217,7 @@ def _indexes2XY(self, indexesAndTime, kwargs):
216217
assert B == v.shape[0], f'Invalid batch size for X[{k}]: {v.shape[0]} != {B} ({v.shape})'
217218
for k, v in Y.items():
218219
assert B == v.shape[0], f'Invalid batch size for Y[{k}]: {v.shape[0]} != {B} ({v.shape})'
219-
return (X, Y)
220+
return (X, Y), B
220221

221222
def merge(self, samples, expected_batch_size):
222223
X = {}

Core/CDatasetLoader.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ def __init__(self, folder, samplerArgs, sampling, stats, sampler_class):
4343
i: len(ds.validSamples())
4444
for i, ds in enumerate(self._datasets)
4545
}
46+
# ignore datasets with no valid samples
47+
validSamples = {k: v for k, v in validSamples.items() if 0 < v}
4648
dtype = np.uint8 if len(self._datasets) < 256 else np.uint32
4749
# create an array of dataset indices to sample from
4850
sampling = ESampling(sampling)
@@ -101,13 +103,15 @@ def sample(self, **kwargs):
101103
samples = []
102104
totalSamples = 0
103105
# find the datasets ids and the number of samples to take from each dataset
104-
datasetIds, counts = self._getBatchStats(batchSize)
105-
for datasetId, N in zip(datasetIds, counts):
106-
dataset = self._datasets[datasetId]
107-
sampled = dataset.sample(N=N, **kwargs)
108-
samples.append(sampled)
109-
totalSamples += N
110-
continue
106+
while totalSamples < batchSize:
107+
datasetIds, counts = self._getBatchStats(batchSize - totalSamples)
108+
for datasetId, N in zip(datasetIds, counts):
109+
dataset = self._datasets[datasetId]
110+
sampled, N = dataset.sample(N=N, **kwargs)
111+
if 0 < N:
112+
samples.append(sampled)
113+
totalSamples += N
114+
continue
111115

112116
first_dataset = self._datasets[0]
113117
return first_dataset.merge(samples, batchSize)

Core/CInpaintingTrainer.py

Lines changed: 30 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -45,32 +45,36 @@ def __init__(self, timesteps, model='simple', KP=5, **kwargs):
4545
def compile(self):
4646
self._optimizer = NNU.createOptimizer()
4747

48+
def _calcLoss(self, x, y, training):
49+
losses = {}
50+
x = self._model.replaceByEmbeddings(x)
51+
latents = self._encoder(x, training=training)['latent']
52+
decoderArgs = {
53+
'keyPoints': latents,
54+
'time': y['time'],
55+
'userId': x['userId'],
56+
'placeId': x['placeId'],
57+
'screenId': x['screenId'],
58+
}
59+
predictions = self._decoder(decoderArgs, training=training)
60+
losses = {}
61+
for k in predictions.keys():
62+
pred = predictions[k]
63+
gt = y[k]
64+
tf.assert_equal(tf.shape(pred), tf.shape(gt))
65+
loss = tf.losses.mse(gt, pred)
66+
losses[f"loss-{k}"] = tf.reduce_mean(loss)
67+
68+
# calculate total loss and final loss
69+
losses['loss'] = sum(losses.values())
70+
return losses, losses['loss']
71+
4872
def _trainStep(self, Data):
4973
print('Instantiate _trainStep')
5074
###############
5175
x, y = Data
52-
losses = {}
5376
with tf.GradientTape() as tape:
54-
x = self._model.replaceByEmbeddings(x)
55-
latents = self._encoder(x, training=True)['latent']
56-
decoderArgs = {
57-
'keyPoints': latents,
58-
'time': y['time'],
59-
'userId': x['userId'],
60-
'placeId': x['placeId'],
61-
'screenId': x['screenId'],
62-
}
63-
predictions = self._decoder(decoderArgs, training=True)
64-
losses = {}
65-
for k in predictions.keys():
66-
pred = predictions[k]
67-
gt = y[k]
68-
tf.assert_equal(tf.shape(pred), tf.shape(gt))
69-
loss = tf.losses.mse(gt, pred)
70-
losses[f"loss-{k}"] = tf.reduce_mean(loss)
71-
72-
# calculate total loss and final loss
73-
losses['loss'] = loss = sum(losses.values())
77+
losses, loss = self._calcLoss(x, y, training=True)
7478

7579
self._optimizer.minimize(loss, tape.watched_variables(), tape=tape)
7680
###############
@@ -84,18 +88,10 @@ def fit(self, data):
8488

8589
def _eval(self, xy):
8690
print('Instantiate _eval')
87-
x, (y,) = xy
88-
x = self._replaceByEmbeddings(x)
89-
y = y[:, :, 0]
90-
predictions = self._model(x, training=False)
91-
points = predictions['result'][:, :, :]
92-
tf.assert_equal(tf.shape(points), tf.shape(y))
93-
94-
loss = self._pointLoss(y, points)
95-
tf.assert_equal(tf.shape(loss), tf.shape(y)[:2])
96-
_, dist = NNU.normVec(points - y)
97-
return loss, points, dist
91+
x, y = xy
92+
losses, loss = self._calcLoss(x, y, training=False)
93+
return loss
9894

9995
def eval(self, data):
100-
loss, sampled, dist = self._eval(data)
101-
return loss.numpy(), sampled.numpy(), dist.numpy()
96+
loss = self._eval(data)
97+
return loss.numpy()

Core/CTestInpaintingLoader.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import tensorflow as tf
2+
import numpy as np
3+
import os, glob
4+
from functools import lru_cache
5+
6+
class CTestInpaintingLoader(tf.keras.utils.Sequence):
7+
def __init__(self, testFolder):
8+
self._batchesNpz = [
9+
f for f in glob.glob(os.path.join(testFolder, 'test-*.npz'))
10+
]
11+
self.on_epoch_end()
12+
return
13+
14+
def on_epoch_end(self):
15+
return
16+
17+
def __len__(self):
18+
return len(self._batchesNpz)
19+
20+
def __getitem__(self, idx):
21+
with np.load(self._batchesNpz[idx]) as res:
22+
res = {k: v for k, v in res.items()}
23+
24+
X = {k.replace('X_', ''): v for k, v in res.items() if 'X_' in k}
25+
Y = {k.replace('Y_', ''): v for k, v in res.items() if 'Y_' in k}
26+
return(X, Y)

Core/CTestLoader.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,12 @@
55

66
class CTestLoader(tf.keras.utils.Sequence):
77
def __init__(self, testFolder):
8-
self._folder = testFolder
98
self._batchesNpz = [
109
f for f in glob.glob(os.path.join(testFolder, 'test-*.npz'))
1110
]
1211
self.on_epoch_end()
1312
return
1413

15-
@property
16-
def folder(self):
17-
return self._folder
18-
19-
@lru_cache(maxsize=1)
20-
def parametersIDs(self):
21-
batch, _ = self[0]
22-
userId = batch['userId'][0, 0, 0]
23-
placeId = batch['placeId'][0, 0, 0]
24-
screenId = batch['screenId'][0, 0, 0]
25-
return userId, placeId, screenId
26-
2714
def on_epoch_end(self):
2815
return
2916

@@ -35,15 +22,4 @@ def __getitem__(self, idx):
3522
res = {k: v for k, v in res.items()}
3623

3724
Y = res.pop('y')
38-
return(res, (Y, ))
39-
40-
if __name__ == '__main__':
41-
folder = os.path.dirname(__file__)
42-
ds = CTestLoader(os.path.join(folder, 'test'))
43-
print(len(ds))
44-
batch, (y,) = ds[0]
45-
for k, v in batch.items():
46-
print(k, v.shape)
47-
print()
48-
print(y.shape)
49-
pass
25+
return(res, (Y, ))
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-.
3+
import argparse, os, sys
4+
# add the root folder of the project to the path
5+
ROOT_FOLDER = os.path.abspath(os.path.dirname(__file__) + '/../')
6+
sys.path.append(ROOT_FOLDER)
7+
8+
import numpy as np
9+
import Core.Utils as Utils
10+
from Core.CSamplesStorage import CSamplesStorage
11+
from Core.CDataSamplerInpainting import CDataSamplerInpainting
12+
from collections import defaultdict
13+
import glob
14+
import json
15+
import shutil
16+
import tensorflow as tf
17+
18+
BATCH_SIZE = 128 * 4
19+
20+
def samplesStream(params, take, filename, stats):
21+
if not isinstance(take, list): take = [take]
22+
# filename is "{placeId}/{userId}/{screenId}/train.npz"
23+
# extract the placeId, userId, and screenId
24+
parts = os.path.split(filename)[0].split(os.path.sep)
25+
placeId, userId, screenId = parts[-3], parts[-2], parts[-1]
26+
# use the stats to get the numeric values of the placeId, userId, and screenId
27+
ds = CDataSamplerInpainting(
28+
CSamplesStorage(
29+
placeId=stats['placeId'].index(placeId),
30+
userId=stats['userId'].index(userId),
31+
screenId=stats['screenId'].index('%s/%s' % (placeId, screenId))
32+
),
33+
defaults=params,
34+
batch_size=BATCH_SIZE, minFrames=params['timesteps'],
35+
keys=take
36+
)
37+
ds.addBlock(Utils.datasetFrom(filename))
38+
39+
N = ds.totalSamples
40+
for i in range(0, N, BATCH_SIZE):
41+
indices = list(range(i, min(i + BATCH_SIZE, N)))
42+
batch, rejected, accepted = ds.sampleByIds(indices)
43+
if batch is None: continue
44+
45+
# main batch
46+
x, y = batch
47+
for idx in range(len(x['points'])):
48+
resX = {}
49+
for k, v in x.items():
50+
item = v[idx, None]
51+
if tf.is_tensor(item): item = item.numpy()
52+
resX[f'X_{k}'] = item
53+
continue
54+
55+
resY = {}
56+
for k, v in y.items():
57+
item = v[idx, None]
58+
if tf.is_tensor(item): item = item.numpy()
59+
resY[f'Y_{k}'] = item
60+
continue
61+
62+
yield dict(**resX, **resY)
63+
continue
64+
continue
65+
return
66+
67+
def batches(*params):
68+
data = defaultdict(list)
69+
for sample in samplesStream(*params):
70+
for k, v in sample.items():
71+
data[k].append(v)
72+
continue
73+
74+
if BATCH_SIZE <= len(data['X_points']):
75+
yield data
76+
data = defaultdict(list)
77+
continue
78+
79+
if 0 < len(data['X_points']):
80+
# copy data to match batch size
81+
for k, v in data.items():
82+
while len(v) < BATCH_SIZE: v.extend(v)
83+
data[k] = v[:BATCH_SIZE]
84+
continue
85+
yield data
86+
return
87+
############################################
88+
def generateTestDataset(params, filename, stats, outputFolder):
89+
# generate test dataset
90+
ONE_MB = 1024 * 1024
91+
totalSize = 0
92+
if not os.path.exists(outputFolder):
93+
os.makedirs(outputFolder, exist_ok=True)
94+
for bIndex, batch in enumerate(batches(params, ['clean'], filename, stats)):
95+
fname = os.path.join(outputFolder, 'test-%d.npz' % bIndex)
96+
# concatenate all arrays
97+
batch = {k: np.concatenate(v, axis=0) for k, v in batch.items()}
98+
np.savez_compressed(fname, **batch)
99+
# get fname size
100+
size = os.path.getsize(fname)
101+
totalSize += size
102+
print('%d | Size: %.1f MB | Total: %.1f MB' % (bIndex + 1, size / ONE_MB, totalSize / ONE_MB))
103+
continue
104+
print('Done')
105+
return
106+
107+
def main(args):
108+
PARAMS = [
109+
dict(
110+
timesteps=args.steps,
111+
stepsSampling='uniform',
112+
# no augmentations by default
113+
pointsNoise=0.01, pointsDropout=0.0,
114+
eyesDropout=0.1, eyesAdditiveNoise=0.01, brightnessFactor=1.5, lightBlobFactor=1.5,
115+
targets=dict(keypoints=3, total=10),
116+
),
117+
]
118+
folder = os.path.join(ROOT_FOLDER, 'Data', 'remote')
119+
120+
stats = None
121+
with open(os.path.join(folder, 'stats.json'), 'r') as f:
122+
stats = json.load(f)
123+
124+
# remove all content from the output folder
125+
shutil.rmtree(args.output, ignore_errors=True)
126+
# recursively find the train file
127+
trainFilename = glob.glob(os.path.join(folder, '**', 'test.npz'), recursive=True)
128+
print('Found test files:', len(trainFilename))
129+
for idx, filename in enumerate(trainFilename):
130+
print('Processing', filename)
131+
for params in PARAMS:
132+
targetFolder = os.path.join(args.output, 'test-%d' % idx)
133+
generateTestDataset(params, filename, stats, outputFolder=targetFolder)
134+
continue
135+
return
136+
137+
if __name__ == '__main__':
138+
parser = argparse.ArgumentParser()
139+
parser.add_argument('--steps', type=int, default=5, help='Number of timesteps')
140+
parser.add_argument('--batch-size', type=int, default=512, help='Batch size of the test dataset')
141+
parser.add_argument(
142+
'--output', type=str, help='Output folder',
143+
default=os.path.join(ROOT_FOLDER, 'Data', 'test-inpainting')
144+
)
145+
args = parser.parse_args()
146+
BATCH_SIZE = args.batch_size # TODO: fix this hack
147+
main(args)
148+
pass

0 commit comments

Comments
 (0)