GithubHelp home page GithubHelp logo

wazenmai / mc-smoe Goto Github PK

View Code? Open in Web Editor NEW

This project forked from unites-lab/mc-smoe

0.0 0.0 0.0 2.19 MB

[ICLR 2024 Spotlight] Code for the paper "Merge, Then Compress: Demystify Efficient SMoE with Hints from Its Routing Policy"

Home Page: https://arxiv.org/abs/2310.01334

License: MIT License

Shell 13.02% Python 73.59% Dockerfile 0.04% Jupyter Notebook 13.35%

mc-smoe's Introduction

Merge, Then Compress: Demystify Efficient SMoE with Hints from Its Routing Policy

License: MIT

Code for the paper "Merge, Then Compress: Demystify Efficient SMoE with Hints from Its Routing Policy"

Update:

  • ๐Ÿš€ MC-SMoE now supports Mixtral-8x7B!

Overview

Sparsely activated Mixture-of-Experts (SMoE) has shown promise to scale up the learning capacity of neural networks, however, they have issues like: ($a$) $\textit{High Memory Usage,}$ due to duplication of the network layers into multiple copies as experts; and ($b$) $\textit{Redundancy in Experts,}$ as common learning-based routing policies suffer from representational collapse. Therefore, vanilla SMoE models are memory inefficient and non-scalable, especially for resource-constrained downstream scenarios. In this paper, we ask: Can we craft a compact SMoE model by consolidating expert information? What is the best recipe to merge multiple experts into fewer but more knowledgeable experts? Our pilot investigation reveals that conventional model merging methods fail to be effective in such expert merging for SMoE. The potential reasons are: ($1$) redundant information overshadows critical experts; ($2$) appropriate neuron permutation for each expert is missing to bring all of them in alignment. To address these challenges, we propose a novel merging algorithm for SMoE, $\textit{i.e.}$, $\texttt{M-SMoE}$, which leverages routing statistics to guide expert merging. Specifically, it starts with neuron permutation alignment for experts; then, dominant experts and their "group members" are formed based on routing policies; lastly, every expert group is merged into a single expert by utilizing each expert's activation frequency as their weight for merging, thus diminishing the impact of insignificant experts. Moreover, we draw an interesting observation that our proposed merging promotes a low dimensionality in the merged expert's weight space, naturally paving the way for additional compression. Hence, our final method, $\texttt{MC-SMoE}$ ($\textit{i.e.}$, Merge, then Compress SMoE), further decomposes the merged experts into low-rank and structural sparse alternatives. Extensive experiments across $8$ benchmarks validate the effectiveness of our proposals. For instance, our $\texttt{MC-SMoE}$ achieves up to $80$% memory and a $20$% FLOPs reduction, with virtually no loss in performance.

Setup

conda create -n mcsmoe python=3.9 -y && conda activate mcsmoe
pip install -r requirements.txt

Usage and Examples

Full SMoE fine-tuning

accelerate launch --config_file static/finetune_config.yaml \
  mcsmoe/finetune-switch-transformers.py \
  --per_device_train_batch_size=8 \
  --per_device_eval_batch_size=64 \
  --gradient_accumulation_steps=1 \
  --num_epochs=20 \
  --no_eval_until_epochs=1 \
  --save_each_epoch=False \
  --preprocessing_num_workers=8 \
  --num_experts=32 \
  --task="copa" \
  --learning_rate=3e-5 \
  --warmup_steps=16 \
  --output_dir="results/copa/switch-32e"

M-SMoE Expert Permutation Alignment

python -u mcsmoe/permute-model.py \
  --checkpoint="results/copa/switch-32e" \
  --save_dir="results/copa/switch-32e-permuted" 

M-SMoE Merging

accelerate launch --config_file static/finetune_config.yaml \
  mcsmoe/msmoe-merging.py \
  --per_device_train_batch_size=16 \    # ======== training arguments from here ========
  --per_device_eval_batch_size=256 \
  --gradient_accumulation_steps=1 \
  --preprocessing_num_workers=8 \
  --num_epochs=10 \
  --num_eval_steps=100 \
  --learning_rate=3e-5 \
  --warmup_steps=16 \
  --weight_decay=0.01 \
  --kd_temperature=2 \
  --mlm_lambda=1.0 \
  --kd_lambda=0.2 \
  --task="copa" \     # ======== merging arguments from here ========
  --num_samples_for_merging=256 \
  --similarity_base="router-logits" \     # for all available options refer to LEGAL_SIMILARITY_BASES in mcsmoe/merging/grouping.py 
  --num_groups=8 \    # average number of experts per SMoE layer
  --globally_group=True \   # if True, apply adaptive merging ratio for each SMoE layer
  --save_stable_rank=False \    # whether to save stable rank of each expert for analysis
  --encoder_merging_layers="3,5,7,9,11" \   # encoder layer indices to be merged
  --decoder_merging_layers="1,3,5,7,9,11" \   # decoder layer indices to be merged
  --output_dir="results/copa/merged/" \     # M-SMoE checkpoint will be saved here
  --teacher_checkpoint="results/copa/switch-32e-permuted" \    # KD teacher checkpoint, full SMoE
  --student_checkpoint="results/copa/switch-32e-permuted"    # KD student checkpoint, will be merged by M-SMoE

MC-SMoE Low-rank Compression

accelerate launch --config_file static/finetune_config.yaml \
  --main_process_port 29510 mcsmoe/losparse-downstream.py \
  --per_device_train_batch_size=16 \     # ======== training arguments from here ========
  --per_device_eval_batch_size=256 \
  --gradient_accumulation_steps=1 \
  --preprocessing_num_workers=8 \
  --num_epochs=50 \
  --num_eval_steps=100 \
  --learning_rate=3e-5 \
  --warmup_steps=50 \
  --weight_decay=0.01 \
  --kd_temperature=2 \
  --mlm_lambda=1.0 \
  --kd_lambda=0.2 \
  --hd_lambda=0.0 \
  --task="copa" \     # ======== compression arguments from here ========
  --output_dir="results/copa/switch-32e-merged-8e-compressed/" \      # MC-SMoE checkpoint will be saved here
  --teacher_checkpoint="results/copa/switch-32e-permuted" \      # KD teacher checkpoint, full SMoE
  --student_checkpoint="results/copa/switch-32e-merged-8e" \     # M-SMoE checkpoint, will be further compressed by MC-SMoE
  --final_threshold=0.10 \      # average remaining ratio of S matrices in compression
  --low_rank_factor=32      # low-rank factor for U, V matrices in compression

More Examples

Please refer to scripts/t5 and scripts/gpt for more examples (e.g. baselines, ablations).

Hyper-Parameters

General Hyper-Parameters

  • Optimizer: AdamW
  • Adam $\epsilon$: $1e-6$
  • Adam $\beta$: ($0.9$, $0.98$)
  • Warm-up steps: $16$
  • Weight decay: $0.01$
  • LR scheduler: Linear decay
  • KD $\alpha$: $0.2$
  • KD $T$: $2.0$

Task-Specific Hyper-Parameters

Batch size Learning rate
SST-2 $64$ $1e-4$
MRPC $32$ $5e-5$
MultiRC $32$ $3e-5$
COPA $8$ $3e-5$
WinoGrande $32$ $1e-5$
SQuAD $16$ $5e-5$
WikiQA $32$ $5e-5$
HotpotQA $32$ $1e-4$

Citation

@misc{li2023merge,
      title={Merge, Then Compress: Demystify Efficient SMoE with Hints from Its Routing Policy}, 
      author={Pingzhi Li and Zhenyu Zhang and Prateek Yadav and Yi-Lin Sung and Yu Cheng and Mohit Bansal and Tianlong Chen},
      year={2023},
      eprint={2310.01334},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

mc-smoe's People

Contributors

25349023 avatar pingzhili avatar tianlong-chen avatar wazenmai avatar yl3469 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.