GithubHelp home page GithubHelp logo

rubencart / sssl-food-security Goto Github PK

View Code? Open in Web Editor NEW
3.0 1.0 0.0 145.27 MB

Code and data export scripts accompanying the article "Spatiotemporal self-supervised pre-training on satellite imagery improves food insecurity prediction"

Home Page: https://www.cambridge.org/core/journals/environmental-data-science/article/spatiotemporal-selfsupervised-pretraining-on-satellite-imagery-improves-food-insecurity-prediction/47FDCFF96FF9A99D31548C1539D506A5

License: MIT License

Python 97.76% JavaScript 2.24%

sssl-food-security's Introduction

Spatiotemporal self-supervised pre-training on satellite imagery improves food insecurity prediction

Code and data export scripts accompanying our article, published in Environmental Data Science.

We used code from Patacchiola's Relational Reasoning repository.

Set up

Dependencies

  • pytorch
  • pytorch-lightning
  • typed-argument-parser
  • tqdm
  • rasterio 1.3
  • geopandas 0.8.1
  • pyyaml
  • wandb
  • h5py
  • scikit-learn
  • matplotlib
  • shap
conda create -n ssslenv python=3.10 ipython
conda activate ssslenv
conda install pytorch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 pytorch-cuda=11.7 -c pytorch -c nvidia
conda install -c conda-forge rasterio=1.3.0.post1 gdal=3.5.1 poppler=22.04.0 pytorch-lightning=1.8.6
# if error `ImportError: libLerc.so.4: cannot open shared object file: No such file or directory` when `import rasterio`:
# conda install -c conda-forge lerc=4.0.0
pip install geopandas tqdm wandb scikit-learn typed-argument-parser matplotlib seaborn shap
conda install h5py -c conda-forge
# https://github.com/ContinuumIO/anaconda-issues/issues/10351
# conda install "poppler<0.62"
# conda install -c conda-forge poppler=21.09 gdal=3.3.3
# conda install -c conda-forge poppler=22.04 gdal=3.5.1 rasterio=1.3 --force-reinstall
rasterio                  1.3.0.post1     py310h1bedc6d_0    conda-forge
geopandas                 0.8.1              pyhd3eb1b0_0  
geopandas-base            0.11.1             pyha770c72_0    conda-forge
gdal                      3.5.1           py310hb7951cf_2    conda-forge
poppler                   22.04.0              h1434ded_1    conda-forge
poppler-data              0.4.11               hd8ed1ab_0    conda-forge
pytorch                   1.13.0          py3.10_cuda11.7_cudnn8.5.0_0    pytorch
pytorch-cuda              11.7                 h67b0de4_0    pytorch

Finally (!) install the sssl package as module so the scripts can import it: run from the repository root directory:

pip install -e .

Data

  • Start date: '2013-05-01'
  • End date 3-month intervals: '2015-11-01'
  • End date 4-month intervals: '2020-03-01'
  • Total months: 7y * 12m - 2m = 82m

We took 2013-05-01 as start date for GEE images, which means the first composite is from 2013-05-01 until 2013-08-01, from which we predict the IPC score corresponding to 2013-08-01. This means that we can't predict a score for 2013-05-01, even though in the .csv file, there is a score for this date. If we wanted to include this date, we should have used 2013-02-01 as start date in GEE. We exclude this date from data (no change needed in neural net code, since no exported tile corresponds to this date, but not in e.g. random forest code since that uses the .csv).

Download & Preprocess

Make sure you ran pip install -e . first!

IPC scores:

  • Extract data/predicting_food_crises_data.zip

If running on different country/region/timeframe than Somalia 2013-2020:

  • Adjust and run scripts/preprocess/preprocess_ipc_csv.py
  • Do the set-up under 'Tiles' below
  • Run scripts/preprocess/ipc_class_weights.py to recompute relative weights of IPC classes if using different data than Somalia 2013-2020 and update in utils.Constants

Tiles:

  • Download data from GEE with code in scripts/earth_engine/js/export_somalia.js
  • Download from google cloud to local server
  • Run scripts/preprocess/build_indices.py. The indices (and output of the next 2 steps) used for results in the publication are included in data/indices.zip. You can hence skip this and the subsequent 2 items by extracting that zip file.
  • Check for tiles that don't have enough positives with scripts/preprocess/search_tiles_not_enough_neighbors.py
  • Run scripts/preprocess/pseudo_random_order.py so tiles in val set for pretraining always have the same positives.
  • Run scripts/preprocess/make_h5.py (optional but much faster), use cfg.use_h5 = True in subsequent runs
  • Run scripts/preprocess/channel_mean_std.py if using different data than Somalia Landsat8 2013-2020, update means/stds in utils.Constants

Config files:

  • Run config/generate_pretrain_configs.py to generate pretrain run config files, to run with python code/pretrain.py --cfg config/pretrain/<config>.yaml
  • Run config/generate_downstr_configs.py to generate downstream IPC run config files, to run with python code/finetune.py --cfg config/ipc/<config>.yaml
  • Run config/generate_comb_configs.py to generate pretrain run config files, to run with python code/pretrain_then_finetune.py --cfg config/pretrain_ipc/<config>.yaml

Run

Pretrain:

CUDA_VISIBLE_DEVICES=0 python code/pretrain.py --cfg config/pretrain/debug.yaml

Finetune:

CUDA_VISIBLE_DEVICES=0 python code/finetune.py --cfg config/ipc/debug.yaml

Pretrain then finetune:

CUDA_VISIBLE_DEVICES=0 python code/pretrain_then_finetune.py --cfg config/pretrain_ipc/debug.yaml --seed 41

Checkpoints

The file checkpoints/sssl_resnet18_t1_s04_all_bands.zip contains the best checkpoint of SSSL pretraining with temporal threshold set to 1 month and spatial threshold to 0.4 degrees.

from sssl.model.backbone_module import BackboneModule
path = "checkpoints/sssl_resnet18_t1_s04_all_bands.zip"
module = BackboneModule.load_from_checkpoint(path)
# or
from pytorch_lightning import Trainer
module = BackboneModule(...)
trainer = Trainer(...)
trainer.predict(module, ckpt_path=path)

The file checkpoints/ipc_finetuned_resnet18_t1_s04_all_bands.zip contains the best checkpoint of IPC finetuning, started from weights initialized to the SSSL checkpoint above.

from sssl.model.ipc_module import IPCModule
path = "checkpoints/ipc_finetuned_resnet18_t1_s04_all_bands.zip"
module = IPCModule.load_from_checkpoint(path)
# or
from pytorch_lightning import Trainer
module = IPCModule(...)
trainer = Trainer(...)
trainer.predict(module, ckpt_path=path)

Plots and results

  • Run the limit_plot function in scripts/results/loss_and_threshold_plots.py to generate plots like in Figure 5.
  • Run scripts/results/tile_ipc_plots.py to generate plots like in Figure 2, 3, 4.
  • Run scripts/results/seasonality_plots.py to generate a plot like in Figure 8.
  • Run scripts/results/less_data_future_plots.py to generate plots like in Figure 6 and 7.
  • Run scripts/results/deeplift_shap_plots.py to generate SHAP value plots like in Figure 9.

Reference

If you use our work, please cite:

@article{cartuyvels2023spatiotemporal,
  title={Spatiotemporal self-supervised pre-training on satellite imagery improves food insecurity prediction},
  author={Cartuyvels, Ruben and Fierens, Tom and Coppieters, Emiel and Moens, Marie-Francine and Sileo, Damien},
  journal={Environmental Data Science},
  volume={2},
  pages={e48},
  year={2023},
  publisher={Cambridge University Press}
}

sssl-food-security's People

Contributors

rubencart avatar

Stargazers

 avatar  avatar  avatar

Watchers

 avatar

sssl-food-security's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.