QUANTIZATION IN PYTORCH Raghu Krishnamoorthi Facebook QUANTIZATION N
QUANTIZATION IN PYTORCH Raghu Krishnamoorthi, Facebook
QUANTIZATION N × float 32 ● ● Can neural networks run in lower precision? float 16, int 8 Supported by modern hardware N × uint 8 float 32 scale ● ● int 32 zero_poin t x 86 CPU, ARM CPU, NVidia Volta & Turing, Qualcomm DSP, … Maintaining accuracy is hard Working approaches, ongoing research float_val = (uint 8_val - zero_point) × scale 4 x less memory 2 -4 x compute speedup
SYSTEM-MODEL CO-DESIGN — Neural networks inference is expensive — Io. T and mobile devices with limited resources — Design models for efficient inference at scale OUR MISSION: GIVE TOOLS FOR BUILDING AND RUNNING EFFICIENT MODELS
torch. quantize_per_tensor torch. quantize_per_channel PYTORCH AT CORE Same framework, no conversion • Same serialization • Python or Torch. Script Eager at its core • Most logic is in python • Extensibility, debuggers, stack traces Extensible API • New layers • Observers • Quantization techniques • Partial quantization torch. nn. quantized. * torch. nn. quantized. dynamic. * torch. quantization. Observer torch. quantization. Fake. Quantize
PYTORCH QUANTIZATION TURN-KEY WORKFLOWS COMPONENTS FOR TUNING & RESEARCH CORE SUPPORT Dynamic quantization Post training quantization Quantization aware training Every part of the workflow is flexible Use or build your own (in Py. Torch) Quantized tensor and operations Optimized kernels for x 86 and ARM CPUs (other backends coming)
WORKFLOWS Quantization Dataset Requirements Dynamic Quantization weights only (both fp 16 and int 8) None Post Training Quantization weights and activations (8 bit) calibration CNNs good Quantization-Aware Training weights and activations (8 bit) fine-tuning all best Or build your own Works Best For small batch LSTMs and MLPs Accuracy good
WORKFLOW: X float W bias int 8 float nnqd. Linear Y float DYNAMIC QUANTIZATION How: tweak model, one line API What: quantize weights once, activations at runtime Good for: LSTMs/Transformers and MLPs with small batch size Savings: 2 x faster compute, 4 x less memory # load or train your model = Word. Language. Model() model. load_state_dict(torch. load("model. pt")) # quantize qmodel = quantize_dynamic(model, dtype=torch. quint 8) # use or deploy for C++ inference output = qmodel(input) torch. jit. script(qmodel). save("scripted. pt")
CALIBRATE observ er WORKFLOW: POST TRAINING X float W bias float Conv 2 d observ er Y float How: tweak model, calibrate on data, convert QUANTIZE What: quantize weight and activations for entire model or submodules Good for: CNNs (if the accuracy drop is acceptable) Savings: 1. 5 -2 x faster compute, 4 x less memory X uint 8 W bias int 8 float nnq. Conv out qparams Y uint 8
# load or train your model = Res. Net 50() model. load_state_dict(torch. load("model. pt")) WORKFLOW: POST TRAINING 1. MODIFY MODEL # tweak model for best results # change code directly or use manipulation APIs model. eval() model = quantization. fuse_modules(model, [["conv 1", "bn 1", "relu 1"]]) How: 1. Use modules for all operations with state 2. Explicitly control where activations are quantized and dequantized 3. Fuse operations prior to quantization for Conv. Re. LU 2 d( (0): Conv 2 d(3, 64, kernel_size=(7, 7), performance and accuracy benefits stride=(2, 2), padding=(3, 3)) (1): Re. LU(inplace=True) )
# load or train your model = Res. Net 50() model. load_state_dict(torch. load("model. pt")) WORKFLOW: POST TRAINING 2. PREPARE AND CALIBRATE How: 1. Specify which parts of the model need to be quantized 2. Specify how to collect statistics (Observers) # tweak model for best results # change code directly or use manipulation APIs model. eval() model = quantization. fuse_modules(model, [["conv 1", "bn 1", "relu 1"]]) # specify which part to quantize and how model. qconfig = quantization. get_default_qconfig(‘fbgemm’) # configurable! qmodel = quantization. prepare(model, inplace=False) # collect calibration statistics qmodel. eval() for batch, target in data_loader: model(batch) print(model. conv 1) Conv. Re. LU 2 d(3, 64, kernel_size=(7, 7), . . . (observer): Min. Max. Observer( min_val=0. 0, max_val=4. 55) )
# load or train your model = Res. Net 50() model. load_state_dict(torch. load("model. pt")) WORKFLOW: POST TRAINING 3. CONVERT How: call torch. quantization. convert() What: Converts operations from fp 32 to int 8 arithmetic # tweak model for best results # change code directly or use manipulation APIs model. eval() model = quantization. fuse_modules(model, [["conv 1", "bn 1", "relu 1"]]) # specify which part to quantize and how model. qconfig = quantization. get_default_qconfig(‘fbgemm’) # configurable! # collect calibration statistics qmodel. eval() for batch, target in data_loader: model(batch) # get the quantized model quantization. convert(qmodel) print(qmodel. conv 1) Quantized. Conv. Re. LU 2 d(3, 64, scale=0. 035, zero_point=0, kernel_size=(7, 7), . . . )
# load or train your model = Res. Net 50() model. load_state_dict(torch. load("model. pt")) WORKFLOW: POST TRAINING 4. DEPLOY How: tweak model, calibrate on data, convert What: quantize weight and activations for entire model or submodules Good for: CNNs (if the accuracy drop is acceptable) Savings: 1. 5 -2 x faster compute, 4 x less memory # tweak model for best results # change code directly or use manipulation APIs model. eval() model = quantization. fuse_modules(model, [["conv 1", "bn 1", "relu 1"]]) # specify which part to quantize and how model. qconfig = quantization. get_default_qconfig(‘fbgemm’) # configurable! qmodel = quantization. prepare(model, inplace=False) # collect calibration statistics qmodel. eval() for batch, target in data_loader: model(batch) # get the quantized model quantization. convert(qmodel) # use or deploy for C++ inference qmodel(input) torch. jit. script(qmodel). save(“quantized. pt”)
# load or train your model = Res. Net 50() model. load_state_dict(torch. load("model. pt")) WORKFLOW: QUANTIZATION AWARE TRAINING How: Steps are almost identical to the post training quantization workflow. • Identical modifications to model • Specify a different qconfig and use prepare_qat • train instead of calibrate What: quantize weight and activations for entire model or submodules Good for: Provides best accuracy vs performance tradeoff Savings: Identical to that of post training quantization # tweak model for best results # change code directly or use manipulation APIs model = quantization. fuse_modules(model, [["conv 1", "bn 1", "relu 1"]]) # specify which part to quantize and how model. qconfig = quantization. get_default_qat_qconfig('fbgemm') qmodel = quantization. prepare_qat(model, inplace=False) # configurable # fine tune model train(qmodel, train_data)
WORKFLOW: QUANTIZATION AWARE TRAINING torch. quantization. Fake. Quantize -UNDER THE HOOD Fake quantization to mimic quantization in forward pass - Straight through estimator in backward pass Can also provide your own fake-quantization module torch. nn. qat. * Special handling of batch normalization - Fold batch normalization to mimic inference during training - Freeze batch norm stats update for improved accuracy during quantization aware training
EXAMPLE MODELS Res. Net 50 Mobile. Net. V 2 BERT fp 32 accuracy int 8 accuracy change Technique 76. 1 -0. 2 Imagenet 75. 9 Post Training 71. 9 -0. 3 Imagenet 71. 6 90. 2 -0. 6 F 1 (GLUE MRPC) 89. 6 Quantization ready torch vision models available now! Quantization-Aware Training CPU inference speed up 2 x 214 ms ➙ 102 ms, Intel Skylake-DE 4 x 75 ms ➙ 18 ms One. Plus 5, Snapdragon 835 1. 6 x Dynamic Quantization 581 ms ➙ 313 ms, Intel Skylake-DE, Batch size=1
QUANTIZING M O B I L E N E T fp 32 S accuracy Mobile. Net. V 2 int 8 accuracy change Technique 71. 9 -6. 3 Imagenet 65. 6 Post Training: Per Tensor quantization 71. 9 -4. 8 Imagenet 67. 1 71. 9 -0. 3 Imagenet 71. 6 Extend workflow with custom observers and fakequant modules Post-Training: Perchannel quantization Quantization-Aware Training
SOON: JIT TO SIMPLIFY PREPARATION Structural tweaks for Torch. Script models automatically: fusion, batch norm folding, etc model = torch. jit. script(model) # tweak model for best results # change code or use manipulation APIs model = quantization. fuse_modules(model, [["conv 1", "bn 1", "relu 1"]]) qmodel = quantization. prepare_script(model, {"": quantization. default_qconfig}). . . Coming in 1. 5, check nightlies qmodel = quantization. convert_script(qmodel) qmodel. save(“quantized. pt")
FRAMEWORK view S UPPORT Basic support - enough for CNNs and RNNs * Backends • x 86 CPU in 1. 3 (via FBGEMM) • ARM CPU early alpha (QNNPACK) clone + sort max_pool 2 d RNN Linear LSTM slice topk avg_pool 2 d upsample_nearest 2 d Conv 2 d resize max relu interpolate
TRY IT NOW EXPERIMENTAL IN 1. 3 QUANTIZATION CORE AND WORKFLOWS Pytorch quantization documentation Pytorch quantization tutorials MODELS AND EXAMPLE SCRIPTS QUANTIZED MODELS AND TUTORIALS TO OBTAIN TQuantized HEM models in torch-vision Reference scripts for quantization aware training and post training quantization Tutorials for: 1. Dynamic quantization for LSTM models 2. Static quantization and quantization aware training for Resnet 3. Dynamic quantization for BERT 4. Transfer learning with quantized models COMING IN 1. 5 MORE BACKENDS AND JIT WORKFLOW Simpler workflow for Torch. Script Expanding operator coverage
Thank you
- Slides: 20