GithubHelp home page GithubHelp logo

wangwenhao0716 / domainmix Goto Github PK

View Code? Open in Web Editor NEW
21.0 3.0 1.0 335 KB

[BMVC 2021] The official implementation of "DomainMix: Learning Generalizable Person Re-Identification Without Human Annotations"

License: MIT License

Python 100.00%

domainmix's Introduction

DomainMix

[BMVC 2021] The official implementation of "DomainMix: Learning Generalizable Person Re-Identification Without Human Annotations"

[paper] [demo] [Chinese blog]

DomainMix works fine on both PaddlePaddle and PyTorch.

Framework:

Requirement

  • Python 3.7
  • Pytorch 1.7.0
  • sklearn 0.23.2
  • PIL 5.4.1
  • Numpy 1.19.4
  • Torchvision 0.8.1

Reproduction Environment

  • Test our models: 1 Tesla V100 GPU.
  • Train new models: 4 Telsa V100 GPUs.
  • Note that the required for GPU is not very strict, and 6G memory per GPU is minimum.

Preparation

  1. Dataset

We evaluate our algorithm on RandPerson, Market-1501, CUHK03-NP and MSMT17. You should download them by yourselves and prepare the directory structure like this:

*DATA_PATH
      *data
         *randperson_subset
             *randperson_subset
                 ...
         *market1501
             *Market-1501-v15.09.15
                 *bounding_box_test
                 ...
         *cuhk03_np
             *detected
             *labeled
         *msmt17
             *MSMT17_V1
                 *test
                 *train
                 ...
  1. Pretrained Models

We use ResNet-50 and IBN-ResNet-50 as backbones. The pretrained models for ResNet-50 will be downloaded automatically. When training with the backbone of IBN-ResNet-50, you should download the pretrained models from here, and save it like this:

*DATA_PATH
      *logs
         *pretrained
             resnet50_ibn_a.pth.tar
  1. Our Trained Models

We provide our trained models as follows. They should be saved in ./logs/trained

Market1501:

DomainMix(43.5% mAP) DomainMix-IBN(45.7% mAP)

CUHK03-NP:

DomainMix(16.7% mAP) DomainMix-IBN(18.3% mAP)

MSMT17:

DomainMix(9.3% mAP) DomainMix-IBN(12.1% mAP)

Train

We use RandPerson+MSMT->Market as an example, other DG tasks will follow similar pipelines.

CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py \
-dsy randperson_subset -dre msmt17 -dun market1501 \
-a resnet50 --margin 0.0 --num-instances 4 -b 64 -j 4 --warmup-step 5 \
--lr 0.00035 --milestones 10 15 30 40 50 --iters 2000 \
--epochs 60 --eval-step 1 --logs-dir logs/randperson_subsetmsTOm/domainmix

Test

We use RandPerson+MSMT->Market as an example, other DG tasks will follow similar pipelines.

CUDA_VISIBLE_DEVICES=0 python test.py -b 256 -j 8 --dataset-target market1501 -a resnet50 \
--resume logs/trained/model_best_435.pth.tar

Acknowledgement

Some parts of our code are from MMT and SpCL. Thanks Yixiao Ge for her contribution.

Citation

If you find this code useful for your research, please cite our paper

@inproceedings{wang2021domainmix,
  title={DomainMix: Learning Generalizable Person Re-Identification Without Human Annotations},
  author={Wenhao Wang and Shengcai Liao and Fang Zhao and Kangkang Cui and Ling Shao},
  booktitle={British Machine Vision Conference},
  year={2021}
}

domainmix's People

Contributors

wangwenhao0716 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

yuzhouxianzhi

domainmix's Issues

about real dataset (labeled or unlabeled)

Thanks for your work, I noticed that you validated the importance of using unlabeled real datasets.
So how do I control in code to implement data states ((labeled or unlabeled)) for training.

Thank you!

same features for different images

I wrote the below code:

import os.path as osp
import time

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from PIL import Image
from torchvision import transforms as T

from dg import models
from dg.evaluators import Evaluator
from dg.utils.serialization import load_checkpoint, copy_state_dict

import numpy as np
from sklearn.metrics.pairwise import cosine_similarity


def normalize_embedding(embedding):
    norm = np.linalg.norm(embedding)
    return embedding / norm


def preprocess_single_image(image_path, height, width):
    normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    test_transformer = T.Compose([
        T.Resize((height, width), interpolation=3),
        T.ToTensor(),
        normalizer
    ])
    img = Image.open(image_path).convert('RGB')
    img = test_transformer(img)
    img = img.unsqueeze(0)  
    return img


def main():
    image_paths = [
        r"D:\Saeed\Data\Re-ID\market1501\bounding_box_train\0002_c1s1_000451_03.jpg",
        r"D:\Saeed\Data\Re-ID\market1501\bounding_box_train\0002_c1s1_000551_01.jpg",
        r"D:\Saeed\Data\Re-ID\market1501\bounding_box_train\0011_c1s6_027271_01.jpg",
        # Add more image paths here
    ]

    try:
        args = {
            'arch': 'resnet50',
            'resume': r'D:\Saeed\DEVS\DomainMix\logs\trained\model_best_435.pth.tar',
            'height': 256,
            'width': 128,
        }

        print(models.names())

        # Create model
        model = models.create(args['arch'], pretrained=False, num_features=0, dropout=0, num_classes=0)
        model.cuda()

        # Load from checkpoint
        checkpoint = load_checkpoint(args['resume'])
        copy_state_dict(checkpoint['state_dict'], model)

        model = model.eval()

        features = []
        for image_path in image_paths:
            # Preprocess the image
            img = preprocess_single_image(image_path, args['height'], args['width']).cuda()
            feature = normalize_embedding(model(img).cpu().data.numpy())
            features.append(feature)
            print(f"Processing {image_path}")

        # cross similarity check
        for i in range(len(features)):
            for j in range(len(features)):
                if j > i:
                    similarity = cosine_similarity(features[i].reshape(1, -1), features[j].reshape(1, -1))
                    img_name1 = osp.basename(image_paths[i])
                    img_name2 = osp.basename(image_paths[j])
                    print(f'{img_name1} - {img_name2}: {similarity}')

    except:
        print(f'Error in model {model}')


if __name__ == '__main__':
    main()

And it always return same features for different images:

Processing D:\Saeed\Data\Re-ID\market1501\bounding_box_train\0002_c1s1_000451_03.jpg
Processing D:\Saeed\Data\Re-ID\market1501\bounding_box_train\0002_c1s1_000551_01.jpg
Processing D:\Saeed\Data\Re-ID\market1501\bounding_box_train\0011_c1s6_027271_01.jpg
0002_c1s1_000451_03.jpg - 0002_c1s1_000551_01.jpg: [[0.99988645]]
0002_c1s1_000451_03.jpg - 0011_c1s6_027271_01.jpg: [[0.9997024]]
0002_c1s1_000551_01.jpg - 0011_c1s6_027271_01.jpg: [[0.99968207]]

合成数据

您好,请问一下合成数据是用什么工具合成的呢?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.