-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Allow for >1 batch size in Splatfacto #3582
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Hey Alex! this is super cool; especially in MCMC which doesn't require gradient thresholds at all. #3216 might have mildly broken parts of this PR since it merged in parallel dataloading, but it shouldn't be too bad; let us know if you want any help fixing conflicts! |
@akristoffersen I think you might want to modify |
2d95d1e
to
cbb5ceb
Compare
Works with masks now As expected, I noticed a almost 2x increase in rays/s with a batch size of two, and a very slight performance drop with a batch size of 1 compared to baseline (50.1 M rays/sec -> 48 M rays/sec) |
@hardikdava do you mean that the tuning might be different for the thresholds? Yeah, I don't know exactly what to do there. maybe someone else has an opinion? Some quick stats on the poster dataset. ![]() so the splitting / densification outcomes are affected by batch size. ![]() Similarly, train rays/sec do start higher due to the larger batch size, but go down as you'd expect with the higher number of gaussians. ![]() Some good news, with a higher batch I do see the training loss hitting better values quicker as the batch size increases. |
@akristoffersen currently, densification, splitting and culling are implemented inside strategy and logic is based on In simple words, suppose the batch size is 2, opacity reset needs to be applied at every 3000th step. So it should happen at every 1500th steps according to batch size. But according to your current implementation it will be applied at every 3000th steps but actually it will be 6000th step (batch size * step). |
@hardikdava I think dividing those parameters by the batch size assumes that every image produces gradients for a unique set of gaussians. If there's any overlap, then the gaussians seen by 2 images would just be getting a single gradient descent update applied to them (albeit of a possibly better quality because of the signal from both images), while if it was a single image batch those gaussians would have gotten 2 gradient descent updates applied to them. I think that dividing those params by the batch size could still be a good approximation, I'll try it and see how the losses look. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested these changes on 2 of my datasets with the following commands
- ns-train
- ns-render
- ns-eval
- ns-export
These all worked well! @jeffreyhuparallel do you have any comments about the hyperparameter strategy stuff?
How would batch size work with a dataset of several thousand images? With opacity reset in mind etc... :) I suspect that scenes with lots of images would inherently benefit from a larger batch size because more images are part of the training per step. How is the memory use? Are we able to train with, say, 20, 50 or 100 images per step? How would that affect training speed and memory use? Is it linear? |
All good questions @abrahamezzeddine , unfortunately I haven't had cycles as of late to finish up the implementation and do the necessary benchmarking. @AntonioMacaronio has been helping on that front. On the large dataset question, I think you're right. Something that has always bothered me about gs is that a batch isn't representative of the full objective-- with NeRFs the batch is a random collection of rays from all images so that helps. I do suspect that at the moment, splatfacto is memory bound, so increasing the batch size may not improve things as you'd expect. But I think it should make the scene converge better / faster. |
Thanks for the quick response! Just food for thought: As we reach a satisfactory loss at the initial stage, an upsample session starts and splatfacto progressively reduce the batch size while increasing the image resolution, iteratively refining until we achieve the desired quality at full resolution. Essentially, starting with maximum batch size with lower quality images, and gradually trading off batch size (N-times to the desired batch size) for higher resolution during training. Do you think this would help converge complex scenes initially as we upsample and reduce batch size? |
We sort of already do this-- we initially train on downsampled images and then increase the resolution as training continues. But inversely scaling the batch size as training goes on also sounds like a good idea. Something I've also wanted to try (probably in a separate PR) is to load in a large batch of patches, so the training acts more like NeRFs. You could do this by keeping the focal lengths the same, but augmenting the principal points for each patch. |
Thanks. Dynamic batching is indeed an interesting possibility. Two thoughts: 💭 How would batching process images? Randomly or sequentially? One option is to simply process images in the order they were captured—for example, taking the first n images as one batch, then the next n, and so on. The idea here is that sequential ordering might naturally preserve temporal continuity, so adjacent frames (which are likely to have similar viewpoints) get processed together. But that can perhaps be difficult to know depending on what the user matched the images with; exhaustive, sequential or vocabulary tree. Another approach is to order the sparse point cloud using a Hilbert curve. Since a Hilbert curve is a space-filling curve that preserves locality, it essentially divides your scene into “patches” of points that are spatially close together. For instance, if you select 10,000 consecutive points from this 1D Hilbert index, you’re effectively picking a coherent patch of the scene. If you divide them into n-points, you essentially create patches of local regions. You can then choose the images that see these points for your batch based on the colmap input data. Since the images are already ordered according to the hilbert curve, it’s easy to keep track which images belongs to which patch. This strategy explicitly enforces spatial coherence, ensuring that each batch is focused on a local region of the scene as it is training. You Would love to hear your thoughts about this. |
You might want to check out https://arxiv.org/abs/2501.13975 , they have a similar "locality" heuristic that they use to pull multiple images seeing the same region of the scene. They say that this helps prevent overshoot/overfitting to a single image which could happen as they are using a second-order optimization algorithm. My take is that with a suitably large and diverse batch, this might not be a problem? But I agree that with smaller batches, a local neighborhood might work out better. I imagine batch-building heuristic doesn't have to be super complicated to get the behavior you'd want. |
Trying this out now myself and seems to converge initially much faster with a batch size of 50 at the moment. Using 2K resolution images and around 2500 images. 18GB VRAM is used of 48GB VRAM. 750 (2.50%) 699.992 ms 5 h, 41 m, 14 s 494.09 M Not the fastest but as long as it produces a high quality output, it's fine I guess. =) |
I am not seeing the linear increase in rays/s with larger batch sizes. Is there a diminishing effect after a certain batch? |
Yes, please see the initial wandb results in an earlier comment. Initially the ray throughput scales, but I think because the splitting behavior currently assumes a single image per patch, we are getting many more gaussians with higher batch sizes. |
Ok, thanks. The learning rates, should one consider the square root batch scaling due to larger batch size? Bilagrad was also not working but I made these changes to have it work again with batched training. Not sure however if this is "compatible" with Bilagrad.
|
Recently I tested Why the metrics (PSNR, SSIM, LPIPS) are fluctuating up and down in Blue line in Nerfstudio commit fix alpha compositing. ??, while pink line (Nerfstudio commit 194b5d4) is more stable.. |
Found another bugs. Batch-size > 1 is not working when using multi camera setup (For example, the images resolution is not same, some images are landscape, and others is potrait. You can test my dataset here:( https://drive.google.com/file/d/1NWZSDU9tEmrAtpKxntTw6YBge_AZ66mf/view?usp=sharing)
And yeah, this code works, but the TV_loss is still zero when I activated bilagrid.
|
@ichsan2895 thank you for the beautiful testing! The batching with cameras of different resolutions is concerning, and I suspect this is something that just can't be supported until Pytorch supports jagged tensors. Afaik, it is called NestedTensors and it's currently in beta, but it will likely be some time before it is supported perhaps the current best solution is to just not allow batching when images of varying resolution are given |
I can also mention that camera optim does not work with batch size over 1. With a few modifications, I had made it work again. |
How do you activate the train validation loss in the console log output? Maybe I can check and see what to find. |
Preliminary BenchmarkJust benchmarking Mip360 dataset with various value of batch-size In this time, for each scene, I just run 1000 steps only. Maybe when I have other free time, I will set it to 30k steps. FYI, mip-360's downscaled images is not compatible with nerfstudio since nerfstudio needs downscale with floor rounding decimals. So, I resize it manually.. See this #1438 for discussion. ns-train splatfacto --pipeline.datamanager.batch-size 1, Nerfstudio commit d5bdd45, Python 3.10, RTX4090
ns-train splatfacto --pipeline.datamanager.batch-size 2, Nerfstudio commit d5bdd45, Python 3.10, RTX4090
ns-train splatfacto --pipeline.datamanager.batch-size 3, Nerfstudio commit d5bdd45, Python 3.10, RTX4090
SPLATFACTO-BIGns-train splatfacto-big --pipeline.datamanager.batch-size 1, Nerfstudio commit d5bdd45, Python 3.10, RTX4090
ns-train splatfacto-big --pipeline.datamanager.batch-size 2, Nerfstudio commit d5bdd45, Python 3.10, RTX4090
ns-train splatfacto-big --pipeline.datamanager.batch-size 3, Nerfstudio commit d5bdd45, Python 3.10, RTX4090
MCMCns-train splatfacto-mcmc --pipeline.datamanager.batch-size 1, Nerfstudio commit d5bdd45, Python 3.10, RTX4090
ns-train splatfacto-mcmc --pipeline.datamanager.batch-size 2, Nerfstudio commit d5bdd45, Python 3.10, RTX4090
ns-train splatfacto-mcmc --pipeline.datamanager.batch-size 2 --pipeline.datamanager.train-cameras-sampling-strategy fps, Nerfstudio commit d5bdd45, Python 3.10, RTX4090
|
@abrahamezzeddine I use wandb for logging.
|
Sorry for the neglect of this pr all, @ichsan2895 thank you so much for the benchmarking, it's a real help. I will try to crush the bugs re: bilagrid this weekend. Sorry again for the delay here. Regarding multi-res camera support, I think I'm okay limiting that ability at the moment, though if I go through with the "large patch sampling" technique described above, that limit can go away. |
d5bdd45
to
7c7b859
Compare
Bilagrid+RGBA dataset does not work: >> ns-train splatfacto --vis viewer+wandb \
--pipeline.model.use-bilateral-grid True --pipeline.model.color-corrected-metrics True \
--pipeline.datamanager.batch-size 2 \
nerfstudio-data \
--data path/to/scene --downscale-factor 1 .
.
.
Step (% Done) Train Iter (time) ETA (time)
--------------------------------------------------------------
Step (% Done) Train Iter (time) ETA (time)
--------------------------------------------------------------
0 (0.00%) 1 m, 7 s 23 d, 6 h, 36 m, 54 s
----------------------------------------------------------------------------------------------------
Viewer running locally at: http://localhost:7007/ (listening on 0.0.0.0)
[06:48:33] Caching / undistorting eval images �]8;id=439898;file:///workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/data/datamanagers/full_images_datamanager.py�\full_images_datamanager.py�]8;;�\:�]8;id=231148;file:///workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/data/datamanagers/full_images_datamanager.py#241�\241�]8;;�\
Caching / undistorting eval images ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:0600:0100:01
Printing profiling stats, from longest to shortest duration in seconds
VanillaPipeline.get_eval_image_metrics_and_images: 7.3795
Trainer.train_iteration: 0.7054
VanillaPipeline.get_train_loss_dict: 0.6988
Trainer.eval_iteration: 0.0731
Traceback (most recent call last):
File "/usr/local/bin/ns-train", line 8, in <module>
sys.exit(entrypoint())
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/scripts/train.py", line 272, in entrypoint
main(
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/scripts/train.py", line 257, in main
launch(
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/scripts/train.py", line 190, in launch
main_func(local_rank=0, world_size=world_size, config=config)
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/scripts/train.py", line 101, in train_loop
trainer.train()
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/engine/trainer.py", line 304, in train
self.eval_iteration(step)
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/utils/decorators.py", line 71, in wrapper
ret = func(self, *args, **kwargs)
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/utils/profiler.py", line 111, in inner
out = func(*args, **kwargs)
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/engine/trainer.py", line 551, in eval_iteration
metrics_dict, images_dict = self.pipeline.get_eval_image_metrics_and_images(step=step)
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/utils/profiler.py", line 111, in inner
out = func(*args, **kwargs)
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/pipelines/base_pipeline.py", line 339, in get_eval_image_metrics_and_images
metrics_dict, images_dict = self.model.get_image_metrics_and_images(outputs, batch)
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/models/splatfacto.py", line 763, in get_image_metrics_and_images
combined_rgb = torch.cat([gt_rgb, predicted_rgb], dim=1)
RuntimeError: Tensors must have same number of dimensions: got 4 and 3
Traceback (most recent call last):
File "/usr/local/bin/ns-train", line 8, in <module>
sys.exit(entrypoint())
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/scripts/train.py", line 272, in entrypoint
main(
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/scripts/train.py", line 257, in main
launch(
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/scripts/train.py", line 190, in launch
main_func(local_rank=0, world_size=world_size, config=config)
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/scripts/train.py", line 101, in train_loop
trainer.train()
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/engine/trainer.py", line 304, in train
self.eval_iteration(step)
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/utils/decorators.py", line 71, in wrapper
ret = func(self, *args, **kwargs)
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/utils/profiler.py", line 111, in inner
out = func(*args, **kwargs)
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/engine/trainer.py", line 551, in eval_iteration
metrics_dict, images_dict = self.pipeline.get_eval_image_metrics_and_images(step=step)
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/utils/profiler.py", line 111, in inner
out = func(*args, **kwargs)
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/pipelines/base_pipeline.py", line 339, in get_eval_image_metrics_and_images
metrics_dict, images_dict = self.model.get_image_metrics_and_images(outputs, batch)
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/models/splatfacto.py", line 763, in get_image_metrics_and_images
combined_rgb = torch.cat([gt_rgb, predicted_rgb], dim=1)
RuntimeError: Tensors must have same number of dimensions: got 4 and 3 |
Another error:Images and Mask does not work well >> ns-train splatfacto --vis viewer+wandb \
--pipeline.model.use-bilateral-grid True --pipeline.model.color-corrected-metrics True \
--pipeline.datamanager.batch-size 2 \
colmap \
--data path/to/scene --downscale-factor 1 --colmap-path "sparse/0" \
--images-path "images" --masks-path "masks" [08:31:21] Caching / undistorting train images �]8;id=442417;file:///workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/data/datamanagers/full_images_datamanager.py�\full_images_datamanager.py�]8;;�\:�]8;id=33326;file:///workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/data/datamanagers/full_images_datamanager.py#241�\241�]8;;�\
Caching / undistorting train images ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:3400:0100:02
/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:135: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
warnings.warn(
Printing profiling stats, from longest to shortest duration in seconds
Trainer.train_iteration: 47.2053
VanillaPipeline.get_train_loss_dict: 47.2035
Traceback (most recent call last):
File "/usr/local/bin/ns-train", line 8, in <module>
sys.exit(entrypoint())
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/scripts/train.py", line 272, in entrypoint
main(
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/scripts/train.py", line 257, in main
launch(
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/scripts/train.py", line 190, in launch
main_func(local_rank=0, world_size=world_size, config=config)
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/scripts/train.py", line 101, in train_loop
trainer.train()
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/engine/trainer.py", line 266, in train
loss, loss_dict, metrics_dict = self.train_iteration(step)
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/utils/profiler.py", line 111, in inner
out = func(*args, **kwargs)
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/engine/trainer.py", line 502, in train_iteration
_, loss_dict, metrics_dict = self.pipeline.get_train_loss_dict(step=step)
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/utils/profiler.py", line 111, in inner
out = func(*args, **kwargs)
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/pipelines/base_pipeline.py", line 301, in get_train_loss_dict
loss_dict = self.model.get_loss_dict(model_outputs, batch, metrics_dict)
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/models/splatfacto.py", line 687, in get_loss_dict
mask = self._downscale_if_required(batch["mask"])
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/models/splatfacto.py", line 452, in _downscale_if_required
return resize_image(image, d)
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/models/splatfacto.py", line 66, in resize_image
downscaled = tf.conv2d(image, weight, stride=d)
RuntimeError: Input type (CUDABoolType) and weight type (torch.cuda.FloatTensor) should be the same
Traceback (most recent call last):
File "/usr/local/bin/ns-train", line 8, in <module>
sys.exit(entrypoint())
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/scripts/train.py", line 272, in entrypoint
main(
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/scripts/train.py", line 257, in main
launch(
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/scripts/train.py", line 190, in launch
main_func(local_rank=0, world_size=world_size, config=config)
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/scripts/train.py", line 101, in train_loop
trainer.train()
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/engine/trainer.py", line 266, in train
loss, loss_dict, metrics_dict = self.train_iteration(step)
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/utils/profiler.py", line 111, in inner
out = func(*args, **kwargs)
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/engine/trainer.py", line 502, in train_iteration
_, loss_dict, metrics_dict = self.pipeline.get_train_loss_dict(step=step)
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/utils/profiler.py", line 111, in inner
out = func(*args, **kwargs)
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/pipelines/base_pipeline.py", line 301, in get_train_loss_dict
loss_dict = self.model.get_loss_dict(model_outputs, batch, metrics_dict)
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/models/splatfacto.py", line 687, in get_loss_dict
mask = self._downscale_if_required(batch["mask"])
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/models/splatfacto.py", line 452, in _downscale_if_required
return resize_image(image, d)
File "/workspace/NERFSTUDIO_v115a2/nerfstudio/nerfstudio/models/splatfacto.py", line 66, in resize_image
downscaled = tf.conv2d(image, weight, stride=d)
RuntimeError: Input type (CUDABoolType) and weight type (torch.cuda.FloatTensor) should be the same |
I have tried padding with alpha channel in the images thats does not have same resolution as first image in the batch. It worked but the result is not good. Now, I have an idea @akristoffersen for using multi-cameras with batch. It only stack same images with same resolution of the first image of the batch. If it does not have any same images with same resolution in the batch, return the first image itself. To preserve the batch shape, I create a dummy of clone of first image in the batch. Add this code in def stacked_batches(batch, dim=0, out=None):
if not batch:
raise ValueError("Batch cannot be empty")
# Reference size from the first tensor
ref_h, ref_w, ref_c = batch[0].shape
# Collect tensors that match the reference size
matching = [tensor for tensor in batch if tensor.shape == (ref_h, ref_w, ref_c)]
if not matching:
raise ValueError("No tensors with matching resolution found")
# Create output list, starting with matching tensors
result = matching.copy()
# Fill remaining slots with duplicates of the first tensor to match original batch length
while len(result) < len(batch):
result.append(batch[0])
# Stack all tensors
return torch.stack(result, dim=dim, out=out)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If in a background process, use shared memory
numel = sum(x.numel() for x in batch)
storage = elem.untyped_storage()._new_shared(numel, device=str(elem.device))
out = elem.new(storage).resize_(len(batch), *list(elem.size()))
return stacked_batches(batch, 0, out=out) |
WIP, preliminary testing makes it look like it's working but I would want to make sure.