Federated learning 101 warts and all!
Frenemies often want to collaborate but there’s this underlying mistrust between them. This is where federated learning (FL) can help.
FL enables collaborative training of robust ML & AI models across decentralized datasets without requiring the centralization of raw, private information. FL is what’s known as a Privacy enhanced technology ( PET) . It’s one of a number of techniques often used with other PETs such as differential privacy to help minimise PII data use .
Using FL can also be used to address situations other than the frenemies scenario where data cannot be moved from its source such as financial services and healthcare solutions where PII and regulations mandate this. FL is also getting traction as a way to protect the source data being used to train models .
FL can at a high level be categorised by how the data and features are distributed among participating parties.
- Horizontal Federated Learning:
This involves training a model on multiple datasets with the same features( columns) e.g. different hospitals sharing patient data with the same parameters but different patient populations. - Vertical Federated Learning:
This involves training a model on datasets that share the same sample IDs ( rows) e.g. different hospitals sharing patient data with different features. Each participant contributes complementary features about the same individuals. - Federated Transfer Learning:
This scenario is where there is a Partial overlap in both samples and features among the datasets e.g two hospitals might have patient records with some overlapping individuals and some shared medical parameters, but also unique features in each dataset. - Split learning: which I am going to include although the purists may disagree with me including but hey my post ! This involves splitting a model’s architecture into multiple segments and training them on different parties (clients and server). Only intermediate representations are exchanged, not raw data.
FL can then be further categorised into
- Cross-silo : which is where the participating members are organizations or companies. In practice, the number of members is usually small (for example, within one hundred members).
- Cross-device: which is where the participating members are end-user devices such as mobile phones, vehicles, or IoT devices. The number of members can reach up to a scale of millions or even tens of millions.
The basic premise for the way FL works regardless of how the data and features are distributed ( split learning apart as that is more complex) or whether its cross-silo or cross-device is:
- Create a global model
- Send a copy of the model or part of the model to where the data is located
- Use the co-located model to train on the local data
- Send the locally created weights back to be used to update the global model
Rinse repeat.
Obviously there’s detail missing from that process and the nuances between cross-silo and cross-device are not covered but that’s the general gist & this is a 101.
Google introduced FL as a concept back in 2016 and have written a number of seminal papers on this : [1602.05629] Communication-Efficient Learning of Deep Networks from Decentralized Data , Towards Federated Learning at Scale: System Design Google has incorporated FL into their products e.g APPLIED FEDERATED LEARNING: IMPROVING GOOGLE KEYBOARD QUERY SUGGESTIONS which is delightfully explained here Federated Learning
But I hear you ask how can I as an enterprise use FL in production in a way my security admins will allow?
CSPs have reference architectures and as you’d expect I’ll point you at my cloud of choice Google cloud which has an enterprise ready federated learning reference architecture built on Google Kubernetes Engine (GKE).
What I like about the approach (for transparency I am one of the authors!) is that it does not lock you into a single framework. You have the flexibility to deploy and run other FL frameworks compatible with Kubernetes and containerized environments The repo for the reference architecture includes an example of doing just this using Nvidia Flare .
Although FL unlocks value in decentralised data for example being a great way to provide data in a privacy enhanced way for fine tuning LLMs (Integration of large language models and federated learning - ScienceDirect , [2402.06954] OpenFedLLM: Training Large Language Models on Decentralized Private Data via Federated Learning ) this can also cause problems due to the fact there is no centralised control and thus one of the participants may be compromised or deliberately poisoning their local model. FL members require a strong degree of trust between collaborators so it’s great to start to see more ways of tackling the " security" issues with FL . As part of my collection of posts at GAI is going well I’ve started stumbling across articles focusing on addressing the achilles heel when using FL including the below:
DeTrigger: A Gradient-Centric Approach to Backdoor Attack Mitigation in Federated Learning proposes a framework to help defend against backdoor attacks ( tl;dr where a local model is deliberately poisoned & the poisoned updates get passed back to the global model)
Tazza: Shuffling Neural Network Parameters for Secure and Private Federated Learning proposes a framework to defend against both data leakage via gradient inversion and model poisoning from malicious clients