1+ #!/usr/bin/env python
2+ # -*- coding: utf-8 -*-.
3+ '''
4+ This script is load one by one the datasets and check how many unique samples are there
5+ '''
6+ import argparse , os , sys
7+ # add the root folder of the project to the path
8+ ROOT_FOLDER = os .path .abspath (os .path .dirname (__file__ ) + '/../' )
9+ sys .path .append (ROOT_FOLDER )
10+
11+ from Core .CDataSamplerInpainting import CDataSamplerInpainting
12+ from Core .CDataSampler import CDataSampler
13+ import Core .Utils as Utils
14+ import json
15+ from Core .CSamplesStorage import CSamplesStorage
16+
17+ def samplesStream (params , filename , ID , batch_size , is_inpainting ):
18+ placeId , userId , screenId = ID
19+ storage = CSamplesStorage (placeId = placeId , userId = userId , screenId = screenId )
20+ if is_inpainting :
21+ ds = CDataSamplerInpainting (
22+ storage ,
23+ defaults = params ,
24+ batch_size = batch_size , minFrames = params ['timesteps' ],
25+ keys = ['clean' ]
26+ )
27+ else :
28+ ds = CDataSampler (
29+ storage ,
30+ defaults = params ,
31+ batch_size = batch_size , minFrames = params ['timesteps' ],
32+ )
33+ ds .addBlock (Utils .datasetFrom (filename ))
34+
35+ N = ds .totalSamples
36+ for i in range (0 , N , batch_size ):
37+ indices = list (range (i , min (i + batch_size , N )))
38+ batch , rejected , accepted = ds .sampleByIds (indices )
39+ if batch is None : continue
40+
41+ # main batch
42+ x , y = batch
43+ if not is_inpainting :
44+ x = x ['clean' ]
45+ for idx in range (len (x ['points' ])):
46+ yield idx
47+ return
48+
49+ def main (args ):
50+ params = dict (
51+ timesteps = args .steps ,
52+ stepsSampling = 'uniform' ,
53+ # no augmentations by default
54+ pointsNoise = 0.0 , pointsDropout = 0.0 ,
55+ eyesDropout = 0.0 , eyesAdditiveNoise = 0.0 , brightnessFactor = 1.0 , lightBlobFactor = 1.0 ,
56+ targets = dict (keypoints = 3 , total = 10 ),
57+ )
58+ folder = os .path .join (args .folder , 'Data' , 'remote' )
59+
60+ stats = None
61+ with open (os .path .join (folder , 'stats.json' ), 'r' ) as f :
62+ stats = json .load (f )
63+
64+ # enable all disabled datasets
65+ stats ['blacklist' ] = []
66+ for datasetFolder , ID in Utils .dataset_from_stats (stats , folder ):
67+ trainFile = os .path .join (datasetFolder , 'train.npz' )
68+ if not os .path .exists (trainFile ):
69+ continue
70+ print ('Processing' , trainFile )
71+
72+ stream = samplesStream (params , trainFile , ID = ID , batch_size = 64 , is_inpainting = args .inpainting )
73+ samplesN = 0
74+ for _ in stream :
75+ samplesN += 1
76+ continue
77+ print (f'Dataset has { samplesN } valid samples' )
78+ if samplesN <= args .min_samples :
79+ print (f'Warning: dataset has less or equal to { args .min_samples } samples and will be disabled' )
80+ stats ['blacklist' ].append (ID )
81+
82+ with open (os .path .join (folder , 'stats.json' ), 'w' ) as f :
83+ json .dump (stats , f , indent = 2 , sort_keys = True , default = str )
84+
85+ if __name__ == '__main__' :
86+ parser = argparse .ArgumentParser ()
87+ parser .add_argument ('--steps' , type = int , default = 5 )
88+ parser .add_argument ('--folder' , type = str , default = ROOT_FOLDER )
89+ parser .add_argument ('--min-samples' , type = int , default = 0 )
90+ parser .add_argument ('--inpainting' , action = 'store_true' , default = False )
91+ main (parser .parse_args ())
92+ pass
0 commit comments