The NonIID Data Quagmire of Decentralized Machine Learning

  • Slides: 26
Download presentation
The Non-IID Data Quagmire of Decentralized Machine Learning ICML 2020 Kevin Hsieh, Amar Phanishayee,

The Non-IID Data Quagmire of Decentralized Machine Learning ICML 2020 Kevin Hsieh, Amar Phanishayee, Onur Mutlu, Phillip Gibbons

ML Training with Decentralized Data Geo-Distributed Learning Federated Learning Data Sovereignty and Privacy 2

ML Training with Decentralized Data Geo-Distributed Learning Federated Learning Data Sovereignty and Privacy 2

Major Challenges in Decentralized ML Geo-Distributed Learning Federated Learning Challenge 1: Communication Bottlenecks Solutions:

Major Challenges in Decentralized ML Geo-Distributed Learning Federated Learning Challenge 1: Communication Bottlenecks Solutions: Federated Averaging, Gaia, Deep Gradient Compression 3

Major Challenges in Decentralized ML Geo-Distributed Learning Federated Learning Challenge 2: Data are often

Major Challenges in Decentralized ML Geo-Distributed Learning Federated Learning Challenge 2: Data are often highly skewed (non-iid data) Solutions: Understudied! Is it a real problem? 4

Our Work in a Nutshell Real-World Dataset Experimental Study Proposed Solution 5

Our Work in a Nutshell Real-World Dataset Experimental Study Proposed Solution 5

Geographical mammal images from Flickr 736 K pictures in 42 mammal classes Real-World Dataset

Geographical mammal images from Flickr 736 K pictures in 42 mammal classes Real-World Dataset Highly skewed labels among geographic regions 6

Skewed data labels are a fundamental and pervasive problem The problem is even worse

Skewed data labels are a fundamental and pervasive problem The problem is even worse for DNNs with batch normalization Experimental Study The degree of skew determines the difficulty of the problem 7

Replace batch normalization with group normalization Proposed Solution Skew. Scout: communication-efficient decentralized learning over

Replace batch normalization with group normalization Proposed Solution Skew. Scout: communication-efficient decentralized learning over arbitrarily skewed data 8

Real-World Dataset 9

Real-World Dataset 9

Flickr-Mammal Dataset 42 mammal classes from Open Images and Image. Net 40, 000 images

Flickr-Mammal Dataset 42 mammal classes from Open Images and Image. Net 40, 000 images per class https: //doi. org/10. 5281/zenodo. 3676081 Clean images with PNAS [Liu et al. , ’ 18] Reverse geocoding to country, subcontinent, and continent 736 K Pictures with Labels and Geographic Information

Top-3 Mammals in Each Continent Each top-3 mammal takes 44 -92% share of global

Top-3 Mammals in Each Continent Each top-3 mammal takes 44 -92% share of global images 11

alpaca antelope armadillo brown bear bull camel cattle cheetah deer dolphin elephant fox goat

alpaca antelope armadillo brown bear bull camel cattle cheetah deer dolphin elephant fox goat hamster harbor seal hedgehog hippopotamus jaguar kangaroo koala leopard lion lynx monkey mule otter panda pig polar bear porcupine rabbit red panda sea lion sheep skunk squirrel teddy bear tiger whale zebra Label Distribution Across Continents 100% 90% 80% 70% 60% 50% 40% 30% 20% 10% 0% Africa Americas Asia Europe Oceania Vast majority of mammals are dominated by 2 -3 continents The labels are even more skewed among subcontinents 12

Experimental Study 13

Experimental Study 13

Scope of Experimental Study ML Application Decentralized Learning Algorithms × • • Image Classification

Scope of Experimental Study ML Application Decentralized Learning Algorithms × • • Image Classification (with various DNNs and datasets) Face recognition Skewness of Data Label Partitions × Gaia [NSDI’ 17] 2 -5 Partitions -Federated. Averaging [AISTATS’ 17] more partitions are worse Deep. Gradient. Compression [ICLR’ 18]

Results: Goog. Le. Net over CIFAR-10 Gaia (20 X faster than BSP) Federated. Averaging

Results: Goog. Le. Net over CIFAR-10 Gaia (20 X faster than BSP) Federated. Averaging (20 X faster than BSP) Deep. Gradient. Compression (30 X faster than BSP) Top-1 Validation Accuarcy BSP (Bulk Synchronous Parallel) 80% -12% -15% 60% 40% -69% 20% 0% Shuffled Data Skewed Data All decentralized learning algorithms lose significant accuracy Tight synchronization (BSP) is accurate but too slow 15

Skewed data is a pervasive and fundamental problem Top-1 Validation Accuracy Even. Similar BSP

Skewed data is a pervasive and fundamental problem Top-1 Validation Accuracy Even. Similar BSP loses accuracy DNNs with Batch Normalization layers Resultsforacross the Board 90% BSP Gaia Federated. Averaging 45% 0% Shuffled Data Skewed Data Shuffled Data Alex. Net 80% Top-1 Validation Accuracy Deep. Gradient. Compression Shuffled Data Le. Net Skewed Data Res. Net 20 Image Classification (CIFAR-10) BSP 100% 40% 0% Skewed Data Gaia Fed. Avg 100%BSP 80% 50% Shuffled Skewed Data 60% 0% Goog. Le. Net Res. Net 10 Image Classification (Image. Net) Shuffled Skewed Data Image Classification (Mammal-Flickr) Gaia Shuffled Data Fed. Avg Skewed Data Face Recognition (CASIA and test with LFW)

Degree of Skew is a Key Factor 20% Skewed Data Top-1 Validation Accuracy 80%

Degree of Skew is a Key Factor 20% Skewed Data Top-1 Validation Accuracy 80% 75% 70% 40% Skewed Data 60% Skewed Data 80% Skewed Data -0. 5% -1. 5% -3. 5% -1. 3% -3. 0% -1. 1% -2. 6% -4. 8% -5. 1% -5. 3% -6. 5% -8. 5% 60% BSP Gaia Federated Averaging Deep Gradient Compression CIFAR-10 with GN-Le. Net Degree of skew can determine the difficulty of the problem 17

Batch Normalization ― Problem and Solution 18

Batch Normalization ― Problem and Solution 18

Background: Batch Normalization [Ioffe & Szegedy, 2015] Prev Layer W Standard normal distribution (μ

Background: Batch Normalization [Ioffe & Szegedy, 2015] Prev Layer W Standard normal distribution (μ = 0, σ = 1) in each minibatch at training time BN Next Layer Normalize with estimated global μ and σ at test time Batch normalization enables larger learning rates and avoid sharp local minimum (generalize better)

Minibatch Mean Divergence Batch Normalization with Skewed Data Minibatch Mean Divergence: ||Mean 1 –

Minibatch Mean Divergence Batch Normalization with Skewed Data Minibatch Mean Divergence: ||Mean 1 – Mean 2|| / AVG(Mean 1, Mean 2) Shuffled Data 70% 35% 0% 0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 Channel CIFAR-10 with BN-Le. Net (2 Partitions) Minibatch μ and σ vary significantly among partitions Global μ and σ do not work for all partitions 20

Solution: Use Group Normalization [Wu and He, ECCV’ 18] Batch Normalization Group Normalization H,

Solution: Use Group Normalization [Wu and He, ECCV’ 18] Batch Normalization Group Normalization H, W C C N N Designed for small minibatches We apply as a solution for skewed data 21

Results with Group Normalization Validation Accuracy Shuffled Data 80% -12% 60% -26% Skewed Data

Results with Group Normalization Validation Accuracy Shuffled Data 80% -12% 60% -26% Skewed Data 0% -15% -10% -9% -29% 40% 20% 0% -70% BSP Gaia Federated Averaging Batch. Norm Deep Gradient Compression BSP Gaia Federated Averaging Deep Gradient Compression Group. Norm recovers the accuracy loss for BSP and reduces accuracy losses for decentralized algorithms 22

Skew. Scout: Decentralized learning over arbitrarily skewed data 23

Skew. Scout: Decentralized learning over arbitrarily skewed data 23

Overview of Skew. Scout • Recall that degree of data skew determines difficulty •

Overview of Skew. Scout • Recall that degree of data skew determines difficulty • Skew. Scout: Adapts communication to the skew-induced accuracy loss Model Travelling Accuracy Loss Estimation Communication Control Minimize commutation when accuracy loss is acceptable Work with different decentralized learning algorithms

Evaluation of Skew. Scout Communication Saving over BSP (times) All data points achieves the

Evaluation of Skew. Scout Communication Saving over BSP (times) All data points achieves the same validation accuracy 60 50 40 30 20 10 0 51. 8 Skew. Scout Oracle 50 40 34. 1 30 24. 9 19. 9 9. 6 10. 6 20 Skew. Scout 42. 1 29. 6 Oracle 23. 6 19. 1 9. 9 11. 0 10 0 20% Skewed 60% Skewed 100% Skewed CIFAR-10 with Alex. Net 20% Skewed 60% Skewed 100% Skewed CIFAR-10 with Goog. Le. Net Significant saving over BSP Only within 1. 5 X more than Oracle 25

Key Takeaways • Flickr-Mammal dataset: Highly skewed label distribution in the real world •

Key Takeaways • Flickr-Mammal dataset: Highly skewed label distribution in the real world • Skewed data is a pervasive problem • Batch normalization is particularly problematic • Skew. Scout: adapts decentralized learning over arbitrarily skewed data • Group normalization is a good alternative to batch normalization 26