1+ import argparse , os , sys
2+ # add the root folder of the project to the path
3+ ROOT_FOLDER = os .path .abspath (os .path .dirname (__file__ ) + '/../' )
4+ sys .path .append (ROOT_FOLDER )
5+
16from Core .Utils import setupGPU
27setupGPU () # dirty hack to setup GPU memory limit on startup
38
712from NN .Utils import *
813from NN .EyeEncoder import eyeEncoder
914from NN .FaceMeshEncoder import FaceMeshEncoder
10- import numpy as np
15+ from NN . LagrangianInterpolation import lagrange_interpolation
1116
1217class CTimeEncoderLayer (tf .keras .layers .Layer ):
1318 def __init__ (self , ** kwargs ):
@@ -219,14 +224,157 @@ def Face2LatentModel(
219224 'inputs specification' : _InputSpec ()
220225 }
221226
222- if __name__ == '__main__' :
223- X = Face2LatentModel (steps = 5 , latentSize = 64 ,
224- embeddings = {
225- 'userId' : 1 , 'placeId' : 1 , 'screenId' : 1 , 'size' : 64
227+ ##########################
228+ def InpaintingEncoderModel (latentSize , embeddings , steps = 5 , pointsN = 478 , eyeSize = 32 , KP = 5 ):
229+ points = L .Input ((steps , pointsN , 2 ))
230+ eyeL = L .Input ((steps , eyeSize , eyeSize , 1 ))
231+ eyeR = L .Input ((steps , eyeSize , eyeSize , 1 ))
232+ T = L .Input ((steps , 1 )) # accumulative time
233+ target = L .Input ((steps , 2 ))
234+ userIdEmb = L .Input ((steps , embeddings ['size' ]))
235+ placeIdEmb = L .Input ((steps , embeddings ['size' ]))
236+ screenIdEmb = L .Input ((steps , embeddings ['size' ]))
237+
238+ emb = L .Concatenate (- 1 )([userIdEmb , placeIdEmb , screenIdEmb ])
239+
240+ Face2Step = Face2StepModel (pointsN , eyeSize , latentSize , embeddingsSize = emb .shape [- 1 ])
241+ stepsData = Face2Step ({
242+ 'embeddings' : emb ,
243+ 'points' : points ,
244+ 'left eye' : eyeL ,
245+ 'right eye' : eyeR ,
246+ })
247+
248+ diffT = T [:, 1 :] - T [:, :- 1 ]
249+ diffT = L .Concatenate (- 2 )([tf .zeros_like (diffT [:, :1 ]), diffT ])
250+ combinedT = L .Concatenate (- 1 )([T , diffT ])
251+ encodedT = CRolloutTimesteps (CCoordsEncodingLayer (32 ), name = 'Time' )(combinedT [..., None , :])[..., 0 , :]
252+
253+ latent = stepsData ['latent' ]
254+ # add time encoding and target position
255+ targetEncoded = CRolloutTimesteps (CCoordsEncodingLayer (32 ), name = 'Target' )(target [..., None , :])[..., 0 , :]
256+ latent = L .Concatenate (- 1 )([latent , encodedT , targetEncoded ])
257+ # flatten the latent
258+ latent = L .Reshape ((- 1 ,))(latent )
259+
260+ # compress the latent
261+ latent_N = latent .shape [- 1 ]
262+ sizes = []
263+ for i in range (1 , 4 ):
264+ for _ in range (i ):
265+ sizes .append (max (latent_N // i , latentSize ))
266+
267+ sizes .append (latentSize )
268+ latent = sMLP (sizes = sizes , activation = 'relu' , name = 'Compress' )(latent )
269+ keyT = tf .linspace (0.0 , 1.0 , KP )[None , :]
270+
271+ # keyT shape: (B, KP, 1)
272+ def transformKeyT (x ):
273+ t , x = x
274+ B = tf .shape (x )[0 ]
275+ return tf .tile (t , (B , 1 ))[..., None ]
276+ keyT = L .Lambda (transformKeyT )([keyT , latent ])
277+ # keyT shape: (B, KP, 1)
278+ maxT = T [:, - 1 , None ]
279+ keyT = L .Concatenate (- 1 )([keyT , maxT * keyT ]) # fractional time and absolute time
280+ encodedKeyT = CRolloutTimesteps (CCoordsEncodingLayer (32 ), name = 'KeyTime' )(keyT [..., None , :])[..., 0 , :]
281+
282+ def combineKeys (x ):
283+ latent , keyT = x
284+ latent = tf .tile (latent [..., None , :], (1 , KP , 1 ))
285+ return L .Concatenate (- 1 )([latent , keyT ])
286+ latent = L .Lambda (combineKeys )([latent , encodedKeyT ])
287+
288+ latent = sMLP (sizes = [latentSize ] * 3 , activation = 'relu' , name = 'CombineKeys' )(latent )
289+
290+ main = tf .keras .Model (
291+ inputs = {
292+ 'points' : points ,
293+ 'left eye' : eyeL ,
294+ 'right eye' : eyeR ,
295+ 'time' : T ,
296+ 'target' : target ,
297+ 'userId' : userIdEmb ,
298+ 'placeId' : placeIdEmb ,
299+ 'screenId' : screenIdEmb ,
300+ },
301+ outputs = {
302+ 'latent' : latent ,
303+ }
304+ )
305+ return main
306+
307+ def InpaintingDecoderModel (latentSize , embeddings , pointsN = 478 , eyeSize = 32 , KP = 5 ):
308+ latentKeyPoints = L .Input ((KP , latentSize ))
309+ T = L .Input ((None , 1 ))
310+ userIdEmb = L .Input ((embeddings ['size' ]))
311+ placeIdEmb = L .Input ((embeddings ['size' ]))
312+ screenIdEmb = L .Input ((embeddings ['size' ]))
313+
314+ emb = L .Concatenate (- 1 )([userIdEmb , placeIdEmb , screenIdEmb ])[..., None , :]
315+ # emb shape: (B, 1, 3 * embSize)
316+ def interpolateKeys (x ):
317+ latents , T = x
318+ B = tf .shape (latents )[0 ]
319+ keyT = tf .linspace (0.0 , 1.0 , KP )[None , :]
320+ keyT = tf .tile (keyT , (B , 1 ))
321+ return lagrange_interpolation (x_values = keyT , y_values = latents , x_targets = T [..., 0 ])
322+ latents = L .Lambda (interpolateKeys , name = 'InterpolateKeys' )([latentKeyPoints , T ])
323+ # latents shape: (B, N, latentSize)
324+ def transformLatents (x ):
325+ latents , emb = x
326+ N = tf .shape (latents )[1 ]
327+ emb = tf .tile (emb , (1 , N , 1 )) # (B, 1, 3 * embSize) -> (B, N, 3 * embSize)
328+ return L .Concatenate (- 1 )([latents , emb ])
329+ latents = L .Lambda (transformLatents , name = 'CombineEmb' )([latents , emb ])
330+ # process the latents
331+ latents = sMLP (sizes = [latentSize ] * 3 , activation = 'relu' , name = 'CombineEmb/MLP' )(latents )
332+ # decode the latents to the face points (478, 2), two eyes (32, 32, 2) and the target (2)
333+ target = IntermediatePredictor (shift = 0.5 )(latents )
334+ # two eyes
335+ eyesN = eyeSize * eyeSize
336+ eyes = sMLP (sizes = [eyesN ] * 2 , activation = 'relu' )(latents )
337+ eyes = L .Dense (eyesN * 2 )(eyes )
338+ eyes = L .Reshape ((- 1 , eyeSize , eyeSize , 2 ))(eyes )
339+ # face points
340+ face = sMLP (sizes = [pointsN ] * 2 , activation = 'relu' )(latents )
341+ face = L .Dense (pointsN * 2 )(face )
342+ face = L .Reshape ((- 1 , pointsN , 2 ))(face )
343+
344+ model = tf .keras .Model (
345+ inputs = {
346+ 'keyPoints' : latentKeyPoints ,
347+ 'time' : T ,
348+ 'userId' : userIdEmb ,
349+ 'placeId' : placeIdEmb ,
350+ 'screenId' : screenIdEmb ,
351+ },
352+ outputs = {
353+ 'target' : target ,
354+ 'left eye' : eyes [:, :, 0 ],
355+ 'right eye' : eyes [:, :, 1 ],
356+ 'face' : face ,
226357 }
227358 )
228- X ['main' ].summary (expand_nested = True )
229- X ['Face2Step' ].summary (expand_nested = False )
230- X ['Step2Latent' ].summary (expand_nested = False )
231- print (X ['main' ].outputs )
232- pass
359+ return model
360+
361+
362+ if __name__ == '__main__' :
363+ # X = InpaintingEncoderModel(latentSize=256, embeddings={
364+ # 'size': 64
365+ # })
366+ X = InpaintingDecoderModel (latentSize = 256 , embeddings = {
367+ 'size' : 64
368+ })
369+ X .summary (expand_nested = False )
370+
371+ # X = Face2LatentModel(steps=5, latentSize=64,
372+ # embeddings={
373+ # 'userId': 1, 'placeId': 1, 'screenId': 1, 'size': 64
374+ # }
375+ # )
376+ # X['main'].summary(expand_nested=True)
377+ # X['Face2Step'].summary(expand_nested=False)
378+ # X['Step2Latent'].summary(expand_nested=False)
379+ # print(X['main'].outputs)
380+ # pass
0 commit comments