GithubHelp home page GithubHelp logo

yyzharry / subpopbench Goto Github PK

View Code? Open in Web Editor NEW
89.0 5.0 11.0 233 KB

[ICML 2023] Change is Hard: A Closer Look at Subpopulation Shift

Home Page: https://subpopbench.csail.mit.edu

License: MIT License

Python 100.00%
benchmark domain-generalization out-of-distribution subgroup subpopulation subpopulation-shift class-imbalance ood-generalization ood-robustness spurious-correlations

subpopbench's People

Contributors

hzhang0 avatar yyzharry avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

subpopbench's Issues

Clarification on Appendix E

I want to know what is the backbone featurizer used for results in Appendix -E or is it averaged over all the different architectures.
If it is the later, could you share the results due to individual architectures?

Thanks a lot

Error loading MIMICNotes

I successfully downloaded MIMICNotes following the instructions provided here: https://github.com/YyzHarry/SubpopBench/blob/main/MedicalData.md#mimicnotes

When I try to train a model on MimicNotes, I got the following error when loading the features.npy file due to this line:

self.x_array = np.load(os.path.join(data_path, "mimic_notes", 'features.npy'))

raise ValueError("Object arrays cannot be loaded when "
ValueError: Object arrays cannot be loaded when allow_pickle=False

I then added allow_pickle=True in the np.load() statement above, which fixed this error. But then I get a different error due to this line:

return self.x_array[int(x), :].astype('float32')

    return self.x_array[int(x), :].astype('float32')
IndexError: too many indices for array: array is 0-dimensional, but 2 were indexed

Upon inspection, self.x_array does not look like a standard numpy ndarray but a sparse matrix in Compressed Sparse Row format.

Could you please advise how to correcty load and index this dataset?

Thanks!

Error Downloading CelebA Dataset Due to Access Restrictions on Google Drive

While attempting to download the CelebA dataset using the provided script (subpopbench.scripts.download), I encountered an access denied error from Google Drive followed by a FileNotFoundError. This issue prevents the successful execution of the dataset download script.

Steps to Reproduce

  1. Activate the subpop_bench environment.
  2. Run the dataset download script with the command:
python -m subpopbench.scripts.download --data_path data --download celeba waterbirds civilcomments multinli imagenetbg metashift nico++ breeds cmnist

Error

The script begins the download process for the CelebA dataset but fails with the following error message:

INFO:root:Downloading CelebA...

Access denied with the following error:

    Cannot retrieve the public link of the file. You may need to change
    the permission to 'Anyone with the link', or have had many accesses. 

You may still be able to access the file from the browser:

     https://drive.google.com/uc?id=1mb1R6dXfWbvk3DnlWOBO8pDeoBKOcLE6 

... [Followed by a FileNotFoundError for img_align_celeba.zip]

I guess the Google Drive link for the CelebA dataset might have moved to restricted access?

Unable to reproduce results

I installed the conda environment using the provided environment.yml. I am using RedHat OS.

Environment:
    Python: 3.9.7
    PyTorch: 1.13.0+cu117
    Torchvision: 0.14.0+cu117
    CUDA: 11.7
    CUDNN: 8500
    NumPy: 1.19.5
    PIL: 10.0.0

When I run ERM on CelebA without training attributes (CelebA_ERM_attrNo) using hparams_seed=0 and seeds={0,1,2} I get the following results, which are quite different from the paper.

Total records: [93]

-------- Dataset: CelebA, model selection method: test set worst accuracy (oracle)
Algorithm     Avg           Worst         AvgPrec       WorstPrec     AvgF1         WorstF1       Adjusted      Balanced      AUROC         ECE
ERM           94.0 +/- 0.2  67.6 +/- 2.4  85.0 +/- 0.5  71.3 +/- 1.1  88.4 +/- 0.2  80.4 +/- 0.4  87.7 +/- 0.6  93.2 +/- 0.3  98.1 +/- 0.1  4.5 +/- 0.2

-------- Worst-case accuracy, model selection method: test set worst accuracy (oracle)
Algorithm     CelebA        Avg
ERM           67.6 +/- 2.4  67.6

-------- Dataset: CelebA, model selection method: validation set worst accuracy (with attributes)
Algorithm     Avg           Worst         AvgPrec       WorstPrec     AvgF1         WorstF1       Adjusted      Balanced      AUROC         ECE
ERM           94.0 +/- 0.2  67.6 +/- 2.4  85.0 +/- 0.5  71.3 +/- 1.1  88.4 +/- 0.2  80.4 +/- 0.4  87.7 +/- 0.6  93.2 +/- 0.3  98.1 +/- 0.1  4.5 +/- 0.2

-------- Worst-case accuracy, model selection method: validation set worst accuracy (with attributes)
Algorithm     CelebA        Avg
ERM           67.6 +/- 2.4  67.6

-------- Dataset: CelebA, model selection method: validation set worst accuracy (without attributes)
Algorithm     Avg           Worst         AvgPrec       WorstPrec     AvgF1         WorstF1       Adjusted      Balanced      AUROC         ECE
ERM           94.0 +/- 0.2  67.6 +/- 2.4  85.0 +/- 0.5  71.3 +/- 1.1  88.4 +/- 0.2  80.4 +/- 0.4  87.7 +/- 0.6  93.2 +/- 0.3  98.1 +/- 0.1  4.5 +/- 0.2

-------- Worst-case accuracy, model selection method: validation set worst accuracy (without attributes)
Algorithm     CelebA        Avg
ERM           67.6 +/- 2.4  67.6


I am also not able to reproduce Waterbirds_ERM_attrNo or the other datasets. I did not modify the code in subpopbench.

Appreciate any help in reproducing the results!

Question about sweep.py

Based on the README, the command to launch a sweep after fixing hparam with different seeds (with unknown attributes) is

python -m subpopbench.sweep launch \
       --algorithms <...> \
       --dataset <...> \
       --train_attr no \
       --best_hp \
       --input_folder <...> \
       --n_trials <num_of_trials>

May I know how the best hparam is chosen in this step? Is it using ValWorstAccAttributeYes as shown in the script below? Should we modify the selection_method in this script based on our experimental setup (ValWorstAccAttributeYes vs ValWorstAccAttributeNo)?

selection_method = model_selection.ValWorstAccAttributeYes

Thanks

early stopping using min_class:accuracy

Can I check if this is the right way to modify early_stopping to use --es_metric min_class:accuracy i.e. worst-class accuracy?

in eval_helper.py

res['min_class'] = {'accuracy': min([res['per_class'][y]['recall'] for y in np.unique(targets)])}

IRM for CheXpertNoFinding generates NaN output

After a few learning steps, the algorithm is producing NaN outputs.

More specifically, in the file "algorithm.py" within the "_compute_loss()" function, "self.network(x)" is generating NaN values, despite the fact that neither the parameters of "self.network" nor "x" contain any NaN values.

GPU: Nvidia A100

	Python: 3.9.7
	PyTorch: 1.13.0+cu117
	Torchvision: 0.14.0+cu117
	CUDA: 11.7
	CUDNN: 8500
	NumPy: 1.19.5
	PIL: 9.5.0
Args:
	algorithm: IRM
	checkpoint_freq: None
	cmnist_attr_prob: 0.5
	cmnist_flip_prob: 0.25
	cmnist_label_prob: 0.5
	cmnist_spur_prob: 0.2
	data_dir: /root/data
	dataset: CheXpertNoFinding
	es_metric: min_group:accuracy
	es_patience: 5
	es_strategy: metric
	hparams: None
	hparams_seed: 0
	image_arch: resnet_sup_in1k
	output_dir: ./output/test_attrNo/CheXpertNoFinding_IRM_hparams0_seed0
	output_folder_name: test_attrNo
	pretrained: 
	resume: 
	seed: 0
	skip_model_save: False
	stage1_algo: ERM
	stage1_folder: vanilla
	steps: None
	store_name: CheXpertNoFinding_IRM_hparams0_seed0
	tb_log_all: False
	text_arch: bert-base-uncased
	train_attr: no
	use_es: False
HParams:
	batch_size: 108
	group_balanced: False
	image_arch: resnet_sup_in1k
	irm_lambda: 100.0
	irm_penalty_anneal_iters: 500
	last_layer_dropout: 0.0
	lr: 0.001
	nonlinear_classifier: False
	optimizer: sgd
	pretrained: True
	resnet18: False
	text_arch: bert-base-uncased
	weight_decay: 0.0001
Dataset:
	[train]	167093 (without attributes)
	[val]	22280
	[test]	33419
/root/anaconda3/envs/subpop/lib/python3.9/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/root/anaconda3/envs/subpop/lib/python3.9/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet50_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet50_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)


step          epoch         loss          te_avg_acc    te_worst_acc  va_avg_acc    va_worst_acc 
0             0.0000        0.5206        0.9046        0.0000        0.8984        0.0000       
Traceback (most recent call last):
  File "/root/anaconda3/envs/subpop/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/root/anaconda3/envs/subpop/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/root/workspace/SubpopBench/subpopbench/train.py", line 252, in <module>
    curr_metrics = {split: eval_helper.eval_metrics(algorithm, loader, device)
  File "/root/workspace/SubpopBench/subpopbench/train.py", line 252, in <dictcomp>
    curr_metrics = {split: eval_helper.eval_metrics(algorithm, loader, device)
  File "/root/workspace/SubpopBench/subpopbench/utils/eval_helper.py", line 41, in eval_metrics
    **prob_metrics(targets, preds, label_set)
  File "/root/workspace/SubpopBench/subpopbench/utils/eval_helper.py", line 128, in prob_metrics
    'AUROC_ovo': roc_auc_score(targets, preds, multi_class='ovo', labels=label_set),
  File "/root/anaconda3/envs/subpop/lib/python3.9/site-packages/sklearn/utils/validation.py", line 63, in inner_f
    return f(*args, **kwargs)
  File "/root/anaconda3/envs/subpop/lib/python3.9/site-packages/sklearn/metrics/_ranking.py", line 524, in roc_auc_score
    y_score = check_array(y_score, ensure_2d=False)
  File "/root/anaconda3/envs/subpop/lib/python3.9/site-packages/sklearn/utils/validation.py", line 63, in inner_f
    return f(*args, **kwargs)
  File "/root/anaconda3/envs/subpop/lib/python3.9/site-packages/sklearn/utils/validation.py", line 663, in check_array
    _assert_all_finite(array,
  File "/root/anaconda3/envs/subpop/lib/python3.9/site-packages/sklearn/utils/validation.py", line 103, in _assert_all_finite
    raise ValueError(
ValueError: Input contains NaN, infinity or a value too large for dtype('float32').```

Using NICO benchmark

howdy! I was interested in using the NICO benchmark in my own codebase but I am unsure of how to write a dataset for it. In your code it looks like there is a metadata.csv but when downloading the dataset I don't see such a file. Any help would be greatly appreciated!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.