Skip to content

Commit 6b1f0db

Browse files
update to accept 48x48 eyes images
1 parent 53c8f36 commit 6b1f0db

File tree

3 files changed

+68
-18
lines changed

3 files changed

+68
-18
lines changed

Core/CDataSampler_utils.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ def toTensor(data, params, userId, placeId, screenId):
7474
N = tf.shape(points)[0]
7575
imgA = tf.cast(imgA, tf.float32) / 255.
7676
imgB = tf.cast(imgB, tf.float32) / 255.
77+
tf.assert_equal(tf.shape(imgA), (N, 48, 48))
78+
tf.assert_equal(tf.shape(imgA), tf.shape(imgB))
7779
userId = tf.fill((N, 1), userId)
7880
placeId = tf.fill((N, 1), placeId)
7981
screenId = tf.fill((N, 1), screenId)
@@ -82,16 +84,45 @@ def toTensor(data, params, userId, placeId, screenId):
8284
x,
8385
tf.concat([(N // timesteps, timesteps), tf.shape(x)[1:]], axis=-1)
8486
)
87+
# apply center crop
88+
fraction = 32.0 / 48.0
89+
pos = tf.constant(
90+
[[0.5 - fraction / 2, 0.5 - fraction / 2, 0.5 + fraction / 2, 0.5 + fraction / 2]],
91+
dtype=tf.float32
92+
)
93+
pos = tf.tile(pos, [N, 1])
94+
withCrop = lambda x: tf.image.crop_and_resize(
95+
x[..., None],
96+
boxes=pos,
97+
box_indices=tf.range(N), crop_size=(32, 32),
98+
)[..., 0]
99+
85100
clean = {
86101
'time': reshape(T),
87102
'points': reshape(points),
88-
'left eye': reshape(imgA),
89-
'right eye': reshape(imgB),
103+
'left eye': reshape(withCrop(imgA)),
104+
'right eye': reshape(withCrop(imgB)),
90105
'userId': reshape(userId),
91106
'placeId': reshape(placeId),
92107
'screenId': reshape(screenId),
93108
}
94109
##########################
110+
# random crop 32x32 eyes
111+
fraction = 32.0 / 48.0
112+
pos = tf.random.uniform((N, 2), minval=0.0, maxval=1.0 - fraction)
113+
boxes = tf.concat([pos, pos + fraction], axis=-1)
114+
tf.assert_equal(tf.shape(boxes), (N, 4))
115+
imgA = tf.image.crop_and_resize(
116+
imgA[..., None],
117+
boxes=boxes,
118+
box_indices=tf.range(N), crop_size=(32, 32),
119+
)[..., 0]
120+
imgB = tf.image.crop_and_resize(
121+
imgB[..., None],
122+
boxes=boxes,
123+
box_indices=tf.range(N), crop_size=(32, 32),
124+
)[..., 0]
125+
##########################
95126
def clip(x): return tf.clip_by_value(x, 0., 1.)
96127

97128
def sampleBrightness(a, b, mid=1.0):

Core/CEyeTracker.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def _extract(self, image, pts, isBGR):
9292
if len(pts) < 1: return EMPTY
9393

9494
HW = np.array(image.shape[:2][::-1])
95-
roi = self._circleROI(pts, padding=1.25)
95+
roi = self._circleROI(pts, padding=1.5)
9696
if roi is None: return EMPTY
9797
A, B = roi
9898
A = A.clip(min=0, max=HW)
@@ -103,8 +103,11 @@ def _extract(self, image, pts, isBGR):
103103
if np.min(crop.shape[:2]) < 8:
104104
return np.zeros(sz, np.uint8), rect
105105

106-
crop = cv2.resize(crop, sz)
106+
crop = cv2.resize(crop, (48, 48)) # 48x48, not 32x32
107107
crop = cv2.cvtColor(crop, cv2.COLOR_BGR2GRAY if isBGR else cv2.COLOR_RGB2GRAY)
108+
# center crop 32x32
109+
d = (48 - 32) // 2
110+
crop = crop[d:d+32, d:d+32]
108111
return crop.astype(np.uint8), rect
109112

110113
def _processFace(self, pose, image):

scripts/download-remote.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ def deserialize(buffer):
2727
offset = 0
2828
samples = []
2929
# read header (uint8)
30-
version = np.frombuffer(buffer, dtype=np.uint8, count=1, offset=offset)
31-
if version[0] != 1:
30+
version = np.frombuffer(buffer, dtype=np.uint8, count=1, offset=offset)[0]
31+
if not (version in [2]): # only version 2 is supported
3232
raise ValueError('Invalid version %d' % version[0])
3333
offset += 1
3434

@@ -38,6 +38,8 @@ def deserialize(buffer):
3838
offset += 36
3939
screenId = buffer[offset:offset+36].decode('utf-8')
4040
offset += 36
41+
42+
EYE_SIZE = 32 if 1 == version else 48
4143
# read samples
4244
while offset < len(buffer):
4345
sample = {
@@ -51,15 +53,16 @@ def deserialize(buffer):
5153
sample['time'] = time_data[0]
5254
offset += 4
5355

54-
# Read leftEye (1024 uint8)
55-
sample['leftEye'] = np.frombuffer(buffer, dtype=np.uint8, count=32*32, offset=offset) \
56-
.reshape(32, 32)
57-
offset += 32 * 32
56+
# Read leftEye (uint8)
57+
EYE_COUNT = EYE_SIZE * EYE_SIZE
58+
sample['leftEye'] = np.frombuffer(buffer, dtype=np.uint8, count=EYE_COUNT, offset=offset) \
59+
.reshape(EYE_SIZE, EYE_SIZE)
60+
offset += EYE_COUNT
5861

59-
# Read rightEye (1024 uint8)
60-
sample['rightEye'] = np.frombuffer(buffer, dtype=np.uint8, count=32*32, offset=offset) \
61-
.reshape(32, 32)
62-
offset += 32 * 32
62+
# Read rightEye (uint8)
63+
sample['rightEye'] = np.frombuffer(buffer, dtype=np.uint8, count=EYE_COUNT, offset=offset) \
64+
.reshape(EYE_SIZE, EYE_SIZE)
65+
offset += EYE_COUNT
6366

6467
# Read points (float32)
6568
sample['points'] = np.frombuffer(buffer, dtype='>f4', count=2*478, offset=offset) \
@@ -88,6 +91,18 @@ def deserialize(buffer):
8891
# rename "leftEye" and "rightEye" to "left eye" and "right eye"
8992
res['left eye'] = res.pop('leftEye')
9093
res['right eye'] = res.pop('rightEye')
94+
if 1 == version: # upscale to 48x48
95+
import cv2
96+
res['left eye'] = np.stack(
97+
[cv2.resize(img[..., None], (48, 48)) for img in res['left eye']]
98+
)
99+
res['right eye'] = np.stack([
100+
cv2.resize(img[..., None], (48, 48)) for img in res['right eye']
101+
])
102+
pass
103+
104+
assert res['left eye'].shape[1:] == (48, 48), 'Invalid shape for left eye. Got %s' % str(res['left eye'].shape)
105+
assert res['right eye'].shape[1:] == (48, 48), 'Invalid shape for right eye. Got %s' % str(res['right eye'].shape)
91106
return res
92107

93108
def find_free_name(folder, base_name, extension=".npz"):
@@ -158,15 +173,17 @@ def main(args):
158173
shutil.rmtree(os.path.join(folder, 'remote'), ignore_errors=True)
159174
# get the list of files from the remote server
160175
urls = requests.get(args.url).json()
161-
print('Found %d files on the remote server' % len(urls))
162-
for file in urls:
176+
N = len(urls)
177+
L = len(str(N))
178+
print('Found %d files on the remote server' % N)
179+
for i, file in enumerate(urls):
163180
response = requests.get(file)
164181
content = IO.BytesIO(response.content)
165182
# read first file in the gz archive
166183
with gzip.open(content, 'rb') as f:
167184
first_file = f.read()
168185
samples = deserialize(first_file)
169-
print('Read %d samples from %s' % (len(samples['time']), file))
186+
print(f'[{i:0{L}d}/{N:0{L}d}] Read {len(samples["time"])} samples from {file}')
170187

171188
# don't want to messing up with such cases
172189
userId = np.unique(samples['userId'])
@@ -182,7 +199,6 @@ def main(args):
182199
saveChunk(chunk, os.path.join(folder, 'remote'))
183200
continue
184201
pass
185-
186202
return
187203

188204
if __name__ == '__main__':

0 commit comments

Comments
 (0)