pvnieo / cp2p-pfarm-benchmark Goto Github PK
View Code? Open in Web Editor NEWBenchmark for non-rigid part-to-part shape matching
License: MIT License
Benchmark for non-rigid part-to-part shape matching
License: MIT License
Glad to find that you have updated the code! You're doing a very good job with your research!Wonderful work~
But, when I use faust and scape datasets to train the model, I meet some code errors.
In train_shrec_partial.py
, I added the following code to use faust and scape datasets.
# origin code
if cfg["dataset"]["name"] =="shrec16":
train_dataset = ShrecPartialDataset(dataset_path, name=cfg["dataset"]["subset"], k_eig=cfg["fmap"]["k_eig"],
n_fmap=cfg["fmap"]["n_fmap"], use_cache=True, op_cache_dir=op_cache_dir)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=None, shuffle=True)
# what I added~
# from faust_scape_dataset import FaustScapeDataset, shape_to_device
if cfg["dataset"]["name"] =="scape":
train_dataset = FaustScapeDataset(dataset_path, name="scape", k_eig=cfg["fmap"]["k_eig"],
n_fmap=cfg["fmap"]["n_fmap"], use_cache=True, op_cache_dir=op_cache_dir)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=None, shuffle=True)
if cfg["dataset"]["name"] =="faust":
train_dataset = FaustScapeDataset(dataset_path, name="faust", k_eig=cfg["fmap"]["k_eig"],
n_fmap=cfg["fmap"]["n_fmap"], use_cache=True, op_cache_dir=op_cache_dir)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=None, shuffle=True)
Then, the code program is running, and the process of precompute operators is okey~
However, when precompute finished, I encountered a program error:
Traceback (most recent call last):
File "train_shrec_partial.py", line 105, in <module>
train_net(cfg)
File "train_shrec_partial.py", line 66, in train_net
for i, data in enumerate(train_loader):
File "/home/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 517, in __next__
data = self._next_data()
File "/home/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 557, in _next_data
data = self.dataset[possibly_batched_index]
File "/home/xxx/dpfm/faust_scape_dataset.py", line 202, in __getitem__
evec_1_a, evec2_a = evec_1[vts1], evec_2[vts2]
IndexError: index 5000 is out of bounds for dimension 0 with size 5000
My faust and scape dataset is download from https://nuage.lix.polytechnique.fr/index.php/s/LJFXrsTG22wYCXx
Two datasets meet same error, I can't address it.
Can you give me some suggestions on how to solve this problem?
Hope you have a nice day~
To use the cp2p dataset in DPFM, I wrote two py files:
In the cp2p_dataset.py, I defined the class "Cp2pDataset" that inherits "Dataset" class, which is come from:
from torch.utils.data import Dataset
The specific implementation code of py is as follows:
import os
from pathlib import Path
import numpy as np
import potpourri3d as pp3d
import torch
from torch.utils.data import Dataset
import diffusion_net as dfn
from tqdm import tqdm
from itertools import permutations
from utils import farthest_point_sample, square_distance
class Cp2pDataset(Dataset):
def __init__(self, root_dir, name="cp2p", k_eig=128, n_fmap=30, use_cache=True, op_cache_dir=None):
self.k_eig = k_eig
self.n_fmap = n_fmap
self.root_dir = root_dir
self.cache_dir = root_dir
self.op_cache_dir = op_cache_dir
if use_cache:
train_cache = os.path.join(self.cache_dir, "train.pt")
load_cache = train_cache
print("using dataset cache path: " + str(load_cache))
if os.path.exists(load_cache):
print(" --> loading dataset from cache")
(
self.verts_list,
self.faces_list,
self.frames_list,
self.massvec_list,
self.L_list,
self.evals_list,
self.evecs_list,
self.gradX_list,
self.gradY_list,
self.hks_list,
self.vts_list,
self.names_list,
self.sample_list
) = torch.load(load_cache)
self.combinations = list(self.corres_dict.keys())
return
print(" --> dataset not in cache, repopulating")
# Load the meshes and labels
# define files and order
train = True
if train:
path = "./data/cp2p/splits/train.txt"
with open(path, 'r') as f:
mesh_lists = f.read().strip().split()
else:
path = "./data/cp2p/splits/test.txt"
with open(path, 'r') as f:
mesh_lists = f.read().strip().split()
self.used_shapes = sorted(x[:-4] for x in mesh_lists)
corres_path = Path(root_dir) / "maps"
all_combs = [x.stem for x in corres_path.iterdir()]
self.corres_dict = {}
for x, y in map(lambda x: (x[:x.rfind("_")], x[x.rfind("_") + 1:]), all_combs):
if x in self.used_shapes and y in self.used_shapes:
map_ = torch.from_numpy(np.loadtxt(corres_path / f"{x}_{y}.map", dtype=np.int32)).long() - 1
self.corres_dict[(self.used_shapes.index(y), self.used_shapes.index(x))] = map_
# set combinations
self.combinations = list(self.corres_dict.keys())
mesh_dirpath = Path(root_dir) / "shapes"
# Get all the files
self.verts_list = []
self.faces_list = []
self.sample_list = []
# Load the actual files
for shape_name in self.used_shapes:
print("loading mesh " + str(shape_name))
verts, faces = pp3d.read_mesh(str(mesh_dirpath / f"{shape_name}.off"))
# to torch
verts = torch.tensor(np.ascontiguousarray(verts)).float()
faces = torch.tensor(np.ascontiguousarray(faces))
self.verts_list.append(verts)
self.faces_list.append(faces)
idx0 = farthest_point_sample(verts.t(), ratio=0.9)
dists, idx1 = square_distance(verts.unsqueeze(0), verts[idx0].unsqueeze(0)).sort(dim=-1)
dists, idx1 = dists[:, :, :130].clone(), idx1[:, :, :130].clone()
self.sample_list.append((idx0, idx1, dists))
# Precompute operators
(
self.frames_list,
self.massvec_list,
self.L_list,
self.evals_list,
self.evecs_list,
self.gradX_list,
self.gradY_list,
) = dfn.geometry.get_all_operators(
self.verts_list,
self.faces_list,
k_eig=self.k_eig,
op_cache_dir=self.op_cache_dir,
)
# save to cache
if use_cache:
dfn.utils.ensure_dir_exists(self.cache_dir)
torch.save(
(
self.verts_list,
self.faces_list,
self.frames_list,
self.massvec_list,
self.L_list,
self.evals_list,
self.evecs_list,
self.gradX_list,
self.gradY_list,
self.used_shapes,
self.corres_dict,
self.sample_list,
),
load_cache,
)
def __len__(self):
return len(self.combinations)
def __getitem__(self, item):
idx1, idx2 = self.combinations[item]
shape1 = {
"xyz": self.verts_list[idx1],
"faces": self.faces_list[idx1],
"frames": self.frames_list[idx1],
"mass": self.massvec_list[idx1],
"L": self.L_list[idx1],
"evals": self.evals_list[idx1],
"evecs": self.evecs_list[idx1],
"gradX": self.gradX_list[idx1],
"gradY": self.gradY_list[idx1],
"name": self.used_shapes[idx1],
"sample_idx": self.sample_list[idx1],
}
shape2 = {
"xyz": self.verts_list[idx2],
"faces": self.faces_list[idx2],
"frames": self.frames_list[idx2],
"mass": self.massvec_list[idx2],
"L": self.L_list[idx2],
"evals": self.evals_list[idx2],
"evecs": self.evecs_list[idx2],
"gradX": self.gradX_list[idx2],
"gradY": self.gradY_list[idx2],
"name": self.used_shapes[idx2],
"sample_idx": self.sample_list[idx2],
}
# Compute fmap
map21 = self.corres_dict[(idx1, idx2)]
evec_1, evec_2, mass2 = shape1["evecs"][:, :self.n_fmap], shape2["evecs"][:, :self.n_fmap], shape2["mass"]
trans_evec2 = evec_2.t() @ torch.diag(mass2)
P = torch.zeros(evec_2.size(0), evec_1.size(0))
P[range(evec_2.size(0)), map21.flatten()] = 1
C_gt = trans_evec2 @ P @ evec_1
# compute region labels
gt_partiality_mask12 = torch.zeros(shape1["xyz"].size(0)).long().detach()
gt_partiality_mask12[map21[map21 != -1]] = 1
gt_partiality_mask21 = torch.zeros(shape2["xyz"].size(0)).long().detach()
gt_partiality_mask21[map21 != -1] = 1
return {"shape1": shape1, "shape2": shape2, "C_gt": C_gt,
"map21": map21, "gt_partiality_mask12": gt_partiality_mask12, "gt_partiality_mask21": gt_partiality_mask21}
def shape_to_device(dict_shape, device):
names_to_device = ["xyz", "faces", "mass", "evals", "evecs", "gradX", "gradY"]
for k, v in dict_shape.items():
if "shape" in k:
for name in names_to_device:
v[name] = v[name].to(device)
dict_shape[k] = v
else:
dict_shape[k] = v.to(device)
return dict_shape
However, I encountered the following error:
Traceback (most recent call last):
File "train_cp2p.py", line 90, in <module>
train_net(cfg)
File "train_cp2p.py", line 55, in train_net
for i, data in enumerate(train_loader):
File "/home/anaconda3/envs/fm2023/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 517, in __next__
data = self._next_data()
File "/home/anaconda3/envs/fm2023/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 557, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/home/anaconda3/envs/fm2023/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 46, in fetch
data = self.dataset[possibly_batched_index]
File "/home/FM_Code/dpfm/Cp2p_dataset.py", line 174, in __getitem__
P[range(evec_2.size(0)), map21.flatten()] = 1
IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [1777], [5830]
Fortunately, loding shapes and get_all_operators operations are working fine.
From the error above, I think there is a problem with the implementation of the "getitem" function
I don't know why the error occurred and hope I can get an answer from you~
Have a nice day :)
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.