Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,6 @@ def __init__(
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
linear_cls = nn.Linear

if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim, bias=bias)
Expand All @@ -651,7 +650,7 @@ def __init__(
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
Expand Down
19 changes: 8 additions & 11 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,25 +181,22 @@ def __init__(
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
)

linear_cls = nn.Linear

self.linear_cls = linear_cls
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)

if not self.only_cross_attention:
# only relevant for the `AddedKVProcessor` classes
self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
else:
self.to_k = None
self.to_v = None

if self.added_kv_proj_dim is not None:
self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)

self.to_out = nn.ModuleList([])
self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))

# set attention processor
Expand Down Expand Up @@ -706,7 +703,7 @@ def fuse_projections(self, fuse=True):
out_features = concatenated_weights.shape[0]

# create a new single projection layer and copy over the weights.
self.to_qkv = self.linear_cls(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_qkv.weight.copy_(concatenated_weights)
if self.use_bias:
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
Expand All @@ -717,7 +714,7 @@ def fuse_projections(self, fuse=True):
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]

self.to_kv = self.linear_cls(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_kv.weight.copy_(concatenated_weights)
if self.use_bias:
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
Expand Down
3 changes: 1 addition & 2 deletions src/diffusers/models/downsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def __init__(
self.padding = padding
stride = 2
self.name = name
conv_cls = nn.Conv2d

if norm_type == "ln_norm":
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
Expand All @@ -114,7 +113,7 @@ def __init__(
raise ValueError(f"unknown norm_type: {norm_type}")

if use_conv:
conv = conv_cls(
conv = nn.Conv2d(
self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
)
else:
Expand Down
5 changes: 2 additions & 3 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,8 @@ def __init__(
sample_proj_bias=True,
):
super().__init__()
linear_cls = nn.Linear

self.linear_1 = linear_cls(in_channels, time_embed_dim, sample_proj_bias)
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)

if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
Expand All @@ -214,7 +213,7 @@ def __init__(
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out, sample_proj_bias)
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)

if post_act_fn is None:
self.post_act = None
Expand Down
21 changes: 8 additions & 13 deletions src/diffusers/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,6 @@ def __init__(
self.output_scale_factor = output_scale_factor
self.time_embedding_norm = time_embedding_norm

conv_cls = nn.Conv2d

if groups_out is None:
groups_out = groups

Expand All @@ -113,7 +111,7 @@ def __init__(
else:
raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}")

self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)

if self.time_embedding_norm == "ada_group": # ada_group
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
Expand All @@ -125,7 +123,7 @@ def __init__(
self.dropout = torch.nn.Dropout(dropout)

conv_2d_out_channels = conv_2d_out_channels or out_channels
self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)

self.nonlinearity = get_activation(non_linearity)

Expand All @@ -139,7 +137,7 @@ def __init__(

self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = conv_cls(
self.conv_shortcut = nn.Conv2d(
in_channels,
conv_2d_out_channels,
kernel_size=1,
Expand Down Expand Up @@ -263,21 +261,18 @@ def __init__(
self.time_embedding_norm = time_embedding_norm
self.skip_time_act = skip_time_act

linear_cls = nn.Linear
conv_cls = nn.Conv2d

if groups_out is None:
groups_out = groups

self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)

self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)

if temb_channels is not None:
if self.time_embedding_norm == "default":
self.time_emb_proj = linear_cls(temb_channels, out_channels)
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
elif self.time_embedding_norm == "scale_shift":
self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
self.time_emb_proj = nn.Linear(temb_channels, 2 * out_channels)
else:
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
else:
Expand All @@ -287,7 +282,7 @@ def __init__(

self.dropout = torch.nn.Dropout(dropout)
conv_2d_out_channels = conv_2d_out_channels or out_channels
self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)

self.nonlinearity = get_activation(non_linearity)

Expand All @@ -313,7 +308,7 @@ def __init__(

self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = conv_cls(
self.conv_shortcut = nn.Conv2d(
in_channels,
conv_2d_out_channels,
kernel_size=1,
Expand Down
11 changes: 4 additions & 7 deletions src/diffusers/models/transformers/transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,6 @@ def __init__(
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim

conv_cls = nn.Conv2d
linear_cls = nn.Linear

# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
Expand Down Expand Up @@ -159,9 +156,9 @@ def __init__(

self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if use_linear_projection:
self.proj_in = linear_cls(in_channels, inner_dim)
self.proj_in = nn.Linear(in_channels, inner_dim)
else:
self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
Expand Down Expand Up @@ -222,9 +219,9 @@ def __init__(
if self.is_input_continuous:
# TODO: should use out_channels for continuous projections
if use_linear_projection:
self.proj_out = linear_cls(inner_dim, in_channels)
self.proj_out = nn.Linear(inner_dim, in_channels)
else:
self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
Expand Down
9 changes: 4 additions & 5 deletions src/diffusers/models/unets/unet_stable_cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ def forward(self, x):
class SDCascadeTimestepBlock(nn.Module):
def __init__(self, c, c_timestep, conds=[]):
super().__init__()
linear_cls = nn.Linear
self.mapper = linear_cls(c_timestep, c * 2)

self.mapper = nn.Linear(c_timestep, c * 2)
self.conds = conds
for cname in conds:
setattr(self, f"mapper_{cname}", linear_cls(c_timestep, c * 2))
setattr(self, f"mapper_{cname}", nn.Linear(c_timestep, c * 2))

def forward(self, x, t):
t = t.chunk(len(self.conds) + 1, dim=1)
Expand Down Expand Up @@ -94,12 +94,11 @@ def forward(self, x):
class SDCascadeAttnBlock(nn.Module):
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
super().__init__()
linear_cls = nn.Linear

self.self_attn = self_attn
self.norm = SDCascadeLayerNorm(c, elementwise_affine=False, eps=1e-6)
self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True)
self.kv_mapper = nn.Sequential(nn.SiLU(), linear_cls(c_cond, c))
self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c))

def forward(self, x, kv):
kv = self.kv_mapper(kv)
Expand Down
3 changes: 1 addition & 2 deletions src/diffusers/models/upsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def __init__(
self.use_conv_transpose = use_conv_transpose
self.name = name
self.interpolate = interpolate
conv_cls = nn.Conv2d

if norm_type == "ln_norm":
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
Expand All @@ -131,7 +130,7 @@ def __init__(
elif use_conv:
if kernel_size is None:
kernel_size = 3
conv = conv_cls(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)

# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv":
Expand Down
15 changes: 5 additions & 10 deletions src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def forward(self, x):
class TimestepBlock(nn.Module):
def __init__(self, c, c_timestep):
super().__init__()
linear_cls = nn.Linear
self.mapper = linear_cls(c_timestep, c * 2)

self.mapper = nn.Linear(c_timestep, c * 2)

def forward(self, x, t):
a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1)
Expand All @@ -29,13 +29,10 @@ class ResBlock(nn.Module):
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
super().__init__()

conv_cls = nn.Conv2d
linear_cls = nn.Linear

self.depthwise = conv_cls(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
self.channelwise = nn.Sequential(
linear_cls(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), linear_cls(c * 4, c)
nn.Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), nn.Linear(c * 4, c)
)

def forward(self, x, x_skip=None):
Expand Down Expand Up @@ -64,12 +61,10 @@ class AttnBlock(nn.Module):
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
super().__init__()

linear_cls = nn.Linear

self.self_attn = self_attn
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True)
self.kv_mapper = nn.Sequential(nn.SiLU(), linear_cls(c_cond, c))
self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c))

def forward(self, x, kv):
kv = self.kv_mapper(kv)
Expand Down
10 changes: 4 additions & 6 deletions src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,13 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
@register_to_config
def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1):
super().__init__()
conv_cls = nn.Conv2d
linear_cls = nn.Linear

self.c_r = c_r
self.projection = conv_cls(c_in, c, kernel_size=1)
self.projection = nn.Conv2d(c_in, c, kernel_size=1)
self.cond_mapper = nn.Sequential(
linear_cls(c_cond, c),
nn.Linear(c_cond, c),
nn.LeakyReLU(0.2),
linear_cls(c, c),
nn.Linear(c, c),
)

self.blocks = nn.ModuleList()
Expand All @@ -58,7 +56,7 @@ def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dro
self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout))
self.out = nn.Sequential(
WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6),
conv_cls(c, c_in * 2, kernel_size=1),
nn.Conv2d(c, c_in * 2, kernel_size=1),
)

self.gradient_checkpointing = False
Expand Down