GithubHelp home page GithubHelp logo

val-iisc / sdat Goto Github PK

View Code? Open in Web Editor NEW
58.0 14.0 12.0 491 KB

[ICML 2022]Source code for "A Closer Look at Smoothness in Domain Adversarial Training",

License: MIT License

Python 100.00%
adversarial-training dann domain-adaptation icml-2022 pytorch sharpness-aware-minimization

sdat's Introduction

Smooth Domain Adversarial Training

Harsh Rangwani*, Sumukh K Aithal*, Mayank Mishra, Arihant Jain, R. Venkatesh Babu

This is the official PyTorch implementation for our ICML'22 paper: A Closer Look at Smoothness in Domain Adversarial Training.[Paper]

PWC PWC

Introduction

Smooth Domain Adversarial Training

In recent times, methods converging to smooth optima have shown improved generalization for supervised learning tasks like classification. In this work, we analyze the effect of smoothness enhancing formulations on domain adversarial training, the objective of which is a combination of task loss (eg. classification, regression etc.) and adversarial terms. We find that converging to a smooth minima with respect to (w.r.t.) task loss stabilizes the adversarial training leading to better performance on target domain. In contrast to task loss, our analysis shows that converging to smooth minima w.r.t. adversarial loss leads to sub-optimal generalization on the target domain. Based on the analysis, we introduce the Smooth Domain Adversarial Training (SDAT) procedure, which effectively enhances the performance of existing domain adversarial methods for both classification and object detection tasks.

TLDR: Just do a few line of code change to improve your adversarial domain adaptation algorithm by converting it to it's smooth variant.

Why use SDAT?

  • Can be combined with any DAT algorithm.
  • Easy to integrate with a few lines of code.
  • Leads to significant improvement in the accuracy of target domain.

DAT Based Method w/ SDAT

We provide the details of changes required to convert any DAT algorithm (eg. CDAN, DANN, CDAN+MCC etc.) to it's Smooth DAT version.

optimizer = SAM(classifier.get_parameters(), torch.optim.SGD, rho=args.rho, adaptive=False,
                    lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
# optimizer refers to the Smooth optimizer which contains parameters of the feature extractor and classifier.
optimizer.zero_grad()
# ad_optimizer refers to standard SGD optimizer which contains parameters of domain classifier.
ad_optimizer.zero_grad()

# Calculate task loss
class_prediction, feature = model(x)
task_loss = task_loss_fn(class_prediction, label)
task_loss.backward()

# Calculate ϵ̂ (w) and add it to the weights
optimizer.first_step()

# Calculate task loss and domain loss
class_prediction, feature = model(x)
task_loss = task_loss_fn(class_prediction, label)
domain_loss = domain_classifier(feature)
loss = task_loss + domain_loss
loss.backward()

# Update parameters (Sharpness-Aware update)
optimizer.step()
# Update parameters of domain classifier
ad_optimizer.step()

Getting started

  • Requirements

    • pytorch 1.9.1
    • torchvision 0.10.1
    • wandb 0.12.2
    • timm 0.5.5
    • prettytable 2.2.0
    • scikit-learn
  • Installation

git clone https://github.com/val-iisc/SDAT.git
cd SDAT
pip install -r requirements.txt

We use Weights and Biases (wandb) to track our experiments and results. To track your experiments with wandb, create a new project with your account. The project and entity arguments in wandb.init must be changed accordingly. To disable wandb tracking, the log_results flag can be used.

  • Datasets

    The datasets used in the repository can be downloaded from the following links: The datasets are automatically downloaded to the data/ folder if it is not available.

Training

We report our numbers primarily on two domain adaptation methods: CDAN w/ SDAT and CDAN+MCC w/ SDAT. The training scripts can be found under the examples subdirectory.

Domain Adversarial Training (DAT)

To train using standard CDAN and CDAN+MCC, use the cdan.py and cdan_mcc.py files, respectively. Sample command to execute the training of the aforementioned methods with a ViT B-16 backbone, on Office-Home dataset (with Art as source domain and Clipart as the target domain) can be found below.

python cdan_mcc.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_vit/OfficeHome_Ar2Cl --log_name Ar2Cl_cdan_mcc_vit --gpu 0 --lr 0.002 --log_results

Smooth Domain Adversarial Training (SDAT)

To train using our proposed CDAN w/ SDAT and CDAN+MCC w/ SDAT, use the cdan_sdat.py and cdan_mcc_sdat.py files, respectively.

A sample script to run CDAN+MCC w/ SDAT with a ViT B-16 backbone, on Office-Home dataset (with Art as source domain and Clipart as the target domain) is given below.

python cdan_mcc_sdat.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_sdat_vit/OfficeHome_Ar2Cl --log_name Ar2Cl_cdan_mcc_sdat_vit --gpu 0 --rho 0.02 --lr 0.002 --log_results

Additional commands to reproduce the results can be found from run_office_home.sh and run_visda.sh under examples.

Results

We following table reports the accuracy score across the various splits of Office-Home and VisDA-2017 datasets using CDAN+MCC w/ SDAT with VIT B-16 backbone. We also provide downloadable weights for the corresponding pretrained classifier.

Dataset Source Target Accuracy Checkpoints
Office-Home Art Clipart 70.8 ckpt
Art Product 80.7 ckpt
Art Real World 90.5 ckpt
Clipart Art 85.2 ckpt
Clipart Product 87.3 ckpt
Clipart Real World 89.7 ckpt
Product Art 84.1 ckpt
Product Clipart 70.7 ckpt
Product Real World 90.6 ckpt
Real World Art 88.3 ckpt
Real World Clipart 75.5 ckpt
Real World Product 92.1 ckpt
VisDA-2017 Synthetic Real 89.8 ckpt

Evaluation

To evaluate a classifier with pretrained weights, use the eval.py under examples. Set the --weight_path argument with the path of the weight to be evaluated.

A sample run to evaluate the pretrained ViT B-16 with CDAN+MCC w/ SDAT on Office-Home (with Art as source domain and Clipart as the target domain) is given below.

python eval.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 -b 24 --no-pool --weight_path path_to_weight.pth --log_name Ar2Cl_cdan_mcc_sdat_vit_eval --gpu 0 --phase test

A sample run to evaluate the pretrained ViT B-16 with CDAN+MCC w/ SDAT on VisDA-2017 (with Synthetic as source domain and Real as the target domain) is given below.

python eval.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a vit_base_patch16_224 --per-class-eval --train-resizing cen.crop --weight_path path_to_weight.pth --log_name visda_cdan_mcc_sdat_vit_eval --gpu 0 --no-pool --phase test

Overview of the arguments

Generally, all scripts in the project take the following flags

  • -a: Architecture of the backbone. (resnet50|vit_base_patch16_224)
  • -d: Dataset (OfficeHome|DomainNet)
  • -s: Source Domain
  • -t: Target Domain
  • --epochs: Number of Epochs to be trained for.
  • --no-pool: Use --no-pool for all experiments with ViT backbone.
  • --log_name: Name of the run on wandb.
  • --gpu: GPU id to use.
  • --rho: $\rho$ value in SDAT (Applicable only for SDAT runs).

Acknowledgement

Our implementation is based on the Transfer Learning Library. We use the PyTorch implementation of SAM from https://github.com/davda54/sam.

Citation

If you find our paper or codebase useful, please consider citing us as:

@InProceedings{rangwani2022closer,
  title={A Closer Look at Smoothness in Domain Adversarial Training},
  author={Rangwani, Harsh and Aithal, Sumukh K and Mishra, Mayank and Jain, Arihant and Babu, R. Venkatesh},
 booktitle={Proceedings of the 39th International Conference on Machine Learning},
  year={2022}
}

sdat's People

Contributors

mmayank74567 avatar rangwani-harsh avatar sumukhaithal6 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

sdat's Issues

about domain acc

Thanks for your sharing. I try to apply this method (cdan_sdat.py) to UDA for an image regression task. But the domain acc keeps 100% with no decline. Can you give me some guidance? Thanks a lot.

Learning rate

hi,Very solid work, I have a question:

Can the learning rate of discriminator be updated? Because in the function class DomainDiscriminator(nn.Sequential): "lr": 1.

def get_parameters(self) -> List[Dict]:
    return [{"params": self.parameters(), "lr": 1.}]

**lr_scheduler_ad = LambdaLR(
    ad_optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))**

Does it work?

Because we found in the experiment that the discrimination loss remained stable when the epoch was very small, although the task loss was still decreasing.

We look forward to your reply
Thank you!

Question about correctness of Domain Accuracy

Hello everybody, we are currently examining different DA method and tried to reproduce the SDAT paper results on our system.
So far I have made two trainings with all the settings and arguments I could gather from the paper for DomainNET. The val_acc do look quite promising:

real->sketch:
image

sketch -> clipart
image

However, the Domain accuracy is somewhat confusing for us (not smoothed):
real->sketch:
image

sketch -> clipart:
image

We would have expected a graph similar to this one from the paper for the Domain Accuracy:
image

However, it seems this was done with the Homeoffice dataset and different settings. Question is now, can we assume that our results are correct?

Different behaviours on VisDA-2017 using different pretrained models from timm

Thanks for the great work. I meet two problems when conducting the experiment using ViT on VisDA-2017.

  1. It seems that the ViT backbone doesn't match with the bottleneck when setting no_pool. The output of ViT backbone is a sequence of tokens instead of a single class token. Thus, it makes the BatchNorm1d layer complains about the dimension.
  2. I fix the previous problem by adding a pool layer to extract the class token:
    pool_layer = lambda _x: _x[:, 0] if args.no_pool else None
    Then use the exact command in examples/run_visda.sh to run CDAN_MCC_SDAT:
    python cdan_mcc_sdat.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a vit_base_patch16_224 --epochs 15 --seed 0 --lr 0.002 --per-class-eval --train-resizing cen.crop --log logs/cdan_mcc_sdat_vit/VisDA2017 --log_name visda_cdan_mcc_sdat_vit --gpu 0 --no-pool --rho 0.02 --log_results
    Finally I get a slightly lower accuracy as below:
    global correct: 86.0
    mean correct:88.3
    mean IoU: 78.5
    +------------+-------------------+--------------------+
    | class | acc | iou |
    +------------+-------------------+--------------------+
    | aeroplane | 97.83323669433594 | 96.3012924194336 |
    | bicycle | 88.43165588378906 | 81.25331115722656 |
    | bus | 81.79104614257812 | 72.69281768798828 |
    | car | 78.06941986083984 | 67.53160095214844 |
    | horse | 97.31400299072266 | 92.78455352783203 |
    | knife | 96.91566467285156 | 82.31681823730469 |
    | motorcycle | 94.9102783203125 | 83.37374877929688 |
    | person | 81.3499984741211 | 58.12790298461914 |
    | plant | 94.04264831542969 | 89.68553161621094 |
    | skateboard | 95.87899780273438 | 81.48286437988281 |
    | train | 94.05099487304688 | 87.69535064697266 |
    | truck | 59.04830551147461 | 48.311458587646484 |
    +------------+-------------------+--------------------+
    test_acc1 = 86.0
    I notice that the epochs is 15 in the scripts. Is the experiment setting correct? How to get the reported accuracy? Many thank.

Unavailable dataset

Hello! I tried run_office_home.sh for CDAN_MCC_SDAT and I ran into the error below. Is it possible thtat the download list is outdated? I've tried manually downloading OfficeHome and setting the download boolean in utils.py to False but it doesn't seem to work with raw OfficeHome download. Could you provide more details on the dataset for it?

Downloading image_list
Fail to download image_list.zip from url link https://cloud.tsinghua.edu.cn/f/ca3a3b6a8d554905b4cd/?dl=1
Please check you internet connection.Simply trying again may be fine.

Newcomers ask for help

sorry to interupt, I just simply git the project and make sure the environment is suitable. then I run the
python cdan_mcc.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_vit/OfficeHome_Ar2Cl --log_name Ar2Cl_cdan_mcc_vit --gpu 0 --lr 0.002 --log_results which is in your readme text .
and the feedback is this .
**Traceback (most recent call last): File "/usr/data_disk1/ar/SDAT/examples/cdan_mcc.py", line 10, in <module> import wandb ModuleNotFoundError: No module named 'wandb'**
I just start learning deeping learning for two weeks. hence i dont know how to fix the error bacause the code structure is too complex for me . can you tell me how to deal with it ? thanks a lot.

How to get correct accuracy?

Hello, I tried the source code of your method on Visda-2017 with the backbone of ResNet101, but I only got 71.7%. So I want to know how to get 84.3%, thank you very much.

How to get the accuracy reported in Paper

Hello, I tried the source code of your method on Visda-2017 with the backbone of VIT, but I only got 87.8%. So I want to know how to get 89.8%, thank you very much.
And besides, I want to know how to get the real accuracy of these datasets on the backbone of ResNet 50, thank you very much,

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.