Federated Learning: Training AI without sharing data


How to build powerful ML models while keeping data private and distributed
The Data Dilemma
In the world of large language models, the demand for publicly available data is growing every day.
"LLMs are running out of training data"
The problem is that we're running out of high-quality publicly available data. Current LLMs are trained on only about 15 trillion tokens of public data. Therefore, the traditional approach of centralising data for machine learning is hitting its limits. Meanwhile, the most valuable data sits locked away in organisations (healthcare, government, finance, manufacturing) and on billions of devices (phones, laptops, cars, smart home devices).
Traditional centralised training paradigms require aggregating data in a single location, creating significant barriers, including regulatory compliance, data sovereignty concerns, and privacy implications.
AI FACT: Meta Llama 3 8B outperformed its predecessor, Llama 2 70B, primarily because it was trained on significantly more data and four times as much code.
Solution:
Federated Learning - a paradigm shift that brings training to the data, not the other way around.
Federated Learning: Principles and Architecture
Federated Learning represents a distributed computational framework where machine learning models are trained across multiple decentralised data sources without requiring data centralisation. The core principle involves bringing computation to data rather than data to computation.
Architectural Components
Decentralised Training Nodes:
Individual participants (organisations, devices, or systems) maintain complete control over their local datasets while contributing to global model improvement through local training iterations.
Central Orchestration Server:
A coordination layer that manages model distribution, aggregates parameter updates, and maintains the global model state without accessing raw training data.
Federated Aggregation Protocols:
Sophisticated algorithms that combine local model updates into a cohesive global model, with Federated Averaging (FedAvg) serving as the foundational approach.
Use Cases:
Healthcare and Medical Research
Example: Consider a scenario where hospitals across different geographical regions contribute to training a cancer detection model without sharing sensitive patient imaging data or medical records.
Financial Services and Risk Management
Example: Banks can identify sophisticated financial crimes that span multiple institutions while preserving customer privacy and regulatory compliance requirements.
Mobile and Edge Computing
Example: Google's Gboard exemplifies this approach, where millions of mobile devices contribute to improving autocorrect and predictive text functionality through local training on user interaction patterns, transmitting only mathematical parameter updates rather than keystrokes or messages.
Reference: https://arxiv.org/abs/2305.18465
The Intuition and the implementation
The federated learning process follows a systematic cycle:
Initialisation:
The server initialises the global model.
Communication round:
For each round, the server sends the global model to each client participating in the training.
Training:
Each client trains the model on their private data and sends their locally updated model to the server.
Model Aggregation:
The server aggregates the updated models received from all clients using aggregation algorithms.
Convergence Check:
The FL process stops when convergence is met. If not, the next communication round is initiated
FedAvg: The classic aggregation algorithm
$$w_{t+1} = Σ(n_k/n) \times w_k^{t+1}$$
Where w_{t+1}
represents the updated global model, n_k
is the number of samples at client k
, n
is the total number of samples, and w_k^{t+1}
represents the local model update from client k
.
Hyperparameters
The following is a useful reference outlining the various categories of hyperparameters to consider when implementing Federated Learning.
Server Hyperparameters | Client Hyperparameters |
Client selection | Pre-processing |
Client configuration | Local training |
Result Aggregation | Post processing |
Data Privacy
You might think that Federated Learning is quite secure due to its decentralised nature, but three different levels of attacks can still occur within an FL infrastructure.
ATTACK | GOAL |
Membership Inference Attack | Infer the participation of the data samples |
Attribute Inference Attack | Infer unseen attributes of the training data |
Reconstruction attack | Infer specific training data samples |
Reference: https://arxiv.org/abs/2211.14952
NOTE: The primary goal of federated learning is “data minimisation”
Data Minimisation
This refers to the concept of minimising the share of data sent to the server. It is a privacy principle and design strategy that means collecting, processing, and sharing only the minimum amount of data necessary to achieve the learning objective.
While federated learning inherently provides data minimisation benefits, comprehensive privacy protection requires additional Privacy-Enhancing Technologies (PETs).
Differential Privacy (DP)
Differential Privacy serves as a mathematical framework ensuring individual data points cannot be inferred from model outputs.
One of the most commonly used mechanisms to achieve DP is adding enough noise to the output of the analysis to mask the contribution of each individual in the data while preserving the overall accuracy of the analysis.
DP Techniques
Gradient Clipping
Bounds the sensitivity of model updates by limiting the magnitude of gradient changes, reducing the impact of outliers and preventing information leakage through extreme parameter values.
Noise Injection
Introduces calibrated statistical noise to model updates, making outputs statistically indistinguishable regardless of individual data point presence or absence.
Sensitivity Analysis
Quantifies the maximum output change when a single data point is added or removed from the dataset, providing measurable privacy guarantees.
Depending on the level of privacy…
Central Differential Privacy
Clients send raw model updates to a trusted aggregator, which then adds calibrated noise to the aggregated result, providing privacy guarantees on the global model while preserving higher utility.
Local Differential Privacy
Each client independently adds noise to its model updates or data before sending them to the server, so privacy is protected even if the server is untrusted.
Bandwidth
Federated learning systems must carefully manage communication overhead, and thus, it needs to be calculated beforehand. The bandwidth requirement can be calculated as:
$$BW = (M_{out} + M_{in}) \times C \times f \times n$$
Where:
M_out
: Outbound model sizeM_in
: Inbound model sizeC
: Cohort size (number of participants)f
: Fraction of participants selected per roundn
: Number of training rounds
Uses of calculating bandwidth in FL
Estimate network requirements
Know how much data needs to be sent/received per round, so you can check if your network can handle it.
Optimise communication
Decide how to compress updates (quantisation, sparsification) to reduce data transfer.
Cost planning
For commercial or large deployments, bandwidth translates into real costs (e.g., cellular data, cloud transfer fees).
Scalability analysis
Predict how FL behaves when the number of clients increases or when models get larger.
Energy efficiency
On mobile or edge devices, less bandwidth means less energy consumption and longer battery life.
Common ways of reducing bandwidth
REDUCE UPDATE SIZE | COMMUNICATE LESS |
Sparsification | Pre-trained models |
Quantisation | Train for more epochs |
Apart from model compression, some other optimisation strategies include Asynchronous Aggregation and Hierarchical Federated Learning
Challenges and Considerations
System Heterogeneity:
Participants operate diverse hardware configurations, network capabilities, and computational resources.
Statistical Heterogeneity:
Non-IID data distributions across participants create challenges for global model convergence.
Security and Trust:
Federated learning systems must protect against various attack vectors, including model poisoning, inference attacks, and byzantine failures.
Conclusion
Federated Learning transcends traditional machine learning paradigms by enabling collaborative model development without compromising data privacy or sovereignty. As organisations increasingly recognise data as a strategic asset requiring careful protection, federated learning provides a pathway to harness collective intelligence while maintaining individual control.
The successful implementation of federated learning systems requires careful consideration of technical, privacy, and operational challenges. However, the potential benefits, including enhanced model performance, improved privacy protection, and expanded collaboration opportunities, position federated learning as an essential component of the future AI ecosystem.
Organisations investing in federated learning capabilities today will be better positioned to navigate the evolving landscape of privacy-preserving AI development and collaborative machine learning initiatives.
References and Further Reading
Wen, J., et al. (2017). "A survey on federated learning: challenges and applications"
Li, T., et al. (2020). "Federated Learning: Challenges, Methods, and Future Directions."
Yang, Q., et al. (2019). "Federated Machine Learning: Concept and Applications."
Subscribe to my newsletter
Read articles from Soham Samal directly inside your inbox. Subscribe to the newsletter, and don't miss out.
Written by
