Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
217 commits
Select commit Hold shift + click to select a range
86d38f1
add dia model
buttercrab Apr 28, 2025
ad7302e
add tokenizer files
ArthurZucker May 3, 2025
3aefa5f
cleanup some stuff
ArthurZucker May 3, 2025
782af0e
Merge branch 'main' into add-dia
ArthurZucker May 3, 2025
c61f885
brut copy paste code
ArthurZucker May 3, 2025
810a8df
Merge branch 'add-dia' of github.com:huggingface/transformers into ad…
ArthurZucker May 3, 2025
4ac684e
rough cleanup of the modeling code
ArthurZucker May 3, 2025
9cc0d6b
nuke some stuff
ArthurZucker May 3, 2025
d7491ab
more nuking
ArthurZucker May 3, 2025
a447adf
more cleanups
ArthurZucker May 3, 2025
5f3b0c3
updates
ArthurZucker May 5, 2025
6427323
add mulitLayerEmbedding vectorization
ArthurZucker May 5, 2025
79e4f03
nits
ArthurZucker May 5, 2025
df780fd
more modeling simplifications
ArthurZucker May 5, 2025
f7b2c08
updates
ArthurZucker May 5, 2025
fdabeb5
update rope
ArthurZucker May 5, 2025
9861ab5
update rope
ArthurZucker May 5, 2025
007d480
just fixup
ArthurZucker May 5, 2025
14a502e
update configuration files
ArthurZucker May 5, 2025
d9e9585
more cleanup!
ArthurZucker May 5, 2025
677481b
default config values
ArthurZucker May 5, 2025
919ef03
update
ArthurZucker May 5, 2025
73acbdd
forgotten comma
ArthurZucker May 5, 2025
87375ef
another comma!
ArthurZucker May 5, 2025
f1dfefd
update, more cleanups
ArthurZucker May 5, 2025
1311792
just more nits
ArthurZucker May 5, 2025
f795327
more config cleanups
ArthurZucker May 5, 2025
68650ef
time for the encoder
ArthurZucker May 5, 2025
10066b6
fix
ArthurZucker May 5, 2025
738b858
sa=mall nit
ArthurZucker May 5, 2025
233db05
nits
ArthurZucker May 6, 2025
9f2608c
n
ArthurZucker May 6, 2025
87efaf0
refacto a bit
ArthurZucker May 6, 2025
525a40d
cleanup
ArthurZucker May 6, 2025
4d35808
update cv scipt
ArthurZucker May 6, 2025
3dfe8cd
fix last issues
ArthurZucker May 6, 2025
0ee2ac6
fix last nits
ArthurZucker May 6, 2025
8ac8a9e
styling
ArthurZucker May 6, 2025
8e63b28
small fixes
ArthurZucker May 7, 2025
43e1082
just run 1 generation
ArthurZucker May 7, 2025
ab43e72
fixes
ArthurZucker May 7, 2025
83ce1f8
nits
ArthurZucker May 7, 2025
35a317f
fix conversion
ArthurZucker May 7, 2025
a12a766
fix
ArthurZucker May 7, 2025
ed0d00c
more fixes
ArthurZucker May 7, 2025
2af2418
full generate
ArthurZucker May 7, 2025
4199000
ouf!
ArthurZucker May 7, 2025
5f39655
fixes!
ArthurZucker May 8, 2025
f09666c
updates
ArthurZucker May 8, 2025
badc5d0
fix
ArthurZucker May 8, 2025
463f366
fix cvrt
ArthurZucker May 8, 2025
37b1943
fixup
ArthurZucker May 8, 2025
81497ed
nits
ArthurZucker May 8, 2025
afb3915
delete wrong test
ArthurZucker May 8, 2025
0d60608
update
buttercrab May 15, 2025
0691228
update
buttercrab May 19, 2025
41c4f39
test tokenization
buttercrab May 22, 2025
465d92f
let's start changing things bit by bit - fix encoder step
vasqu May 26, 2025
10893cb
removing custom generation, moving to GenerationMixin
buttercrab May 27, 2025
30007d9
add encoder decoder attention masks for generation
buttercrab May 27, 2025
6ecf953
merge upstream main
buttercrab May 27, 2025
daaa83d
mask changes, correctness checked against ad29837 in dia repo
vasqu May 27, 2025
f88ea86
refactor a bit already --> next cache
vasqu May 27, 2025
a72e6e4
too important not to push :)
vasqu May 28, 2025
7c3c230
minimal cleanup + more todos
vasqu May 28, 2025
02ce881
Merge branch 'main' into add-dia
vasqu May 28, 2025
b153426
make main overwrite modeling utils
vasqu May 28, 2025
6891a65
add cfg filter & eos filter
buttercrab May 29, 2025
93670a3
add eos countdown & delay pattern
buttercrab May 29, 2025
cd3d95c
update eos countdown
buttercrab May 29, 2025
47de688
add max step eos countdown
buttercrab May 29, 2025
a4fdcab
fix tests
buttercrab May 30, 2025
601143e
merge
buttercrab May 30, 2025
40b9c64
fix some things
buttercrab May 30, 2025
00d15a6
fix generation with testing
buttercrab May 31, 2025
a4a750d
move cfg & eos stuff to logits processor
buttercrab May 31, 2025
616e70a
make RepetitionPenaltyLogitsProcessor flexible
buttercrab May 31, 2025
0837c73
fix input_ids concatenation dimension in GenerationMixin for flexibility
buttercrab Jun 1, 2025
2053574
Add DiaHangoverLogitsProcessor and DiaExponentialDecayLengthPenalty c…
buttercrab Jun 1, 2025
3d33951
merge main
buttercrab Jun 1, 2025
35f564d
Add stopping criteria
buttercrab Jun 1, 2025
baa5677
refactor
buttercrab Jun 1, 2025
fa47036
move delay pattern from processor to modeling like musicgen.
buttercrab Jun 1, 2025
5fe7b82
fix processor & fix tests
buttercrab Jun 1, 2025
52628a5
refactor types
buttercrab Jun 1, 2025
bd2ac7e
refactor imports
buttercrab Jun 1, 2025
7e22d24
format code
buttercrab Jun 1, 2025
f793f0d
fix docstring to pass ci
buttercrab Jun 1, 2025
c1016a1
Merge branch 'main' into add-dia
buttercrab Jun 1, 2025
efce662
add docstring to DiaConfig & add DiaModel to test
buttercrab Jun 1, 2025
093ad3b
fix docstring
buttercrab Jun 1, 2025
b612125
add docstring
buttercrab Jun 1, 2025
607fa7a
fix some bugs
buttercrab Jun 2, 2025
4e14097
Merge branch 'main' into add-dia
buttercrab Jun 3, 2025
85febf3
check
vasqu Jun 3, 2025
3429985
porting / merging results from other branch - IMPORTANT: it very like…
vasqu Jun 3, 2025
2357447
experimental testing of left padding for first channel
vasqu Jun 3, 2025
036c9ae
whoops
vasqu Jun 3, 2025
b8e648c
Fix merge to make generation work
buttercrab Jun 4, 2025
15b43ad
fix cfg filter
buttercrab Jun 4, 2025
fc6e4c5
add position ids
vasqu Jun 4, 2025
68b75d8
add todos, break things
vasqu Jun 4, 2025
d8b892e
revert changes to generation --> we will force 2d but go 3d on custom…
vasqu Jun 4, 2025
19a9dff
refactor a lot, change prepare decoder ids to work with left padding …
vasqu Jun 4, 2025
4e7c550
some first fixes to get to 10. in generation
vasqu Jun 5, 2025
7d08de5
some more generation fixes / adjustment
vasqu Jun 5, 2025
e988dda
style + rope fixes
vasqu Jun 5, 2025
7f9a0c8
move cfg out, simplify a few things, more todos
vasqu Jun 5, 2025
1a274d7
nit
vasqu Jun 5, 2025
d0ab693
start working on custom logit processors
vasqu Jun 5, 2025
5804eed
nit
vasqu Jun 5, 2025
8d0979e
quick fixes
vasqu Jun 5, 2025
5dc7b3f
cfg top k
vasqu Jun 6, 2025
686e754
more refactor of logits processing, needs a decision if gen config ge…
vasqu Jun 6, 2025
52d983a
lets keep changes to core code minimal, only eos scaling is questiona…
vasqu Jun 6, 2025
2e7c6c7
simpler eos delay logits processor
vasqu Jun 6, 2025
e636bde
that was for debugging :D
vasqu Jun 6, 2025
d46de54
proof of concept rope
vasqu Jun 6, 2025
2db8338
small fix on device mismatch
buttercrab Jun 7, 2025
a9d9520
cfg fixes + delay logits max len
vasqu Jun 9, 2025
371b953
transformers rope
vasqu Jun 9, 2025
5cfac34
modular dia
vasqu Jun 9, 2025
9382926
more cleanup
vasqu Jun 9, 2025
07f2720
keep modeling consistently 3D, generate handles 2D internally
vasqu Jun 9, 2025
3f8c1be
decoder starts with bos if nothing
vasqu Jun 9, 2025
a0f61ed
post processing prototype
vasqu Jun 9, 2025
59e2bcc
style
vasqu Jun 9, 2025
282de2d
lol
vasqu Jun 9, 2025
0f9171f
force sample / greedy + fixes on padding
vasqu Jun 10, 2025
db69ca3
style
vasqu Jun 10, 2025
4d3805a
fixup tokenization
vasqu Jun 10, 2025
0b26d2b
nits
vasqu Jun 10, 2025
5e104a4
revert
vasqu Jun 10, 2025
210faf5
start working on dia tests
vasqu Jun 10, 2025
291fe54
fix a lot of tests
vasqu Jun 11, 2025
ef9ab75
more test fixes
vasqu Jun 11, 2025
da1ad0d
nit
vasqu Jun 11, 2025
cedd8a5
Merge branch 'main' into add-dia
vasqu Jun 11, 2025
8d93cd3
more test fixes + some features to simplify code more
vasqu Jun 11, 2025
d75899c
more cleanup
vasqu Jun 11, 2025
a31ae4b
forgot that one
vasqu Jun 11, 2025
456a9cd
autodocs
vasqu Jun 11, 2025
1a1953e
small consistency fixes
vasqu Jun 11, 2025
0a41da2
fix regression
vasqu Jun 11, 2025
a1bf0f7
small fixes
vasqu Jun 12, 2025
9eb506a
dia feature extraction
vasqu Jun 12, 2025
6d33db8
docs
vasqu Jun 12, 2025
af9b65f
wip processor
vasqu Jun 12, 2025
2cf7eda
fix processor order
vasqu Jun 13, 2025
12bf32f
processing goes brrr
vasqu Jun 13, 2025
f8354bd
transpose before
vasqu Jun 13, 2025
57e687b
small fix
vasqu Jun 13, 2025
402f60f
fix major bug but needs now a closer look into the custom processors …
vasqu Jun 13, 2025
737dcc9
small thing on logits
vasqu Jun 14, 2025
7e81f45
nits
vasqu Jun 16, 2025
9c69257
simplify indices and shifts
vasqu Jun 16, 2025
9937e32
add simpler version of padding tests back (temporarily)
vasqu Jun 16, 2025
d3e522e
add logit processor tests
vasqu Jun 16, 2025
fc96004
starting tests on processor
vasqu Jun 16, 2025
0d5a683
fix mask application during generation
vasqu Jun 17, 2025
15ab87f
some fixes on the weights conversion
vasqu Jun 17, 2025
889e323
style + fixup logits order
vasqu Jun 17, 2025
2a8d0d5
simplify conversion
vasqu Jun 17, 2025
0ccfea8
nit
vasqu Jun 17, 2025
368d6b3
remove padding tests
vasqu Jun 17, 2025
5340979
nits on modeling
vasqu Jun 17, 2025
e5a5dc3
hmm
vasqu Jun 17, 2025
584cde2
fix tests
vasqu Jun 17, 2025
f53f43c
trigger
vasqu Jun 17, 2025
5beae32
probably gonna be reverted, just a quick design around audio tokenizer
vasqu Jun 17, 2025
1c90b37
fixup typing
vasqu Jun 18, 2025
adc6e6c
Merge branch 'main' into add-dia
vasqu Jun 18, 2025
2acbe81
post merge + more typing
vasqu Jun 18, 2025
c35bc88
initial design for audio tokenizer
vasqu Jun 18, 2025
31dea2a
more design changes
vasqu Jun 18, 2025
3572cad
nit
vasqu Jun 18, 2025
1003a52
more processor tests and style related things
vasqu Jun 18, 2025
6a61da2
add to init
vasqu Jun 18, 2025
c95defd
protect import
vasqu Jun 18, 2025
b647afa
not sure why tbh
vasqu Jun 18, 2025
ce98e6d
add another protect
vasqu Jun 18, 2025
980fa37
more fixes
vasqu Jun 18, 2025
ac08063
wow
vasqu Jun 18, 2025
b291f25
it aint stopping :D
vasqu Jun 18, 2025
1e198b1
another missed type issue
vasqu Jun 18, 2025
135ed15
...
vasqu Jun 18, 2025
caac9e7
change design around audio tokenizer to prioritize init and go for au…
vasqu Jun 20, 2025
06a30c8
change to new causal mask function + docstrings
vasqu Jun 20, 2025
623547a
Merge branch 'main' into add-dia
vasqu Jun 20, 2025
16add9b
change ternary
vasqu Jun 20, 2025
828c55a
docs
vasqu Jun 20, 2025
0f11408
remove todo, i dont think its essential tbh
vasqu Jun 20, 2025
1895c41
remove pipeline as current pipelines do not fit in the current scheme…
vasqu Jun 20, 2025
7598dbd
closer to wrapping up the processor
vasqu Jun 20, 2025
148f1fd
text to audio, just for demo purposes (will likely be reverted)
vasqu Jun 20, 2025
aa684dd
check if it's this
vasqu Jun 20, 2025
9ec4224
save audio function
vasqu Jun 23, 2025
aab87b0
ensure no grad
vasqu Jun 23, 2025
32237e3
fixes on prefixed audio, hop length is used via preprocess dac, devic…
vasqu Jun 23, 2025
2705710
integration tests (tested locally on a100) + some processor utils / f…
vasqu Jun 23, 2025
7ca3d9b
style
vasqu Jun 23, 2025
005224b
nits
vasqu Jun 24, 2025
756b408
another round of smaller things
vasqu Jun 24, 2025
6afb932
Merge branch 'main' into add-dia
vasqu Jun 24, 2025
a144382
docs + some fixes (generate one might be big)
vasqu Jun 24, 2025
3427824
msytery solved
vasqu Jun 24, 2025
8e9daf6
small fix on conversion
vasqu Jun 24, 2025
d55130b
add abstract audio tokenizer, change init check to abstract class
vasqu Jun 25, 2025
d2597ae
nits
vasqu Jun 25, 2025
75ed1ac
update docs + fix some processing :D
vasqu Jun 25, 2025
4d87181
change inheritance scheme for audio tokenizer
vasqu Jun 25, 2025
96ca4e2
delete dead / unnecessary code in copied generate loop
vasqu Jun 25, 2025
5828d1b
last nits on new pipeline behavior (+ todo on tests) + style
vasqu Jun 25, 2025
279944d
Merge branch 'main' into add-dia
vasqu Jun 25, 2025
798330e
trigger
vasqu Jun 25, 2025
9d1ea00
Merge branch 'main' into add-dia
vasqu Jun 26, 2025
fa59c1e
fixup loss
vasqu Jun 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,8 @@
title: CSM
- local: model_doc/dac
title: dac
- local: model_doc/dia
title: Dia
- local: model_doc/encodec
title: EnCodec
- local: model_doc/fastspeech2_conformer
Expand Down
4 changes: 4 additions & 0 deletions docs/source/en/model_doc/auto.md
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,10 @@ The following auto classes are available for the following audio tasks.

[[autodoc]] AutoModelForTextToWaveform

### AutoModelForAudioTokenization

[[autodoc]] AutoModelForAudioTokenization

## Multimodal

The following auto classes are available for the following multimodal tasks.
Expand Down
162 changes: 162 additions & 0 deletions docs/source/en/model_doc/dia.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# Dia

<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>
</div>

## Overview

Dia is an opensource text-to-speech (TTS) model (1.6B parameters) developed by [Nari Labs](https://huggingface.co/nari-labs).
It can generate highly realistic dialogue from transcript including nonverbal communications such as laughter and coughing.
Furthermore, emotion and tone control is also possible via audio conditioning (voice cloning).

**Model Architecture:**
Dia is an encoder-decoder transformer based on the original transformer architecture. However, some more modern features such as
rotational positional embeddings (RoPE) are also included. For its text portion (encoder), a byte tokenizer is utilized while
for the audio portion (decoder), a pretrained codec model [DAC](./dac.md) is used - DAC encodes speech into discrete codebook
tokens and decodes them back into audio.

## Usage Tips

### Generation with Text

```python
from transformers import AutoProcessor, DiaForConditionalGeneration

torch_device = "cuda"
model_checkpoint = "buttercrab/dia-v1-1.6b"

text = ["[S1] Dia is an open weights text to dialogue model."]
processor = AutoProcessor.from_pretrained(model_checkpoint)
inputs = processor(text=text, padding=True, return_tensors="pt").to(torch_device)

model = DiaForConditionalGeneration.from_pretrained(model_checkpoint).to(torch_device)
outputs = model.generate(**inputs, max_new_tokens=256) # corresponds to around ~2s

# save audio to a file
outputs = processor.batch_decode(outputs)
processor.save_audio(outputs, "example.wav")

```

### Generation with Text and Audio (Voice Cloning)

```python
from datasets import load_dataset, Audio
from transformers import AutoProcessor, DiaForConditionalGeneration

torch_device = "cuda"
model_checkpoint = "buttercrab/dia-v1-1.6b"

ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
ds = ds.cast_column("audio", Audio(sampling_rate=44100))
audio = ds[-1]["audio"]["array"]
# text is a transcript of the audio + additional text you want as new audio
text = ["[S1] I know. It's going to save me a lot of money, I hope. [S2] I sure hope so for you."]

processor = AutoProcessor.from_pretrained(model_checkpoint)
inputs = processor(text=text, audio=audio, padding=True, return_tensors="pt").to(torch_device)
prompt_len = processor.get_audio_prompt_len(inputs["decoder_attention_mask"])

model = DiaForConditionalGeneration.from_pretrained(model_checkpoint).to(torch_device)
outputs = model.generate(**inputs, max_new_tokens=256) # corresponds to around ~2s

# retrieve actually generated audio and save to a file
outputs = processor.batch_decode(outputs, audio_prompt_len=prompt_len)
processor.save_audio(outputs, "example_with_audio.wav")
```

### Training

```python
from datasets import load_dataset, Audio
from transformers import AutoProcessor, DiaForConditionalGeneration

torch_device = "cuda"
model_checkpoint = "buttercrab/dia-v1-1.6b"

ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
ds = ds.cast_column("audio", Audio(sampling_rate=44100))
audio = ds[-1]["audio"]["array"]
# text is a transcript of the audio
text = ["[S1] I know. It's going to save me a lot of money, I hope."]

processor = AutoProcessor.from_pretrained(model_checkpoint)
inputs = processor(
text=text,
audio=audio,
generation=False,
output_labels=True,
padding=True,
return_tensors="pt"
).to(torch_device)

model = DiaForConditionalGeneration.from_pretrained(model_checkpoint).to(torch_device)
out = model(**inputs)
out.loss.backward()
```


This model was contributed by [Jaeyong Sung](https://huggingface.co/buttercrab), [Arthur Zucker](https://huggingface.co/ArthurZ),
and [Anton Vlasjuk](https://huggingface.co/AntonV). The original code can be found [here](https://github.com/nari-labs/dia/).


## DiaConfig

[[autodoc]] DiaConfig

## DiaDecoderConfig

[[autodoc]] DiaDecoderConfig

## DiaEncoderConfig

[[autodoc]] DiaEncoderConfig

## DiaTokenizer

[[autodoc]] DiaTokenizer
- __call__

## DiaFeatureExtractor

[[autodoc]] DiaFeatureExtractor
- __call__

## DiaProcessor

[[autodoc]] DiaProcessor
- __call__
- batch_decode
- decode

## DiaModel

[[autodoc]] DiaModel
- forward

## DiaForConditionalGeneration

[[autodoc]] DiaForConditionalGeneration
- forward
- generate
1 change: 0 additions & 1 deletion src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,6 @@ def __init__(self, **kwargs):
self.pad_token_id = kwargs.pop("pad_token_id", None)
self.eos_token_id = kwargs.pop("eos_token_id", None)
self.sep_token_id = kwargs.pop("sep_token_id", None)

self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)

# task specific arguments
Expand Down
Loading