GithubHelp home page GithubHelp logo

philjd / generalize_spatial_relations Goto Github PK

View Code? Open in Web Editor NEW
14.0 6.0 7.0 9.89 MB

Code for the paper "Optimization Beyond the Convolution: Generalizing Spatial Relations with End-to-End Metric Learning" (ICRA 2018)

License: Apache License 2.0

Python 100.00%

generalize_spatial_relations's Introduction

Code for: Optimization Beyond the Convolution: Generalizing Spatial Relations with End-to-End Metric Learning

This repository contains the code for our paper on generalizing spatial relations with end-to end metric learning, published at ICRA2018.

News

  • 24.06.2018 We won the ICRA2018 Best Paper Award in Robot Vision!

Setup:

Install this repository (in development mode) using pip:

git clone https://github.com/PhilJd/generalize_spatial_relations
cd generalize_spatial_relations
pip install -e .

If you'd like to extend the code, I highly recommend to additionally install pandas as it reads in the point clouds in 3 seconds, while numpy takes over 30 seconds.

  • Download the dataset
  • Navigate to the relations dataset and create the point clouds from the .obj files (needs pcl): cd scripts; ./create_uniform_pcd.sh

Training

To train all 15 splits run train.py:

CUDA_VISIBLE_DEVICES=0 python train.py --logdir=$STORE_WEIGHTS_HERE --data_dir=$OBJECT_MODELS_ARE_HERE --more_augmentation=True

Adding the flag --more_augmentation applies stronger augmentation, i.e. clones a scene three times and applies stronger augmentation to the third clone. This leads to a better metric performance but might lead to less realistic generalizations.

To train a model on all the data add the flag --train_on_all_data.

Experiments (generalize relations)

To generalize relations from one scene to another, take a look at generalize.py. Here we picked a random subset of the scenes and use each scene as a reference to generalize the relation to all other scenes of this subset. The 3d visualization requires Mayavi to be installed and running it is extremely slow (~12 hours to generate all scenes.)

Integrate the model into your code

For a simple example of how you could use this model in your code, see SpatialRelationCNN/inference_example.py. Please note that the code currently runs on GPU only.

generalize_spatial_relations's People

Contributors

philjd avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

generalize_spatial_relations's Issues

ValueError: could not broadcast input array from shape (135,1) into shape (135,)

My enviroment is Ubuntu 20.04 NVIDIA GeForce RTX 3090
CUDA==11.6
tensorflow==2.9.1
and I change
import tensorflow as tf
to

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

I use CUDA_VISIBLE_DEVICES=0 python3 train.py --logdir=./log --data_dir=~/桌面/dataset/FreiburgSpatialRelationsDataset/object-models -- more_augmention=True --train_on_all_data=True
and
CUDA_VISIBLE_DEVICES=0 python3 train.py --logdir=./log --data_dir=~/桌面/dataset/FreiburgSpatialRelationsDataset/object-models -- more_augmention=True
to train the model, but I have this problem :

/home/darkplume/桌面/git/generalize_spatial_relations/SpatialRelationCNN/model/evaluation_metrics.py:84: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.
  knn_labels = similarity_matrix[index][:, :k]
Traceback (most recent call last):
  File "train.py", line 211, in <module>
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
  File "/home/darkplume/.local/lib/python3.8/site-packages/tensorflow/python/platform/app.py", line 36, in run
    _run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
  File "/home/darkplume/.local/lib/python3.8/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/darkplume/.local/lib/python3.8/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "train.py", line 164, in main
    evaluate(model, sess, input_layer, labels, split_index, "test",
  File "train.py", line 126, in evaluate
    acc = metrics.knn_accuracy(dist_mat, similarity_mat, k, x_of_k)
  File "/home/darkplume/桌面/git/generalize_spatial_relations/SpatialRelationCNN/model/evaluation_metrics.py", line 84, in knn_accuracy
    knn_labels = similarity_matrix[index][:, :k]
ValueError: could not broadcast input array from shape (135,1) into shape (135,)

So what's wrong with my process? Maybe somebody can help me with this ?

Many thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.