diff --git a/.gitignore b/.gitignore index 3cd1a01..34c57b6 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,5 @@ __pycache__ out.* trace.json *.egg-info +images/ +results/ diff --git a/checkpoint/100.pth b/checkpoint/100.pth new file mode 100644 index 0000000..44d4996 Binary files /dev/null and b/checkpoint/100.pth differ diff --git a/checkpoint/150.pth b/checkpoint/150.pth new file mode 100644 index 0000000..44d4996 Binary files /dev/null and b/checkpoint/150.pth differ diff --git a/checkpoint/200.pth b/checkpoint/200.pth new file mode 100644 index 0000000..44d4996 Binary files /dev/null and b/checkpoint/200.pth differ diff --git a/checkpoint/250.pth b/checkpoint/250.pth new file mode 100644 index 0000000..44d4996 Binary files /dev/null and b/checkpoint/250.pth differ diff --git a/checkpoint/300.pth b/checkpoint/300.pth new file mode 100644 index 0000000..44d4996 Binary files /dev/null and b/checkpoint/300.pth differ diff --git a/checkpoint/350.pth b/checkpoint/350.pth new file mode 100644 index 0000000..44d4996 Binary files /dev/null and b/checkpoint/350.pth differ diff --git a/checkpoint/400.pth b/checkpoint/400.pth new file mode 100644 index 0000000..44d4996 Binary files /dev/null and b/checkpoint/400.pth differ diff --git a/checkpoint/450.pth b/checkpoint/450.pth new file mode 100644 index 0000000..44d4996 Binary files /dev/null and b/checkpoint/450.pth differ diff --git a/checkpoint/50.pth b/checkpoint/50.pth new file mode 100644 index 0000000..44d4996 Binary files /dev/null and b/checkpoint/50.pth differ diff --git a/checkpoint/500.pth b/checkpoint/500.pth new file mode 100644 index 0000000..44d4996 Binary files /dev/null and b/checkpoint/500.pth differ diff --git a/cli.py b/cli.py new file mode 100644 index 0000000..0d319af --- /dev/null +++ b/cli.py @@ -0,0 +1,519 @@ +"""Neural style transfer (https://arxiv.org/abs/1508.06576) in PyTorch.""" + +import argparse +import atexit +from dataclasses import asdict +import io +import json +from pathlib import Path +import platform +import sys +import webbrowser +import PIL +import numpy as np +from PIL import Image, ImageCms +from numpy.core.fromnumeric import argsort +from tifffile import TIFF, TiffWriter +import torch +import torch.multiprocessing as mp +from tqdm import tqdm +from style_transfer import srgb_profile, StyleTransfer, WebInterface, HRNet +from style_transfer.style_transfer_HRNet import * +import os + + + + +def prof_to_prof(image, src_prof, dst_prof, **kwargs): + src_prof = io.BytesIO(src_prof) + dst_prof = io.BytesIO(dst_prof) + return ImageCms.profileToProfile(image, src_prof, dst_prof, **kwargs) + + +def load_image(path, proof_prof=None): + src_prof = dst_prof = srgb_profile + try: + image = Image.open(path) + if 'icc_profile' in image.info: + src_prof = image.info['icc_profile'] + else: + image = image.convert('RGB') + if proof_prof is None: + if src_prof == dst_prof: + return image.convert('RGB') + return prof_to_prof(image, src_prof, dst_prof, outputMode='RGB') + proof_prof = Path(proof_prof).read_bytes() + cmyk = prof_to_prof(image, src_prof, proof_prof, outputMode='CMYK') + return prof_to_prof(cmyk, proof_prof, dst_prof, outputMode='RGB') + except OSError as err: + print_error(err) + sys.exit(1) + + +def save_pil(path, image): + try: + kwargs = {'icc_profile': srgb_profile} + if path.suffix.lower() in {'.jpg', '.jpeg'}: + kwargs['quality'] = 95 + kwargs['subsampling'] = 0 + elif path.suffix.lower() == '.webp': + kwargs['quality'] = 95 + image.save(path, **kwargs) + except (OSError, ValueError) as err: + print_error(err) + sys.exit(1) + + +def save_tiff(path, image): + tag = ('InterColorProfile', TIFF.DATATYPES.BYTE, + len(srgb_profile), srgb_profile, False) + try: + with TiffWriter(path) as writer: + writer.save(image, photometric='rgb', + resolution=(72, 72), extratags=[tag]) + except OSError as err: + print_error(err) + sys.exit(1) + + +def save_image(path, image): + path = Path(path) + tqdm.write(f'Writing image to {path}.') + if isinstance(image, Image.Image): + save_pil(path, image) + elif isinstance(image, np.ndarray) and path.suffix.lower() in {'.tif', '.tiff'}: + save_tiff(path, image) + else: + raise ValueError('Unsupported combination of image type and extension') + + +def get_safe_scale(w, h, dim): + """Given a w x h content image and that a dim x dim square does not + exceed GPU memory, compute a safe end_scale for that content image.""" + return int(pow(w / h if w > h else h / w, 1/2) * dim) + + +def setup_exceptions(): + try: + from IPython.core.ultratb import FormattedTB + sys.excepthook = FormattedTB(mode='Plain', color_scheme='Neutral') + except ImportError: + pass + + +def fix_start_method(): + if platform.system() == 'Darwin': + mp.set_start_method('spawn') + + +def print_error(err): + print('\033[31m{}:\033[0m {}'.format( + type(err).__name__, err), file=sys.stderr) + + +class Callback: + def __init__(self, st, args, image_type='pil', web_interface=None): + self.st = st + self.args = args + self.image_type = image_type + self.web_interface = web_interface + self.iterates = [] + self.progress = None + + def __call__(self, iterate): + # print(iterate.i) + self.iterates.append(asdict(iterate)) + if iterate.i == 1: + self.progress = tqdm(total=iterate.i_max, dynamic_ncols=True) + msg = 'Size: {}x{}, iteration: {}, loss: {:g}' + tqdm.write(msg.format(iterate.w, iterate.h, iterate.i, iterate.loss)) + # print(self.progress) + self.progress.update() + if self.web_interface is not None: + self.web_interface.put_iterate(iterate, self.st.get_image_tensor()) + if iterate.i == iterate.i_max: + self.progress.close() + if max(iterate.w, iterate.h) != self.args.end_scale: + save_image(self.args.output, + self.st.get_image(self.image_type)) + else: + if self.web_interface is not None: + self.web_interface.put_done() + elif iterate.i % self.args.save_every == 0: + save_image(self.args.output, self.st.get_image( + self.image_type)) # save intermediate results + + def close(self): + if self.progress is not None: + self.progress.close() + + def get_trace(self): + return {'args': self.args.__dict__, 'iterates': self.iterates} + +# import os +# os.environ['CUDA_VISIBLE_DEVICES'] = '3' + +loader = transforms.Compose([ + transforms.ToTensor()]) + +def PIL_to_tensor(image, device): + image = loader(image).unsqueeze(0) + return image.to(device, torch.float) + +def tensor_to_image(tensor): + tensor = tensor*255 + tensor = np.array(tensor, dtype=np.uint8) + if np.ndim(tensor)>3: + assert tensor.shape[0] == 1 + tensor = tensor[0] + return PIL.Image.fromarray(tensor) + +def train(model, st_model, + content_image, sky_mask, style_images, *, + style_weights=None, + content_weight: float = 0.04, + grad_weight: float = 20, + sky_weight: float = 1, + tv_weight: float = 2., + min_scale: int = 128, + end_scale: int = 512, + iterations: int = 500, + initial_iterations: int = 1000, + step_size: float = 0.02, + avg_decay: float = 0.99, + init: str = 'content', + style_scale_fac: float = 1., + style_size: int = None, + callback=None,): + model.train() + min_scale = min(min_scale, end_scale) + content_weights = [content_weight / + len(st_model.content_layers)] * len(st_model.content_layers) + + # style weights among multiple style images + if style_weights is None: + style_weights = [1 / len(style_images)] * len(style_images) + else: + weight_sum = sum(abs(w) for w in style_weights) + style_weights = [weight / weight_sum for weight in style_weights] + if len(style_images) != len(style_weights): + raise ValueError( + 'style_images and style_weights must have the same length') + + # add TVloss -> the sum of the absolute differences for neighboring pixel-values in the result image + tv_loss = Scale(LayerApply(TVLoss(), 'input'), tv_weight) + + # get a sequence of scales, from small to large + scales = gen_scales(min_scale, end_scale) + + # set the initial image and load it to device + cw, ch = size_to_fit(content_image.size, scales[0], scale_up=True) + if init == 'content': + st_model.image = TF.to_tensor( + content_image.resize((cw, ch), Image.LANCZOS))[None] + elif init == 'gray': + st_model.image = torch.rand([1, 3, ch, cw]) / 255 + 0.5 + elif init == 'uniform': + st_model.image = torch.rand([1, 3, ch, cw]) + elif init == 'style_mean': + means = [] + for i, image in enumerate(style_images): + means.append(TF.to_tensor(image).mean( + dim=(1, 2)) * style_weights[i]) + st_model.image = torch.rand([1, 3, ch, cw]) / \ + 255 + sum(means)[None, :, None, None] + else: + raise ValueError( + "init must be one of 'content', 'gray', 'uniform', 'style_mean'") + + # Stylize the image at successively finer scales, each greater by a factor of sqrt(2). + # This differs from the scheme given in Gatys et al. (2016). + for scale in scales: + if st_model.devices[0].type == 'cuda': + torch.cuda.empty_cache() + + # resize the content image to be smaller than [scale * scale] -> target size + cw, ch = size_to_fit(content_image.size, scale, scale_up=True) + content = TF.to_tensor(content_image.resize( + (cw, ch), Image.LANCZOS))[None] + content = content.to(st_model.devices[0]) + + # resize the mask along with the content iamge + mask = TF.to_tensor(sky_mask.resize((cw, ch), Image.LANCZOS))[None] + mask = mask.to(st_model.devices[0]) + + grad_loss = Scale(LayerApply(GradientLoss( + content, mask, sky_weight), 'input'), grad_weight) + + # add ContentLoss + content_feats = st_model.model(content, layers=st_model.content_layers) + content_losses = [] + for layer, weight in zip(st_model.content_layers, content_weights): + target = content_feats[layer] # target content feature + # how to calculate content loss? + content_losses.append( + Scale(LayerApply(ContentLoss(target), layer), weight)) + + style_targets, style_losses = {}, [] + # add StyleLoss + for i, image in enumerate(style_images): + # resize the image and load it to GPU + if style_size is None: + sw, sh = size_to_fit( + image.size, round(scale * style_scale_fac)) + else: + sw, sh = size_to_fit(image.size, style_size) + style = TF.to_tensor(image.resize( + (sw, sh), Image.LANCZOS))[None] + style = style.to(st_model.devices[0]) + + print(f'Processing style image ({sw}x{sh})...') + style_feats = st_model.model(style, layers=st_model.style_layers) + # Take the weighted average of multiple style targets (Gram matrices). + for layer in st_model.style_layers: + target = StyleLoss.get_target( + style_feats[layer]) * style_weights[i] + if layer not in style_targets: + style_targets[layer] = target + else: + style_targets[layer] += target + for layer, weight in zip(st_model.style_layers, st_model.style_weights): + target = style_targets[layer] + style_losses.append( + Scale(LayerApply(StyleLoss(target), layer), weight)) + + # Construct a list of losses + crit = SumLoss( + [*content_losses, *style_losses, tv_loss, grad_loss], verbose=False) + + # Warm-start the Adam optimizer if this is not the first scale. (load the previous optimizer state) + opt2 = optim.Adam([st_model.image], lr=step_size) + if scale != scales[0]: + opt_state = scale_adam(opt.state_dict(), (ch, cw)) + opt2.load_state_dict(opt_state) + opt = opt2 + + # empty GPU cache + if st_model.devices[0].type == 'cuda': + torch.cuda.empty_cache() + + for batch_count in range(iterations): + + # load content image + batch_data = TF.to_tensor( + content_image.resize((cw, ch), Image.LANCZOS))[None] + batch_data = batch_data.to('cuda:0') + batch_data = model(batch_data) + + # load output of HRNet into ST + st_model.image = batch_data + # the original input + # interpolate the initial image to the target size + st_model.image = interpolate( + st_model.image.detach(), (ch, cw), mode='bicubic').clamp(0, 1) + # averaging across the time?? + st_model.average = EMA(st_model.image, avg_decay) + st_model.image.requires_grad_() + + # extract feature + feats = st_model.model(st_model.image) + loss = crit(feats).cuda() # calculate all the losses at the same time + opt.zero_grad() + + # back + loss.backward() + opt.step() + with torch.no_grad(): + st_model.image.clamp_(0, 1) + + st_model.average.update(st_model.image) + + if callback is not None: + gpu_ram = 0 + for device in st_model.devices: + if device.type == 'cuda': + gpu_ram = max( + gpu_ram, torch.cuda.max_memory_allocated(device)) + callback(STIterate(w=cw, h=ch, i=batch_count+1, i_max=iterations, loss=loss.item(), + time=time.time(), gpu_ram=gpu_ram)) + + # save model + if ((batch_count+1)%50==0 or (batch_count+1)==iterations): + print("========Iteration {}/{}========".format(batch_count, iterations)) + checkpoint_path = os.path.join("checkpoint", str(batch_count+1) + ".pth") + torch.save(model.state_dict(), checkpoint_path) + print("Saved HRNet checkpoint file at {}".format(checkpoint_path)) + image_type = 'pil' + sample_img = st_model.get_image(image_type) + sample_img_path = os.path.join("results", "batch_result/"+str(batch_count+1)+'.jpg') + save_image(sample_img_path, sample_img) + +def main(): + setup_exceptions() + fix_start_method() + + p = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + def arg_info(arg): + defaults = StyleTransfer.stylize.__kwdefaults__ + default_types = StyleTransfer.stylize.__annotations__ + return {'default': defaults[arg], 'type': default_types[arg]} + + content_dir = './images/content/' + style_dir = './images/styles/' + my_content_image = 'ust1.jpg' + my_style_images = ['neon1.jpg'] + output_name = './results/neon1/neon1_ust1_cw0.04_gw20_sw1.5_mask_laplacian.png' + + # files + p.add_argument('--content', type=str, default=(content_dir + + my_content_image), help='the content image') + p.add_argument('--sky_mask', type=str, default=(content_dir + + my_content_image.split('.')[0]+'_skymask.jpg')) + p.add_argument('--styles', type=str, default=[(style_dir+i) + for i in my_style_images], nargs='+', metavar='style', help='the style images') + p.add_argument('--output', '-o', type=str, + default=output_name, help='the output image') + + # training param + p.add_argument('--style-weights', '-sw', type=float, nargs='+', default=None, + metavar='STYLE_WEIGHT', help='the relative weights for each style image') + p.add_argument('--content-weight', '-cw', ** + arg_info('content_weight'), help='the content weight') + p.add_argument('--grad-weight', '-gw', ** + arg_info('grad_weight'), help='the grad weight') + p.add_argument('--sky-weight', '-sky', ** + arg_info('sky_weight'), help='the sky weight') + + p.add_argument('--tv-weight', '-tw', **arg_info('tv_weight'), + help='the smoothing weight') + p.add_argument('--step-size', '-ss', **arg_info('step_size'), + help='the step size (learning rate)') + p.add_argument('--avg-decay', '-ad', **arg_info('avg_decay'), + help='the EMA decay rate for iterate averaging') + p.add_argument('--pooling', type=str, default='average', + choices=['max', 'average', 'l2'], help='the model\'s pooling mode') + p.add_argument('--devices', type=str, + default=['cuda:0'], nargs='+', help='the device names to use (omit for auto)') + + p.add_argument('--min-scale', '-ms', **arg_info('min_scale'), + help='the minimum scale (max image dim), in pixels') + p.add_argument('--end-scale', '-s', type=str, default='512', + help='the final scale (max image dim), in pixels') + + p.add_argument('--random-seed', '-r', type=int, + default=0, help='the random seed') + p.add_argument('--iterations', '-i', **arg_info('iterations'), + help='the number of iterations per scale') + p.add_argument('--initial-iterations', '-ii', **arg_info('initial_iterations'), + help='the number of iterations on the first scale') + p.add_argument('--save-every', type=int, default=50, + help='save the image every SAVE_EVERY iterations') + + p.add_argument('--init', **arg_info('init'), + choices=['content', 'gray', 'uniform', 'style_mean'], help='the initial image') + p.add_argument('--style-scale-fac', **arg_info('style_scale_fac'), + help='the relative scale of the style to the content') + p.add_argument('--style-size', **arg_info('style_size'), + help='the fixed scale of the style at different content scales') + + p.add_argument('--proof', type=str, default=None, + help='the ICC color profile (CMYK) for soft proofing the content and styles') + p.add_argument('--web', default=False, action='store_true', + help='enable the web interface') + p.add_argument('--host', type=str, default='0.0.0.0', + help='the host the web interface binds to') + p.add_argument('--port', type=int, default=8080, + help='the port the web interface binds to') + p.add_argument('--browser', type=str, default='', nargs='?', + help='open a web browser (specify the browser if not system default)') + + args = p.parse_args() + + # load images + content_img = load_image(args.content, args.proof) + sky_mask = load_image(args.sky_mask, args.proof) + style_imgs = [load_image(img, args.proof) for img in args.styles] + image_type = 'pil' + if Path(args.output).suffix.lower() in {'.tif', '.tiff'}: + image_type = 'np_uint16' + + # find device + devices = [torch.device(device) for device in args.devices] + if not devices: + devices = [torch.device( + 'cuda:0' if torch.cuda.is_available() else 'cpu')] + if len(set(device.type for device in devices)) != 1: + print('Devices must all be the same type.') + sys.exit(1) + if not 1 <= len(devices) <= 2: + print('Only 1 or 2 devices are supported.') + sys.exit(1) + print('Using devices:', ' '.join(str(device) for device in devices)) + + # print device information + if devices[0].type == 'cpu': + print('CPU threads:', torch.get_num_threads()) + if devices[0].type == 'cuda': + for i, device in enumerate(devices): + props = torch.cuda.get_device_properties(device) + print( + f'GPU {i} type: {props.name} (compute {props.major}.{props.minor})') + print(f'GPU {i} RAM:', round( + props.total_memory / 1024 / 1024), 'MB') + + # verify end scale + end_scale = int(args.end_scale.rstrip('+')) + if args.end_scale.endswith('+'): + end_scale = get_safe_scale(*content_img.size, end_scale) + args.end_scale = end_scale + + web_interface = None + if args.web: + web_interface = WebInterface(args.host, args.port) + atexit.register(web_interface.close) + + for device in devices: + torch.tensor(0).to(device) + torch.manual_seed(args.random_seed) + + print('Loading model...') + + # load the model + st = StyleTransfer(devices=devices, pooling=args.pooling) + feedforward_net = HRNet.HRNet().cuda() + + callback = Callback(st, args, image_type=image_type, + web_interface=web_interface) + atexit.register(callback.close) + + # setup online monitor + url = f'http://{args.host}:{args.port}/' + if args.web: + if args.browser: + webbrowser.get(args.browser).open(url) + elif args.browser is None: + webbrowser.open(url) + + # do style transfer + # get the default keyword dictionary + defaults = StyleTransfer.stylize.__kwdefaults__ + # find modified args and put them into an array + st_kwargs = {k: v for k, v in args.__dict__.items() if k in defaults} + try: + train(feedforward_net, st, content_img, sky_mask, style_imgs, ** + st_kwargs, callback=callback) # training + except KeyboardInterrupt: + pass + + # get the result image + output_image = st.get_image(image_type) + if output_image is not None: + save_image(args.output, output_image) + with open('trace.json', 'w') as fp: + json.dump(callback.get_trace(), fp, indent=4) + + +if __name__ == '__main__': + main() diff --git a/data/color150.mat b/data/color150.mat new file mode 100644 index 0000000..c518b64 Binary files /dev/null and b/data/color150.mat differ diff --git a/images/content/ust1.jpg b/images/content/ust1.jpg new file mode 100644 index 0000000..18c1e2f Binary files /dev/null and b/images/content/ust1.jpg differ diff --git a/images/content/ust1_skymask.jpg b/images/content/ust1_skymask.jpg new file mode 100644 index 0000000..d6ab406 Binary files /dev/null and b/images/content/ust1_skymask.jpg differ diff --git a/images/content/ust39.jpg b/images/content/ust39.jpg new file mode 100644 index 0000000..0e997cb Binary files /dev/null and b/images/content/ust39.jpg differ diff --git a/images/content/ust4.jpg b/images/content/ust4.jpg new file mode 100644 index 0000000..34bb6ae Binary files /dev/null and b/images/content/ust4.jpg differ diff --git a/images/content/ust4_mask.jpg b/images/content/ust4_mask.jpg new file mode 100644 index 0000000..9f73d00 Binary files /dev/null and b/images/content/ust4_mask.jpg differ diff --git a/images/content/ust99.jpg b/images/content/ust99.jpg new file mode 100644 index 0000000..5962167 Binary files /dev/null and b/images/content/ust99.jpg differ diff --git a/images/styles/neon1.jpg b/images/styles/neon1.jpg new file mode 100644 index 0000000..50aa0da Binary files /dev/null and b/images/styles/neon1.jpg differ diff --git a/results/1024.png b/results/1024.png new file mode 100644 index 0000000..3f7f541 Binary files /dev/null and b/results/1024.png differ diff --git a/results/512.png b/results/512.png new file mode 100644 index 0000000..98834c2 Binary files /dev/null and b/results/512.png differ diff --git a/results/average_pool.png b/results/average_pool.png new file mode 100644 index 0000000..eef4877 Binary files /dev/null and b/results/average_pool.png differ diff --git a/results/snow1_cw0.04.png b/results/snow1_cw0.04.png new file mode 100644 index 0000000..662ed85 Binary files /dev/null and b/results/snow1_cw0.04.png differ diff --git a/results/snow1_cw0.08.png b/results/snow1_cw0.08.png new file mode 100644 index 0000000..840c850 Binary files /dev/null and b/results/snow1_cw0.08.png differ diff --git a/results/snow2/snow2_cw0.04.png b/results/snow2/snow2_cw0.04.png new file mode 100644 index 0000000..58e7b92 Binary files /dev/null and b/results/snow2/snow2_cw0.04.png differ diff --git a/results/snow2/snow2_cw0.04_average-style-weight.png b/results/snow2/snow2_cw0.04_average-style-weight.png new file mode 100644 index 0000000..963eea2 Binary files /dev/null and b/results/snow2/snow2_cw0.04_average-style-weight.png differ diff --git a/results/snow2/snow2_cw0.04_gradloss=2.png b/results/snow2/snow2_cw0.04_gradloss=2.png new file mode 100644 index 0000000..2d19480 Binary files /dev/null and b/results/snow2/snow2_cw0.04_gradloss=2.png differ diff --git a/results/snow2/snow2_cw0.04_gradloss=6.png b/results/snow2/snow2_cw0.04_gradloss=6.png new file mode 100644 index 0000000..3061855 Binary files /dev/null and b/results/snow2/snow2_cw0.04_gradloss=6.png differ diff --git a/results/snow2/snow2_cw0.04_layer-shift.png b/results/snow2/snow2_cw0.04_layer-shift.png new file mode 100644 index 0000000..e2baf3c Binary files /dev/null and b/results/snow2/snow2_cw0.04_layer-shift.png differ diff --git a/results/style2_0.015.png b/results/style2_0.015.png new file mode 100644 index 0000000..99dd842 Binary files /dev/null and b/results/style2_0.015.png differ diff --git a/results/style2_cw0.03.png b/results/style2_cw0.03.png new file mode 100644 index 0000000..594fd8d Binary files /dev/null and b/results/style2_cw0.03.png differ diff --git a/results/style2_cw0.1.png b/results/style2_cw0.1.png new file mode 100644 index 0000000..13aea79 Binary files /dev/null and b/results/style2_cw0.1.png differ diff --git a/results/style3_cw0.05.png b/results/style3_cw0.05.png new file mode 100644 index 0000000..de58873 Binary files /dev/null and b/results/style3_cw0.05.png differ diff --git a/results/style3_cw0.08.png b/results/style3_cw0.08.png new file mode 100644 index 0000000..0562e7a Binary files /dev/null and b/results/style3_cw0.08.png differ diff --git a/style_transfer/HRNet.py b/style_transfer/HRNet.py new file mode 100644 index 0000000..553df73 --- /dev/null +++ b/style_transfer/HRNet.py @@ -0,0 +1,134 @@ +import torch +import torch.nn as nn +import torchvision +import numpy as np + +IN_MOMENTUM = 0.1 + + +class ReflectionConv(nn.Module): + ''' + Reflection padding convolution + ''' + def __init__(self, in_channels, out_channels, kernel_size, stride): + super(ReflectionConv, self).__init__() + reflection_padding = int(np.floor(kernel_size / 2)) + self.reflection_pad = nn.ReflectionPad2d(reflection_padding) + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride) + + def forward(self, x): + out = self.reflection_pad(x) + out = self.conv(out) + return out + + +class ConvLayer(nn.Module): + ''' + zero-padding convolution + ''' + def __init__(self, in_channels, out_channels, kernel_size, stride): + super(ConvLayer, self).__init__() + conv_padding = int(np.floor(kernel_size / 2)) + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=conv_padding) + def forward(self, x): + return self.conv(x) + + +class BasicBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride): + super(BasicBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.relu = nn.ReLU(inplace=True) # 1 + + self.identity_block = nn.Sequential( + ConvLayer(in_channels, out_channels // 4, kernel_size=1, stride=1), + nn.InstanceNorm2d(out_channels // 4, momentum=IN_MOMENTUM), + nn.ReLU(), + ConvLayer(out_channels // 4, out_channels // 4, kernel_size, stride=stride), + nn.InstanceNorm2d(out_channels // 4, momentum=IN_MOMENTUM), + nn.ReLU(), + ConvLayer(out_channels // 4, out_channels, kernel_size=1, stride=1), + nn.InstanceNorm2d(out_channels, momentum=IN_MOMENTUM), + nn.ReLU(), + ) + self.shortcut = nn.Sequential( + ConvLayer(in_channels, out_channels, 1, stride), + nn.InstanceNorm2d(out_channels) + ) + + def forward(self, x): + out = self.identity_block(x) + if self.in_channels == self.out_channels: + residual = x + else: + residual = self.shortcut(x) + out += residual + out = self.relu(out) + + return out + + +class Upsample(nn.Module): + ''' + Since the number of channels of the feature map changes after upsampling in HRNet. + we have to write a new Upsample class. + ''' + def __init__(self, in_channels, out_channels, scale_factor, mode): + super(Upsample, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1) + self.upsample = nn.Upsample(scale_factor=scale_factor, mode='nearest') + self.instance = nn.InstanceNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + out = self.conv(x) + out = self.upsample(out) + out = self.instance(out) + out = self.relu(out) + + return out + + +class HRNet(nn.Module): + def __init__(self): + super(HRNet, self).__init__() + + self.pass1_1 = BasicBlock(3, 16, kernel_size=3, stride=1) + self.pass1_2 = BasicBlock(16, 32, kernel_size=3, stride=1) + self.pass1_3 = BasicBlock(32, 32, kernel_size=3, stride=1) + self.pass1_4 = BasicBlock(64, 64, kernel_size=3, stride=1) + self.pass1_5 = BasicBlock(192, 64, kernel_size=3, stride=1) + self.pass1_6 = BasicBlock(64, 32, kernel_size=3, stride=1) + self.pass1_7 = BasicBlock(32, 16, kernel_size=3, stride=1) + self.pass1_8 = nn.Conv2d(16, 3, kernel_size=3, stride=1, padding=1) + self.pass2_1 = BasicBlock(32, 32, kernel_size=3, stride=1) + self.pass2_2 = BasicBlock(64, 64, kernel_size=3, stride=1) + + self.downsample1_1 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1) + self.downsample1_2 = nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1) + self.downsample1_3 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1) + self.downsample1_4 = nn.Conv2d(32, 32, kernel_size=3, stride=4, padding=1) + self.downsample1_5 = nn.Conv2d(64, 64, kernel_size=3, stride=4, padding=1) + self.downsample2_1 = nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1) + self.downsample2_2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1) + + self.upsample1_1 = nn.Upsample(scale_factor=2, mode='bilinear') + self.upsample1_2 = nn.Upsample(scale_factor=2, mode='bilinear') + self.upsample2_1 = nn.Upsample(scale_factor=4, mode='bilinear') + self.upsample2_2 = nn.Upsample(scale_factor=2, mode='bilinear') + + def forward(self, x): + map1 = self.pass1_1(x) + map2 = self.pass1_2(map1) + map3 = self.downsample1_1(map1) + map4 = torch.cat((self.pass1_3(map2), self.upsample1_1(map3)), 1) + map5 = torch.cat((self.downsample1_2(map2), self.pass2_1(map3)), 1) + map6 = torch.cat((self.downsample1_4(map2), self.downsample2_1(map3)), 1) + map7 = torch.cat((self.pass1_4(map4), self.upsample1_2(map5), self.upsample2_1(map6)), 1) + out = self.pass1_5(map7) + out = self.pass1_6(out) + out = self.pass1_7(out) + out = self.pass1_8(out) + + return out \ No newline at end of file diff --git a/style_transfer/cli.py b/style_transfer/cli.py deleted file mode 100644 index c7644d9..0000000 --- a/style_transfer/cli.py +++ /dev/null @@ -1,271 +0,0 @@ -"""Neural style transfer (https://arxiv.org/abs/1508.06576) in PyTorch.""" - -import argparse -import atexit -from dataclasses import asdict -import io -import json -from pathlib import Path -import platform -import sys -import webbrowser - -import numpy as np -from PIL import Image, ImageCms -from tifffile import TIFF, TiffWriter -import torch -import torch.multiprocessing as mp -from tqdm import tqdm - -from . import srgb_profile, StyleTransfer, WebInterface - - -def prof_to_prof(image, src_prof, dst_prof, **kwargs): - src_prof = io.BytesIO(src_prof) - dst_prof = io.BytesIO(dst_prof) - return ImageCms.profileToProfile(image, src_prof, dst_prof, **kwargs) - - -def load_image(path, proof_prof=None): - src_prof = dst_prof = srgb_profile - try: - image = Image.open(path) - if 'icc_profile' in image.info: - src_prof = image.info['icc_profile'] - else: - image = image.convert('RGB') - if proof_prof is None: - if src_prof == dst_prof: - return image.convert('RGB') - return prof_to_prof(image, src_prof, dst_prof, outputMode='RGB') - proof_prof = Path(proof_prof).read_bytes() - cmyk = prof_to_prof(image, src_prof, proof_prof, outputMode='CMYK') - return prof_to_prof(cmyk, proof_prof, dst_prof, outputMode='RGB') - except OSError as err: - print_error(err) - sys.exit(1) - - -def save_pil(path, image): - try: - kwargs = {'icc_profile': srgb_profile} - if path.suffix.lower() in {'.jpg', '.jpeg'}: - kwargs['quality'] = 95 - kwargs['subsampling'] = 0 - elif path.suffix.lower() == '.webp': - kwargs['quality'] = 95 - image.save(path, **kwargs) - except (OSError, ValueError) as err: - print_error(err) - sys.exit(1) - - -def save_tiff(path, image): - tag = ('InterColorProfile', TIFF.DATATYPES.BYTE, len(srgb_profile), srgb_profile, False) - try: - with TiffWriter(path) as writer: - writer.save(image, photometric='rgb', resolution=(72, 72), extratags=[tag]) - except OSError as err: - print_error(err) - sys.exit(1) - - -def save_image(path, image): - path = Path(path) - tqdm.write(f'Writing image to {path}.') - if isinstance(image, Image.Image): - save_pil(path, image) - elif isinstance(image, np.ndarray) and path.suffix.lower() in {'.tif', '.tiff'}: - save_tiff(path, image) - else: - raise ValueError('Unsupported combination of image type and extension') - - -def get_safe_scale(w, h, dim): - """Given a w x h content image and that a dim x dim square does not - exceed GPU memory, compute a safe end_scale for that content image.""" - return int(pow(w / h if w > h else h / w, 1/2) * dim) - - -def setup_exceptions(): - try: - from IPython.core.ultratb import FormattedTB - sys.excepthook = FormattedTB(mode='Plain', color_scheme='Neutral') - except ImportError: - pass - - -def fix_start_method(): - if platform.system() == 'Darwin': - mp.set_start_method('spawn') - - -def print_error(err): - print('\033[31m{}:\033[0m {}'.format(type(err).__name__, err), file=sys.stderr) - - -class Callback: - def __init__(self, st, args, image_type='pil', web_interface=None): - self.st = st - self.args = args - self.image_type = image_type - self.web_interface = web_interface - self.iterates = [] - self.progress = None - - def __call__(self, iterate): - self.iterates.append(asdict(iterate)) - if iterate.i == 1: - self.progress = tqdm(total=iterate.i_max, dynamic_ncols=True) - msg = 'Size: {}x{}, iteration: {}, loss: {:g}' - tqdm.write(msg.format(iterate.w, iterate.h, iterate.i, iterate.loss)) - self.progress.update() - if self.web_interface is not None: - self.web_interface.put_iterate(iterate, self.st.get_image_tensor()) - if iterate.i == iterate.i_max: - self.progress.close() - if max(iterate.w, iterate.h) != self.args.end_scale: - save_image(self.args.output, self.st.get_image(self.image_type)) - else: - if self.web_interface is not None: - self.web_interface.put_done() - elif iterate.i % self.args.save_every == 0: - save_image(self.args.output, self.st.get_image(self.image_type)) - - def close(self): - if self.progress is not None: - self.progress.close() - - def get_trace(self): - return {'args': self.args.__dict__, 'iterates': self.iterates} - - -def main(): - setup_exceptions() - fix_start_method() - - p = argparse.ArgumentParser(description=__doc__, - formatter_class=argparse.ArgumentDefaultsHelpFormatter) - - def arg_info(arg): - defaults = StyleTransfer.stylize.__kwdefaults__ - default_types = StyleTransfer.stylize.__annotations__ - return {'default': defaults[arg], 'type': default_types[arg]} - - p.add_argument('content', type=str, help='the content image') - p.add_argument('styles', type=str, nargs='+', metavar='style', help='the style images') - p.add_argument('--output', '-o', type=str, default='out.png', - help='the output image') - p.add_argument('--style-weights', '-sw', type=float, nargs='+', default=None, - metavar='STYLE_WEIGHT', help='the relative weights for each style image') - p.add_argument('--devices', type=str, default=[], nargs='+', - help='the device names to use (omit for auto)') - p.add_argument('--random-seed', '-r', type=int, default=0, - help='the random seed') - p.add_argument('--content-weight', '-cw', **arg_info('content_weight'), - help='the content weight') - p.add_argument('--tv-weight', '-tw', **arg_info('tv_weight'), - help='the smoothing weight') - p.add_argument('--min-scale', '-ms', **arg_info('min_scale'), - help='the minimum scale (max image dim), in pixels') - p.add_argument('--end-scale', '-s', type=str, default='512', - help='the final scale (max image dim), in pixels') - p.add_argument('--iterations', '-i', **arg_info('iterations'), - help='the number of iterations per scale') - p.add_argument('--initial-iterations', '-ii', **arg_info('initial_iterations'), - help='the number of iterations on the first scale') - p.add_argument('--save-every', type=int, default=50, - help='save the image every SAVE_EVERY iterations') - p.add_argument('--step-size', '-ss', **arg_info('step_size'), - help='the step size (learning rate)') - p.add_argument('--avg-decay', '-ad', **arg_info('avg_decay'), - help='the EMA decay rate for iterate averaging') - p.add_argument('--init', **arg_info('init'), - choices=['content', 'gray', 'uniform', 'style_mean'], - help='the initial image') - p.add_argument('--style-scale-fac', **arg_info('style_scale_fac'), - help='the relative scale of the style to the content') - p.add_argument('--style-size', **arg_info('style_size'), - help='the fixed scale of the style at different content scales') - p.add_argument('--pooling', type=str, default='max', choices=['max', 'average', 'l2'], - help='the model\'s pooling mode') - p.add_argument('--proof', type=str, default=None, - help='the ICC color profile (CMYK) for soft proofing the content and styles') - p.add_argument('--web', default=False, action='store_true', help='enable the web interface') - p.add_argument('--host', type=str, default='0.0.0.0', - help='the host the web interface binds to') - p.add_argument('--port', type=int, default=8080, - help='the port the web interface binds to') - p.add_argument('--browser', type=str, default='', nargs='?', - help='open a web browser (specify the browser if not system default)') - - args = p.parse_args() - - content_img = load_image(args.content, args.proof) - style_imgs = [load_image(img, args.proof) for img in args.styles] - - image_type = 'pil' - if Path(args.output).suffix.lower() in {'.tif', '.tiff'}: - image_type = 'np_uint16' - - devices = [torch.device(device) for device in args.devices] - if not devices: - devices = [torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')] - if len(set(device.type for device in devices)) != 1: - print('Devices must all be the same type.') - sys.exit(1) - if not 1 <= len(devices) <= 2: - print('Only 1 or 2 devices are supported.') - sys.exit(1) - print('Using devices:', ' '.join(str(device) for device in devices)) - - if devices[0].type == 'cpu': - print('CPU threads:', torch.get_num_threads()) - if devices[0].type == 'cuda': - for i, device in enumerate(devices): - props = torch.cuda.get_device_properties(device) - print(f'GPU {i} type: {props.name} (compute {props.major}.{props.minor})') - print(f'GPU {i} RAM:', round(props.total_memory / 1024 / 1024), 'MB') - - end_scale = int(args.end_scale.rstrip('+')) - if args.end_scale.endswith('+'): - end_scale = get_safe_scale(*content_img.size, end_scale) - args.end_scale = end_scale - - web_interface = None - if args.web: - web_interface = WebInterface(args.host, args.port) - atexit.register(web_interface.close) - - for device in devices: - torch.tensor(0).to(device) - torch.manual_seed(args.random_seed) - - print('Loading model...') - st = StyleTransfer(devices=devices, pooling=args.pooling) - callback = Callback(st, args, image_type=image_type, web_interface=web_interface) - atexit.register(callback.close) - - url = f'http://{args.host}:{args.port}/' - if args.web: - if args.browser: - webbrowser.get(args.browser).open(url) - elif args.browser is None: - webbrowser.open(url) - - defaults = StyleTransfer.stylize.__kwdefaults__ - st_kwargs = {k: v for k, v in args.__dict__.items() if k in defaults} - try: - st.stylize(content_img, style_imgs, **st_kwargs, callback=callback) - except KeyboardInterrupt: - pass - - output_image = st.get_image(image_type) - if output_image is not None: - save_image(args.output, output_image) - with open('trace.json', 'w') as fp: - json.dump(callback.get_trace(), fp, indent=4) - - -if __name__ == '__main__': - main() diff --git a/style_transfer/my_test.py b/style_transfer/my_test.py new file mode 100644 index 0000000..5d3baba --- /dev/null +++ b/style_transfer/my_test.py @@ -0,0 +1,16 @@ +import copy +from dataclasses import dataclass +from functools import partial +import time +import warnings + +import numpy as np +from PIL import Image +import torch +from torch import optim, nn +from torch.nn import functional as F +from torchvision import models, transforms +from torchvision.transforms import functional as TF + +s = 'ust4.jpg' +print(s.split('.')[0]) diff --git a/style_transfer/style_transfer.py b/style_transfer/style_transfer.py index fdabd15..48f4973 100644 --- a/style_transfer/style_transfer.py +++ b/style_transfer/style_transfer.py @@ -16,7 +16,8 @@ class VGGFeatures(nn.Module): - poolings = {'max': nn.MaxPool2d, 'average': nn.AvgPool2d, 'l2': partial(nn.LPPool2d, 2)} + poolings = {'max': nn.MaxPool2d, 'average': nn.AvgPool2d, + 'l2': partial(nn.LPPool2d, 2)} pooling_scales = {'max': 1., 'average': 2., 'l2': 0.78} def __init__(self, layers, pooling='max'): @@ -30,7 +31,8 @@ def __init__(self, layers, pooling='max'): # The PyTorch pre-trained VGG-19 has different parameters from Simonyan et al.'s original # model. - self.model = models.vgg19(pretrained=True).features[:self.layers[-1] + 1] + self.model = models.vgg19( + pretrained=True).features[:self.layers[-1] + 1] self.devices = [torch.device('cpu')] * len(self.model) # Reduces edge artifacts. @@ -74,14 +76,20 @@ def distribute_layers(self, devices): self.devices[i] = device def forward(self, input, layers=None): - layers = self.layers if layers is None else sorted(set(layers)) + layers = self.layers if layers is None else sorted( + set(layers)) # an array of layer numbers h, w = input.shape[2:4] - min_size = self._get_min_size(layers) + + # check min_size to reach the set of layers (make sure the feature map doesn't shrink to 0 by 0) + min_size = self._get_min_size(layers) # what is this doing? if min(h, w) < min_size: - raise ValueError(f'Input is {h}x{w} but must be at least {min_size}x{min_size}') + raise ValueError( + f'Input is {h}x{w} but must be at least {min_size}x{min_size}') + feats = {'input': input} + # normalize the input image with the mean and std of ImageNet input = self.normalize(input) - for i in range(max(layers) + 1): + for i in range(max(layers) + 1): # put the input through each layer of the model input = self.model[i](input.to(self.devices[i])) if i in layers: feats[i] = input @@ -130,16 +138,66 @@ def forward(self, input): return self.loss(self.get_target(input), self.target) -class TVLoss(nn.Module): +class TVLoss(nn.Module): # calculate sum of local differences on feature map?? """L2 total variation loss, as in Mahendran et al.""" def forward(self, input): + # (left,right,top,bottom) input = F.pad(input, (0, 1, 0, 1), 'replicate') - x_diff = input[..., :-1, 1:] - input[..., :-1, :-1] - y_diff = input[..., 1:, :-1] - input[..., :-1, :-1] + x_diff = input[:, :-1, 1:] - input[:, :-1, :-1] + y_diff = input[:, 1:, :-1] - input[:, :-1, :-1] return (x_diff**2 + y_diff**2).mean() +class GradientLoss(nn.Module): + def __init__(self, content_image, s_mask=None, s_weight=1): + super().__init__() + content_image = F.pad(content_image, (0, 1, 0, 1), 'replicate') + # print(content_image.shape) + content_grayscale = 0.2989 * \ + content_image[:, 0, :, :] + 0.5870*content_image[:, + 1, :, :] + 0.1140*content_image[:, 2, :, :] + self.register_buffer( + 'content_x_diff', content_grayscale[..., :-1, 1:] - content_grayscale[..., :-1, :-1]) + self.register_buffer( + 'content_y_diff', content_grayscale[..., 1:, :-1] - content_grayscale[..., :-1, :-1]) + self.register_buffer('sky_mask', s_mask*(s_weight*s_weight)) + # self.register_buffer('sky_weight', s_weight) + + def forward(self, input): + # (left,right,top,bottom) + input = F.pad(input, (0, 1, 0, 1), 'replicate') + input_grayscale = 0.2989 * \ + input[:, 0, :, :] + 0.5870 * \ + input[:, 1, :, :] + 0.1140*input[:, 2, :, :] + x_diff = input_grayscale[..., :-1, 1:] - input_grayscale[..., :-1, :-1] + y_diff = input_grayscale[..., 1:, :-1] - input_grayscale[..., :-1, :-1] + x_dist, y_dist = x_diff-self.content_x_diff, y_diff-self.content_y_diff + global_dist = (x_dist**2 + y_dist**2).mean() + sky_dist = 0 + if self.sky_mask != None: + sky_dist = ((x_dist*self.sky_mask)**2 + + (y_dist*self.sky_mask)**2).mean() + return global_dist + sky_dist + +# class GradientLoss(nn.Module): +# def __init__(self, content_image, s_mask = None, s_weight = 1): +# super().__init__() +# content_grayscale = 0.2989*content_image[:,0,:,:] + 0.5870*content_image[:,1,:,:] + 0.1140*content_image[:,2,:,:] +# D = torch.tensor([[[0,-1,0],[-1,4,-1], [0,-1,0]]]) +# self.register_buffer('content_gradmap', F.conv2d(content_grayscale, D, padding=1)) +# self.register_buffer('sky_mask', s_mask*(s_weight*s_weight)) + +# def forward(self, input): +# input_grayscale = 0.2989*input[:,0,:,:] + 0.5870*input[:,1,:,:] + 0.1140*input[:,2,:,:] +# D = torch.tensor([[[0,-1,0],[-1,4,-1], [0,-1,0]]]) +# global_dist = self.content_gradmap - F.conv2d(input_grayscale, D, padding=1) +# sky_dist = 0 +# if self.sky_mask != None: +# sky_dist = (global_dist*self.sky_mask).mean() +# return global_dist + sky_dist + + class SumLoss(nn.ModuleList): def __init__(self, losses, verbose=False): super().__init__(losses) @@ -166,7 +224,7 @@ def forward(self, *args, **kwargs): return self.module(*args, **kwargs) * self.scale -class LayerApply(nn.Module): +class LayerApply(nn.Module): # apply the loss function to some speficied layers def __init__(self, module, layer): super().__init__() self.module = module @@ -198,6 +256,7 @@ def update(self, input): self.value += (1 - self.decay) * input +# warp the image to be inside [max_dim * max_dim] square def size_to_fit(size, max_dim, scale_up=False): w, h = size if not scale_up and max(h, w) <= max_dim: @@ -210,10 +269,10 @@ def size_to_fit(size, max_dim, scale_up=False): return new_w, new_h -def gen_scales(start, end): +def gen_scales(start, end): # return an array of scales, each greater by a factor of sqrt(2) scale = end i = 0 - scales = set() + scales = set() # create an empty set while scale >= start: scales.add(scale) i += 1 @@ -233,10 +292,12 @@ def scale_adam(state, shape): for group in state['state'].values(): exp_avg, exp_avg_sq = group['exp_avg'], group['exp_avg_sq'] group['exp_avg'] = interpolate(exp_avg, shape, mode='bicubic') - group['exp_avg_sq'] = interpolate(exp_avg_sq, shape, mode='bilinear').relu_() + group['exp_avg_sq'] = interpolate( + exp_avg_sq, shape, mode='bilinear').relu_() if 'max_exp_avg_sq' in group: max_exp_avg_sq = group['max_exp_avg_sq'] - group['max_exp_avg_sq'] = interpolate(max_exp_avg_sq, shape, mode='bilinear').relu_() + group['max_exp_avg_sq'] = interpolate( + max_exp_avg_sq, shape, mode='bilinear').relu_() return state @@ -254,33 +315,37 @@ class STIterate: class StyleTransfer: def __init__(self, devices=['cpu'], pooling='max'): self.devices = [torch.device(device) for device in devices] - self.image = None - self.average = None + self.image = None # the output at each iteration + self.average = None # the final result is an average among outputs of each iteration # The default content and style layers follow Gatys et al. (2015). self.content_layers = [22] self.style_layers = [1, 6, 11, 20, 29] # The weighting of the style layers differs from Gatys et al. (2015) and Johnson et al. - style_weights = [256, 64, 16, 4, 1] + style_weights = [256, 64, 16, 4, 1] # default + # style_weights = [1, 1, 1, 1, 1] # average -> trial weight_sum = sum(abs(w) for w in style_weights) + # the normalized style weights for each style_layers self.style_weights = [w / weight_sum for w in style_weights] - self.model = VGGFeatures(self.style_layers + self.content_layers, pooling=pooling) + # the vgg model + self.model = VGGFeatures( + self.style_layers + self.content_layers, pooling=pooling) + # distribute model to two devices if possible if len(self.devices) == 1: device_plan = {0: self.devices[0]} elif len(self.devices) == 2: device_plan = {0: self.devices[0], 5: self.devices[1]} else: raise ValueError('Only 1 or 2 devices are supported.') - self.model.distribute_layers(device_plan) def get_image_tensor(self): return self.average.get().detach()[0].clamp(0, 1) - def get_image(self, image_type='pil'): + def get_image(self, image_type='pil'): # output the average image (but what's that?) if self.average is not None: image = self.get_image_tensor() if image_type.lower() == 'pil': @@ -291,9 +356,11 @@ def get_image(self, image_type='pil'): else: raise ValueError("image_type must be 'pil' or 'np_uint16'") - def stylize(self, content_image, style_images, *, + def stylize(self, content_image, sky_mask, style_images, *, style_weights=None, - content_weight: float = 0.015, + content_weight: float = 0.04, + grad_weight: float = 20, + sky_weight: float = 1.5, tv_weight: float = 2., min_scale: int = 128, end_scale: int = 512, @@ -307,23 +374,30 @@ def stylize(self, content_image, style_images, *, callback=None): min_scale = min(min_scale, end_scale) - content_weights = [content_weight / len(self.content_layers)] * len(self.content_layers) + content_weights = [content_weight / + len(self.content_layers)] * len(self.content_layers) + # style weights among multiple style images if style_weights is None: style_weights = [1 / len(style_images)] * len(style_images) else: weight_sum = sum(abs(w) for w in style_weights) style_weights = [weight / weight_sum for weight in style_weights] if len(style_images) != len(style_weights): - raise ValueError('style_images and style_weights must have the same length') + raise ValueError( + 'style_images and style_weights must have the same length') + # add TVloss -> the sum of the absolute differences for neighboring pixel-values in the result image tv_loss = Scale(LayerApply(TVLoss(), 'input'), tv_weight) + # get a sequence of scales, from small to large scales = gen_scales(min_scale, end_scale) + # set the initial image and load it to device cw, ch = size_to_fit(content_image.size, scales[0], scale_up=True) if init == 'content': - self.image = TF.to_tensor(content_image.resize((cw, ch), Image.LANCZOS))[None] + self.image = TF.to_tensor( + content_image.resize((cw, ch), Image.LANCZOS))[None] elif init == 'gray': self.image = torch.rand([1, 3, ch, cw]) / 255 + 0.5 elif init == 'uniform': @@ -331,84 +405,119 @@ def stylize(self, content_image, style_images, *, elif init == 'style_mean': means = [] for i, image in enumerate(style_images): - means.append(TF.to_tensor(image).mean(dim=(1, 2)) * style_weights[i]) - self.image = torch.rand([1, 3, ch, cw]) / 255 + sum(means)[None, :, None, None] + means.append(TF.to_tensor(image).mean( + dim=(1, 2)) * style_weights[i]) + self.image = torch.rand([1, 3, ch, cw]) / \ + 255 + sum(means)[None, :, None, None] else: - raise ValueError("init must be one of 'content', 'gray', 'uniform', 'style_mean'") + raise ValueError( + "init must be one of 'content', 'gray', 'uniform', 'style_mean'") self.image = self.image.to(self.devices[0]) opt = None - # Stylize the image at successively finer scales, each greater by a factor of sqrt(2). # This differs from the scheme given in Gatys et al. (2016). for scale in scales: if self.devices[0].type == 'cuda': torch.cuda.empty_cache() + # resize the content image to be smaller than [scale * scale] -> target size cw, ch = size_to_fit(content_image.size, scale, scale_up=True) - content = TF.to_tensor(content_image.resize((cw, ch), Image.LANCZOS))[None] + content = TF.to_tensor(content_image.resize( + (cw, ch), Image.LANCZOS))[None] content = content.to(self.devices[0]) - self.image = interpolate(self.image.detach(), (ch, cw), mode='bicubic').clamp(0, 1) + # resize the mask along with the content iamge + mask = TF.to_tensor(sky_mask.resize((cw, ch), Image.LANCZOS))[None] + mask = mask.to(self.devices[0]) + + grad_loss = Scale(LayerApply(GradientLoss( + content, mask, sky_weight), 'input'), grad_weight) + + # interpolate the initial image to the target size + self.image = interpolate( + self.image.detach(), (ch, cw), mode='bicubic').clamp(0, 1) + # averaging across the time?? self.average = EMA(self.image, avg_decay) self.image.requires_grad_() print(f'Processing content image ({cw}x{ch})...') + # add ContentLoss content_feats = self.model(content, layers=self.content_layers) content_losses = [] for layer, weight in zip(self.content_layers, content_weights): - target = content_feats[layer] - content_losses.append(Scale(LayerApply(ContentLoss(target), layer), weight)) + target = content_feats[layer] # target content feature + # how to calculate content loss? + content_losses.append( + Scale(LayerApply(ContentLoss(target), layer), weight)) style_targets, style_losses = {}, [] + # add StyleLoss for i, image in enumerate(style_images): + # resize the image and load it to GPU if style_size is None: - sw, sh = size_to_fit(image.size, round(scale * style_scale_fac)) + sw, sh = size_to_fit( + image.size, round(scale * style_scale_fac)) else: sw, sh = size_to_fit(image.size, style_size) - style = TF.to_tensor(image.resize((sw, sh), Image.LANCZOS))[None] + style = TF.to_tensor(image.resize( + (sw, sh), Image.LANCZOS))[None] style = style.to(self.devices[0]) + print(f'Processing style image ({sw}x{sh})...') style_feats = self.model(style, layers=self.style_layers) # Take the weighted average of multiple style targets (Gram matrices). for layer in self.style_layers: - target = StyleLoss.get_target(style_feats[layer]) * style_weights[i] + target = StyleLoss.get_target( + style_feats[layer]) * style_weights[i] if layer not in style_targets: style_targets[layer] = target else: style_targets[layer] += target for layer, weight in zip(self.style_layers, self.style_weights): target = style_targets[layer] - style_losses.append(Scale(LayerApply(StyleLoss(target), layer), weight)) + style_losses.append( + Scale(LayerApply(StyleLoss(target), layer), weight)) - crit = SumLoss([*content_losses, *style_losses, tv_loss]) + # Construct a list of losses + crit = SumLoss( + [*content_losses, *style_losses, tv_loss, grad_loss]) + # Warm-start the Adam optimizer if this is not the first scale. (load the previous optimizer state) opt2 = optim.Adam([self.image], lr=step_size) - # Warm-start the Adam optimizer if this is not the first scale. if scale != scales[0]: opt_state = scale_adam(opt.state_dict(), (ch, cw)) opt2.load_state_dict(opt_state) opt = opt2 + # empty GPU cache if self.devices[0].type == 'cuda': torch.cuda.empty_cache() + # forward & backward propagation + # first scale: 1000, others: 500 actual_its = initial_iterations if scale == scales[0] else iterations for i in range(1, actual_its + 1): feats = self.model(self.image) - loss = crit(feats) + loss = crit(feats) # calculate all the losses at the same time opt.zero_grad() loss.backward() opt.step() + # Enforce box constraints. with torch.no_grad(): self.image.clamp_(0, 1) + + # do averaging along time (to be investigated) self.average.update(self.image) + + # what does a callback function do? (to be investigated) if callback is not None: gpu_ram = 0 for device in self.devices: if device.type == 'cuda': - gpu_ram = max(gpu_ram, torch.cuda.max_memory_allocated(device)) + gpu_ram = max( + gpu_ram, torch.cuda.max_memory_allocated(device)) callback(STIterate(w=cw, h=ch, i=i, i_max=actual_its, loss=loss.item(), time=time.time(), gpu_ram=gpu_ram)) diff --git a/style_transfer/style_transfer_HRNet.py b/style_transfer/style_transfer_HRNet.py new file mode 100644 index 0000000..2140436 --- /dev/null +++ b/style_transfer/style_transfer_HRNet.py @@ -0,0 +1,539 @@ +"""Neural style transfer (https://arxiv.org/abs/1508.06576) in PyTorch.""" + +import copy +from dataclasses import dataclass +from functools import partial +import time +import warnings + +import numpy as np +from PIL import Image +import torch +from torch import optim, nn +from torch._C import device +from torch.nn import functional as F +from torchvision import models, transforms +from torchvision.transforms import functional as TF +from scipy.io import loadmat + + +colors = loadmat('data/color150.mat')['colors'] + +class VGGFeatures(nn.Module): + poolings = {'max': nn.MaxPool2d, 'average': nn.AvgPool2d, + 'l2': partial(nn.LPPool2d, 2)} + pooling_scales = {'max': 1., 'average': 2., 'l2': 0.78} + + def __init__(self, layers, pooling='max'): + super().__init__() + self.layers = sorted(set(layers)) + + # The PyTorch pre-trained VGG-19 expects sRGB inputs in the range [0, 1] which are then + # normalized according to this transform, unlike Simonyan et al.'s original model. + self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + # The PyTorch pre-trained VGG-19 has different parameters from Simonyan et al.'s original + # model. + self.model = models.vgg19( + pretrained=True).features[:self.layers[-1] + 1] + self.devices = [torch.device('cpu')] * len(self.model) + + # Reduces edge artifacts. + self.model[0] = self._change_padding_mode(self.model[0], 'replicate') + + pool_scale = self.pooling_scales[pooling] + for i, layer in enumerate(self.model): + if pooling != 'max' and isinstance(layer, nn.MaxPool2d): + # Changing the pooling type from max results in the scale of activations + # changing, so rescale them. Gatys et al. (2015) do not do this. + self.model[i] = Scale(self.poolings[pooling](2), pool_scale) + + self.model.eval() + self.model.requires_grad_(False) + + @staticmethod + def _change_padding_mode(conv, padding_mode): + new_conv = nn.Conv2d(conv.in_channels, conv.out_channels, conv.kernel_size, + stride=conv.stride, padding=conv.padding, + padding_mode=padding_mode) + with torch.no_grad(): + new_conv.weight.copy_(conv.weight) + new_conv.bias.copy_(conv.bias) + return new_conv + + @staticmethod + def _get_min_size(layers): + last_layer = max(layers) + min_size = 1 + for layer in [4, 9, 18, 27, 36]: + if last_layer < layer: + break + min_size *= 2 + return min_size + + def distribute_layers(self, devices): + for i, layer in enumerate(self.model): + if i in devices: + device = torch.device(devices[i]) + self.model[i] = layer.to(device) + self.devices[i] = device + + def forward(self, input, layers=None): + layers = self.layers if layers is None else sorted( + set(layers)) # an array of layer numbers + h, w = input.shape[2:4] + + # check min_size to reach the set of layers (make sure the feature map doesn't shrink to 0 by 0) + min_size = self._get_min_size(layers) # what is this doing? + if min(h, w) < min_size: + raise ValueError( + f'Input is {h}x{w} but must be at least {min_size}x{min_size}') + + feats = {'input': input} + # normalize the input image with the mean and std of ImageNet + input = self.normalize(input) + for i in range(max(layers) + 1): # put the input through each layer of the model + input = self.model[i](input.to(self.devices[i])) + if i in layers: + feats[i] = input + return feats + + +class ScaledMSELoss(nn.Module): + """Computes MSE scaled such that its gradient L1 norm is approximately 1. + This differs from Gatys at al. (2015) and Johnson et al.""" + + def __init__(self, eps=1e-8): + super().__init__() + self.register_buffer('eps', torch.tensor(eps)) + + def extra_repr(self): + return f'eps={self.eps:g}' + + def forward(self, input, target): + diff = input - target + return diff.pow(2).sum() / diff.abs().sum().add(self.eps) + + +class ContentLoss(nn.Module): + def __init__(self, target, eps=1e-8): + super().__init__() + self.register_buffer('target', target) + self.loss = ScaledMSELoss(eps=eps) + + def forward(self, input): + return self.loss(input, self.target) + + +class StyleLoss(nn.Module): + def __init__(self, target, eps=1e-8): + super().__init__() + self.register_buffer('target', target) + self.loss = ScaledMSELoss(eps=eps) + + @staticmethod + def get_target(target): + mat = target.flatten(-2) + # The Gram matrix normalization differs from Gatys et al. (2015) and Johnson et al. + return mat @ mat.transpose(-2, -1) / mat.shape[-1] + + def forward(self, input): + return self.loss(self.get_target(input), self.target) + + +class TVLoss(nn.Module): # calculate sum of local differences on feature map?? + """L2 total variation loss, as in Mahendran et al.""" + + def forward(self, input): + # (left,right,top,bottom) + input = F.pad(input, (0, 1, 0, 1), 'replicate') + x_diff = input[:, :-1, 1:] - input[:, :-1, :-1] + y_diff = input[:, 1:, :-1] - input[:, :-1, :-1] + return (x_diff**2 + y_diff**2).mean() + + +class GradientLoss(nn.Module): + def __init__(self, content_image, s_mask=None, s_weight=1): + super().__init__() + content_image = F.pad(content_image, (0, 1, 0, 1), 'replicate') + # print(content_image.shape) + content_grayscale = 0.2989 * \ + content_image[:, 0, :, :] + 0.5870*content_image[:, + 1, :, :] + 0.1140*content_image[:, 2, :, :] + self.register_buffer( + 'content_x_diff', content_grayscale[..., :-1, 1:] - content_grayscale[..., :-1, :-1]) + self.register_buffer( + 'content_y_diff', content_grayscale[..., 1:, :-1] - content_grayscale[..., :-1, :-1]) + self.register_buffer('sky_mask', s_mask*(s_weight*s_weight)) + # self.register_buffer('sky_weight', s_weight) + + def forward(self, input): + # (left,right,top,bottom) + input = input.to('cuda:0') + input = F.pad(input, (0, 1, 0, 1), 'replicate') + input_grayscale = 0.2989 * \ + input[:, 0, :, :] + 0.5870 * \ + input[:, 1, :, :] + 0.1140*input[:, 2, :, :] + x_diff = input_grayscale[..., :-1, 1:] - input_grayscale[..., :-1, :-1] + y_diff = input_grayscale[..., 1:, :-1] - input_grayscale[..., :-1, :-1] + # print(x_diff.get_device(), y_diff.get_device()) + # print(self.content_x_diff.get_device(), self.content_y_diff.get_device()) + x_dist, y_dist = x_diff-self.content_x_diff, y_diff-self.content_y_diff + global_dist = (x_dist**2 + y_dist**2).mean() + sky_dist = 0 + if self.sky_mask != None: + sky_dist = ((x_dist*self.sky_mask)**2 + + (y_dist*self.sky_mask)**2).mean() + return global_dist + sky_dist + + +class SumLoss(nn.ModuleList): + def __init__(self, losses, verbose=False): + super().__init__(losses) + self.verbose = verbose + + def forward(self, *args, **kwargs): + losses = [loss(*args, **kwargs) for loss in self] + if self.verbose: + for i, loss in enumerate(losses): + print(f'({i}): {loss.item():g}') + return sum(loss.to(losses[-1].device) for loss in losses) + + +class Scale(nn.Module): + def __init__(self, module, scale): + super().__init__() + self.module = module + self.register_buffer('scale', torch.tensor(scale)) + + def extra_repr(self): + return f'(scale): {self.scale.item():g}' + + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) * self.scale + + +class LayerApply(nn.Module): # apply the loss function to some speficied layers + def __init__(self, module, layer): + super().__init__() + self.module = module + self.layer = layer + + def extra_repr(self): + return f'(layer): {self.layer!r}' + + def forward(self, input): + return self.module(input[self.layer]) + + +class EMA(nn.Module): + """A bias-corrected exponential moving average, as in Kingma et al. (Adam).""" + + def __init__(self, input, decay): + super().__init__() + self.register_buffer('value', torch.zeros_like(input)) + self.register_buffer('decay', torch.tensor(decay)) + self.register_buffer('accum', torch.tensor(1.)) + self.update(input) + + def get(self): + return self.value / (1 - self.accum) + + def update(self, input): + self.accum *= self.decay + self.value *= self.decay + self.value += (1 - self.decay) * input + + +# warp the image to be inside [max_dim * max_dim] square +def size_to_fit(size, max_dim, scale_up=False): + w, h = size + if not scale_up and max(h, w) <= max_dim: + return w, h + new_w, new_h = max_dim, max_dim + if h > w: + new_w = round(max_dim * w / h) + else: + new_h = round(max_dim * h / w) + return new_w, new_h + + +def gen_scales(start, end): # return an array of scales, each greater by a factor of sqrt(2) + scale = end + i = 0 + scales = set() # create an empty set + while scale >= start: + scales.add(scale) + i += 1 + scale = round(end / pow(2, i/2)) + return sorted(scales) + + +def interpolate(*args, **kwargs): + with warnings.catch_warnings(): + warnings.simplefilter('ignore', UserWarning) + return F.interpolate(*args, **kwargs) + + +def scale_adam(state, shape): + """Prepares a state dict to warm-start the Adam optimizer at a new scale.""" + state = copy.deepcopy(state) + for group in state['state'].values(): + exp_avg, exp_avg_sq = group['exp_avg'], group['exp_avg_sq'] + group['exp_avg'] = interpolate(exp_avg, shape, mode='bicubic') + group['exp_avg_sq'] = interpolate( + exp_avg_sq, shape, mode='bilinear').relu_() + if 'max_exp_avg_sq' in group: + max_exp_avg_sq = group['max_exp_avg_sq'] + group['max_exp_avg_sq'] = interpolate( + max_exp_avg_sq, shape, mode='bilinear').relu_() + return state + + +@dataclass +class STIterate: + w: int + h: int + i: int + i_max: int + loss: float + time: float + gpu_ram: int + + +class StyleTransfer: + def __init__(self, devices=['cpu'], pooling='max'): + self.devices = [torch.device(device) for device in devices] + self.image = None # the output at each iteration + self.average = None # the final result is an average among outputs of each iteration + + # The default content and style layers follow Gatys et al. (2015). + self.content_layers = [22] + self.style_layers = [1, 6, 11, 20, 29] + + # The weighting of the style layers differs from Gatys et al. (2015) and Johnson et al. + style_weights = [256, 64, 16, 4, 1] # default + # style_weights = [1, 1, 1, 1, 1] # average -> trial + weight_sum = sum(abs(w) for w in style_weights) + # the normalized style weights for each style_layers + self.style_weights = [w / weight_sum for w in style_weights] + + # the vgg model + self.model = VGGFeatures( + self.style_layers + self.content_layers, pooling=pooling) + + # distribute model to two devices if possible + if len(self.devices) == 1: + device_plan = {0: self.devices[0]} + elif len(self.devices) == 2: + device_plan = {0: self.devices[0], 5: self.devices[1]} + else: + raise ValueError('Only 1 or 2 devices are supported.') + self.model.distribute_layers(device_plan) + + def get_image_tensor(self): + return self.average.get().detach()[0].clamp(0, 1) + + def get_image(self, image_type='pil'): # output the average image (but what's that?) + if self.average is not None: + image = self.get_image_tensor() + if image_type.lower() == 'pil': + return TF.to_pil_image(image) + elif image_type.lower() == 'np_uint16': + arr = image.cpu().movedim(0, 2).numpy() + return np.uint16(np.round(arr * 65535)) + else: + raise ValueError("image_type must be 'pil' or 'np_uint16'") + + + def stylize(self, content_image, sky_mask, style_images, *, + style_weights=None, + content_weight: float = 0.04, + grad_weight: float = 20, + sky_weight: float = 1, + tv_weight: float = 2., + min_scale: int = 128, + end_scale: int = 512, + iterations: int = 500, + initial_iterations: int = 1000, + step_size: float = 0.02, + avg_decay: float = 0.99, + init: str = 'content', + style_scale_fac: float = 1., + style_size: int = None, + callback=None): + + min_scale = min(min_scale, end_scale) + content_weights = [content_weight / + len(self.content_layers)] * len(self.content_layers) + + # style weights among multiple style images + if style_weights is None: + style_weights = [1 / len(style_images)] * len(style_images) + else: + weight_sum = sum(abs(w) for w in style_weights) + style_weights = [weight / weight_sum for weight in style_weights] + if len(style_images) != len(style_weights): + raise ValueError( + 'style_images and style_weights must have the same length') + + # add TVloss -> the sum of the absolute differences for neighboring pixel-values in the result image + tv_loss = Scale(LayerApply(TVLoss(), 'input'), tv_weight) + + # get a sequence of scales, from small to large + scales = gen_scales(min_scale, end_scale) + + # set the initial image and load it to device + cw, ch = size_to_fit(content_image.size, scales[0], scale_up=True) + if init == 'content': + self.image = TF.to_tensor( + content_image.resize((cw, ch), Image.LANCZOS))[None] + elif init == 'gray': + self.image = torch.rand([1, 3, ch, cw]) / 255 + 0.5 + elif init == 'uniform': + self.image = torch.rand([1, 3, ch, cw]) + elif init == 'style_mean': + means = [] + for i, image in enumerate(style_images): + means.append(TF.to_tensor(image).mean( + dim=(1, 2)) * style_weights[i]) + self.image = torch.rand([1, 3, ch, cw]) / \ + 255 + sum(means)[None, :, None, None] + else: + raise ValueError( + "init must be one of 'content', 'gray', 'uniform', 'style_mean'") + self.image = self.image.to(self.devices[0]) # the original input + + + if self.devices[0].type == 'cuda': + torch.cuda.empty_cache() + + content = content_image.to(self.devices[0]) + style = style_images[0].to(self.devices[0]) + mask = sky_mask.to(self.devices[0]) + + grad_loss = Scale(LayerApply(GradientLoss( + content, mask, sky_weight), 'input'), grad_weight) + + # add ContentLoss + content_feats = self.model(content, layers=self.content_layers) + content_losses = [] + for layer, weight in zip(self.content_layers, content_weights): + target = content_feats[layer] # target content feature + # how to calculate content loss? + content_losses.append( + Scale(LayerApply(ContentLoss(target), layer), weight)) + + opt = None + # Stylize the image at successively finer scales, each greater by a factor of sqrt(2). + # This differs from the scheme given in Gatys et al. (2016). + for scale in scales: + if self.devices[0].type == 'cuda': + torch.cuda.empty_cache() + + # resize the content image to be smaller than [scale * scale] -> target size + cw, ch = size_to_fit(content_image.size, scale, scale_up=True) + content = TF.to_tensor(content_image.resize( + (cw, ch), Image.LANCZOS))[None] + content = content.to(self.devices[0]) + + # resize the mask along with the content iamge + mask = TF.to_tensor(sky_mask.resize((cw, ch), Image.LANCZOS))[None] + mask = mask.to(self.devices[0]) + + grad_loss = Scale(LayerApply(GradientLoss( + content, mask, sky_weight), 'input'), grad_weight) + + # interpolate the initial image to the target size + self.image = interpolate( + self.image.detach(), (ch, cw), mode='bicubic').clamp(0, 1) + # averaging across the time?? + self.average = EMA(self.image, avg_decay) + self.image.requires_grad_() + + print(f'Processing content image ({cw}x{ch})...') + # add ContentLoss + content_feats = self.model(content, layers=self.content_layers) + content_losses = [] + for layer, weight in zip(self.content_layers, content_weights): + target = content_feats[layer] # target content feature + # how to calculate content loss? + content_losses.append( + Scale(LayerApply(ContentLoss(target), layer), weight)) + + style_targets, style_losses = {}, [] + # add StyleLoss + for i, image in enumerate(style_images): + # resize the image and load it to GPU + if style_size is None: + sw, sh = size_to_fit( + image.size, round(scale * style_scale_fac)) + else: + sw, sh = size_to_fit(image.size, style_size) + style = TF.to_tensor(image.resize( + (sw, sh), Image.LANCZOS))[None] + style = style.to(self.devices[0]) + + print(f'Processing style image ({sw}x{sh})...') + style_feats = self.model(style, layers=self.style_layers) + # Take the weighted average of multiple style targets (Gram matrices). + for layer in self.style_layers: + target = StyleLoss.get_target( + style_feats[layer]) * style_weights[i] + if layer not in style_targets: + style_targets[layer] = target + else: + style_targets[layer] += target + for layer, weight in zip(self.style_layers, self.style_weights): + target = style_targets[layer] + style_losses.append( + Scale(LayerApply(StyleLoss(target), layer), weight)) + + # Construct a list of losses + crit = SumLoss( + [*content_losses, *style_losses, tv_loss, grad_loss], verbose=False) + + # Warm-start the Adam optimizer if this is not the first scale. (load the previous optimizer state) + opt2 = optim.Adam([self.image], lr=step_size) + if scale != scales[0]: + opt_state = scale_adam(opt.state_dict(), (ch, cw)) + opt2.load_state_dict(opt_state) + opt = opt2 + + # empty GPU cache + if self.devices[0].type == 'cuda': + torch.cuda.empty_cache() + + # forward & backward propagation + # first scale: 1000, others: 500 + actual_its = initial_iterations if scale == scales[0] else iterations + for i in range(1, actual_its + 1): + feats = self.model(self.image) + loss = crit(feats) # calculate all the losses at the same time + opt.zero_grad() + loss.backward() + opt.step() + + # Enforce box constraints. + with torch.no_grad(): + self.image.clamp_(0, 1) + + # do averaging along time (to be investigated) + self.average.update(self.image) + + # what does a callback function do? (to be investigated) + if callback is not None: + gpu_ram = 0 + for device in self.devices: + if device.type == 'cuda': + gpu_ram = max( + gpu_ram, torch.cuda.max_memory_allocated(device)) + callback(STIterate(w=cw, h=ch, i=i, i_max=actual_its, loss=loss.item(), + time=time.time(), gpu_ram=gpu_ram)) + + # Initialize each new scale with the previous scale's averaged iterate. + with torch.no_grad(): + self.image.copy_(self.average.get()) + + return self.get_image()