GithubHelp home page GithubHelp logo

combustion_nn_model's Introduction

PyTorch Template Project

PyTorch deep learning project made easy.

Requirements

  • Python >= 3.5
  • PyTorch >= 0.4
  • tqdm (Optional for test.py)
  • tensorboard >= 1.7.0 (Optional for TensorboardX)
  • tensorboardX >= 1.2 (Optional for TensorboardX)

Features

  • Clear folder structure which is suitable for many deep learning projects.
  • .json config file support for more convenient parameter tuning.
  • Checkpoint saving and resuming.
  • Abstract base classes for faster development:
    • BaseTrainer handles checkpoint saving/resuming, training process logging, and more.
    • BaseDataLoader handles batch generation, data shuffling, and validation data splitting.
    • BaseModel provides basic model summary.

Folder Structure

pytorch-template/
│
├── train.py - main script to start training
├── test.py - evaluation of trained model
├── config.json - config file
│
├── base/ - abstract base classes
│   ├── base_data_loader.py - abstract base class for data loaders
│   ├── base_model.py - abstract base class for models
│   └── base_trainer.py - abstract base class for trainers
│
├── data_loader/ - anything about data loading goes here
│   └── data_loaders.py
│
├── data/ - default directory for storing input data
│
├── model/ - models, losses, and metrics
│   ├── loss.py
│   ├── metric.py
│   └── model.py
│
├── saved/ - default checkpoints folder
│   └── runs/ - default logdir for tensorboardX
│
├── trainer/ - trainers
│   └── trainer.py
│
└── utils/
    ├── util.py
    ├── logger.py - class for train logging
    ├── visualization.py - class for tensorboardX visualization support
    └── ...

Usage

The code in this repo is an MNIST example of the template. Try python3 train.py -c config.json to run code.

Config file format

Config files are in .json format:

{
  "name": "Mnist_LeNet",        // training session name
  "n_gpu": 1,                   // number of GPUs to use for training.
  
  "arch": {
    "type": "MnistModel",       // name of model architecture to train
    "args": {

    }                
  },
  "data_loader": {
    "type": "MnistDataLoader",  // selecting data loader
    "args":{
      "data_dir": "data/",      // dataset path
      "batch_size": 64,         // batch size
      "shuffle": true,          // shuffle training data before splitting
      "validation_split": 0.1   // validation data ratio
      "num_workers": 2,         // number of cpu processes to be used for data loading
    }
  },
  "optimizer": {
    "type": "Adam",
    "args":{
      "lr": 0.001,              // learning rate
      "weight_decay": 0,        // (optional) weight decay
      "amsgrad": true
    }
  },
  "loss": "nll_loss",           // loss
  "metrics": [
    "my_metric", "my_metric2"   // list of metrics to evaluate
  ],                         
  "lr_scheduler": {
    "type":"StepLR",            // learning rate scheduler
    "args":{
      "step_size":50,          
      "gamma":0.1
    }
  },
  "trainer": {
    "epochs": 1000,             // number of training epochs
    "save_dir": "saved/",       // checkpoints are saved in save_dir/name
    "save_freq": 1,             // save checkpoints every save_freq epochs
    "verbosity": 2,             // 0: quiet, 1: per epoch, 2: full
    "monitor": "val_loss",      // evaluation metric for finding best model
    "monitor_mode": "min"       // "min" if monitor value the lower the better, otherwise "max". "off" to disable
  },
  "visualization":{
    "tensorboardX": true,       // enable tensorboardX visualization support
    "log_dir": "saved/runs"     // directory to save log files for visualization
  }
}

Add addional configurations if you need.

Using config files

Modify the configurations in .json config files, then run:

python train.py --config config.json

Resuming from checkpoints

You can resume from a previously saved checkpoint by:

python train.py --resume path/to/checkpoint

Using Multiple GPU

You can enable multi-GPU training by setting n_gpu argument of the config file to larger number. If configured to use smaller number of gpu than available, first n devices will be used by default. Specify indices of available GPUs by cuda environmental variable.

python train.py --device 2,3 -c config.json

This is equivalent to

CUDA_VISIBLE_DEVICES=2,3 python train.py -c config.py

Customization

Data Loader

  • Writing your own data loader
  1. Inherit BaseDataLoader

    BaseDataLoader is a subclass of torch.utils.data.DataLoader, you can use either of them.

    BaseDataLoader handles:

    • Generating next batch
    • Data shuffling
    • Generating validation data loader by calling BaseDataLoader.split_validation()
  • DataLoader Usage

    BaseDataLoader is an iterator, to iterate through batches:

    for batch_idx, (x_batch, y_batch) in data_loader:
        pass
  • Example

    Please refer to data_loader/data_loaders.py for an MNIST data loading example.

Trainer

  • Writing your own trainer
  1. Inherit BaseTrainer

    BaseTrainer handles:

    • Training process logging
    • Checkpoint saving
    • Checkpoint resuming
    • Reconfigurable monitored value for saving current best
      • Controlled by the configs monitor and monitor_mode, if monitor_mode == 'min' then the trainer will save a checkpoint model_best.pth when monitor is a current minimum
  2. Implementing abstract methods

    You need to implement _train_epoch() for your training process, if you need validation then you can implement _valid_epoch() as in trainer/trainer.py

  • Example

    Please refer to trainer/trainer.py for MNIST training.

Model

  • Writing your own model
  1. Inherit BaseModel

    BaseModel handles:

    • Inherited from torch.nn.Module
    • summary(): Model summary
  2. Implementing abstract methods

    Implement the foward pass method forward()

  • Example

    Please refer to model/model.py for a LeNet example.

Loss

Custom loss functions can be implemented in 'model/loss.py'. Use them by changing the name given in "loss" in config file, to corresponding name.

Metrics

Metric functions are located in 'model/metric.py'.

You can monitor multiple metrics by providing a list in the configuration file, e.g.:

"metrics": ["my_metric", "my_metric2"],

Additional logging

If you have additional information to be logged, in _train_epoch() of your trainer class, merge them with log as shown below before returning:

additional_log = {"gradient_norm": g, "sensitivity": s}
log = {**log, **additional_log}
return log

Testing

You can test trained model by running test.py passing path to the trained checkpoint by --resume argument.

Validation data

To split validation data from a data loader, call BaseDataLoader.split_validation(), it will return a validation data loader, with the number of samples according to the specified ratio in your config file.

Note: the split_validation() method will modify the original data loader Note: split_validation() will return None if "validation_split" is set to 0

Checkpoints

You can specify the name of the training session in config files:

"name": "MNIST_LeNet",

The checkpoints will be saved in save_dir/name/timestamp/checkpoint_epoch_n, with timestamp in mmdd_HHMMSS format.

A copy of config file will be saved in the same folder.

Note: checkpoints contain:

{
  'arch': arch,
  'epoch': epoch,
  'logger': self.train_logger,
  'state_dict': self.model.state_dict(),
  'optimizer': self.optimizer.state_dict(),
  'monitor_best': self.monitor_best,
  'config': self.config
}

TensorboardX Visualization

This template supports TensorboardX visualization.

  • TensorboardX Usage
  1. Install

    Follow installation guide in TensorboardX.

  2. Run training

    Set tensorboardX option in config file true.

  3. Open tensorboard server

    Type tensorboard --logdir saved/runs/ at the project root, then server will open at http://localhost:6006

By default, values of loss and metrics specified in config file, and input image will be logged. If you need more visualizations, use add_scalar('tag', data), add_image('tag', image), etc in the trainer._train_epoch method. add_something() methods in this template are basically wrappers for those of tensorboardX.SummaryWriter module.

Note: You don't have to specify current steps, since WriterTensorboardX class defined at logger/visualization.py will track current steps.

Contributing

Feel free to contribute any kind of function or enhancement, here the coding style follows PEP8

Code should pass the Flake8 check before committing.

TODOs

  • Iteration-based training (instead of epoch-based)
  • Multiple optimizers
  • Configurable logging layout, checkpoint naming
  • visdom logger support
  • tensorboardX logger support
  • Adding command line option for fine-tuning
  • Multi-GPU support
  • Update the example to PyTorch 0.4
  • Learning rate scheduler
  • Deprecate BaseDataLoader, use torch.utils.data instesad
  • Load settings from config files

License

This project is licensed under the MIT License. See LICENSE for more details

Acknowledgments

This project is inspired by the project Tensorflow-Project-Template by Mahmoud Gemy

combustion_nn_model's People

Contributors

arslansadiq avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.