-
Notifications
You must be signed in to change notification settings - Fork 104
Description
I created two scripts to enable ffmpeg PIPE encoding and decoding, using the original video source instead of a YUV file of it, which can take up TeraBytes of storage for some videos, making it impractical to use in this way.
Testing on a computer with an RTX 3050 Mobile 4GB and on another computer with an RTX 4070, the videos encoded on each computer are decoded correctly on the same computer where the .bin file was created, but when trying to decode it on the other computer, the decoded frames are “pure noise”.
It's also worth observing that, for some videos, the RTX 4070 failed to encode using the cuda inference, but when setting it to false (inside src/layers/cuda_inference.py) it does work, the same was not observed in the 3050; Videos encoded with the pytorch backend also have the problem mentioned previously.
This raises the question of whether the .bin file being produced on each computer is different from each other, which could be related to the architecture of the gpu OR to some deep configuration of the software running it (including my scripts), by the way, both computers are using:
NVIDIA-SMI 575.64.03 Driver Version: 575.64.03 CUDA Version: 12.9
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2025 NVIDIA Corporation
Built on Tue_May_27_02:21:03_PDT_2025
Cuda compilation tools, release 12.9, V12.9.86
Build cuda_12.9.r12.9/compiler.36037853_0
print(torch.__version__)
2.7.1+cu126
ffmpeg version 6.1.1-3ubuntu5 Copyright (c) 2000-2023 the FFmpeg developers
built with gcc 13 (Ubuntu 13.2.0-23ubuntu3)
In both I compiled the libs of README
My code base structure looks like this:
.
├── checkpoints
│ ├── cvpr2025_image.pth.tar
│ └── cvpr2025_video.pth.tar
├── decode.py
├── encode.py
├── README.md
├── src
│ ├── cpp
│ │ ├── py_rans
│ │ │ ├── py_rans.cpp
│ │ │ ├── py_rans.h
│ │ │ ├── rans_byte.h
│ │ │ ├── rans.cpp
│ │ │ └── rans.h
│ │ └── setup.py
│ ├── layers
│ │ ├── cuda_inference.py
│ │ ├── extensions
│ │ │ └── inference
│ │ │ ├── bind.cpp
│ │ │ ├── common.h
│ │ │ ├── def.h
│ │ │ ├── impl.cpp
│ │ │ ├── kernel.cu
│ │ │ └── setup.py
│ │ └── layers.py
│ ├── models
│ │ ├── common_model.py
│ │ ├── entropy_models.py
│ │ ├── image_model.py
│ │ └── video_model.py
│ └── utils
│ ├── common.py
│ ├── metrics.py
│ ├── stream_helper.py
│ ├── transforms.py
│ ├── video_reader.py
│ └── video_writer.py
└── test_conditions.md
10 directories, 30 files
I used the root DCVC code base from this git repo, which means I'm not using the files inside DCVC-family. I'm using the DCVC-RT model files; below I will append both encode.py and decode.py, all the other codes remain unmodified.
# encode.py
import argparse
import concurrent.futures
import io
import json
import logging
import multiprocessing
import os
import signal
import sys
import time
import subprocess
import torch
import numpy as np
from tqdm import tqdm
from src.layers.cuda_inference import replicate_pad
from src.models.video_model import DMC
from src.models.image_model import DMCI
from src.utils.common import str2bool, create_folder, get_state_dict, set_torch_env
from src.utils.stream_helper import SPSHelper, write_sps, write_ip
from src.utils.transforms import ycbcr420_to_444_np
def signal_handler(sig, frame):
"""Handles Ctrl+C, terminating all child processes."""
print('\nCtrl+C detected! Shutting down all processes...')
# Terminate all active child processes created by multiprocessing
for p in multiprocessing.active_children():
p.terminate()
p.join() # Wait for termination to complete
sys.exit(1)
def extract_video_data(video_path):
"""Extracts video width and height using ffprobe."""
command = [
'ffprobe', '-v', 'error', '-select_streams', 'v:0',
'-show_entries', 'stream=width,height', '-of', 'json', video_path
]
try:
result = subprocess.run(command, capture_output=True, text=True, check=True)
data = json.loads(result.stdout)['streams'][0]
width = int(data['width'])
height = int(data['height'])
return width, height
except (subprocess.CalledProcessError, ValueError, KeyError, IndexError) as e:
print(f"FATAL: Could not extract video dimensions for {os.path.basename(video_path)}. Error: {e}")
return 0, 0
def parse_args():
parser = argparse.ArgumentParser(description="Video Encoding Script")
parser.add_argument('--model_path_i', type=str, default="./checkpoints/cvpr2025_image.pth.tar")
parser.add_argument('--model_path_p', type=str, default="./checkpoints/cvpr2025_video.pth.tar")
parser.add_argument('--qp_i', type=int, required=True, help="QP for I-frames.")
parser.add_argument('--qp_p', type=int, required=True, help="QP for P-frames.")
parser.add_argument("--input_folder", type=str, required=True)
parser.add_argument("--output_folder", type=str, required=True)
parser.add_argument("-res", "--resolution", type=int, default=None,
help="Scale the video's smallest dimension to this resolution, preserving aspect ratio.")
parser.add_argument("--worker", "-w", type=int, default=1)
parser.add_argument("--cuda", type=str2bool, default=True)
parser.add_argument('--cuda_idx', type=int, nargs="+")
parser.add_argument("--force_intra_period", type=int, default=-1)
parser.add_argument('--reset_interval', type=int, default=64)
parser.add_argument('--force_zero_thres', type=float, default=None)
return parser.parse_args()
def get_src_reader(video_path, target_width=None, target_height=None):
command = ['ffmpeg', '-i', video_path]
if target_width is not None and target_height is not None:
command.extend(['-vf', f'scale={target_width}:{target_height}'])
command.extend(['-f', 'rawvideo', '-pix_fmt', 'yuv420p', '-'])
return subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
def np_image_to_tensor(img, device):
image = torch.from_numpy(img).to(device=device, dtype=torch.float32) / 255.0
return image.unsqueeze(0).half()
def run_encoding(p_frame_net, i_frame_net, args):
device = next(i_frame_net.parameters()).device
pic_width = args['process_width']
pic_height = args['process_height']
seq_name = os.path.basename(args['src_path'])
if pic_width == 0 or pic_height == 0:
print(f"Skipping {seq_name} due to earlier error in reading video dimensions.")
return
padding_r, padding_b = DMCI.get_padding_size(pic_height, pic_width, 16)
src_pipe = get_src_reader(args['src_path'], pic_width, pic_height)
use_two_entropy_coders = pic_height * pic_width > 1280 * 720
i_frame_net.set_use_two_entropy_coders(use_two_entropy_coders)
p_frame_net.set_use_two_entropy_coders(use_two_entropy_coders)
output_buff = io.BytesIO()
sps_helper = SPSHelper()
p_frame_net.set_curr_poc(0)
with torch.no_grad():
last_qp = 0
frame_idx = 0
with tqdm(desc=f"Encoding {seq_name}", leave=False, unit="fr") as pbar:
while True:
y_data_size = pic_width * pic_height
uv_data_size = (pic_width // 2) * (pic_height // 2)
y_data = src_pipe.stdout.read(y_data_size)
if not y_data:
break
u_data = src_pipe.stdout.read(uv_data_size)
v_data = src_pipe.stdout.read(uv_data_size)
if not u_data or not v_data:
break
y = np.frombuffer(y_data, dtype=np.uint8).reshape(1, pic_height, pic_width)
uv = np.stack([
np.frombuffer(u_data, dtype=np.uint8).reshape(pic_height // 2, pic_width // 2),
np.frombuffer(v_data, dtype=np.uint8).reshape(pic_height // 2, pic_width // 2)
])
x = np_image_to_tensor(ycbcr420_to_444_np(y, uv), device)
x_padded = replicate_pad(x, padding_b, padding_r)
is_i_frame = (frame_idx == 0) or \
(args['intra_period'] > 0 and frame_idx % args['intra_period'] == 0)
if is_i_frame:
curr_qp = args['qp_i']
use_ada_i = 0
encoded = i_frame_net.compress(x_padded, curr_qp)
p_frame_net.clear_dpb()
p_frame_net.add_ref_frame(None, encoded['x_hat'])
else:
use_ada_i = 1 if args['reset_interval'] > 0 and frame_idx % args['reset_interval'] == 1 else 0
if use_ada_i: p_frame_net.prepare_feature_adaptor_i(last_qp)
curr_qp = p_frame_net.shift_qp(args['qp_p'], [0, 1, 0, 2, 0, 2, 0, 2][frame_idx % 8])
encoded = p_frame_net.compress(x_padded, curr_qp)
last_qp = curr_qp
sps = {'height': pic_height, 'width': pic_width, 'ec_part': int(use_two_entropy_coders), 'use_ada_i': use_ada_i}
sps_id, is_new_sps = sps_helper.get_sps_id(sps)
sps['sps_id'] = sps_id
if is_new_sps: write_sps(output_buff, sps)
write_ip(output_buff, is_i_frame, sps_id, curr_qp, encoded['bit_stream'])
frame_idx += 1
pbar.update(1)
src_pipe.stdout.close()
src_pipe.wait()
with open(args['bin_path'], "wb") as f:
f.write(output_buff.getbuffer())
def worker(args):
bin_folder = args['output_folder']
create_folder(bin_folder, True)
args['src_path'] = os.path.join(args['input_folder'], args['seq'])
try:
original_width, original_height = extract_video_data(args['src_path'])
# --- Aspect Ratio Scaling Logic ---
if args['resolution'] is not None and original_width > 0 and original_height > 0:
R = args['resolution']
if original_height < original_width: # Landscape
new_height = R
new_width = int(R * original_width / original_height)
elif original_width < original_height: # Portrait
new_width = R
new_height = int(R * original_height / original_width)
else: # Square
new_width = R
new_height = R
# Ensure final dimensions are even for compatibility
args['process_width'] = (new_width // 2) * 2
args['process_height'] = (new_height // 2) * 2
else:
args['process_width'] = original_width
args['process_height'] = original_height
# --- End of Logic ---
base_name = f"{os.path.splitext(args['seq'])[0]}_{args['process_width']}x{args['process_height']}_qI{args['qp_i']}_qP{args['qp_p']}"
args['bin_path'] = os.path.join(bin_folder, f"{base_name}.bin")
run_encoding(p_frame_net, i_frame_net, args)
except Exception as e:
print(f"A critical error occurred while processing {args['seq']}: {e}")
raise
def init_func(args, gpu_num):
# Make child processes ignore the interrupt signal; the main process will handle it
signal.signal(signal.SIGINT, signal.SIG_IGN)
set_torch_env()
process_idx = int(multiprocessing.current_process().name.split('-')[-1]) - 1
gpu_id = -1
if gpu_num > 0:
gpu_id = (args.cuda_idx[process_idx % len(args.cuda_idx)] if args.cuda_idx else process_idx % gpu_num)
device = f"cuda:{gpu_id}" if gpu_id != -1 else "cpu"
if gpu_id != -1: os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
global i_frame_net, p_frame_net
i_frame_net = DMCI().to(device).eval().half()
i_frame_net.load_state_dict(get_state_dict(args.model_path_i))
i_frame_net.update(args.force_zero_thres)
p_frame_net = DMC().to(device).eval().half()
p_frame_net.load_state_dict(get_state_dict(args.model_path_p))
p_frame_net.update(args.force_zero_thres)
def main():
# Register the signal handler for the main process
signal.signal(signal.SIGINT, signal_handler)
args = parse_args()
worker_num = args.worker
gpu_num = torch.cuda.device_count() if args.cuda else 0
if gpu_num == 0 and args.cuda:
print("Warning: --cuda specified but no devices found. Running on CPU.")
processed_basenames = set()
if os.path.exists(args.output_folder):
for f in os.listdir(args.output_folder):
if f.lower().endswith('.bin'):
parts = os.path.splitext(f)[0].split('_')
if len(parts) > 3:
basename = '_'.join(parts[:-3])
processed_basenames.add(basename)
all_video_files = [f for f in os.listdir(args.input_folder) if f.lower().endswith(('.mp4', '.mkv', '.avi', '.mov', '.webm'))]
videos_to_process = []
for video_file in all_video_files:
input_basename = os.path.splitext(video_file)[0]
if input_basename not in processed_basenames:
videos_to_process.append(video_file)
print(f"Found {len(all_video_files)} total videos in '{args.input_folder}'.")
skipped_count = len(all_video_files) - len(videos_to_process)
if skipped_count > 0:
print(f"Skipping {skipped_count} videos that already have a corresponding '.bin' file in the output folder.")
if not videos_to_process:
print("All videos are already processed. Exiting.")
return
multiprocessing.set_start_method("spawn", force=True)
try:
with concurrent.futures.ProcessPoolExecutor(
max_workers=worker_num, initializer=init_func, initargs=(args, gpu_num)
) as executor:
futures = []
for seq in videos_to_process:
cur_args = vars(args).copy()
cur_args['seq'] = seq
cur_args['intra_period'] = args.force_intra_period
futures.append(executor.submit(worker, cur_args))
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Encoding videos"):
try:
future.result()
except Exception as e:
print(f"\nA worker failed with an exception: {e}")
except KeyboardInterrupt:
print("\nMain loop interrupted. The signal handler will now exit.")
pass
print(f'Encoding finished for {len(videos_to_process)} sequences.')
if __name__ == "__main__":
main()
# decode.py
import argparse
import concurrent.futures
import io
import json
import multiprocessing
import os
import subprocess
import struct
import torch
import numpy as np
from tqdm import tqdm
from src.models.video_model import DMC
from src.models.image_model import DMCI
from src.utils.common import str2bool, create_folder, get_state_dict, set_torch_env
from src.utils.stream_helper import SPSHelper, NalType, read_header, read_sps_remaining, read_ip_remaining
from src.utils.video_writer import YUV420Writer
from src.utils.transforms import yuv_444_to_420
def parse_args():
parser = argparse.ArgumentParser(description="Video Decoding Script")
parser.add_argument('--model_path_i', type=str, default="./checkpoints/cvpr2025_image.pth.tar")
parser.add_argument('--model_path_p', type=str, default="./checkpoints/cvpr2025_video.pth.tar")
parser.add_argument("--input_folder", type=str, required=True, help="Folder with .bin files.")
parser.add_argument("--output_folder", type=str, required=True, help="Folder for decoded .yuv files.")
parser.add_argument("--original_folder", type=str, default=None, help="[Optional] Folder with original videos for bitrate/frame count.")
parser.add_argument("--worker", "-w", type=int, default=1)
parser.add_argument("--cuda", type=str2bool, default=True)
parser.add_argument('--cuda_idx', type=int, nargs="+")
parser.add_argument('--force_zero_thres', type=float, default=None)
return parser.parse_args()
def extract_video_data(video_path):
"""Extracts frame rate and frame count from a video file."""
command = [
'ffprobe', '-v', 'error', '-select_streams', 'v:0',
'-show_entries', 'stream=r_frame_rate,nb_frames', '-of', 'json', video_path
]
result = subprocess.run(command, capture_output=True, text=True, check=True)
data = json.loads(result.stdout)['streams'][0]
frame_rate = eval(data.get('r_frame_rate', '30/1'))
frame_count = int(data.get('nb_frames', 0))
if frame_count == 0:
count_command = ['ffprobe', '-v', 'error', '-count_frames', '-select_streams', 'v:0', '-show_entries', 'stream=nb_read_frames', '-of', 'default=nokey=1:noprint_wrappers=1', video_path]
count_result = subprocess.run(count_command, capture_output=True, text=True, check=True)
frame_count = int(count_result.stdout.strip())
return frame_rate, frame_count
def run_decoding(p_frame_net, i_frame_net, args):
sps_helper = SPSHelper()
with open(args['bin_path'], "rb") as f:
input_buff = io.BytesIO(f.read())
bin_size = f.tell()
header = read_header(input_buff)
while header['nal_type'] != NalType.NAL_SPS:
_, _ = read_ip_remaining(input_buff)
header = read_header(input_buff)
sps = read_sps_remaining(input_buff, header['sps_id'])
pic_height, pic_width = sps['height'], sps['width']
# Calculate kbps and frame number if original video is provided
total_kbps = 0
frame_num = 0
original_path = args.get('original_path')
if original_path and os.path.exists(original_path):
try:
frame_rate, frame_num = extract_video_data(original_path)
if frame_num > 0 and frame_rate > 0:
total_kbps = int(bin_size * 8 / (frame_num / frame_rate) / 1000)
except (subprocess.CalledProcessError, json.JSONDecodeError, IndexError) as e:
print(f"Warning: Could not process original video {original_path}: {e}. Will decode until end of stream.")
frame_num = 0 # Reset frame_num to trigger fallback
if frame_num == 0:
if original_path:
# This message is for when the file was provided but failed to be read
print(f"Warning: Failed to get frame count for {args['base_name']}.")
else:
# This message is for when the original folder was not provided at all
print(f"Info: Original video not provided for {args['base_name']}.")
print("Decoding will proceed until the end of the bitstream.")
frame_num = 1_000_000 # Use a large number and let the loop break on EOF
output_yuv_name = f"{args['base_name']}_{total_kbps}kbps.yuv" if total_kbps > 0 else f"{args['base_name']}.yuv"
output_yuv_path = os.path.join(args['output_folder'], output_yuv_name)
recon_writer = YUV420Writer(output_yuv_path, pic_width, pic_height)
p_frame_net.set_curr_poc(0)
input_buff.seek(0)
for _ in tqdm(range(frame_num), desc=f"Decoding {args['base_name']}", leave=False):
try:
header = read_header(input_buff)
while header['nal_type'] == NalType.NAL_SPS:
sps = read_sps_remaining(input_buff, header['sps_id'])
sps_helper.add_sps_by_id(sps)
header = read_header(input_buff)
sps = sps_helper.get_sps_by_id(header['sps_id'])
qp, bit_stream = read_ip_remaining(input_buff)
if header['nal_type'] == NalType.NAL_I:
decoded = i_frame_net.decompress(bit_stream, sps, qp)
p_frame_net.clear_dpb()
p_frame_net.add_ref_frame(None, decoded['x_hat'])
else:
if sps['use_ada_i']: p_frame_net.reset_ref_feature()
decoded = p_frame_net.decompress(bit_stream, sps, qp)
x_hat = decoded['x_hat'][:, :, :pic_height, :pic_width]
y_rec, uv_rec = yuv_444_to_420(x_hat)
y_rec_np = torch.clamp(y_rec * 255, 0, 255).byte().squeeze().cpu().numpy()
uv_rec_np = torch.clamp(uv_rec * 255, 0, 255).byte().squeeze().cpu().numpy()
recon_writer.write_one_frame(y_rec_np, uv_rec_np)
except (struct.error, IndexError): # Catches errors when reading past the end of the buffer
break
recon_writer.close()
def worker(args):
# original_path is now prepared in main() and passed within args
run_decoding(p_frame_net, i_frame_net, args)
def init_func(args, gpu_num):
set_torch_env()
process_idx = int(multiprocessing.current_process().name.split('-')[-1]) - 1
gpu_id = -1
if gpu_num > 0:
gpu_id = (args.cuda_idx[process_idx % len(args.cuda_idx)] if args.cuda_idx else process_idx % gpu_num)
device = f"cuda:{gpu_id}" if gpu_id != -1 else "cpu"
if gpu_id != -1: os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
global i_frame_net, p_frame_net
i_frame_net = DMCI().to(device).eval().half()
i_frame_net.load_state_dict(get_state_dict(args.model_path_i))
i_frame_net.update(args.force_zero_thres)
p_frame_net = DMC().to(device).eval().half()
p_frame_net.load_state_dict(get_state_dict(args.model_path_p))
p_frame_net.update(args.force_zero_thres)
def main():
args = parse_args()
create_folder(args.output_folder)
worker_num = args.worker
gpu_num = torch.cuda.device_count() if args.cuda else 0
if gpu_num == 0 and args.cuda:
print("Warning: --cuda specified but no devices found. Running on CPU.")
bin_files = [f for f in os.listdir(args.input_folder) if f.lower().endswith('.bin')]
original_videos = {}
if args.original_folder:
try:
original_videos = {os.path.splitext(f)[0].lower(): f for f in os.listdir(args.original_folder)}
except FileNotFoundError:
print(f"Warning: Original folder '{args.original_folder}' not found. Continuing without it.")
args.original_folder = None # Ensure it is treated as not provided
multiprocessing.set_start_method("spawn", force=True)
with concurrent.futures.ProcessPoolExecutor(
max_workers=worker_num, initializer=init_func, initargs=(args, gpu_num)
) as executor:
futures = []
for bin_file in bin_files:
base_name = os.path.splitext(bin_file)[0]
cur_args = vars(args).copy()
cur_args['bin_path'] = os.path.join(args.input_folder, bin_file)
cur_args['base_name'] = base_name
cur_args['original_path'] = None # Default to None
if args.original_folder:
original_base_name = base_name.split('_qI')[0].lower()
# Find the corresponding original video, assuming .mp4 as a fallback extension
original_seq_name = original_videos.get(original_base_name, f"{original_base_name}.mp4")
cur_args['original_path'] = os.path.join(args.original_folder, original_seq_name)
futures.append(executor.submit(worker, cur_args))
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Decoding videos"):
try:
future.result()
except Exception as e:
print(f"A worker failed with an exception: {e}")
print(f'Decoding finished for {len(futures)} files.')
if __name__ == "__main__":
main()
Edit 1: The paper mentions a lookup table for the non linear functions, however I can't find it in the code, is it avaiable anywhere?
Please, can you tell me if there is anything wrong in my scrips regarding on how the model should be used?
Thank you for this great work!