GithubHelp home page GithubHelp logo

cp2p-pfarm-benchmark's People

Contributors

pvnieo avatar

Stargazers

 avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

cp2p-pfarm-benchmark's Issues

Problem about train code of full shape matching

Glad to find that you have updated the code! You're doing a very good job with your research!Wonderful work~

But, when I use faust and scape datasets to train the model, I meet some code errors.
In train_shrec_partial.py, I added the following code to use faust and scape datasets.

# origin code
if cfg["dataset"]["name"] =="shrec16":
    train_dataset = ShrecPartialDataset(dataset_path, name=cfg["dataset"]["subset"], k_eig=cfg["fmap"]["k_eig"],
                                                              n_fmap=cfg["fmap"]["n_fmap"], use_cache=True, op_cache_dir=op_cache_dir)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=None, shuffle=True)
# what I added~ 
# from faust_scape_dataset import FaustScapeDataset, shape_to_device
if cfg["dataset"]["name"] =="scape":
    train_dataset = FaustScapeDataset(dataset_path, name="scape", k_eig=cfg["fmap"]["k_eig"],
                                                              n_fmap=cfg["fmap"]["n_fmap"], use_cache=True, op_cache_dir=op_cache_dir)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=None, shuffle=True)
if cfg["dataset"]["name"] =="faust":
    train_dataset = FaustScapeDataset(dataset_path, name="faust", k_eig=cfg["fmap"]["k_eig"],
                                                              n_fmap=cfg["fmap"]["n_fmap"], use_cache=True, op_cache_dir=op_cache_dir)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=None, shuffle=True)

Then, the code program is running, and the process of precompute operators is okey~
However, when precompute finished, I encountered a program error:

Traceback (most recent call last):
    File "train_shrec_partial.py", line 105, in <module>
        train_net(cfg)
   File  "train_shrec_partial.py", line 66, in train_net
        for i, data in enumerate(train_loader):
   File  "/home/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 517, in __next__
        data = self._next_data()
   File "/home/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 557, in  _next_data
        data = self.dataset[possibly_batched_index]
   File "/home/xxx/dpfm/faust_scape_dataset.py", line 202, in __getitem__
        evec_1_a, evec2_a = evec_1[vts1], evec_2[vts2]
IndexError: index 5000 is out of bounds for dimension 0 with size 5000

My faust and scape dataset is download from https://nuage.lix.polytechnique.fr/index.php/s/LJFXrsTG22wYCXx
Two datasets meet same error, I can't address it.
Can you give me some suggestions on how to solve this problem?

Hope you have a nice day~

Error about implementation of cp2p dataset

To use the cp2p dataset in DPFM, I wrote two py files:

  • cp2p_dataset.py
  • train_cp2p.py

In the cp2p_dataset.py, I defined the class "Cp2pDataset" that inherits "Dataset" class, which is come from:

from torch.utils.data import Dataset

The specific implementation code of py is as follows:

import os
from pathlib import Path
import numpy as np
import potpourri3d as pp3d
import torch
from torch.utils.data import Dataset
import diffusion_net as dfn
from tqdm import tqdm
from itertools import permutations
from utils import farthest_point_sample, square_distance

class Cp2pDataset(Dataset):
    def __init__(self, root_dir, name="cp2p", k_eig=128, n_fmap=30, use_cache=True, op_cache_dir=None):

        self.k_eig = k_eig
        self.n_fmap = n_fmap
        self.root_dir = root_dir
        self.cache_dir = root_dir
        self.op_cache_dir = op_cache_dir

        if use_cache:
            train_cache = os.path.join(self.cache_dir, "train.pt")
            load_cache = train_cache
            print("using dataset cache path: " + str(load_cache))
            if os.path.exists(load_cache):
                print("  --> loading dataset from cache")
                (
                    self.verts_list,
                    self.faces_list,
                    self.frames_list,
                    self.massvec_list,
                    self.L_list,
                    self.evals_list,
                    self.evecs_list,
                    self.gradX_list,
                    self.gradY_list,
                    self.hks_list,
                    self.vts_list,
                    self.names_list,
                    self.sample_list
                ) = torch.load(load_cache)
                self.combinations = list(self.corres_dict.keys())
                return
            print("  --> dataset not in cache, repopulating")

        # Load the meshes and labels
        # define files and order
        train = True
        if train:
            path = "./data/cp2p/splits/train.txt"
            with open(path, 'r') as f:
                mesh_lists = f.read().strip().split()
        else:
            path = "./data/cp2p/splits/test.txt"
            with open(path, 'r') as f:
                mesh_lists = f.read().strip().split()
        self.used_shapes = sorted(x[:-4] for x in mesh_lists)

        corres_path = Path(root_dir) / "maps"
        all_combs = [x.stem for x in corres_path.iterdir()]
        self.corres_dict = {}
        for x, y in map(lambda x: (x[:x.rfind("_")], x[x.rfind("_") + 1:]), all_combs):
            if x in self.used_shapes and y in self.used_shapes:
                map_ = torch.from_numpy(np.loadtxt(corres_path / f"{x}_{y}.map", dtype=np.int32)).long() - 1
                self.corres_dict[(self.used_shapes.index(y), self.used_shapes.index(x))] = map_

        # set combinations
        self.combinations = list(self.corres_dict.keys())
        mesh_dirpath = Path(root_dir) / "shapes"

        # Get all the files
        self.verts_list = []
        self.faces_list = []
        self.sample_list = []

        # Load the actual files
        for shape_name in self.used_shapes:
            print("loading mesh " + str(shape_name))
            verts, faces = pp3d.read_mesh(str(mesh_dirpath / f"{shape_name}.off"))

            # to torch
            verts = torch.tensor(np.ascontiguousarray(verts)).float()
            faces = torch.tensor(np.ascontiguousarray(faces))
            self.verts_list.append(verts)
            self.faces_list.append(faces)
            idx0 = farthest_point_sample(verts.t(), ratio=0.9)
            dists, idx1 = square_distance(verts.unsqueeze(0), verts[idx0].unsqueeze(0)).sort(dim=-1)
            dists, idx1 = dists[:, :, :130].clone(), idx1[:, :, :130].clone()
            self.sample_list.append((idx0, idx1, dists))

        # Precompute operators
        (
            self.frames_list,
            self.massvec_list,
            self.L_list,
            self.evals_list,
            self.evecs_list,
            self.gradX_list,
            self.gradY_list,
        ) = dfn.geometry.get_all_operators(
            self.verts_list,
            self.faces_list,
            k_eig=self.k_eig,
            op_cache_dir=self.op_cache_dir,
        )

        # save to cache
        if use_cache:
            dfn.utils.ensure_dir_exists(self.cache_dir)
            torch.save(
                (
                    self.verts_list,
                    self.faces_list,
                    self.frames_list,
                    self.massvec_list,
                    self.L_list,
                    self.evals_list,
                    self.evecs_list,
                    self.gradX_list,
                    self.gradY_list,
                    self.used_shapes,
                    self.corres_dict,
                    self.sample_list,
                ),
                load_cache,
            )

    def __len__(self):
        return len(self.combinations)

    def __getitem__(self, item):
        idx1, idx2 = self.combinations[item]

        shape1 = {
            "xyz": self.verts_list[idx1],
            "faces": self.faces_list[idx1],
            "frames": self.frames_list[idx1],
            "mass": self.massvec_list[idx1],
            "L": self.L_list[idx1],
            "evals": self.evals_list[idx1],
            "evecs": self.evecs_list[idx1],
            "gradX": self.gradX_list[idx1],
            "gradY": self.gradY_list[idx1],
            "name": self.used_shapes[idx1],
            "sample_idx": self.sample_list[idx1],
        }

        shape2 = {
            "xyz": self.verts_list[idx2],
            "faces": self.faces_list[idx2],
            "frames": self.frames_list[idx2],
            "mass": self.massvec_list[idx2],
            "L": self.L_list[idx2],
            "evals": self.evals_list[idx2],
            "evecs": self.evecs_list[idx2],
            "gradX": self.gradX_list[idx2],
            "gradY": self.gradY_list[idx2],
            "name": self.used_shapes[idx2],
            "sample_idx": self.sample_list[idx2],
        }
        # Compute fmap
        map21 = self.corres_dict[(idx1, idx2)]

        evec_1, evec_2, mass2 = shape1["evecs"][:, :self.n_fmap], shape2["evecs"][:, :self.n_fmap], shape2["mass"]
        trans_evec2 = evec_2.t() @ torch.diag(mass2)

        P = torch.zeros(evec_2.size(0), evec_1.size(0))
        P[range(evec_2.size(0)), map21.flatten()] = 1
        C_gt = trans_evec2 @ P @ evec_1

        # compute region labels
        gt_partiality_mask12 = torch.zeros(shape1["xyz"].size(0)).long().detach()
        gt_partiality_mask12[map21[map21 != -1]] = 1
        gt_partiality_mask21 = torch.zeros(shape2["xyz"].size(0)).long().detach()
        gt_partiality_mask21[map21 != -1] = 1

        return {"shape1": shape1, "shape2": shape2, "C_gt": C_gt,
                "map21": map21, "gt_partiality_mask12": gt_partiality_mask12, "gt_partiality_mask21": gt_partiality_mask21}


def shape_to_device(dict_shape, device):
    names_to_device = ["xyz", "faces", "mass", "evals", "evecs", "gradX", "gradY"]
    for k, v in dict_shape.items():
        if "shape" in k:
            for name in names_to_device:
                v[name] = v[name].to(device)
            dict_shape[k] = v
        else:
            dict_shape[k] = v.to(device)

    return dict_shape

However, I encountered the following error:

Traceback (most recent call last):
  File "train_cp2p.py", line 90, in <module>
    train_net(cfg)
  File "train_cp2p.py", line 55, in train_net
    for i, data in enumerate(train_loader):
  File "/home/anaconda3/envs/fm2023/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 517, in __next__
    data = self._next_data()
  File "/home/anaconda3/envs/fm2023/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 557, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/home/anaconda3/envs/fm2023/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 46, in fetch
    data = self.dataset[possibly_batched_index]
  File "/home/FM_Code/dpfm/Cp2p_dataset.py", line 174, in __getitem__
    P[range(evec_2.size(0)), map21.flatten()] = 1
IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [1777], [5830]

Fortunately, loding shapes and get_all_operators operations are working fine.
From the error above, I think there is a problem with the implementation of the "getitem" function

I don't know why the error occurred and hope I can get an answer from you~
Have a nice day :)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.