Toolkit for building machine learning models that generalize to unseen domains and are robust to privacy and other attacks.

Overview

Toolkit for Building Robust ML models that generalize to unseen domains (RobustDG)

Divyat Mahajan, Shruti Tople, Amit Sharma

Privacy & Causal Learning (ICML 2020) | MatchDG: Causal View of DG (ICML 2021) | Privacy & DG Connection paper

For machine learning models to be reliable, they need to generalize to data beyond the train distribution. In addition, ML models should be robust to privacy attacks like membership inference and domain knowledge-based attacks like adversarial attacks.

To advance research in building robust and generalizable models, we are releasing a toolkit for building and evaluating ML models, RobustDG. RobustDG contains implementations of domain generalization algorithms and includes evaluation benchmarks based on out-of-distribution accuracy and robustness to membership privacy attacks. We will be adding evaluation for adversarial attacks and more privacy attacks soon.

It is easily extendable. Add your own DG algorithms and evaluate them on different benchmarks.

Installation

To use the command-line interface of RobustDG, clone this repo and add the folder to your system's PATH (or alternatively, run the commands from the RobustDG root directory).

Load dataset

Let's first load the rotatedMNIST dataset in a suitable format for the resnet18 architecture.

python data/data_gen_mnist.py --dataset rot_mnist --model resnet18 --img_h 224 --img_w 224 --subset_size 2000

Train and evaluate ML model

The following commands would train and evalute the MatchDG method on the Rotated MNIST dataset.

python train.py --dataset rot_mnist --method_name matchdg_ctr --match_case 0.0 --match_flag 1 --epochs 50 --batch_size 64 --pos_metric cos --match_func_aug_case 1

python train.py --dataset rot_mnist --method_name matchdg_erm --penalty_ws 0.1 --match_case -1 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18 --epochs 25

python test.py --dataset rot_mnist --method_name matchdg_erm --penalty_ws 0.1 --match_case -1 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18 --epochs 25 --test_metric acc

python test.py --dataset rot_mnist --method_name matchdg_ctr --match_case 0.0 --match_flag 1 --pos_metric cos --test_metric match_score

Demo

A quick introduction on how to use our repository can be accessed here in the Getting Started notebook.

If you are interested in reproducing results from the MatchDG paper, check out the Reproducing results notebook.

Roadmap

  • Support for more domain generalization algorithms like CSD and IRM. If you are an author of a DG algorithm and would like to contribute, please raise a pull request here or get in touch.
  • More evaluation metrics based on adversarial attacks, privacy attacks like model inversion. If you'd like to see an evaluation metric implemented, please raise an issue here.

Contributing

This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.

When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA.

This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact [email protected] with any additional questions or comments.

Owner
Microsoft
Open source projects and samples from Microsoft
Microsoft
A Python package to preprocess time series

Disclaimer: This package is WIP. Do not take any APIs for granted. tspreprocess Time series can contain noise, may be sampled under a non fitting rate

Maximilian Christ 57 Dec 17, 2022
Timeseries analysis for neuroscience data

=================================================== Nitime: timeseries analysis for neuroscience data ===============================================

NIPY developers 212 Dec 09, 2022
Lingtrain Alignment Studio is an ML based app for texts alignment on different languages.

Lingtrain Alignment Studio Intro Lingtrain Alignment Studio is the ML based app for accurate texts alignment on different languages. Extracts parallel

Sergei Averkiev 186 Jan 03, 2023
Dragonfly is an open source python library for scalable Bayesian optimisation.

Dragonfly is an open source python library for scalable Bayesian optimisation. Bayesian optimisation is used for optimising black-box functions whose

744 Jan 02, 2023
fastFM: A Library for Factorization Machines

Citing fastFM The library fastFM is an academic project. The time and resources spent developing fastFM are therefore justified by the number of citat

1k Dec 24, 2022
Azure Cloud Advocates at Microsoft are pleased to offer a 12-week, 24-lesson curriculum all about Machine Learning

Azure Cloud Advocates at Microsoft are pleased to offer a 12-week, 24-lesson curriculum all about Machine Learning

Microsoft 43.4k Jan 04, 2023
Estudos e projetos feitos com PySpark.

PySpark (Spark com Python) PySpark é uma biblioteca Spark escrita em Python, e seu objetivo é permitir a análise interativa dos dados em um ambiente d

Karinne Cristina 54 Nov 06, 2022
CVXPY is a Python-embedded modeling language for convex optimization problems.

CVXPY The CVXPY documentation is at cvxpy.org. We are building a CVXPY community on Discord. Join the conversation! For issues and long-form discussio

4.3k Jan 08, 2023
a distributed deep learning platform

Apache SINGA Distributed deep learning system http://singa.apache.org Quick Start Installation Examples Issues JIRA tickets Code Analysis: Mailing Lis

The Apache Software Foundation 2.7k Jan 05, 2023
A simple machine learning package to cluster keywords in higher-level groups.

Simple Keyword Clusterer A simple machine learning package to cluster keywords in higher-level groups. Example: "Senior Frontend Engineer" -- "Fronte

Andrea D'Agostino 10 Dec 18, 2022
MLReef is an open source ML-Ops platform that helps you collaborate, reproduce and share your Machine Learning work with thousands of other users.

The collaboration platform for Machine Learning MLReef is an open source ML-Ops platform that helps you collaborate, reproduce and share your Machine

MLReef 1.4k Dec 27, 2022
CD) in machine learning projectsImplementing continuous integration & delivery (CI/CD) in machine learning projects

CML with cloud compute This repository contains a sample project using CML with Terraform (via the cml-runner function) to launch an AWS EC2 instance

Iterative 19 Oct 03, 2022
Machine learning model evaluation made easy: plots, tables, HTML reports, experiment tracking and Jupyter notebook analysis.

sklearn-evaluation Machine learning model evaluation made easy: plots, tables, HTML reports, experiment tracking, and Jupyter notebook analysis. Suppo

Eduardo Blancas 354 Dec 31, 2022
ELI5 is a Python package which helps to debug machine learning classifiers and explain their predictions

A library for debugging/inspecting machine learning classifiers and explaining their predictions

154 Dec 17, 2022
A machine learning web application for binary classification using streamlit

Machine Learning web App This is a machine learning web application for binary classification using streamlit options this application contains 3 clas

abdelhak mokri 1 Dec 20, 2021
This repository demonstrates the usage of hover to understand and supervise a machine learning task.

Hover Example Apps (works out-of-the-box on Binder) This repository demonstrates the usage of hover to understand and supervise a machine learning tas

Pavel 43 Dec 03, 2021
WAGMA-SGD is a decentralized asynchronous SGD for distributed deep learning training based on model averaging.

WAGMA-SGD is a decentralized asynchronous SGD based on wait-avoiding group model averaging. The synchronization is relaxed by making the collectives externally-triggerable, namely, a collective can b

Shigang Li 6 Jun 18, 2022
Toolkit for building machine learning models that generalize to unseen domains and are robust to privacy and other attacks.

Toolkit for Building Robust ML models that generalize to unseen domains (RobustDG) Divyat Mahajan, Shruti Tople, Amit Sharma Privacy & Causal Learning

Microsoft 149 Jan 06, 2023
Python ML pipeline that showcases mltrace functionality.

mltrace tutorial Date: October 2021 This tutorial builds a training and testing pipeline for a toy ML prediction problem: to predict whether a passeng

Log Labs 28 Nov 09, 2022
ParaMonte is a serial/parallel library of Monte Carlo routines for sampling mathematical objective functions of arbitrary-dimensions

ParaMonte is a serial/parallel library of Monte Carlo routines for sampling mathematical objective functions of arbitrary-dimensions, in particular, the posterior distributions of Bayesian models in

Computational Data Science Lab 182 Dec 31, 2022