Skip to content

Conversation

@patil-suraj
Copy link
Contributor

@patil-suraj patil-suraj commented Jul 15, 2022

What does this PR do?

Adds the VQGAN model, first step for adding the Dallemega model in transformers.

  • This model is different from most the models available in Transformers, it's an U-Net like encoder-decoder architecture with vector quantizer bottleneck.
  • This is only the generator part of the GAN, intended only for inference.
  • It does not have common transformer style embeddings, blocks and other attributes.
  • Currently it does not support output_hidden_states and output_attentions, since this is complex architecture and it's not clear which hidden_states to return. Would love to hear your thoughts if we should support this.

@patil-suraj patil-suraj mentioned this pull request Jul 25, 2022
5 tasks

## Usage

TODO (patil-suraj): add some tips here
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
TODO (patil-suraj): add some tips here

@@ -0,0 +1,763 @@
# coding=utf-8
# Copyright 2022 The Tamin Transformers authors and The HuggingFace Inc. team. All rights reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copyright 2022 The Tamin Transformers authors and The HuggingFace Inc. team. All rights reserved.
# Copyright 2022 The Taming Transformers authors and The HuggingFace Inc. team. All rights reserved.

logger = logging.get_logger(__name__)


class VQGANFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we keep the name "...FeatureExtractor" here or not? cc @sgugger @LysandreJik

The number of channels of the hidden representation.
channel_mult (`tuple`, *optional*, defaults to (1, 1, 2, 2, 4)):
The channel multipliers for the hidden representation.
num_res_blocks (`int`, *optional*, defaults to 2):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is num_res_layers_per_block no?

num_res_blocks (`int`, *optional*, defaults to 2):
The number of residual blocks.
attn_resolutions (`tuple`, *optional*, defaults to (16,)):
The resolutions of the attention heads.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attn_resolutions is a bit misleading IMO, I'd prefer something like resolutions_with_attention

num_res_blocks (`int`, *optional*, defaults to 2):
The number of residual blocks.
attn_resolutions (`tuple`, *optional*, defaults to (16,)):
The resolutions of the attention heads.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The resolutions of the attention heads.
The resolutions at which an attention layer is used.

The dimension of the quantized (latent) embedding vectors.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability.
resample_with_conv (`bool`, *optional*, defaults to True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly this is always True no? Should we maybe just remove this parameter and default it to True?

_CONFIG_FOR_DOC = "VQGANConfig"

VQGAN_PRETRAINED_MODEL_ARCHIVE_LIST = [
"valhalla/vqgan_imagenet_f16_16384", # TODO: upload this to CompVis org.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's indeed change this to CompVis

def __init__(self, in_channels: int, with_conv: bool):
super().__init__()

self.with_conv = with_conv
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly this is always true

def __init__(self, in_channels: int, with_conv: bool):
super().__init__()

self.with_conv = with_conv
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here think this is always true no?

self,
in_channels: int,
out_channels: int = None,
use_conv_shortcut: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this param is never used and always defaults to False -> let's remove it and also remove the corresponding use_conv_short_cut param

super().__init__()

self.in_channels = in_channels
self.out_channels = out_channels
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.out_channels = out_channels


self.in_channels = in_channels
self.out_channels = out_channels
self.out_channels_ = self.in_channels if self.out_channels is None else self.out_channels
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.out_channels_ = self.in_channels if self.out_channels is None else self.out_channels
self.out_channels_ = self.in_channels if out_channels is None else out_channels

conv = partial(nn.Conv2d, self.in_channels, self.in_channels, kernel_size=1, stride=1, padding=0)

self.norm = nn.GroupNorm(num_groups=32, num_channels=self.in_channels, eps=1e-6, affine=True)
self.q, self.k, self.v = conv(), conv(), conv()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No convolution layers anymore for attention blocks please. We have a working solution in diffusers that makes use of nn.Linear -> let's use this instead


self.norm = nn.GroupNorm(num_groups=32, num_channels=self.in_channels, eps=1e-6, affine=True)
self.q, self.k, self.v = conv(), conv(), conv()
self.proj_out = conv()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, we should use nn.Linear

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is too much of a simple copy-paste of the original code for me.

  • Some parameters, flags are never used and therefore we should remove them
  • I'm not a big fan of the original config naming such as attn_resolutions -> this is extremely hard to understand
  • Let's not use conv layers for attention projection layers

I'm currently refactoring I think the exact same model. How about you wait 1,2 days and then you can copy-paste my refactor + conversion script?
See: huggingface/diffusers#137

VQGan from taming transformers is IMO too important to have it be a simple copy-paste

@patrickvonplaten
Copy link
Contributor

@patil-suraj note that you can use the current main version of diffusers as a reference of how the code should look like and you can use the conversion script to covert the official weights

@huggingface huggingface deleted a comment from github-actions bot Aug 31, 2022
@patrickvonplaten
Copy link
Contributor

Taking over this PR

@huggingface huggingface deleted a comment from github-actions bot Sep 27, 2022
@patrickvonplaten patrickvonplaten added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Sep 27, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants