Skip to content

Commit 65dc261

Browse files
khoeustlb
andauthored
Add inputs_to_logits_ratio to LasrCTCConfig (#42720)
* Add inputs_to_logits_ratio to LasrCTCConfig * changes * nit * update * Add an _align_to property to unify stride computation in AutomaticSpeechRecognition pipeline --------- Co-authored-by: Eustache Le Bihan <eulebihan@gmail.com> Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
1 parent 64a7cc8 commit 65dc261

File tree

3 files changed

+27
-7
lines changed

3 files changed

+27
-7
lines changed

src/transformers/models/lasr/configuration_lasr.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,5 +240,9 @@ def from_encoder_config(cls, encoder_config: LasrEncoderConfig, **kwargs):
240240

241241
return cls(encoder_config=encoder_config.to_dict(), **kwargs)
242242

243+
@property
244+
def inputs_to_logits_ratio(self):
245+
return self.encoder_config.subsampling_conv_stride**2
246+
243247

244248
__all__ = ["LasrEncoderConfig", "LasrCTCConfig"]

src/transformers/models/lasr/modular_lasr.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,10 @@ def __init__(
291291
**kwargs,
292292
)
293293

294+
@property
295+
def inputs_to_logits_ratio(self):
296+
return self.encoder_config.subsampling_conv_stride**2
297+
294298

295299
class LasrEncoderSubsampling(nn.Module):
296300
def __init__(self, config: LasrEncoderConfig):

src/transformers/pipelines/automatic_speech_recognition.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,20 @@ def _sanitize_parameters(
350350

351351
return preprocess_params, forward_params, postprocess_params
352352

353+
@property
354+
def _align_to(self):
355+
"""Sample stride per output."""
356+
# XXX: Carefully, this variable will not exist in `seq2seq` setting.
357+
# Currently chunking is not possible at this level for `seq2seq` so
358+
# it's ok.
359+
align_to = getattr(self.model.config, "inputs_to_logits_ratio", 1)
360+
if self.model.config.model_type == "lasr_ctc":
361+
# TODO: find a standard for that but not easy because input length -> mel length depends on the feature extractor
362+
# specific way of doing it
363+
# means the model take mel features as input, we align according to the hop length
364+
align_to *= self.feature_extractor.hop_length
365+
return align_to
366+
353367
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
354368
if isinstance(inputs, str):
355369
if inputs.startswith("http://") or inputs.startswith("https://"):
@@ -444,10 +458,7 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
444458
if isinstance(stride_length_s, (int, float)):
445459
stride_length_s = [stride_length_s, stride_length_s]
446460

447-
# XXX: Carefully, this variable will not exist in `seq2seq` setting.
448-
# Currently chunking is not possible at this level for `seq2seq` so
449-
# it's ok.
450-
align_to = getattr(self.model.config, "inputs_to_logits_ratio", 1)
461+
align_to = self._align_to
451462
chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate / align_to) * align_to)
452463
stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to)
453464
stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to)
@@ -567,7 +578,7 @@ def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs):
567578
# Send stride to `postprocess`.
568579
# it needs to be handled there where
569580
# the pieces are to be concatenated.
570-
ratio = 1 / self.model.config.inputs_to_logits_ratio
581+
ratio = 1 / self._align_to
571582
if isinstance(stride, tuple):
572583
out["stride"] = rescale_stride([stride], ratio)[0]
573584
else:
@@ -650,11 +661,12 @@ def postprocess(
650661

651662
if return_timestamps and self.type not in {"seq2seq", "seq2seq_whisper"}:
652663
chunks = []
664+
align_to = self._align_to
653665
for item in offsets:
654-
start = item["start_offset"] * self.model.config.inputs_to_logits_ratio
666+
start = item["start_offset"] * align_to
655667
start /= self.feature_extractor.sampling_rate
656668

657-
stop = item["end_offset"] * self.model.config.inputs_to_logits_ratio
669+
stop = item["end_offset"] * align_to
658670
stop /= self.feature_extractor.sampling_rate
659671

660672
chunks.append({"text": item[return_timestamps], "timestamp": (start, stop)})

0 commit comments

Comments
 (0)