1+ import numpy as np
2+ import random
3+ from math import ceil
4+ from functools import lru_cache
5+
6+ class CBaseDataSampler :
7+ def __init__ (self , storage , batch_size , minFrames , defaults = {}, maxT = 1.0 , cumulative_time = True ):
8+ '''
9+ Base class for data sampling.
10+
11+ Parameters:
12+ - storage: The storage object containing the samples.
13+ - batch_size: The number of samples per batch.
14+ - minFrames: The minimum number of frames required in a trajectory.
15+ - defaults: Default parameters for sampling.
16+ - maxT: Maximum time window for sampling frames.
17+ - cumulative_time: If True, time is cumulative; otherwise, it's time deltas.
18+ '''
19+ self ._storage = storage
20+ self ._defaults = defaults
21+ self ._batchSize = batch_size
22+ self ._maxT = maxT
23+ self ._minFrames = minFrames
24+ self ._samples = []
25+ self ._currentSample = None
26+ self ._cumulative_time = cumulative_time
27+ return
28+
29+ def reset (self ):
30+ random .shuffle (self ._samples )
31+ self ._currentSample = 0
32+ return
33+
34+ def __len__ (self ):
35+ return ceil (len (self ._samples ) / self ._batchSize )
36+
37+ def _storeSample (self , idx ):
38+ # Store sample if it has enough frames
39+ minInd = self ._getTrajectoryBefore (idx )
40+ if self ._minFrames <= (idx - minInd ):
41+ self ._samples .append (idx )
42+ return
43+
44+ def add (self , sample ):
45+ idx = self ._storage .add (sample )
46+ self ._storeSample (idx )
47+ return idx
48+
49+ def addBlock (self , samples ):
50+ indexes = self ._storage .addBlock (samples )
51+ for idx in indexes :
52+ self ._storeSample (idx )
53+ continue
54+ return
55+
56+ def _getTrajectoryBefore (self , mainInd ):
57+ mainT = self ._storage [mainInd ]['time' ]
58+ minT = mainT - self ._maxT
59+
60+ minInd = mainInd
61+ for ind in range (mainInd - 1 , - 1 , - 1 ):
62+ if self ._storage [ind ]['time' ] < minT : break
63+ minInd = ind
64+ continue
65+ return minInd
66+
67+ @lru_cache (None )
68+ def _trajectoryRange (self , mainInd ):
69+ '''
70+ Returns indexes of samples that are within maxT from mainInd.
71+ Returns (minInd, maxInd) where minInd <= mainInd <= maxInd
72+ '''
73+ mainT = self ._storage [mainInd ]['time' ]
74+ maxT = mainT + self ._maxT
75+ maxInd = mainInd
76+ for ind in range (mainInd , len (self ._storage )):
77+ if maxT < self ._storage [ind ]['time' ]: break
78+ maxInd = ind
79+ continue
80+
81+ minInd = self ._getTrajectoryBefore (mainInd )
82+ return minInd , maxInd
83+
84+ def _trajectory (self , mainInd ):
85+ minInd , maxInd = self ._trajectoryRange (mainInd )
86+ return list (range (minInd , mainInd + 1 )), list (range (mainInd + 1 , maxInd + 1 ))
87+
88+ def _prepareT (self , res ):
89+ T = np .array ([self ._storage [ind ]['time' ] for ind in res ])
90+ T -= T [0 ]
91+ diff = np .diff (T , 1 )
92+ idx = np .nonzero (diff )[0 ]
93+ if len (idx ) < 1 : return None # All frames have the same time
94+ if len (diff ) == len (idx ):
95+ T = diff
96+ else :
97+ return None # Time is not consistent
98+ T = np .insert (T , 0 , 0.0 )
99+ assert len (res ) == len (T )
100+ # Convert to cumulative time if required
101+ if self ._cumulative_time :
102+ T = np .cumsum (T )
103+ return T
104+
105+ def _reshapeSteps (self , values , steps ):
106+ if steps is None :
107+ return values
108+
109+ res = []
110+ for x in values :
111+ B , * s = x .shape
112+ newShape = (B // steps , steps , * s )
113+ res .append (x .reshape (newShape ))
114+ continue
115+ return tuple (res )
116+
117+ @property
118+ def totalSamples (self ):
119+ return len (self ._storage )
120+
121+ def validSamples (self ):
122+ return list (sorted (self ._samples ))
123+
124+ def _framesFor (self , mainInd , samples , steps , stepsSampling ):
125+ if 'uniform' == stepsSampling :
126+ samples = random .sample (samples , steps - 1 )
127+ elif 'last' == stepsSampling :
128+ samples = samples [- (steps - 1 ):]
129+ elif isinstance (stepsSampling , dict ):
130+ candidates = list (samples )
131+ maxFrames = stepsSampling ['max frames' ]
132+ candidates = candidates [::- 1 ]
133+ samples = []
134+ left = steps - 1
135+ for _ in range (left ):
136+ avl = min ((maxFrames , 1 + len (candidates ) - left ))
137+ ind = random .randint (0 , avl - 1 )
138+ samples .append (candidates [ind ])
139+ candidates = candidates [ind + 1 :]
140+ left -= 1
141+ continue
142+ pass
143+ else :
144+ raise ValueError ('Unknown sampling method: ' + str (stepsSampling ))
145+
146+ res = list (sorted (samples + [mainInd ]))
147+ assert len (res ) == steps
148+ return res
0 commit comments