Fix SDXL VAE decode latents dtype mismatch on non-MPS #12847
+50
−0
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
On non-MPS platforms (CUDA/CPU),
StableDiffusionXLPipelinecan callvae.decode()with fp16 latents while the VAE (or parts of it) are fp32, which causes a hard runtime error in normalization/linear layers (e.g.GroupNorm):expected scalar type Half but found Float.This happens because the pipeline currently only aligns
latentsdtype whenneeds_upcastingisTrue(fp16 VAE +force_upcast), and theelif latents.dtype != self.vae.dtype:branch only handles MPS by casting the VAE to the latents dtype. On CUDA/CPU there is no dtype/device alignment, so mixed dtype can reach VAE decode.Reproduction
diffusers==0.36.0.dev0(observed), CUDA or CPU (non-MPS)StableDiffusionXLPipeline.__call__withoutput_type != "latent".vae.decode(latents, ...)and errors inside VAE decoderGroupNorm/Lineardue to fp16 input + fp32 weights.A concrete regression test is included to reproduce this without GPU:
pipe.vaeto fp32callback_on_step_endto forcelatentsto fp16vae.decodeFix
When
needs_upcastingisFalsebutlatents.dtype != self.vae.dtype, we now alignlatentsdtype/device to the VAE decode dtype/device (preferringvae.post_quant_convparameters when available) on non-MPS platforms. This prevents mixed dtype from reachingvae.decode()and matches the intent of the upcast path.Tests
test_vae_decode_aligns_latents_dtype_when_vae_is_fp32intests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py.Why this is a bug
Users can legitimately end up with fp32 VAE (stability) while latents are fp16 (performance / callbacks / schedulers). The pipeline should not crash with dtype mismatch in this scenario; it should deterministically align latents to the VAE decode dtype.