diff --git a/sdmetrics/single_table/bayesian_network.py b/sdmetrics/single_table/bayesian_network.py index d2f03151..cc2ed187 100644 --- a/sdmetrics/single_table/bayesian_network.py +++ b/sdmetrics/single_table/bayesian_network.py @@ -7,6 +7,8 @@ from sdmetrics.goal import Goal from sdmetrics.single_table.base import SingleTableMetric +from sklearn.preprocessing import LabelEncoder +from pomegranate import bayesian_network LOGGER = logging.getLogger(__name__) @@ -16,12 +18,6 @@ class BNLikelihoodBase(SingleTableMetric): @classmethod def _likelihoods(cls, real_data, synthetic_data, metadata=None, structure=None): - try: - from pomegranate import BayesianNetwork - except ImportError: - raise ImportError( - 'Please install pomegranate with `pip install pomegranate` on a version of python ' - '< 3.11. This metric is not supported on python versions >= 3.11.') real_data, synthetic_data, metadata = cls._validate_inputs( real_data, synthetic_data, metadata) @@ -30,19 +26,29 @@ def _likelihoods(cls, real_data, synthetic_data, metadata=None, structure=None): if not fields: return np.full(len(real_data), np.nan) + + encoders = {field: LabelEncoder() for field in fields} + + real_data_encoded = real_data.copy() + synthetic_data_encoded = synthetic_data.copy() + + for field in fields: + real_data_encoded[field] = encoders[field].fit_transform(real_data_encoded[field]) + + for field in fields: + synthetic_data_encoded[field] = encoders[field].fit_transform(synthetic_data_encoded[field]) LOGGER.debug('Fitting the BayesianNetwork to the real data') if structure: - if isinstance(structure, dict): - structure = BayesianNetwork.from_json(json.dumps(structure)).structure - - bn = BayesianNetwork.from_structure(real_data[fields].to_numpy(), structure) + bn = bayesian_network.BayesianNetwork(structure=structure, algorithm='chow-liu') else: - bn = BayesianNetwork.from_samples(real_data[fields].to_numpy(), algorithm='chow-liu') + bn = bayesian_network.BayesianNetwork(algorithm='chow-liu') + + bn.fit(real_data_encoded[fields].to_numpy()) LOGGER.debug('Evaluating likelihood of the synthetic data') probabilities = [] - for _, row in synthetic_data[fields].iterrows(): + for _, row in synthetic_data_encoded[fields].iterrows(): try: probabilities.append(bn.probability([row.to_numpy()])) except ValueError: