GithubHelp home page GithubHelp logo

tfzhou / protoseg Goto Github PK

View Code? Open in Web Editor NEW
325.0 17.0 35.0 1.95 MB

CVPR2022 (Oral) - Rethinking Semantic Segmentation: A Prototype View

Home Page: https://arxiv.org/abs/2203.15102

License: MIT License

Python 70.80% Shell 18.76% Cuda 5.58% C 3.01% C++ 1.75% Cython 0.10%
semantic-segmentation prototype nonparametric fcn transformer softmax nearest-neighbours-classifier clustering deep-learning metric-learning

protoseg's Introduction

Rethinking Semantic Segmentation: A Prototype View

Rethinking Semantic Segmentation: A Prototype View,
Tianfei Zhou, Wenguan Wang, Ender Konukoglu and Luc Van Gool
CVPR 2022 (Oral) (arXiv 2203.15102)

News

  • [2022-04-19] Release the code based on openseg.pytorch!
  • [2022-03-31] Paper link updated!
  • [2022-03-12] Repo created. Paper and code will come soon.

Abstract

Prevalent semantic segmentation solutions, despite their different network designs (FCN based or attention based) and mask decoding strategies (parametric softmax based or pixel-query based), can be placed in one category, by considering the softmax weights or query vectors as learnable class prototypes. In light of this prototype view, this study uncovers several limitations of such parametric segmentation regime, and proposes a nonparametric alternative based on non-learnable prototypes. Instead of prior methods learning a single weight/query vector for each class in a fully parametric manner, our model represents each class as a set of non-learnable prototypes, relying solely on the mean features of several training pixels within that class. The dense prediction is thus achieved by nonparametric nearest prototype retrieving. This allows our model to directly shape the pixel embedding space, by optimizing the arrangement between embedded pixels and anchored prototypes. It is able to handle arbitrary number of classes with a constant amount of learnable parameters.We empirically show that, with FCN based and attention based segmentation models (i.e., HR-Net, Swin, SegFormer) and backbones (i.e., ResNet, HRNet, Swin, MiT), our nonparametric framework yields compelling results over several datasets (i.e., ADE20K, Cityscapes, COCO-Stuff), and performs well in the large-vocabulary situation. We expect this work will provoke a rethink of the current de facto semantic segmentation model design.

Installation

This implementation is built on openseg.pytorch. Many thanks to the authors for the efforts.

Please follow the Getting Started for installation and dataset preparation.

Performance

Cityscapes

Method Train Set Val Set Iters Batch Size mIoU Log CKPT Script
HRNet train val 80K 8 79.0 log ckpt scripts/cityscapes/hrnet/run_h_48_d_4.sh
Ours train val 80K 8 80.1 log ckpt scripts/cityscapes/hrnet/run_h_48_d_4_proto.sh

More results will come soon

Citation

@inproceedings{zhou2022rethinking,
    author    = {Zhou, Tianfei and Wang, Wenguan and Konukoglu, Ender and Van Gool, Luc},
    title     = {Rethinking Semantic Segmentation: A Prototype View},
    booktitle = {CVPR},
    year      = {2022}
}

Relevant Projects

Please also see our works [1] for a novel training paradigm with a cross-image, pixel-to-pixel contrative loss, and [2] for a novel hierarchy-aware segmentation learning scheme for structured scene parsing.

[1] Exploring Cross-Image Pixel Contrast for Semantic Segmentation - ICCV 2021 (Oral) [arXiv][code]

[2] Deep Hierarchical Semantic Segmentation - CVPR 2022 [arXiv][code]

protoseg's People

Contributors

tfzhou avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

protoseg's Issues

parameter numbers of the entire model?

Thanks for the great work !
I want to ask a question that confuses me. In the Paper, Table 4 shows that the parameter numbers of the entire model is not increased. However, i see that in the code the prototype is inceased by class_num as follow:
self.prototypes = nn.Parameter(torch.zeros(self.num_classes, self.num_prototype, in_channels), requires_grad=True)
If class_num is increased, the parameter numbers of prototypes are also increased. Did I get it wrong?

Can't run the code because of dims

Your work is greatly excellent and I want to get the result of the code.
But I am poor at coding, when I run the code with H_48_D_4_proto.json setting using cityscapes dataset. I met a problem when the code stepped into the val part.
Here is the reported error:

Traceback (most recent call last):
  File "main.py", line 229, in <module>
    model.train()
  File "/data/qc/ProtoSeg/segmentor/trainer.py", line 373, in train
    self.__train()
  File "/data/qc/ProtoSeg/segmentor/trainer.py", line 267, in __train
    self.__val()
  File "/data/qc/ProtoSeg/segmentor/trainer.py", line 336, in __val
    self.evaluator.update_score(outputs, data_dict['meta'])
  File "/data/qc/ProtoSeg/segmentor/tools/evaluator/standard.py", line 88, in update_score
    item = outputs[idx].permute(0, 2, 3, 1)
RuntimeError: number of dims don't match in permute

and here is the code

def update_score(self, outputs, metas):
        if isinstance(outputs, torch.Tensor):
            outputs = [outputs]
            
        for i in range(len(outputs[0])):

            ori_img_size = metas[i]['ori_img_size']
            border_size = metas[i]['border_size']

            outputs_numpy = {}
            for name, idx in self.output_indices.items():
                item = outputs[idx].permute(0, 2, 3, 1)
                if self.configer.get('dataset') == 'celeba':
                    # the celeba image is of size 1024x1024
                    item = cv2.resize(
                        item[i, :border_size[1], :border_size[0]].cpu().numpy(),
                        tuple(x // 2 for x in ori_img_size), interpolation=cv2.INTER_CUBIC
                    )
                else:
                    item = cv2.resize(
                        item[i, :border_size[1], :border_size[0]].cpu().numpy(),
                        tuple(ori_img_size), interpolation=cv2.INTER_CUBIC
                    )
                outputs_numpy[name] = item

            for name in outputs_numpy:
                tasks.task_mapping[name].eval(
                    outputs_numpy, metas[i], self.running_scores
                )

I have tried to print the dims of the variable outputs[idx], the result is [38,256,512], but the function permute(0,2,3,1) should match with a dim of 4. I'm stuck here.

I would appreciate it if you could answer this question for me.

awesome work!

Awesome work! but can you provide the results of each class's mIoU of the original model and the improved model, which can help me think better about the problems that can be improved. And this work "Exploring Cross-Image Pixel Contrast for Semantic Segmentation"
Best.

'ce_weight' parameter in loss

Thanks for your great work!!
I'm planning to run your model on the coco_stuff dataset. I'm curious about how the 'ce_weight' parameter in the H_48_D_4_proto.json file was configured for the loss function. Could you provide details?

스크린샷 2024-03-11 오후 9 59 58

intra-class prototypes is same.

Hi.Thanks for your great work.But I have one problem here.
I downloaded the model file hrnet_w48_proto_lr1x_hrnet_proto_80k_latest.pth you supported.and
use code below to get the self similarity of prototypes.

protos=model['state_dict']['module.prototypes']
feat=protos.view(-1,protos.shape[-1])
simi=feat @ feat.t()
simi=simi.cpu().numpy()
sns.heatmap(simi)
print(simi[0,:10])
print(simi[100,100:110])

got result:

The prototypes in one class is all the same! I'm confusing this.
Could you please me some help.Thank you.

Layernorm in Prototype learning

Thanks for your great work!
I notice you use layernorm for the final features before the classifier and also for the predictions. I think it is quite uncommon in prototype learning (correct me if i am wrong).

Could you please provide some explanation for this? And if removing the two layernorm, will the performance be degraded?

self.feat_norm = nn.LayerNorm(in_channels)

Questions about code

Dear Author, I'm very interested in this wonderful work of yours, but due to my weak code ability, I can't find your code of online-clustering part... Could you please tell me which part of this code I should pay more attention to?

Question about the existence of the normalization, and how to run it

I'm impressed with your work.

  1. About normalization
    I wonder if the l2 normalization of embeddings of pixel i $i \in I$ and prototypes $p_{c,k}$ exists in your code. If exists, I'm glad to get the answer of where it is. The below image is from the your paper.
    image

  2. how to run it

  • Basically, i check README.md but i cannot find how to run it. Simply running main.py requires some config path like
TypeError: stat: path should be string, bytes, os.PathLike or integer, not NoneType

It might need the config files to set the config. Could you let me know the basic running code?

  1. The difference b/w main.py and main_contrastive.py
    What is difference b/w them?

Thank you for reading

Question about seed

if args_parser.seed is not None:
	random.seed(args_parser.seed)
	torch.manual_seed(args_parser.seed)

Each gpu is set to the same seed.

# fix the seed for reproducibility
if args_parser.seed is not None:
	from lib.utils.distributed import get_rank()
	seed = args_parser.seed + get_rank()
	torch.manual_seed(seed)
	np.random.seed(seed)
	random.seed(seed)

Reference

Question about loss

Hi, I'm interested in your work. After reading the paper, I'm confused that the PPC loss is achieved by contrastive learning strategy in your paper. But according to the code, the PPC loss is using cross entropy loss. Hope to receive your reply, thanks.

Question about code

May I ask when will the core code be released? I am so interested in your works!!!!!!!!!!!

Question about the prototype initialization?

Hi, thanks for the impressive work.

After reading the paper, I have a question that how the micro-prototypes are initialized? They seem to be properly initialized so as for a reasonable solution in eq.(10).

Cheers

Initial Prototype

Thank you for your amazing work.

The paper clearly outlines an iterative process for updating the prototype.

However, I'm unclear about the criteria or process used for selecting the initial prototype.

Question about paper [# model parameter]

Dear author,
Thank you so much for your work and code.
I have a question about the number of model parameters.

As I understand the paper, pixels are classified as the closest prototype among CK prototypes at inference time. In the end, we have to store CK prototypes, then I wonder why we don't interpret them as model parameters. Also, the number of prototypes to be stored is proportional to the number of classes. Is it just convention?

Thank you.

Question about Within-Class Online Clustering

Hi, I'm interested in your work. After reading the paper, I'm confused that the goal of Within-Class Online Clustering is to map the pixels Ic to the K prototypes of class c. But how to know if pixels Ic belongs to class c? Did you use Ground Truth in this step? So how do you set it up when testing?

Hope to receive your reply, thanks!

Questions about K prototypes

Hi, I'm interested in your work. I wonder if there are k prototypes of some classes that become similar after training? For example, after visualization according to Figure 3 in the paper, it will be found that the activation area of each prototype is roughly the same. I found this problem while running your code. I suspect it's caused by some classes of my dataset that don't have meaningful parts.

Hope to receive your reply, thanks!

Question regarding IoUs of pretrained HRNet Proto

Hi, I downloaded the checkpoint, prepared the data and ran the evaluation script :

bash scripts/cityscapes/hrnet/run_h_48_d_4_proto.sh val hrnet_proto_80k

I had to include a tiny fix label_img_ = Image.fromarray(label_img_) instead of label_img_ = Image.fromarray(label_img_, 'P') in tester.py because the labels were all black in the output directory. If I then execute the above line, I end up with an mIoU of 85.7, much better than the 81.1 reported in your paper in Table 2 for HRNet. This is the output:

classes          IoU      nIoU
--------------------------------
road          : 0.978194      nan
sidewalk      : 0.817676      nan
building      : 0.966470      nan
wall          : 0.586037      nan
fence         : 0.650901      nan
pole          : 0.858629      nan
traffic light : 0.865478      nan
traffic sign  : 0.904490      nan
vegetation    : 0.977709      nan
terrain       : 0.664823      nan
sky           : 0.973433      nan
person        : 0.950269    0.000000
rider         : 0.802941    0.000000
car           : 0.985926    0.000000
truck         : 0.818478    0.000000
bus           : 0.939901    0.000000
train         : 0.811736    0.000000
motorcycle    : 0.806717    0.000000
bicycle       : 0.919648    0.000000
--------------------------------
Score Average : 0.856814    0.000000
--------------------------------

I also used my own evaluation script on the generated labels that are in the label directory and I get exactly the same results. Could you check?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.