This codebase implements HyperNeRF using JAX,
building on JaxNeRF.
Demo
We provide an easy-to-get-started demo using Google Colab!
These Colabs will allow you to train a basic version of our method using
Cloud TPUs (or GPUs) on Google Colab.
Note that due to limited compute resources available, these are not the fully
featured models and will train quite slowly and the quality will likely not be that great.
If you would like to train a fully featured model, please refer to the instructions below
on how to train on your own machine.
Description
Link
Process a video into a dataset
Train HyperNeRF
Render HyperNeRF Videos
Setup
The code can be run under any environment with Python 3.8 and above.
(It may run with lower versions, but we have not tested it).
We recommend using Miniconda and setting up an environment:
If you find our work useful, please consider citing:
@article{park2021hypernerf,
author = {Park, Keunhong and Sinha, Utkarsh and Hedman, Peter and Barron, Jonathan T. and Bouaziz, Sofien and Goldman, Dan B and Martin-Brualla, Ricardo and Seitz, Steven M.},
title = {HyperNeRF: A Higher-Dimensional Representation for Topologically Varying Neural Radiance Fields},
journal = {ACM Trans. Graph.},
issue_date = {December 2021},
publisher = {ACM},
volume = {40},
number = {6},
month = {dec},
year = {2021},
articleno = {238},
}
Thank for your great work. I have an issue. Using the hyperparameters same config in the code (config), I can not get 0.156LPIPS for interp_chickchicken .It is much higher than 0.156LPIPS! @keunhong
I was wondering if you could provide some rules-of-thumb on how to collect good selfie video so that hypernerf can converge to its best global optima. (My first trial wasn't so much successful despite calibrating my camera extrinsics using background only w/ COLMAP)
Any tips from others successfully trained their own selfie videos are welcomed to comment!
Hello, I have a small question. Does the observation space refer to the camera space or the space after the transformation of the camera pose matrix (world space in NeRF)?
To experiment on your dataset with pytorch framework,
First, I''m trying to reimplemen nerfies data loader(HyperNeRF data loader) in pytorch.
However, I have some questions about camera parameters and scene parameters in dataset.
What does the scale, scene_to_metric and bbox in scene.json file??
What is differences between above scale parameter and 'scale' of images like 4x? ( scale of image vs scale of scene )
[ scene_scale is applied ]
For llff data format, we generate rays using this code.
[llff data format - get ray function]
I think that is the simple version of yours(ignore the distortion and assume that principal points is almost same to the center of image)
[ yours ]
But you undistorted to generate dirs using principal points, tangential distortion, radial_distortion, and pixel_aspect_ratio.
Is there huge difference?
Are HyperNerf and Nerfies is implemented in NDC(Normalized Device Coordinate)?
Usually many Nerf-based models have been implemented in NDC.
It will be really helpful for me if you answer this question!!
Thanks for your awesome work again!
Hi,
Thanks for the code. :)
I have been unable to run the example code due to the following error.
python train.py --base_folder $EXPERIMENT_PATH --gin_bindings="data_dir='$DATASET_PATH'" --gin_configs configs/test_local.gin
2021-10-16 22:37:32.827267: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudnn.so.8'; dlerror: libcudnn.so.8: cannot open shared object file: No such file or directory
2021-10-16 22:37:32.827288: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1835] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
I1016 22:37:32.828192 140583580991872 train.py:135] *** Starting experiment
I1016 22:37:32.828281 140583580991872 train.py:139] *** Loading Gin configs from: ['configs/test_local.gin']
I1016 22:37:32.840747 140583580991872 xla_bridge.py:231] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
I1016 22:37:32.949953 140583580991872 xla_bridge.py:231] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
I1016 22:37:32.950088 140583580991872 train.py:159] exp_dir = /disk2/nerf-review/hypernerf/test01
I1016 22:37:32.950088 140583580991872 train.py:159] exp_dir = /disk2/nerf-review/hypernerf/test01
I1016 22:37:32.950333 140583580991872 train.py:163] summary_dir = /disk2/nerf-review/hypernerf/test01/summaries/train
I1016 22:37:32.950404 140583580991872 train.py:167] checkpoint_dir = /disk2/nerf-review/hypernerf/test01/checkpoints
I1016 22:37:32.950486 140583580991872 train.py:171] Starting process 0. There are 1 processes.
I1016 22:37:32.950564 140583580991872 train.py:173] Found 2 accelerator devices: [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0)].
I1016 22:37:32.950626 140583580991872 train.py:175] Found 2 total devices: [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0)].
I1016 22:37:33.311031 140583580991872 train.py:187] Creating datasource
I1016 22:37:33.311344 140583580991872 nerfies.py:72] *** Loading dataset IDs from /disk2/nerf-review/datasets/hypernerf/achen_000000_nerfies_04/capture_upright_01/dataset.json
I1016 22:37:33.312014 140583580991872 core.py:237] Creating datasource of type NerfiesDataSource with use_appearance_id=True, use_camera_id=False, use_warp_id=True, use_depth=False, use_time=False, train_stride=1, val_stride=1
I1016 22:37:33.312376 140583580991872 train.py:200] Initializing models.
Traceback (most recent call last):
File "train.py", line 370, in <module>
app.run(main)
File "/home/kurt/anaconda3/envs/hypernerf/lib/python3.8/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/home/kurt/anaconda3/envs/hypernerf/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "train.py", line 203, in main
model, params['model'] = models.construct_nerf(
File "/disk2/nerf-review/hypernerf/hypernerf/models.py", line 701, in construct_nerf
params = model.init({
File "/home/kurt/anaconda3/envs/hypernerf/lib/python3.8/site-packages/flax/linen/module.py", line 998, in init
_, v_out = self.init_with_output(
File "/home/kurt/anaconda3/envs/hypernerf/lib/python3.8/site-packages/flax/linen/module.py", line 968, in init_with_output
return self.apply(
File "/home/kurt/anaconda3/envs/hypernerf/lib/python3.8/site-packages/flax/linen/module.py", line 936, in apply
return apply(
File "/home/kurt/anaconda3/envs/hypernerf/lib/python3.8/site-packages/flax/core/scope.py", line 686, in wrapper
with bind(variables, rngs=rngs, mutable=mutable).temporary() as root:
File "/home/kurt/anaconda3/envs//hypernerf/lib/python3.8/site-packages/flax/core/scope.py", line 663, in bind
raise errors.InvalidRngError(
flax.errors.InvalidRngError: rngs should be a dictionary mapping strings to `jax.PRNGKey`. (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.InvalidRngError)
installation instructions followed except for "jaxlib==0.1.71+cuda111" being changed to "jaxlib==0.1.71" to allow installation (followed by jax installation from github). Dataset has been previously trained with Nerfies successfully.
I have been testing the Geman-McClure loss implemented in the repo, more specifically the model_utils.general_loss_with_squared_residual() with alpha = -2.
However, the output, as shown below, is very different from my expectation. (alpha = -2, scale = 0.1)
Shouldn't it be like the quadratic shape when |x| < scale
Thanks a lot for sharing the code! I want to know how to use your code to realize dynamic video training and rendering just like the video below.
Now I can only render the free view video with the object still, just like the first video on your project page.
When I try to execute python train.py --base_folder $EXPERIMENT_PATH --gin_bindings="data_dir='$DATASET_PATH'" --gin_configs configs/hypernerf_interp_ds_2d.gin
I0925 16:29:32.415290 140117065011584 train.py:135] *** Starting experiment
I0925 16:29:32.415382 140117065011584 train.py:139] *** Loading Gin configs from: ['configs/hypernerf_interp_ds_2d.gin']
I0925 16:29:32.430606 140117065011584 xla_bridge.py:236] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
I0925 16:29:32.494084 140117065011584 xla_bridge.py:236] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
I0925 16:29:32.494200 140117065011584 train.py:159] exp_dir = /data/TinyMe/chickchickenout
I0925 16:29:32.494389 140117065011584 train.py:163] summary_dir = /data/TinyMe/chickchickenout/summaries/train
I0925 16:29:32.494438 140117065011584 train.py:167] checkpoint_dir = /data/TinyMe/chickchickenout/checkpoints
I0925 16:29:32.494497 140117065011584 train.py:171] Starting process 0. There are 1 processes.
I0925 16:29:32.494542 140117065011584 train.py:173] Found 1 accelerator devices: [GpuDevice(id=0, process_index=0)].
I0925 16:29:32.494578 140117065011584 train.py:175] Found 1 total devices: [GpuDevice(id=0, process_index=0)].
I0925 16:29:32.698378 140117065011584 train.py:187] Creating datasource
I0925 16:29:32.698573 140117065011584 interp.py:71] *** Loading dataset IDs from /data/TinyMe/chickchicken/dataset.json
I0925 16:29:32.699014 140117065011584 core.py:237] Creating datasource of type InterpDataSource with use_appearance_id=True, use_camera_id=False, use_warp_id=True, use_depth=False, use_time=False, train_stride=1, val_stride=1
I0925 16:29:32.699630 140117065011584 train.py:200] Initializing models.
2022-09-25 16:29:45.751242: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2085] Execution of replica 0 failed: INTERNAL: Failed to load in-memory CUBIN: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered
Traceback (most recent call last):
File "train.py", line 370, in
app.run(main)
File "/home/hy/miniconda3/envs/hypernerf/lib/python3.8/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/home/hy/miniconda3/envs/hypernerf/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "train.py", line 203, in main
model, params['model'] = models.construct_nerf(
File "/data/TinyMe/hypernerf/hypernerf/models.py", line 701, in construct_nerf
params = model.init({
File "/home/hy/miniconda3/envs/hypernerf/lib/python3.8/site-packages/flax/linen/module.py", line 998, in init
_, v_out = self.init_with_output(
File "/home/hy/miniconda3/envs/hypernerf/lib/python3.8/site-packages/flax/linen/module.py", line 968, in init_with_output
return self.apply(
File "/home/hy/miniconda3/envs/hypernerf/lib/python3.8/site-packages/flax/linen/module.py", line 936, in apply
return apply(
File "/home/hy/miniconda3/envs/hypernerf/lib/python3.8/site-packages/flax/core/scope.py", line 687, in wrapper
y = fn(root, *args, **kwargs)
File "/home/hy/miniconda3/envs/hypernerf/lib/python3.8/site-packages/flax/linen/module.py", line 1178, in scope_fn
return fn(module.clone(parent=scope), *args, **kwargs)
File "/home/hy/miniconda3/envs/hypernerf/lib/python3.8/site-packages/flax/linen/module.py", line 275, in wrapped_module_method
y = fun(self, *args, **kwargs)
File "/data/TinyMe/hypernerf/hypernerf/models.py", line 631, in call
z_vals, points = model_utils.sample_pdf(
File "/data/TinyMe/hypernerf/hypernerf/model_utils.py", line 216, in sample_pdf
z_samples = piecewise_constant_pdf(key, bins, weights, num_coarse_samples,
File "/data/TinyMe/hypernerf/hypernerf/model_utils.py", line 184, in piecewise_constant_pdf
bins_g0, bins_g1 = minmax(bins)
File "/data/TinyMe/hypernerf/hypernerf/model_utils.py", line 179, in minmax
x1 = jnp.min(jnp.where(~mask, x[..., None], x[..., -1:, None]), -2)
File "/home/hy/miniconda3/envs/hypernerf/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/hy/miniconda3/envs/hypernerf/lib/python3.8/site-packages/jax/_src/api.py", line 412, in cache_miss
out_flat = xla.xla_call(
File "/home/hy/miniconda3/envs/hypernerf/lib/python3.8/site-packages/jax/core.py", line 1616, in bind
return call_bind(self, fun, *args, **params)
File "/home/hy/miniconda3/envs/hypernerf/lib/python3.8/site-packages/jax/core.py", line 1607, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/hy/miniconda3/envs/hypernerf/lib/python3.8/site-packages/jax/core.py", line 1619, in process
return trace.process_call(self, fun, tracers, params)
File "/home/hy/miniconda3/envs/hypernerf/lib/python3.8/site-packages/jax/core.py", line 615, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/hy/miniconda3/envs/hypernerf/lib/python3.8/site-packages/jax/interpreters/xla.py", line 623, in _xla_call_impl
out = compiled_fun(*args)
File "/home/hy/miniconda3/envs/hypernerf/lib/python3.8/site-packages/jax/interpreters/xla.py", line 913, in _execute_compiled
out_bufs = compiled.execute(input_bufs)
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: INTERNAL: Failed to load in-memory CUBIN: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "train.py", line 370, in
app.run(main)
File "/home/hy/miniconda3/envs/hypernerf/lib/python3.8/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/home/hy/miniconda3/envs/hypernerf/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "train.py", line 203, in main
model, params['model'] = models.construct_nerf(
File "/data/TinyMe/hypernerf/hypernerf/models.py", line 701, in construct_nerf
params = model.init({
File "/home/hy/miniconda3/envs/hypernerf/lib/python3.8/site-packages/flax/linen/module.py", line 998, in init
_, v_out = self.init_with_output(
File "/home/hy/miniconda3/envs/hypernerf/lib/python3.8/site-packages/flax/linen/module.py", line 968, in init_with_output
return self.apply(
File "/home/hy/miniconda3/envs/hypernerf/lib/python3.8/site-packages/flax/linen/module.py", line 936, in apply
return apply(
File "/home/hy/miniconda3/envs/hypernerf/lib/python3.8/site-packages/flax/core/scope.py", line 687, in wrapper
y = fn(root, *args, **kwargs)
File "/home/hy/miniconda3/envs/hypernerf/lib/python3.8/site-packages/flax/linen/module.py", line 1178, in scope_fn
return fn(module.clone(parent=scope), *args, **kwargs)
File "/home/hy/miniconda3/envs/hypernerf/lib/python3.8/site-packages/flax/linen/module.py", line 275, in wrapped_module_method
y = fun(self, *args, **kwargs)
File "/data/TinyMe/hypernerf/hypernerf/models.py", line 631, in call
z_vals, points = model_utils.sample_pdf(
File "/data/TinyMe/hypernerf/hypernerf/model_utils.py", line 216, in sample_pdf
z_samples = piecewise_constant_pdf(key, bins, weights, num_coarse_samples,
File "/data/TinyMe/hypernerf/hypernerf/model_utils.py", line 184, in piecewise_constant_pdf
bins_g0, bins_g1 = minmax(bins)
File "/data/TinyMe/hypernerf/hypernerf/model_utils.py", line 179, in minmax
x1 = jnp.min(jnp.where(~mask, x[..., None], x[..., -1:, None]), -2)
File "/home/hy/miniconda3/envs/hypernerf/lib/python3.8/site-packages/jax/interpreters/xla.py", line 913, in _execute_compiled
out_bufs = compiled.execute(input_bufs)
RuntimeError: INTERNAL: Failed to load in-memory CUBIN: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered
2022-09-25 16:29:45.849669: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:1047] could not synchronize on CUDA context: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered :: *** Begin stack trace ***
This is really nice work and I am wondering if you have tested the code to generate figures in the notebooks file (like hypernerf_ap_ds_figure.ipynb). I run this code and seems cannot work. Could you give some hints on how to use these code in the notebooks file?
Thank you so much!
Honestly appreciate your project!
Qusetions about the use_time mode of eval.py.
When I changed the value of 'warp_embed_key' to 'time', I got this error
It seems 'get_time_id' returns a tuple and a tuple cannot be divided by an int value.
So how should I do to generate the dataset for using time_id.
BTW, how can I get time_based images with fixed camera params ?
Thanks a lot!
I have some questions about the camera pose matrix in camera folder's json files.
I'm reimplementing your paper on pytorch based nerf code.
However, I have confused about camera coordinate in nerfies dataset.
In load_llff.py in nerf pytorch code, they load the camera poses from given poses_bounds.npy file and correct their correct rotation matrix like follow codes.(poses matrix format is (3, 5, frame_num)
However, I could not find any correction part on your code when you generate the poses of dataset after loading them from given camera information json files.
I guess you used different coordinate system compared to original nerf(using llff dataset)
Then if I want to use your dataset as llff format and ray function, should I correct poses matrix like above code?
If you don't mind, could you explain the camera coordinate axis format of COLMAP results, NeRF(llff), and Nerfies(HyperNerf)??
for dynamic scene dataset(called nvidia-dataset), there is no undistortion part for camera pose like follow part of your code.
I wonder the necessity of above part because nerfies dataset images seems to not be distorted.
and the dynamic scene dataset is pre-undistorted image?
Did you implement your code on NDC coordinate??
I wonder should I use the ndc_rays
I really hope your answer and appreciate to your great work again!!
cv.gapi.wip.GStreamerPipeline = cv.gapi_wip_gst_GStreamerPipeline
AttributeError: partially initialized module 'cv2' has no attribute 'gapi_wip_gst_GStreamerPipeline' (most likely due to a circular import)
I got this segmentation fault when I start to train the HyperNeRF. I thought it was an OOM error, but it still shows up after I reduced the batch_size. I reinstalled numpy as well, nothing happened. Any ideas would be appreciated.
Hello author, thank you very much for such a good job. I have a request to ask you. One of my papers was reworked, and a reviewer asked me to add the results on hypernerf_vrig, but this was a bit difficult for me. So, I was wondering if you could send me your repro code on NSFF. Thank you very much.
Like other nerf tools, can I get mesh data or texture data from your tool?
I cannot find where the similar functions in your source code.
Thanks for your help in advance.
Processed the video, trained the model with dataset saved to /content/gdrive/My Drive/nerfies/captures/capture1 and config.gin and checkpoints saved to /content/gdrive/My Drive/nerfies/experiments/capture1/exp1.
For Model and dataset configuration during rendering, the following error occurs:
Loading config from /content/gdrive/My Drive/nerfies/experiments/capture1/exp1/config.gin
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-13-573aff68880d> in <module>()
27 logging.info('Loading config from %s', config_path)
28 config_str = f.read()
---> 29 gin.parse_config(config_str)
30
31 config_path = Path(train_dir, 'config.gin')
7 frames
/usr/local/lib/python3.7/dist-packages/gin/config.py in parse(cls, binding_key)
917 if not _might_have_parameter(configurable_.wrapper, arg_name):
918 err_str = "Configurable '{}' doesn't have a parameter named '{}'."
--> 919 raise ValueError(err_str.format(selector, arg_name))
920
921 if configurable_.allowlist and arg_name not in configurable_.allowlist:
ValueError: Configurable 'ExperimentConfig' doesn't have a parameter named 'datasource_spec'.
In bindings string line 3
ExperimentConfig.datasource_spec = {
I checked config.gin file and ExperimentConfig.datasource_spec is there.
I check configs.py and I see datasource_cls instead of datasource_spec - should this be changed?
Hi, thanks for your great work!
And I'm trying to run hyperNeRF on my own data captured with a mobile phone.
But after running COLMAP on those frames, I don't know how to get all those parameters in your data class.
Could you please offer a brief guide or something to refer to? I worry that I can't get the best performance of your work without a correct data preprocess.
Thanks!
Hello, I would like to know how to process my own data set. The medical data set in the video of my data set moves slowly, the data I processed cannot meet the data requirements in the paper, and the training results are also poor. May I ask what I need to pay attention to most when processing other video data
Is it possible to get mesh, point cloud, or vertices of resulted mesh?
I saw similar questions are still open, so i am sorry that i have to write it down again.
But since there were no answers, i hope that somebody would respond.
While reading the code base, I found out the HyperNerf.use_nerf_embed is set to false in all of the given configurations. Based on my understanding this flag is used to determine whether additional embedding, such as appearance code, is added as input to the NeRF network. This seems to contradict the paper which suggests a per-frame appearance code is fed into the network along with view directions. May I ask for what types of experiments do you use the appearance code? And does it constantly give better performance? Many thanks.
def generate_video_path(self):
self.select_video_cams = [item for i, item in enumerate(self.all_cam_params) if i % 1 == 0 ]
self.video_path, self.video_time = smooth_camera_poses(self.select_video_cams,10)
self.video_path = self.video_path[:500]
self.video_time = self.video_time[:500]
def load_video(self, idx):
if idx in self.map.keys():
return self.map[idx]
camera = self.all_cam_params[idx]
w = self.image_one.size[0]
h = self.image_one.size[1]
time = self.video_time[idx]
R = camera.orientation.T
T = - camera.position @ R
FovY = focal2fov(camera.focal_length, self.h)
FovX = focal2fov(camera.focal_length, self.w)
image_path = "/".join(self.all_img[idx].split("/")[:-1])
image_name = self.all_img[idx].split("/")[-1]
caminfo = CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=self.image_one_torch,
image_path=image_path, image_name=image_name, width=w, height=h, time=time, mask=None
)
self.map[idx] = caminfo
return caminfo
These are functions in hyper_loader.py and it seems that only interpolated timestamps are used for video and interpolated poses are discarded. Could you explain me about the reason for it?
An error occurred when I tried the command pip install -r requirements.txt
which said No matching distribution found for jaxlib==0.1.71+cuda111
Should I change the version of jaxlib? Appreciate a lot for your generous help.
Hi authors
I am getting this error while running train.py script. I have not made any changes to the code.
If I may ask, what was your build setup and configuration ? Also, is this code specific only for a jax version or does it have a forward compatibility with new version of jax. ?
When I run the training script, I encounter the problems as follow.
By the way, it's ok when I run training script from nerfies with similar config.
train.py:302] Starting training
Traceback (most recent call last):
File "/root/anaconda3/lib/python3.9/site-packages/jax/api_util.py", line 146, in argnums_partial_except
hash(static_arg)
TypeError: unhashable type: 'NerfModel'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/bfs/sz/Research/hypernerf/train.py", line 370, in <module>
app.run(main)
File "/root/anaconda3/lib/python3.9/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/root/anaconda3/lib/python3.9/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "/bfs/sz/Research/hypernerf/train.py", line 330, in main
state, stats, keys, model_out = ptrain_step(
File "/root/anaconda3/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/root/anaconda3/lib/python3.9/site-packages/jax/_src/api.py", line 1669, in f_pmapped
out = pxla.xla_pmap(
File "/root/anaconda3/lib/python3.9/site-packages/jax/core.py", line 1620, in bind
return call_bind(self, fun, *args, **params)
File "/root/anaconda3/lib/python3.9/site-packages/jax/core.py", line 1551, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/root/anaconda3/lib/python3.9/site-packages/jax/core.py", line 1623, in process
return trace.process_map(self, fun, tracers, params)
File "/root/anaconda3/lib/python3.9/site-packages/jax/core.py", line 606, in process_call
return primitive.impl(f, *tracers, **params)
File "/root/anaconda3/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 624, in xla_pmap_impl
compiled_fun, fingerprint = parallel_callable(fun, backend, axis_name, axis_size,
File "/root/anaconda3/lib/python3.9/site-packages/jax/linear_util.py", line 262, in memoized_fun
ans = call(fun, *args)
File "/root/anaconda3/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 712, in parallel_callable
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
File "/root/anaconda3/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1284, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/root/anaconda3/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1262, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/root/anaconda3/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/root/anaconda3/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/root/anaconda3/lib/python3.9/site-packages/jax/_src/api.py", line 413, in cache_miss
f, args = argnums_partial_except(f, static_argnums, args, allow_invalid=True)
File "/root/anaconda3/lib/python3.9/site-packages/jax/api_util.py", line 148, in argnums_partial_except
raise ValueError(
jax._src.traceback_util.UnfilteredStackTrace: ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 0) of type <class 'hypernerf.models.NerfModel'> for function train_step is non-hashable.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/bfs/sz/Research/hypernerf/train.py", line 370, in <module>
app.run(main)
File "/root/anaconda3/lib/python3.9/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/root/anaconda3/lib/python3.9/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "/bfs/sz/Research/hypernerf/train.py", line 330, in main
state, stats, keys, model_out = ptrain_step(
File "/root/anaconda3/lib/python3.9/site-packages/jax/api_util.py", line 148, in argnums_partial_except
raise ValueError(
ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 0) of type <class 'hypernerf.models.NerfModel'> for function train_step is non-hashable.
Hello!
I want to produce a hyper-space template with my own dataset, but I have no idea how I can do it.
Could you please give me some instructions on it?
Hello everyone,
Can someone help me to run this code locally on my computer? I run train.py and eval.py for my own dataset and my render folder is empty. actually there is no code for render video in eval.py!