GithubHelp home page GithubHelp logo

jinhangzhu / hand-classifier Goto Github PK

View Code? Open in Web Editor NEW
3.0 2.0 1.0 632 KB

GPU-support and PyTorch-based Convolutional Neural Network for Hand Classification.

License: MIT License

Shell 5.67% Python 94.33%
object-detection hand-classification pytorch cnn python

hand-classifier's Introduction

Hand classifier

This repo contains how we develop a 5-layer convolutional neural network in PyTorch. The code works with CPU and GPU.

Our approaches

Our CNN architecture is based on the R-Net in MTCNN[1].

Our dataset is based on EPIC-KITCHENS-100 [2], an egocentric video dataset. In the image below are four examples from our dataset, where c0, c1 mean left_hand, right_hand.

We also apply PyTorch transforms to enable Data Augmentation in training. The transforms we utilise are as follows. Kind reminder: Do NOT enable flip in the training due to the reflectional symmetry.

transforms.RandomCrop(20),
transforms.RandomAffine(degrees=(30),translate=(0.1, 0.2)),
transforms.Resize(28)

Requirements

Python 3.7 or later with all requirements.txt dependencies installed, including torch >= 1.6. To install run:

$ pip install -U -r requirements.txt

Tutorials

$ python .\hand_cnn.py -h                                 
usage: hand_cnn.py [-h] [--mode MODE] [--dataset DATASET]
                   [--save-txt SAVE_TXT] [--batch-size BATCH_SIZE]
                   [--epochs EPOCHS] [--lr LR] [--augment] [--weight WEIGHT]
                   [--source SOURCE]

Training

optional arguments:
  -h, --help            show this help message and exit
  --mode MODE           Runing mode: train, test, detect
  --dataset DATASET     Path to dataset
  --save-txt SAVE_TXT   Path to results
  --batch-size BATCH_SIZE
  --epochs EPOCHS
  --lr LR               Learning rate
  --augment             Data augmentation
  --weight WEIGHT       Weight file to load.

Training command example command.

python hand_cnn.py --mode train --dataset datasets/handcrops --save-txt results --batch-size 100 --epochs 3 --lr 0.0001 --augment

Resuming training example command. We require to set the number of epochs a larger value than one saved in the checkpoint.

python hand_cnn.py --mode train --dataset datasets/handcrops --weight weights/handcnn.pt --save-txt results --batch-size 100 --epochs 10 --lr 0.0001 --augment

Plot training process command:

from utils import plot_process
plot_process('results')

Evaluation

optional arguments:
  -h, --help            show this help message and exit
  --mode MODE           Runing mode: train, test, detect
  --dataset DATASET     Path to dataset
  --save-txt SAVE_TXT   Path to results
  --batch-size BATCH_SIZE
  --epochs EPOCHS
  --lr LR               Learning rate
  --augment             Data augmentation
  --weight WEIGHT       Weight file to load.
  --source SOURCE       Source images for detection.

Evaluation on test set example with trained weight example command.

python hand_cnn.py --mode test --batch-size 100 --weight weights/handcnn.pt

Detection

optional arguments:
  -h, --help            show this help message and exit
  --mode MODE           Runing mode: train, test, detect
  --dataset DATASET     Path to dataset
  --save-txt SAVE_TXT   Path to results
  --batch-size BATCH_SIZE
  --epochs EPOCHS
  --lr LR               Learning rate
  --augment             Data augmentation
  --weight WEIGHT       Weight file to load.
  --source SOURCE       Source images for detection.

Detection of samples example command:

python hand_cnn.py --mode detect --weight weights/handcnn.pt --source samples

Easy integration

Our implementation should be easy to be integrated into your codes.

model = HandCropCNN()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
weight_file = 'weights/handcnn.pt'

patches = []
imgs = sorted(glob.glob('samples/' + '*jpg'))
for im_path in imgs:
    img = cv2.imread(im_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    patches.append(img)
patches = tuple(patches)

with open('datasets/handcrops/classes.names', 'r') as f:
    classes = [line.rstrip() for line in f]

cls_conf, labels = model.detect(device, weight_file, patches, classes)

References

  1. Zhang, K. et al. (2016) ‘Joint Face Detection and Alignment Using Multitask Cascaded Convolutional Networks’, IEEE Signal Processing Letters. Institute of Electrical and Electronics Engineers Inc., 23(10), pp. 1499–1503. doi: 10.1109/LSP.2016.2603342.
  2. Damen, D. et al. (2020) ‘The EPIC-KITCHENS Dataset: Collection, Challenges and Baselines’, IEEE Transactions on Pattern Analysis and Machine Intelligence. Institute of Electrical and Electronics Engineers (IEEE), pp. 1–1. Available at: http://arxiv.org/abs/2005.00343 (Accessed: 16 August 2020).

hand-classifier's People

Contributors

jinhangzhu avatar

Stargazers

 avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

yijun88

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.