@@ -350,6 +350,20 @@ def _sanitize_parameters(
350350
351351 return preprocess_params , forward_params , postprocess_params
352352
353+ @property
354+ def _align_to (self ):
355+ """Sample stride per output."""
356+ # XXX: Carefully, this variable will not exist in `seq2seq` setting.
357+ # Currently chunking is not possible at this level for `seq2seq` so
358+ # it's ok.
359+ align_to = getattr (self .model .config , "inputs_to_logits_ratio" , 1 )
360+ if self .model .config .model_type == "lasr_ctc" :
361+ # TODO: find a standard for that but not easy because input length -> mel length depends on the feature extractor
362+ # specific way of doing it
363+ # means the model take mel features as input, we align according to the hop length
364+ align_to *= self .feature_extractor .hop_length
365+ return align_to
366+
353367 def preprocess (self , inputs , chunk_length_s = 0 , stride_length_s = None ):
354368 if isinstance (inputs , str ):
355369 if inputs .startswith ("http://" ) or inputs .startswith ("https://" ):
@@ -444,10 +458,7 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
444458 if isinstance (stride_length_s , (int , float )):
445459 stride_length_s = [stride_length_s , stride_length_s ]
446460
447- # XXX: Carefully, this variable will not exist in `seq2seq` setting.
448- # Currently chunking is not possible at this level for `seq2seq` so
449- # it's ok.
450- align_to = getattr (self .model .config , "inputs_to_logits_ratio" , 1 )
461+ align_to = self ._align_to
451462 chunk_len = int (round (chunk_length_s * self .feature_extractor .sampling_rate / align_to ) * align_to )
452463 stride_left = int (round (stride_length_s [0 ] * self .feature_extractor .sampling_rate / align_to ) * align_to )
453464 stride_right = int (round (stride_length_s [1 ] * self .feature_extractor .sampling_rate / align_to ) * align_to )
@@ -567,7 +578,7 @@ def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs):
567578 # Send stride to `postprocess`.
568579 # it needs to be handled there where
569580 # the pieces are to be concatenated.
570- ratio = 1 / self .model . config . inputs_to_logits_ratio
581+ ratio = 1 / self ._align_to
571582 if isinstance (stride , tuple ):
572583 out ["stride" ] = rescale_stride ([stride ], ratio )[0 ]
573584 else :
@@ -650,11 +661,12 @@ def postprocess(
650661
651662 if return_timestamps and self .type not in {"seq2seq" , "seq2seq_whisper" }:
652663 chunks = []
664+ align_to = self ._align_to
653665 for item in offsets :
654- start = item ["start_offset" ] * self . model . config . inputs_to_logits_ratio
666+ start = item ["start_offset" ] * align_to
655667 start /= self .feature_extractor .sampling_rate
656668
657- stop = item ["end_offset" ] * self . model . config . inputs_to_logits_ratio
669+ stop = item ["end_offset" ] * align_to
658670 stop /= self .feature_extractor .sampling_rate
659671
660672 chunks .append ({"text" : item [return_timestamps ], "timestamp" : (start , stop )})
0 commit comments