yyzharry / subpopbench Goto Github PK
View Code? Open in Web Editor NEW[ICML 2023] Change is Hard: A Closer Look at Subpopulation Shift
Home Page: https://subpopbench.csail.mit.edu
License: MIT License
[ICML 2023] Change is Hard: A Closer Look at Subpopulation Shift
Home Page: https://subpopbench.csail.mit.edu
License: MIT License
After a few learning steps, the algorithm is producing NaN outputs.
More specifically, in the file "algorithm.py" within the "_compute_loss()" function, "self.network(x)" is generating NaN values, despite the fact that neither the parameters of "self.network" nor "x" contain any NaN values.
GPU: Nvidia A100
Python: 3.9.7
PyTorch: 1.13.0+cu117
Torchvision: 0.14.0+cu117
CUDA: 11.7
CUDNN: 8500
NumPy: 1.19.5
PIL: 9.5.0
Args:
algorithm: IRM
checkpoint_freq: None
cmnist_attr_prob: 0.5
cmnist_flip_prob: 0.25
cmnist_label_prob: 0.5
cmnist_spur_prob: 0.2
data_dir: /root/data
dataset: CheXpertNoFinding
es_metric: min_group:accuracy
es_patience: 5
es_strategy: metric
hparams: None
hparams_seed: 0
image_arch: resnet_sup_in1k
output_dir: ./output/test_attrNo/CheXpertNoFinding_IRM_hparams0_seed0
output_folder_name: test_attrNo
pretrained:
resume:
seed: 0
skip_model_save: False
stage1_algo: ERM
stage1_folder: vanilla
steps: None
store_name: CheXpertNoFinding_IRM_hparams0_seed0
tb_log_all: False
text_arch: bert-base-uncased
train_attr: no
use_es: False
HParams:
batch_size: 108
group_balanced: False
image_arch: resnet_sup_in1k
irm_lambda: 100.0
irm_penalty_anneal_iters: 500
last_layer_dropout: 0.0
lr: 0.001
nonlinear_classifier: False
optimizer: sgd
pretrained: True
resnet18: False
text_arch: bert-base-uncased
weight_decay: 0.0001
Dataset:
[train] 167093 (without attributes)
[val] 22280
[test] 33419
/root/anaconda3/envs/subpop/lib/python3.9/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
warnings.warn(
/root/anaconda3/envs/subpop/lib/python3.9/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet50_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet50_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
step epoch loss te_avg_acc te_worst_acc va_avg_acc va_worst_acc
0 0.0000 0.5206 0.9046 0.0000 0.8984 0.0000
Traceback (most recent call last):
File "/root/anaconda3/envs/subpop/lib/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/root/anaconda3/envs/subpop/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/root/workspace/SubpopBench/subpopbench/train.py", line 252, in <module>
curr_metrics = {split: eval_helper.eval_metrics(algorithm, loader, device)
File "/root/workspace/SubpopBench/subpopbench/train.py", line 252, in <dictcomp>
curr_metrics = {split: eval_helper.eval_metrics(algorithm, loader, device)
File "/root/workspace/SubpopBench/subpopbench/utils/eval_helper.py", line 41, in eval_metrics
**prob_metrics(targets, preds, label_set)
File "/root/workspace/SubpopBench/subpopbench/utils/eval_helper.py", line 128, in prob_metrics
'AUROC_ovo': roc_auc_score(targets, preds, multi_class='ovo', labels=label_set),
File "/root/anaconda3/envs/subpop/lib/python3.9/site-packages/sklearn/utils/validation.py", line 63, in inner_f
return f(*args, **kwargs)
File "/root/anaconda3/envs/subpop/lib/python3.9/site-packages/sklearn/metrics/_ranking.py", line 524, in roc_auc_score
y_score = check_array(y_score, ensure_2d=False)
File "/root/anaconda3/envs/subpop/lib/python3.9/site-packages/sklearn/utils/validation.py", line 63, in inner_f
return f(*args, **kwargs)
File "/root/anaconda3/envs/subpop/lib/python3.9/site-packages/sklearn/utils/validation.py", line 663, in check_array
_assert_all_finite(array,
File "/root/anaconda3/envs/subpop/lib/python3.9/site-packages/sklearn/utils/validation.py", line 103, in _assert_all_finite
raise ValueError(
ValueError: Input contains NaN, infinity or a value too large for dtype('float32').```
I successfully downloaded MIMICNotes following the instructions provided here: https://github.com/YyzHarry/SubpopBench/blob/main/MedicalData.md#mimicnotes
When I try to train a model on MimicNotes, I got the following error when loading the features.npy file due to this line:
SubpopBench/subpopbench/dataset/datasets.py
Line 473 in 4d3dbbe
raise ValueError("Object arrays cannot be loaded when "
ValueError: Object arrays cannot be loaded when allow_pickle=False
I then added allow_pickle=True
in the np.load() statement above, which fixed this error. But then I get a different error due to this line:
SubpopBench/subpopbench/dataset/datasets.py
Line 478 in 4d3dbbe
return self.x_array[int(x), :].astype('float32')
IndexError: too many indices for array: array is 0-dimensional, but 2 were indexed
Upon inspection, self.x_array
does not look like a standard numpy ndarray but a sparse matrix in Compressed Sparse Row format.
Could you please advise how to correcty load and index this dataset?
Thanks!
May I know the difference between CivilComments and CivilCommentsFine in datasets.py?
Which one should we use and which dataset is reported in the paper?
Thanks!
Can I check if this is the right way to modify early_stopping to use --es_metric min_class:accuracy i.e. worst-class accuracy?
in eval_helper.py
res['min_class'] = {'accuracy': min([res['per_class'][y]['recall'] for y in np.unique(targets)])}
While attempting to download the CelebA dataset using the provided script (subpopbench.scripts.download), I encountered an access denied error from Google Drive followed by a FileNotFoundError. This issue prevents the successful execution of the dataset download script.
Steps to Reproduce
python -m subpopbench.scripts.download --data_path data --download celeba waterbirds civilcomments multinli imagenetbg metashift nico++ breeds cmnist
Error
The script begins the download process for the CelebA dataset but fails with the following error message:
INFO:root:Downloading CelebA...
Access denied with the following error:
Cannot retrieve the public link of the file. You may need to change
the permission to 'Anyone with the link', or have had many accesses.
You may still be able to access the file from the browser:
https://drive.google.com/uc?id=1mb1R6dXfWbvk3DnlWOBO8pDeoBKOcLE6
... [Followed by a FileNotFoundError for img_align_celeba.zip]
I guess the Google Drive link for the CelebA dataset might have moved to restricted access?
howdy! I was interested in using the NICO benchmark in my own codebase but I am unsure of how to write a dataset for it. In your code it looks like there is a metadata.csv but when downloading the dataset I don't see such a file. Any help would be greatly appreciated!
I want to know what is the backbone featurizer used for results in Appendix -E or is it averaged over all the different architectures.
If it is the later, could you share the results due to individual architectures?
Thanks a lot
Based on the README, the command to launch a sweep after fixing hparam with different seeds (with unknown attributes) is
python -m subpopbench.sweep launch \
--algorithms <...> \
--dataset <...> \
--train_attr no \
--best_hp \
--input_folder <...> \
--n_trials <num_of_trials>
May I know how the best hparam is chosen in this step? Is it using ValWorstAccAttributeYes as shown in the script below? Should we modify the selection_method in this script based on our experimental setup (ValWorstAccAttributeYes vs ValWorstAccAttributeNo)?
SubpopBench/subpopbench/sweep.py
Line 81 in 4d3dbbe
Thanks
I tried installing the modules as per their version in the requirements file and end up obtaining the following dependency issues. for python version 3.9.19.
astropy 5.3.4 requires numpy<2,>=1.21, but you have numpy 1.19.5 which is incompatible.
bokeh 3.4.0 requires contourpy>=1.2, but you have contourpy 1.1.1 which is incompatible.
dask-expr 1.1.0 requires pandas>=2, but you have pandas 1.3.2 which is incompatible.
imbalanced-learn 0.11.0 requires scikit-learn>=1.0.2, but you have scikit-learn 0.24.1 which is incompatible.
numba 0.59.1 requires numpy<1.27,>=1.22, but you have numpy 1.19.5 which is incompatible.
pywavelets 1.5.0 requires numpy<2.0,>=1.22.4, but you have numpy 1.19.5 which is incompatible.
scikit-image 0.22.0 requires numpy>=1.22, but you have numpy 1.19.5 which is incompatible.
scikit-image 0.22.0 requires scipy>=1.8, but you have scipy 1.7.0 which is incompatible.
xarray 2023.6.0 requires numpy>=1.21, but you have numpy 1.19.5 which is incompatible.
xarray 2023.6.0 requires pandas>=1.4, but you have pandas 1.3.2 which is incompatible.
I was able to resolve this by upgrading the scipy version to 1.8.0 in the requirement file and upgrading numpy as well to 1.19.5.
Is it a python version issue that the old requirements file did not work for me? What is the python version used by the authors for the same?
FYI: "--" prefix is missing in the datasets
argument in download.py. This causes issues when passing arguments to this script.
parser.add_argument('datasets', nargs='+', type=str, default=[
change to
parser.add_argument('--datasets', nargs='+', type=str, default=[
SubpopBench/subpopbench/scripts/download.py
Line 691 in 4d3dbbe
I installed the conda environment using the provided environment.yml. I am using RedHat OS.
Environment:
Python: 3.9.7
PyTorch: 1.13.0+cu117
Torchvision: 0.14.0+cu117
CUDA: 11.7
CUDNN: 8500
NumPy: 1.19.5
PIL: 10.0.0
When I run ERM on CelebA without training attributes (CelebA_ERM_attrNo) using hparams_seed=0 and seeds={0,1,2} I get the following results, which are quite different from the paper.
Total records: [93]
-------- Dataset: CelebA, model selection method: test set worst accuracy (oracle)
Algorithm Avg Worst AvgPrec WorstPrec AvgF1 WorstF1 Adjusted Balanced AUROC ECE
ERM 94.0 +/- 0.2 67.6 +/- 2.4 85.0 +/- 0.5 71.3 +/- 1.1 88.4 +/- 0.2 80.4 +/- 0.4 87.7 +/- 0.6 93.2 +/- 0.3 98.1 +/- 0.1 4.5 +/- 0.2
-------- Worst-case accuracy, model selection method: test set worst accuracy (oracle)
Algorithm CelebA Avg
ERM 67.6 +/- 2.4 67.6
-------- Dataset: CelebA, model selection method: validation set worst accuracy (with attributes)
Algorithm Avg Worst AvgPrec WorstPrec AvgF1 WorstF1 Adjusted Balanced AUROC ECE
ERM 94.0 +/- 0.2 67.6 +/- 2.4 85.0 +/- 0.5 71.3 +/- 1.1 88.4 +/- 0.2 80.4 +/- 0.4 87.7 +/- 0.6 93.2 +/- 0.3 98.1 +/- 0.1 4.5 +/- 0.2
-------- Worst-case accuracy, model selection method: validation set worst accuracy (with attributes)
Algorithm CelebA Avg
ERM 67.6 +/- 2.4 67.6
-------- Dataset: CelebA, model selection method: validation set worst accuracy (without attributes)
Algorithm Avg Worst AvgPrec WorstPrec AvgF1 WorstF1 Adjusted Balanced AUROC ECE
ERM 94.0 +/- 0.2 67.6 +/- 2.4 85.0 +/- 0.5 71.3 +/- 1.1 88.4 +/- 0.2 80.4 +/- 0.4 87.7 +/- 0.6 93.2 +/- 0.3 98.1 +/- 0.1 4.5 +/- 0.2
-------- Worst-case accuracy, model selection method: validation set worst accuracy (without attributes)
Algorithm CelebA Avg
ERM 67.6 +/- 2.4 67.6
I am also not able to reproduce Waterbirds_ERM_attrNo or the other datasets. I did not modify the code in subpopbench.
Appreciate any help in reproducing the results!
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.