Distributed Deep Learning Using Java on the Client
Distributed Deep Learning Using Java on the Client and in the Cloud Johan Vos, Gluon, @johanvos https: //gluonhq. com
What we will discuss • Brief introduction in deep learning and how to do that in Java using deeplearning 4 j. • Using Java on the Oracle Cloud allows for highperformant deep learning. • Client-side deep learning (on mobile) helps with privacy, security and improves the model in the Cloud. • Leverage Java deep learning algorithms in crossplatform, performant and beautiful Java mobile apps.
Deep Learning • A class of machine learning algorithms • Can be supervised and unsupervised • Raw input data (e. g. images, text, numbers) traverses a number of layers, where the last layer typically is a classification • Resources: https: //deeplearning 4 j. org/deeplearningforbeginners. html
Deep Learning
Deeplearning 4 j • When science shops become mature, they turn into Java shops • https: //deeplearning 4 j. org • Java framework providing Deep Learning functionality • Implementations use native mathematical libraries and GPU (if available) • Created by Sky. Mind • Open-source at github
Learning about digits • Mnist dataset • 60, 000 labeled training images • 10, 000 labeled test images • Supervised learning
About the code • All code shown in this presentations is in https: //github. com/gluonhq/gluon-samples/tree/master/mnist • Java standalone application to train model and make predictions • Server using Payara on Oracle Cloud • Mobile app sending requests to cloud and making local predictions
Train and evaluate network • • • Retrieve mnist data (70, 000 labeled images) Create a 2 -layer neural network Train network with mnist data Evaluate quality (how much % of test images fails? ) Make prediction with own data
Creating a Model for mnist Multi. Layer. Configuration conf = new Neural. Net. Configuration. Builder(). optimization. Algo(Optimization. Algorithm. STOCHASTIC_GRADIENT_DESCENT). layer(0, new Dense. Layer. Builder(). n. In(height * width). n. Out(100). activation(Activation. RELU). weight. Init(Weight. Init. XAVIER). build()). layer(1, new Output. Layer. Builder(Loss. Functions. Loss. Function. NEGATIVELOGLIKELIHOOD). n. In(100). n. Out(output. Num). activation(Activation. SOFTMAX). weight. Init(Weight. Init. XAVIER). build()). backprop(true). set. Input. Type(Input. Type. convolutional(height, width, channels)). build();
Train Model =========Confusion Matrix============= 0 1 2 3 4 5 6 7 8 9 -------------------------961 0 2 2 0 6 5 1 3 0 | 0 = 0 0 1113 2 2 0 1 4 2 11 0 | 1 = 1 8 8 908 18 9 3 14 14 45 5 | 2 = 2 2 1 15 929 0 29 2 12 14 6 | 3 = 3 1 4 5 1 917 0 9 2 6 37 | 4 = 4 9 4 3 26 9 794 10 5 24 8 | 5 = 5 11 3 4 1 15 19 899 2 4 0 | 6 = 6 2 11 23 7 5 1 0 948 3 28 | 7 = 7 7 6 6 22 8 16 10 10 881 8 | 8 = 8 10 7 1 12 40 11 0 21 4 903 | 9 = 9
Cloud setup • Oracle Cloud • Use existing account or create trial account • Compute classic instance with Java
Server application setup • • Simple Java EE application using Payara micro https: //www. payara. fish/payara_micro Accessible via REST JAX-RS handler
Server application @Path("handler") public class Handler { @Inject private Model. Service service; @POST @Produces(Media. Type. APPLICATION_JSON) @Path("classify. Image") public String classify. Image(@Header. Param("authorization") String auth. Header, byte[] raw. Data) { INDArray answer = service. predict(raw. Data); return answer. to. String(); } }
Mobile client • Java app, using Java. FX for UI • Allow camera to take pictures, or select pictures from gallery • Send image to Cloud • Cloud returns prediction • Show prediction in the app • If incorrect, ask user to provide correct result • Send corrected result to server and retrain
Java on i. OS/Android: Gluon Mobile • Java 8 (released) or Java 9 (beta) on i. OS and Android, including Java. FX • Tooling to build and deploy your application on desktop, mobile (i. OS/Android) and embedded. • Glisten, the UI library of controls with cross-platform behaviour but platform specific look and feel • Connect, API for communicating with Cloud. Link and other services • Down, device API in cross-platform fashion • https: //gluonhq. com/products/mobile/ 15
From Java client to Java server • Sending data to server using authentication • Don’t open rest endpoint to public world
Security • Direct connections between Mobile devices and enterprise back-ends are a security nightmare (for developers and operators)
From Java client to Java server • Sending data to server using authentication • Don’t open rest endpoint to public world • Client calls a “Remote Function” (Java API) which sends a requests to Gluon Cloud. Link, where the request is mapped to a REST endpoint in the server at Oracle Cloud • https: //gluonhq. com/products/cloudlink/
Gluon Cloud. Link on Oracle ACCS Public cloud infrastructure Storage DB Oracle Cloud Backend Application Oracle Bare Metal Cloud Configuratio n Application Container Cloud Gluon Cloud. Link Dashboard (https: //gluon. io) Mobile Application
From Java client to Java server Byte[] raw. Body = … Gluon. Observable. Object<String> classified = Remote. Function. Builder. create("classify. Image"). raw. Body(raw. Body). object(). call(String. class); … String result = classified. get();
Monitor calls via Gluon Cloud. Link
Did we have permission to do this? • Sending personal, raw data to cloud systems might be sensitive. • Pictures, locations, audio, video, … are all very relevant input for AI and Deep Learning • If users have to give permission for this, chances are they will refuse.
Distributed Deep Learning • Since we are using Java on the Client, we can run the exact same code to use the model and to predictions • Model is still retrieved from the server • Predictions are made on client, using raw input data from client
Deep Learning on client Native. Image. Loader loader = new Native. Image. Loader(width, height, channels, true); Image. Processing. Scaler scaler = new Image. Processing. Scaler(1, 0); row = loader. as. Row. Vector(image); scaler. transform(row); final Multi. Layer. Network nn. Model = model. get. Nn. Model(); int answer = model. predict(row)[0];
Quality concerns • Supervised learning not possible anymore? • We can do local training • And synchronize model to server
Distributed Deep Learning
Client-side learning INDArray row = loader. as. Row. Vector(image); INDArray labels = Nd 4 j. create(10); labels. put. Scalar(label, 1. 0 d); nnmodel. fit(row, labels); Gradient gradient = nnmodel. gradient(); INDArray update. Vector = gradient(); Gluon. Observable. Object<Void> function = Remote. Function. Builder. create("publish. Gradient"). caching. Enabled(false). raw. Body(Nd 4 j. to. Byte. Array(update. Vector)). object(). call(new Void. Input. Converter());
Server-side update public void publish. Gradient(byte[] client. Gradient) throws IOException { INDArray update. Gradient = Nd 4 j. from. Byte. Array(client. Gradient); Gradient gradient = model. gradient(); // do some smart merging System. out. println("Thanks for the update, our model just became better"); }
Server-side update public void apply. Gradient(Gradient[] gradient, int batch. Size) { Multi. Layer. Configuration mln. Conf = mln. get. Layer. Wise. Configurations(); int iteration. Count = mln. Conf. get. Iteration. Count(); int epoch. Count = mln. Conf. get. Epoch. Count(); mln. get. Updater(). update(mln, gradient[0], iteration. Count, epoch. Count, batch. Size, Layer. Workspace. Mgr. no. Workspaces()); mln. params(). subi(gradient[0]. gradient()); Collection<Training. Listener> iteration. Listeners = mln. get. Listeners(); if (iteration. Listeners != null && iteration. Listeners. size() > 0) { for (Training. Listener listener : iteration. Listeners) { listener. iteration. Done(mln, iteration. Count, epoch. Count); } } mln. Conf. set. Iteration. Count(iteration. Count + 1); } https: //github. com/deeplearning 4 j/rl 4 j/blob/master/rl 4 jcore/src/main/java/org/deeplearning 4 j/rl 4 j/network/dqn/DQN. java
Conclusions • • Deep Learning algorithms provide many opportunities Java developers can use Deep Learning API’s Deep Learning applications run very well on Oracle Cloud Java Mobile applications can be used as clients for Deep Learning cloud applications • Java on Mobile allows to do more processing on the mobile client, respecting user privacy and improving cloud model.
Thanks For Attending Any Questions? @johanvos Johan. vos@gluonhq. com @gluonhq https: //gluonhq. com support@gluonhq. com
- Slides: 31