diff --git a/resources/equalized_odds_improvement_tutorial.ipynb b/resources/equalized_odds_improvement_tutorial.ipynb
new file mode 100644
index 00000000..88b69149
--- /dev/null
+++ b/resources/equalized_odds_improvement_tutorial.ipynb
@@ -0,0 +1,820 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "vscode": {
+ "languageId": "raw"
+ }
+ },
+ "source": [
+ "# Tutorial: EqualizedOddsImprovement Metric\n",
+ "\n",
+ "This notebook demonstrates how to use the `EqualizedOddsImprovement` metric to evaluate fairness in synthetic data generation. We'll use the Adult dataset to show how synthetic data can potentially improve fairness in machine learning models.\n",
+ "\n",
+ "## What is Equalized Odds?\n",
+ "\n",
+ "Equalized odds is a fairness criterion that requires the True Positive Rate (TPR) and False Positive Rate (FPR) to be equal across different groups defined by a sensitive attribute (like gender, race, etc.). \n",
+ "\n",
+ "The `EqualizedOddsImprovement` metric compares how well a model trained on synthetic data maintains fairness compared to a model trained on real data, both evaluated on the same validation set.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "vscode": {
+ "languageId": "raw"
+ }
+ },
+ "source": [
+ "## Setup and Imports\n",
+ "\n",
+ "First, let's install and import all the necessary libraries:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install sdv\n",
+ "!pip install xgboost\n",
+ "!pip install matplotlib"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 42,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "All libraries imported successfully!\n"
+ ]
+ }
+ ],
+ "source": [
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt\n",
+ "from sklearn.model_selection import train_test_split\n",
+ "import json\n",
+ "\n",
+ "from sdv.single_table import TVAESynthesizer\n",
+ "from sdv.datasets.demo import download_demo\n",
+ "from sdv.sampling import Condition\n",
+ "\n",
+ "from sdmetrics.single_table.equalized_odds import EqualizedOddsImprovement\n",
+ "\n",
+ "print(\"All libraries imported successfully!\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "vscode": {
+ "languageId": "raw"
+ }
+ },
+ "source": [
+ "## Step 1: Load the Adult Dataset\n",
+ "\n",
+ "We'll use the Adult dataset from the SDV demo datasets. This dataset contains information about individuals and whether they earn more than $50K per year.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 69,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Dataset shape: (32561, 15)\n",
+ "\n",
+ "First few rows:\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " age | \n",
+ " workclass | \n",
+ " fnlwgt | \n",
+ " education | \n",
+ " education-num | \n",
+ " marital-status | \n",
+ " occupation | \n",
+ " relationship | \n",
+ " race | \n",
+ " sex | \n",
+ " capital-gain | \n",
+ " capital-loss | \n",
+ " hours-per-week | \n",
+ " native-country | \n",
+ " label | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 27 | \n",
+ " Private | \n",
+ " 177119 | \n",
+ " Some-college | \n",
+ " 10 | \n",
+ " Divorced | \n",
+ " Adm-clerical | \n",
+ " Unmarried | \n",
+ " White | \n",
+ " Female | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 44 | \n",
+ " United-States | \n",
+ " <=50K | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 27 | \n",
+ " Private | \n",
+ " 216481 | \n",
+ " Bachelors | \n",
+ " 13 | \n",
+ " Never-married | \n",
+ " Prof-specialty | \n",
+ " Not-in-family | \n",
+ " White | \n",
+ " Female | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 40 | \n",
+ " United-States | \n",
+ " <=50K | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 25 | \n",
+ " Private | \n",
+ " 256263 | \n",
+ " Assoc-acdm | \n",
+ " 12 | \n",
+ " Married-civ-spouse | \n",
+ " Sales | \n",
+ " Husband | \n",
+ " White | \n",
+ " Male | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 40 | \n",
+ " United-States | \n",
+ " <=50K | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 46 | \n",
+ " Private | \n",
+ " 147640 | \n",
+ " 5th-6th | \n",
+ " 3 | \n",
+ " Married-civ-spouse | \n",
+ " Transport-moving | \n",
+ " Husband | \n",
+ " Amer-Indian-Eskimo | \n",
+ " Male | \n",
+ " 0 | \n",
+ " 1902 | \n",
+ " 40 | \n",
+ " United-States | \n",
+ " <=50K | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 45 | \n",
+ " Private | \n",
+ " 172822 | \n",
+ " 11th | \n",
+ " 7 | \n",
+ " Divorced | \n",
+ " Transport-moving | \n",
+ " Not-in-family | \n",
+ " White | \n",
+ " Male | \n",
+ " 0 | \n",
+ " 2824 | \n",
+ " 76 | \n",
+ " United-States | \n",
+ " >50K | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " age workclass fnlwgt education education-num marital-status \\\n",
+ "0 27 Private 177119 Some-college 10 Divorced \n",
+ "1 27 Private 216481 Bachelors 13 Never-married \n",
+ "2 25 Private 256263 Assoc-acdm 12 Married-civ-spouse \n",
+ "3 46 Private 147640 5th-6th 3 Married-civ-spouse \n",
+ "4 45 Private 172822 11th 7 Divorced \n",
+ "\n",
+ " occupation relationship race sex capital-gain \\\n",
+ "0 Adm-clerical Unmarried White Female 0 \n",
+ "1 Prof-specialty Not-in-family White Female 0 \n",
+ "2 Sales Husband White Male 0 \n",
+ "3 Transport-moving Husband Amer-Indian-Eskimo Male 0 \n",
+ "4 Transport-moving Not-in-family White Male 0 \n",
+ "\n",
+ " capital-loss hours-per-week native-country label \n",
+ "0 0 44 United-States <=50K \n",
+ "1 0 40 United-States <=50K \n",
+ "2 0 40 United-States <=50K \n",
+ "3 1902 40 United-States <=50K \n",
+ "4 2824 76 United-States >50K "
+ ]
+ },
+ "execution_count": 69,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Load the adult dataset\n",
+ "real_data, metadata = download_demo('single_table', 'adult')\n",
+ "\n",
+ "print(f\"Dataset shape: {real_data.shape}\")\n",
+ "print(f\"\\nFirst few rows:\")\n",
+ "real_data.head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 2: Introduce Data Bias\n",
+ "\n",
+ "We'll introduce bias to the data by setting the label of 95% of female rows as '<=50K', while the other 5% are '>50K'. The labels are randomly chosen, ie they have no correlation with the data besides the gender column.\n",
+ "\n",
+ "For male rows the opposite will be done, ie 5% of the data labels will be '<=50K' and 95% will be '>50K'."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 81,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "female_mask = real_data['sex'] == 'Female'\n",
+ "num_females = female_mask.sum()\n",
+ "\n",
+ "female_labels = np.random.choice(['<=50K', '>50K'], size=num_females, p=[0.95, 0.05])\n",
+ "real_data.loc[female_mask, 'label'] = female_labels\n",
+ "\n",
+ "male_mask = real_data['sex'] == 'Male'\n",
+ "num_males = male_mask.sum()\n",
+ "\n",
+ "male_labels = np.random.choice(['<=50K', '>50K'], size=num_males, p=[0.05, 0.95])\n",
+ "real_data.loc[male_mask, 'label'] = male_labels\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 71,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# Visualize the distributions\n",
+ "fig, axes = plt.subplots(1, 2, figsize=(12, 5))\n",
+ "\n",
+ "# income distribution by sex\n",
+ "crosstab_pct = pd.crosstab(real_data['sex'], real_data['label'], normalize='index') * 100\n",
+ "crosstab_pct.plot(kind='bar', ax=axes[0], rot=0)\n",
+ "axes[0].set_title('Income Distribution by Sex (%)')\n",
+ "axes[0].set_xlabel('Sex')\n",
+ "axes[0].set_ylabel('Percentage')\n",
+ "axes[0].legend(title='Income')\n",
+ "\n",
+ "# Overall income distribution\n",
+ "real_data['label'].value_counts().plot(kind='bar', ax=axes[1], rot=0)\n",
+ "axes[1].set_title('Overall Income Distribution')\n",
+ "axes[1].set_xlabel('Income')\n",
+ "axes[1].set_ylabel('Count')\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "vscode": {
+ "languageId": "raw"
+ }
+ },
+ "source": [
+ "## Step 3: Split Data into Training and Validation Sets"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 72,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Training set shape: (22792, 15)\n",
+ "Validation set shape: (9769, 15)\n",
+ "\n",
+ "Training set combinations:\n",
+ "sex Female Male\n",
+ "label \n",
+ "<=50K 7164 738\n",
+ ">50K 354 14536\n",
+ "\n",
+ "Validation set combinations:\n",
+ "sex Female Male\n",
+ "label \n",
+ "<=50K 3089 336\n",
+ ">50K 164 6180\n"
+ ]
+ }
+ ],
+ "source": [
+ "training_data, validation_data = train_test_split(\n",
+ " real_data,\n",
+ " test_size=0.3,\n",
+ " random_state=42,\n",
+ ")\n",
+ "\n",
+ "print(f\"\\nTraining set shape: {training_data.shape}\")\n",
+ "print(f\"Validation set shape: {validation_data.shape}\")\n",
+ "\n",
+ "# Verify all combinations exist in both sets\n",
+ "print(\"\\nTraining set combinations:\")\n",
+ "print(pd.crosstab(training_data['label'], training_data['sex']))\n",
+ "print(\"\\nValidation set combinations:\")\n",
+ "print(pd.crosstab(validation_data['label'], validation_data['sex']))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "vscode": {
+ "languageId": "raw"
+ }
+ },
+ "source": [
+ "## Step 4: Generate Synthetic Data\n",
+ "\n",
+ "We'll use the TVAE (Tabular Variational AutoEncoder) synthesizer to generate synthetic data.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 73,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Training TVAE synthesizer...\n",
+ "Synthesizer training completed!\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"Training TVAE synthesizer...\")\n",
+ "\n",
+ "synthesizer = TVAESynthesizer(metadata=metadata)\n",
+ "synthesizer.fit(training_data)\n",
+ "\n",
+ "print(\"Synthesizer training completed!\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 74,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Synthetic data shape: (22792, 15)\n",
+ "\n",
+ "Target and sensitive attribute distribution:\n",
+ "sex Female Male\n",
+ "label \n",
+ "<=50K 6868 936\n",
+ ">50K 83 14905\n"
+ ]
+ }
+ ],
+ "source": [
+ "synthetic_data = synthesizer.sample(len(training_data))\n",
+ "\n",
+ "print(f\"Synthetic data shape: {synthetic_data.shape}\")\n",
+ "print(\"\\nTarget and sensitive attribute distribution:\")\n",
+ "print(pd.crosstab(synthetic_data['label'], synthetic_data['sex']))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "vscode": {
+ "languageId": "raw"
+ }
+ },
+ "source": [
+ "## Step 5: Evaluate Synthetic Data\n",
+ "\n",
+ "Let's evaluate the synthetic data generated with the EqualizedOddsImprovement metric."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 75,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Score: 0.5337\n",
+ "\n",
+ "Score Interpretation:\n",
+ "- Score > 0.5 means synthetic data improves fairness\n",
+ "- Score < 0.5 means synthetic data worsens fairness\n",
+ "- Score = 0.5 means no change in fairness\n"
+ ]
+ }
+ ],
+ "source": [
+ "result_standard = EqualizedOddsImprovement.compute_breakdown(\n",
+ " real_training_data=training_data,\n",
+ " synthetic_data=synthetic_data,\n",
+ " real_validation_data=validation_data,\n",
+ " metadata=metadata.to_dict()['tables']['adult'],\n",
+ " prediction_column_name='label',\n",
+ " positive_class_label='>50K',\n",
+ " sensitive_column_name='sex',\n",
+ " sensitive_column_value='Female'\n",
+ ")\n",
+ "\n",
+ "print(f\"Score: {result_standard['score']:.4f}\")\n",
+ "print(f\"\\nScore Interpretation:\")\n",
+ "print(f\"- Score > 0.5 means synthetic data improves fairness\")\n",
+ "print(f\"- Score < 0.5 means synthetic data worsens fairness\")\n",
+ "print(f\"- Score = 0.5 means no change in fairness\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 76,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Full breakdown of the Equalized Odds Improvement metric:\n",
+ "{\n",
+ " \"score\": 0.5337489169733715,\n",
+ " \"real_training_data\": {\n",
+ " \"equalized_odds\": 0.0008090614886731018,\n",
+ " \"prediction_counts_validation\": {\n",
+ " \"Female=True\": {\n",
+ " \"true_positive\": 0,\n",
+ " \"false_positive\": 15,\n",
+ " \"true_negative\": 3074,\n",
+ " \"false_negative\": 164\n",
+ " },\n",
+ " \"Female=False\": {\n",
+ " \"true_positive\": 6175,\n",
+ " \"false_positive\": 336,\n",
+ " \"true_negative\": 0,\n",
+ " \"false_negative\": 5\n",
+ " }\n",
+ " }\n",
+ " },\n",
+ " \"synthetic_data\": {\n",
+ " \"equalized_odds\": 0.06830689543541602,\n",
+ " \"prediction_counts_validation\": {\n",
+ " \"Female=True\": {\n",
+ " \"true_positive\": 14,\n",
+ " \"false_positive\": 211,\n",
+ " \"true_negative\": 2878,\n",
+ " \"false_negative\": 150\n",
+ " },\n",
+ " \"Female=False\": {\n",
+ " \"true_positive\": 6172,\n",
+ " \"false_positive\": 336,\n",
+ " \"true_negative\": 0,\n",
+ " \"false_negative\": 8\n",
+ " }\n",
+ " }\n",
+ " }\n",
+ "}\n"
+ ]
+ }
+ ],
+ "source": [
+ "print('Full breakdown of the Equalized Odds Improvement metric:')\n",
+ "print(json.dumps(result_standard, indent=2))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "vscode": {
+ "languageId": "raw"
+ }
+ },
+ "source": [
+ "## Step 6: Generate Conditionally Sampled Synthetic Data\n",
+ "\n",
+ "Now let's try to improve fairness by using conditional sampling to create a more balanced dataset where each combination of target and sensitive attributes has equal representation (25% each)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 77,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Generating conditionally sampled synthetic data...\n",
+ "Each condition will have 25% of the data (equal representation)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Sampling conditions: 100%|██████████| 22792/22792 [00:14<00:00, 1609.35it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Generated 22792 samples\n",
+ "\n",
+ "Target and sensitive attribute distribution:\n",
+ "sex Female Male\n",
+ "label \n",
+ "<=50K 5698 5698\n",
+ ">50K 5698 5698\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"Generating conditionally sampled synthetic data...\")\n",
+ "print(\"Each condition will have 25% of the data (equal representation)\")\n",
+ "\n",
+ "total_samples = len(training_data)\n",
+ "samples_per_condition = total_samples // 4\n",
+ "conditions = [\n",
+ " Condition({'label': '>50K', 'sex': 'Female'}, num_rows=samples_per_condition),\n",
+ " Condition({'label': '<=50K', 'sex': 'Female'}, num_rows=samples_per_condition),\n",
+ " Condition({'label': '>50K', 'sex': 'Male'}, num_rows=samples_per_condition),\n",
+ " Condition({'label': '<=50K', 'sex': 'Male'}, num_rows=samples_per_condition)\n",
+ "]\n",
+ "balanced_synthetic_data = synthesizer.sample_from_conditions(conditions=conditions)\n",
+ "print(f\"Generated {len(balanced_synthetic_data)} samples\")\n",
+ "\n",
+ "print(\"\\nTarget and sensitive attribute distribution:\")\n",
+ "balanced_crosstab = pd.crosstab(balanced_synthetic_data['label'], balanced_synthetic_data['sex'])\n",
+ "print(balanced_crosstab)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "vscode": {
+ "languageId": "raw"
+ }
+ },
+ "source": [
+ "## Step 7: Evaluate Balanced Synthetic Data\n",
+ "\n",
+ "Now let's evaluate the balanced synthetic data to compare it with the standard synthetic data."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 78,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Score: 0.7693\n"
+ ]
+ }
+ ],
+ "source": [
+ "result_balanced = EqualizedOddsImprovement.compute_breakdown(\n",
+ " real_training_data=training_data,\n",
+ " synthetic_data=balanced_synthetic_data,\n",
+ " real_validation_data=validation_data,\n",
+ " metadata=metadata.to_dict()['tables']['adult'],\n",
+ " prediction_column_name='label',\n",
+ " positive_class_label='>50K',\n",
+ " sensitive_column_name='sex',\n",
+ " sensitive_column_value='Female'\n",
+ ")\n",
+ "\n",
+ "print(f\"Score: {result_balanced['score']:.4f}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 79,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The full breakdown of the Equalized Odds Improvement metric is:\n",
+ "{\n",
+ " \"score\": 0.7692813165995738,\n",
+ " \"real_training_data\": {\n",
+ " \"equalized_odds\": 0.0008090614886731018,\n",
+ " \"prediction_counts_validation\": {\n",
+ " \"Female=True\": {\n",
+ " \"true_positive\": 0,\n",
+ " \"false_positive\": 15,\n",
+ " \"true_negative\": 3074,\n",
+ " \"false_negative\": 164\n",
+ " },\n",
+ " \"Female=False\": {\n",
+ " \"true_positive\": 6175,\n",
+ " \"false_positive\": 336,\n",
+ " \"true_negative\": 0,\n",
+ " \"false_negative\": 5\n",
+ " }\n",
+ " }\n",
+ " },\n",
+ " \"synthetic_data\": {\n",
+ " \"equalized_odds\": 0.5393716946878206,\n",
+ " \"prediction_counts_validation\": {\n",
+ " \"Female=True\": {\n",
+ " \"true_positive\": 81,\n",
+ " \"false_positive\": 1708,\n",
+ " \"true_negative\": 1381,\n",
+ " \"false_negative\": 83\n",
+ " },\n",
+ " \"Female=False\": {\n",
+ " \"true_positive\": 5899,\n",
+ " \"false_positive\": 317,\n",
+ " \"true_negative\": 19,\n",
+ " \"false_negative\": 281\n",
+ " }\n",
+ " }\n",
+ " }\n",
+ "}\n"
+ ]
+ }
+ ],
+ "source": [
+ "print('The full breakdown of the Equalized Odds Improvement metric is:')\n",
+ "print(json.dumps(result_balanced, indent=2))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "vscode": {
+ "languageId": "raw"
+ }
+ },
+ "source": [
+ "## Step 8: Compare Results and Analysis\n",
+ "\n",
+ "Let's compare the results from both approaches to analyze the impact of balanced sampling on fairness."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 80,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "fig, axes = plt.subplots(1, 2, figsize=(15, 6))\n",
+ "\n",
+ "# Improvement scores comparison\n",
+ "scores = [result_standard['score'], result_balanced['score']]\n",
+ "labels = ['Standard\\nSynthetic', 'Balanced\\nSynthetic']\n",
+ "colors = ['lightcoral', 'lightgreen']\n",
+ "\n",
+ "bars1 = axes[0].bar(labels, scores, color=colors, alpha=0.7, edgecolor='black')\n",
+ "axes[0].axhline(y=0.5, color='red', linestyle='--', alpha=0.7, label='No Improvement Baseline')\n",
+ "axes[0].set_ylim(0, 1)\n",
+ "axes[0].set_ylabel('Improvement Score')\n",
+ "axes[0].set_title('Overall Score Comparison')\n",
+ "axes[0].legend()\n",
+ "\n",
+ "# Add score labels on bars\n",
+ "for bar, score in zip(bars1, scores):\n",
+ " axes[0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, \n",
+ " f'{score:.4f}', ha='center', va='bottom', fontweight='bold')\n",
+ "\n",
+ "# Equalized odds scores comparison\n",
+ "eq_scores = [result_standard['synthetic_data']['equalized_odds'], result_balanced['synthetic_data']['equalized_odds']]\n",
+ "bars2 = axes[1].bar(labels, eq_scores, color=colors, alpha=0.7, edgecolor='black')\n",
+ "axes[1].set_ylim(0, 1)\n",
+ "axes[1].set_ylabel('Equalized Odds Score')\n",
+ "axes[1].set_title('Equalized Odds Score Comparison')\n",
+ "\n",
+ "# Add score labels on bars\n",
+ "for bar, score in zip(bars2, eq_scores):\n",
+ " axes[1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, \n",
+ " f'{score:.4f}', ha='center', va='bottom', fontweight='bold')\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "a",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.11"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/resources/visualize.png b/resources/visualize.png
deleted file mode 100644
index 1d9924cb..00000000
Binary files a/resources/visualize.png and /dev/null differ
diff --git a/sdmetrics/single_table/equalized_odds.py b/sdmetrics/single_table/equalized_odds.py
index d6b543d3..d343797f 100644
--- a/sdmetrics/single_table/equalized_odds.py
+++ b/sdmetrics/single_table/equalized_odds.py
@@ -272,6 +272,11 @@ def _validate_parameters(
required_columns = [prediction_column_name, sensitive_column_name]
_validate_required_columns(dataframes_dict, required_columns)
+ # Use base class validation for real_training_data and synthetic_data
+ real_training_data, synthetic_data, metadata = cls._validate_inputs(
+ real_training_data, synthetic_data, metadata
+ )
+
# Validate data and metadata consistency for prediction column
_validate_data_and_metadata(
real_training_data,
@@ -286,11 +291,6 @@ def _validate_parameters(
column_value_pairs = [(sensitive_column_name, sensitive_column_value)]
_validate_column_values_exist(dataframes_dict, column_value_pairs)
- # Use base class validation for real_training_data and synthetic_data
- real_training_data, synthetic_data, metadata = cls._validate_inputs(
- real_training_data, synthetic_data, metadata
- )
-
# Validate the validation data separately (not part of standard _validate_inputs)
real_validation_data = real_validation_data.copy()