GithubHelp home page GithubHelp logo

vanderschaarlab / synthcity Goto Github PK

View Code? Open in Web Editor NEW
354.0 12.0 47.0 7.06 MB

A library for generating and evaluating synthetic tabular data for privacy, fairness and data augmentation.

Home Page: https://www.vanderschaar-lab.com/

License: Apache License 2.0

Python 81.25% Jupyter Notebook 18.75%
pytorch tabular-data privacy machine-learning generative-model data-augmentation fairness-ml synthetic-data

synthcity's People

Contributors

2045ga avatar bcebere avatar bvanbreugel avatar dependabot[bot] avatar drshushen avatar eltociear avatar gsel9 avatar hlasse avatar pravsels avatar robsdavis avatar seedatnabeel avatar vholstein avatar zhaozhiqian avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

synthcity's Issues

Import rdt fail

Running (first commands in README)
f```
rom synthcity.plugins import Plugins
Plugins(categories=["generic"]).list()

results in error:
`cannot import name 'ClusterBasedNormalizer' from 'rdt.transformers`

Bayesian network

I checked all the mainstream Bayesian network libraries in Python but none of them supports continuous or mixed data types.

Hence, I propose to do the following:

  1. Discretize the continuous variable, e.g. using sklearn KBinsDiscretizer
  2. Fit the BN on discretized data
  3. During sampling, first generate the discrete bin id using BN, then randomly sample a continuous value in the bin range.

[Bug] IntegerDistribution returns float

Sampling from an IntegerDistribution returns a float type. This causes the batch_size to be float, which will trigger an exception when running tvae.

syn_model = Plugins().get("tvae")
params = syn_model.hyperparameter_space()
param_val = [x.sample()[0] for x in params]
param_name = [x.name for x in params]

param_dict = dict(zip(param_name, param_val))
isinstance(param_dict['batch_size'], int)

returns false.

Passing the float batch size triggers the following exception when running tvae

[2022-06-08T19:50:33.558025+0000][297][CRITICAL] [tvae][param 19][take 0] failed: batch_size should be a positive integer value, but got batch_size=150

[Install] pytorch_wavelets dependency

This is a low priority issue. Fix it only when you have time.

Synthcity now depends on the library pytorch_wavelets. This library cannot be automatically pip installed; instead one has to download it from github and then install it in the directory. This might make it difficult for new users to install.

Is there a possible workaround? If no, we need to update the installation guide.

Also please add PyWavelets to the dependency.

Input format for time series data

Question

Which input format is required for time series data?

Further Information

Dear SynthCity developers, I really like your work and wanted to test out the package on my own time series dataset. I have a dataset with phone data consisting of passive sensing, sampled daily with some days missing for some individuals. Number of days of collected data varies between individuals. To familiarize me with the required input format I went through the PBC dataset.

loader = TimeSeriesDataLoader(temporal_data=temporal,
                                                    observation_times=temporal_horizons,
                                                    outcome=outcome_surv,
                                                    static_data=static_surv)

As far as I understand, temporal_data is a list of dataframes of variable length containing variables of interest and time as an index column. The observation_times is a list of lists with the timestamps for each observation in a list. outcome is a tuple with two series of outcomes, and static_data is just a dataframe.

If I understand correctly I'd have to split temporal features into multiple dataframes, make timestamps the index and put these in a list. Then I'd generate lists of the timestamps for each dataframe add them to the list of observation times and select a list of outcome and static features with the same ordering as the two lists. Before I mess up the analysis, is there anything I'm missing here?

If this works out I'd be willing to write a short tutorial on this - could help other labs import their own data.

Integrate jaxtyping for advanced parameter validation

Description

Right now, synthcity uses pydantic for validating the parameters for various functions.

An improvement on top of that would be to integrate jaxtyping, which allows for validating tensor shapes as well
jaxtyping supports PyTorch tensors and numpy arrays.

Example

from jaxtyping import Array, Float, PyTree

# Accepts floating-point 2D arrays with matching dimensions
def matrix_multiply(x: Float[Array, "dim1 dim2"],
                    y: Float[Array, "dim2 dim3"]
                  ) -> Float[Array, "dim1 dim3"]:
    ...

def accepts_pytree_of_ints(x: PyTree[int]):
    ...

def accepts_pytree_of_arrays(x: PyTree[Float[Array, "batch c1 c2"]]):
    ...

https://github.com/google/jaxtyping

Dataloader train_size argument not passed

Custom dataloaders (e.g. GenericDataLoader) do not pass train_size to the DataLoader initialisation (e.g. "train_size=train_size," missing in line 258, etc), thus dataloaders are always using default train_size=0.8

[Notebook] Benchmark argument change

Low priority issue related to notebooks.

In commit #42 the Benchmarks.evaluate takes

tests: List[Tuple[str, str, dict]], # test name, plugin name, plugin args

But in the notebooks, it takes

plugins: List,

Need to update notebooks to reflect the change.

Can't suppress warnings when evaluating xgb performance

When evaluating xgb performance metric for dpgan and pategan synthetic models, the console is spammed with warnings from xgbse. warnings.filterwarnings("ignore") does not suppress them.

Here's the code I'm running.

syn_model = serialization.load_from_file("some_saved_dpgan_model.bkp")
selected_metrics = {
    'performance': ['xgb'],
}
my_metrics = Metrics()
selected_metrics_in_my_metrics = {k: my_metrics.list()[k] for k in my_metrics.list().keys() & selected_metrics.keys()}
X_syn = syn_model.generate(count=6882)
evaluation = my_metrics.evaluate(
    loader,
    X_syn,
    task_type="survival_analysis",
    metrics=selected_metrics_in_my_metrics,
    workspace="workspace",
)

Early stopping

A question that we are almost certain to get is how to set the number of training iterations.

I propose to implement an early stopping mechanism that the user can choose to enable. The user can supply a dictionary of {metric: weight}. We calculate the weighted sum of several metrics (e.g. 0.8 * MMD + 0.2 * performance), and do early stopping on that (they also specify patience parameters and so on).

[Plugin] Saving generative models

Hi Bogdan, what's the best way to save a trained generator?

I tried pickle on CT-GAN but it has an error:

_pickle.PicklingError: Can't pickle <class 'plugin_ctgan.py.CTGANPlugin'>: import of module 'plugin_ctgan.py' failed

Do you think we can add a save (and load) method for the plugin class?

Progress bar and logging during training

The training procedure can take a long time. We should add a progress bar or some logging message during training to inform the user. They can control the verbosity of the message by changing the logging level.

Checking directory exists before saving to file

Description

The save_to_file function (utils/serialization.py) does not check if the file directory exists. When it does not, it returns a FileNotFound error. The improvement is about adding the additional check, and create the directory if it does not exist, before writing to the file.

Are you interested in working on this improvement yourself?

  • Yes, I am.

Additional Context

Note the directory 'saved_models/' does not exist.

 19 def save_to_file(path: Union[str, Path], model: Any) -> Any:
---> 20     with open(path, "wb") as f:
     21         return cloudpickle.dump(model, f)

FileNotFoundError: [Errno 2] No such file or directory: 'saved_models/XXX.bkp'

Error in fitting privbayes on categorical data

I'm hitting an error when fitting privbayes on a dataset containing both numerical fields and categorical text fields. I do not seem to hit the same error for a datset soley comprised of numerical data.

The code:

X = pd.read_csv("...") # Read a csv file containing both numerical fields and categorical text fields.
loader = GenericDataLoader(X, target_column="some_column", sensitive_features=["some_sensitive_columns],)
syn_model = Plugins().get("privbayes")
syn_model.fit(loader)

Here's the traceback:
"""
Traceback (most recent call last):
File "tutorials/privbayes_error.py", line 29, in
syn_model.fit(loader)
File "pydantic/decorator.py", line 40, in pydantic.decorator.validate_arguments.validate.wrapper_function
from contextlib import _GeneratorContextManager
File "pydantic/decorator.py", line 134, in pydantic.decorator.ValidatedFunction.call

File "pydantic/decorator.py", line 206, in pydantic.decorator.ValidatedFunction.execute

File "/home/rob/miniconda3/envs/synthcity/lib/python3.8/site-packages/synthcity/plugins/core/plugin.py", line 183, in fit
return self._fit(X, *args, **kwargs)
File "/home/rob/miniconda3/envs/synthcity/lib/python3.8/site-packages/synthcity/plugins/generic/plugin_privbayes.py", line 576, in _fit
self.model.fit(X.dataframe())
File "pydantic/decorator.py", line 40, in pydantic.decorator.validate_arguments.validate.wrapper_function
from contextlib import _GeneratorContextManager
File "pydantic/decorator.py", line 134, in pydantic.decorator.ValidatedFunction.call

File "pydantic/decorator.py", line 206, in pydantic.decorator.ValidatedFunction.execute

File "/home/rob/miniconda3/envs/synthcity/lib/python3.8/site-packages/synthcity/plugins/generic/plugin_privbayes.py", line 109, in fit
self.dag = self._greedy_bayes(data)
File "pydantic/decorator.py", line 40, in pydantic.decorator.validate_arguments.validate.wrapper_function
from contextlib import _GeneratorContextManager
File "pydantic/decorator.py", line 134, in pydantic.decorator.ValidatedFunction.call

File "pydantic/decorator.py", line 206, in pydantic.decorator.ValidatedFunction.execute

File "/home/rob/miniconda3/envs/synthcity/lib/python3.8/site-packages/synthcity/plugins/generic/plugin_privbayes.py", line 212, in _greedy_bayes
) = self._evaluate_parent_mutual_information(
File "pydantic/decorator.py", line 40, in pydantic.decorator.validate_arguments.validate.wrapper_function
from contextlib import _GeneratorContextManager
File "pydantic/decorator.py", line 134, in pydantic.decorator.ValidatedFunction.call

File "pydantic/decorator.py", line 206, in pydantic.decorator.ValidatedFunction.execute

File "/home/rob/miniconda3/envs/synthcity/lib/python3.8/site-packages/synthcity/plugins/generic/plugin_privbayes.py", line 430, in _evaluate_parent_mutual_information
score = self.mutual_info_score(data, parents, candidate)
File "pydantic/decorator.py", line 40, in pydantic.decorator.validate_arguments.validate.wrapper_function
from contextlib import _GeneratorContextManager
File "pydantic/decorator.py", line 134, in pydantic.decorator.ValidatedFunction.call

File "pydantic/decorator.py", line 206, in pydantic.decorator.ValidatedFunction.execute

File "/home/rob/miniconda3/envs/synthcity/lib/python3.8/site-packages/synthcity/plugins/generic/plugin_privbayes.py", line 451, in mutual_info_score
target_bins, _ = pd.cut(target, bins=self.n_bins, retbins=True)
File "/home/rob/miniconda3/envs/synthcity/lib/python3.8/site-packages/pandas/core/reshape/tile.py", line 259, in cut
mn, mx = (mi + 0.0 for mi in rng)
File "/home/rob/miniconda3/envs/synthcity/lib/python3.8/site-packages/pandas/core/reshape/tile.py", line 259, in
mn, mx = (mi + 0.0 for mi in rng)
TypeError: can only concatenate str (not "float") to str
"""

[Bug] PicklingError for several plugins

The save/load utility does not work for nflow, adsgan, privbayes, pategan, and rtvae plugins.

from synthcity.utils.serialization import save_to_file

syn_model = Plugins().get("rtvae")
save_to_file('temp.pkl', syn_model)

raises an exception:

PicklingError: Can't pickle <cyfunction RTVAEPlugin.__init__ at 0x7ff2d1faf6c0>: import of module 'plugin_rtvae.py' failed

[Install] PyWavelets

Please add PyWavelets (pywt) to the dependency

pip install PyWavelets

Note that this is different from the pytorch_wavelets library that is already included in the dependency. Thanks.

Add fairness to the metrics collection

Suggest to add various metrics for evaluating potential bias in the synthetic data wrt. a group of entities from a protected category (e.g., gender, age, race, location etc.)

[Metrics, Bug?] detection.detection_xgb

The metric detection.detection_xgb is always > 90% for all the datasets and all methods in the jupyter notebook (except for bayesian_network). The number is very high compared to detection.detection_mlp and detection.detection_gmm.

This is quite odd. Could you please take a look at it? What happens if you pass a subset of the real data as synthetic (this should give us around 50% in principle). Thanks.

[Model] Bayesian Networks

Please take a look at these two python libraries for Bayesian Networks. Let's discuss in more details in next week's catch up.

bnlearn
pgmpy

bnlearn is built on top of gpmpy. They receive many stars on github and they are actively maintained.

We need to have some Bayesian Network models in the library.

[Metrics] Review inlier/outlier metrics

Calls: evaluate_inlier_probability, evaluate_outlier_probability

The current implementation might be confusing.

The reference is Generating high-fidelity synthetic patient data for assessing
machine learning healthcare software" section "Detecting re-identification risks using outlier analysis with distance metrics."

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.