Federated Learning with PySyft

The new era of training Machine Learning model with on-device capability. In this tutorial I will be using PyTorch and PySyft to train a Deep Learning neural network using federated approach.

Saransh Mittal
Towards Data Science

--

Introduction to Federated Learning

What is Federated Learning?

Federated Learning is a distributed machine learning approach which enables model training on a large corpus of decentralised data. Federated Learning enables mobile phones to collaboratively learn a shared prediction model while keeping all the training data on device, decoupling the ability to do machine learning from the need to store the data in the cloud. This goes beyond the use of local models that make predictions on mobile devices (like the Mobile Vision API and On-Device Smart Reply) by bringing model training to the device as well.

The goal is a machine learning setting where the goal is to train a high-quality centralised model with training data distributed over a large number of clients each with unreliable and relatively slow network connections.

This new field consists of an ensemble of techniques that allow ML engineers to train models without having direct access to the data used for the training and avoid them to get any information about the data by the use of cryptography.

This framework relies on three main techniques:

  • Federated Learning
  • Differential Privacy
  • Secured Multi-Party Computation

In this article, I will cover Federated Learning and its application for predicting Boston Housing prices.

How Google uses Federated Learning?

How google uses Federated Learning to make more accurate keyboard suggestions

How does the application works?

With the rise of many famous libraries like PySyft and Tensorflow Federated. It has become easier for general developers, researchers and machine learning enthusiasts to create a decentralised machine learning training model. In this project to train a dataset based on the aim to predict housing prices of the properties listed in the city of Boston, I have used PySyft — a Python library for secure, private Deep Learning. PySyft decouples private data from model training, using Multi-Party Computation (MPC) within PyTorch.

While training the Deep Learning prediction model, the data is securely saved locally with Alice and Bob. For a private training, I used a Federated approach where the ML model is trained locally with on-device capability of mobile devices owned by two parties Alice and Bob. With the rise of computer performance of the mobile devices, it has become easier to train ML models with much more efficiency.

Steps involved in the Federated Learning approach

  1. The mobile devices download the global ML model
  2. Data is being generated while the user is using the application linked with the ML model
  3. As the user starts to interact with the application more, the user gets much better predictions according to his usage
  4. Once the model is ready for the scheduled sync with the server. The personalised model that was getting trained with the on device capability is sent to the server
  5. Models from all the devices are collected and a Federated average function is used to generate a much improved version of the model than the previous one
  6. Once trained the improved version is sent to all the devices where the user gets the experience based on the usage by all the devices around the globe

Installing PySyft

PySyft is a Python library for secure, private Deep Learning. PySyft decouples private data from model training, using Federated Learning, Differential Privacy, and Multi-Party Computation (MPC) within PyTorch.

In order to install PySyft, it is recommended that you set up a conda environment first

conda create -n pysyft python=3
conda activate pysyft
conda install jupyter notebook

You then need to install the package

pip install syft

Step by step guide to develop the Neural Network using Federated Learning approach

1) Importing the libraries

The following Python libraries were used for the developing the project.

  1. Numpy — NumPy is a library for the Python programming language, adding support for large, multi-dimensional arrays and matrices, along with a large collection of high-level mathematical functions to operate on these arrays.
  2. PyTorch — PyTorch is an open source machine learning library based on the Torch library, used for applications such as computer vision and natural language processing. It is primarily developed by Facebook’s artificial intelligence research group. It is free and open-source software.
  3. PySyft — PySyft is a Python library for secure, private Deep Learning. PySyft decouples private data from model training, using Federated Learning, Differential Privacy, and Multi-Party Computation (MPC) within PyTorch.
  4. Pickle — The pickle module implements binary protocols for serialising and de-serialising a Python object structure.
Importing libraries in the Jupyter Notebook

2) Initiating the training parameters

We have trained the neural network for more than 100 epochs to get good results. By creating a total batches of 8 the records we were able to get good results. Learning rate for was set at 0.001 to use stochastic gradient descent as the optimiser for the network. As of now PySyft only supports the SGD optimiser for the Back propagation algorithm to calculate the error and updating the network parameters.

Gradient descent is a first-order iterative optimisation algorithm for finding the minimum of a function. To find a local minimum of a function using gradient descent, one takes steps proportional to the negative of the gradient (or approximate gradient) of the function at the current point.

Initiating the learning parameter in the Jupyter Notebook

3) Dataset Preprocessing

The next step involves reading the dataset into the Jupyter notebook and preprocessing it before training the neural network with the data. Preprocessing the data helps us to get more understanding about the dataset and helps us to select the best features that could help in predicting the results from the input.

Data Preprocessing for the Boston Housing dataset

In the project we have used the most common Boston Housing dataset to train the neural network. We are predicting the prices of the various kinds of housing properties based on the different pricing features. Some of them are listed below. You can find the dataset here.

Each record in the database describes a Boston suburb or town. The data was drawn from the Boston Standard Metropolitan Statistical Area (SMSA) in 1970.

The attributes are defined as follows

  1. CRIM: per capita crime rate by town
  2. ZN: proportion of residential land zoned for lots over 25,000 sq.ft.
  3. INDUS: proportion of non-retail business acres per town
  4. CHAS: Charles River dummy variable (= 1 if tract bounds river; 0 otherwise)
  5. NOX: nitric oxides concentration (parts per 10 million)
  6. RM: average number of rooms per dwelling
  7. AGE: proportion of owner-occupied units built prior to 1940
  8. DIS: weighted distances to five Boston employment centers
  9. RAD: index of accessibility to radial highways
  10. TAX: full-value property-tax rate per $10,000
  11. PTRATIO: pupil-teacher ratio by town
  12. B: 1000(Bk−0.63)2 where Bk is the proportion of blacks by town
  13. LSTAT: % lower status of the population
  14. MEDV: Median value of owner-occupied homes in $1000s We can see that the input attributes have a mixture of units.

4) Creating Neural Network with PyTorch

We now have to define the Neural Network architecture for the model using PyTorch. The deep learning network consists of 2 different hidden layers and involves use of Relu activation function for all the layers in the network. The input layer consists of 13 different perceptrons which corresponds to each of the input feature for the training dataset.

5) Connecting the data with the remote mobile devices

Bob and Alice are the two people who are involved in the whole cycle. For simulation purposes we send batches of dataset to all the network clients who interacting with the application using the global ML models.

We can see that with the PySyft library and its PyTorch extension, we can perform operations with tensor pointers such as we can do with PyTorch API (but for some limitations that are still to be addressed).

Now we connect with the workers named Alice and Bob for training the neural network with the data available with the parties.

6) Training the Neural Network

Thanks to PySyft, we were able to train a model without having any access to the remote and private data: for each batch, we sent the model to the current remote worker and get it back to the local machine before sending it to the worker of the next batch.

There is however one limitation of this method: by getting the model back we can still have access to some private information. Let’s say Bob had only a single data record on his machine. When we get the model back, we can check with the updated weights for the data with which Bob trained retrained the model.

In order to address this issue, there are two solutions: Differential Privacy and Secured Multi-Party Computation (SMPC). Differential Privacy would be used to make sure the model does not give access to some private information. SMPC, which is one kind of Encrypted Computation, in return allows you to send the model privately so that the remote workers which have the data cannot see the weights you are using.

You can checkout the whole project with the Jupyter Notebook available below. I have also left the link to the Federated Learning project which is available on my GitHub repository below.

Resources

If you are interested in learning more about Secure and Private AI and how to use PySyft you can also check out this free course on Udacity. It’s a great course for beginners taught by Andrew Trask, the founder of the OpenMined Initiative.

Check out the project on GitHub

About Me

I am a final year Computer Science and Engineering student in India and have been coding since past 5 years. After working on several Machine Learning and iOS development projects, I am learning the new technique of Federated Learning because the next decade of how you train the Deep Learning model on mobile devices depends of the Federated protocol. You can know more about me on my website. For more projects checkout my GitHub profile.

--

--

Winner @Ivyhacks 2020 | Finalist @Microsoft Imagine Cup Asia 2019, @Hackharvard 2018 | Facebook F8 2019 | saransh.xyz