33from Core .Utils import FACE_MESH_POINTS
44
55import numpy as np
6+ import tensorflow as tf
67
78'''
89This sampler are sample N frames from the dataset, where N is the number of timesteps.
2425 - The target point.
2526'''
2627class CDataSamplerInpainting (CBaseDataSampler ):
27- def __init__ (self , storage , batch_size , minFrames , defaults = {}, maxT = 1.0 , cumulative_time = True ):
28+ def __init__ (self , storage , batch_size , minFrames , keys , defaults = {}, maxT = 1.0 , cumulative_time = True ):
2829 super ().__init__ (storage , batch_size , minFrames , defaults , maxT , cumulative_time )
30+ self ._keys = keys
2931
3032 def _stepsFor (self , mainInd , steps , stepsSampling = 'uniform' , ** _ ):
3133 if (steps is None ) or (1 == steps ): return [(mainInd , 0.0 )]
@@ -46,7 +48,7 @@ def _stepsFor(self, mainInd, steps, stepsSampling='uniform', **_):
4648 def sample (self , ** kwargs ):
4749 kwargs = {** self ._defaults , ** kwargs }
4850 timesteps = kwargs .get ('timesteps' , None )
49- N = kwargs .get ('N' , self ._batchSize )
51+ N = kwargs .get ('N' , self ._batchSize ) // len ( self . _keys )
5052 indexes = []
5153 for _ in range (N ):
5254 added = False
@@ -101,6 +103,7 @@ def sampleByIds(self, ids, **kwargs):
101103 def _indexes2XY (self , indexesAndTime , kwargs ):
102104 timesteps = kwargs .get ('timesteps' , None )
103105 assert timesteps is not None , 'The number of timesteps must be defined.'
106+ B = len (indexesAndTime ) // timesteps
104107 samples = [self ._storage [i ] for i , _ in indexesAndTime ]
105108 ##############
106109 userIds = np .unique ([x ['userId' ] for x in samples ])
@@ -130,7 +133,25 @@ def _indexes2XY(self, indexesAndTime, kwargs):
130133 ),
131134 userIds [0 ], placeIds [0 ], screenIds [0 ]
132135 )
133- X = X ['clean' ]
136+ for k in X .keys ():
137+ # add the target point to the X
138+ targets = np .array ([x ['goal' ] for x in samples ], np .float32 ).reshape ((B , timesteps , 2 ))
139+ X [k ]['target' ] = tf .constant (targets , dtype = tf .float32 )
140+
141+ if 1 == len (self ._keys ):
142+ X = X [self ._keys [0 ]]
143+ else :
144+ res = {}
145+ k = self ._keys [0 ]
146+ subkeys = list (X [k ].keys ())
147+ for k in subkeys :
148+ values = [X [key ][k ] for key in self ._keys ]
149+ res [k ] = tf .concat (values , axis = 0 )
150+ continue
151+ X = res
152+ indexesAndTime = indexesAndTime * len (self ._keys )
153+ B = len (self ._keys ) * B
154+
134155 ###############
135156 # generate the target data
136157 targets = kwargs .get ('targets' , {'keypoints' : timesteps , 'total' : timesteps })
@@ -141,7 +162,7 @@ def _indexes2XY(self, indexesAndTime, kwargs):
141162
142163 samples_indexes = np .array ([ i for i , _ in indexesAndTime ], np .int32 )
143164 samples_indexes = samples_indexes .reshape ((- 1 , timesteps ))
144- B = samples_indexes .shape [0 ]
165+ assert samples_indexes . shape [ 0 ] == B , 'Invalid number of samples: %d != %d' % ( samples_indexes .shape [0 ], B )
145166 targetsIdx = np .zeros ((B , T ), np .int32 )
146167 for i in range (B ):
147168 # sample K frames from the X
@@ -186,20 +207,20 @@ def _indexes2XY(self, indexesAndTime, kwargs):
186207 Y ['right eye' ][i , j ] = data ['right eye' ][p :p + 32 , p :p + 32 ]
187208 Y ['time' ][i , j ] = (data ['time' ] - startT ) / duration
188209 Y ['target' ][i , j ] = data ['goal' ]
189-
210+ # eyes in 0..255, so we need to normalize them
211+ Y ['left eye' ] /= 255.0
212+ Y ['right eye' ] /= 255.0
190213 # check that time is between 0 and 1
191214 assert np .all ((0 <= Y ['time' ]) & (Y ['time' ] <= 1 )), 'Invalid time: ' + str (Y ['time' ])
192- B = Y ['points' ].shape [0 ]
193215 for k , v in X .items ():
194216 assert B == v .shape [0 ], f'Invalid batch size for X[{ k } ]: { v .shape [0 ]} != { B } ({ v .shape } )'
195217 for k , v in Y .items ():
196218 assert B == v .shape [0 ], f'Invalid batch size for Y[{ k } ]: { v .shape [0 ]} != { B } ({ v .shape } )'
197219 return (X , Y )
198220
199221 def merge (self , samples , expected_batch_size ):
200- # each dictionary contains the subkeys: points, left eye, right eye, time, userId, placeId, screenId
201222 X = {}
202- for subkey in ['points' , 'left eye' , 'right eye' , 'time' , 'userId' , 'placeId' , 'screenId' ]:
223+ for subkey in ['points' , 'left eye' , 'right eye' , 'time' , 'userId' , 'placeId' , 'screenId' , 'target' ]:
203224 data = [x [subkey ] for x , _ in samples ]
204225 X [subkey ] = np .concatenate (data , axis = 0 )
205226 assert X [subkey ].shape [0 ] == expected_batch_size , 'Invalid batch size: %d != %d' % (X [subkey ].shape [0 ], expected_batch_size )
@@ -211,4 +232,4 @@ def merge(self, samples, expected_batch_size):
211232 Y [subkey ] = np .concatenate (data , axis = 0 )
212233 assert Y [subkey ].shape [0 ] == expected_batch_size , 'Invalid batch size: %d != %d' % (Y [subkey ].shape [0 ], expected_batch_size )
213234 continue
214- return (X , ( Y , ) )
235+ return (X , Y )
0 commit comments