Skip to content

Commit 7900487

Browse files
wip
1 parent 5818241 commit 7900487

File tree

9 files changed

+255
-221
lines changed

9 files changed

+255
-221
lines changed

Core/CBaseModel.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import os
2+
import numpy as np
3+
from tensorflow.keras import layers as L
4+
5+
class CBaseModel:
6+
def __init__(self, model, embeddings, submodels):
7+
self._model = model
8+
self._embeddings = {
9+
'userId': L.Embedding(embeddings['userId'], embeddings['size']),
10+
'placeId': L.Embedding(embeddings['placeId'], embeddings['size']),
11+
'screenId': L.Embedding(embeddings['screenId'], embeddings['size']),
12+
}
13+
self._submodels = submodels
14+
return
15+
16+
def replaceByEmbeddings(self, data):
17+
data = dict(**data) # copy
18+
for name, emb in self._embeddings.items():
19+
data[name] = emb(data[name][..., 0])
20+
continue
21+
return data
22+
23+
def _modelFilename(self, folder, postfix=''):
24+
postfix = '-' + postfix if postfix else ''
25+
return os.path.join(folder, '%s%s.h5' % (self._modelID, postfix))
26+
27+
def save(self, folder=None, postfix=''):
28+
path = self._modelFilename(folder, postfix)
29+
if 1 < len(self._submodels):
30+
for i, model in enumerate(self._submodels):
31+
model.save_weights(path.replace('.h5', '-%d.h5' % i))
32+
else:
33+
self._submodels[0].save_weights(path)
34+
35+
embeddings = {}
36+
for nm in self._embeddings.keys():
37+
weights = self._embeddings[nm].get_weights()[0]
38+
embeddings[nm] = weights
39+
40+
np.savez_compressed(path.replace('.h5', '-embeddings.npz'), **embeddings)
41+
42+
def load(self, folder=None, postfix='', embeddings=False):
43+
path = self._modelFilename(folder, postfix) if not os.path.isfile(folder) else folder
44+
if 1 < len(self._submodels):
45+
for i, model in enumerate(self._submodels):
46+
model.load_weights(path.replace('.h5', '-%d.h5' % i))
47+
else:
48+
self._submodels[0].load_weights(path)
49+
50+
if embeddings:
51+
embeddings = np.load(path.replace('.h5', '-embeddings.npz'))
52+
for nm, emb in self._embeddings.items():
53+
w = embeddings[nm]
54+
if not emb.built: emb.build((None, w.shape[0]))
55+
emb.set_weights([w]) # replace embeddings
56+
57+
def trainable_variables(self):
58+
parts = list(self._embeddings.values()) + self._submodels
59+
return sum([p.trainable_variables for p in parts], [])

Core/CDataSamplerInpainting.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from Core.Utils import FACE_MESH_POINTS
44

55
import numpy as np
6+
import tensorflow as tf
67

78
'''
89
This sampler are sample N frames from the dataset, where N is the number of timesteps.
@@ -24,8 +25,9 @@
2425
- The target point.
2526
'''
2627
class CDataSamplerInpainting(CBaseDataSampler):
27-
def __init__(self, storage, batch_size, minFrames, defaults={}, maxT=1.0, cumulative_time=True):
28+
def __init__(self, storage, batch_size, minFrames, keys, defaults={}, maxT=1.0, cumulative_time=True):
2829
super().__init__(storage, batch_size, minFrames, defaults, maxT, cumulative_time)
30+
self._keys = keys
2931

3032
def _stepsFor(self, mainInd, steps, stepsSampling='uniform', **_):
3133
if (steps is None) or (1 == steps): return [(mainInd, 0.0)]
@@ -46,7 +48,7 @@ def _stepsFor(self, mainInd, steps, stepsSampling='uniform', **_):
4648
def sample(self, **kwargs):
4749
kwargs = {**self._defaults, **kwargs}
4850
timesteps = kwargs.get('timesteps', None)
49-
N = kwargs.get('N', self._batchSize)
51+
N = kwargs.get('N', self._batchSize) // len(self._keys)
5052
indexes = []
5153
for _ in range(N):
5254
added = False
@@ -101,6 +103,7 @@ def sampleByIds(self, ids, **kwargs):
101103
def _indexes2XY(self, indexesAndTime, kwargs):
102104
timesteps = kwargs.get('timesteps', None)
103105
assert timesteps is not None, 'The number of timesteps must be defined.'
106+
B = len(indexesAndTime) // timesteps
104107
samples = [self._storage[i] for i, _ in indexesAndTime]
105108
##############
106109
userIds = np.unique([x['userId'] for x in samples])
@@ -130,7 +133,25 @@ def _indexes2XY(self, indexesAndTime, kwargs):
130133
),
131134
userIds[0], placeIds[0], screenIds[0]
132135
)
133-
X = X['clean']
136+
for k in X.keys():
137+
# add the target point to the X
138+
targets = np.array([x['goal'] for x in samples], np.float32).reshape((B, timesteps, 2))
139+
X[k]['target'] = tf.constant(targets, dtype=tf.float32)
140+
141+
if 1 == len(self._keys):
142+
X = X[self._keys[0]]
143+
else:
144+
res = {}
145+
k = self._keys[0]
146+
subkeys = list(X[k].keys())
147+
for k in subkeys:
148+
values = [X[key][k] for key in self._keys]
149+
res[k] = tf.concat(values, axis=0)
150+
continue
151+
X = res
152+
indexesAndTime = indexesAndTime * len(self._keys)
153+
B = len(self._keys) * B
154+
134155
###############
135156
# generate the target data
136157
targets = kwargs.get('targets', {'keypoints': timesteps, 'total': timesteps})
@@ -141,7 +162,7 @@ def _indexes2XY(self, indexesAndTime, kwargs):
141162

142163
samples_indexes = np.array([ i for i, _ in indexesAndTime], np.int32)
143164
samples_indexes = samples_indexes.reshape((-1, timesteps))
144-
B = samples_indexes.shape[0]
165+
assert samples_indexes.shape[0] == B, 'Invalid number of samples: %d != %d' % (samples_indexes.shape[0], B)
145166
targetsIdx = np.zeros((B, T), np.int32)
146167
for i in range(B):
147168
# sample K frames from the X
@@ -186,20 +207,20 @@ def _indexes2XY(self, indexesAndTime, kwargs):
186207
Y['right eye'][i, j] = data['right eye'][p:p+32, p:p+32]
187208
Y['time'][i, j] = (data['time'] - startT) / duration
188209
Y['target'][i, j] = data['goal']
189-
210+
# eyes in 0..255, so we need to normalize them
211+
Y['left eye'] /= 255.0
212+
Y['right eye'] /= 255.0
190213
# check that time is between 0 and 1
191214
assert np.all((0 <= Y['time']) & (Y['time'] <= 1)), 'Invalid time: ' + str(Y['time'])
192-
B = Y['points'].shape[0]
193215
for k, v in X.items():
194216
assert B == v.shape[0], f'Invalid batch size for X[{k}]: {v.shape[0]} != {B} ({v.shape})'
195217
for k, v in Y.items():
196218
assert B == v.shape[0], f'Invalid batch size for Y[{k}]: {v.shape[0]} != {B} ({v.shape})'
197219
return (X, Y)
198220

199221
def merge(self, samples, expected_batch_size):
200-
# each dictionary contains the subkeys: points, left eye, right eye, time, userId, placeId, screenId
201222
X = {}
202-
for subkey in ['points', 'left eye', 'right eye', 'time', 'userId', 'placeId', 'screenId']:
223+
for subkey in ['points', 'left eye', 'right eye', 'time', 'userId', 'placeId', 'screenId', 'target']:
203224
data = [x[subkey] for x, _ in samples]
204225
X[subkey] = np.concatenate(data, axis=0)
205226
assert X[subkey].shape[0] == expected_batch_size, 'Invalid batch size: %d != %d' % (X[subkey].shape[0], expected_batch_size)
@@ -211,4 +232,4 @@ def merge(self, samples, expected_batch_size):
211232
Y[subkey] = np.concatenate(data, axis=0)
212233
assert Y[subkey].shape[0] == expected_batch_size, 'Invalid batch size: %d != %d' % (Y[subkey].shape[0], expected_batch_size)
213234
continue
214-
return (X, (Y, ))
235+
return (X, Y)

Core/CInpaintingTrainer.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import tensorflow as tf
2+
import time
3+
import NN.Utils as NNU
4+
import NN.networks as networks
5+
from Core.CBaseModel import CBaseModel
6+
7+
class CInpaintingTrainer:
8+
def __init__(self, timesteps, model='simple', KP=5, **kwargs):
9+
stats = kwargs.get('stats', None)
10+
embeddingsSize = kwargs.get('embeddingsSize', 64)
11+
latentSize = kwargs.get('latentSize', 64)
12+
embeddings = {
13+
'userId': len(stats['userId']),
14+
'placeId': len(stats['placeId']),
15+
'screenId': len(stats['screenId']),
16+
'size': embeddingsSize,
17+
}
18+
19+
self._encoder = networks.InpaintingEncoderModel(
20+
steps=timesteps, latentSize=latentSize,
21+
embeddingsSize=embeddingsSize,
22+
KP=KP,
23+
)
24+
self._decoder = networks.InpaintingDecoderModel(
25+
latentSize=latentSize,
26+
embeddingsSize=embeddingsSize,
27+
KP=KP,
28+
)
29+
self._model = CBaseModel(
30+
model=model, embeddings=embeddings, submodels=[self._encoder, self._decoder]
31+
)
32+
self.compile()
33+
# add signatures to help tensorflow optimize the graph
34+
specification = networks.InpaintingInputSpec()
35+
self._trainStep = tf.function(
36+
self._trainStep,
37+
input_signature=[specification]
38+
)
39+
self._eval = tf.function(
40+
self._eval,
41+
input_signature=[specification]
42+
)
43+
return
44+
45+
def compile(self):
46+
self._optimizer = NNU.createOptimizer()
47+
48+
def _trainStep(self, Data):
49+
print('Instantiate _trainStep')
50+
###############
51+
x, y = Data
52+
losses = {}
53+
with tf.GradientTape() as tape:
54+
x = self._model.replaceByEmbeddings(x)
55+
latents = self._encoder(x, training=True)['latent']
56+
decoderArgs = {
57+
'keyPoints': latents,
58+
'time': y['time'],
59+
'userId': x['userId'],
60+
'placeId': x['placeId'],
61+
'screenId': x['screenId'],
62+
}
63+
predictions = self._decoder(decoderArgs, training=True)
64+
losses = {}
65+
for k in predictions.keys():
66+
pred = predictions[k]
67+
gt = y[k]
68+
tf.assert_equal(tf.shape(pred), tf.shape(gt))
69+
loss = tf.losses.mse(gt, pred)
70+
losses[f"loss-{k}"] = tf.reduce_mean(loss)
71+
72+
# calculate total loss and final loss
73+
losses['loss'] = loss = sum(losses.values())
74+
75+
self._optimizer.minimize(loss, tape.watched_variables(), tape=tape)
76+
###############
77+
return losses
78+
79+
def fit(self, data):
80+
t = time.time()
81+
losses = self._trainStep(data)
82+
losses = {k: v.numpy() for k, v in losses.items()}
83+
return {'time': int((time.time() - t) * 1000), 'losses': losses}
84+
85+
def _eval(self, xy):
86+
print('Instantiate _eval')
87+
x, (y,) = xy
88+
x = self._replaceByEmbeddings(x)
89+
y = y[:, :, 0]
90+
predictions = self._model(x, training=False)
91+
points = predictions['result'][:, :, :]
92+
tf.assert_equal(tf.shape(points), tf.shape(y))
93+
94+
loss = self._pointLoss(y, points)
95+
tf.assert_equal(tf.shape(loss), tf.shape(y)[:2])
96+
_, dist = NNU.normVec(points - y)
97+
return loss, points, dist
98+
99+
def eval(self, data):
100+
loss, sampled, dist = self._eval(data)
101+
return loss.numpy(), sampled.numpy(), dist.numpy()

0 commit comments

Comments
 (0)