Skip to content

Commit

Permalink
Add the check for columns
Browse files Browse the repository at this point in the history
  • Loading branch information
lajohn4747 committed Feb 26, 2024
1 parent 2d57c07 commit 8e763cd
Showing 1 changed file with 17 additions and 12 deletions.
29 changes: 17 additions & 12 deletions sdgym/synthesizers/uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,23 @@ def _get_trained_synthesizer(self, real_data, metadata):
hyper_transformer = HyperTransformer()
hyper_transformer.detect_initial_config(real_data)
if metadata:
metadata_dict = metadata.to_dict()
supported_sdtypes = hyper_transformer._get_supported_sdtypes()
config = {}
for column in metadata_dict['columns']:
sdtype = metadata_dict['columns'][column]['sdtype']
if sdtype in supported_sdtypes:
config[column] = metadata_dict['columns'][column]['sdtype']
else:
LOGGER.info(
f'Column {column} sdtype: {sdtype} is not supported, '
f'defaulting to inferred type.')
hyper_transformer.update_sdtypes(config)
metadata_dict = {}
if isinstance(metadata, dict):
metadata_dict = metadata
else:
metadata_dict = metadata.to_dict()
if 'columns' in metadata_dict:
supported_sdtypes = hyper_transformer._get_supported_sdtypes()
config = {}
for column in metadata_dict['columns']:
sdtype = metadata_dict['columns'][column]['sdtype']
if sdtype in supported_sdtypes:
config[column] = metadata_dict['columns'][column]['sdtype']
else:
LOGGER.info(
f'Column {column} sdtype: {sdtype} is not supported, '
f'defaulting to inferred type.')
hyper_transformer.update_sdtypes(config)

# This is done to match the behavior of the synthesizer for SDGym <= 0.6.0
columns_to_remove = [
Expand Down

0 comments on commit 8e763cd

Please sign in to comment.