GithubHelp home page GithubHelp logo

moskomule / dda Goto Github PK

View Code? Open in Web Editor NEW
119.0 7.0 12.0 2.73 MB

Differentiable Data Augmentation Library

License: MIT License

Python 4.95% Jupyter Notebook 95.05%
differentiable-data-augmentation data-augmentation pytorch faster-auto-augment

dda's Introduction

Differentiable Data Augmentation Library

This library is a core of Faster AutoAugment and its descendants. This library is research oriented, and its AIP may change in the near future.

Requirements and Installation

Requirements

Python>=3.8
PyTorch>=1.5.0
torchvision>=0.6
kornia>=0.2

Installation

pip install -U git+https://github.com/moskomule/dda

APIs

dda.functional

Basic operations that can be differentiable w.r.t. the magnitude parameter mag. When mag=0, no augmentation is applied, and when mag=1 (and mag=-1 if it exists), the severest augmentation is applied. As introduced in Faster AutoAugment, some operations use straight-through estimator to be differentiated w.r.t. their magnitude parameters.

def operation(img: torch.Tensor,
              mag: Optional[torch.Tensor]) -> torch.Tensor:
    ...

dda.pil contains the similar APIs using PIL (not differentiable).

dda.operations

class Operation(nn.Module):
   
    def __init__(self,
                 initial_magnitude: Optional[float] = None,
                 initial_probability: float = 0.5,
                 magnitude_range: Optional[Tuple[float, float]] = None,
                 probability_range: Optional[Tuple[float, float]] = None,
                 temperature: float = 0.1,
                 flip_magnitude: bool = False,
                 magnitude_scale: float = 1,
                 debug: bool = False):
        ...

If magnitude_range=None, probability_range=None, then magnitude, probability is not Parameter but Buffer, respectively.

magnitude moves in magnitude_scale * magnitude_range. For example, dda.operations.Rotation has magnitude_range=[0, 1] and magnitude_scale=30 so that magnitude is between 0 to 30 degrees.

To differentiate w.r.t. the probability parameter, RelaxedBernoulli is used.

Examples

Citation

dda (except RandAugment) is developed as a core library of the following research projects.

If you use dda in your academic research, please cite hataya2020a.

@inproceesings{hataya2020a,
    title={{Faster AutoAugment: Learning Augmentation Strategies using Backpropagation}},
    author={Ryuichiro Hataya and Jan Zdenek and Kazuki Yoshizoe and Hideki Nakayama},
    year={2020},
    booktitle={ECCV}
}

...

dda's People

Contributors

moskomule avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

dda's Issues

Questions about the implementation of the building blocks of your Faster AA system

Hello,

I want to compare the operations your Faster AA system selects after training with other systems. I am referring to the following file faster_autoaugment\policy.py.

I tried to print the policy with all its learned parameters: W (Selection), Probability, and Magnitude using the following code.

    path = pathlib.Path(cfg.output_dir) / 'K{cfg.model.operation_count}/policy_weights/{cfg.optim.epochs}.pt'
    assert path.exists()

    policy_weight = torch.load(path, map_location="cpu")
    print('epoch', policy_weight['epoch'])
    print('step', policy_weight['step'])
    
    policy = Policy.faster_auto_augment_policy(num_chunks=cfg.model.num_chunks, 
                                                **policy_weight["policy_kwargs"])

    policy.load_state_dict(policy_weight["policy"])

    #print the trainable parameters
   for name, param in policy.named_parameters():
        if param.requires_grad:
            print(name, param.data)

For example, with the Rotation operation, the print statement returns

sub_policies.1.stages.0.operations.4._magnitude tensor([0.7193])
sub_policies.1.stages.0.operations.4._probability tensor([0.0831])

Next, I visualize the augmented images using the following code

#visualization
    with torch.no_grad():
        # input: [-1, 1]
        # input = policy(policy.denormalize_(input))
        aug = policy(policy.denormalize_(input))
    print(f'aug {aug.shape}')
    batch = aug[:10]
    #save the images
    show(batch, 'aug.png')

At the same time, I also modified the Policy class defined in faster_autoaugment\policy.py to print the selected sub-policies.

def _forward(self,
                 input: Tensor
                 ) -> Tensor:
        index = random.randrange(self.num_sub_policies)
        print(f'Used sub-policies {index}')
        return self.sub_policies[index](input)

However, with the Rotation operation, the print statement returns a new information.

Used sub-policies 1
Used operations Rotate(probability=0.500 (learnable), magnitude=21.579 (learnable), temperature=0.100)

Because I can't reconcile the output information given by the first print statement and the second print statement. Can you help suggest a way to help fix the reconciliation issue here? Many thanks.

why (0.5,0.5,0.5) for policy's mean and std?

Hello Mr. Hataya

In the search stage, a different set of mean and variance (0.5, 0.5, 0.5) other than the dataset's mean and variance is initialized in the policy for de-normalization and re-normalization. I replaced (0.5,0.5,0.5) with the default mean and variance of cifar10 and found the effectiveness of the searched policy was bad. Would you please explain why this technique might enhance the policy searching process? Thank you.

def faster_auto_augment_policy(num_sub_policies: int,
temperature: float,
operation_count: int,
num_chunks: int,
mean: Optional[torch.Tensor] = None,
std: Optional[torch.Tensor] = None,
) -> Policy:
if mean is None or std is None:
mean = torch.ones(3) * 0.5
std = torch.ones(3) * 0.5

Bolian

homura can not install

Hello, I can not install homura. and in train.py, hat "from homura import trainers,TensorMap, optim, callbacks". but I fund in homura folder hat no TensorMap.py

Can you share the found policy weights for the image classification benchmarks?

Hi,

I am trying to replicate your work. I can't find your found policies's weights on these benchmarks yet. I looked into Fig.5 in the Faster AutoAugment paper. For confirmation, does the Figure show the complete found policies and their "weights"? In the Fig.5, are the found policies based on the full CIFAR-10 training set or the subset?

I also looked into [the train.py] program at . The code shows
path = Path(hydra.utils.get_original_cwd()) / cfg.path
What does Path(hydra.utils.get_original_cwd())return? I can't find the path cfg.path to the found policies yet.

Can you help point me to the files with such information? For example, in the Fast AA, readers can find their policies at this link https://github.com/kakaobrain/fast-autoaugment/blob/master/FastAutoAugment/archive.py

Any comments is appreciated. Thanks.

Some questions about straight through estimator?

Hello! Thank you for providing a new trick on data augmentation!
I have a question about straight through estimator.

@tensor_function
def solarize(img: torch.Tensor,
mag: torch.Tensor) -> torch.Tensor:
mag = mag.view(-1, 1, 1, 1)
return ste(torch.where(img < mag, img, 1 - img), mag)

@tensor_function
def posterize(img: torch.Tensor,
mag: torch.Tensor) -> torch.Tensor:
# mag: 0 to 1
mag = mag.view(-1, 1, 1, 1)
with torch.no_grad():
shift = ((1 - mag) * 8).long()
shifted = (img.mul(255).long() << shift) >> shift
return ste(shifted.float() / 255, mag)

For solarize and posterize operations, the gradient of loss on input img will disappear? In this case, if these two operations are the second operation, then there is no way to return the gradient for the first operation?

I hope you can upload the entire code of faster autoaugmentation as soon as possible, thank you! !

About the torchvision.io import problem

Hi:

I also use the library of homura, but I faced the issue:
ImportError: cannot import name 'read_image' from 'torchvision.io' (/home/csgrad/ychan/anaconda3/envs/py38/lib/python3.8/site-packages/torchvision/io/init.py)

Did you have this problem before? I tried pytorch 1.5.0 and 1.8.1, and both of them can not import read_image

update if I tried pytorch 1.8.1 at homura part, type "from homura import trainers, TensorMap, optim, callbacks"
it will become:

Traceback (most recent call last):
File "", line 1, in
File "/home/csgrad/ychan/anaconda3/envs/py38/lib/python3.8/site-packages/homura/init.py", line 8, in
Registry.import_modules('homura.vision')
File "/home/csgrad/ychan/anaconda3/envs/py38/lib/python3.8/site-packages/homura/register.py", line 97, in import_modules
module = importlib.import_module(package_name)
File "/home/csgrad/ychan/anaconda3/envs/py38/lib/python3.8/importlib/init.py", line 127, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
File "/home/csgrad/ychan/anaconda3/envs/py38/lib/python3.8/site-packages/homura/vision/init.py", line 2, in
from .data import DATASET_REGISTRY
File "/home/csgrad/ychan/anaconda3/envs/py38/lib/python3.8/site-packages/homura/vision/data/init.py", line 2, in
from .datasets import VisionSet, DATASET_REGISTRY
File "/home/csgrad/ychan/anaconda3/envs/py38/lib/python3.8/site-packages/homura/vision/data/datasets.py", line 230, in
{'cifar10': VisionSet(datasets.CIFAR10, "~/.torch/data/cifar10", 10,
File "", line 10, in init
File "/home/csgrad/ychan/anaconda3/envs/py38/lib/python3.8/site-packages/homura/vision/data/datasets.py", line 86, in post_init
raise RuntimeError(f"dataset DataSet(root, train, transform, download) is expected, "
RuntimeError: dataset DataSet(root, train, transform, download) is expected, but <class 'torchvision.datasets.cifar.CIFAR10'> has arguments of set() instead.

Thanks

Iteration in search.py divides batch into two groups

Hi,

In faster_autoaugment/search.py, there's the iteration function that divides the batch into two groups.

class AdvTrainer(trainers.TrainerBase):
    # acknowledge https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py
    def iteration(self, data: Tuple[Tensor, Tensor]) -> Mapping[str, Tensor]:
        # input: [-1, 1]
        input, target = data
        b = input.size(0) // 2
        a_input, a_target = input[:b], target[:b]
        n_input, n_target = input[b:], target[b:]
        loss, d_loss, a_loss = self.wgan_loss(n_input, n_target, a_input, a_target)

        return TensorMap(loss=loss, d_loss=d_loss, a_loss=a_loss)

However, I think the train_loader doesn't really load it this way.

In my understanding, shouldn't the n_input and a_input be the same set of images? So in the self.wgan_loss, the a_input is augmented and compared with n_input if the classification loss changes.
If I'm right, the code is a bug and it should instead copy the input and the target?

Training the classification model when using faster_autoaugment/search.py

Hi,

I wonder if there's absolutely no need to pre-train the classification WideResnet model before searching for autoaugment policies.
It looks like in the code it's using randomly initialised weights.

  • Should you not use imagenet pretrained or even CIFAR fine-tuned model weights for searching policies in search.py?
  • If the searching involves updating classification model weights, then should you not load this weight when you train with the autoaugment policies in train.py?

Thank you!

Search policy on my custom dataset

Hi,

Thank you for such good research work both for writing and implementation. As a reader, I feel comfortable reading your work specifically, Faster AutoAugment and MADAO systems.

Now I need to use both systems on my new custom datasets (none of the default torchvision.datasets). To make it easier for me, can you provide some suggestions on which Python files and yaml config files that I should change to make the search.py program run with my custom datasets?

I ask this because I am not yet familiar with Homura ([link (https://github.com/moskomule/homura/blob/master/examples/cifar10.py)) and Chika (link).

Any information is appreciated.

Thanks

test_loader in train.py

Is the test_loader used in train.py actually validation data?
I'm wondering because it seems to be a data leakage if test data is used as validation.

Debug the policy search process for a custom dataset

Hello,

I saw you keep track of the total training loss using the following code at
https://github.com/moskomule/dda/blob/3ffe718e253a77ecb8b4e638d851f0d3d248c111/faster_autoaugment/search.py#L176.

However, when I reviewed the output file, the information related to the classification loss was trimmed. Please refer to the following screenshot. Can you help suggest a fix for this issue?

image

Also, can you help suggest additional debugging strategies for the search process besides monitoring the classification loss, the critic's training loss and the policy's loss? Many thanks.

How to load imagenet?

I have downloaded the imagenet data set from the official website, but I don't know how to process it to run the code.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.