FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX.

Overview

FedJAX: Federated learning with JAX

What is FedJAX?

FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX. FedJAX prioritizes ease-of-use and is intended to be useful for anyone with knowledge of NumPy.

FedJAX is built around the common core components needed in the FL setting:

  • Federated datasets: Clients and a dataset for each client
  • Models: CNN, ResNet, etc.
  • Optimizers: SGD, Momentum, etc.
  • Federated algorithms: Client updates and server aggregation

For Models and Optimizers, FedJAX provides lightweight wrappers and containers that can work with a variety of existing implementations (e.g. a model wrapper that can support both Haiku and Stax). Similarly, for Federated datasets, TFF provides a well established API for working with federated datasets, and FedJAX just provides utilties for converting to NumPy input acceptable to JAX.

However, what FL researchers will find most useful is the collection and customizability of Federated algorithms provided out of box by FedJAX.

Quickstart

The FedJAX Intro notebook provides an introduction into running existing FedJAX experiments. For more custom use cases, please refer to the FedJAX Advanced notebook.

You can also take a look at some of our examples:

Installation

You will need Python 3.6 or later and a working JAX installation. For a CPU-only version:

pip install --upgrade pip
pip install --upgrade jax jaxlib  # CPU-only version

For other devices (e.g. GPU), follow these instructions.

Then, install fedjax from PyPi:

pip install fedjax

Or, to upgrade to the latest version of fedjax:

pip install --upgrade git+https://github.com/google/fedjax.git

Useful pointers

NOTE: FedJAX is not an officially supported Google product. FedJAX is still in the early stages and the API will likely continue to change.

Comments
  • FedJax depends on TensorFlow Federated?

    FedJax depends on TensorFlow Federated?

    I am helping users install FedJax for use in their federated learning research projects and I noticed that installing FedJax is pulling in TensorFlow Federated (0.17) and TensorFlow (2.3). I don't see either of these listed as dependencies of FedJax so I am trying to understand why they are being pulled in by pip install fedjax.

    opened by davidrpugh 7
  • CIFAR 100 Questions

    CIFAR 100 Questions

    Hi, thanks for the awesome library! I want to ask a couple of questions related to CIFAR100 datasets.

    1. I noticed that while the dataset is available in the library, the model is not. Curious if a model for CIFAR100 is work-in-progress, or if there is no short-term plan for this?
    2. Looking at the CIFAR100 dataset, this seems to be inconsistent with Google's TFF. Notably, the cropping size and normalizing are done differently from TFF. Is this intentional? Would it be correct to say that we could expect this to mirror TFF's design eventually?

    Thanks in advance for all the help!

    opened by HanGuo97 5
  • unbiased scale for DRIVE

    unbiased scale for DRIVE

    Following a discussion with @stheertha, I suggest using the unbiased scale (section 4.2 in Drive's paper) for cases where there is more than 1 client.

    Thank you for considering.

    opened by amitport 3
  • Problem of Quick Start in Readme.md

    Problem of Quick Start in Readme.md

    I tried to run the code in the QuickStart and I found some problems. federated_data = fedjax.FederatedData() can not be executed because it is an abstract class. So I replaced it as

    client_a_data = {
            'x': np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
            'y': np.array([7, 8])
        }
    client_b_data = {'x': np.array([[9.0, 10.0, 11.0]]), 'y': np.array([12])}
    client_to_data_mapping = {'a': client_a_data, 'b': client_b_data}
    federated_data = fedjax.InMemoryFederatedData(client_to_data_mapping)
    

    The other things are same as the QuickStart, but i got an error

    for client_id, client_output, _ in func(shared_input, clients):
    for client_id, client_batches, client_input in clients:
    ValueError: not enough values to unpack (expected 3, got 2)
    

    It seems that client_batches is missing and we need to batch the dataset, but there is no example which fits this situation.

    opened by Ichiruchan 2
  • Full EMNIST example does not exhibit parallelization

    Full EMNIST example does not exhibit parallelization

    Hi! I am facing an issue with parallelizing the base code provided by the developers.

    • My local workstation contains two GPUs.
    • I installed FedJax in a conda environment
    • I downloaded "emnist_fed_avg.py" file from the folder "examples", deleted the "fedjax.training.set_tf_cpu_only()" line and replaced fed_avg.federated_averaging to fedjax.algorithms.fed_avg.federated_averaging on line 61
    • Having activated the conda environment, I ran the file with python emnist_fed_avg.py. The file runs correctly and prints the expected output (round nums and train/test metrics on each 10th round)
    • The nvidia-smi command shows zero percent utilization and almost zero memory usage on one of the GPUs (and ~40% utilization/maximum memory usage on another node)

    Any ideas what I am doing wrong?

    opened by gaseln 2
  • Clarifying the meaning of

    Clarifying the meaning of "weight"

    In the Intro notebook, the backward_pass_output from model.backward has a weight feature. It seems to me that this is used for performing a weighted averaging in FedAvg, but this is not clear to me how. Perhaps this can be renamed to batch_size?

    opened by Saipraneet 1
  • [NumPy] Remove references to deprecated NumPy type aliases.

    [NumPy] Remove references to deprecated NumPy type aliases.

    [NumPy] Remove references to deprecated NumPy type aliases.

    This change replaces references to a number of deprecated NumPy type aliases (np.bool, np.int, np.float, np.complex, np.object, np.str) with their recommended replacement (bool, int, float, complex, object, str).

    NumPy 1.24 drops the deprecated aliases, so we must remove uses before updating NumPy.

    opened by copybara-service[bot] 0
  • Disable pytype import error for old stax import path

    Disable pytype import error for old stax import path

    Disable pytype import error for old stax import path

    Why? The deprecated jax.experimental.stax path will soon be removed (see https://github.com/google/jax/pull/11700), and this causes pytype to fail.

    opened by copybara-service[bot] 0
  • Rename jax.experimental.stax -> jax.example_libraries.stax

    Rename jax.experimental.stax -> jax.example_libraries.stax

    Rename jax.experimental.stax -> jax.example_libraries.stax

    Why? The former name has been deprecated since JAX version 0.2.25, released in November 2021 (see https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0225-nov-10-2021), and will soon be removed.

    opened by copybara-service[bot] 0
  • Implement standard CIFAR-100 model in fedjax.models.cifar100

    Implement standard CIFAR-100 model in fedjax.models.cifar100

    Add a standard implementation of the model for the CIFAR-100 task. The dataset can be found in fedjax.datasets.cifar100.

    For the model architecture, we should follow “Adaptive Federated Optimization”. The model architecture is detailed in section 4 as a ResNet-18 (replacing batch norm with group norm). Code for this paper and a Keras implementation of the model can be found here. We suggest using either haiku or flax to implement the model for use with JAX.

    If you choose to use haiku, you can use fedjax.create_model_from_haiku to create a fedjax compatible model. If you choose to use flax, wrapping it in a fedjax.Model is fairly straightforward and we can provide guidance for this.

    A good example to follow is #265 that checks in a simple linear model for CIFAR-100 and includes the model implementation, tests, and baseline results with FedAvg using this script. Make sure to add a flags file similar to https://github.com/google/fedjax/blob/main/experiments/fed_avg/fed_avg.CIFAR100_LOGISTIC.flags and add the new task to https://github.com/google/fedjax/blob/main/fedjax/training/tasks.py.

    Thanks for your contributions!

    enhancement contributions welcome 
    opened by jaehunro 1
  • Support for manually modifying client/server learning rate

    Support for manually modifying client/server learning rate

    Hi, I'm playing around with clients learning rate but I cannot find a clean way of modifying it.

    Basically, I need to change the LR following a schedule based on the current round. Is that possible?

    Thanks

    opened by marcociccone 1
  • Support for gldv2 and inaturalist datasets

    Support for gldv2 and inaturalist datasets

    I think it would be great to port these datasets from tff to fedjax. I would be happy to make the effort and contribute to the library, but I need a bit of support from the fedjax team 🙂

    By looking at the tff codebase (gldv2, inaturalist) it looks that load_data_from_cache function creates a tfrecords file for each client.

    The only concrete classes that I see are SQLiteFederatedData and InMemoryFederatedData, but I don't think they are meant for this use case. What would be the best way to map the clients into a FederatedDataset? We could replicate something like FilePerUserClientData.

    Thanks!

    opened by marcociccone 7
  • Support for haiku models with non-trainable state

    Support for haiku models with non-trainable state

    Hi! congrats on this great library! I've started using it a few days ago and I love it!

    Is there any way to use a haiku model with a non-trainable state (e.g. to use batch norm)? I didn't find any nontrivial way, but maybe I'm missing something.

    Thanks a lot for your help!

    opened by marcociccone 2
  • How to create a validation dataset?

    How to create a validation dataset?

    Hello!

    I may need to split each client's train dataset into train and validation parts for grid search purposes (for example, tuning the stepsizes in a method). How can this be achieved in the framework?

    opened by gaseln 4
  • Feature request: Convert standard dataset into a federated dataset

    Feature request: Convert standard dataset into a federated dataset

    Synthetic federated datasets can constructed from standard centralized ones by artificially splitting them among clients. This is usually done using a Dirichlet distribution (e.g. Hsu et al. 2019). Such synthetic datasets are very useful since we can explicitly control the total number of users, as well as the heterogeneity.

    It would be great to have primitives which can automatically convert standard numpy dataset into a FedJax datset.

    contributions welcome 
    opened by Saipraneet 5
Releases(v0.0.15)
Owner
Google
Google ❤️ Open Source
Google
Implementation of SegNet: A Deep Convolutional Encoder-Decoder Architecture for Semantic Pixel-Wise Labelling

Caffe SegNet This is a modified version of Caffe which supports the SegNet architecture As described in SegNet: A Deep Convolutional Encoder-Decoder A

Alex Kendall 1.1k Jan 02, 2023
Explaining neural decisions contrastively to alternative decisions.

Contrastive Explanations for Model Interpretability This is the repository for the paper "Contrastive Explanations for Model Interpretability", about

AI2 16 Oct 16, 2022
[3DV 2020] PeeledHuman: Robust Shape Representation for Textured 3D Human Body Reconstruction

PeeledHuman: Robust Shape Representation for Textured 3D Human Body Reconstruction International Conference on 3D Vision, 2020 Sai Sagar Jinka1, Rohan

Rohan Chacko 39 Oct 12, 2022
[Preprint] "Chasing Sparsity in Vision Transformers: An End-to-End Exploration" by Tianlong Chen, Yu Cheng, Zhe Gan, Lu Yuan, Lei Zhang, Zhangyang Wang

Chasing Sparsity in Vision Transformers: An End-to-End Exploration Codes for [Preprint] Chasing Sparsity in Vision Transformers: An End-to-End Explora

VITA 64 Dec 08, 2022
Fast and accurate optimisation for registration with little learningconvexadam

convexAdam Learn2Reg 2021 Submission Fast and accurate optimisation for registration with little learning Excellent results on Learn2Reg 2021 challeng

17 Dec 06, 2022
EncT5: Fine-tuning T5 Encoder for Non-autoregressive Tasks

EncT5 (Unofficial) Pytorch Implementation of EncT5: Fine-tuning T5 Encoder for Non-autoregressive Tasks About Finetune T5 model for classification & r

Jangwon Park 34 Jan 01, 2023
Demonstration of transfer of knowledge and generalization with distillation

Distilling-the-Knowledge-in-a-Neural-Network This is an implementation of a part of the paper "Distilling the Knowledge in a Neural Network" (https://

26 Nov 25, 2022
Face Mask Detector by live camera using tensorflow-keras, openCV and Python

Face Mask Detector 😷 by Live Camera Detecting masked or unmasked faces by live camera with percentange of mask occupation About Project: This an Arti

Karan Shingde 2 Apr 04, 2022
Bayesian optimization in PyTorch

BoTorch is a library for Bayesian Optimization built on PyTorch. BoTorch is currently in beta and under active development! Why BoTorch ? BoTorch Prov

2.5k Dec 31, 2022
Official Code Release for Container : Context Aggregation Network

Container: Context Aggregation Network Official Code Release for Container : Context Aggregation Network Comparion between CNN, MLP-Mixer and Transfor

peng gao 42 Nov 17, 2021
Fully Convolutional DenseNet (A.K.A 100 layer tiramisu) for semantic segmentation of images implemented in TensorFlow.

FC-DenseNet-Tensorflow This is a re-implementation of the 100 layer tiramisu, technically a fully convolutional DenseNet, in TensorFlow (Tiramisu). Th

Hasnain Raza 121 Oct 12, 2022
A Temporal Extension Library for PyTorch Geometric

Documentation | External Resources | Datasets PyTorch Geometric Temporal is a temporal (dynamic) extension library for PyTorch Geometric. The library

Benedek Rozemberczki 1.9k Jan 07, 2023
MXNet implementation for: Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution

Octave Convolution MXNet implementation for: Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution Imag

Meta Research 549 Dec 28, 2022
Pose estimation with MoveNet Lightning

Pose Estimation With MoveNet Lightning MoveNet is the TensorFlow pre-trained model that identifies 17 different key points of the human body. It is th

Yash Vora 2 Jan 04, 2022
Synthetic structured data generators

Join us on What is Synthetic Data? Synthetic data is artificially generated data that is not collected from real world events. It replicates the stati

YData 850 Jan 07, 2023
An implementation of the Contrast Predictive Coding (CPC) method to train audio features in an unsupervised fashion.

CPC_audio This code implements the Contrast Predictive Coding algorithm on audio data, as described in the paper Unsupervised Pretraining Transfers we

Meta Research 283 Dec 30, 2022
PyTorch implementation of normalizing flow models

PyTorch implementation of normalizing flow models

Vincent Stimper 242 Jan 02, 2023
PyKaldi GOP-DNN on Epa-DB

PyKaldi GOP-DNN on Epa-DB This repository has the tools to run a PyKaldi GOP-DNN algorithm on Epa-DB, a database of non-native English speech by Spani

18 Dec 14, 2022
Train robotic agents to learn pick and place with deep learning for vision-based manipulation in PyBullet.

Ravens is a collection of simulated tasks in PyBullet for learning vision-based robotic manipulation, with emphasis on pick and place. It features a Gym-like API with 10 tabletop rearrangement tasks,

Google Research 367 Jan 09, 2023
Learning Time-Critical Responses for Interactive Character Control

Learning Time-Critical Responses for Interactive Character Control Abstract This code implements the paper Learning Time-Critical Responses for Intera

Movement Research Lab 227 Dec 31, 2022