@@ -1199,7 +1199,7 @@ def __init__(
11991199 config .head_dim if hasattr (config , "head_dim" ) else config .hidden_size // config .num_attention_heads
12001200 )
12011201
1202- self .dtype = dtype
1202+ self ._dtype = dtype
12031203 self .num_key_value_heads = (
12041204 config .num_attention_heads
12051205 if getattr (config , "num_key_value_heads" , None ) is None
@@ -1216,8 +1216,8 @@ def __init__(
12161216 layer_device = layer_device_map [idx ]
12171217 else :
12181218 layer_device = device
1219- new_layer_key_cache = torch .zeros (cache_shape , dtype = self .dtype , device = layer_device )
1220- new_layer_value_cache = torch .zeros (cache_shape , dtype = self .dtype , device = layer_device )
1219+ new_layer_key_cache = torch .zeros (cache_shape , dtype = self ._dtype , device = layer_device )
1220+ new_layer_value_cache = torch .zeros (cache_shape , dtype = self ._dtype , device = layer_device )
12211221 # Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
12221222 # preventing compiled graph breaks when updating the cache.
12231223 torch ._dynamo .mark_static_address (new_layer_key_cache )
@@ -1680,7 +1680,7 @@ def __init__(
16801680 config .head_dim if hasattr (config , "head_dim" ) else config .hidden_size // config .num_attention_heads
16811681 )
16821682
1683- self .dtype = dtype
1683+ self ._dtype = dtype
16841684 self .num_key_value_heads = (
16851685 config .num_attention_heads if config .num_key_value_heads is None else config .num_key_value_heads
16861686 )
@@ -1707,8 +1707,8 @@ def __init__(
17071707 # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
17081708 # breaks when updating the cache.
17091709 cache_shape = global_cache_shape if not self .is_sliding [i ] else sliding_cache_shape
1710- new_layer_key_cache = torch .zeros (cache_shape , dtype = self .dtype , device = layer_device )
1711- new_layer_value_cache = torch .zeros (cache_shape , dtype = self .dtype , device = layer_device )
1710+ new_layer_key_cache = torch .zeros (cache_shape , dtype = self ._dtype , device = layer_device )
1711+ new_layer_value_cache = torch .zeros (cache_shape , dtype = self ._dtype , device = layer_device )
17121712 torch ._dynamo .mark_static_address (new_layer_key_cache )
17131713 torch ._dynamo .mark_static_address (new_layer_value_cache )
17141714 self .key_cache .append (new_layer_key_cache )
@@ -1853,8 +1853,8 @@ def __init__(
18531853 dtype : torch .dtype = torch .float16 ,
18541854 device : Union [torch .device , str , None ] = None ,
18551855 ):
1856- self .dtype = dtype
18571856 self .max_batch_size = max_batch_size
1857+ self ._dtype = dtype
18581858 self .intermediate_size = config .intermediate_size
18591859 self .ssm_state_size = config .state_size
18601860 self .conv_kernel_size = config .conv_kernel
@@ -1868,14 +1868,14 @@ def __init__(
18681868 self .intermediate_size ,
18691869 self .conv_kernel_size ,
18701870 device = device ,
1871- dtype = dtype ,
1871+ dtype = self . _dtype ,
18721872 )
18731873 ssm_state : torch .Tensor = torch .zeros (
18741874 self .max_batch_size ,
18751875 self .intermediate_size ,
18761876 self .ssm_state_size ,
18771877 device = device ,
1878- dtype = dtype ,
1878+ dtype = self . _dtype ,
18791879 )
18801880
18811881 torch ._dynamo .mark_static_address (conv_state )
@@ -1972,7 +1972,7 @@ def __init__(
19721972 self .max_cache_len = config .max_position_embeddings if max_cache_len is None else max_cache_len
19731973 self .device = torch .device (device ) if layer_device_map is None else torch .device (layer_device_map [0 ])
19741974 self .offload_device = torch .device (offload_device )
1975- self .dtype = dtype if dtype is not None else torch .float32
1975+ self ._dtype = dtype if dtype is not None else torch .float32
19761976
19771977 # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
19781978 head_dim = config .head_dim if hasattr (config , "head_dim" ) else config .hidden_size // config .num_attention_heads
@@ -2144,8 +2144,8 @@ def _create_key_value_cache_tensors(
21442144
21452145 is_cpu_device = device == torch .device ("cpu" )
21462146
2147- key_cache = torch .zeros (shape , dtype = self .dtype , device = device , pin_memory = is_cpu_device )
2148- value_cache = torch .zeros (shape , dtype = self .dtype , device = device , pin_memory = is_cpu_device )
2147+ key_cache = torch .zeros (shape , dtype = self ._dtype , device = device , pin_memory = is_cpu_device )
2148+ value_cache = torch .zeros (shape , dtype = self ._dtype , device = device , pin_memory = is_cpu_device )
21492149
21502150 # Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
21512151 # preventing compiled graph breaks when updating the cache.
0 commit comments