GithubHelp home page GithubHelp logo

xiaoxying / deep-person-reid Goto Github PK

View Code? Open in Web Editor NEW

This project forked from kaiyangzhou/deep-person-reid

0.0 1.0 0.0 176 KB

Pytorch implementation of deep person re-identification models.

License: MIT License

Python 100.00%

deep-person-reid's Introduction

deep-person-reid

This repo contains pytorch implementations of deep person re-identification models.

Trained models can be found here.

Updates

  • Apr 2018: Added iLIDS-VID and PRID-2011. Models are available.
  • Mar 2018: Added argument --htri-only to train_img_model_xent_htri.py and train_vid_model_xent_htri.py. If this argument is true, only htri [4] is used for training. See here for detailed changes.
  • Mar 2018: Added Multi-scale Deep CNN (ICCV'17) [10] with slight modifications: (a) Input size is (256, 128) instead of (160, 60); (b) We add an average pooling layer after the last conv feature maps. (c) We train the network with our strategy. Model trained from scratch on Market1501 is available.
  • Mar 2018: Added center loss (ECCV'16) [9] and the trained model weights.

Dependencies

Install

  1. cd to the folder where you want to download this repo.
  2. run git clone https://github.com/KaiyangZhou/deep-person-reid.

Prepare data

Create a directory to store reid datasets under this repo via

cd deep-person-reid/
mkdir data/

Please follow the instructions below to prepare each dataset.

Market1501 [7]:

  1. Download dataset to data/ from http://www.liangzheng.org/Project/project_reid.html.
  2. Extract dataset and rename to market1501. The data structure would look like:
market1501/
    bounding_box_test/
    bounding_box_train/
    ...
  1. Use -d market1501 when running the training code.

MARS [8]:

  1. Create a directory named mars/ under data/.
  2. Download dataset to data/mars/ from http://www.liangzheng.com.cn/Project/project_mars.html.
  3. Extract bbox_train.zip and bbox_test.zip.
  4. Download split information from https://github.com/liangzheng06/MARS-evaluation/tree/master/info and put info/ in data/mars (we want to follow the standard split in [8]). The data structure would look like:
mars/
    bbox_test/
    bbox_train/
    info/
  1. Use -d mars when running the training code.

iLIDS-VID [11]:

  1. The code supports automatic download and formatting. Simple use -d ilidsvid when running the training code. The data structure would look like:
ilids-vid/
    i-LIDS-VID/
    train-test people splits/
    splits.json

PRID [12]:

  1. Under data/, do mkdir prid2011 to create a directory.
  2. Download dataset from https://www.tugraz.at/institute/icg/research/team-bischof/lrs/downloads/PRID11/ and extract it under data/prid2011.
  3. Download the split created by iLIDS-VID from here, and put it in data/prid2011/. We follow [11] and use 178 persons whose sequences are more than a threshold so that results on this dataset can be fairly compared with other approaches. The data structure would look like:
prid2011/
    splits_prid2011.json
    prid_2011/
        multi_shot/
        single_shot/
        readme.txt
  1. Use -d prid when running the training code.

Dataset loaders

These are implemented in dataset_loader.py where we have two main classes that subclass torch.utils.data.Dataset:

These two classes are used for torch.utils.data.DataLoader that can provide batched data. Data loader wich ImageDataset outputs batch data of (batch, channel, height, width), while data loader with VideoDataset outputs batch data of (batch, sequence, channel, height, width).

Models

  • models/ResNet.py: ResNet50 [1], ResNet50M [2].
  • models/DenseNet.py: DenseNet121 [3].
  • models/MuDeep.py: MuDeep [10].

Loss functions

  • xent: cross entropy + label smoothing regularizer [5].
  • htri: triplet loss with hard positive/negative mining [4] .
  • cent: center loss [9].

We use Adam [6] everywhere, which turned out to be the most effective optimizer in our experiments.

Train

Training codes are implemented mainly in

  • train_img_model_xent.py: train image model with cross entropy loss.
  • train_img_model_xent_htri.py: train image model with combination of cross entropy loss and hard triplet loss.
  • train_img_model_cent.py: train image model with center loss.
  • train_vid_model_xent.py: train video model with cross entropy loss.
  • train_vid_model_xent_htri.py: train video model with combination of cross entropy loss and hard triplet loss.

For example, to train an image reid model using ResNet50 and cross entropy loss, run

python train_img_model_xent.py -d market1501 -a resnet50 --max-epoch 60 --train-batch 32 --test-batch 32 --stepsize 20 --eval-step 20 --save-dir log/resnet50-xent-market1501 --gpu-devices 0

To use multiple GPUs, you can set --gpu-devices 0,1,2,3.

Please run python train_blah_blah.py -h for more details regarding arguments.

Results

Image person reid

Market1501

Model Param Size (M) Loss Rank-1/5/10 (%) mAP (%) Model weights Published Rank Published mAP
DenseNet121 7.72 xent 86.5/93.6/95.7 67.8 download
DenseNet121 7.72 xent+htri 89.5/96.3/97.5 72.6 download
Resnet50 25.05 cent 85.1/93.8/96.2 69.1 download
ResNet50 25.05 xent 85.4/94.1/95.9 68.8 download 87.3/-/- 67.6
ResNet50 25.05 xent+htri 87.5/95.3/97.3 72.3 download
ResNet50M 30.01 xent 89.0/95.5/97.3 75.0 download 89.9/-/- 75.6
ResNet50M 30.01 xent+htri 90.4/96.7/98.0 76.6 download
MuDeep 138.02 xent+htri 71.5/89.3/96.3 47.0 download

Video person reid

MARS

Model Param Size (M) Loss Rank-1/5/10 (%) mAP (%) Model weights Published Rank Published mAP
DenseNet121 7.59 xent 65.2/81.1/86.3 52.1 download
DenseNet121 7.59 xent+htri 82.6/93.2/95.4 74.6 download
ResNet50 24.79 xent 74.5/88.8/91.8 64.0 download
ResNet50 24.79 xent+htri 80.8/92.1/94.3 74.0 download
ResNet50M 29.63 xent 77.8/89.8/92.8 67.5 download
ResNet50M 29.63 xent+htri 82.3/93.8/95.3 75.4 download

iLIDS-VID

Model Param Size (M) Loss Rank-1/5/10 (%) mAP (%) Model weights Published Rank Published mAP
ResNet50 23.82 xent 62.7/82.7/90.7 72.6 download
ResNet50M 28.17 xent 63.3/85.3/92.7 73.6 download

PRID-2011

Model Param Size (M) Loss Rank-1/5/10 (%) mAP (%) Model weights Published Rank Published mAP
ResNet50 23.69 xent 75.3/96.6/97.8 84.3 download
ResNet50M 27.98 xent 85.4/96.6/98.9 90.1 download

Test

Say you have downloaded ResNet50 trained with xent on market1501. The path to this model is 'saved-models/resnet50_xent_market1501.pth.tar' (create a directory to store model weights mkdir saved-models/). Then, run the following command to test

python train_img_model_xent.py -d market1501 -a resnet50 --evaluate --resume saved-models/resnet50_xent_market1501.pth.tar --save-dir log/resnet50-xent-market1501 --test-batch 32

Likewise, to test video reid model, you should have a pretrained model saved under saved-models/, e.g. saved-models/resnet50_xent_mars.pth.tar, then run

python train_vid_model_xent.py -d mars -a resnet50 --evaluate --resume saved-models/resnet50_xent_mars.pth.tar --save-dir log/resnet50-xent-mars --test-batch 2

Note that --test-batch in video reid represents number of tracklets. If we set this argument to 2, and sample 15 images per tracklet, the resulting number of images per batch is 2*15=30. Adjust this argument according to your GPU memory. Currently, please set --test-batch to 1 in prid and ilidsvid due to this error.

Q&A

  1. How do I set different learning rates to different components in my model?

A: Instead of giving model.parameters() to optimizer, you could pass an iterable of dicts, as described here. Please see the example below

# First comment the following code.
#optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
param_groups = [
  {'params': model.base.parameters(), 'lr': 0},
  {'params': model.classifier.parameters()},
]
# Such that model.base will be frozen and model.classifier will be trained with
# the default leanring rate, i.e. args.lr. This example code only applies to model
# that has two components (base and classifier). Modify the code to adapt to your model.
optimizer = torch.optim.Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay)

Of course, you can pass model.classifier.parameters() to optimizer if you only need to train the classifier (in this case, setting the requires_grads wrt the base model params to false will be more efficient).

References

[1] He et al. Deep Residual Learning for Image Recognition. CVPR 2016.
[2] Yu et al. The Devil is in the Middle: Exploiting Mid-level Representations for Cross-Domain Instance Matching. arXiv:1711.08106.
[3] Huang et al. Densely Connected Convolutional Networks. CVPR 2017.
[4] Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.
[5] Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
[6] Kingma and Ba. Adam: A Method for Stochastic Optimization. ICLR 2015.
[7] Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.
[8] Zheng et al. MARS: A Video Benchmark for Large-Scale Person Re-identification. ECCV 2016.
[9] Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016
[10] Qian et al. Multi-scale Deep Learning Architectures for Person Re-identification. ICCV 2017.
[11] Wang et al. Person Re-Identification by Video Ranking. ECCV 2014.
[12] Hirzer et al. Person Re-Identification by Descriptive and Discriminative Classification. SCIA 2011.

deep-person-reid's People

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.