Skip to content

Do GPU architectures interfere with file encoding/decoding? #118

@QLaHPD

Description

@QLaHPD

I created two scripts to enable ffmpeg PIPE encoding and decoding, using the original video source instead of a YUV file of it, which can take up TeraBytes of storage for some videos, making it impractical to use in this way.

Testing on a computer with an RTX 3050 Mobile 4GB and on another computer with an RTX 4070, the videos encoded on each computer are decoded correctly on the same computer where the .bin file was created, but when trying to decode it on the other computer, the decoded frames are “pure noise”.

It's also worth observing that, for some videos, the RTX 4070 failed to encode using the cuda inference, but when setting it to false (inside src/layers/cuda_inference.py) it does work, the same was not observed in the 3050; Videos encoded with the pytorch backend also have the problem mentioned previously.

This raises the question of whether the .bin file being produced on each computer is different from each other, which could be related to the architecture of the gpu OR to some deep configuration of the software running it (including my scripts), by the way, both computers are using:

NVIDIA-SMI 575.64.03              Driver Version: 575.64.03      CUDA Version: 12.9


nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2025 NVIDIA Corporation
Built on Tue_May_27_02:21:03_PDT_2025
Cuda compilation tools, release 12.9, V12.9.86
Build cuda_12.9.r12.9/compiler.36037853_0


print(torch.__version__)
2.7.1+cu126


ffmpeg version 6.1.1-3ubuntu5 Copyright (c) 2000-2023 the FFmpeg developers
built with gcc 13 (Ubuntu 13.2.0-23ubuntu3)

In both I compiled the libs of README

My code base structure looks like this:

.
├── checkpoints
│   ├── cvpr2025_image.pth.tar
│   └── cvpr2025_video.pth.tar
├── decode.py
├── encode.py
├── README.md
├── src
│   ├── cpp
│   │   ├── py_rans
│   │   │   ├── py_rans.cpp
│   │   │   ├── py_rans.h
│   │   │   ├── rans_byte.h
│   │   │   ├── rans.cpp
│   │   │   └── rans.h
│   │   └── setup.py
│   ├── layers
│   │   ├── cuda_inference.py
│   │   ├── extensions
│   │   │   └── inference
│   │   │       ├── bind.cpp
│   │   │       ├── common.h
│   │   │       ├── def.h
│   │   │       ├── impl.cpp
│   │   │       ├── kernel.cu
│   │   │       └── setup.py
│   │   └── layers.py
│   ├── models
│   │   ├── common_model.py
│   │   ├── entropy_models.py
│   │   ├── image_model.py
│   │   └── video_model.py
│   └── utils
│       ├── common.py
│       ├── metrics.py
│       ├── stream_helper.py
│       ├── transforms.py
│       ├── video_reader.py
│       └── video_writer.py
└── test_conditions.md

10 directories, 30 files

I used the root DCVC code base from this git repo, which means I'm not using the files inside DCVC-family. I'm using the DCVC-RT model files; below I will append both encode.py and decode.py, all the other codes remain unmodified.

# encode.py
import argparse
import concurrent.futures
import io
import json
import logging
import multiprocessing
import os
import signal
import sys
import time
import subprocess

import torch
import numpy as np
from tqdm import tqdm

from src.layers.cuda_inference import replicate_pad
from src.models.video_model import DMC
from src.models.image_model import DMCI
from src.utils.common import str2bool, create_folder, get_state_dict, set_torch_env
from src.utils.stream_helper import SPSHelper, write_sps, write_ip
from src.utils.transforms import ycbcr420_to_444_np


def signal_handler(sig, frame):
    """Handles Ctrl+C, terminating all child processes."""
    print('\nCtrl+C detected! Shutting down all processes...')
    # Terminate all active child processes created by multiprocessing
    for p in multiprocessing.active_children():
        p.terminate()
        p.join() # Wait for termination to complete
    sys.exit(1)


def extract_video_data(video_path):
    """Extracts video width and height using ffprobe."""
    command = [
        'ffprobe', '-v', 'error', '-select_streams', 'v:0',
        '-show_entries', 'stream=width,height', '-of', 'json', video_path
    ]
    try:
        result = subprocess.run(command, capture_output=True, text=True, check=True)
        data = json.loads(result.stdout)['streams'][0]
        width = int(data['width'])
        height = int(data['height'])
        return width, height
    except (subprocess.CalledProcessError, ValueError, KeyError, IndexError) as e:
        print(f"FATAL: Could not extract video dimensions for {os.path.basename(video_path)}. Error: {e}")
        return 0, 0

def parse_args():
    parser = argparse.ArgumentParser(description="Video Encoding Script")
    parser.add_argument('--model_path_i', type=str, default="./checkpoints/cvpr2025_image.pth.tar")
    parser.add_argument('--model_path_p', type=str, default="./checkpoints/cvpr2025_video.pth.tar")
    parser.add_argument('--qp_i', type=int, required=True, help="QP for I-frames.")
    parser.add_argument('--qp_p', type=int, required=True, help="QP for P-frames.")
    parser.add_argument("--input_folder", type=str, required=True)
    parser.add_argument("--output_folder", type=str, required=True)
    parser.add_argument("-res", "--resolution", type=int, default=None,
                        help="Scale the video's smallest dimension to this resolution, preserving aspect ratio.")
    parser.add_argument("--worker", "-w", type=int, default=1)
    parser.add_argument("--cuda", type=str2bool, default=True)
    parser.add_argument('--cuda_idx', type=int, nargs="+")
    parser.add_argument("--force_intra_period", type=int, default=-1)
    parser.add_argument('--reset_interval', type=int, default=64)
    parser.add_argument('--force_zero_thres', type=float, default=None)
    return parser.parse_args()

def get_src_reader(video_path, target_width=None, target_height=None):
    command = ['ffmpeg', '-i', video_path]
    if target_width is not None and target_height is not None:
        command.extend(['-vf', f'scale={target_width}:{target_height}'])
    command.extend(['-f', 'rawvideo', '-pix_fmt', 'yuv420p', '-'])
    return subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)


def np_image_to_tensor(img, device):
    image = torch.from_numpy(img).to(device=device, dtype=torch.float32) / 255.0
    return image.unsqueeze(0).half()

def run_encoding(p_frame_net, i_frame_net, args):
    device = next(i_frame_net.parameters()).device
    pic_width = args['process_width']
    pic_height = args['process_height']
    seq_name = os.path.basename(args['src_path'])

    if pic_width == 0 or pic_height == 0:
        print(f"Skipping {seq_name} due to earlier error in reading video dimensions.")
        return

    padding_r, padding_b = DMCI.get_padding_size(pic_height, pic_width, 16)
    src_pipe = get_src_reader(args['src_path'], pic_width, pic_height)

    use_two_entropy_coders = pic_height * pic_width > 1280 * 720
    i_frame_net.set_use_two_entropy_coders(use_two_entropy_coders)
    p_frame_net.set_use_two_entropy_coders(use_two_entropy_coders)

    output_buff = io.BytesIO()
    sps_helper = SPSHelper()
    p_frame_net.set_curr_poc(0)

    with torch.no_grad():
        last_qp = 0
        frame_idx = 0
        with tqdm(desc=f"Encoding {seq_name}", leave=False, unit="fr") as pbar:
            while True:
                y_data_size = pic_width * pic_height
                uv_data_size = (pic_width // 2) * (pic_height // 2)

                y_data = src_pipe.stdout.read(y_data_size)
                if not y_data:
                    break

                u_data = src_pipe.stdout.read(uv_data_size)
                v_data = src_pipe.stdout.read(uv_data_size)
                if not u_data or not v_data:
                    break

                y = np.frombuffer(y_data, dtype=np.uint8).reshape(1, pic_height, pic_width)
                uv = np.stack([
                    np.frombuffer(u_data, dtype=np.uint8).reshape(pic_height // 2, pic_width // 2),
                    np.frombuffer(v_data, dtype=np.uint8).reshape(pic_height // 2, pic_width // 2)
                ])
                x = np_image_to_tensor(ycbcr420_to_444_np(y, uv), device)
                x_padded = replicate_pad(x, padding_b, padding_r)

                is_i_frame = (frame_idx == 0) or \
                             (args['intra_period'] > 0 and frame_idx % args['intra_period'] == 0)

                if is_i_frame:
                    curr_qp = args['qp_i']
                    use_ada_i = 0
                    encoded = i_frame_net.compress(x_padded, curr_qp)
                    p_frame_net.clear_dpb()
                    p_frame_net.add_ref_frame(None, encoded['x_hat'])
                else:
                    use_ada_i = 1 if args['reset_interval'] > 0 and frame_idx % args['reset_interval'] == 1 else 0
                    if use_ada_i: p_frame_net.prepare_feature_adaptor_i(last_qp)
                    curr_qp = p_frame_net.shift_qp(args['qp_p'], [0, 1, 0, 2, 0, 2, 0, 2][frame_idx % 8])
                    encoded = p_frame_net.compress(x_padded, curr_qp)
                    last_qp = curr_qp

                sps = {'height': pic_height, 'width': pic_width, 'ec_part': int(use_two_entropy_coders), 'use_ada_i': use_ada_i}
                sps_id, is_new_sps = sps_helper.get_sps_id(sps)
                sps['sps_id'] = sps_id
                if is_new_sps: write_sps(output_buff, sps)
                write_ip(output_buff, is_i_frame, sps_id, curr_qp, encoded['bit_stream'])

                frame_idx += 1
                pbar.update(1)

    src_pipe.stdout.close()
    src_pipe.wait()

    with open(args['bin_path'], "wb") as f:
        f.write(output_buff.getbuffer())


def worker(args):
    bin_folder = args['output_folder']
    create_folder(bin_folder, True)
    args['src_path'] = os.path.join(args['input_folder'], args['seq'])

    try:
        original_width, original_height = extract_video_data(args['src_path'])

        # --- Aspect Ratio Scaling Logic ---
        if args['resolution'] is not None and original_width > 0 and original_height > 0:
            R = args['resolution']
            if original_height < original_width:  # Landscape
                new_height = R
                new_width = int(R * original_width / original_height)
            elif original_width < original_height:  # Portrait
                new_width = R
                new_height = int(R * original_height / original_width)
            else:  # Square
                new_width = R
                new_height = R
            # Ensure final dimensions are even for compatibility
            args['process_width'] = (new_width // 2) * 2
            args['process_height'] = (new_height // 2) * 2
        else:
            args['process_width'] = original_width
            args['process_height'] = original_height
        # --- End of Logic ---

        base_name = f"{os.path.splitext(args['seq'])[0]}_{args['process_width']}x{args['process_height']}_qI{args['qp_i']}_qP{args['qp_p']}"
        args['bin_path'] = os.path.join(bin_folder, f"{base_name}.bin")

        run_encoding(p_frame_net, i_frame_net, args)
    except Exception as e:
        print(f"A critical error occurred while processing {args['seq']}: {e}")
        raise

def init_func(args, gpu_num):
    # Make child processes ignore the interrupt signal; the main process will handle it
    signal.signal(signal.SIGINT, signal.SIG_IGN)
    set_torch_env()
    process_idx = int(multiprocessing.current_process().name.split('-')[-1]) - 1
    gpu_id = -1
    if gpu_num > 0:
        gpu_id = (args.cuda_idx[process_idx % len(args.cuda_idx)] if args.cuda_idx else process_idx % gpu_num)

    device = f"cuda:{gpu_id}" if gpu_id != -1 else "cpu"
    if gpu_id != -1: os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)

    global i_frame_net, p_frame_net
    i_frame_net = DMCI().to(device).eval().half()
    i_frame_net.load_state_dict(get_state_dict(args.model_path_i))
    i_frame_net.update(args.force_zero_thres)

    p_frame_net = DMC().to(device).eval().half()
    p_frame_net.load_state_dict(get_state_dict(args.model_path_p))
    p_frame_net.update(args.force_zero_thres)

def main():
    # Register the signal handler for the main process
    signal.signal(signal.SIGINT, signal_handler)

    args = parse_args()
    worker_num = args.worker
    gpu_num = torch.cuda.device_count() if args.cuda else 0
    if gpu_num == 0 and args.cuda:
        print("Warning: --cuda specified but no devices found. Running on CPU.")

    processed_basenames = set()
    if os.path.exists(args.output_folder):
        for f in os.listdir(args.output_folder):
            if f.lower().endswith('.bin'):
                parts = os.path.splitext(f)[0].split('_')
                if len(parts) > 3:
                    basename = '_'.join(parts[:-3])
                    processed_basenames.add(basename)

    all_video_files = [f for f in os.listdir(args.input_folder) if f.lower().endswith(('.mp4', '.mkv', '.avi', '.mov', '.webm'))]

    videos_to_process = []
    for video_file in all_video_files:
        input_basename = os.path.splitext(video_file)[0]
        if input_basename not in processed_basenames:
            videos_to_process.append(video_file)

    print(f"Found {len(all_video_files)} total videos in '{args.input_folder}'.")
    skipped_count = len(all_video_files) - len(videos_to_process)
    if skipped_count > 0:
        print(f"Skipping {skipped_count} videos that already have a corresponding '.bin' file in the output folder.")

    if not videos_to_process:
        print("All videos are already processed. Exiting.")
        return

    multiprocessing.set_start_method("spawn", force=True)
    try:
        with concurrent.futures.ProcessPoolExecutor(
            max_workers=worker_num, initializer=init_func, initargs=(args, gpu_num)
        ) as executor:
            futures = []
            for seq in videos_to_process:
                cur_args = vars(args).copy()
                cur_args['seq'] = seq
                cur_args['intra_period'] = args.force_intra_period
                futures.append(executor.submit(worker, cur_args))

            for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Encoding videos"):
                try:
                    future.result()
                except Exception as e:
                    print(f"\nA worker failed with an exception: {e}")

    except KeyboardInterrupt:
        print("\nMain loop interrupted. The signal handler will now exit.")
        pass

    print(f'Encoding finished for {len(videos_to_process)} sequences.')

if __name__ == "__main__":
    main()
# decode.py
import argparse
import concurrent.futures
import io
import json
import multiprocessing
import os
import subprocess
import struct

import torch
import numpy as np
from tqdm import tqdm

from src.models.video_model import DMC
from src.models.image_model import DMCI
from src.utils.common import str2bool, create_folder, get_state_dict, set_torch_env
from src.utils.stream_helper import SPSHelper, NalType, read_header, read_sps_remaining, read_ip_remaining
from src.utils.video_writer import YUV420Writer
from src.utils.transforms import yuv_444_to_420

def parse_args():
    parser = argparse.ArgumentParser(description="Video Decoding Script")
    parser.add_argument('--model_path_i', type=str, default="./checkpoints/cvpr2025_image.pth.tar")
    parser.add_argument('--model_path_p', type=str, default="./checkpoints/cvpr2025_video.pth.tar")
    parser.add_argument("--input_folder", type=str, required=True, help="Folder with .bin files.")
    parser.add_argument("--output_folder", type=str, required=True, help="Folder for decoded .yuv files.")
    parser.add_argument("--original_folder", type=str, default=None, help="[Optional] Folder with original videos for bitrate/frame count.")
    parser.add_argument("--worker", "-w", type=int, default=1)
    parser.add_argument("--cuda", type=str2bool, default=True)
    parser.add_argument('--cuda_idx', type=int, nargs="+")
    parser.add_argument('--force_zero_thres', type=float, default=None)
    return parser.parse_args()

def extract_video_data(video_path):
    """Extracts frame rate and frame count from a video file."""
    command = [
        'ffprobe', '-v', 'error', '-select_streams', 'v:0',
        '-show_entries', 'stream=r_frame_rate,nb_frames', '-of', 'json', video_path
    ]
    result = subprocess.run(command, capture_output=True, text=True, check=True)
    data = json.loads(result.stdout)['streams'][0]
    frame_rate = eval(data.get('r_frame_rate', '30/1'))
    frame_count = int(data.get('nb_frames', 0))
    if frame_count == 0:
        count_command = ['ffprobe', '-v', 'error', '-count_frames', '-select_streams', 'v:0', '-show_entries', 'stream=nb_read_frames', '-of', 'default=nokey=1:noprint_wrappers=1', video_path]
        count_result = subprocess.run(count_command, capture_output=True, text=True, check=True)
        frame_count = int(count_result.stdout.strip())
    return frame_rate, frame_count

def run_decoding(p_frame_net, i_frame_net, args):
    sps_helper = SPSHelper()
    
    with open(args['bin_path'], "rb") as f:
        input_buff = io.BytesIO(f.read())
        bin_size = f.tell()

    header = read_header(input_buff)
    while header['nal_type'] != NalType.NAL_SPS:
        _, _ = read_ip_remaining(input_buff)
        header = read_header(input_buff)
    sps = read_sps_remaining(input_buff, header['sps_id'])
    pic_height, pic_width = sps['height'], sps['width']
    
    # Calculate kbps and frame number if original video is provided
    total_kbps = 0
    frame_num = 0
    original_path = args.get('original_path')

    if original_path and os.path.exists(original_path):
        try:
            frame_rate, frame_num = extract_video_data(original_path)
            if frame_num > 0 and frame_rate > 0:
                total_kbps = int(bin_size * 8 / (frame_num / frame_rate) / 1000)
        except (subprocess.CalledProcessError, json.JSONDecodeError, IndexError) as e:
            print(f"Warning: Could not process original video {original_path}: {e}. Will decode until end of stream.")
            frame_num = 0 # Reset frame_num to trigger fallback
    
    if frame_num == 0:
        if original_path:
             # This message is for when the file was provided but failed to be read
            print(f"Warning: Failed to get frame count for {args['base_name']}.")
        else:
            # This message is for when the original folder was not provided at all
            print(f"Info: Original video not provided for {args['base_name']}.")
        print("Decoding will proceed until the end of the bitstream.")
        frame_num = 1_000_000 # Use a large number and let the loop break on EOF

    output_yuv_name = f"{args['base_name']}_{total_kbps}kbps.yuv" if total_kbps > 0 else f"{args['base_name']}.yuv"
    output_yuv_path = os.path.join(args['output_folder'], output_yuv_name)
    recon_writer = YUV420Writer(output_yuv_path, pic_width, pic_height)
    
    p_frame_net.set_curr_poc(0)

    input_buff.seek(0)

    for _ in tqdm(range(frame_num), desc=f"Decoding {args['base_name']}", leave=False):
        try:
            header = read_header(input_buff)
            while header['nal_type'] == NalType.NAL_SPS:
                sps = read_sps_remaining(input_buff, header['sps_id'])
                sps_helper.add_sps_by_id(sps)
                header = read_header(input_buff)
            
            sps = sps_helper.get_sps_by_id(header['sps_id'])
            qp, bit_stream = read_ip_remaining(input_buff)
            
            if header['nal_type'] == NalType.NAL_I:
                decoded = i_frame_net.decompress(bit_stream, sps, qp)
                p_frame_net.clear_dpb()
                p_frame_net.add_ref_frame(None, decoded['x_hat'])
            else:
                if sps['use_ada_i']: p_frame_net.reset_ref_feature()
                decoded = p_frame_net.decompress(bit_stream, sps, qp)

            x_hat = decoded['x_hat'][:, :, :pic_height, :pic_width]

            y_rec, uv_rec = yuv_444_to_420(x_hat)
            y_rec_np = torch.clamp(y_rec * 255, 0, 255).byte().squeeze().cpu().numpy()
            uv_rec_np = torch.clamp(uv_rec * 255, 0, 255).byte().squeeze().cpu().numpy()
            recon_writer.write_one_frame(y_rec_np, uv_rec_np)
        except (struct.error, IndexError): # Catches errors when reading past the end of the buffer
            break
    
    recon_writer.close()

def worker(args):
    # original_path is now prepared in main() and passed within args
    run_decoding(p_frame_net, i_frame_net, args)

def init_func(args, gpu_num):
    set_torch_env()
    process_idx = int(multiprocessing.current_process().name.split('-')[-1]) - 1
    gpu_id = -1
    if gpu_num > 0:
        gpu_id = (args.cuda_idx[process_idx % len(args.cuda_idx)] if args.cuda_idx else process_idx % gpu_num)
    
    device = f"cuda:{gpu_id}" if gpu_id != -1 else "cpu"
    if gpu_id != -1: os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)

    global i_frame_net, p_frame_net
    i_frame_net = DMCI().to(device).eval().half()
    i_frame_net.load_state_dict(get_state_dict(args.model_path_i))
    i_frame_net.update(args.force_zero_thres)

    p_frame_net = DMC().to(device).eval().half()
    p_frame_net.load_state_dict(get_state_dict(args.model_path_p))
    p_frame_net.update(args.force_zero_thres)

def main():
    args = parse_args()
    create_folder(args.output_folder)
    worker_num = args.worker
    gpu_num = torch.cuda.device_count() if args.cuda else 0
    if gpu_num == 0 and args.cuda:
        print("Warning: --cuda specified but no devices found. Running on CPU.")
    
    bin_files = [f for f in os.listdir(args.input_folder) if f.lower().endswith('.bin')]
    
    original_videos = {}
    if args.original_folder:
        try:
            original_videos = {os.path.splitext(f)[0].lower(): f for f in os.listdir(args.original_folder)}
        except FileNotFoundError:
            print(f"Warning: Original folder '{args.original_folder}' not found. Continuing without it.")
            args.original_folder = None # Ensure it is treated as not provided
    
    multiprocessing.set_start_method("spawn", force=True)
    with concurrent.futures.ProcessPoolExecutor(
        max_workers=worker_num, initializer=init_func, initargs=(args, gpu_num)
    ) as executor:
        futures = []
        for bin_file in bin_files:
            base_name = os.path.splitext(bin_file)[0]
            
            cur_args = vars(args).copy()
            cur_args['bin_path'] = os.path.join(args.input_folder, bin_file)
            cur_args['base_name'] = base_name
            cur_args['original_path'] = None # Default to None

            if args.original_folder:
                original_base_name = base_name.split('_qI')[0].lower()
                # Find the corresponding original video, assuming .mp4 as a fallback extension
                original_seq_name = original_videos.get(original_base_name, f"{original_base_name}.mp4")
                cur_args['original_path'] = os.path.join(args.original_folder, original_seq_name)
            
            futures.append(executor.submit(worker, cur_args))

        for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Decoding videos"):
            try:
                future.result()
            except Exception as e:
                print(f"A worker failed with an exception: {e}")

    print(f'Decoding finished for {len(futures)} files.')

if __name__ == "__main__":
    main()

Edit 1: The paper mentions a lookup table for the non linear functions, however I can't find it in the code, is it avaiable anywhere?

Please, can you tell me if there is anything wrong in my scrips regarding on how the model should be used?
Thank you for this great work!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions