Skip to content

Commit a61aba5

Browse files
authored
Improve BatchFeature: stack list and lists of torch tensors (#42750)
* stack lists of tensors in BatchFeature, improve error messages, add tests * remove unnecessary stack in fast image processors and video processors * make style * fix tests
1 parent 5b710c7 commit a61aba5

File tree

62 files changed

+206
-111
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+206
-111
lines changed

src/transformers/feature_extraction_utils.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,18 @@ class BatchFeature(UserDict):
6767
tensor_type (`Union[None, str, TensorType]`, *optional*):
6868
You can give a tensor_type here to convert the lists of integers in PyTorch/Numpy Tensors at
6969
initialization.
70+
skip_tensor_conversion (`list[str]` or `set[str]`, *optional*):
71+
List or set of keys that should NOT be converted to tensors, even when `tensor_type` is specified.
7072
"""
7173

72-
def __init__(self, data: Optional[dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None):
74+
def __init__(
75+
self,
76+
data: Optional[dict[str, Any]] = None,
77+
tensor_type: Union[None, str, TensorType] = None,
78+
skip_tensor_conversion: Optional[Union[list[str], set[str]]] = None,
79+
):
7380
super().__init__(data)
74-
self.convert_to_tensors(tensor_type=tensor_type)
81+
self.convert_to_tensors(tensor_type=tensor_type, skip_tensor_conversion=skip_tensor_conversion)
7582

7683
def __getitem__(self, item: str) -> Any:
7784
"""
@@ -110,6 +117,14 @@ def _get_is_as_tensor_fns(self, tensor_type: Optional[Union[str, TensorType]] =
110117
import torch
111118

112119
def as_tensor(value):
120+
if torch.is_tensor(value):
121+
return value
122+
123+
# stack list of tensors if tensor_type is PyTorch (# torch.tensor() does not support list of tensors)
124+
if isinstance(value, (list, tuple)) and len(value) > 0 and torch.is_tensor(value[0]):
125+
return torch.stack(value)
126+
127+
# convert list of numpy arrays to numpy array (stack) if tensor_type is Numpy
113128
if isinstance(value, (list, tuple)) and len(value) > 0:
114129
if isinstance(value[0], np.ndarray):
115130
value = np.array(value)
@@ -138,14 +153,20 @@ def as_tensor(value, dtype=None):
138153
is_tensor = is_numpy_array
139154
return is_tensor, as_tensor
140155

141-
def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
156+
def convert_to_tensors(
157+
self,
158+
tensor_type: Optional[Union[str, TensorType]] = None,
159+
skip_tensor_conversion: Optional[Union[list[str], set[str]]] = None,
160+
):
142161
"""
143162
Convert the inner content to tensors.
144163
145164
Args:
146165
tensor_type (`str` or [`~utils.TensorType`], *optional*):
147166
The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
148167
`None`, no modification is done.
168+
skip_tensor_conversion (`list[str]` or `set[str]`, *optional*):
169+
List or set of keys that should NOT be converted to tensors, even when `tensor_type` is specified.
149170
"""
150171
if tensor_type is None:
151172
return self
@@ -154,18 +175,26 @@ def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = Non
154175

155176
# Do the tensor conversion in batch
156177
for key, value in self.items():
178+
# Skip keys explicitly marked for no conversion
179+
if skip_tensor_conversion and key in skip_tensor_conversion:
180+
continue
181+
157182
try:
158183
if not is_tensor(value):
159184
tensor = as_tensor(value)
160-
161185
self[key] = tensor
162-
except: # noqa E722
186+
except Exception as e:
163187
if key == "overflowing_values":
164-
raise ValueError("Unable to create tensor returning overflowing values of different lengths. ")
188+
raise ValueError(
189+
f"Unable to create tensor for '{key}' with overflowing values of different lengths. "
190+
f"Original error: {str(e)}"
191+
) from e
165192
raise ValueError(
166-
"Unable to create tensor, you should probably activate padding "
167-
"with 'padding=True' to have batched tensors with the same length."
168-
)
193+
f"Unable to convert output '{key}' (type: {type(value).__name__}) to tensor: {str(e)}\n"
194+
f"You can try:\n"
195+
f" 1. Use padding=True to ensure all outputs have the same shape\n"
196+
f" 2. Set return_tensors=None to return Python objects instead of tensors"
197+
) from e
169198

170199
return self
171200

src/transformers/image_processing_utils_fast.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -932,7 +932,6 @@ def _preprocess(
932932
if do_pad:
933933
processed_images = self.pad(processed_images, pad_size=pad_size, disable_grouping=disable_grouping)
934934

935-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
936935
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
937936

938937
def to_dict(self):

src/transformers/models/beit/image_processing_beit_fast.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,6 @@ def _preprocess(
163163
processed_images_grouped[shape] = stacked_images
164164

165165
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
166-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
167166

168167
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
169168

src/transformers/models/bridgetower/image_processing_bridgetower_fast.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,10 +251,8 @@ def _preprocess(
251251
processed_images, processed_masks = self.pad(
252252
processed_images, return_mask=True, disable_grouping=disable_grouping
253253
)
254-
processed_masks = torch.stack(processed_masks, dim=0) if return_tensors else processed_masks
255254
data["pixel_mask"] = processed_masks
256255

257-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
258256
data["pixel_values"] = processed_images
259257

260258
return BatchFeature(data=data, tensor_type=return_tensors)

src/transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,6 @@ def _preprocess(
263263
processed_images_grouped[shape] = stacked_images
264264

265265
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
266-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
267266

268267
return BatchFeature(
269268
data={"pixel_values": processed_images, "num_patches": num_patches}, tensor_type=return_tensors

src/transformers/models/convnext/image_processing_convnext_fast.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,6 @@ def _preprocess(
162162
processed_images_grouped[shape] = stacked_images
163163

164164
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
165-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
166165

167166
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
168167

src/transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,6 @@ def _preprocess(
171171
processed_images_grouped[shape] = stacked_images
172172

173173
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
174-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
175174

176175
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
177176

src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,6 @@ def _preprocess(
207207
)
208208
high_res_processed_images_grouped[shape] = stacked_high_res_images
209209
high_res_processed_images = reorder_images(high_res_processed_images_grouped, grouped_high_res_images_index)
210-
high_res_processed_images = (
211-
torch.stack(high_res_processed_images, dim=0) if return_tensors else high_res_processed_images
212-
)
213210

214211
resized_images_grouped = {}
215212
for shape, stacked_high_res_padded_images in high_res_padded_images.items():
@@ -233,7 +230,6 @@ def _preprocess(
233230
)
234231
processed_images_grouped[shape] = stacked_images
235232
processed_images = reorder_images(processed_images_grouped, grouped_resized_images_index)
236-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
237233

238234
return BatchFeature(
239235
data={"pixel_values": processed_images, "high_res_pixel_values": high_res_processed_images},

src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -888,9 +888,6 @@ def _preprocess(
888888
)
889889
high_res_processed_images_grouped[shape] = stacked_high_res_images
890890
high_res_processed_images = reorder_images(high_res_processed_images_grouped, grouped_high_res_images_index)
891-
high_res_processed_images = (
892-
torch.stack(high_res_processed_images, dim=0) if return_tensors else high_res_processed_images
893-
)
894891

895892
resized_images_grouped = {}
896893
for shape, stacked_high_res_padded_images in high_res_padded_images.items():
@@ -914,7 +911,6 @@ def _preprocess(
914911
)
915912
processed_images_grouped[shape] = stacked_images
916913
processed_images = reorder_images(processed_images_grouped, grouped_resized_images_index)
917-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
918914

919915
return BatchFeature(
920916
data={"pixel_values": processed_images, "high_res_pixel_values": high_res_processed_images},

src/transformers/models/depth_pro/image_processing_depth_pro_fast.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ def _preprocess(
9494
processed_images_grouped[shape] = stacked_images
9595

9696
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
97-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
9897

9998
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
10099

0 commit comments

Comments
 (0)