File tree Expand file tree Collapse file tree 2 files changed +3
-3
lines changed
Expand file tree Collapse file tree 2 files changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -2578,7 +2578,7 @@ def forward(
25782578 lang = self .language_embedding (lang_id ).transpose (1 , 2 )
25792579
25802580 log_dur_pred = self .dur_predictor (hidden_states .transpose (1 , 2 ))
2581- dur_out = torch .clamp (torch .round ((torch .exp (log_dur_pred ) - 1 )).long (), min = 1 )
2581+ dur_out = torch .clamp (torch .round ((torch .expm1 (log_dur_pred ))).long (), min = 1 )
25822582 # B x C x T
25832583 if hidden_states .size (0 ) == 1 :
25842584 hidden_states = torch .repeat_interleave (hidden_states , dur_out .view (- 1 ), dim = 2 )
Original file line number Diff line number Diff line change @@ -2292,7 +2292,7 @@ def forward(
22922292
22932293 # predict duration
22942294 log_dur_pred = self .duration_predictor (char_hidden_states , padding_mask = char_padding_mask )
2295- dur_out = torch .clamp (torch .round ((torch .exp (log_dur_pred ) - 1 )).long (), min = 1 )
2295+ dur_out = torch .clamp (torch .round ((torch .expm1 (log_dur_pred ))).long (), min = 1 )
22962296 dur_out = dur_out .masked_fill (~ char_padding_mask .bool (), 0.0 )
22972297
22982298 # upsample char hidden states according to predicted duration
@@ -2854,7 +2854,7 @@ def forward(
28542854 lang = self .language_embedding (lang_id ).transpose (1 , 2 )
28552855
28562856 log_dur_pred = self .dur_predictor (hidden_states .transpose (1 , 2 ))
2857- dur_out = torch .clamp (torch .round ((torch .exp (log_dur_pred ) - 1 )).long (), min = 1 )
2857+ dur_out = torch .clamp (torch .round ((torch .expm1 (log_dur_pred ))).long (), min = 1 )
28582858 # B x C x T
28592859 if hidden_states .size (0 ) == 1 :
28602860 hidden_states = torch .repeat_interleave (hidden_states , dur_out .view (- 1 ), dim = 2 )
You can’t perform that action at this time.
0 commit comments