Skip to content

Commit 597e159

Browse files
ArthurZuckerLysandreJikydshieh
committed
Protect ParallelInterface (#38262)
Co-authored-by: Lysandre <hi@lysand.re> Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
1 parent 237c7c3 commit 597e159

File tree

1 file changed

+19
-14
lines changed

1 file changed

+19
-14
lines changed

src/transformers/integrations/tensor_parallel.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -729,23 +729,24 @@ class ParallelInterface(MutableMapping):
729729

730730
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
731731
# a new instance is created (in order to locally override a given function)
732-
_global_mapping = {
733-
"colwise": ColwiseParallel(),
734-
"rowwise": RowwiseParallel(),
735-
"colwise_rep": ColwiseParallel(output_layouts=Replicate()),
736-
"rowwise_rep": RowwiseParallel(input_layouts=Replicate()),
737-
"local_colwise": ColwiseParallel(use_dtensor=False),
738-
"local_rowwise": RowwiseParallel(use_dtensor=False),
739-
"local": IsolatedParallel(),
740-
"gather": GatherParallel(),
741-
"local_packed_rowwise": PackedRowwiseParallel(use_dtensor=False),
742-
"sequence_parallel": SequenceParallel(),
743-
"replicate": ReplicateParallel(),
744-
}
745732

746733
def __init__(self):
747734
self._local_mapping = {}
748735

736+
ParallelInterface._global_mapping = {
737+
"colwise": ColwiseParallel(),
738+
"rowwise": RowwiseParallel(),
739+
"colwise_rep": ColwiseParallel(output_layouts=Replicate()),
740+
"rowwise_rep": RowwiseParallel(input_layouts=Replicate()),
741+
"local_colwise": ColwiseParallel(use_dtensor=False),
742+
"local_rowwise": RowwiseParallel(use_dtensor=False),
743+
"local": IsolatedParallel(),
744+
"gather": GatherParallel(),
745+
"local_packed_rowwise": PackedRowwiseParallel(use_dtensor=False),
746+
"sequence_parallel": SequenceParallel(),
747+
"replicate": ReplicateParallel(),
748+
}
749+
749750
def __getitem__(self, key):
750751
# First check if instance has a local override
751752
if key in self._local_mapping:
@@ -775,7 +776,11 @@ def valid_keys(self) -> List[str]:
775776

776777

777778
# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones
778-
ALL_PARALLEL_STYLES: ParallelInterface = ParallelInterface()
779+
780+
if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
781+
ALL_PARALLEL_STYLES: ParallelInterface = ParallelInterface()
782+
else:
783+
ALL_PARALLEL_STYLES = None
779784

780785

781786
def convert_local_tensor_to_dtensor(

0 commit comments

Comments
 (0)