diff --git a/taming/modules/vqvae/quantize.py b/taming/modules/vqvae/quantize.py index d75544e4..a0d14052 100644 --- a/taming/modules/vqvae/quantize.py +++ b/taming/modules/vqvae/quantize.py @@ -362,8 +362,8 @@ class EMAVectorQuantizer(nn.Module): def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5, remap=None, unknown_index="random"): super().__init__() - self.codebook_dim = codebook_dim - self.num_tokens = num_tokens + self.codebook_dim = embedding_dim + self.num_tokens = n_embed self.beta = beta self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)