Skip to content

Commit 32b33cf

Browse files
wip
1 parent 66ab19d commit 32b33cf

File tree

10 files changed

+802
-343
lines changed

10 files changed

+802
-343
lines changed

Core/CBaseDataSampler.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import numpy as np
2+
import random
3+
from math import ceil
4+
from functools import lru_cache
5+
6+
class CBaseDataSampler:
7+
def __init__(self, storage, batch_size, minFrames, defaults={}, maxT=1.0, cumulative_time=True):
8+
'''
9+
Base class for data sampling.
10+
11+
Parameters:
12+
- storage: The storage object containing the samples.
13+
- batch_size: The number of samples per batch.
14+
- minFrames: The minimum number of frames required in a trajectory.
15+
- defaults: Default parameters for sampling.
16+
- maxT: Maximum time window for sampling frames.
17+
- cumulative_time: If True, time is cumulative; otherwise, it's time deltas.
18+
'''
19+
self._storage = storage
20+
self._defaults = defaults
21+
self._batchSize = batch_size
22+
self._maxT = maxT
23+
self._minFrames = minFrames
24+
self._samples = []
25+
self._currentSample = None
26+
self._cumulative_time = cumulative_time
27+
return
28+
29+
def reset(self):
30+
random.shuffle(self._samples)
31+
self._currentSample = 0
32+
return
33+
34+
def __len__(self):
35+
return ceil(len(self._samples) / self._batchSize)
36+
37+
def _storeSample(self, idx):
38+
# Store sample if it has enough frames
39+
minInd = self._getTrajectoryBefore(idx)
40+
if self._minFrames <= (idx - minInd):
41+
self._samples.append(idx)
42+
return
43+
44+
def add(self, sample):
45+
idx = self._storage.add(sample)
46+
self._storeSample(idx)
47+
return idx
48+
49+
def addBlock(self, samples):
50+
indexes = self._storage.addBlock(samples)
51+
for idx in indexes:
52+
self._storeSample(idx)
53+
continue
54+
return
55+
56+
def _getTrajectoryBefore(self, mainInd):
57+
mainT = self._storage[mainInd]['time']
58+
minT = mainT - self._maxT
59+
60+
minInd = mainInd
61+
for ind in range(mainInd - 1, -1, -1):
62+
if self._storage[ind]['time'] < minT: break
63+
minInd = ind
64+
continue
65+
return minInd
66+
67+
@lru_cache(None)
68+
def _trajectoryRange(self, mainInd):
69+
'''
70+
Returns indexes of samples that are within maxT from mainInd.
71+
Returns (minInd, maxInd) where minInd <= mainInd <= maxInd
72+
'''
73+
mainT = self._storage[mainInd]['time']
74+
maxT = mainT + self._maxT
75+
maxInd = mainInd
76+
for ind in range(mainInd, len(self._storage)):
77+
if maxT < self._storage[ind]['time']: break
78+
maxInd = ind
79+
continue
80+
81+
minInd = self._getTrajectoryBefore(mainInd)
82+
return minInd, maxInd
83+
84+
def _trajectory(self, mainInd):
85+
minInd, maxInd = self._trajectoryRange(mainInd)
86+
return list(range(minInd, mainInd + 1)), list(range(mainInd + 1, maxInd + 1))
87+
88+
def _prepareT(self, res):
89+
T = np.array([self._storage[ind]['time'] for ind in res])
90+
T -= T[0]
91+
diff = np.diff(T, 1)
92+
idx = np.nonzero(diff)[0]
93+
if len(idx) < 1: return None # All frames have the same time
94+
if len(diff) == len(idx):
95+
T = diff
96+
else:
97+
return None # Time is not consistent
98+
T = np.insert(T, 0, 0.0)
99+
assert len(res) == len(T)
100+
# Convert to cumulative time if required
101+
if self._cumulative_time:
102+
T = np.cumsum(T)
103+
return T
104+
105+
def _reshapeSteps(self, values, steps):
106+
if steps is None:
107+
return values
108+
109+
res = []
110+
for x in values:
111+
B, *s = x.shape
112+
newShape = (B // steps, steps, *s)
113+
res.append(x.reshape(newShape))
114+
continue
115+
return tuple(res)
116+
117+
@property
118+
def totalSamples(self):
119+
return len(self._storage)
120+
121+
def validSamples(self):
122+
return list(sorted(self._samples))
123+
124+
def _framesFor(self, mainInd, samples, steps, stepsSampling):
125+
if 'uniform' == stepsSampling:
126+
samples = random.sample(samples, steps - 1)
127+
elif 'last' == stepsSampling:
128+
samples = samples[-(steps - 1):]
129+
elif isinstance(stepsSampling, dict):
130+
candidates = list(samples)
131+
maxFrames = stepsSampling['max frames']
132+
candidates = candidates[::-1]
133+
samples = []
134+
left = steps - 1
135+
for _ in range(left):
136+
avl = min((maxFrames, 1 + len(candidates) - left))
137+
ind = random.randint(0, avl - 1)
138+
samples.append(candidates[ind])
139+
candidates = candidates[ind+1:]
140+
left -= 1
141+
continue
142+
pass
143+
else:
144+
raise ValueError('Unknown sampling method: ' + str(stepsSampling))
145+
146+
res = list(sorted(samples + [mainInd]))
147+
assert len(res) == steps
148+
return res

0 commit comments

Comments
 (0)