diff --git a/Dockerfile b/Dockerfile index f4685b19..b62959fe 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,10 @@ -FROM python:3.8-buster -RUN apt-get update && apt-get install -y build-essential +FROM nvidia/cuda:11.0.3-cudnn8-devel-ubuntu18.04 +CMD nvidia-smi + +RUN apt-get update && apt-get install -y build-essential && apt-get -y install curl +RUN apt-get -y install python3.8 python3-distutils && ln -s /usr/bin/python3.8 /usr/bin/python +RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && \ + python get-pip.py && ln -s /usr/bin/pip3 /usr/bin/pip RUN mkdir /SDGym && \ mkdir /SDGym/sdgym && \ @@ -13,7 +18,8 @@ COPY /privbayes/ /SDGym/privbayes WORKDIR /SDGym # Install project -RUN make install-ydata compile +RUN make install-all compile +RUN pip install -U numpy==1.20 ENV PRIVBAYES_BIN /SDGym/privbayes/privBayes.bin ENV TF_CPP_MIN_LOG_LEVEL 2 diff --git a/HISTORY.md b/HISTORY.md index 2fe72943..96e56e91 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,12 +1,29 @@ # History +## v0.4.0 - 2021-06-17 +This release fixed a bug where passing a `json` file as configuration for a multi-table synthesizer crashed the model. +It also adds a number of fixes and enhancements, including: (1) a function and CLI command to list the available synthesizer names, +(2) a curate set of dependencies and making `Gretel` into an optional dependency, (3) updating `Gretel` to use temp directories, +(4) using `nvidia-smi` to get the number of gpus and (5) multiple `dockerfile` updates to improve functionality. + +### Issues closed + +* Bug when using JSON configuration for multiple multi-table evaluation - [Issue #115](https://github.com/sdv-dev/SDGym/issues/115) by @pvk-developer +* Use nvidia-smi to get number of gpus - [PR #113](https://github.com/sdv-dev/SDGym/issues/113) by @katxiao +* List synthesizer names - [Issue #82](https://github.com/sdv-dev/SDGym/issues/82) by @fealho +* Use nvidia base for dockerfile - [PR #108](https://github.com/sdv-dev/SDGym/issues/108) by @katxiao +* Add Makefile target to install gretel and ydata - [PR #107](https://github.com/sdv-dev/SDGym/issues/107) by @katxiao +* Curate dependencies and make Gretel optional - [PR #106](https://github.com/sdv-dev/SDGym/issues/106) by @csala +* Update gretel checkpoints to use temp directory - [PR #105](https://github.com/sdv-dev/SDGym/issues/105) by @katxiao +* Initialize variable before reference - [PR #104](https://github.com/sdv-dev/SDGym/issues/104) by @katxiao + ## v0.4.0 - 2021-06-17 This release adds new synthesizers for Gretel and ydata, and creates a Docker image for SDGym. It also includes enhancements to the accepted SDGym arguments, adds a summary command to aggregate metrics, and adds the normalized score to the benchmark results. -## New Features +### New Features * Add normalized score to benchmark results - [Issue #102](https://github.com/sdv-dev/SDGym/issues/102) by @katxiao * Add max rows and max columns args - [Issue #96](https://github.com/sdv-dev/SDGym/issues/96) by @katxiao diff --git a/Makefile b/Makefile index cd9888bc..27b4a318 100644 --- a/Makefile +++ b/Makefile @@ -100,6 +100,19 @@ install-ydata-develop: clean-build clean-compile clean-pyc compile ## install th pip install 'ydata-synthetic>=0.3.0,<0.4' pip install -e .[dev] +.PHONY: install-gretel +install-gretel: clean-build clean-compile clean-pyc compile ## install the package with gretel + pip install .[gretel] + +.PHONY: install-gretel-develop +install-gretel-develop: clean-build clean-compile clean-pyc compile ## install the package with gretel and dependencies for development + pip install -e .[dev,gretel] + +.PHONY: install-all +install-all: clean-build clean-compile clean-pyc compile ## install the package with gretel and ydata + pip install 'ydata-synthetic>=0.3.0,<0.4' + pip install .[gretel] + # LINT TARGETS .PHONY: lint @@ -126,12 +139,8 @@ test-readme: ## run the readme snippets cd tests/readme_test && rundoc run --single-session python3 -t python3 ../../README.md rm -rf tests/readme_test -# .PHONY: test-tutorials -# test-tutorials: ## run the tutorial notebooks -# jupyter nbconvert --execute --ExecutePreprocessor.timeout=600 tutorials/*.ipynb --stdout > /dev/null - .PHONY: test -test: test-unit test-readme # test-tutorials ## test everything that needs test dependencies +test: test-unit test-readme ## test everything that needs test dependencies .PHONY: test-devel test-devel: lint ## test everything that needs development dependencies @@ -187,26 +196,31 @@ publish-test: dist publish-confirm ## package and upload a release on TestPyPI publish: dist publish-confirm ## package and upload a release twine upload dist/* -.PHONY: bumpversion-release -bumpversion-release: ## Merge master to stable and bumpversion release +.PHONY: git-merge-master-stable +git-merge-master-stable: ## Merge master into stable git checkout stable || git checkout -b stable git merge --no-ff master -m"make release-tag: Merge branch 'master' into stable" - bumpversion release + +.PHONY: git-merge-stable-master +git-merge-stable-master: ## Merge stable into master + git checkout master + git merge stable + +.PHONY: git-push +git-push: ## Simply push the repository to github + git push + +.PHONY: git-push-tags-stable +git-push-tags-stable: ## Push tags and stable to github git push --tags origin stable -.PHONY: bumpversion-release-test -bumpversion-release-test: ## Merge master to stable and bumpversion release - git checkout stable || git checkout -b stable - git merge --no-ff master -m"make release-tag: Merge branch 'master' into stable" - bumpversion release --no-tag - @echo git push --tags origin stable +.PHONY: bumpversion-release +bumpversion-release: ## Bump the version to the next release + bumpversion release .PHONY: bumpversion-patch -bumpversion-patch: ## Merge stable to master and bumpversion patch - git checkout master - git merge stable +bumpversion-patch: ## Bump the version to the next patch bumpversion --no-tag patch - git push .PHONY: bumpversion-candidate bumpversion-candidate: ## Bump the version to the next candidate @@ -222,12 +236,13 @@ bumpversion-major: ## Bump the version the next major skipping the release .PHONY: bumpversion-revert bumpversion-revert: ## Undo a previous bumpversion-release + git tag --delete $(shell git tag --points-at HEAD) git checkout master git branch -D stable -CURRENT_VERSION := $(shell grep "^current_version" setup.cfg | grep -o "dev[0-9]*") CLEAN_DIR := $(shell git status --short | grep -v ??) CURRENT_BRANCH := $(shell git rev-parse --abbrev-ref HEAD 2>/dev/null) +CURRENT_VERSION := $(shell grep "^current_version" setup.cfg | grep -o "dev[0-9]*") CHANGELOG_LINES := $(shell git diff HEAD..origin/stable HISTORY.md 2>&1 | wc -l) .PHONY: check-clean @@ -236,18 +251,18 @@ ifneq ($(CLEAN_DIR),) $(error There are uncommitted changes) endif -.PHONY: check-candidate -check-candidate: ## Check if a release candidate has been made -ifeq ($(CURRENT_VERSION),dev0) - $(error Please make a release candidate and test it before atempting a release) -endif - .PHONY: check-master check-master: ## Check if we are in master branch ifneq ($(CURRENT_BRANCH),master) $(error Please make the release from master branch\n) endif +.PHONY: check-candidate +check-candidate: ## Check if a release candidate has been made +ifeq ($(CURRENT_VERSION),dev0) + $(error Please make a release candidate and test it before atempting a release) +endif + .PHONY: check-history check-history: ## Check if HISTORY.md has been modified ifeq ($(CHANGELOG_LINES),0) @@ -255,17 +270,18 @@ ifeq ($(CHANGELOG_LINES),0) endif .PHONY: check-release -check-release: check-candidate check-clean check-master check-history ## Check if the release can be made +check-release: check-clean check-candidate check-master check-history ## Check if the release can be made @echo "A new release can be made" .PHONY: release -release: check-release bumpversion-release publish bumpversion-patch +release: check-release git-merge-master-stable bumpversion-release git-push-tags-stable \ + publish git-merge-stable-master bumpversion-patch git-push .PHONY: release-test -release-test: check-release bumpversion-release-test publish-test bumpversion-revert +release-test: check-release git-merge-master-stable bumpversion-release bumpversion-revert .PHONY: release-candidate -release-candidate: check-master publish bumpversion-candidate +release-candidate: check-master publish bumpversion-candidate git-push .PHONY: release-candidate-test release-candidate-test: check-clean check-master publish-test diff --git a/conda/meta.yaml b/conda/meta.yaml index c60d039a..778087f9 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -1,5 +1,5 @@ {% set name = 'sdgym' %} -{% set version = '0.4.0' %} +{% set version = '0.4.1.dev3' %} package: name: "{{ name|lower }}" @@ -17,52 +17,55 @@ build: requirements: host: - - ctgan >=0.2.2.dev1,<0.3 - - gretel-synthetics >=0.15.4,<0.16 - - humanfriendly >=8.2,<9 - - numpy >=1.15.4,<2 - - pandas >=0.23.4,<2 - pip - - pomegranate >=0.13.0,<0.13.5 + - pytest-runner + - graphviz + - python >=3.6,<3.9 + - appdirs >=1.1.4,<2 + - boto3 >=1.15.0,<2 + - botocore >=1.20,<2 + - compress-pickle >=1.2.0,<2 + - humanfriendly >=8.2,<9 + - numpy >=1.18.0,<2 + - pandas >=1.1,<1.1.5 + - pomegranate >=0.13.4,<0.14.2 - psutil >=5.7,<6 - - python - - scikit-learn >=0.20,<0.24 - - scipy >=1.3.0,<2 - - sdv >=0.4.4.dev0,<0.6 + - rdt >=0.4.1 + - sdmetrics >=0.3.0 + - sdv >=0.9.0 + - scikit-learn >=0.23,<1 - tabulate >=0.8.3,<0.9 - - pytorch >=1.1.0,<2 - - tensorflow ==2.4.0rc1 - - torchvision >=0.3.0 - - tqdm >=4,<5 - - xlsxwriter >=1.2.8,<1.3 - - pytest-runner + - torch >=1.4,<2 + - tqdm >=4.14,<5 + - XlsxWriter >=1.2.8,<1.3 run: - - ctgan >=0.2.2.dev1,<0.3 - - gretel-synthetics >=0.15.4,<0.16 + - python >=3.6,<3.9 + - appdirs >=1.1.4,<2 + - boto3 >=1.15.0,<2 + - botocore >=1.20,<2 + - compress-pickle >=1.2.0,<2 - humanfriendly >=8.2,<9 - - numpy >=1.15.4,<2 - - pandas >=0.23.4,<2 - - pomegranate >=0.13.0,<0.13.5 + - numpy >=1.18.0,<2 + - pandas >=1.1,<1.1.5 + - pomegranate >=0.13.4,<0.14.2 - psutil >=5.7,<6 - - python - - scikit-learn >=0.20,<0.24 - - scipy >=1.3.0,<2 - - sdv >=0.4.4.dev0,<0.6 + - rdt >=0.4.1 + - sdmetrics >=0.3.0 + - sdv >=0.9.0 + - scikit-learn >=0.23,<1 - tabulate >=0.8.3,<0.9 - - pytorch >=1.1.0,<2 - - tensorflow ==2.4.0rc1 - - torchvision >=0.3.0 - - tqdm >=4,<5 - - xlsxwriter >=1.2.8,<1.3 + - torch >=1.4,<2 + - tqdm >=4.14,<5 + - XlsxWriter >=1.2.8,<1.3 about: home: "https://github.com/sdv-dev/SDGym" license: MIT license_family: MIT - license_file: + license_file: summary: "A framework to benchmark the performance of synthetic data generators for non-temporal tabular data" - doc_url: - dev_url: + doc_url: + dev_url: extra: recipe-maintainers: diff --git a/sdgym/__init__.py b/sdgym/__init__.py index c07605a0..4a31eff0 100644 --- a/sdgym/__init__.py +++ b/sdgym/__init__.py @@ -8,7 +8,7 @@ __copyright__ = 'Copyright (c) 2018, MIT Data To AI Lab' __email__ = 'dailabmit@gmail.com' __license__ = 'MIT' -__version__ = '0.4.0' +__version__ = '0.4.1.dev3' from sdgym import benchmark, synthesizers from sdgym.benchmark import run diff --git a/sdgym/__main__.py b/sdgym/__main__.py index 98e02658..aa31ee3d 100644 --- a/sdgym/__main__.py +++ b/sdgym/__main__.py @@ -13,6 +13,8 @@ import tqdm import sdgym +from sdgym.synthesizers.base import Baseline +from sdgym.utils import get_synthesizers def _env_setup(logfile, verbosity): @@ -134,6 +136,11 @@ def _list_available(args): _print_table(datasets, args.sort, args.reverse, {'size': humanfriendly.format_size}) +def _list_synthesizers(args): + synthesizers = Baseline.get_baselines() + _print_table(pd.DataFrame(get_synthesizers(list(synthesizers)))) + + def _collect(args): sdgym.collect.collect_results(args.input_path, args.output_file, args.aws_key, args.aws_secret) @@ -241,6 +248,11 @@ def _get_parser(): list_available.add_argument('-as', '--aws-secret', type=str, required=False, help='Aws secret access key to use when reading datasets.') + # list-synthesizers + list_available = action.add_parser('list-synthesizers', + help='List synthesizers available for use.') + list_available.set_defaults(action=_list_synthesizers) + # collect collect = action.add_parser('collect', help='Collect sdgym results.') collect.set_defaults(action=_collect) diff --git a/sdgym/benchmark.py b/sdgym/benchmark.py index 42a85f6c..35e82016 100644 --- a/sdgym/benchmark.py +++ b/sdgym/benchmark.py @@ -11,7 +11,6 @@ import compress_pickle import numpy as np import pandas as pd -import torch import tqdm from sdgym.datasets import get_dataset_paths, load_dataset, load_tables @@ -20,6 +19,7 @@ from sdgym.progress import TqdmLogger, progress from sdgym.s3 import is_s3_path, write_csv, write_file from sdgym.synthesizers.base import Baseline +from sdgym.synthesizers.utils import get_num_gpus from sdgym.utils import ( build_synthesizer, format_exception, get_synthesizers, import_object, used_memory) @@ -72,6 +72,7 @@ def _compute_scores(metrics, real_data, synthetic_data, metadata, output): error = None score = None + normalized_score = None start = datetime.utcnow() try: LOGGER.info('Computing %s on dataset %s', metric_name, metadata._metadata['name']) @@ -309,8 +310,9 @@ def run(synthesizers=None, datasets=None, datasets_path=None, modalities=None, b run_id = os.getenv('RUN_ID') or str(uuid.uuid4())[:10] if workers == -1: - if torch.cuda.is_available(): - workers = torch.cuda.device_count() + num_gpus = get_num_gpus() + if num_gpus > 0: + workers = num_gpus else: workers = multiprocessing.cpu_count() diff --git a/sdgym/synthesizers/__init__.py b/sdgym/synthesizers/__init__.py index 858a8957..8336ce23 100644 --- a/sdgym/synthesizers/__init__.py +++ b/sdgym/synthesizers/__init__.py @@ -23,7 +23,6 @@ 'CTGAN', 'Uniform', 'VEEGAN', - 'CTGAN', 'CopulaGAN', 'GaussianCopulaCategorical', 'GaussianCopulaCategoricalFuzzy', diff --git a/sdgym/synthesizers/base.py b/sdgym/synthesizers/base.py index 4cbf94f8..bd649971 100644 --- a/sdgym/synthesizers/base.py +++ b/sdgym/synthesizers/base.py @@ -1,3 +1,4 @@ +import abc import logging import pandas as pd @@ -8,7 +9,7 @@ LOGGER = logging.getLogger(__name__) -class Baseline: +class Baseline(abc.ABC): """Base class for all the ``SDGym`` baselines.""" MODALITIES = () @@ -31,11 +32,21 @@ def get_subclasses(cls, include_parents=False): return subclasses + @classmethod + def get_baselines(cls): + subclasses = cls.get_subclasses(include_parents=True) + synthesizers = [] + for _, subclass in subclasses.items(): + if abc.ABC not in subclass.__bases__: + synthesizers.append(subclass) + + return synthesizers + def fit_sample(self, real_data, metadata): pass -class SingleTableBaseline(Baseline): +class SingleTableBaseline(Baseline, abc.ABC): """Base class for all the SingleTable Baselines. Subclasses can choose to implement ``_fit_sample``, which will @@ -77,7 +88,7 @@ def fit_sample(self, real_data, metadata): return _fit_sample(real_data, metadata) -class MultiSingleTableBaseline(Baseline): +class MultiSingleTableBaseline(Baseline, abc.ABC): """Base class for SingleTableBaselines that are used on multi table scenarios. These classes model and sample each table independently and then just @@ -111,7 +122,7 @@ def fit_sample(self, real_data, metadata): return self._fit_sample(real_data, metadata) -class LegacySingleTableBaseline(SingleTableBaseline): +class LegacySingleTableBaseline(SingleTableBaseline, abc.ABC): """Single table baseline which passes ordinals and categoricals down. This class exists here to support the legacy baselines which do not operate diff --git a/sdgym/synthesizers/gretel.py b/sdgym/synthesizers/gretel.py index 2cdfdfd0..ec3b3dfb 100644 --- a/sdgym/synthesizers/gretel.py +++ b/sdgym/synthesizers/gretel.py @@ -1,19 +1,24 @@ -import os +import tempfile import numpy as np -from gretel_synthetics.batch import DataFrameBatch from sdgym.synthesizers.base import SingleTableBaseline +try: + from gretel_synthetics.batch import DataFrameBatch +except ImportError: + DataFrameBatch = None + class Gretel(SingleTableBaseline): """Class to represent Gretel's neural network model.""" - DEFAULT_CHECKPOINT_DIR = os.path.join(os.getcwd(), 'checkpoints') - def __init__(self, max_lines=0, max_line_len=2048, epochs=None, vocab_size=20000, gen_lines=None, dp=False, field_delimiter=",", overwrite=True, - checkpoint_dir=DEFAULT_CHECKPOINT_DIR): + checkpoint_dir=None): + if DataFrameBatch is None: + raise ImportError('Please install gretel-synthetics using `pip install sdgym[gretel]`') + self.max_lines = max_lines self.max_line_len = max_line_len self.epochs = epochs @@ -22,7 +27,7 @@ def __init__(self, max_lines=0, max_line_len=2048, epochs=None, vocab_size=20000 self.dp = dp self.field_delimiter = field_delimiter self.overwrite = overwrite - self.checkpoint_dir = checkpoint_dir + self.checkpoint_dir = checkpoint_dir or tempfile.TemporaryDirectory().name def _fit_sample(self, data, metadata): config = { diff --git a/sdgym/synthesizers/sdv.py b/sdgym/synthesizers/sdv.py index af615666..cecd7995 100644 --- a/sdgym/synthesizers/sdv.py +++ b/sdgym/synthesizers/sdv.py @@ -1,3 +1,4 @@ +import abc import logging import sdv @@ -9,7 +10,7 @@ LOGGER = logging.getLogger(__name__) -class SDV(Baseline): +class SDV(Baseline, abc.ABC): MODALITIES = ('single-table', 'multi-table') @@ -22,7 +23,7 @@ def fit_sample(self, data, metadata): return model.sample_all() -class SDVTabular(SingleTableBaseline): +class SDVTabular(SingleTableBaseline, abc.ABC): MODALITIES = ('single-table', ) _MODEL = None @@ -62,7 +63,7 @@ class GaussianCopulaOneHot(SDVTabular): } -class CUDATabular(SDVTabular): +class CUDATabular(SDVTabular, abc.ABC): def _fit_sample(self, data, metadata): LOGGER.info('Fitting %s', self.__class__.__name__) @@ -90,7 +91,7 @@ class CopulaGAN(CUDATabular): _MODEL = sdv.tabular.CopulaGAN -class SDVRelational(Baseline): +class SDVRelational(Baseline, abc.ABC): MODALITIES = ('single-table', 'multi-table') _MODEL = None @@ -111,7 +112,7 @@ class HMA1(SDVRelational): _MODEL = sdv.relational.HMA1 -class SDVTimeseries(SingleTableBaseline): +class SDVTimeseries(SingleTableBaseline, abc.ABC): MODALITIES = ('timeseries', ) _MODEL = None diff --git a/sdgym/synthesizers/utils.py b/sdgym/synthesizers/utils.py index 9f908b9e..192efd79 100644 --- a/sdgym/synthesizers/utils.py +++ b/sdgym/synthesizers/utils.py @@ -2,7 +2,6 @@ import numpy as np import pandas as pd -import torch from sklearn.mixture import BayesianGaussianMixture, GaussianMixture from sklearn.preprocessing import KBinsDiscretizer @@ -438,16 +437,21 @@ def inverse_transform(self, data): return data_t -def select_device(): - if not torch.cuda.is_available(): - return 'cpu' +def get_num_gpus(): + try: + command = ['nvidia-smi', '--query-gpu=utilization.gpu', '--format=csv,noheader,nounits'] + output = subprocess.run(command, stdout=subprocess.PIPE) + return len(output.stdout.decode().split()) + except Exception: + return 0 + +def select_device(): try: command = ['nvidia-smi', '--query-gpu=utilization.gpu', '--format=csv,noheader,nounits'] output = subprocess.run(command, stdout=subprocess.PIPE) loads = np.array(output.stdout.decode().split()).astype(float) device = loads.argmin() + return f'cuda:{device}' except Exception: - device = np.random.randint(torch.cuda.device_count()) - - return f'cuda:{device}' + return 'cpu' diff --git a/sdgym/synthesizers/ydata.py b/sdgym/synthesizers/ydata.py index 7aaeefe8..af8ace7f 100644 --- a/sdgym/synthesizers/ydata.py +++ b/sdgym/synthesizers/ydata.py @@ -1,12 +1,14 @@ +import abc + +from sdgym.synthesizers.base import SingleTableBaseline + try: import ydata_synthetic.synthesizers.regular as ydata except ImportError: ydata = None -from sdgym.synthesizers.base import SingleTableBaseline - -class YData(SingleTableBaseline): +class YData(SingleTableBaseline, abc.ABC): def _fit_sample(self, real_data, table_metadata): if ydata is None: diff --git a/sdgym/utils.py b/sdgym/utils.py index 86ba9df8..12fcee12 100644 --- a/sdgym/utils.py +++ b/sdgym/utils.py @@ -1,5 +1,6 @@ """Random utils used by SDGym.""" +import copy import importlib import json import logging @@ -81,7 +82,7 @@ def _get_synthesizer(synthesizer, name=None): with open(synthesizer, 'r') as json_file: return json.load(json_file) - baselines = Baseline.get_subclasses() + baselines = Baseline.get_subclasses(include_parents=True) if synthesizer in baselines: LOGGER.info('Trying to import synthesizer by name.') synthesizer = baselines[synthesizer] @@ -187,14 +188,17 @@ def build_synthesizer(synthesizer, synthesizer_dict): callable: The synthesizer function """ + + _synthesizer_dict = copy.deepcopy(synthesizer_dict) + def _synthesizer_function(real_data, metadata): - metadata_keyword = synthesizer_dict.get('metadata', '$metadata') - real_data_keyword = synthesizer_dict.get('real_data', '$real_data') - device_keyword = synthesizer_dict.get('device', '$device') - device_attribute = synthesizer_dict.get('device_attribute') + metadata_keyword = _synthesizer_dict.get('metadata', '$metadata') + real_data_keyword = _synthesizer_dict.get('real_data', '$real_data') + device_keyword = _synthesizer_dict.get('device', '$device') + device_attribute = _synthesizer_dict.get('device_attribute') device = select_device() - multi_table = 'multi-table' in synthesizer_dict['modalities'] + multi_table = 'multi-table' in _synthesizer_dict['modalities'] if not multi_table: table = metadata.get_tables()[0] metadata = metadata.get_table_meta(table) @@ -206,8 +210,8 @@ def _synthesizer_function(real_data, metadata): (device_keyword, device), ] - init_kwargs = _get_kwargs(synthesizer_dict, 'init', replace) - fit_kwargs = _get_kwargs(synthesizer_dict, 'fit', replace) + init_kwargs = _get_kwargs(_synthesizer_dict, 'init', replace) + fit_kwargs = _get_kwargs(_synthesizer_dict, 'fit', replace) instance = synthesizer(**init_kwargs) if device_attribute: diff --git a/setup.cfg b/setup.cfg index 576868f3..0857a96f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.4.0 +current_version = 0.4.1.dev3 commit = True tag = True parse = (?P\d+)\.(?P\d+)\.(?P\d+)(\.(?P[a-z]+)(?P\d+))? diff --git a/setup.py b/setup.py index c5a67950..a5dd4bb4 100644 --- a/setup.py +++ b/setup.py @@ -12,32 +12,44 @@ history = history_file.read() install_requires = [ - 'appdirs>1.1.4,<2', + 'appdirs>=1.1.4,<2', 'boto3>=1.15.0,<2', + 'botocore>=1.20,<2', 'compress-pickle>=1.2.0,<2', - 'gretel-synthetics>=0.15.4,<0.16', 'humanfriendly>=8.2,<9', - 'numpy>=1.15.4,<1.20', - 'pandas<1.1.5,>=1.1', - 'pomegranate>=0.13.0,<0.13.5', + 'numpy>=1.18.0,<2', + 'pandas>=1.1,<1.1.5', + 'pomegranate>=0.13.4,<0.14.2', 'psutil>=5.7,<6', - 'scikit-learn>=0.20,<1', - 'tabulate>=0.8.3,<0.9', - 'torch>=1.1.0,<2', - 'tqdm>=4,<5', - 'XlsxWriter>=1.2.8,<1.3', 'rdt>=0.4.1', + 'scikit-learn>=0.23,<1', + 'scipy>=1.4.1,<1.7', 'sdmetrics>=0.3.0', 'sdv>=0.9.0', - 'tensorflow==2.4.0rc1', - 'wheel~=0.35', + 'tabulate>=0.8.3,<0.9', + 'torch>=1.4,<2', + 'tqdm>=4.14,<5', + 'XlsxWriter>=1.2.8,<1.3', +] + + +dask_requires = [ + 'dask', + 'distributed', ] + ydata_requires = [ # preferably install using make install-ydata 'ydata-synthetic>=0.3.0,<0.4', ] +gretel_requires = [ + 'gretel-synthetics>=0.15.4,<0.16', + 'tensorflow==2.4.0rc1', + 'wheel~=0.35', +] + setup_requires = [ 'pytest-runner>=2.11.1', ] @@ -103,8 +115,11 @@ ], }, extras_require={ - 'dev': development_requires + tests_require, + 'all': development_requires + tests_require + dask_requires + gretel_requires, + 'dev': development_requires + tests_require + dask_requires, 'test': tests_require, + 'gretel': gretel_requires, + 'dask': dask_requires, }, include_package_data=True, install_requires=install_requires, @@ -119,6 +134,6 @@ test_suite='tests', tests_require=tests_require, url='https://github.com/sdv-dev/SDGym', - version='0.4.0', + version='0.4.1.dev3', zip_safe=False, ) diff --git a/tests/integration/test_benchmark.py b/tests/integration/test_benchmark.py index c2984724..87b78d19 100644 --- a/tests/integration/test_benchmark.py +++ b/tests/integration/test_benchmark.py @@ -39,11 +39,11 @@ def test_identity_jobs(): def test_json_synthesizer(): synthesizer = { - "name": "synthesizer_name", - "synthesizer": "sdgym.synthesizers.ydata.PreprocessedVanillaGAN", - "modalities": ["single-table"], - "init_kwargs": {"categorical_transformer": "label_encoding"}, - "fit_kwargs": {"data": "$real_data"} + 'name': 'synthesizer_name', + 'synthesizer': 'sdgym.synthesizers.ydata.PreprocessedVanillaGAN', + 'modalities': ['single-table'], + 'init_kwargs': {'categorical_transformer': 'label_encoding'}, + 'fit_kwargs': {'data': '$real_data'} } output = sdgym.run( @@ -52,4 +52,28 @@ def test_json_synthesizer(): iterations=1, ) - assert set(output['synthesizer']) == {"synthesizer_name"} + assert set(output['synthesizer']) == {'synthesizer_name'} + + +def test_json_synthesizer_multi_table(): + synthesizer = { + 'name': 'HMA1', + 'synthesizer': 'sdv.relational.HMA1', + 'modalities': [ + 'multi-table' + ], + 'init_kwargs': { + 'metadata': '$metadata' + }, + 'fit_kwargs': { + 'tables': '$real_data' + } + } + + output = sdgym.run( + synthesizers=[json.dumps(synthesizer)], + datasets=['university_v1', 'trains_v1'], + iterations=1, + ) + + assert not output.error.any() diff --git a/tests/unit/test_datasets.py b/tests/unit/test_datasets.py index 89d25703..9c55f193 100644 --- a/tests/unit/test_datasets.py +++ b/tests/unit/test_datasets.py @@ -9,6 +9,7 @@ class AnyConfigWith: """AnyConfigWith matches any s3 config with the specified signature version.""" + def __init__(self, signature_version): self.signature_version = signature_version