Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 0 additions & 231 deletions autoencoding.ipynb

This file was deleted.

975 changes: 0 additions & 975 deletions experiment.py

This file was deleted.

Binary file modified imgs_manipulated/compare.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified imgs_manipulated/output.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
245 changes: 0 additions & 245 deletions interpolate.ipynb

This file was deleted.

270 changes: 0 additions & 270 deletions manipulate.ipynb

This file was deleted.

116 changes: 116 additions & 0 deletions model/unet_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import copy
import torch
from torch import nn
from torch.cuda import amp

class UNetModel:
"""Core model architecture implementation for diffusion models."""
def __init__(self, conf):
"""
Initialize the UNet model.

Args:
conf: Configuration object containing model parameters
"""
self.conf = conf
self.model = conf.make_model_conf().make_model()
self.ema_model = copy.deepcopy(self.model)
self.ema_model.requires_grad_(False)
self.ema_model.eval()

# Calculate model size
model_size = 0
for param in self.model.parameters():
model_size += param.data.nelement()
print('Model params: %.2f M' % (model_size / 1024 / 1024))

# Initialize samplers
self.sampler = conf.make_diffusion_conf().make_sampler()
self.eval_sampler = conf.make_eval_diffusion_conf().make_sampler()
self.T_sampler = conf.make_T_sampler()

# Initialize latent samplers if needed
if conf.train_mode.use_latent_net():
self.latent_sampler = conf.make_latent_diffusion_conf().make_sampler()
self.eval_latent_sampler = conf.make_latent_eval_diffusion_conf().make_sampler()
else:
self.latent_sampler = None
self.eval_latent_sampler = None

def update_ema(self, decay):
"""
Update the exponential moving average model.

Args:
decay: EMA decay rate
"""
self._ema(self.model, self.ema_model, decay)

def _ema(self, source, target, decay):
"""
Apply exponential moving average update.

Args:
source: Source model
target: Target model (EMA)
decay: EMA decay rate
"""
source_dict = source.state_dict()
target_dict = target.state_dict()
for key in source_dict.keys():
target_dict[key].data.copy_(target_dict[key].data * decay +
source_dict[key].data * (1 - decay))

def encode(self, x):
"""
Encode input using the model's encoder.

Args:
x: Input tensor

Returns:
Encoded representation
"""
assert self.conf.model_type.has_autoenc()
cond = self.ema_model.encoder.forward(x)
return cond

def encode_stochastic(self, x, cond, T=None):
"""
Stochastically encode input.

Args:
x: Input tensor
cond: Conditioning tensor
T: Number of diffusion steps

Returns:
Stochastically encoded sample
"""
if T is None:
sampler = self.eval_sampler
else:
sampler = self.conf._make_diffusion_conf(T).make_sampler()
out = sampler.ddim_reverse_sample_loop(self.ema_model,
x,
model_kwargs={'cond': cond})
return out['sample']

def forward(self, noise=None, x_start=None, use_ema=False):
"""
Forward pass through the model.

Args:
noise: Input noise
x_start: Starting point for diffusion
use_ema: Whether to use EMA model

Returns:
Generated sample
"""
with amp.autocast(False):
model = self.ema_model if use_ema else self.model
gen = self.eval_sampler.sample(model=model,
noise=noise,
x_start=x_start)
return gen
297 changes: 297 additions & 0 deletions notebook/autoencoding.ipynb

Large diffs are not rendered by default.

316 changes: 316 additions & 0 deletions notebook/interpolate.ipynb

Large diffs are not rendered by default.

399 changes: 399 additions & 0 deletions notebook/manipulate.ipynb

Large diffs are not rendered by default.

File renamed without changes.
181 changes: 181 additions & 0 deletions notebook/sample.ipynb

Large diffs are not rendered by default.

63 changes: 63 additions & 0 deletions notebook/xsem_from_image.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
{
"cells": [
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"from PIL import Image\n",
"\n",
"import torch\n",
"\n",
"from torchvision.transforms import functional as VF\n",
"\n",
"from templates import ffhq256_autoenc, LitModel"
],
"id": "da96b18b5c1ba66"
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"\n",
"device = 'cuda'\n",
"\n",
"conf = ffhq256_autoenc()\n",
"\n",
"model = LitModel(conf)\n"
],
"id": "4d66f421dbc620fd"
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"# Load Image\n",
"\n",
"img = Image.open('example.jpg').resize((256, 256)).convert('RGB')"
],
"id": "66599ce24f785418"
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"\n",
"# Encode\n",
"\n",
"xsem = model.encode(x)"
],
"id": "89b2012572f7b91"
}
],
"metadata": {},
"nbformat": 5,
"nbformat_minor": 9
}
101 changes: 101 additions & 0 deletions notebook/xt_from_image.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
{
"cells": [
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-03T02:32:59.986408Z",
"start_time": "2025-03-03T02:32:33.741561Z"
}
},
"cell_type": "code",
"source": [
"from PIL import Image\n",
"\n",
"import torch\n",
"\n",
"from torchvision.transforms import functional as VF\n",
"\n",
"from templates import ffhq256_autoenc, LitModel"
],
"id": "53da58752a83d6df",
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"E:\\enanalytica_shanghai\\diffae\\metrics.py:10: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from tqdm.autonotebook import tqdm, trange\n"
]
}
],
"execution_count": 1
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-03T02:33:26.229021Z",
"start_time": "2025-03-03T02:33:25.170628Z"
}
},
"cell_type": "code",
"source": [
"device = 'cuda:0'\n",
"\n",
"conf = ffhq256_autoenc()\n",
"\n",
"model = LitModel(conf)"
],
"id": "9e78bb2404a0eaf8",
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Seed set to 0\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model params: 160.69 M\n"
]
}
],
"execution_count": 2
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"# Load Image\n",
"\n",
"img = Image.open('example.jpg').resize((256, 256)).convert('RGB')\n",
"\n",
"\n",
"\n",
"# Convert to Tensor\n",
"\n",
"x = VF.to_tensor(img).unsqueeze(0).to(device)\n",
"\n",
"\n",
"\n",
"# Encode\n",
"\n",
"xt = model.encode_stochastic(x, cond, T=250)\n"
],
"id": "fec9d058ecbb3ad6"
}
],
"metadata": {
"kernelspec": {
"name": "python3",
"language": "python",
"display_name": "Python 3 (ipykernel)"
}
},
"nbformat": 5,
"nbformat_minor": 9
}
115 changes: 115 additions & 0 deletions preprocessing/unet_preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import torch
from torch.utils.data import DataLoader, TensorDataset, ConcatDataset
from dataset import *
from dist_utils import get_world_size, get_rank
import numpy as np

class UNetPreprocessor:
"""Handles data preprocessing and dataset creation for UNet models."""
def __init__(self, conf):
"""
Initialize the preprocessor.

Args:
conf: Configuration object
"""
self.conf = conf
self.train_data = None
self.val_data = None

def setup(self, seed=None, global_rank=0):
"""
Set up datasets with proper seeding.

Args:
seed: Random seed
global_rank: Current process rank
"""
# Set seed for each worker separately
if seed is not None:
seed_worker = seed * get_world_size() + global_rank
np.random.seed(seed_worker)
torch.manual_seed(seed_worker)
torch.cuda.manual_seed(seed_worker)
print('local seed:', seed_worker)

# Create datasets
self.train_data = self.conf.make_dataset()
print('train data:', len(self.train_data))
self.val_data = self.train_data
print('val data:', len(self.val_data))

def create_train_dataloader(self, batch_size, drop_last=True, shuffle=True):
"""
Create training dataloader.

Args:
batch_size: Batch size
drop_last: Whether to drop the last incomplete batch
shuffle: Whether to shuffle the data

Returns:
DataLoader for training
"""
if not hasattr(self, "train_data") or self.train_data is None:
self.setup()

# Create a DataLoader directly
dataloader = torch.utils.data.DataLoader(
self.train_data,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
num_workers=0, # Use 0 to avoid pickling issues
persistent_workers=False
)
return SizedIterableWrapper(dataloader, len(self.train_data))

def create_val_dataloader(self, batch_size, drop_last=False):
"""
Create validation dataloader.

Args:
batch_size: Batch size
drop_last: Whether to drop the last incomplete batch

Returns:
DataLoader for validation
"""
if not hasattr(self, "val_data") or self.val_data is None:
self.setup()

dataloader = torch.utils.data.DataLoader(
self.val_data,
batch_size=batch_size,
shuffle=False,
drop_last=drop_last,
num_workers=0,
persistent_workers=False
)
return dataloader

def create_latent_dataset(self, conds):
"""
Create a dataset from latent conditions.

Args:
conds: Latent conditions tensor

Returns:
TensorDataset containing the conditions
"""
return TensorDataset(conds)


class SizedIterableWrapper:
"""Wrapper for iterables that provides a __len__ method."""
def __init__(self, dataloader, length):
self.dataloader = dataloader
self._length = length

def __iter__(self):
return iter(self.dataloader)

def __len__(self):
return self._length
Empty file added report.md
Empty file.
Loading