Skip to content

nusdbsystem/model-slicing

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Model Slicing

version python pytorch singa

This repository contains our PyTorch implementation of Model Slicing for Supporting Complex Analytics with Elastic Inference Cost and Resource Constraints. Model Slicing is a general dynamic width technique that enables neural networks to support budgeted inference, namely producing predictions within a prescribed computational budget by dynamically trading off accuracy for efficiency at runtime.

Budgeted inference is achieved by dividing each layer of the network into equal-sized groups of basic components (i.e., neurons in dense layers and channels in convolutional layers). Technically, we use a single parameter called slice rate r to control the fraction of groups involved in computation for all layers at runtime, namely to control the width of the network in both training and inference.

In particular, the groups involved in computation always start from the first group, and contiguously to the dynamically determined last group indexed by the current slice rate. E.g., a slice rate of 0.5 will select the first two groups in a layer of 4 groups as illustrated below.

This repo includes:

  1. representative models (/models)
  2. codes for model slicing training (train.py)
  3. codes for supporting model slicing functionalities (models/model_slicing.py)

Training

  1. Dependencies
pip install -r requirements.txt
  1. Model Training
Example training code:
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --exp_name resnet_50 --net_type resnet --group 8 --depth 50 --sr_list 1.0 0.75 0.5 0.25 --sr_scheduler_type random_min_max --sr_rand_num 1 --epoch 100 --batch_size 256 --lr 0.1  --dataset imagenet --data_dir /data/ --log_freq 50

Please check help info in argparse.ArgumentParser (train.py) for configuration details.
  1. One line to support Model Slicing
model = upgrade_dynamic_layers(model, args.groups, args.sr_list)

    * groups:   the number of groups for each layer, e.g. 8
    * sr_list:  slice rate list, e.g. [1.0, 0.75, 0.5, 0.25]

Contact

To ask questions or report issues, you can directly drop us an email.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages