GithubHelp home page GithubHelp logo

mushr_push_mstp's Introduction

README

Setting up Conda environment and packages

Create a new conda environment and go to the home directory (mushr_push_sim: $HOME)

conda create -n my_env python=3.7
conda activate my_env
cd $HOME
pip install -r requirements.txt

Data Generation

To generate Data, you need a valid license of MuJoCo and must be able to run the Figure 8 tutorial on the Mushr website (https://mushr.io/tutorials/mujoco_figure8/).

rm -f $HOME/datacol/json_files/*
python collect_data.py --cpu_count 3

This command will generate train.csv.gz, test_seen.csv.gz and test_unseen.csv.gz in $HOME/datacol/json_files/.

Training the Model

Note: You might want to train the model on a server since the dataloader consumes a lot of memory, which might hang your laptop/computer, and it is preferred to have a GPU (also I haven't tested the code on cpu, but it should work). We first need to set the data_dir, home_dir, traj_save_addr, model_path and train variables in $HOME/train/main.py. You can set these values by passing them as arguments or change it in the code itself.

data_dir is the parent folder of data stored ($HOME/datacol/json_files).

home_dir is the parent folder for main script file ($HOME/train/main.py).

traj_save_addr is the parent folder where you want to save your visual trajectories.

model_path is the folder where you want to store your pytorch model.

train is basically True when you want to train a model. If you simply want to evaluate and visualize an already pretrained lstm or simple regression model, then you can disable train. This comes handy when you want to evaluate your model and visualize trajectories, you dont have to train all over again.

Note: Dont use '~/' notation for paths, rather mention the complete path. For example: '/home/user/'.

Now run the training code as follows:

cd $HOME/train/
python main.py --data_dir '{data_dir}' --home_dir '{home_dir}' --traj_save_addr '{traj_save_addr}'

If you made the changes in the argument parser of the code simply run,

cd $HOME/train/
python main.py

Visualizations

There are a total of three visuals, the trajectories, red are predicted positions and blue the actual positions of the block stored in {traj_save_addr}, the tensorboard visualizations of the train and test loss plots in $HOME/train/TensorboardVisuals/, and finally Average Trajectory MSE losses stored in the current directory (during code execution).

Figure 1: Trajectories. Red trajectory is the model prediction, and the blue trajectories are ground truths.

To run the tensorboard plots, simple run in terminal

cd $HOME/train/
tensorboard --logdir=TensorboardVisuals/ --bind_all

Figure 2: Training and testing plots of the lstm model with sequence length 1000.

To visualize the model,

cd $HOME/train/
python -c "import netron; netron.start('model.onnx');"

The trajectory error plots can be found in the $HOME/train/ folder.

Figure 3: The Absolute error is the absolute difference between the target/labels and model predictions. Index variable is the index on of a point in a trajectory.

mushr_push_mstp's People

Contributors

alrick11 avatar schmittlema avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.