GithubHelp home page GithubHelp logo

tmralmeida / clust-based-trajpred Goto Github PK

View Code? Open in Web Editor NEW
0.0 1.0 0.0 182.54 MB

PyTorch(v1.12.0) implementation of Likely, Light, and Accurate Context-Free Clusters-based Trajectory Prediction.

Home Page: https://ieeexplore.ieee.org/abstract/document/10422479

License: MIT License

Python 25.72% Jupyter Notebook 74.28%
clustering-methods deep-learning motion-prediction

clust-based-trajpred's Introduction

motion-pred-clustbased

augmentation branch can be updated at anytime

PyTorch(v1.12.0) implementation of Likely, Light, and Accurate Context-Free Clusters-based Trajectory Prediction.

Current status

  • Data sets

    • Argoverse
      • Extraction
      • Preprocessing
    • THÖR - Preprocessing
    • Benchmark - Preprocessing
  • Inputs

    • X, Y
    • dX, dY
    • R, $\theta$
    • dR, d$\theta$
  • Data Analysis

    • Transforms visualization
    • 2D Histograms -> inputs range
    • Histograms -> inputs by interval
    • Histograms -> time step
    • Trip Length
    • Statistics
  • Clustering based on full trajectory

    • k-means
      • visualization on TB
      • metrics (SSE and silhouette score)
      • Elbow method
    • DBSCAN
      • visualization on TB
      • metrics (silhouette score, avg_distance based on knn)
      • Fréchet-distance
    • HDBSCAN
      • visualization on TB
      • metrics (silhouette score)
      • Fréchet-distance
    • Self-Conditioned GAN
      • Discriminator based on predicted trajectory
      • Discriminator based on full trajectory
      • Elbow
        • silhouette coef. based
  • Forecasters

  • Augmentation

    • Plot displacements
    • Plot translated trajectories
    • sc-GAN generating synthetic data
      • Tested in THÖR
      • Tested in Argorverse
      • Tested in Benchmark
    • FT-GAN generating synthetic data
      • Tested in THÖR
      • Tested in Argorverse
      • Tested in Benchmark

Installation

Install miniconda. Then, you can install all packages required by running:

conda env create -f environment.yml
conda activate mpc
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
cd motion-pred-clustbased && pip install .

Datasets extraction and preprocessing

  • Benchmark

Define cfg file settings. Run:

python -m motion_pred_clustbased.datasets.utils.preprocess --cfg motion_pred_clustbased/cfg/preprocessing/benchmark.yaml
  • Argoverse

Define cfg file settings. First, run:

python -m motion_pred_clustbased.datasets.utils.extract_argoverse

Then, run:

python -m motion_pred_clustbased.datasets.utils.preprocess --cfg motion_pred_clustbased/cfg/preprocessing/argoverse.yaml
  • THÖR

First, run pythor-tools. To do so, define the cfg file and extract the dataset according to the README.

Then, run:

python -m motion_pred_clustbased.datasets.utils.preprocess --cfg motion_pred_clustbased/cfg/preprocessing/thor.yaml

Note: ETH has a different frame rate. (accelerated video)

Data Analysis

The Data Analysis branch comprises 3 different types of analysis: transformation assertion (normalization and polar/cartesian coordinates), input features analysis through histograms, and statistical analysis.

Assert Transformations

To run the transformation assertion on UNIV:

python -m motion_pred_clustbased.data_analysis.assert_transforms --cfg motion_pred_clustbased/cfg/data_analysis/benchmark.yaml

This would output:

The first row shows the original spaces, the second row depicts the normalization procedure (see axis) and the third row demonstrates the inverse transformation to the original space.

Analysis through histograms

Through this analysis, one can see frequency maps of the input features. Example (hist_type : input_frequency in the cfg file):

python -m motion_pred_clustbased.data_analysis.histograms --cfg motion_pred_clustbased/cfg/data_analysis/benchmark.yaml

Visual examples from THÖR (upper) and UNIV (bottom) data sets:

Another example is the counting of input features per interval. Visual Example (hist_type : cnt_per_interval in the cfg file):

Analougsly, we compute the same step-wise stats. Visual example (hist_type : step_wise in the cfg file):

Finally, we also compute trip length statistics (hist_type : trip_length in the cfg file):

Statistical Analysis

To run the statistical analysis (Straightness Index, average and standard deviation of input features) on Argoverse:

python -m motion_pred_clustbased.data_analysis.stats --cfg motion_pred_clustbased/cfg/data_analysis/argoverse.yaml

Clustering

This repo provides clustering methods in two different data spaces: input space (trajectory) and feature space (via Self-Conditioned GAN adapted to trajectories generation).

Config files can be found in the respective directory.

Trajectory Space

First, one needs to edit the config file comprising the settings of the clustering. Through the default.yaml file, one can select the clustering algorithm to use: k-means with type: normal, k-shape implemented via tslearn, and k-means with type: time_series, and the respective hyperparameters. Note: elbow method can be run in each algorithm (elbow key).

Then, to run the respective clustering method on Argoverse e.g.:

python -m motion_pred_clustbased.clustering.raw_data.run --cfg motion_pred_clustbased/cfg/clustering/raw/argoverse.yaml 

Feature Space

First, one needs to edit the config file comprising the settings of the Self-Conditioned GAN. Here, the dataset and hyperparameters must be specified. Furthermore, the best_metric key defines the stopping criteria.

Then, to run the sc-gan (Self-Conditioned GAN) method on the benchmark e.g., run:

python -m motion_pred_clustbased.clustering.feature_space.run --cfg motion_pred_clustbased/cfg/clustering/feature_space/benchmark.yaml

There is also the possibility of running the elbow method, through:

python -m motion_pred_clustbased.clustering.feature_space.run_elbow --cfg motion_pred_clustbased/cfg/clustering/feature_space/benchmark.yaml

Predictors

To run the predictors (3 variants of the Constant Velocity model) on the test set, one needs to configure the respective model's settings.

Then, to run the respective CVM:

python -m motion_pred_clustbased.predictors.cvm --cfg motion_pred_clustbased/cfg/predict/argoverse/cvm.yaml

Trainers

To train the RED predictor on THÖR:

python -m motion_pred_clustbased.train --cfg motion_pred_clustbased/cfg/training/thor/van_det.yaml

To train the vanilla GAN on THÖR:

python -m motion_pred_clustbased.train --cfg motion_pred_clustbased/cfg/training/thor/van_gan.yaml
  • Vanilla VAE To train the vanilla GAN on THÖR:
python -m motion_pred_clustbased.train --cfg motion_pred_clustbased/cfg/training/thor/van_vae.yaml
  • cVAE conditioned on supervised classes (when available)

To train the supervised conditioned VAE on THÖR:

python -m motion_pred_clustbased.train --cfg motion_pred_clustbased/cfg/training/thor/cvae.yaml

To train the supervised conditioned GAN on THÖR:

python -m motion_pred_clustbased.train --cfg motion_pred_clustbased/cfg/training/thor/cgan.yaml
  • Ours (dist-based)

To train a conditional context-free deep generative trajectory forecaster on THÖR:

python -m motion_pred_clustbased.train --cfg motion_pred_clustbased/cfg/training/thor/gan_clust_based/dist_based.yaml

This will train a model that outputs the trajectories and the respective likelihoods based on the distance to the centroids or the distance to the closest n_neighbors.

  • Ours (anet-based)

To train a conditional context-free deep generative trajectory forecaster on THÖR based on an auxiliary network:

python -m motion_pred_clustbased.train_anet --cfg motion_pred_clustbased/cfg/training/thor/gan_clust_based/anet_based.yaml

This will train a model that outputs the trajectories and the respective likelihoods based on the auxiliary network. This network aims to learn the correct activated clusters based on a set of trajectories coming from each cluster.

Augmentation

To create a synthetic THÖR dataset based on sc-GAN, and train an anet-conditioned deep generative trajectory forecaster:

python -m motion_pred_clustbased.augmentation.train --model_cfg motion_pred_clustbased/cfg/training/thor/ours_anet_based.yaml --synt_gen_cfg motion_pred_clustbased/cfg/clustering/feature_space/thor.yaml

To create a synthetic THÖR dataset based on FT-GAN, and train an anet-conditioned deep generative trajectory forecaster:

python -m motion_pred_clustbased.augmentation.train --model_cfg motion_pred_clustbased/cfg/training/thor/ours_anet_based.yaml --synt_gen_cfg motion_pred_clustbased/cfg/augmentation/thor_ft_gan.yaml

To plot the alignment between the original training and test sets and the synthetic dataset, toggle the plots variable to True in the respective configuration files.

Citation

@INPROCEEDINGS{10422479,
  author={de Almeida, Tiago Rodrigues and Mozos, Oscar Martinez},
  booktitle={2023 IEEE 26th International Conference on Intelligent Transportation Systems (ITSC)}, 
  title={Likely, Light, and Accurate Context-Free Clusters-based Trajectory Prediction}, 
  year={2023},
  volume={},
  number={},
  pages={1269-1276},
  keywords={Uncertainty;Clustering methods;Roads;Probabilistic logic;Generative adversarial networks;Trajectory;Proposals},
  doi={10.1109/ITSC57777.2023.10422479}}

clust-based-trajpred's People

Contributors

tmralmeida avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.