GithubHelp home page GithubHelp logo

zhlinup / ris-fl Goto Github PK

View Code? Open in Web Editor NEW

This project forked from liuhang1994/ris-fl

0.0 1.0 0.0 149 KB

Simulation Codes for "Reconfigurable Intelligent Surface Enabled Federated Learning: A Unified Communication-Learning Design Approach"

License: MIT License

Python 100.00%

ris-fl's Introduction

RIS-FL

This is the simulation code package for the following paper:

Hang Liu, Xiaojun Yuan, and Ying-Jun Angela Zhang. "Reconfigurable intelligent surface enabled federated learning: A unified communication-learning design approach," to appear at IEEE Transactions on Wireless Communications, 2020. [ArXiv Version]

The package, written on Python 3, reproduces the numerical results of the proposed algorithm in the above paper.

Abstract of Article:

To exploit massive amounts of data generated at mobile edge networks, federated learning (FL) has been proposed as an attractive substitute for centralized machine learning (ML). By collaboratively training a shared learning model at edge devices, FL avoids direct data transmission and thus overcomes high communication latency and privacy issues as compared to centralized ML. To improve the communication efficiency in FL model aggregation, over-the-air computation has been introduced to support a large number of simultaneous local model uploading by exploiting the inherent superposition property of wireless channels. However, due to the heterogeneity of communication capacities among edge devices, over-the-air FL suffers from the straggler issue in which the device with the weakest channel acts as a bottleneck of the model aggregation performance. This issue can be alleviated by device selection to some extent, but the latter still suffers from a tradeoff between data exploitation and model communication. In this paper, we leverage the reconfigurable intelligent surface (RIS) technology to relieve the straggler issue in over-the-air FL. Specifically, we develop a learning analysis framework to quantitatively characterize the impact of device selection and model aggregation error on the convergence of over-the-air FL. Then, we formulate a unified communication-learning optimization problem to jointly optimize device selection, over-the-air transceiver design, and RIS configuration. Numerical experiments show that the proposed design achieves substantial learning accuracy improvement compared with the state-of-the-art approaches, especially when channel conditions vary dramatically across edge devices.

Dependencies

This package is written on Python 3. It requires the following libraries:

  • Python >= 3.5
  • torch
  • torchvision
  • scipy
  • CUDA (if GPU is used)

How to Use

The main file is main.py. It can take the following user-input parameters by a parser (also see the function initial() in main.py):

Parameter Name Meaning Default Value Type/Range
M total number of devices 40 int
N total number of receive antennas 5 int
L total number of RIS elements 40 int
nit maximum number of iterations for Algorithm 1, I_max 100 int
Jmax number of iterations for Gibbs sampling 50 int
threshold threshold value for the early stopping in Algorithm 1 1e-2 float
tau SCA regularization term for Algorithm 1 1 float
trial total number of Monte Carlo trials 50 int
SNR signal-to-noise ratio, P_0/sigma^2_n in dB 90.0 float
verbose Output no/importatnt/detailed messages in running the scripts 0 0,1,2
set Which simulation setting (1 or 2) to use; see Section V-A 2 1,2
seed random seed 1 int
gpu GPU index used for learning (if possible) 1 int
momentum SGD momentum, only used for multiple local updates 0.9 float
epochs number of training rounds T 500 int

Here is an example for executing the scripts in a Linux terminal:

python -u main.py --gpu=0 --trial=50 --set=2

Documentations (Please also see each file for more details):

  • main.py: Initialize the simulation system, optimizing the variables, training the learning model, and storing the result to Store/ as a npz file

    • initial(): Initialize the parser function to read the user-input parameters
  • optlib.py:

    • Gibbs(): Optimize x, f, and theta via Algorithm 2 on top of the following two functions
    • find_obj_inner(): Given x, compute the objective value by executing sca_fmincon()
    • sca_fmincon(): Given the device selection decision x, optimize f and theta via Algorithm 1
  • flow.py:

    • learning_flow(): Read the optimization result, initial the learning model, and perform training and testing on top of Learning_iter()
    • Learning_iter(): Given learning model, compute the graidents, update the training models, and perform testing on top of train_script.py
    • FedAvg_grad(): Given the aggregated global gradient and the current model, update the global model by eq.(4)
  • Nets.py:

    • CNNMnist(): Specify the convolutional neural network structure used for learning
  • AirComp.py:

    • transmission(): Given the local gradients, perform over-the-air model aggregation; see Section II-C
  • train_script.py:

    • Load_FMNIST_IID(): Download (if needed) and load the Fashion-MNIST data, and distribute them to the local devices
    • local_update(): Given a learning model and the distributed training data, compute the local gradients/model changes
    • test_model(): Given a learning model, test the accuracy/loss based on certain test images
  • Monte_Carlo_Averaging.py: Load the npz file from store, and average the Monte Carlo trials

  • data/: Store the Fashion-MNIST dataset. When running at the first time, it automatically downloads the dataset from the Interenet.

  • store/: Store output files (*.npz)

Referencing

If you in any way use this code for research that results in publications, please cite our original article listed above.

ris-fl's People

Contributors

liuhang1994 avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.