diff --git a/README.md b/README.md
index 4fe6027..80bc9b5 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,98 @@
+# SceneRF (Fork Edition)
+**Self-Supervised Monocular 3D Scene Reconstruction with Radiance Fields**
+_Original Authors:_ [Anh-Quan Cao](https://anhquancao.github.io), [Raoul de Charette](https://team.inria.fr/rits/membres/raoul-de-charette/)
+_Inria, Paris, France._
+
+
+
+**This repository is a personal fork of the official [SceneRF](https://github.com/astra-vision/SceneRF) repository.**
+Please note that the changes described below are _not_ part of the official SceneRF repository and are _not_ endorsed by the original authors.
+
+---
+## Fork Changelog
+
+This project was completed as a part of the **Machine Learning for 3D Geometry (IN2392)** course at **TUM**.
+
+### **Enhancements in SceneRF Performance**
+
+- Implemented **Random Fourier Features positional encoding** and **Hierarchical Sampling** (alongside existing sampling techniques) to significantly enhance **novel depths synthesis, novel views synthesis, and scene reconstruction** in SceneRF.
+- Also tried **Multihead Self Attention** in Spherical-UNet, but it didn't improve the results because we didn't have alot of data and compute to train it for longer.
+- Please checkout the project report [here](docs/BetterSceNeRF.pdf).
+- These improvements yield better performance, as shown in the following table:
+
+
+
+- The **best results** are highlighted with **bold** font.
+- **Original results** are taken from the SceneRF paper.
+- **Scaled-down results** correspond to a scaled-down model using the configuration in `train_eval_bash_scripts/train_bundlefusion_scaled_down.sh`.
+
+### **Additional Modifications**
+Below is a summary of the modifications introduced in **this fork** to support additional features and datasets. **All credit for the original work goes to the original authors.**
+
+1. **Dataset Argument for TUM RGB-D**
+ - A new `--dataset` argument has been introduced to:
+ - `scenerf/scripts/train_bundlefusion.py`
+ - `scenerf/data/bundlefusion_dm.py`
+ - `scenerf/data/bundlefusion_dataset.py`
+ - This allows for selecting between **BundleFusion** (`bf`) and **TUM RGB-D** (`tum_rgbd`) during training and data loading.
+
+2. **Modified Evaluation and Reconstruction Scripts**
+ - Added a `--dataset` argument to:
+ - `scenerf/scripts/evaluation/save_depth_metrics_bf.py`
+ - `scenerf/scripts/evaluation/agg_depth_metrics_bf.py`
+ - `scenerf/scripts/evaluation/render_colors_bf.py`
+ - `scenerf/scripts/reconstruction/generate_novel_depths_bf.py`
+ - `scenerf/scripts/reconstruction/depth2tsdf_bf.py`
+ - `scenerf/scripts/reconstruction/generate_sc_gt_bf.py`
+ - `scenerf/scripts/evaluation/eval_sc_bf.py`
+ - This makes it possible to perform the same depth/TSDF/color metrics evaluations on the TUM RGB-D dataset using a BundleFusion-like format.
+
+3. **TUM RGB-D to BundleFusion Conversion**
+ - **New File:** `convert_tum_to_bf/tum_to_bf`
+ - Script to convert the **TUM RGB-D** dataset into a BundleFusion-like directory structure, including:
+ - Pose conversion
+ - Depth scaling
+ - Converting `color.png` to `color.jpg`
+
+4. **Random Fourier Features Positional Encoding**
+ - **New File:** `scenerf/models/pe_rff.py`
+ - **Modified Files:** `scenerf/models/scenerf_bf.py` (to implement rff positional encoding).
+ - Implements **Random Fourier Features** for positional encoding, providing an alternative to standard positional encodings.
+
+5. **Hierarchical Sampling**
+ - **Modified Files:** `scenerf/scripts/train_bundlefusion.py` (added a `--n_pts_hier` argument) and `scenerf/models/scenerf_bf.py` (to implement hierarchical sampling).
+ - Implements **Hierarchical Sampling** alongside uniform and probabilistic sampling.
+ - Allows specifying the number of points for hierarchical sampling directly from the command line.
+ - Probabilistic sampling could sometimes overly concentrate on specific surface areas, leading to an imbalanced focus. Hierarchical sampling refines the uniform sampling points, ensuring a more even distribution near surfaces and improving overall reconstruction quality.
+
+6. **Self Attention**
+ - **Modified Files:** `scenerf/models/unet2d_sphere.py` (to implement multihead self attention in the u-net bottleneck).
+
+7. **Training and Evaluation Bash Scripts**
+ - **New File:** `train_eval_bash_scripts/train_bundlefusion_scaled_down.sh` (to train the model with scaled down configuration)
+ - **New File:** `train_eval_bash_scripts/eval_bundlefusion_scaled_down.sh` (to evaluate the model)
+ - Change paths in the bash scripts accordingly.
+ - Train either the **BundleFusion** (`bf`) and **TUM RGB-D** (`tum_rgbd`) dataset by selecting (`bf`) or (`tum_rgbd`) in the bash scripts.
+
+8. **Assets**
+ - **New Directory:** `assets` (to save evaluation results)
+
+---
+
+
+
+# Original SceneRF README
+
+
+
+Please refer to the original [SceneRF repository](https://github.com/astra-vision/SceneRF) for the most up-to-date official code and instructions. The following sections are from the original SceneRF README (with minor adaptations to reflect the presence of the fork).
+
+----
+
+
+
# SceneRF: Self-Supervised Monocular 3D Scene Reconstruction with Radiance Fields
ICCV 2023
@@ -14,7 +107,7 @@ Inria, Paris, France.
If you find this work or code useful, please cite our [paper](https://arxiv.org/abs/2212.02501) and [give this repo a star](https://github.com/astra-vision/SceneRF/stargazers):
-```
+```bibtex
@InProceedings{cao2023scenerf,
author = {Cao, Anh-Quan and de Charette, Raoul},
title = {SceneRF: Self-Supervised Monocular 3D Scene Reconstruction with Radiance Fields},
diff --git a/assets/outputResults.png b/assets/outputResults.png
new file mode 100644
index 0000000..6b3dda4
Binary files /dev/null and b/assets/outputResults.png differ
diff --git a/convert_tum_to_bf/tum_to_bf.py b/convert_tum_to_bf/tum_to_bf.py
new file mode 100644
index 0000000..04b39b5
--- /dev/null
+++ b/convert_tum_to_bf/tum_to_bf.py
@@ -0,0 +1,256 @@
+import os
+import numpy as np
+from PIL import Image
+from scipy.spatial.transform import Rotation as R
+import argparse
+
+
+def combine_and_rename_files(
+ tum_folder: str,
+ output_folder: str,
+ margin: float = 0.02
+) -> None:
+ """
+ Match and rename RGB/depth files and write them to the output folder with
+ the BundleFusion naming convention. Additionally, load poses from TUM's
+ groundtruth.txt, match them by timestamp within a given margin, and save
+ them as 4x4 transformation matrices. Color images are re-encoded as JPG at
+ maximum quality, and depth images are divided by 5 becuase for bf depth
+ 1mm=1 and for tum_rgbd 1mm=5.
+
+ Parameters
+ ----------
+ tum_folder : str
+ Path to the TUM scene folder containing 'rgb', 'depth', and
+ 'groundtruth.txt'.
+ output_folder : str
+ Path to the output folder where converted data will be stored.
+ margin : float, optional
+ Maximum allowed time difference (in seconds) for matching frames
+ between RGB, depth, and pose data. Defaults to 0.02.
+
+ Returns
+ -------
+ None
+ """
+
+ # Create output folder if it doesn't exist
+ os.makedirs(output_folder, exist_ok=True)
+
+ # Paths to subfolders and files
+ rgb_path = os.path.join(tum_folder, "rgb")
+ depth_path = os.path.join(tum_folder, "depth")
+ pose_file = os.path.join(tum_folder, "groundtruth.txt")
+
+ # If any required directory or file doesn't exist, return early
+ if not (os.path.isdir(rgb_path) and os.path.isdir(depth_path) and
+ os.path.exists(pose_file)):
+ print(f"Skipping {tum_folder} because it lacks required folders/files.")
+ return
+
+ # Get sorted lists of RGB and depth files
+ rgb_files = sorted(f for f in os.listdir(rgb_path)
+ if f.lower().endswith((".png", ".jpg")))
+ depth_files = sorted(f for f in os.listdir(depth_path)
+ if f.lower().endswith(".png"))
+
+ # Extract timestamps from filenames (assuming `timestamp.ext`)
+ rgb_entries = [(float(f.rsplit(".", 1)[0]), f) for f in rgb_files]
+ depth_entries = [(float(f.rsplit(".", 1)[0]), f) for f in depth_files]
+
+ # Load pose entries (timestamp tx ty tz qx qy qz qw) ignoring commented lines
+ with open(pose_file, "r") as f:
+ pose_lines = [
+ line.strip() for line in f
+ if line.strip() and not line.startswith("#")
+ ]
+ pose_entries = []
+ for line in pose_lines:
+ parts = line.split()
+ ts = float(parts[0])
+ data = parts[1:] # [tx, ty, tz, qx, qy, qz, qw]
+ pose_entries.append((ts, data))
+
+ frame_counter = 0
+
+ # Iterate over RGB frames and find matching depth and pose
+ for rgb_ts, rgb_filename in rgb_entries:
+ frame_id = f"frame-{frame_counter:06d}"
+
+ # Find closest depth frame
+ if not depth_entries:
+ break
+ closest_depth = min(depth_entries, key=lambda x: abs(rgb_ts - x[0]))
+ if abs(rgb_ts - closest_depth[0]) > margin:
+ continue
+
+ # Find closest pose
+ if not pose_entries:
+ break
+ closest_pose = min(pose_entries, key=lambda x: abs(rgb_ts - x[0]))
+ if abs(rgb_ts - closest_pose[0]) > margin:
+ continue
+
+ # We have matched depth and pose; remove them from the pool
+ depth_entries.remove(closest_depth)
+ pose_entries.remove(closest_pose)
+
+ # Increment the frame counter now that we have a valid match
+ frame_counter += 1
+
+ # -- Process and save color image as JPG with max quality --
+ rgb_src = os.path.join(rgb_path, rgb_filename)
+ rgb_dst = os.path.join(output_folder, f"{frame_id}.color.jpg")
+
+ rgb_img = Image.open(rgb_src).convert("RGB")
+ rgb_img.save(rgb_dst, "JPEG", quality=100)
+
+ # -- Process and save depth image (divide by 5) as PNG --
+ depth_src = os.path.join(depth_path, closest_depth[1])
+ depth_dst = os.path.join(output_folder, f"{frame_id}.depth.png")
+
+ depth_img = np.array(Image.open(depth_src))
+ # Convert to uint16 and divide by 5 (integer division)
+ depth_img = depth_img.astype(np.uint16)
+ depth_img //= 5
+
+ depth_img_pil = Image.fromarray(depth_img)
+ depth_img_pil.save(depth_dst)
+
+ # -- Process and save pose as .pose.txt --
+ tx, ty, tz, qx, qy, qz, qw = map(float, closest_pose[1])
+ rotation = R.from_quat([qx, qy, qz, qw]).as_matrix()
+
+ pose_matrix = np.eye(4, dtype=np.float64)
+ pose_matrix[:3, :3] = rotation
+ pose_matrix[:3, 3] = [tx, ty, tz]
+
+ pose_dst = os.path.join(output_folder, f"{frame_id}.pose.txt")
+ np.savetxt(pose_dst, pose_matrix, fmt="%.6f")
+
+
+def generate_info_txt(output_folder: str, folder_name: str) -> None:
+ """
+ Generate an 'info.txt' file suitable for BundleFusion. This file includes
+ camera intrinsics and extrinsics for the color and depth sensors. The
+ intrinsics are selected based on the TUM dataset prefix.
+
+ Parameters
+ ----------
+ output_folder : str
+ Path to the output folder where 'info.txt' will be saved.
+ folder_name : str
+ Name of the scene folder. Used to determine if the dataset is
+ 'freiburg1', 'freiburg2', or 'freiburg3'.
+
+ Returns
+ -------
+ None
+ """
+
+ # Intrinsics for different TUM prefixes
+ intrinsics = {
+ "freiburg1": "517.3 0 318.6 0 0 516.5 255.3 0 0 0 1 0 0 0 0 1",
+ "freiburg2": "520.9 0 325.1 0 0 521.0 249.7 0 0 0 1 0 0 0 0 1",
+ "freiburg3": "535.4 0 320.1 0 0 539.2 247.6 0 0 0 1 0 0 0 0 1"
+ }
+
+ # Default values
+ color_intrinsic = "525.0 0 319.5 0 0 525.0 239.5 0 0 0 1 0 0 0 0 1"
+ depth_intrinsic = "525.0 0 319.5 0 0 525.0 239.5 0 0 0 1 0 0 0 0 1"
+
+ folder_name_lower = folder_name.lower()
+ if "freiburg1" in folder_name_lower:
+ color_intrinsic = intrinsics["freiburg1"]
+ depth_intrinsic = intrinsics["freiburg1"]
+ elif "freiburg2" in folder_name_lower:
+ color_intrinsic = intrinsics["freiburg2"]
+ depth_intrinsic = intrinsics["freiburg2"]
+ elif "freiburg3" in folder_name_lower:
+ color_intrinsic = intrinsics["freiburg3"]
+ depth_intrinsic = intrinsics["freiburg3"]
+
+ info_path = os.path.join(output_folder, "info.txt")
+ with open(info_path, "w") as f:
+ f.write("m_versionNumber = 4\n")
+ f.write("m_sensorName = Kinect\n")
+ f.write("m_colorWidth = 640\n")
+ f.write("m_colorHeight = 480\n")
+ f.write("m_depthWidth = 640\n")
+ f.write("m_depthHeight = 480\n")
+ f.write("m_depthShift = 5000\n")
+ f.write(f"m_calibrationColorIntrinsic = {color_intrinsic}\n")
+ f.write("m_calibrationColorExtrinsic = "
+ "1 0 0 0 0 1 0 0 0 0 1 0 0 0 0 1\n")
+ f.write(f"m_calibrationDepthIntrinsic = {depth_intrinsic}\n")
+ f.write("m_calibrationDepthExtrinsic = "
+ "1 0 0 0 0 1 0 0 0 0 1 0 0 0 0 1\n")
+
+
+def main(source_dir: str, dest_dir: str) -> None:
+ """
+ Main function that processes multiple TUM scene folders within a source
+ directory and saves the converted data to a destination directory.
+
+ For each scene in `source_dir`, this function:
+ 1. Calls `combine_and_rename_files` to match and convert images/poses.
+ 2. Calls `generate_info_txt` to create the BundleFusion 'info.txt'.
+
+ Parameters
+ ----------
+ source_dir : str
+ Path to the directory containing multiple TUM scenes (subdirectories).
+ dest_dir : str
+ Path to the directory where the converted scenes will be stored.
+
+ Returns
+ -------
+ None
+ """
+
+ if not os.path.isdir(source_dir):
+ print(f"Source directory '{source_dir}' is not valid.")
+ return
+
+ os.makedirs(dest_dir, exist_ok=True)
+
+ # Process each subdirectory (scene) within the source directory
+ for scene_name in sorted(os.listdir(source_dir)):
+ scene_path = os.path.join(source_dir, scene_name)
+ if not os.path.isdir(scene_path):
+ # Skip files; only process directories
+ continue
+
+ # Create corresponding directory in destination
+ output_scene_path = os.path.join(dest_dir, scene_name)
+ os.makedirs(output_scene_path, exist_ok=True)
+
+ print(f"Processing scene: {scene_name}")
+
+ # Perform matching, renaming, and saving
+ combine_and_rename_files(scene_path, output_scene_path)
+
+ # Generate info.txt
+ generate_info_txt(output_scene_path, scene_name)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Convert multiple TUM RGB-D dataset scenes to BundleFusion format."
+ )
+ parser.add_argument(
+ "--source_dir",
+ type=str,
+ required=True,
+ help="Path to the directory containing multiple TUM scene folders."
+ )
+ parser.add_argument(
+ "--dest_dir",
+ type=str,
+ required=True,
+ help="Path to the directory where converted scenes will be stored."
+ )
+
+ args = parser.parse_args()
+ main(args.source_dir, args.dest_dir)
+
diff --git a/docs/BetterSceNeRF.pdf b/docs/BetterSceNeRF.pdf
new file mode 100644
index 0000000..91498d4
Binary files /dev/null and b/docs/BetterSceNeRF.pdf differ
diff --git a/scenerf/data/bundlefusion/bundlefusion_dataset.py b/scenerf/data/bundlefusion/bundlefusion_dataset.py
index a7ab0d4..fde8104 100644
--- a/scenerf/data/bundlefusion/bundlefusion_dataset.py
+++ b/scenerf/data/bundlefusion/bundlefusion_dataset.py
@@ -14,6 +14,7 @@ class BundlefusionDataset(Dataset):
def __init__(
self,
split,
+ dataset,
root,
n_sources=1,
frame_interval=4,
@@ -21,13 +22,53 @@ def __init__(
infer_frame_interval=2,
color_jitter=None,
select_scans=None,
+ tum_rgbd=False,
):
self.root = root
- splits = {
- "train": ["apt0", "apt1", "apt2", "office0", "office1", "office2", "office3"],
- "val": ["copyroom"],
- "all": ["apt0", "apt1", "apt2", "office0", "office1", "office2", "office3", "copyroom"]
- }
+
+ print(dataset)
+ # Select a split based on training dataset being either bf or tum_rgbd
+ if dataset == "bf":
+ splits = {
+ "train": ["apt0", "apt1", "apt2", "office0", "office1", "office2", "office3"],
+ "val": ["copyroom"],
+ "all": ["apt0", "apt1", "apt2", "office0", "office1", "office2", "office3", "copyroom"]
+ }
+
+ elif dataset == "tum_rgbd":
+ splits = {
+ "train": [
+ "rgbd_dataset_freiburg1_360",
+ "rgbd_dataset_freiburg1_desk",
+ "rgbd_dataset_freiburg1_floor",
+ "rgbd_dataset_freiburg1_room",
+ "rgbd_dataset_freiburg1_xyz",
+ "rgbd_dataset_freiburg2_360_hemisphere",
+ "rgbd_dataset_freiburg2_desk",
+ "rgbd_dataset_freiburg2_large_no_loop",
+ "rgbd_dataset_freiburg2_pioneer_360",
+ "rgbd_dataset_freiburg2_xyz",
+ "rgbd_dataset_freiburg3_structure_texture_far"
+ ],
+ "val": [
+ "rgbd_dataset_freiburg3_long_office_household"
+ ],
+ "all": [
+ "rgbd_dataset_freiburg1_360",
+ "rgbd_dataset_freiburg1_desk",
+ "rgbd_dataset_freiburg1_floor",
+ "rgbd_dataset_freiburg1_room",
+ "rgbd_dataset_freiburg1_xyz",
+ "rgbd_dataset_freiburg2_360_hemisphere",
+ "rgbd_dataset_freiburg2_desk",
+ "rgbd_dataset_freiburg2_large_no_loop",
+ "rgbd_dataset_freiburg2_pioneer_360",
+ "rgbd_dataset_freiburg2_xyz",
+ "rgbd_dataset_freiburg3_structure_texture_far",
+ "rgbd_dataset_freiburg3_long_office_household"
+ ]
+ }
+
self.sequences = splits[split]
self.n_sources = n_sources
self.frame_interval = frame_interval
@@ -251,7 +292,7 @@ def _read_depth(depth_filename):
and save the depth values (in millimeters) into a 2d numpy array.
The depth image file is assumed to be in 16-bit PNG format, depth in millimeters.
"""
- depth = imageio.imread(depth_filename) / 1000.0 # numpy.float64
+ depth = imageio.imread(depth_filename)
depth = np.asarray(depth)
return depth
diff --git a/scenerf/data/bundlefusion/bundlefusion_dm.py b/scenerf/data/bundlefusion/bundlefusion_dm.py
index 7c33262..27f396f 100644
--- a/scenerf/data/bundlefusion/bundlefusion_dm.py
+++ b/scenerf/data/bundlefusion/bundlefusion_dm.py
@@ -8,6 +8,7 @@
class BundlefusionDM(pl.LightningDataModule):
def __init__(
self,
+ dataset,
root,
train_n_frames=16,
val_n_frames=8,
@@ -20,6 +21,7 @@ def __init__(
n_sources=1,
):
super().__init__()
+ self.dataset = dataset
self.root = root
self.batch_size = batch_size
self.num_workers = num_workers
@@ -34,6 +36,7 @@ def __init__(
def setup(self, stage=None):
self.train_ds = BundlefusionDataset(
split="train",
+ dataset=self.dataset,
root=self.root,
n_frames=self.train_n_frames,
frame_interval=self.train_frame_interval,
@@ -46,6 +49,7 @@ def setup(self, stage=None):
def setup_val_ds(self, select_scans=None):
self.val_ds = BundlefusionDataset(
split="val",
+ dataset=self.dataset,
root=self.root,
n_frames=self.val_n_frames,
frame_interval=self.val_frame_interval,
diff --git a/scenerf/models/pe_rff.py b/scenerf/models/pe_rff.py
new file mode 100644
index 0000000..b2c3b90
--- /dev/null
+++ b/scenerf/models/pe_rff.py
@@ -0,0 +1,75 @@
+"""
+Code taken from https://github.com/sxyu/pixel-nerf/blob/91a044bdd62aebe0ed3a5685ca37cb8a9dc8e8ee/src/model/code.py#L6
+"""
+import torch
+import numpy as np
+
+class RFFEncoding(torch.nn.Module):
+ """
+ Implementation of Random Fourier Features (RFF) encoding with random frequencies and biases.
+ """
+
+ def __init__(self, num_freqs=6, d_in=3, sigma=1.0, include_input=True):
+ super().__init__()
+ # instead of D we have 2D
+ self.num_freqs = 2 *num_freqs
+ self.d_in = d_in
+ self.include_input = include_input
+ # Output dimensions: (cos) per frequency for each input dimension
+ self.d_out = d_in * (self.num_freqs)
+ if include_input:
+ self.d_out += d_in
+
+ # Randomly sample frequencies w' ~ N(0, sigma^2)
+ freqs = 2 * np.pi * torch.randn(d_in, self.num_freqs) * sigma
+ self.register_buffer("_freqs", freqs)
+
+ # Randomly sample biases b' ~ U[0, 2pi]
+ biases = 2 * np.pi * torch.rand(self.num_freqs)
+ self.register_buffer("_biases", biases)
+
+
+
+ def forward(self, x):
+ """
+ Apply RFF encoding.
+ :param x: Input tensor of shape (batch_size, d_in)
+ :return: Encoded tensor of shape (batch_size, d_out)
+ """
+ # Compute the projection: x @ freqs + biases
+ # Expand the input to match the number of frequencies
+ embed = x.unsqueeze(1).repeat(1, self.num_freqs, 1) # Shape: [64, 12, 3]
+ # Combine with frequencies and biases
+
+ bias = self._biases
+ bias = bias.unsqueeze(0)
+ bias = bias.unsqueeze(-1) # 1 12 1
+
+ # Transpose frequencies to match the shape of embed
+ freqs = self._freqs.unsqueeze(0).transpose(1, 2) # Shape: [1, 12, 3]
+ embed * freqs + bias
+ projected = torch.addcmul(bias, embed, freqs) # Shape: [64, 12, 3]
+
+
+ # Compute cosine embeddings with sqrt(2) scaling
+ cos_enc = torch.sqrt(torch.tensor(2.0)) * torch.cos(projected) # Shape: [64, 12, 3]
+
+ # Reshape to [batch_size, d_in * num_freqs]
+ cos_enc = cos_enc.reshape(cos_enc.shape[0], -1) # Shape: [64, 36]
+
+ # Include the raw input if specified
+ if self.include_input:
+ encoding = torch.cat([x, cos_enc], dim=-1) # Shape: [64, 39]
+ else:
+ encoding = cos_enc # Shape: [64, 36]
+
+ return encoding
+ @classmethod
+ def from_conf(cls, conf, d_in=3):
+ # PyHocon construction
+ return cls(
+ conf.get_int("num_freqs", 6),
+ d_in,
+ conf.get_float("sigma", 1.0),
+ conf.get_bool("include_input", True),
+ )
diff --git a/scenerf/models/scenerf_bf.py b/scenerf/models/scenerf_bf.py
index e626af4..291e6ac 100644
--- a/scenerf/models/scenerf_bf.py
+++ b/scenerf/models/scenerf_bf.py
@@ -10,7 +10,8 @@
from scenerf.loss.depth_metrics import compute_depth_errors
from scenerf.loss.ss_loss import compute_l1_loss
-from scenerf.models.pe import PositionalEncoding
+# from scenerf.models.pe import PositionalEncoding
+from scenerf.models.pe_rff import RFFEncoding as PositionalEncoding
from scenerf.models.ray_som_kl import RaySOM
@@ -38,11 +39,12 @@ def __init__(
std=0.2,
n_gaussians=4,
n_pts_uni=32,
+ n_pts_hier=32,
n_pts_per_gaussian=8,
smooth_loss_weight=0,
sampling_method="uniform",
batch_size=1,
- net_2d="b7",
+ net_2d="b7",
add_fov_hor=0, add_fov_ver=0,
sphere_H=480, sphere_W=640,
use_color=True,
@@ -63,6 +65,7 @@ def __init__(
self.smooth_loss_weight = smooth_loss_weight
self.n_pts_uni = n_pts_uni
+ self.n_pts_hier = n_pts_hier
self.n_gaussians = n_gaussians
self.n_pts_per_gaussian = n_pts_per_gaussian
self.std = std
@@ -74,14 +77,13 @@ def __init__(
self.max_sample_depth = max_sample_depth
self.eval_depth = eval_depth
-
self.net_2d = net_2d
if net_2d == "b7":
feature = 256
self.out_img_W = sphere_W
self.out_img_H = sphere_H
-
+
self.spherical_mapping = SphericalMapping(
v_angle_max=112.2911 + add_fov_ver,
v_angle_min=67.6248 - add_fov_ver,
@@ -110,7 +112,6 @@ def __init__(
d_latent=2480
)
-
self.mlp_gaussian = ResnetFC(
d_in=39 + 3,
d_out=2,
@@ -119,8 +120,6 @@ def __init__(
d_latent=2480
)
self.ray_som = RaySOM(som_sigma=som_sigma)
-
-
def forward(self, batch, step_type):
"""
@@ -133,13 +132,11 @@ def forward(self, batch, step_type):
cam_K = batch['cam_K_depth'][0]
inv_K = torch.inverse(cam_K)
-
pix_coords, out_pix_coords, _ = self.spherical_mapping.from_pixels(inv_K=inv_K)
-
+
x_rgbs = self.net_rgb(img_input, pix=pix_coords,
pix_sphere=out_pix_coords)
-
total_loss_reprojection = 0
total_loss_color = 0
@@ -157,23 +154,21 @@ def forward(self, batch, step_type):
img_sources = batch['img_sources'][i]
img_targets = batch['img_targets'][i]
source_depths = batch['source_depths'][i]
-
+
n_sources = len(img_sources)
-
+
x_rgb = {}
for k in x_rgbs:
x_rgb[k] = x_rgbs[k][i]
-
n_grids = self.n_rays // (self.sample_grid_size ** 2)
-
+
for sid in range(n_sources):
# for sid in [self.current_epoch]:
img_target = img_targets[sid]
- img_source = img_sources[sid]
+ img_source = img_sources[sid]
T_source2infer = T_source2infers[sid]
T_source2target = T_source2targets[sid]
-
ret = self.process_single_source(
n_grids,
@@ -204,7 +199,6 @@ def forward(self, batch, step_type):
mask = depth_gt > 0
if mask.sum() > 0:
self.evaluate_depth(step_type, depth_gt[mask], depth_source_rendered[mask])
-
# ==== Combine all the losses
total_loss = 0
@@ -214,17 +208,16 @@ def forward(self, batch, step_type):
if total_loss_reprojection > 0:
total_loss += total_loss_reprojection * 5.0
self.log(step_type + "/loss_reprojection",
- total_loss_reprojection.detach(), on_epoch=True, sync_dist=True)
-
+ total_loss_reprojection.detach(), on_epoch=True, sync_dist=True)
total_loss_color /= bs
if self.use_color:
total_loss += total_loss_color
self.log(step_type + "/loss_color",
- total_loss_color.detach(), on_epoch=True, sync_dist=True)
+ total_loss_color.detach(), on_epoch=True, sync_dist=True)
# SOM loss
- total_loss_kl /= bs
+ total_loss_kl /= bs
total_loss += total_loss_kl
self.log(step_type + "/loss_som_kl",
total_loss_kl.detach(), on_epoch=True, sync_dist=True)
@@ -247,8 +240,7 @@ def forward(self, batch, step_type):
self.log(step_type + "/total_loss", total_loss.detach(),
on_epoch=True, sync_dist=True)
-
-
+
return {
"total_loss": total_loss
}
@@ -256,7 +248,7 @@ def forward(self, batch, step_type):
def process_single_source(self,
n_grids,
x_rgb,
- # x_sphere,
+ # x_sphere,
cam_K, inv_K,
img_source, img_target,
# pix_source, pix_sky_source,
@@ -283,7 +275,7 @@ def process_single_source(self,
# x_sphere,
ray_batch_size=pix_source.shape[0],
sampled_pixels=pix_source)
-
+
depth_source_rendered = render_out_dict['depth']
color_rendered = render_out_dict['color']
loss_kl = render_out_dict['loss_kl']
@@ -338,14 +330,11 @@ def process_single_source(self,
sampled_color_source=sampled_color_source)
ret["loss_smooth"] = loss_smooth
-
-
return ret
def evaluate_depth(self, step_type, gt_depth, pred_depth, log=True, ret_mean=True):
depth_errors = []
-
depth_error = compute_depth_errors(
gt_depth.reshape(-1),
pred_depth.reshape(-1).detach().cpu().numpy(),
@@ -360,17 +349,16 @@ def evaluate_depth(self, step_type, gt_depth, pred_depth, log=True, ret_mean=Tru
agg_depth_errors = np.array(depth_errors).sum(0)
metric_list = ["abs_rel", "sq_rel",
"rmse", "rmse_log", "a1", "a2", "a3"]
-
+
if not log:
return agg_depth_errors
for i_metric, metric in enumerate(metric_list):
key = step_type + "depth/{}".format(metric)
-
+
self.log(key, agg_depth_errors[i_metric],
- on_epoch=True, sync_dist=True)
+ on_epoch=True, sync_dist=True)
-
def compute_reprojection_loss(
self,
pix_source, sampled_color_source,
@@ -408,15 +396,13 @@ def compute_reprojection_loss(
loss_reprojections = torch.stack(loss_reprojections)
loss_reprojections = torch.min(loss_reprojections, dim=0)[0]
-
- return loss_reprojections
+ return loss_reprojections
def step(self, batch, step_type):
out_dict = self.forward(batch, step_type)
return out_dict['total_loss']
-
def render_rays_batch(self, cam_K,
T_source2infer,
x_rgb,
@@ -437,7 +423,6 @@ def render_rays_batch(self, cam_K,
depth_volumes = []
color_rendereds = []
-
cnt = 0
loss_kl = []
@@ -523,36 +508,37 @@ def batchify_density(self,
'density': torch.cat(densities, dim=0),
}
- def predict(self, mlp,
- cam_pts, x_rgb,
- cam_K, viewdir, output_type="density"):
+ def predict(self, mlp,
+ cam_pts, x_rgb,
+ cam_K, viewdir, output_type="density"):
saved_shape = cam_pts.shape
cam_pts = cam_pts.reshape(-1, 3)
projected_pix = cam_pts_2_pix(cam_pts, cam_K)
-
pix_coords, pix_sphere_coords, _ = self.spherical_mapping.from_pixels(
inv_K=torch.inverse(cam_K),
pix_coords=projected_pix)
-
+
pe = self.pe(cam_pts)
- feats_2d_sphere = [sample_feats_2d(x_rgb["1_1"].unsqueeze(0), pix_sphere_coords, (self.out_img_W, self.out_img_H))]
+ feats_2d_sphere = [
+ sample_feats_2d(x_rgb["1_1"].unsqueeze(0), pix_sphere_coords, (self.out_img_W, self.out_img_H))]
for scale in [2, 4, 8, 16]:
key = "1_{}".format(scale)
- feats_2d_sphere.append(sample_feats_2d(x_rgb[key].unsqueeze(0), pix_sphere_coords, (self.out_img_W//scale, self.out_img_H//scale)))
-
+ feats_2d_sphere.append(sample_feats_2d(x_rgb[key].unsqueeze(0), pix_sphere_coords,
+ (self.out_img_W // scale, self.out_img_H // scale)))
+
feats_2d_sphere = torch.cat(feats_2d_sphere, dim=-1)
-
+
viewdir = viewdir.unsqueeze(1).expand(-1, saved_shape[1], -1).reshape(-1, 3)
x_in = torch.cat([feats_2d_sphere, pe, viewdir], dim=-1)
- if output_type == "density":
+ if output_type == "density":
mlp_output = mlp(x_in)
color = torch.sigmoid(mlp_output[..., :3])
density = self.density_activation(mlp_output[..., 3:4])
-
+
if len(saved_shape) == 3:
density = density.reshape(saved_shape[0], saved_shape[1])
color = color.reshape(saved_shape[0], saved_shape[1], 3)
@@ -564,9 +550,9 @@ def predict(self, mlp,
residual = residual.reshape(saved_shape[0], saved_shape[1], 2)
return residual
- def predict_gaussian_means_and_stds(self, T_source2infer, unit_direction, n_gaussians,
- x_rgb,
- cam_K, base_std, viewdir):
+ def predict_gaussian_means_and_stds(self, T_source2infer, unit_direction, n_gaussians,
+ x_rgb,
+ cam_K, base_std, viewdir):
n_rays = unit_direction.shape[0]
step = self.max_sample_depth * 1.0 / self.n_gaussians
@@ -606,30 +592,34 @@ def predict_gaussian_means_and_stds(self, T_source2infer, unit_direction, n_gaus
gaussian_means_sensor_distance) + 0.5 # avoid negative distance
gaussian_stds_sensor_distance = torch.relu(
gaussian_stds_offset + base_std) + 0.5
-
+
return gaussian_means_sensor_distance, gaussian_stds_sensor_distance
def batchify_depth_and_color(
- self, T_source2infer, x_rgb,
- # x_sphere,
+ self, T_source2infer, x_rgb,
+ # x_sphere,
batch_sampled_pixels,
cam_K, inv_K):
+ hierarchical_sampling = (self.n_pts_hier > 0)
+
depths = []
ret = {}
n_rays = batch_sampled_pixels.shape[0]
unit_direction = compute_direction_from_pixels(
batch_sampled_pixels, inv_K)
-
+
if self.n_pts_uni > 0:
n_pts_uni = self.n_pts_uni
else:
n_pts_uni = 2
+
+ # first uniform sampling (coarse sampling)
cam_pts_uni, depth_volume_uni, sensor_distance_uni, viewdir = sample_rays_viewdir(
inv_K, T_source2infer,
self.img_size,
sampling_method="uniform",
sampled_pixels=batch_sampled_pixels,
- n_pts_per_ray=n_pts_uni,
+ n_pts_per_ray=self.n_pts_uni,
max_sample_depth=self.max_sample_depth)
gaussian_means_sensor_distance, gaussian_stds_sensor_distance = self.predict_gaussian_means_and_stds(
@@ -637,10 +627,10 @@ def batchify_depth_and_color(
unit_direction, self.n_gaussians,
x_rgb=x_rgb,
cam_K=cam_K,
- base_std=self.std,
+ base_std=self.std,
viewdir=viewdir)
-
+ # gaussian sampling
cam_pts_gauss, depth_volume_gauss, sensor_distance_gauss = sample_rays_gaussian(
T_cam2cam=T_source2infer,
n_rays=n_rays,
@@ -650,38 +640,66 @@ def batchify_depth_and_color(
n_gaussians=self.n_gaussians, n_pts_per_gaussian=self.n_pts_per_gaussian,
max_sample_depth=self.max_sample_depth)
- if self.n_pts_uni > 0:
- cam_pts = torch.cat([cam_pts_uni, cam_pts_gauss],
- dim=1) # n_rays, n_pts 3
- depth_volume = torch.cat(
- [depth_volume_uni, depth_volume_gauss], dim=1) # n_rays, n_pts
- sensor_distance = torch.cat(
- [sensor_distance_uni, sensor_distance_gauss], dim=1) # n_rays, n_pts
- elif self.n_pts_per_gaussian == 1:
- cam_pts = cam_pts_uni
- depth_volume = depth_volume_uni
- sensor_distance = sensor_distance_uni
- else:
- cam_pts = cam_pts_gauss
- depth_volume = depth_volume_gauss
- sensor_distance = sensor_distance_gauss
-
- sorted_indices = torch.argsort(sensor_distance, dim=1)
- sensor_distance = torch.gather(
- sensor_distance, dim=1, index=sorted_indices) # n_rays, n_pts
- depth_volume = torch.gather(
- depth_volume, dim=1, index=sorted_indices) # n_rays, n_pts
- cam_pts = torch.gather(
- cam_pts, dim=1, index=sorted_indices.unsqueeze(-1).expand(-1, -1, 3))
-
- density, colors = self.predict(mlp=self.mlp,
- cam_pts=cam_pts.detach(),
- viewdir=viewdir,
- x_rgb=x_rgb,
- cam_K=cam_K)
- rendered_out = self.render_depth_and_color(
- density, sensor_distance, depth_volume,
- colors=colors)
+ sample_phases = ["coarse", "fine"] if hierarchical_sampling else ["uniform"]
+ weights_temp = None
+
+ for sample_phase in sample_phases:
+ if sample_phase == "coarse":
+ cam_pts = cam_pts_uni
+ depth_volume = depth_volume_uni
+ sensor_distance = sensor_distance_uni
+
+ elif sample_phase == "fine":
+ cam_pts_hier, depth_volume_hier, sensor_distance_hier, viewdir = sample_rays_viewdir(
+ inv_K, T_source2infer,
+ self.img_size,
+ sampling_method="uniform",
+ sampled_pixels=batch_sampled_pixels,
+ n_pts_per_ray=self.n_pts_hier,
+ max_sample_depth=self.max_sample_depth,
+ weights=weights_temp)
+ cam_pts = torch.cat([cam_pts_uni, cam_pts_gauss, cam_pts_hier],
+ dim=1) # n_rays, n_pts 3
+ depth_volume = torch.cat([depth_volume_uni, depth_volume_gauss, depth_volume_hier],
+ dim=1)
+ sensor_distance = torch.cat([sensor_distance_uni, sensor_distance_gauss, sensor_distance_hier],
+ dim=1)
+ elif sample_phase == "uniform":
+ if self.n_pts_uni > 0:
+ cam_pts = torch.cat([cam_pts_uni, cam_pts_gauss],
+ dim=1) # n_rays, n_pts 3
+ depth_volume = torch.cat(
+ [depth_volume_uni, depth_volume_gauss], dim=1) # n_rays, n_pts
+ sensor_distance = torch.cat(
+ [sensor_distance_uni, sensor_distance_gauss], dim=1) # n_rays, n_pts
+ elif self.n_pts_per_gaussian == 1:
+ cam_pts = cam_pts_uni
+ depth_volume = depth_volume_uni
+ sensor_distance = sensor_distance_uni
+ else:
+ cam_pts = cam_pts_gauss
+ depth_volume = depth_volume_gauss
+ sensor_distance = sensor_distance_gauss
+
+ sorted_indices = torch.argsort(sensor_distance, dim=1)
+ sensor_distance = torch.gather(
+ sensor_distance, dim=1, index=sorted_indices) # n_rays, n_pts
+ depth_volume = torch.gather(
+ depth_volume, dim=1, index=sorted_indices) # n_rays, n_pts
+ cam_pts = torch.gather(
+ cam_pts, dim=1, index=sorted_indices.unsqueeze(-1).expand(-1, -1, 3))
+
+ density, colors = self.predict(mlp=self.mlp,
+ cam_pts=cam_pts.detach(),
+ viewdir=viewdir,
+ x_rgb=x_rgb,
+ cam_K=cam_K)
+
+ rendered_out = self.render_depth_and_color(
+ density, sensor_distance, depth_volume,
+ colors=colors)
+ if sample_phase == "coarse":
+ weights_temp = rendered_out['weights']
depths = rendered_out['depth_rendered']
colors = rendered_out['color']
@@ -713,10 +731,8 @@ def batchify_depth_and_color(
return ret
-
-
def render_depth_and_color(self,
- density, sensor_distance, depth_volume, colors):
+ density, sensor_distance, depth_volume, colors):
sensor_distance[sensor_distance < 0] = 0
deltas = torch.zeros_like(sensor_distance)
@@ -724,7 +740,6 @@ def render_depth_and_color(self,
deltas[:, 1:] = sensor_distance[:, 1:] - sensor_distance[:, :-1]
alphas = 1 - torch.exp(-deltas * density)
-
ret = {
"alphas": alphas
}
@@ -737,9 +752,8 @@ def render_depth_and_color(self,
weights = alphas * T_alphas[:, :-1] # (B, K)
depth_rendered = torch.sum(weights * depth_volume, -1)
- color_rendered = torch.sum(weights.unsqueeze(-1) * colors, -2) # (B, 3)
+ color_rendered = torch.sum(weights.unsqueeze(-1) * colors, -2) # (B, 3)
-
diff = depth_rendered.unsqueeze(-1) - depth_volume
abs_diff = torch.abs(diff)
closest_pts_to_depth, weights_at_depth_idx = torch.min(abs_diff, dim=1)
@@ -747,7 +761,7 @@ def render_depth_and_color(self,
weights, dim=1, index=weights_at_depth_idx.unsqueeze(-1)).squeeze()
ret['color'] = color_rendered
- ret['weights_at_depth'] = weights_at_depth
+ ret['weights_at_depth'] = weights_at_depth
ret['closest_pts_to_depth'] = closest_pts_to_depth
ret['weights'] = weights
# t = {}
@@ -764,8 +778,6 @@ def training_step(self, batch, batch_idx):
def validation_step(self, batch, batch_idx):
self.step(batch, "val")
-
-
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.parameters(), lr=self.lr, weight_decay=self.weight_decay
diff --git a/scenerf/models/unet2d_sphere.py b/scenerf/models/unet2d_sphere.py
index a073de8..fae61b7 100644
--- a/scenerf/models/unet2d_sphere.py
+++ b/scenerf/models/unet2d_sphere.py
@@ -1,270 +1,270 @@
-"""
-Code adapted from https://github.com/cv-rits/MonoScene/blob/master/monoscene/models/unet2d.py
-"""
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-class BasicBlock(nn.Module):
- def __init__(self, channel_num, dilations):
- super(BasicBlock, self).__init__()
-
- self.conv_block1 = nn.Sequential(
- nn.Conv2d(channel_num, channel_num, 3,
- padding=dilations[0], dilation=dilations[0]),
- nn.BatchNorm2d(channel_num),
- nn.LeakyReLU(),
- # nn.ReLU(),
- )
- self.conv_block2 = nn.Sequential(
- nn.Conv2d(channel_num, channel_num, 3,
- padding=dilations[1], dilation=dilations[1]),
- nn.BatchNorm2d(channel_num),
- )
- self.lrelu = nn.LeakyReLU()
- # self.lrelu = nn.ReLU()
-
- def forward(self, x):
- residual = x
- x = self.conv_block1(x)
- x = self.conv_block2(x)
- x = x + residual
- out = self.lrelu(x)
- return out
-
-
-class UpSampleBN(nn.Module):
- def __init__(self, skip_input, output_features):
- super(UpSampleBN, self).__init__()
- self._net = nn.Sequential(
- nn.Conv2d(skip_input, output_features,
- kernel_size=3, stride=1, padding=1),
- BasicBlock(output_features, dilations=[1, 1]),
- BasicBlock(output_features, dilations=[2, 2]),
- BasicBlock(output_features, dilations=[3, 3]),
- )
-
- def forward(self, x, concat_with):
- up_x = F.interpolate(
- x,
- size=(concat_with.shape[2], concat_with.shape[3]),
- mode="bilinear",
- align_corners=True,
- )
- f = torch.cat([up_x, concat_with], dim=1)
- return self._net(f)
-
-
-class DecoderSphere(nn.Module):
- def __init__(
- self,
- num_features,
- bottleneck_features,
- out_feature,
- out_img_W,
- out_img_H
- ):
- super(DecoderSphere, self).__init__()
-
- self.out_img_W = out_img_W
- self.out_img_H = out_img_H
-
- features = int(num_features)
-
- self.conv2 = nn.Conv2d(
- bottleneck_features, features, kernel_size=1, stride=1, padding=1
- )
-
- self.out_feature_1_1 = out_feature
- self.out_feature_1_2 = out_feature
- self.out_feature_1_4 = out_feature
- self.out_feature_1_8 = out_feature
- self.out_feature_1_16 = out_feature
- self.feature_1_16 = features // 2
- self.feature_1_8 = features // 4
- self.feature_1_4 = features // 8
- self.feature_1_2 = features // 16
- self.feature_1_1 = features // 32
-
- self.resize_1_1 = nn.Conv2d(
- 3, self.feature_1_1, kernel_size=1
- )
- self.resize_1_2 = nn.Conv2d(
- 32, self.feature_1_2, kernel_size=1
- )
- self.resize_1_4 = nn.Conv2d(
- 48, self.feature_1_4, kernel_size=1
- )
- self.resize_1_8 = nn.Conv2d(
- 80, self.feature_1_8, kernel_size=1
- )
- self.resize_1_16 = nn.Conv2d(
- 224, self.feature_1_16, kernel_size=1
- )
-
- self.resize_output_1_1 = nn.Conv2d(
- self.feature_1_1, self.out_feature_1_1, kernel_size=1
- )
- self.resize_output_1_2 = nn.Conv2d(
- self.feature_1_2, self.out_feature_1_2, kernel_size=1
- )
- self.resize_output_1_4 = nn.Conv2d(
- self.feature_1_4, self.out_feature_1_4, kernel_size=1
- )
- self.resize_output_1_8 = nn.Conv2d(
- self.feature_1_8, self.out_feature_1_8, kernel_size=1
- )
- self.resize_output_1_16 = nn.Conv2d(
- self.feature_1_16, self.out_feature_1_16, kernel_size=1
- )
-
- self.up16 = UpSampleBN(
- skip_input=features + 224, output_features=self.feature_1_16
- )
- self.up8 = UpSampleBN(
- skip_input=self.feature_1_16 + 80, output_features=self.feature_1_8
- )
- self.up4 = UpSampleBN(
- skip_input=self.feature_1_8 + 48, output_features=self.feature_1_4,
- )
- self.up2 = UpSampleBN(
- skip_input=self.feature_1_4 + 32, output_features=self.feature_1_2
- )
- self.up1 = UpSampleBN(
- skip_input=self.feature_1_2 + 3, output_features=self.feature_1_1
- )
-
- def get_sphere_feature(self, x, pix, pix_sphere, scale):
- out_W, out_H = round(self.out_img_W/scale), round(self.out_img_H/scale)
- map_sphere = torch.zeros((out_W, out_H, 2)).type_as(x) - 10.0
- pix_sphere_scale = torch.round(pix_sphere / scale).long()
- pix_scale = pix // scale
- pix_sphere_scale[:, 0] = pix_sphere_scale[:, 0].clamp(0, out_W-1)
- pix_sphere_scale[:, 1] = pix_sphere_scale[:, 1].clamp(0, out_H-1)
-
- map_sphere[pix_sphere_scale[:, 0],
- pix_sphere_scale[:, 1], :] = pix_scale
- map_sphere = map_sphere.reshape(-1, 2)
-
-
- map_sphere[:, 0] /= x.shape[3]
- map_sphere[:, 1] /= x.shape[2]
- map_sphere = map_sphere * 2 - 1
- map_sphere = map_sphere.reshape(1, 1, -1, 2)
-
- feats = F.grid_sample(
- x,
- map_sphere,
- align_corners=False,
- mode='bilinear'
- )
- feats = feats.reshape(feats.shape[0], feats.shape[1], out_W, out_H)
- feats = feats.permute(0, 1, 3, 2)
-
- return feats
-
- def forward(self, features, pix, pix_sphere):
- x_block1, x_block2, x_block4, x_block8, x_block16, x_block32 = (
- features[0],
- features[4],
- features[5],
- features[6],
- features[8],
- features[11],
- )
- bs = x_block32.shape[0]
- x_block32 = self.conv2(x_block32)
-
-
- x_sphere_32 = self.get_sphere_feature(x_block32, pix, pix_sphere, 32)
-
- x_sphere_16 = self.get_sphere_feature(x_block16, pix, pix_sphere, 16)
-
- x_sphere_8 = self.get_sphere_feature(x_block8, pix, pix_sphere, 8)
-
- x_sphere_4 = self.get_sphere_feature(x_block4, pix, pix_sphere, 4)
-
- x_sphere_2 = self.get_sphere_feature(x_block2, pix, pix_sphere, 2)
-
- x_sphere_1 = self.get_sphere_feature(x_block1, pix, pix_sphere, 1)
-
- x_1_16 = self.up16(x_sphere_32, x_sphere_16)
- x_1_8 = self.up8(x_1_16, x_sphere_8)
- x_1_4 = self.up4(x_1_8, x_sphere_4)
- x_1_2 = self.up2(x_1_4, x_sphere_2)
- x_1_1 = self.up1(x_1_2, x_sphere_1)
-
-
-
- return {
- "1_1": x_1_1,
- "1_2": x_1_2,
- "1_4": x_1_4,
- "1_8": x_1_8,
- "1_16": x_1_16,
- }
-
-
-class Encoder(nn.Module):
- def __init__(self, backend):
- super(Encoder, self).__init__()
- self.original_model = backend
-
- def forward(self, x):
- features = [x]
- for k, v in self.original_model._modules.items():
- if k == "blocks":
- for ki, vi in v._modules.items():
- features.append(vi(features[-1]))
- else:
- features.append(v(features[-1]))
- return features
-
-
-class UNet2DSphere(nn.Module):
- def __init__(self, backend, num_features, out_feature, out_img_H, out_img_W):
- super(UNet2DSphere, self).__init__()
- self.encoder = Encoder(backend)
- self.out_img_H = out_img_H
- self.out_img_W = out_img_W
- self.decoder = DecoderSphere(
- out_feature=out_feature,
- bottleneck_features=num_features,
- num_features=num_features,
- out_img_W=out_img_W,
- out_img_H=out_img_H
- )
-
- def forward(self, x, pix, pix_sphere):
- encoded_feats = self.encoder(x)
- unet_out = self.decoder(encoded_feats, pix, pix_sphere)
- return unet_out
-
- def get_encoder_params(self):
- return self.encoder.parameters()
-
- def get_decoder_params(self):
- return self.decoder.parameters()
-
- @classmethod
- def build(cls, **kwargs):
- basemodel_name = "tf_efficientnet_b7_ns"
- num_features = 2560
-
- print("Loading base model ()...".format(basemodel_name), end="")
- basemodel = torch.hub.load(
- "rwightman/gen-efficientnet-pytorch", basemodel_name, pretrained=True
- )
- print("Done.")
-
- # Remove last layer
- print("Removing last two layers (global_pool & classifier).")
- basemodel.global_pool = nn.Identity()
- basemodel.classifier = nn.Identity()
-
- # Building Encoder-Decoder model
- print("Building Encoder-Decoder model..", end="")
- m = cls(basemodel, num_features=num_features, **kwargs)
- print("Done.")
+"""
+Code adapted from https://github.com/cv-rits/MonoScene/blob/master/monoscene/models/unet2d.py
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class BasicBlock(nn.Module):
+ def __init__(self, channel_num, dilations):
+ super(BasicBlock, self).__init__()
+
+ self.conv_block1 = nn.Sequential(
+ nn.Conv2d(channel_num, channel_num, 3,
+ padding=dilations[0], dilation=dilations[0]),
+ nn.BatchNorm2d(channel_num),
+ nn.LeakyReLU(),
+ # nn.ReLU(),
+ )
+ self.conv_block2 = nn.Sequential(
+ nn.Conv2d(channel_num, channel_num, 3,
+ padding=dilations[1], dilation=dilations[1]),
+ nn.BatchNorm2d(channel_num),
+ )
+ self.lrelu = nn.LeakyReLU()
+ # self.lrelu = nn.ReLU()
+
+ def forward(self, x):
+ residual = x
+ x = self.conv_block1(x)
+ x = self.conv_block2(x)
+ x = x + residual
+ out = self.lrelu(x)
+ return out
+
+
+class UpSampleBN(nn.Module):
+ def __init__(self, skip_input, output_features):
+ super(UpSampleBN, self).__init__()
+ self._net = nn.Sequential(
+ nn.Conv2d(skip_input, output_features,
+ kernel_size=3, stride=1, padding=1),
+ BasicBlock(output_features, dilations=[1, 1]),
+ BasicBlock(output_features, dilations=[2, 2]),
+ BasicBlock(output_features, dilations=[3, 3]),
+ )
+
+ def forward(self, x, concat_with):
+ up_x = F.interpolate(
+ x,
+ size=(concat_with.shape[2], concat_with.shape[3]),
+ mode="bilinear",
+ align_corners=True,
+ )
+ f = torch.cat([up_x, concat_with], dim=1)
+ return self._net(f)
+
+
+class DecoderSphere(nn.Module):
+ def __init__(
+ self,
+ num_features,
+ bottleneck_features,
+ out_feature,
+ out_img_W,
+ out_img_H
+ ):
+ super(DecoderSphere, self).__init__()
+
+ self.out_img_W = out_img_W
+ self.out_img_H = out_img_H
+
+ features = int(num_features)
+
+ self.conv2 = nn.Conv2d(
+ bottleneck_features, features, kernel_size=1, stride=1, padding=1
+ )
+
+ self.out_feature_1_1 = out_feature
+ self.out_feature_1_2 = out_feature
+ self.out_feature_1_4 = out_feature
+ self.out_feature_1_8 = out_feature
+ self.out_feature_1_16 = out_feature
+ self.feature_1_16 = features // 2
+ self.feature_1_8 = features // 4
+ self.feature_1_4 = features // 8
+ self.feature_1_2 = features // 16
+ self.feature_1_1 = features // 32
+
+ self.resize_1_1 = nn.Conv2d(
+ 3, self.feature_1_1, kernel_size=1
+ )
+ self.resize_1_2 = nn.Conv2d(
+ 32, self.feature_1_2, kernel_size=1
+ )
+ self.resize_1_4 = nn.Conv2d(
+ 48, self.feature_1_4, kernel_size=1
+ )
+ self.resize_1_8 = nn.Conv2d(
+ 80, self.feature_1_8, kernel_size=1
+ )
+ self.resize_1_16 = nn.Conv2d(
+ 224, self.feature_1_16, kernel_size=1
+ )
+
+ self.resize_output_1_1 = nn.Conv2d(
+ self.feature_1_1, self.out_feature_1_1, kernel_size=1
+ )
+ self.resize_output_1_2 = nn.Conv2d(
+ self.feature_1_2, self.out_feature_1_2, kernel_size=1
+ )
+ self.resize_output_1_4 = nn.Conv2d(
+ self.feature_1_4, self.out_feature_1_4, kernel_size=1
+ )
+ self.resize_output_1_8 = nn.Conv2d(
+ self.feature_1_8, self.out_feature_1_8, kernel_size=1
+ )
+ self.resize_output_1_16 = nn.Conv2d(
+ self.feature_1_16, self.out_feature_1_16, kernel_size=1
+ )
+
+ self.up16 = UpSampleBN(
+ skip_input=features + 224, output_features=self.feature_1_16
+ )
+ self.up8 = UpSampleBN(
+ skip_input=self.feature_1_16 + 80, output_features=self.feature_1_8
+ )
+ self.up4 = UpSampleBN(
+ skip_input=self.feature_1_8 + 48, output_features=self.feature_1_4,
+ )
+ self.up2 = UpSampleBN(
+ skip_input=self.feature_1_4 + 32, output_features=self.feature_1_2
+ )
+ self.up1 = UpSampleBN(
+ skip_input=self.feature_1_2 + 3, output_features=self.feature_1_1
+ )
+
+ def get_sphere_feature(self, x, pix, pix_sphere, scale):
+ out_W, out_H = round(self.out_img_W/scale), round(self.out_img_H/scale)
+ map_sphere = torch.zeros((out_W, out_H, 2)).type_as(x) - 10.0
+ pix_sphere_scale = torch.round(pix_sphere / scale).long()
+ pix_scale = pix // scale
+ pix_sphere_scale[:, 0] = pix_sphere_scale[:, 0].clamp(0, out_W-1)
+ pix_sphere_scale[:, 1] = pix_sphere_scale[:, 1].clamp(0, out_H-1)
+
+ map_sphere[pix_sphere_scale[:, 0],
+ pix_sphere_scale[:, 1], :] = pix_scale
+ map_sphere = map_sphere.reshape(-1, 2)
+
+
+ map_sphere[:, 0] /= x.shape[3]
+ map_sphere[:, 1] /= x.shape[2]
+ map_sphere = map_sphere * 2 - 1
+ map_sphere = map_sphere.reshape(1, 1, -1, 2)
+
+ feats = F.grid_sample(
+ x,
+ map_sphere,
+ align_corners=False,
+ mode='bilinear'
+ )
+ feats = feats.reshape(feats.shape[0], feats.shape[1], out_W, out_H)
+ feats = feats.permute(0, 1, 3, 2)
+
+ return feats
+
+ def forward(self, features, pix, pix_sphere):
+ x_block1, x_block2, x_block4, x_block8, x_block16, x_block32 = (
+ features[0],
+ features[4],
+ features[5],
+ features[6],
+ features[8],
+ features[11],
+ )
+ bs = x_block32.shape[0]
+ x_block32 = self.conv2(x_block32)
+
+
+ x_sphere_32 = self.get_sphere_feature(x_block32, pix, pix_sphere, 32)
+
+ x_sphere_16 = self.get_sphere_feature(x_block16, pix, pix_sphere, 16)
+
+ x_sphere_8 = self.get_sphere_feature(x_block8, pix, pix_sphere, 8)
+
+ x_sphere_4 = self.get_sphere_feature(x_block4, pix, pix_sphere, 4)
+
+ x_sphere_2 = self.get_sphere_feature(x_block2, pix, pix_sphere, 2)
+
+ x_sphere_1 = self.get_sphere_feature(x_block1, pix, pix_sphere, 1)
+
+ x_1_16 = self.up16(x_sphere_32, x_sphere_16)
+ x_1_8 = self.up8(x_1_16, x_sphere_8)
+ x_1_4 = self.up4(x_1_8, x_sphere_4)
+ x_1_2 = self.up2(x_1_4, x_sphere_2)
+ x_1_1 = self.up1(x_1_2, x_sphere_1)
+
+
+
+ return {
+ "1_1": x_1_1,
+ "1_2": x_1_2,
+ "1_4": x_1_4,
+ "1_8": x_1_8,
+ "1_16": x_1_16,
+ }
+
+
+class Encoder(nn.Module):
+ def __init__(self, backend):
+ super(Encoder, self).__init__()
+ self.original_model = backend
+
+ def forward(self, x):
+ features = [x]
+ for k, v in self.original_model._modules.items():
+ if k == "blocks":
+ for ki, vi in v._modules.items():
+ features.append(vi(features[-1]))
+ else:
+ features.append(v(features[-1]))
+ return features
+
+
+class UNet2DSphere(nn.Module):
+ def __init__(self, backend, num_features, out_feature, out_img_H, out_img_W):
+ super(UNet2DSphere, self).__init__()
+ self.encoder = Encoder(backend)
+ self.out_img_H = out_img_H
+ self.out_img_W = out_img_W
+ self.decoder = DecoderSphere(
+ out_feature=out_feature,
+ bottleneck_features=num_features,
+ num_features=num_features,
+ out_img_W=out_img_W,
+ out_img_H=out_img_H
+ )
+
+ def forward(self, x, pix, pix_sphere):
+ encoded_feats = self.encoder(x)
+ unet_out = self.decoder(encoded_feats, pix, pix_sphere)
+ return unet_out
+
+ def get_encoder_params(self):
+ return self.encoder.parameters()
+
+ def get_decoder_params(self):
+ return self.decoder.parameters()
+
+ @classmethod
+ def build(cls, **kwargs):
+ basemodel_name = "tf_efficientnet_b7_ns"
+ num_features = 2560
+
+ print("Loading base model ()...".format(basemodel_name), end="")
+ basemodel = torch.hub.load(
+ "rwightman/gen-efficientnet-pytorch", basemodel_name, pretrained=True
+ )
+ print("Done.")
+
+ # Remove last layer
+ print("Removing last two layers (global_pool & classifier).")
+ basemodel.global_pool = nn.Identity()
+ basemodel.classifier = nn.Identity()
+
+ # Building Encoder-Decoder model
+ print("Building Encoder-Decoder model..", end="")
+ m = cls(basemodel, num_features=num_features, **kwargs)
+ print("Done.")
return m
\ No newline at end of file
diff --git a/scenerf/scripts/evaluation/agg_depth_metrics_bf.py b/scenerf/scripts/evaluation/agg_depth_metrics_bf.py
index 436f007..1d840ea 100644
--- a/scenerf/scripts/evaluation/agg_depth_metrics_bf.py
+++ b/scenerf/scripts/evaluation/agg_depth_metrics_bf.py
@@ -13,15 +13,17 @@
@click.option('--n_gpus', default=1, help='number of GPUs')
@click.option('--bs', default=1, help='Batch size')
@click.option('--n_workers_per_gpu', default=3, help='number of workers per GPU')
+@click.option('--dataset', default='bf', help='bf or tum_rgbd dataset to eval on')
@click.option('--root', default="", help='path to dataset folder')
@click.option('--eval_save_dir', default="")
def main(
- root,
+ root, dataset,
bs, n_gpus, n_workers_per_gpu,
eval_save_dir
):
data_module = BundlefusionDM(
+ dataset=dataset,
root=root,
batch_size=int(bs / n_gpus),
num_workers=int(n_workers_per_gpu),
@@ -66,7 +68,7 @@ def main(
print("=================")
print("====== End ======")
print("=================")
- print(model_name)
+ # print(model_name)
print_metrics(agg_depth_errors, agg_n_frames)
diff --git a/scenerf/scripts/evaluation/eval_color_bf.py b/scenerf/scripts/evaluation/eval_color_bf.py
index 4f5256d..c4b81f0 100644
--- a/scenerf/scripts/evaluation/eval_color_bf.py
+++ b/scenerf/scripts/evaluation/eval_color_bf.py
@@ -60,9 +60,13 @@ def print_metrics(psnr_accum, ssim_accum, lpips_accum, cnt_accum):
@click.command()
@click.option('--eval_save_dir', default="")
-def main(eval_save_dir):
+@click.option('--dataset', default='bf', help='bf or tum_rgbd dataset to evaluate on')
+def main(eval_save_dir, dataset):
- sequence = "copyroom"
+ if dataset == "bf":
+ sequence = "copyroom"
+ elif dataset == "tum_rgbd":
+ sequence = "rgbd_dataset_freiburg3_long_office_household"
rgb_save_dir = os.path.join(eval_save_dir, "rgb", sequence)
render_rgb_save_dir = os.path.join(eval_save_dir, "render_rgb", sequence)
rgb_paths = glob.glob(os.path.join(rgb_save_dir, "*.png"))
diff --git a/scenerf/scripts/evaluation/eval_sc_bf.py b/scenerf/scripts/evaluation/eval_sc_bf.py
index b7bd8aa..c411c1a 100644
--- a/scenerf/scripts/evaluation/eval_sc_bf.py
+++ b/scenerf/scripts/evaluation/eval_sc_bf.py
@@ -52,12 +52,14 @@ def evaluate_depth(gt_depth, pred_depth):
@click.option('--model_name', default="", help='model name')
@click.option('--bs', default=1, help='Batch size')
@click.option('--n_workers_per_gpu', default=3, help='number of workers per GPU')
+@click.option('--dataset', default='bf', help='bf or tum_rgbd dataset to eval on')
@click.option('--root', default="/gpfsdswork/dataset/bundlefusion", help='path to dataset folder')
@click.option('--recon_save_dir')
-def main(root, bs, n_gpus, n_workers_per_gpu, model_name, recon_save_dir):
+def main(root, dataset, bs, n_gpus, n_workers_per_gpu, model_name, recon_save_dir):
data_module = BundlefusionDM(
+ dataset,
root=root,
batch_size=int(bs / n_gpus),
num_workers=int(n_workers_per_gpu),
diff --git a/scenerf/scripts/evaluation/render_colors_bf.py b/scenerf/scripts/evaluation/render_colors_bf.py
index 27e37bb..db19b1b 100644
--- a/scenerf/scripts/evaluation/render_colors_bf.py
+++ b/scenerf/scripts/evaluation/render_colors_bf.py
@@ -38,15 +38,17 @@ def disparity_normalization_vis(disparity):
@click.option('--save_depth', default=True)
@click.option('--model_path', default="", help='model path')
@click.option('--n_workers_per_gpu', default=10, help='number of workers per GPU')
+@click.option('--dataset', default='bf', help='bf or tum_rgbd dataset to eval on')
@click.option('--root', default="/gpfsdswork/dataset/bundlefusion", help='path to dataset folder')
@click.option('--eval_save_dir', default="")
def main(
- root, bs, n_gpus, n_workers_per_gpu,
+ root, dataset, bs, n_gpus, n_workers_per_gpu,
model_path, save_depth, eval_save_dir
):
torch.set_grad_enabled(False)
data_module = BundlefusionDM(
+ dataset=dataset,
root=root,
batch_size=int(bs / n_gpus),
num_workers=int(n_workers_per_gpu),
diff --git a/scenerf/scripts/evaluation/save_depth_metrics_bf.py b/scenerf/scripts/evaluation/save_depth_metrics_bf.py
index c2f2be0..392d398 100644
--- a/scenerf/scripts/evaluation/save_depth_metrics_bf.py
+++ b/scenerf/scripts/evaluation/save_depth_metrics_bf.py
@@ -35,17 +35,20 @@ def evaluate_depth(gt_depth, pred_depth):
@click.option('--n_gpus', default=1, help='number of GPUs')
@click.option('--bs', default=1, help='Batch size')
@click.option('--n_workers_per_gpu', default=3, help='number of workers per GPU')
+@click.option('--dataset', default='bf', help='bf or tum_rgbd dataset to eval on')
@click.option('--root', default="", help='path to dataset folder')
@click.option('--model_path', default="", help='model path')
@click.option('--eval_save_dir', default="")
+
def main(
- root,
+ root, dataset,
bs, n_gpus, n_workers_per_gpu,
model_path, eval_save_dir):
data_module = BundlefusionDM(
+ dataset=dataset,
root=root,
batch_size=int(bs / n_gpus),
num_workers=int(n_workers_per_gpu),
diff --git a/scenerf/scripts/reconstruction/depth2tsdf_bf.py b/scenerf/scripts/reconstruction/depth2tsdf_bf.py
index 0087995..4aef74c 100644
--- a/scenerf/scripts/reconstruction/depth2tsdf_bf.py
+++ b/scenerf/scripts/reconstruction/depth2tsdf_bf.py
@@ -52,14 +52,16 @@ def evaluate_depth(gt_depth, pred_depth):
@click.option('--n_gpus', default=1, help='number of GPUs')
@click.option('--bs', default=1, help='Batch size')
@click.option('--n_workers_per_gpu', default=3, help='number of workers per GPU')
+@click.option('--dataset', default='bf', help='bf or tum_rgbd dataset to eval on')
@click.option('--root', default="/gpfsdswork/dataset/bundlefusion", help='path to dataset folder')
@click.option('--recon_save_dir', default="")
@click.option('--angle', default=30)
@click.option('--step', default=0.2)
@click.option('--max_distance', default=2.1, help='max pose sample distance')
-def main(root, bs, n_gpus, n_workers_per_gpu, recon_save_dir, max_distance, step, angle):
+def main(root, dataset, bs, n_gpus, n_workers_per_gpu, recon_save_dir, max_distance, step, angle):
data_module = BundlefusionDM(
+ dataset,
root=root,
batch_size=int(bs / n_gpus),
num_workers=int(n_workers_per_gpu),
diff --git a/scenerf/scripts/reconstruction/generate_novel_depths_bf.py b/scenerf/scripts/reconstruction/generate_novel_depths_bf.py
index b86f7a0..8172db5 100644
--- a/scenerf/scripts/reconstruction/generate_novel_depths_bf.py
+++ b/scenerf/scripts/reconstruction/generate_novel_depths_bf.py
@@ -34,17 +34,19 @@ def disparity_normalization_vis(disparity):
@click.option('--n_gpus', default=1, help='number of GPUs')
@click.option('--bs', default=1, help='Batch size')
@click.option('--n_workers_per_gpu', default=10, help='number of workers per GPU')
+@click.option('--dataset', default='bf', help='bf or tum_rgbd dataset to eval on')
@click.option('--root', default="/gpfsdswork/dataset/bundlefusion", help='path to dataset folder')
@click.option('--model_path', default="", help='model path')
@click.option('--recon_save_dir', default="")
@click.option('--angle', default=30)
@click.option('--step', default=0.2)
@click.option('--max_distance', default=2.1, help='max pose sample distance')
-def main(root, bs, n_gpus, n_workers_per_gpu, model_path,
+def main(root, dataset, bs, n_gpus, n_workers_per_gpu, model_path,
recon_save_dir, max_distance, step, angle):
torch.set_grad_enabled(False)
data_module = BundlefusionDM(
+ dataset=dataset,
root=root,
batch_size=int(bs / n_gpus),
num_workers=int(n_workers_per_gpu),
diff --git a/scenerf/scripts/reconstruction/generate_sc_gt_bf.py b/scenerf/scripts/reconstruction/generate_sc_gt_bf.py
index 2fe4cbd..2393d68 100644
--- a/scenerf/scripts/reconstruction/generate_sc_gt_bf.py
+++ b/scenerf/scripts/reconstruction/generate_sc_gt_bf.py
@@ -15,11 +15,13 @@
@click.option('--n_gpus', default=1, help='number of GPUs')
@click.option('--bs', default=1, help='Batch size')
@click.option('--n_workers_per_gpu', default=3, help='number of workers per GPU')
+@click.option('--dataset', default='bf', help='bf or tum_rgbd dataset to eval on')
@click.option('--root', default="", help='path to dataset folder')
@click.option('--recon_save_dir')
-def main(root, bs, n_gpus, n_workers_per_gpu, recon_save_dir):
+def main(root, dataset, bs, n_gpus, n_workers_per_gpu, recon_save_dir):
data_module = BundlefusionDM(
+ dataset,
root=root,
batch_size=int(bs / n_gpus),
num_workers=int(n_workers_per_gpu),
diff --git a/scenerf/scripts/train_bundlefusion.py b/scenerf/scripts/train_bundlefusion.py
index 7efdc22..178b38c 100644
--- a/scenerf/scripts/train_bundlefusion.py
+++ b/scenerf/scripts/train_bundlefusion.py
@@ -19,6 +19,7 @@
@click.command()
+@click.option('--dataset', default='bf', help='bf or tum_rgbd dataset to train on')
@click.option('--logdir', default='', help='log directory')
@click.option('--root', default='', help='path to dataset folder')
@click.option('--bs', default=1, help='Batch size')
@@ -39,6 +40,7 @@
@click.option('--n_pts_per_gaussian', default=8, help='number of points sampled for each gaussian')
@click.option('--n_gaussians', default=4, help='number of gaussians')
@click.option('--n_pts_uni', default=32, help='number of points sampled uniformly')
+@click.option('--n_pts_hier', default=32, help='number of points sampled hierarchically')
@click.option('--std', default=0.1, help='std of each gaussian')
@click.option('--add_fov_hor', default=14, help='angle added to left and right of the horizontal FOV')
@@ -50,7 +52,7 @@
@click.option('--som_sigma', default=0.02, help='sigma parameter for SOM')
@click.option('--net_2d', default="b7", help='')
-@click.option('--max_epochs', default=50, help='max training epochs')
+@click.option('--max_epochs', default=30, help='max training epochs')
@click.option('--use_color', default=True, help='use color loss')
@click.option('--use_reprojection', default=True, help='use reprojection loss')
@@ -58,7 +60,7 @@
@click.option('--frame_interval', default=2, help='interval between frames in a sequence')
def main(
- root,
+ dataset, root,
bs, n_gpus, n_workers_per_gpu,
exp_prefix, pretrained_exp_name,
logdir, enable_log,
@@ -66,7 +68,7 @@ def main(
n_rays, sample_grid_size,
smooth_loss_weight,
max_sample_depth, eval_depth,
- n_pts_uni,
+ n_pts_uni, n_pts_hier,
n_pts_per_gaussian, n_gaussians, std, som_sigma,
add_fov_hor, add_fov_ver,
use_color, use_reprojection,
@@ -92,6 +94,7 @@ def main(
# max_epochs = 20
data_module = BundlefusionDM(
+ dataset=dataset,
root=root,
batch_size=int(bs / n_gpus),
num_workers=int(n_workers_per_gpu),
@@ -106,6 +109,7 @@ def main(
model = SceneRF(
lr=lr,
n_pts_uni=n_pts_uni,
+ n_pts_hier=n_pts_hier,
weight_decay=wd,
n_rays=n_rays,
smooth_loss_weight=smooth_loss_weight,
diff --git a/train_eval_bash_scripts/eval_bundlefusion_scaled_down.sh b/train_eval_bash_scripts/eval_bundlefusion_scaled_down.sh
new file mode 100755
index 0000000..f9e8c2d
--- /dev/null
+++ b/train_eval_bash_scripts/eval_bundlefusion_scaled_down.sh
@@ -0,0 +1,55 @@
+#!/bin/bash
+
+# Set environment variables
+export DATASET=tum_rgbd
+export BF_ROOT=/root/dataset/tum_rgbd
+export BF_LOG=/root/SceneRFGroupProject/logs/tum_rgbd
+export EVAL_SAVE_DIR=/root/SceneRFGroupProject/evaluation/tum_rgbd/eval
+export RECON_SAVE_DIR=/root/SceneRFGroupProject/evaluation/tum_rgbd/recon
+export MODEL_PATH=/root/SceneRFGroupProject/logs/tum_rgbd/vanilla_exp/vanilla_tum_last.ckpt
+
+# Novel depths synthesis on Bundlefusion
+python scenerf/scripts/evaluation/save_depth_metrics_bf.py \
+ --eval_save_dir=$EVAL_SAVE_DIR \
+ --dataset=$DATASET \
+ --root=$BF_ROOT \
+ --model_path=$MODEL_PATH
+
+python scenerf/scripts/evaluation/agg_depth_metrics_bf.py \
+ --eval_save_dir=$EVAL_SAVE_DIR \
+ --dataset=$DATASET \
+ --root=$BF_ROOT
+
+
+# Novel views synthesis on Bundlefusion
+python scenerf/scripts/evaluation/render_colors_bf.py \
+ --eval_save_dir=$EVAL_SAVE_DIR \
+ --dataset=$DATASET \
+ --root=$BF_ROOT \
+ --model_path=$MODEL_PATH
+
+python scenerf/scripts/evaluation/eval_color_bf.py --eval_save_dir=$EVAL_SAVE_DIR --dataset=$DATASET
+
+# Scene reconstruction on Bundlefusion
+python scenerf/scripts/reconstruction/generate_novel_depths_bf.py \
+ --recon_save_dir=$RECON_SAVE_DIR \
+ --dataset=$DATASET \
+ --root=$BF_ROOT \
+ --model_path=$MODEL_PATH \
+ --angle=30 --step=0.2 --max_distance=2.1
+
+python scenerf/scripts/reconstruction/depth2tsdf_bf.py \
+ --recon_save_dir=$RECON_SAVE_DIR \
+ --dataset=$DATASET \
+ --root=$BF_ROOT \
+ --angle=30 --step=0.2 --max_distance=2.1
+
+python scenerf/scripts/reconstruction/generate_sc_gt_bf.py \
+ --recon_save_dir=$RECON_SAVE_DIR \
+ --dataset=$DATASET \
+ --root=$BF_ROOT
+
+python scenerf/scripts/evaluation/eval_sc_bf.py \
+ --recon_save_dir=$RECON_SAVE_DIR \
+ --dataset=$DATASET \
+ --root=$BF_ROOT
\ No newline at end of file
diff --git a/train_eval_bash_scripts/train_bundlefusion_scaled_down.sh b/train_eval_bash_scripts/train_bundlefusion_scaled_down.sh
new file mode 100755
index 0000000..e3c2048
--- /dev/null
+++ b/train_eval_bash_scripts/train_bundlefusion_scaled_down.sh
@@ -0,0 +1,28 @@
+#!/bin/bash
+
+# Set environment variables
+export DATASET=tum_rgbd
+export BF_ROOT=/root/dataset/tum_rgbd
+export BF_LOG=/root/SceneRFGroupProject/logs/tum_rgbd
+
+# Run the training script
+python scenerf/scripts/train_bundlefusion.py --bs=1 --n_gpus=1 --n_workers_per_gpu=4\
+ --n_rays=1024 \
+ --lr=2e-5 \
+ --enable_log=True \
+ --dataset=$DATASET \
+ --root=$BF_ROOT \
+ --logdir=$BF_LOG \
+ --sample_grid_size=2 \
+ --n_gaussians=2 \
+ --n_pts_per_gaussian=4 \
+ --n_pts_uni=8 \
+ --n_pts_hier=8 \
+ --add_fov_hor=7 \
+ --add_fov_ver=5 \
+ --sphere_h=480 \
+ --sphere_w=640 \
+ --max_sample_depth=8 \
+ --n_frames=16 \
+ --frame_interval=2 \
+ --max_epochs=30