FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX.

Overview

FedJAX: Federated learning with JAX

What is FedJAX?

FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX. FedJAX prioritizes ease-of-use and is intended to be useful for anyone with knowledge of NumPy.

FedJAX is built around the common core components needed in the FL setting:

  • Federated datasets: Clients and a dataset for each client
  • Models: CNN, ResNet, etc.
  • Optimizers: SGD, Momentum, etc.
  • Federated algorithms: Client updates and server aggregation

For Models and Optimizers, FedJAX provides lightweight wrappers and containers that can work with a variety of existing implementations (e.g. a model wrapper that can support both Haiku and Stax). Similarly, for Federated datasets, TFF provides a well established API for working with federated datasets, and FedJAX just provides utilties for converting to NumPy input acceptable to JAX.

However, what FL researchers will find most useful is the collection and customizability of Federated algorithms provided out of box by FedJAX.

Quickstart

The FedJAX Intro notebook provides an introduction into running existing FedJAX experiments. For more custom use cases, please refer to the FedJAX Advanced notebook.

You can also take a look at some of our examples:

Installation

You will need Python 3.6 or later and a working JAX installation. For a CPU-only version:

pip install --upgrade pip
pip install --upgrade jax jaxlib  # CPU-only version

For other devices (e.g. GPU), follow these instructions.

Then, install fedjax from PyPi:

pip install fedjax

Or, to upgrade to the latest version of fedjax:

pip install --upgrade git+https://github.com/google/fedjax.git

Useful pointers

NOTE: FedJAX is not an officially supported Google product. FedJAX is still in the early stages and the API will likely continue to change.

Comments
  • FedJax depends on TensorFlow Federated?

    FedJax depends on TensorFlow Federated?

    I am helping users install FedJax for use in their federated learning research projects and I noticed that installing FedJax is pulling in TensorFlow Federated (0.17) and TensorFlow (2.3). I don't see either of these listed as dependencies of FedJax so I am trying to understand why they are being pulled in by pip install fedjax.

    opened by davidrpugh 7
  • CIFAR 100 Questions

    CIFAR 100 Questions

    Hi, thanks for the awesome library! I want to ask a couple of questions related to CIFAR100 datasets.

    1. I noticed that while the dataset is available in the library, the model is not. Curious if a model for CIFAR100 is work-in-progress, or if there is no short-term plan for this?
    2. Looking at the CIFAR100 dataset, this seems to be inconsistent with Google's TFF. Notably, the cropping size and normalizing are done differently from TFF. Is this intentional? Would it be correct to say that we could expect this to mirror TFF's design eventually?

    Thanks in advance for all the help!

    opened by HanGuo97 5
  • unbiased scale for DRIVE

    unbiased scale for DRIVE

    Following a discussion with @stheertha, I suggest using the unbiased scale (section 4.2 in Drive's paper) for cases where there is more than 1 client.

    Thank you for considering.

    opened by amitport 3
  • Problem of Quick Start in Readme.md

    Problem of Quick Start in Readme.md

    I tried to run the code in the QuickStart and I found some problems. federated_data = fedjax.FederatedData() can not be executed because it is an abstract class. So I replaced it as

    client_a_data = {
            'x': np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
            'y': np.array([7, 8])
        }
    client_b_data = {'x': np.array([[9.0, 10.0, 11.0]]), 'y': np.array([12])}
    client_to_data_mapping = {'a': client_a_data, 'b': client_b_data}
    federated_data = fedjax.InMemoryFederatedData(client_to_data_mapping)
    

    The other things are same as the QuickStart, but i got an error

    for client_id, client_output, _ in func(shared_input, clients):
    for client_id, client_batches, client_input in clients:
    ValueError: not enough values to unpack (expected 3, got 2)
    

    It seems that client_batches is missing and we need to batch the dataset, but there is no example which fits this situation.

    opened by Ichiruchan 2
  • Full EMNIST example does not exhibit parallelization

    Full EMNIST example does not exhibit parallelization

    Hi! I am facing an issue with parallelizing the base code provided by the developers.

    • My local workstation contains two GPUs.
    • I installed FedJax in a conda environment
    • I downloaded "emnist_fed_avg.py" file from the folder "examples", deleted the "fedjax.training.set_tf_cpu_only()" line and replaced fed_avg.federated_averaging to fedjax.algorithms.fed_avg.federated_averaging on line 61
    • Having activated the conda environment, I ran the file with python emnist_fed_avg.py. The file runs correctly and prints the expected output (round nums and train/test metrics on each 10th round)
    • The nvidia-smi command shows zero percent utilization and almost zero memory usage on one of the GPUs (and ~40% utilization/maximum memory usage on another node)

    Any ideas what I am doing wrong?

    opened by gaseln 2
  • Clarifying the meaning of

    Clarifying the meaning of "weight"

    In the Intro notebook, the backward_pass_output from model.backward has a weight feature. It seems to me that this is used for performing a weighted averaging in FedAvg, but this is not clear to me how. Perhaps this can be renamed to batch_size?

    opened by Saipraneet 1
  • [NumPy] Remove references to deprecated NumPy type aliases.

    [NumPy] Remove references to deprecated NumPy type aliases.

    [NumPy] Remove references to deprecated NumPy type aliases.

    This change replaces references to a number of deprecated NumPy type aliases (np.bool, np.int, np.float, np.complex, np.object, np.str) with their recommended replacement (bool, int, float, complex, object, str).

    NumPy 1.24 drops the deprecated aliases, so we must remove uses before updating NumPy.

    opened by copybara-service[bot] 0
  • Disable pytype import error for old stax import path

    Disable pytype import error for old stax import path

    Disable pytype import error for old stax import path

    Why? The deprecated jax.experimental.stax path will soon be removed (see https://github.com/google/jax/pull/11700), and this causes pytype to fail.

    opened by copybara-service[bot] 0
  • Rename jax.experimental.stax -> jax.example_libraries.stax

    Rename jax.experimental.stax -> jax.example_libraries.stax

    Rename jax.experimental.stax -> jax.example_libraries.stax

    Why? The former name has been deprecated since JAX version 0.2.25, released in November 2021 (see https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0225-nov-10-2021), and will soon be removed.

    opened by copybara-service[bot] 0
  • Implement standard CIFAR-100 model in fedjax.models.cifar100

    Implement standard CIFAR-100 model in fedjax.models.cifar100

    Add a standard implementation of the model for the CIFAR-100 task. The dataset can be found in fedjax.datasets.cifar100.

    For the model architecture, we should follow “Adaptive Federated Optimization”. The model architecture is detailed in section 4 as a ResNet-18 (replacing batch norm with group norm). Code for this paper and a Keras implementation of the model can be found here. We suggest using either haiku or flax to implement the model for use with JAX.

    If you choose to use haiku, you can use fedjax.create_model_from_haiku to create a fedjax compatible model. If you choose to use flax, wrapping it in a fedjax.Model is fairly straightforward and we can provide guidance for this.

    A good example to follow is #265 that checks in a simple linear model for CIFAR-100 and includes the model implementation, tests, and baseline results with FedAvg using this script. Make sure to add a flags file similar to https://github.com/google/fedjax/blob/main/experiments/fed_avg/fed_avg.CIFAR100_LOGISTIC.flags and add the new task to https://github.com/google/fedjax/blob/main/fedjax/training/tasks.py.

    Thanks for your contributions!

    enhancement contributions welcome 
    opened by jaehunro 1
  • Support for manually modifying client/server learning rate

    Support for manually modifying client/server learning rate

    Hi, I'm playing around with clients learning rate but I cannot find a clean way of modifying it.

    Basically, I need to change the LR following a schedule based on the current round. Is that possible?

    Thanks

    opened by marcociccone 1
  • Support for gldv2 and inaturalist datasets

    Support for gldv2 and inaturalist datasets

    I think it would be great to port these datasets from tff to fedjax. I would be happy to make the effort and contribute to the library, but I need a bit of support from the fedjax team 🙂

    By looking at the tff codebase (gldv2, inaturalist) it looks that load_data_from_cache function creates a tfrecords file for each client.

    The only concrete classes that I see are SQLiteFederatedData and InMemoryFederatedData, but I don't think they are meant for this use case. What would be the best way to map the clients into a FederatedDataset? We could replicate something like FilePerUserClientData.

    Thanks!

    opened by marcociccone 7
  • Support for haiku models with non-trainable state

    Support for haiku models with non-trainable state

    Hi! congrats on this great library! I've started using it a few days ago and I love it!

    Is there any way to use a haiku model with a non-trainable state (e.g. to use batch norm)? I didn't find any nontrivial way, but maybe I'm missing something.

    Thanks a lot for your help!

    opened by marcociccone 2
  • How to create a validation dataset?

    How to create a validation dataset?

    Hello!

    I may need to split each client's train dataset into train and validation parts for grid search purposes (for example, tuning the stepsizes in a method). How can this be achieved in the framework?

    opened by gaseln 4
  • Feature request: Convert standard dataset into a federated dataset

    Feature request: Convert standard dataset into a federated dataset

    Synthetic federated datasets can constructed from standard centralized ones by artificially splitting them among clients. This is usually done using a Dirichlet distribution (e.g. Hsu et al. 2019). Such synthetic datasets are very useful since we can explicitly control the total number of users, as well as the heterogeneity.

    It would be great to have primitives which can automatically convert standard numpy dataset into a FedJax datset.

    contributions welcome 
    opened by Saipraneet 5
Releases(v0.0.15)
Owner
Google
Google ❤️ Open Source
Google
Static Features Classifier - A static features classifier for Point-Could clusters using an Attention-RNN model

Static Features Classifier This is a static features classifier for Point-Could

ABDALKARIM MOHTASIB 1 Jan 25, 2022
Synthetic LiDAR sequential point cloud dataset with point-wise annotations

SynLiDAR dataset: Learning From Synthetic LiDAR Sequential Point Cloud This is official repository of the SynLiDAR dataset. For technical details, ple

78 Dec 27, 2022
Using OpenAI's CLIP to upscale and enhance images

CLIP Upscaler and Enhancer Using OpenAI's CLIP to upscale and enhance images Based on nshepperd's JAX CLIP Guided Diffusion v2.4 Sample Results Viewpo

Tripp Lyons 5 Jun 14, 2022
A python package simulating the quasi-2D pseudospin-1/2 Gross-Pitaevskii equation with NVIDIA GPU acceleration.

A python package simulating the quasi-2D pseudospin-1/2 Gross-Pitaevskii equation with NVIDIA GPU acceleration. Introduction spinor-gpe is high-level,

2 Sep 20, 2022
PyTorch Implementation of CvT: Introducing Convolutions to Vision Transformers

CvT: Introducing Convolutions to Vision Transformers Pytorch implementation of CvT: Introducing Convolutions to Vision Transformers Usage: img = torch

Rishikesh (ऋषिकेश) 193 Jan 03, 2023
[NeurIPS 2021] Introspective Distillation for Robust Question Answering

Introspective Distillation (IntroD) This repository is the Pytorch implementation of our paper "Introspective Distillation for Robust Question Answeri

Yulei Niu 13 Jul 26, 2022
Share a benchmark that can easily apply reinforcement learning in Job-shop-scheduling

Gymjsp Gymjsp is an open source Python library, which uses the OpenAI Gym interface for easily instantiating and interacting with RL environments, and

134 Dec 08, 2022
YOLOv2 in PyTorch

YOLOv2 in PyTorch NOTE: This project is no longer maintained and may not compatible with the newest pytorch (after 0.4.0). This is a PyTorch implement

Long Chen 1.5k Jan 02, 2023
Complete the code of prefix-tuning in low data setting

Prefix Tuning Note: 作者在论文中提到使用真实的word去初始化prefix的操作(Initializing the prefix with activations of real words,significantly improves generation)。我在使用作者提供的

Andrew Zeng 4 Jul 11, 2022
MINERVA: An out-of-the-box GUI tool for offline deep reinforcement learning

MINERVA is an out-of-the-box GUI tool for offline deep reinforcement learning, designed for everyone including non-programmers to do reinforcement learning as a tool.

Takuma Seno 80 Nov 06, 2022
Implementation of E(n)-Transformer, which extends the ideas of Welling's E(n)-Equivariant Graph Neural Network to attention

E(n)-Equivariant Transformer (wip) Implementation of E(n)-Equivariant Transformer, which extends the ideas from Welling's E(n)-Equivariant G

Phil Wang 132 Jan 02, 2023
A Factor Model for Persistence in Investment Manager Performance

Factor-Model-Manager-Performance A Factor Model for Persistence in Investment Manager Performance I apply methods and processes similar to those used

Omid Arhami 1 Dec 01, 2021
Colar: Effective and Efficient Online Action Detection by Consulting Exemplars, CVPR 2022.

Colar: Effective and Efficient Online Action Detection by Consulting Exemplars This repository is the official implementation of Colar. In this work,

LeYang 246 Dec 13, 2022
The Official Repository for "Generalized OOD Detection: A Survey"

Generalized Out-of-Distribution Detection: A Survey 1. Overview This repository is with our survey paper: Title: Generalized Out-of-Distribution Detec

Jingkang Yang 338 Jan 03, 2023
N-Omniglot is a large neuromorphic few-shot learning dataset

N-Omniglot [Paper] || [Dataset] N-Omniglot is a large neuromorphic few-shot learning dataset. It reconstructs strokes of Omniglot as videos and uses D

11 Dec 05, 2022
A state of the art of new lightweight YOLO model implemented by TensorFlow 2.

CSL-YOLO: A New Lightweight Object Detection System for Edge Computing This project provides a SOTA level lightweight YOLO called "Cross-Stage Lightwe

Miles Zhang 54 Dec 21, 2022
Human pose estimation from video plays a critical role in various applications such as quantifying physical exercises, sign language recognition, and full-body gesture control.

Pose Detection Project Description: Human pose estimation from video plays a critical role in various applications such as quantifying physical exerci

Hassan Shahzad 2 Jan 17, 2022
Code for the paper "Ordered Neurons: Integrating Tree Structures into Recurrent Neural Networks"

ON-LSTM This repository contains the code used for word-level language model and unsupervised parsing experiments in Ordered Neurons: Integrating Tree

Yikang Shen 572 Nov 21, 2022
Lip Reading - Cross Audio-Visual Recognition using 3D Convolutional Neural Networks

Lip Reading - Cross Audio-Visual Recognition using 3D Convolutional Neural Networks - Official Project Page This repository contains the code develope

Amirsina Torfi 1.7k Dec 18, 2022
A TensorFlow 2.x implementation of Masked Autoencoders Are Scalable Vision Learners

Masked Autoencoders Are Scalable Vision Learners A TensorFlow implementation of Masked Autoencoders Are Scalable Vision Learners [1]. Our implementati

Aritra Roy Gosthipaty 59 Dec 10, 2022