-
Notifications
You must be signed in to change notification settings - Fork 31.5k
Add Dia model #38405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add Dia model #38405
Changes from all commits
Commits
Show all changes
217 commits
Select commit
Hold shift + click to select a range
86d38f1
add dia model
buttercrab ad7302e
add tokenizer files
ArthurZucker 3aefa5f
cleanup some stuff
ArthurZucker 782af0e
Merge branch 'main' into add-dia
ArthurZucker c61f885
brut copy paste code
ArthurZucker 810a8df
Merge branch 'add-dia' of github.com:huggingface/transformers into ad…
ArthurZucker 4ac684e
rough cleanup of the modeling code
ArthurZucker 9cc0d6b
nuke some stuff
ArthurZucker d7491ab
more nuking
ArthurZucker a447adf
more cleanups
ArthurZucker 5f3b0c3
updates
ArthurZucker 6427323
add mulitLayerEmbedding vectorization
ArthurZucker 79e4f03
nits
ArthurZucker df780fd
more modeling simplifications
ArthurZucker f7b2c08
updates
ArthurZucker fdabeb5
update rope
ArthurZucker 9861ab5
update rope
ArthurZucker 007d480
just fixup
ArthurZucker 14a502e
update configuration files
ArthurZucker d9e9585
more cleanup!
ArthurZucker 677481b
default config values
ArthurZucker 919ef03
update
ArthurZucker 73acbdd
forgotten comma
ArthurZucker 87375ef
another comma!
ArthurZucker f1dfefd
update, more cleanups
ArthurZucker 1311792
just more nits
ArthurZucker f795327
more config cleanups
ArthurZucker 68650ef
time for the encoder
ArthurZucker 10066b6
fix
ArthurZucker 738b858
sa=mall nit
ArthurZucker 233db05
nits
ArthurZucker 9f2608c
n
ArthurZucker 87efaf0
refacto a bit
ArthurZucker 525a40d
cleanup
ArthurZucker 4d35808
update cv scipt
ArthurZucker 3dfe8cd
fix last issues
ArthurZucker 0ee2ac6
fix last nits
ArthurZucker 8ac8a9e
styling
ArthurZucker 8e63b28
small fixes
ArthurZucker 43e1082
just run 1 generation
ArthurZucker ab43e72
fixes
ArthurZucker 83ce1f8
nits
ArthurZucker 35a317f
fix conversion
ArthurZucker a12a766
fix
ArthurZucker ed0d00c
more fixes
ArthurZucker 2af2418
full generate
ArthurZucker 4199000
ouf!
ArthurZucker 5f39655
fixes!
ArthurZucker f09666c
updates
ArthurZucker badc5d0
fix
ArthurZucker 463f366
fix cvrt
ArthurZucker 37b1943
fixup
ArthurZucker 81497ed
nits
ArthurZucker afb3915
delete wrong test
ArthurZucker 0d60608
update
buttercrab 0691228
update
buttercrab 41c4f39
test tokenization
buttercrab 465d92f
let's start changing things bit by bit - fix encoder step
vasqu 10893cb
removing custom generation, moving to GenerationMixin
buttercrab 30007d9
add encoder decoder attention masks for generation
buttercrab 6ecf953
merge upstream main
buttercrab daaa83d
mask changes, correctness checked against ad29837 in dia repo
vasqu f88ea86
refactor a bit already --> next cache
vasqu a72e6e4
too important not to push :)
vasqu 7c3c230
minimal cleanup + more todos
vasqu 02ce881
Merge branch 'main' into add-dia
vasqu b153426
make main overwrite modeling utils
vasqu 6891a65
add cfg filter & eos filter
buttercrab 93670a3
add eos countdown & delay pattern
buttercrab cd3d95c
update eos countdown
buttercrab 47de688
add max step eos countdown
buttercrab a4fdcab
fix tests
buttercrab 601143e
merge
buttercrab 40b9c64
fix some things
buttercrab 00d15a6
fix generation with testing
buttercrab a4a750d
move cfg & eos stuff to logits processor
buttercrab 616e70a
make RepetitionPenaltyLogitsProcessor flexible
buttercrab 0837c73
fix input_ids concatenation dimension in GenerationMixin for flexibility
buttercrab 2053574
Add DiaHangoverLogitsProcessor and DiaExponentialDecayLengthPenalty c…
buttercrab 3d33951
merge main
buttercrab 35f564d
Add stopping criteria
buttercrab baa5677
refactor
buttercrab fa47036
move delay pattern from processor to modeling like musicgen.
buttercrab 5fe7b82
fix processor & fix tests
buttercrab 52628a5
refactor types
buttercrab bd2ac7e
refactor imports
buttercrab 7e22d24
format code
buttercrab f793f0d
fix docstring to pass ci
buttercrab c1016a1
Merge branch 'main' into add-dia
buttercrab efce662
add docstring to DiaConfig & add DiaModel to test
buttercrab 093ad3b
fix docstring
buttercrab b612125
add docstring
buttercrab 607fa7a
fix some bugs
buttercrab 4e14097
Merge branch 'main' into add-dia
buttercrab 85febf3
check
vasqu 3429985
porting / merging results from other branch - IMPORTANT: it very like…
vasqu 2357447
experimental testing of left padding for first channel
vasqu 036c9ae
whoops
vasqu b8e648c
Fix merge to make generation work
buttercrab 15b43ad
fix cfg filter
buttercrab fc6e4c5
add position ids
vasqu 68b75d8
add todos, break things
vasqu d8b892e
revert changes to generation --> we will force 2d but go 3d on custom…
vasqu 19a9dff
refactor a lot, change prepare decoder ids to work with left padding …
vasqu 4e7c550
some first fixes to get to 10. in generation
vasqu 7d08de5
some more generation fixes / adjustment
vasqu e988dda
style + rope fixes
vasqu 7f9a0c8
move cfg out, simplify a few things, more todos
vasqu 1a274d7
nit
vasqu d0ab693
start working on custom logit processors
vasqu 5804eed
nit
vasqu 8d0979e
quick fixes
vasqu 5dc7b3f
cfg top k
vasqu 686e754
more refactor of logits processing, needs a decision if gen config ge…
vasqu 52d983a
lets keep changes to core code minimal, only eos scaling is questiona…
vasqu 2e7c6c7
simpler eos delay logits processor
vasqu e636bde
that was for debugging :D
vasqu d46de54
proof of concept rope
vasqu 2db8338
small fix on device mismatch
buttercrab a9d9520
cfg fixes + delay logits max len
vasqu 371b953
transformers rope
vasqu 5cfac34
modular dia
vasqu 9382926
more cleanup
vasqu 07f2720
keep modeling consistently 3D, generate handles 2D internally
vasqu 3f8c1be
decoder starts with bos if nothing
vasqu a0f61ed
post processing prototype
vasqu 59e2bcc
style
vasqu 282de2d
lol
vasqu 0f9171f
force sample / greedy + fixes on padding
vasqu db69ca3
style
vasqu 4d3805a
fixup tokenization
vasqu 0b26d2b
nits
vasqu 5e104a4
revert
vasqu 210faf5
start working on dia tests
vasqu 291fe54
fix a lot of tests
vasqu ef9ab75
more test fixes
vasqu da1ad0d
nit
vasqu cedd8a5
Merge branch 'main' into add-dia
vasqu 8d93cd3
more test fixes + some features to simplify code more
vasqu d75899c
more cleanup
vasqu a31ae4b
forgot that one
vasqu 456a9cd
autodocs
vasqu 1a1953e
small consistency fixes
vasqu 0a41da2
fix regression
vasqu a1bf0f7
small fixes
vasqu 9eb506a
dia feature extraction
vasqu 6d33db8
docs
vasqu af9b65f
wip processor
vasqu 2cf7eda
fix processor order
vasqu 12bf32f
processing goes brrr
vasqu f8354bd
transpose before
vasqu 57e687b
small fix
vasqu 402f60f
fix major bug but needs now a closer look into the custom processors …
vasqu 737dcc9
small thing on logits
vasqu 7e81f45
nits
vasqu 9c69257
simplify indices and shifts
vasqu 9937e32
add simpler version of padding tests back (temporarily)
vasqu d3e522e
add logit processor tests
vasqu fc96004
starting tests on processor
vasqu 0d5a683
fix mask application during generation
vasqu 15ab87f
some fixes on the weights conversion
vasqu 889e323
style + fixup logits order
vasqu 2a8d0d5
simplify conversion
vasqu 0ccfea8
nit
vasqu 368d6b3
remove padding tests
vasqu 5340979
nits on modeling
vasqu e5a5dc3
hmm
vasqu 584cde2
fix tests
vasqu f53f43c
trigger
vasqu 5beae32
probably gonna be reverted, just a quick design around audio tokenizer
vasqu 1c90b37
fixup typing
vasqu adc6e6c
Merge branch 'main' into add-dia
vasqu 2acbe81
post merge + more typing
vasqu c35bc88
initial design for audio tokenizer
vasqu 31dea2a
more design changes
vasqu 3572cad
nit
vasqu 1003a52
more processor tests and style related things
vasqu 6a61da2
add to init
vasqu c95defd
protect import
vasqu b647afa
not sure why tbh
vasqu ce98e6d
add another protect
vasqu 980fa37
more fixes
vasqu ac08063
wow
vasqu b291f25
it aint stopping :D
vasqu 1e198b1
another missed type issue
vasqu 135ed15
...
vasqu caac9e7
change design around audio tokenizer to prioritize init and go for au…
vasqu 06a30c8
change to new causal mask function + docstrings
vasqu 623547a
Merge branch 'main' into add-dia
vasqu 16add9b
change ternary
vasqu 828c55a
docs
vasqu 0f11408
remove todo, i dont think its essential tbh
vasqu 1895c41
remove pipeline as current pipelines do not fit in the current scheme…
vasqu 7598dbd
closer to wrapping up the processor
vasqu 148f1fd
text to audio, just for demo purposes (will likely be reverted)
vasqu aa684dd
check if it's this
vasqu 9ec4224
save audio function
vasqu aab87b0
ensure no grad
vasqu 32237e3
fixes on prefixed audio, hop length is used via preprocess dac, devic…
vasqu 2705710
integration tests (tested locally on a100) + some processor utils / f…
vasqu 7ca3d9b
style
vasqu 005224b
nits
vasqu 756b408
another round of smaller things
vasqu 6afb932
Merge branch 'main' into add-dia
vasqu a144382
docs + some fixes (generate one might be big)
vasqu 3427824
msytery solved
vasqu 8e9daf6
small fix on conversion
vasqu d55130b
add abstract audio tokenizer, change init check to abstract class
vasqu d2597ae
nits
vasqu 75ed1ac
update docs + fix some processing :D
vasqu 4d87181
change inheritance scheme for audio tokenizer
vasqu 96ca4e2
delete dead / unnecessary code in copied generate loop
vasqu 5828d1b
last nits on new pipeline behavior (+ todo on tests) + style
vasqu 279944d
Merge branch 'main' into add-dia
vasqu 798330e
trigger
vasqu 9d1ea00
Merge branch 'main' into add-dia
vasqu fa59c1e
fixup loss
vasqu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,162 @@ | ||
| <!--Copyright 2025 The HuggingFace Team. All rights reserved. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
| the License. You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
| specific language governing permissions and limitations under the License. | ||
|
|
||
| ⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | ||
| rendered properly in your Markdown viewer. | ||
|
|
||
| --> | ||
|
|
||
| # Dia | ||
|
|
||
| <div style="float: right;"> | ||
| <div class="flex flex-wrap space-x-1"> | ||
| <img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white"> | ||
| <img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat"> | ||
| <img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white"> | ||
| </div> | ||
| </div> | ||
|
|
||
| ## Overview | ||
|
|
||
| Dia is an opensource text-to-speech (TTS) model (1.6B parameters) developed by [Nari Labs](https://huggingface.co/nari-labs). | ||
| It can generate highly realistic dialogue from transcript including nonverbal communications such as laughter and coughing. | ||
| Furthermore, emotion and tone control is also possible via audio conditioning (voice cloning). | ||
|
|
||
| **Model Architecture:** | ||
| Dia is an encoder-decoder transformer based on the original transformer architecture. However, some more modern features such as | ||
| rotational positional embeddings (RoPE) are also included. For its text portion (encoder), a byte tokenizer is utilized while | ||
| for the audio portion (decoder), a pretrained codec model [DAC](./dac.md) is used - DAC encodes speech into discrete codebook | ||
| tokens and decodes them back into audio. | ||
|
|
||
| ## Usage Tips | ||
|
|
||
| ### Generation with Text | ||
|
|
||
| ```python | ||
| from transformers import AutoProcessor, DiaForConditionalGeneration | ||
|
|
||
| torch_device = "cuda" | ||
| model_checkpoint = "buttercrab/dia-v1-1.6b" | ||
|
|
||
| text = ["[S1] Dia is an open weights text to dialogue model."] | ||
| processor = AutoProcessor.from_pretrained(model_checkpoint) | ||
| inputs = processor(text=text, padding=True, return_tensors="pt").to(torch_device) | ||
|
|
||
| model = DiaForConditionalGeneration.from_pretrained(model_checkpoint).to(torch_device) | ||
| outputs = model.generate(**inputs, max_new_tokens=256) # corresponds to around ~2s | ||
|
|
||
| # save audio to a file | ||
| outputs = processor.batch_decode(outputs) | ||
| processor.save_audio(outputs, "example.wav") | ||
|
|
||
| ``` | ||
|
|
||
| ### Generation with Text and Audio (Voice Cloning) | ||
|
|
||
| ```python | ||
| from datasets import load_dataset, Audio | ||
| from transformers import AutoProcessor, DiaForConditionalGeneration | ||
|
|
||
| torch_device = "cuda" | ||
| model_checkpoint = "buttercrab/dia-v1-1.6b" | ||
|
|
||
| ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train") | ||
| ds = ds.cast_column("audio", Audio(sampling_rate=44100)) | ||
| audio = ds[-1]["audio"]["array"] | ||
| # text is a transcript of the audio + additional text you want as new audio | ||
| text = ["[S1] I know. It's going to save me a lot of money, I hope. [S2] I sure hope so for you."] | ||
|
|
||
| processor = AutoProcessor.from_pretrained(model_checkpoint) | ||
| inputs = processor(text=text, audio=audio, padding=True, return_tensors="pt").to(torch_device) | ||
| prompt_len = processor.get_audio_prompt_len(inputs["decoder_attention_mask"]) | ||
|
|
||
| model = DiaForConditionalGeneration.from_pretrained(model_checkpoint).to(torch_device) | ||
| outputs = model.generate(**inputs, max_new_tokens=256) # corresponds to around ~2s | ||
|
|
||
| # retrieve actually generated audio and save to a file | ||
| outputs = processor.batch_decode(outputs, audio_prompt_len=prompt_len) | ||
| processor.save_audio(outputs, "example_with_audio.wav") | ||
| ``` | ||
|
|
||
| ### Training | ||
|
|
||
| ```python | ||
| from datasets import load_dataset, Audio | ||
| from transformers import AutoProcessor, DiaForConditionalGeneration | ||
|
|
||
| torch_device = "cuda" | ||
| model_checkpoint = "buttercrab/dia-v1-1.6b" | ||
|
|
||
| ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train") | ||
| ds = ds.cast_column("audio", Audio(sampling_rate=44100)) | ||
| audio = ds[-1]["audio"]["array"] | ||
| # text is a transcript of the audio | ||
| text = ["[S1] I know. It's going to save me a lot of money, I hope."] | ||
|
|
||
| processor = AutoProcessor.from_pretrained(model_checkpoint) | ||
| inputs = processor( | ||
| text=text, | ||
| audio=audio, | ||
| generation=False, | ||
| output_labels=True, | ||
| padding=True, | ||
| return_tensors="pt" | ||
| ).to(torch_device) | ||
|
|
||
| model = DiaForConditionalGeneration.from_pretrained(model_checkpoint).to(torch_device) | ||
| out = model(**inputs) | ||
| out.loss.backward() | ||
| ``` | ||
|
|
||
|
|
||
| This model was contributed by [Jaeyong Sung](https://huggingface.co/buttercrab), [Arthur Zucker](https://huggingface.co/ArthurZ), | ||
| and [Anton Vlasjuk](https://huggingface.co/AntonV). The original code can be found [here](https://github.com/nari-labs/dia/). | ||
|
|
||
|
|
||
| ## DiaConfig | ||
|
|
||
| [[autodoc]] DiaConfig | ||
|
|
||
| ## DiaDecoderConfig | ||
|
|
||
| [[autodoc]] DiaDecoderConfig | ||
|
|
||
| ## DiaEncoderConfig | ||
|
|
||
| [[autodoc]] DiaEncoderConfig | ||
|
|
||
| ## DiaTokenizer | ||
|
|
||
| [[autodoc]] DiaTokenizer | ||
| - __call__ | ||
|
|
||
| ## DiaFeatureExtractor | ||
|
|
||
| [[autodoc]] DiaFeatureExtractor | ||
| - __call__ | ||
|
|
||
| ## DiaProcessor | ||
|
|
||
| [[autodoc]] DiaProcessor | ||
| - __call__ | ||
| - batch_decode | ||
| - decode | ||
|
|
||
| ## DiaModel | ||
|
|
||
| [[autodoc]] DiaModel | ||
| - forward | ||
|
|
||
| ## DiaForConditionalGeneration | ||
|
|
||
| [[autodoc]] DiaForConditionalGeneration | ||
| - forward | ||
| - generate | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.