Skip to content

Commit 5562fbe

Browse files
basic model for frame inpainting
1 parent 653b162 commit 5562fbe

File tree

2 files changed

+222
-10
lines changed

2 files changed

+222
-10
lines changed

NN/LagrangianInterpolation.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import tensorflow as tf
2+
3+
def lagrange_interpolation(x_values, y_values, x_targets):
4+
"""
5+
Perform Lagrange Polynomial Interpolation using TensorFlow with batch support
6+
and multidimensional y_values.
7+
8+
Parameters:
9+
- x_values: Tensor of shape (batch_size, n), original x-values for each batch.
10+
- y_values: Tensor of shape (batch_size, n, d), original y-values for each batch.
11+
- x_targets: Tensor of shape (batch_size, m), x-values where interpolation is computed.
12+
13+
Returns:
14+
- interpolated_values: Tensor of shape (batch_size, m, d), interpolated y-values for each batch.
15+
"""
16+
minX = tf.reduce_min(x_values, axis=1)
17+
maxX = tf.reduce_max(x_values, axis=1)
18+
# Check if x_targets in the range of x_values
19+
tf.debugging.assert_greater_equal(x_targets, minX, message="x_targets out of range")
20+
tf.debugging.assert_less_equal(x_targets, maxX, message="x_targets out of range")
21+
22+
batch_size = tf.shape(x_values)[0]
23+
n = tf.shape(x_values)[1]
24+
m = tf.shape(x_targets)[-1]
25+
d = tf.shape(y_values)[2]
26+
27+
tf.assert_equal(tf.shape(x_values), (batch_size, n))
28+
tf.assert_equal(tf.shape(y_values), (batch_size, n, d))
29+
tf.assert_equal(tf.shape(x_targets), (batch_size, m))
30+
# Reshape tensors for broadcasting
31+
x_values_i = tf.reshape(x_values, (batch_size, n, 1, 1)) # Shape: (batch_size, n, 1, 1)
32+
x_values_j = tf.reshape(x_values, (batch_size, 1, n, 1)) # Shape: (batch_size, 1, n, 1)
33+
34+
x_targets_k = tf.reshape(x_targets, (batch_size, 1, 1, m)) # Shape: (batch_size, 1, 1, m)
35+
36+
# Compute the denominators (x_i - x_j)
37+
denominators = x_values_i - x_values_j # Shape: (batch_size, n, n, 1)
38+
# Replace zeros on the diagonal with ones to avoid division by zero
39+
denominators = tf.where(tf.equal(denominators, 0.0), tf.ones_like(denominators), denominators)
40+
41+
# Compute the numerators (x_k - x_j)
42+
numerators = x_targets_k - x_values_j # Shape: (batch_size, 1, n, m)
43+
44+
# Compute the terms (x_k - x_j) / (x_i - x_j)
45+
terms = numerators / denominators # Shape: (batch_size, n, n, m)
46+
47+
# Exclude the terms where i == j by setting them to 1
48+
identity_matrix = tf.eye(n, batch_shape=[batch_size], dtype=tf.float64) # Shape: (batch_size, n, n)
49+
identity_matrix = tf.reshape(identity_matrix, (batch_size, n, n, 1)) # Shape: (batch_size, n, n, 1)
50+
terms = tf.where(tf.equal(identity_matrix, 1.0), tf.ones_like(terms), terms)
51+
52+
# Compute the product over j for each i and x_k
53+
basis_polynomials = tf.reduce_prod(terms, axis=2) # Shape: (batch_size, n, m)
54+
55+
# Multiply each basis polynomial by the corresponding y_i
56+
# Adjust shapes for broadcasting
57+
basis_polynomials_expanded = tf.expand_dims(basis_polynomials, axis=-1) # Shape: (batch_size, n, m, 1)
58+
y_values_expanded = tf.expand_dims(y_values, axis=2) # Shape: (batch_size, n, 1, d)
59+
products = basis_polynomials_expanded * y_values_expanded # Shape: (batch_size, n, m, d)
60+
61+
# Sum over i to get the interpolated values
62+
interpolated_values = tf.reduce_sum(products, axis=1) # Shape: (batch_size, m, d)
63+
64+
return interpolated_values

NN/networks.py

Lines changed: 158 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
import argparse, os, sys
2+
# add the root folder of the project to the path
3+
ROOT_FOLDER = os.path.abspath(os.path.dirname(__file__) + '/../')
4+
sys.path.append(ROOT_FOLDER)
5+
16
from Core.Utils import setupGPU
27
setupGPU() # dirty hack to setup GPU memory limit on startup
38

@@ -7,7 +12,7 @@
712
from NN.Utils import *
813
from NN.EyeEncoder import eyeEncoder
914
from NN.FaceMeshEncoder import FaceMeshEncoder
10-
import numpy as np
15+
from NN.LagrangianInterpolation import lagrange_interpolation
1116

1217
class CTimeEncoderLayer(tf.keras.layers.Layer):
1318
def __init__(self, **kwargs):
@@ -219,14 +224,157 @@ def Face2LatentModel(
219224
'inputs specification': _InputSpec()
220225
}
221226

222-
if __name__ == '__main__':
223-
X = Face2LatentModel(steps=5, latentSize=64,
224-
embeddings={
225-
'userId': 1, 'placeId': 1, 'screenId': 1, 'size': 64
227+
##########################
228+
def InpaintingEncoderModel(latentSize, embeddings, steps=5, pointsN=478, eyeSize=32, KP=5):
229+
points = L.Input((steps, pointsN, 2))
230+
eyeL = L.Input((steps, eyeSize, eyeSize, 1))
231+
eyeR = L.Input((steps, eyeSize, eyeSize, 1))
232+
T = L.Input((steps, 1)) # accumulative time
233+
target = L.Input((steps, 2))
234+
userIdEmb = L.Input((steps, embeddings['size']))
235+
placeIdEmb = L.Input((steps, embeddings['size']))
236+
screenIdEmb = L.Input((steps, embeddings['size']))
237+
238+
emb = L.Concatenate(-1)([userIdEmb, placeIdEmb, screenIdEmb])
239+
240+
Face2Step = Face2StepModel(pointsN, eyeSize, latentSize, embeddingsSize=emb.shape[-1])
241+
stepsData = Face2Step({
242+
'embeddings': emb,
243+
'points': points,
244+
'left eye': eyeL,
245+
'right eye': eyeR,
246+
})
247+
248+
diffT = T[:, 1:] - T[:, :-1]
249+
diffT = L.Concatenate(-2)([tf.zeros_like(diffT[:, :1]), diffT])
250+
combinedT = L.Concatenate(-1)([T, diffT])
251+
encodedT = CRolloutTimesteps(CCoordsEncodingLayer(32), name='Time')(combinedT[..., None, :])[..., 0, :]
252+
253+
latent = stepsData['latent']
254+
# add time encoding and target position
255+
targetEncoded = CRolloutTimesteps(CCoordsEncodingLayer(32), name='Target')(target[..., None, :])[..., 0, :]
256+
latent = L.Concatenate(-1)([latent, encodedT, targetEncoded])
257+
# flatten the latent
258+
latent = L.Reshape((-1,))(latent)
259+
260+
# compress the latent
261+
latent_N = latent.shape[-1]
262+
sizes = []
263+
for i in range(1, 4):
264+
for _ in range(i):
265+
sizes.append(max(latent_N // i, latentSize))
266+
267+
sizes.append(latentSize)
268+
latent = sMLP(sizes=sizes, activation='relu', name='Compress')(latent)
269+
keyT = tf.linspace(0.0, 1.0, KP)[None, :]
270+
271+
# keyT shape: (B, KP, 1)
272+
def transformKeyT(x):
273+
t, x = x
274+
B = tf.shape(x)[0]
275+
return tf.tile(t, (B, 1))[..., None]
276+
keyT = L.Lambda(transformKeyT)([keyT, latent])
277+
# keyT shape: (B, KP, 1)
278+
maxT = T[:, -1, None]
279+
keyT = L.Concatenate(-1)([keyT, maxT * keyT]) # fractional time and absolute time
280+
encodedKeyT = CRolloutTimesteps(CCoordsEncodingLayer(32), name='KeyTime')(keyT[..., None, :])[..., 0, :]
281+
282+
def combineKeys(x):
283+
latent, keyT = x
284+
latent = tf.tile(latent[..., None, :], (1, KP, 1))
285+
return L.Concatenate(-1)([latent, keyT])
286+
latent = L.Lambda(combineKeys)([latent, encodedKeyT])
287+
288+
latent = sMLP(sizes=[latentSize] * 3, activation='relu', name='CombineKeys')(latent)
289+
290+
main = tf.keras.Model(
291+
inputs={
292+
'points': points,
293+
'left eye': eyeL,
294+
'right eye': eyeR,
295+
'time': T,
296+
'target': target,
297+
'userId': userIdEmb,
298+
'placeId': placeIdEmb,
299+
'screenId': screenIdEmb,
300+
},
301+
outputs={
302+
'latent': latent,
303+
}
304+
)
305+
return main
306+
307+
def InpaintingDecoderModel(latentSize, embeddings, pointsN=478, eyeSize=32, KP=5):
308+
latentKeyPoints = L.Input((KP, latentSize))
309+
T = L.Input((None, 1))
310+
userIdEmb = L.Input((embeddings['size']))
311+
placeIdEmb = L.Input((embeddings['size']))
312+
screenIdEmb = L.Input((embeddings['size']))
313+
314+
emb = L.Concatenate(-1)([userIdEmb, placeIdEmb, screenIdEmb])[..., None, :]
315+
# emb shape: (B, 1, 3 * embSize)
316+
def interpolateKeys(x):
317+
latents, T = x
318+
B = tf.shape(latents)[0]
319+
keyT = tf.linspace(0.0, 1.0, KP)[None, :]
320+
keyT = tf.tile(keyT, (B, 1))
321+
return lagrange_interpolation(x_values=keyT, y_values=latents, x_targets=T[..., 0])
322+
latents = L.Lambda(interpolateKeys, name='InterpolateKeys')([latentKeyPoints, T])
323+
# latents shape: (B, N, latentSize)
324+
def transformLatents(x):
325+
latents, emb = x
326+
N = tf.shape(latents)[1]
327+
emb = tf.tile(emb, (1, N, 1)) # (B, 1, 3 * embSize) -> (B, N, 3 * embSize)
328+
return L.Concatenate(-1)([latents, emb])
329+
latents = L.Lambda(transformLatents, name='CombineEmb')([latents, emb])
330+
# process the latents
331+
latents = sMLP(sizes=[latentSize] * 3, activation='relu', name='CombineEmb/MLP')(latents)
332+
# decode the latents to the face points (478, 2), two eyes (32, 32, 2) and the target (2)
333+
target = IntermediatePredictor(shift=0.5)(latents)
334+
# two eyes
335+
eyesN = eyeSize * eyeSize
336+
eyes = sMLP(sizes=[eyesN] * 2, activation='relu')(latents)
337+
eyes = L.Dense(eyesN * 2)(eyes)
338+
eyes = L.Reshape((-1, eyeSize, eyeSize, 2))(eyes)
339+
# face points
340+
face = sMLP(sizes=[pointsN] * 2, activation='relu')(latents)
341+
face = L.Dense(pointsN * 2)(face)
342+
face = L.Reshape((-1, pointsN, 2))(face)
343+
344+
model = tf.keras.Model(
345+
inputs={
346+
'keyPoints': latentKeyPoints,
347+
'time': T,
348+
'userId': userIdEmb,
349+
'placeId': placeIdEmb,
350+
'screenId': screenIdEmb,
351+
},
352+
outputs={
353+
'target': target,
354+
'left eye': eyes[:, :, 0],
355+
'right eye': eyes[:, :, 1],
356+
'face': face,
226357
}
227358
)
228-
X['main'].summary(expand_nested=True)
229-
X['Face2Step'].summary(expand_nested=False)
230-
X['Step2Latent'].summary(expand_nested=False)
231-
print(X['main'].outputs)
232-
pass
359+
return model
360+
361+
362+
if __name__ == '__main__':
363+
# X = InpaintingEncoderModel(latentSize=256, embeddings={
364+
# 'size': 64
365+
# })
366+
X = InpaintingDecoderModel(latentSize=256, embeddings={
367+
'size': 64
368+
})
369+
X.summary(expand_nested=False)
370+
371+
# X = Face2LatentModel(steps=5, latentSize=64,
372+
# embeddings={
373+
# 'userId': 1, 'placeId': 1, 'screenId': 1, 'size': 64
374+
# }
375+
# )
376+
# X['main'].summary(expand_nested=True)
377+
# X['Face2Step'].summary(expand_nested=False)
378+
# X['Step2Latent'].summary(expand_nested=False)
379+
# print(X['main'].outputs)
380+
# pass

0 commit comments

Comments
 (0)