diff --git a/sdgym/synthesizers/uniform.py b/sdgym/synthesizers/uniform.py index efe41b39..6eb04b8d 100644 --- a/sdgym/synthesizers/uniform.py +++ b/sdgym/synthesizers/uniform.py @@ -17,18 +17,23 @@ def _get_trained_synthesizer(self, real_data, metadata): hyper_transformer = HyperTransformer() hyper_transformer.detect_initial_config(real_data) if metadata: - metadata_dict = metadata.to_dict() - supported_sdtypes = hyper_transformer._get_supported_sdtypes() - config = {} - for column in metadata_dict['columns']: - sdtype = metadata_dict['columns'][column]['sdtype'] - if sdtype in supported_sdtypes: - config[column] = metadata_dict['columns'][column]['sdtype'] - else: - LOGGER.info( - f'Column {column} sdtype: {sdtype} is not supported, ' - f'defaulting to inferred type.') - hyper_transformer.update_sdtypes(config) + metadata_dict = {} + if isinstance(metadata, dict): + metadata_dict = metadata + else: + metadata_dict = metadata.to_dict() + if 'columns' in metadata_dict: + supported_sdtypes = hyper_transformer._get_supported_sdtypes() + config = {} + for column in metadata_dict['columns']: + sdtype = metadata_dict['columns'][column]['sdtype'] + if sdtype in supported_sdtypes: + config[column] = metadata_dict['columns'][column]['sdtype'] + else: + LOGGER.info( + f'Column {column} sdtype: {sdtype} is not supported, ' + f'defaulting to inferred type.') + hyper_transformer.update_sdtypes(config) # This is done to match the behavior of the synthesizer for SDGym <= 0.6.0 columns_to_remove = [