Skip to content

Commit 2bfcada

Browse files
wip
1 parent e99e991 commit 2bfcada

File tree

10 files changed

+217
-123
lines changed

10 files changed

+217
-123
lines changed

Core/CBaseModel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def replaceByEmbeddings(self, data):
2222

2323
def _modelFilename(self, folder, postfix=''):
2424
postfix = '-' + postfix if postfix else ''
25-
return os.path.join(folder, '%s%s.h5' % (self._modelID, postfix))
25+
return os.path.join(folder, '%s%s.h5' % (self._model, postfix))
2626

2727
def save(self, folder=None, postfix=''):
2828
path = self._modelFilename(folder, postfix)

Core/CDatasetLoader.py

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,47 @@
11
import Core.Utils as Utils
2-
import os, glob
2+
import os
33
from Core.CSamplesStorage import CSamplesStorage
4-
from Core.CDataSampler import CDataSampler
54
import numpy as np
6-
import tensorflow as tf
75
from enum import Enum
86

97
class ESampling(Enum):
108
AS_IS = 'as_is'
119
UNIFORM = 'uniform'
1210

1311
class CDatasetLoader:
14-
def __init__(self, folder, samplerArgs, sampling, stats, sampler_class):
15-
# recursively find all 'train.npz' files
16-
trainFiles = glob.glob(os.path.join(folder, '**', 'train.npz'), recursive=True)
17-
if 0 == len(trainFiles):
18-
raise Exception('No training dataset found in "%s"' % (folder, ))
19-
exit(1)
20-
21-
print('Found %d training datasets' % (len(trainFiles), ))
22-
12+
def __init__(self, folder, samplerArgs, sampling, stats, sampler_class, test_folders):
2313
self._datasets = []
24-
for trainFile in trainFiles:
25-
print('Loading %s' % (trainFile, ))
26-
# extract the placeId, userId, and screenId
27-
parts = os.path.split(trainFile)[0].split(os.path.sep)
28-
placeId, userId, screenId = parts[-3], parts[-2], parts[-1]
29-
ds = sampler_class(
30-
CSamplesStorage(
31-
placeId=stats['placeId'].index(placeId),
32-
userId=stats['userId'].index(userId),
33-
screenId=stats['screenId'].index('%s/%s' % (placeId, screenId))
34-
),
35-
**samplerArgs
36-
)
37-
ds.addBlock(Utils.datasetFrom(trainFile))
38-
self._datasets.append(ds)
39-
continue
14+
for datasetFolder, ID in Utils.dataset_from_stats(stats, folder):
15+
(place_id_index, user_id_index, screen_id_index) = ID
16+
for test_folder in test_folders:
17+
dataset = os.path.join(datasetFolder, test_folder)
18+
if not os.path.exists(dataset):
19+
continue
20+
print('Loading %s' % (dataset, ))
21+
print(f'ID: {ID}. Index: {1 + len(self._datasets)}')
22+
ds = sampler_class(
23+
CSamplesStorage(
24+
placeId=place_id_index,
25+
userId=user_id_index,
26+
screenId=screen_id_index,
27+
),
28+
**samplerArgs
29+
)
30+
ds.addBlock(Utils.datasetFrom(dataset))
31+
self._datasets.append(ds)
32+
33+
if 0 == len(self._datasets):
34+
raise Exception('No training dataset found in "%s"' % (folder, ))
4035

41-
print('Loaded %d datasets' % (len(self._datasets), ))
4236
validSamples = {
4337
i: len(ds.validSamples())
4438
for i, ds in enumerate(self._datasets)
4539
}
4640
# ignore datasets with no valid samples
4741
validSamples = {k: v for k, v in validSamples.items() if 0 < v}
42+
43+
print('Loaded %d datasets with %d valid samples' % (len(self._datasets), sum(validSamples.values())))
44+
4845
dtype = np.uint8 if len(self._datasets) < 256 else np.uint32
4946
# create an array of dataset indices to sample from
5047
sampling = ESampling(sampling)

Core/CInpaintingTrainer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ def __init__(self, timesteps, model='simple', KP=5, **kwargs):
4040
self._eval,
4141
input_signature=[specification]
4242
)
43+
44+
if 'weights' in kwargs:
45+
self.load(**kwargs['weights'])
4346
return
4447

4548
def compile(self):
@@ -94,4 +97,10 @@ def _eval(self, xy):
9497

9598
def eval(self, data):
9699
loss = self._eval(data)
97-
return loss.numpy()
100+
return loss.numpy()
101+
102+
def save(self, folder=None, postfix=''):
103+
self._model.save(folder=folder, postfix=postfix)
104+
105+
def load(self, folder=None, postfix='', embeddings=False):
106+
self._model.load(folder=folder, postfix=postfix, embeddings=embeddings)

Core/CTestInpaintingLoader.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@ def __init__(self, testFolder):
1111
self.on_epoch_end()
1212
return
1313

14+
@lru_cache(maxsize=1)
15+
def parametersIDs(self):
16+
batch, _ = self[0]
17+
userId = batch['userId'][0, 0, 0]
18+
placeId = batch['placeId'][0, 0, 0]
19+
screenId = batch['screenId'][0, 0, 0]
20+
return placeId, userId, screenId
21+
1422
def on_epoch_end(self):
1523
return
1624

Core/CTestLoader.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,15 @@ def __init__(self, testFolder):
1010
]
1111
self.on_epoch_end()
1212
return
13-
13+
14+
@lru_cache(maxsize=1)
15+
def parametersIDs(self):
16+
batch, _ = self[0]
17+
userId = batch['userId'][0, 0, 0]
18+
placeId = batch['placeId'][0, 0, 0]
19+
screenId = batch['screenId'][0, 0, 0]
20+
return placeId, userId, screenId
21+
1422
def on_epoch_end(self):
1523
return
1624

Core/Utils.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,4 +297,25 @@ def countSamplesIn(folder):
297297
with np.load(fn) as data:
298298
res += len(data['time'])
299299
continue
300-
return res
300+
return res
301+
302+
def dataset_from_stats(stats, folder):
303+
userId = stats['userId']
304+
placeId = stats['placeId']
305+
screenId = stats['screenId']
306+
# screenId is a concatenation of placeId and screenId, to make it unique pair
307+
PlaceAndScreenId = [x.split('/') for x in screenId]
308+
309+
blackList = set(stats.get('blacklist', []))
310+
known = set([tuple(x) for x in blackList])
311+
for screen_id_index, (place_id, screen_id) in enumerate(PlaceAndScreenId):
312+
place_id_index = placeId.index(place_id)
313+
# find user_id among all
314+
for user_id_index, user_id in enumerate(userId):
315+
datasetFolder = os.path.join(folder, place_id, user_id, screen_id)
316+
if not os.path.exists(datasetFolder): continue
317+
ID = (place_id_index, user_id_index, screen_id_index)
318+
if ID in known: continue
319+
known.add(ID)
320+
321+
yield (datasetFolder, ID)

NN/networks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ def transformLatents(x):
351351
# two eyes
352352
eyesN = eyeSize * eyeSize
353353
eyes = sMLP(sizes=[eyesN] * 2, activation='relu')(latents)
354-
eyes = L.Dense(eyesN * 2)(eyes)
354+
eyes = L.Dense(eyesN * 2, 'sigmoid')(eyes)
355355
eyes = L.Reshape((-1, eyeSize, eyeSize, 2))(eyes)
356356
# face points
357357
face = sMLP(sizes=[pointsN] * 2, activation='relu')(latents)

scripts/check-dataset.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-.
3+
'''
4+
This script is load one by one the datasets and check how many unique samples are there
5+
'''
6+
import argparse, os, sys
7+
# add the root folder of the project to the path
8+
ROOT_FOLDER = os.path.abspath(os.path.dirname(__file__) + '/../')
9+
sys.path.append(ROOT_FOLDER)
10+
11+
from Core.CDataSamplerInpainting import CDataSamplerInpainting
12+
from Core.CDataSampler import CDataSampler
13+
import Core.Utils as Utils
14+
import json
15+
from Core.CSamplesStorage import CSamplesStorage
16+
17+
def samplesStream(params, filename, ID, batch_size, is_inpainting):
18+
placeId, userId, screenId = ID
19+
storage = CSamplesStorage(placeId=placeId, userId=userId, screenId=screenId)
20+
if is_inpainting:
21+
ds = CDataSamplerInpainting(
22+
storage,
23+
defaults=params,
24+
batch_size=batch_size, minFrames=params['timesteps'],
25+
keys=['clean']
26+
)
27+
else:
28+
ds = CDataSampler(
29+
storage,
30+
defaults=params,
31+
batch_size=batch_size, minFrames=params['timesteps'],
32+
)
33+
ds.addBlock(Utils.datasetFrom(filename))
34+
35+
N = ds.totalSamples
36+
for i in range(0, N, batch_size):
37+
indices = list(range(i, min(i + batch_size, N)))
38+
batch, rejected, accepted = ds.sampleByIds(indices)
39+
if batch is None: continue
40+
41+
# main batch
42+
x, y = batch
43+
if not is_inpainting:
44+
x = x['clean']
45+
for idx in range(len(x['points'])):
46+
yield idx
47+
return
48+
49+
def main(args):
50+
params = dict(
51+
timesteps=args.steps,
52+
stepsSampling='uniform',
53+
# no augmentations by default
54+
pointsNoise=0.0, pointsDropout=0.0,
55+
eyesDropout=0.0, eyesAdditiveNoise=0.0, brightnessFactor=1.0, lightBlobFactor=1.0,
56+
targets=dict(keypoints=3, total=10),
57+
)
58+
folder = os.path.join(args.folder, 'Data', 'remote')
59+
60+
stats = None
61+
with open(os.path.join(folder, 'stats.json'), 'r') as f:
62+
stats = json.load(f)
63+
64+
# enable all disabled datasets
65+
stats['blacklist'] = []
66+
for datasetFolder, ID in Utils.dataset_from_stats(stats, folder):
67+
trainFile = os.path.join(datasetFolder, 'train.npz')
68+
if not os.path.exists(trainFile):
69+
continue
70+
print('Processing', trainFile)
71+
72+
stream = samplesStream(params, trainFile, ID=ID, batch_size=64, is_inpainting=args.inpainting)
73+
samplesN = 0
74+
for _ in stream:
75+
samplesN += 1
76+
continue
77+
print(f'Dataset has {samplesN} valid samples')
78+
if samplesN <= args.min_samples:
79+
print(f'Warning: dataset has less or equal to {args.min_samples} samples and will be disabled')
80+
stats['blacklist'].append(ID)
81+
82+
with open(os.path.join(folder, 'stats.json'), 'w') as f:
83+
json.dump(stats, f, indent=2, sort_keys=True, default=str)
84+
85+
if __name__ == '__main__':
86+
parser = argparse.ArgumentParser()
87+
parser.add_argument('--steps', type=int, default=5)
88+
parser.add_argument('--folder', type=str, default=ROOT_FOLDER)
89+
parser.add_argument('--min-samples', type=int, default=0)
90+
parser.add_argument('--inpainting', action='store_true', default=False)
91+
main(parser.parse_args())
92+
pass

0 commit comments

Comments
 (0)