Skip to content

Commit bf91715

Browse files
authored
Fix torch.no_grad decorator in VLMS (#41888)
Fix decorator
1 parent 77e8b9f commit bf91715

File tree

4 files changed

+4
-4
lines changed

4 files changed

+4
-4
lines changed

src/transformers/models/emu3/modeling_emu3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1392,7 +1392,7 @@ def get_image_features(self, pixel_values: torch.FloatTensor, image_sizes: torch
13921392
image_features = torch.split(image_features, split_sizes)
13931393
return image_features
13941394

1395-
@torch.no_grad
1395+
@torch.no_grad()
13961396
def decode_image_tokens(self, image_tokens: torch.LongTensor, height: int, width: int):
13971397
"""
13981398
Decodes generated image tokens from language model to continuous pixel values

src/transformers/models/emu3/modular_emu3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -946,7 +946,7 @@ def get_image_features(self, pixel_values: torch.FloatTensor, image_sizes: torch
946946
image_features = torch.split(image_features, split_sizes)
947947
return image_features
948948

949-
@torch.no_grad
949+
@torch.no_grad()
950950
def decode_image_tokens(self, image_tokens: torch.LongTensor, height: int, width: int):
951951
"""
952952
Decodes generated image tokens from language model to continuous pixel values

src/transformers/models/janus/modeling_janus.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1283,7 +1283,7 @@ def decode_image_tokens(self, image_tokens: torch.Tensor):
12831283
decoded_image = decoded_image.permute(0, 2, 3, 1)
12841284
return decoded_image
12851285

1286-
@torch.no_grad
1286+
@torch.no_grad()
12871287
def generate(
12881288
self,
12891289
inputs: Optional[torch.Tensor] = None,

src/transformers/models/janus/modular_janus.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1099,7 +1099,7 @@ def decode_image_tokens(self, image_tokens: torch.Tensor):
10991099
decoded_image = decoded_image.permute(0, 2, 3, 1)
11001100
return decoded_image
11011101

1102-
@torch.no_grad
1102+
@torch.no_grad()
11031103
def generate(
11041104
self,
11051105
inputs: Optional[torch.Tensor] = None,

0 commit comments

Comments
 (0)