GithubHelp home page GithubHelp logo

victor-qtp / harmofl Goto Github PK

View Code? Open in Web Editor NEW

This project forked from med-air/harmofl

0.0 0.0 0.0 5.72 MB

[AAAI'22] HarmoFL: Harmonizing Local and Global Drifts in Federated Learning on Heterogeneous Medical Images

Shell 0.93% Python 93.08% Cython 5.99%

harmofl's Introduction

HarmoFL: Harmonizing Local and Global Drifts in Federated Learning on Heterogeneous Medical Images

This is the PyTorch implemention of our paper HarmoFL: Harmonizing Local and Global Drifts in Federated Learning on Heterogeneous Medical Images by Meirui Jiang, Zirui Wang and Qi Dou.

Abstract

Multiple medical institutions collaboratively training a model using federated learning (FL) has become a promising solution for maximizing the potential of data-driven models, yet the non-independent and identically distributed (non-iid) data in medical images is still an outstanding challenge in real-world practice. The feature heterogeneity caused by diverse scanners or sensors introduces a drift in the learning process, in both local (client) and global (server) optimizations, which harms the convergence as well as model performance. Many previous works have attempted to address the non-iid issue by tackling the drift locally or globally, but how to jointly solve the two essentially coupled drifts is still unclear. In this work, we concentrate on handling both local and global drifts and introduce a new harmonizing framework called HarmoFL. First, we propose to mitigate the local update drift by normalizing amplitudes of images transformed into the frequency domain to mimic a unified scanner/sensor, in order to generate a harmonized feature space across local clients. Second, based on harmonized features, we design a client weight perturbation guiding each local model to reach a flat optimum, where a neighborhood area of the local optimal solution has a uniformly low loss. Without any extra communication cost, the perturbation assists the global model to optimize towards a converged optimal solution by aggregating several local flat optima. We have theoretically analyzed the proposed method and empirically conducted extensive experiments on three medical image classification and segmentation tasks, showing that HarmoFL outperforms a set of recent state-of-the-art methods with promising convergence behavior.

intro

Usage

Setup

Conda

We recommend using conda to setup the environment, See the requirements.yaml for environment configuration

If there is no conda installed on your PC, please find the installers from https://www.anaconda.com/products/individual

If you have already installed conda, please use the following commands.

conda env create -f environment.yaml
conda activate harmofl

Build cython file

build cython file for amplitude normalization

python utils/setup.py build_ext --inplace

Dataset & Trained Model

Classification

  • Please download the histology breast cancer classification datasets here, extract and put folder 'patches' under data/camelyon17 directory:

Segmentation

  • Please download the prostate MRI datasets here, put the folder data under data/prostate directory.

Train

fed_train.py is the main file to run the federated experiments Please using following commands to train a model with federated learning strategy.

bash train.sh

Below please find some useful options:

  • --alpha :specify the degree of weight perturbation, default is 0.05.
  • --wk_iters :specify the local update epochs, default is 1.

Test

suppose your test model's path is 'model/data/harmofl'

bash test.sh

harmofl's People

Contributors

meiruijiang avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.