The other day I came up with this interesting video made by Tensorflow and Françoise Beaufay (research scientist at Google). This article is just a resume and some notes regarding that awesome video, so if you are interested and want to know more about that check the Tensorflow channel and particularly the video about Federated learning: Basics and application to the mobile keyboard (ML Tech Talks).
That being said, let’s start talking about Federated Learning and how it is used for protecting you data privacy, specifically on your keyboard.
Introduction to Federated Learning. The what & why
Why Federated Learning.
When we think about how a machine learning model works we think most of the time about a train and test datasets that will train the model and when it’s deploying time, the approach (most of the time) is server-side. One of the advantages of the server-side deployment is that we can capture more data later that will enhance our model.
What if the model is deployed on-device?
Data-privacy is a critical matter when we are talking about users and products. Therefore, gathering data or logging will be a difficult task for sure. Federated learning offers a privacy-sensitive solution.
Federated Learning allows us to keep enhancing models in the current user’s device, keeping the user’s raw data and learning on their device.
The key part in Federated Learning is that it allows to aggregate learning across different users, combining them server-side in a new model for downloading it again and the cool part is that we only need phones, data (inside users’ devices) and new protocols / algorithms.
How Federated Learning works
Let’s start with the two basics elements for making it work:
- A server with its own data and initial trained model.
- The user’s device with its own data stored inside.
In federated models, the device downloads and uses the initial model as a starting point for training.
The idea behind Federated Learning is to train a centralized model based on decentralized data
There’s a communication between device and server where the device asks the server if it is necessary from the device to perform some learning task. If that’s so, then federated learning is applied:
- The server will prepare the initial model and push it on the device to download it.
- Then, the device starts training starting from that model and based on the data that it has in its own little data store.
- Finally, the device uploads this new model to the server.
The cool thing here is that this is a process that goes through multiple devices, not just one ! Therefore, the server is being fed with tons of trained models that it will aggregate into a single one in order to generate a new model. In this way, it is possible to repeat the process multiple times in order to enhance the models.
AI and Federated Learning on mobile keyboards.
OK, so now we “know” about what federated learning is but, how does it apply in mobile keyboards? Taking the GBoard of Google as an example, we know that at the same time that we are writing on the keyboard, we can see word suggestions in order to speed our writing as well. Nowadays, most of the time, we use the keyboard for chatting or making personal searches on the Internet so privacy in both cases are keys.
Keyboards nowadays have functionalities that use machine learning as their engine.
An extra point of why it is great to use machine learning in keyboards is because the model can learn about corrected errors that users fix while they are writing (think about you are writing one word and the suggested word is not the one that you want to send. The model can learn from that case for example)
Process overview
As it is said at the beginning of the article there are two basic elements for making federated learning work: the device (client) and the server. In this case, as we are talking about the keyboards, the device will be a smartphone and the goal will be to enhance the keyboard.
- The first step is that we come up with an initial model on the server-side selecting clients where the training rounds will happen.
- Once the training part in the client-side is done, the client returns the updated set of weights or the updated model to the server.
- Finally, the server combines all the models into a new model.
The initial model
Using an initial model for this process provides a multiple advantages like:
- Consumes less privacy budget.
- It takes less time for reaching the same accuracy. In the example that Françoise Beaufays provides, for reaching the 15% of accuracy:
- Without an initial model: It takes 350 rounds of training.
- Using an initial model: It takes 80 rounds.
Although is not possible to use private data for training that initial model, it is possible there are a tons of web documents, books, online examples…etc. for feeding that initial model.
Optimization in the client side
That being said, there are some challenges on training in a small device like for instance the code size (in this case based on TensorFlow code) which can’t be too large. Therefore just a few libraries from Tensorflow will be selected (pruning graph concept) depending on the task at hand.
Another challenge for the client-side would be the normalizers needed (not just lemmatization or stemming, we are talking about the big processes related to NLP problems). So in this case, server-side normalizers must be rewritten in order to make them portable and reduce their code-size.
Federated learning needs a lot of previous planning before implementation.
Besides all the above things that should be bear in mind, the optimization in the client side might be more complicated in this case taking into account that:
- The hyper-parameters need to be optimized for a fast convergence (batch size, epochs, learning rate).
- The data (amount & quality) that each device will have will be really different compared with other devices (different usage, different data).
Therefore it is important to find different algorithms that can cover all of those different conditions and situations.
Requirements for the keyboard training (GBoard case)
There are some requirements with the aim of making the whole process invisible to the user. First of all, the server specifies the client requirements. After that, the device must be plugged/charging, connected to Wi-Fi and in idle status for a period of time.
Also the server will pick a random subset of devices for the training round and it is even possible that the server discards some of those devices (this process of selecting and discarding devices is named as pace steering).
Real world applications
The application of Federated Learning on multiple keyboards features have improved several aspects of it. For instance, Next Word Prediction feature improved around a +24% of accuracy and +10% of CTR.
Last but not least, Federated Learning allows the creation of new functionalities like Learning “new trendy” Next Words instead of using a static list of words as a baseline.
Also Federated Learning might be applied to others fields where utility and privacy seem to be in conflict like helping banks to suggest best users’ personalized products, hospitals for improving diagnostics using data shared between different ones….
Sources:
- Federated learning: Basics and application to the mobile keyboard (ML Tech Talks)
- Federated Learning with Google
- Introducing TensorFlow Federated
Contact info
- 📱 Linkedin: Juan Antonio Cabeza Sousa
- 📬 Email: juaancabsou@gmail.com
- 🖥️ Twitter: @Aceconhielo