GithubHelp home page GithubHelp logo

interface's Introduction

InterFace

InterFace is a novel loss function for face recognition that enhances discriminative power by adding margin penalties between deep features and all weights, not just between features and their corresponding weights. This approach increases class separability and reduces intra-class variations, making models more robust in real-world scenarios.

Distributed InterFace Training in Pytorch

This is a deep learning library that makes face recognition efficient, and effective, which can train tens of millions identity on a single server.

Requirements

How to Training

To train a model, run train.py with the path to the configs.
The example commands below show how to run distributed training.

1. To run on a machine with 8 GPUs:

python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=12581 train.py configs/ms1mv3_r50_lr02

2. To run on 2 machines with 8 GPUs each:

Node 0:

python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=12581 train.py configs/webface42m_r100_lr01_pfc02_bs4k_16gpus

Node 1:

python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=12581 train.py configs/webface42m_r100_lr01_pfc02_bs4k_16gpus

Download Datasets or Prepare Datasets

References

Citing

@article{sang2022interface,
  title={InterFace: Adjustable Angular Margin Inter-class Loss for Deep Face Recognition},
  author={Sang, Meng and Chen, Jiaxuan and Li, Mengzhen and Tan, Pan and Pan, Anning and Zhao, Shan and Yang, Yang},
  journal={arXiv preprint arXiv:2210.02018},
  year={2022}
}

interface's People

Contributors

meng-sang avatar

Stargazers

Ykuan Zhang avatar

Watchers

 avatar

interface's Issues

_sample_center_embeddings

I have modified the way of getting _sample_center_embeddings and thought it was equivalent, but the loss gets NAN after a few iters.
It makes me confused. My changed code is as following:

        _sample_center_embeddings = torch.zeros((labels.shape[0], self.embedding_size), dtype=self.weight_activated.dtype, device=self.weight_activated.device)

        idx = torch.where(labels != -1)[0]

        _sample_center_embeddings[idx] = self.weight_activated[labels[idx].view(-1)]

        distributed.all_reduce(_sample_center_embeddings, op=distributed.ReduceOp.SUM)

And the whole forward fuction in partial_fc is :

 """
        Parameters:
        ----------
        local_embeddings: torch.Tensor
            feature embeddings on each GPU(Rank).
        local_labels: torch.Tensor
            labels on each GPU(Rank).

        Returns:
        -------
        loss: torch.Tensor
            pass
        """
        local_labels.squeeze_()
        local_labels = local_labels.long()
        self.update()

        batch_size = local_embeddings.size(0)
        if self.last_batch_size == 0:
            self.last_batch_size = batch_size
        assert self.last_batch_size == batch_size, (
            "last batch size do not equal current batch size: {} vs {}".format(
                self.last_batch_size, batch_size))

        _gather_embeddings = [
            torch.zeros((batch_size, self.embedding_size)).cuda()
            for _ in range(self.world_size)
        ]
        _gather_labels = [
            torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)
        ]
       
        _list_embeddings = AllGather(local_embeddings, *_gather_embeddings)
        distributed.all_gather(_gather_labels, local_labels)

        embeddings = torch.cat(_list_embeddings)
        labels = torch.cat(_gather_labels)
  

        labels = labels.view(-1, 1)
        index_positive = (self.class_start <= labels) & (
                labels < self.class_start + self.num_local
        )
        # print(self.rank, labels)

        labels[~index_positive] = -1

        labels[index_positive] -= self.class_start

        if self.sample_rate < 1:
            self.sample(labels, index_positive, optimizer)
        
        _sample_center_embeddings = torch.zeros((labels.shape[0], self.embedding_size), dtype=self.weight_activated.dtype, device=self.weight_activated.device)

        idx = torch.where(labels != -1)[0]

        _sample_center_embeddings[idx] = self.weight_activated[labels[idx].view(-1)]

        distributed.all_reduce(_sample_center_embeddings, op=distributed.ReduceOp.SUM)

        with torch.cuda.amp.autocast(self.fp16):
            norm_embeddings = normalize(embeddings) # (B*n, C)
            norm_sample_center_embeddings = normalize(_sample_center_embeddings) # (B*n, C)
            norm_weight_activated = normalize(self.weight_activated)
            logits = linear(norm_embeddings, norm_weight_activated)
            center_logits = linear(norm_sample_center_embeddings, norm_weight_activated)
        if self.fp16:
            logits = logits.float()
            center_logits = center_logits.float()

        logits = logits.clamp(-1, 1)
        center_logits = center_logits.clamp(-1, 1)

        logits = self.margin_softmax(logits, labels, center_logits)
        loss = self.dist_cross_entropy(logits, labels)
        return loss

How to intergrate InterFace into CosFace

Hello, have you tried integrating InterFace into CosFace? if yes, how did it work and can you provide the code snippet?
(In my dataset, CosFace is able to converge better than ArcFace, so I wonder if InterFace can improve my loss)

Thanks in advanced!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.