{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Income Prediciton\n", "\n", "The dataset is credited to Ronny Kohavi and Barry Becker and was drawn from the 1994 United States Census Bureau data and involves using personal details such as education level to predict whether an individual will earn more or less than $50,000 per year.\n", "\n", "The dataset provides 14 input variables that are a mixture of categorical, ordinal, and numerical data types. The complete list of variables is as follows:\n", "\n", "- Age.\n", "- Workclass.\n", "- Final Weight.\n", "- Education.\n", "- Education Number of Years.\n", "- Marital-status.\n", "- Occupation.\n", "- Relationship.\n", "- Race.\n", "- Sex.\n", "- Capital-gain.\n", "- Capital-loss.\n", "- Hours-per-week.\n", "- Native-country.\n", "\n", "There are a total of 48,842 rows of data, and 3,620 with missing values, leaving 45,222 complete rows.\n", "\n", "There are two class values ‘>50K‘ and ‘<=50K‘, meaning it is a binary classification task. The classes are imbalanced, with a skew toward the ‘<=50K‘ class label.\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "### We will include the following contents:\n", "- Data Exploration\n", " - Load Dataset\n", " - Data Statistics\n", "- Bias Analysis\n", " - Bias Metric\n", " - Fairness Visualization\n", "- Fair machine learning methods\n", " - Fair constrainted learning\n", " - Fair Representation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "

" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Data Exploration" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "import os\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "import warnings\n", "warnings.filterwarnings('ignore')" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "scrolled": true, "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(32561, 15)\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AgeWorkclassFinal WeightEducationEducation Number of YearsMarital-statusOccupationRelationshipRaceGenderCapital-gainCapital-lossHours-per-weekNative-countryIncome
039State-gov77516Bachelors13Never-marriedAdm-clericalNot-in-familyWhiteMale2174040United-States<=50K
150Self-emp-not-inc83311Bachelors13Married-civ-spouseExec-managerialHusbandWhiteMale0013United-States<=50K
238Private215646HS-grad9DivorcedHandlers-cleanersNot-in-familyWhiteMale0040United-States<=50K
353Private23472111th7Married-civ-spouseHandlers-cleanersHusbandBlackMale0040United-States<=50K
428Private338409Bachelors13Married-civ-spouseProf-specialtyWifeBlackFemale0040Cuba<=50K
\n", "
" ], "text/plain": [ " Age Workclass Final Weight Education \\\n", "0 39 State-gov 77516 Bachelors \n", "1 50 Self-emp-not-inc 83311 Bachelors \n", "2 38 Private 215646 HS-grad \n", "3 53 Private 234721 11th \n", "4 28 Private 338409 Bachelors \n", "\n", " Education Number of Years Marital-status Occupation \\\n", "0 13 Never-married Adm-clerical \n", "1 13 Married-civ-spouse Exec-managerial \n", "2 9 Divorced Handlers-cleaners \n", "3 7 Married-civ-spouse Handlers-cleaners \n", "4 13 Married-civ-spouse Prof-specialty \n", "\n", " Relationship Race Gender Capital-gain Capital-loss \\\n", "0 Not-in-family White Male 2174 0 \n", "1 Husband White Male 0 0 \n", "2 Not-in-family White Male 0 0 \n", "3 Husband Black Male 0 0 \n", "4 Wife Black Female 0 0 \n", "\n", " Hours-per-week Native-country Income \n", "0 40 United-States <=50K \n", "1 13 United-States <=50K \n", "2 40 United-States <=50K \n", "3 40 United-States <=50K \n", "4 40 Cuba <=50K " ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "columns = [\"Age\", \"Workclass\", \"Final Weight\", \"Education\", \"Education Number of Years\", \"Marital-status\", \"Occupation\", \"Relationship\", \"Race\", \"Gender\", \"Capital-gain\", \"Capital-loss\", \"Hours-per-week\", \"Native-country\", \"Income\"]\n", "df = pd.read_csv(\"../datasets/adult/raw/adult.data\", index_col=None, header=None, names=columns)\n", "print(df.shape)\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig,ax=plt.subplots(ncols=3,figsize=(16,5))\n", "df[\"Income\"].value_counts().plot.pie(autopct=\"%.2f%%\",ax=ax[0], title=\"Overall\")\n", "df[df[\"Gender\"]==\" Female\"][\"Income\"].value_counts().plot.pie(autopct=\"%.2f%%\",ax=ax[1], title=\"Female\")\n", "df[df[\"Gender\"]==\" Male\"][\"Income\"].value_counts().plot.pie(autopct=\"%.2f%%\",ax=ax[2], title = \"Male\")" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Bias Analysis\n", "\n", "Take the age as exmaple, we can run the following command to get the fairness analysis results:\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "\n", "import os\n", "import sys\n", "sys.path.append(os.path.join(os.getcwd(), '../src'))\n", "import numpy as np\n", "import pandas as pd\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.preprocessing import StandardScaler\n", "from torch.utils.data import DataLoader\n", "from torch.optim.lr_scheduler import StepLR\n", "from tabulate import tabulate\n", "\n", "\n", "from dataset import load_german_data, load_adult_data\n", "from utils import seed_everything, PandasDataSet, print_metrics, clear_lines, InfiniteDataLoader\n", "from metrics import metric_evaluation\n", "from networks import MLP\n", "\n", "from loss import DiffDP\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "def train_step(model, data, target, sensitive, scheduler, optimizer, clf_criterion, fair_criterion, lam, device, args=None):\n", " model.train()\n", " optimizer.zero_grad()\n", " h, output = model(data)\n", " clf_loss = clf_criterion(output, target)\n", " fair_loss = fair_criterion(output, sensitive)\n", " loss = clf_loss + lam * fair_loss\n", " loss.backward()\n", " optimizer.step()\n", " scheduler.step()\n", " return model, loss.item(), clf_loss.item(), fair_loss.item()\n", "\n", "def test(model, test_loader, clf_criterion, fair_criterion, lam, device, prefix=\"test\", args=None):\n", " model.eval()\n", " clf_loss = 0\n", " fair_loss = 0\n", " target_hat_list = []\n", " target_list = []\n", " sensitive_list = []\n", "\n", " with torch.no_grad():\n", " for data, target, sensitive in test_loader:\n", " data, target, sensitive = (data.to(device), target.to(device), sensitive.to(device))\n", " h, output = model(data)\n", "\n", " clf_loss += clf_criterion(output, target).item()\n", " fair_loss += fair_criterion(output, sensitive).item()\n", " target_hat_list.append(output.cpu().numpy())\n", " target_list.append(target.cpu().numpy())\n", " sensitive_list.append(sensitive.cpu().numpy())\n", "\n", " target_hat_list = np.concatenate(target_hat_list, axis=0)\n", " target_list = np.concatenate(target_list, axis=0)\n", " sensitive_list = np.concatenate(sensitive_list, axis=0)\n", " metric = metric_evaluation(y_gt=target_list, y_pre=target_hat_list, s=sensitive_list, prefix=f\"{prefix}\")\n", "\n", " clf_loss /= len(test_loader)\n", " fair_loss /= len(test_loader)\n", " \n", " metric[f\"{prefix}/clf_loss\"] = clf_loss\n", " metric[f\"{prefix}/fair_loss\"] = fair_loss\n", " metric[f\"{prefix}/loss\"] = clf_loss + lam*fair_loss\n", " metric[f\"{prefix}/y_hat\"] = target_hat_list\n", " metric[f\"{prefix}/sensitive\"] = sensitive_list\n", "\n", " return metric\n", "\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+---------------+--------+\n", "| dataset | german |\n", "+---------------+--------+\n", "| num_features | 101 |\n", "+---------------+--------+\n", "| num_classes | 2 |\n", "+---------------+--------+\n", "| num_sensitive | 2 |\n", "+---------------+--------+\n", "| num_samples | 45222 |\n", "+---------------+--------+\n", "| num_train | 18088 |\n", "+---------------+--------+\n", "| num_val | 13567 |\n", "+---------------+--------+\n", "| num_test | 13567 |\n", "+---------------+--------+\n", "| num_y1 | 11208 |\n", "+---------------+--------+\n", "| num_y0 | 34014 |\n", "+---------------+--------+\n", "| num_s1 | 30527 |\n", "+---------------+--------+\n", "| num_s0 | 14695 |\n", "+---------------+--------+\n" ] } ], "source": [ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "dataset = \"german\"\n", "sensitive_attr = \"sex\"\n", "seed = 42\n", "batch_size = 64\n", "lr = 1e-2\n", "mlp_layers = \"256,256\"\n", "evaluation_metrics = \"acc,dp\"\n", "num_training_steps = 200\n", "\n", "\n", "\n", "# X, y, s = load_german_data(path=\"../datasets/german/raw\", sensitive_attribute=sensitive_attr)\n", "X, y, s = load_adult_data(path=\"../datasets/adult/raw\", sensitive_attribute=sensitive_attr)\n", "\n", "categorical_cols = X.select_dtypes(\"string\").columns\n", "if len(categorical_cols) > 0:\n", " X = pd.get_dummies(X, columns=categorical_cols)\n", "\n", "\n", "n_features = X.shape[1]\n", "n_classes = len(np.unique(y))\n", "\n", "X_train, X_testvalid, y_train, y_testvalid, s_train, s_testvalid = train_test_split(X, y, s, test_size=0.6, stratify=y, random_state=seed)\n", "X_test, X_val, y_test, y_val, s_test, s_val = train_test_split(X_testvalid, y_testvalid, s_testvalid, test_size=0.5, stratify=y_testvalid, random_state=seed)\n", "\n", "dataset_stats = {\n", " \"dataset\": dataset,\n", " \"num_features\": X.shape[1],\n", " \"num_classes\": len(np.unique(y)),\n", " \"num_sensitive\": len(np.unique(s)),\n", " \"num_samples\": X.shape[0],\n", " \"num_train\": X_train.shape[0],\n", " \"num_val\": X_val.shape[0],\n", " \"num_test\": X_test.shape[0],\n", " \"num_y1\": (y.values == 1).sum(),\n", " \"num_y0\": (y.values == 0).sum(),\n", " \"num_s1\": (s.values == 1).sum(),\n", " \"num_s0\": (s.values == 0).sum(),\n", "}\n", "\n", "# Create the table using the tabulate function\n", "table = tabulate([(k, v) for k, v in dataset_stats.items()], tablefmt='grid')\n", "\n", "print(table)\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "numurical_cols = X.select_dtypes(\"float32\").columns\n", "if len(numurical_cols) > 0:\n", " # scaler = StandardScaler().fit(X[numurical_cols])\n", "\n", " def scale_df(df, scaler):\n", " return pd.DataFrame(scaler.transform(df), columns=df.columns, index=df.index)\n", "\n", " scaler = StandardScaler().fit(X_train[numurical_cols])\n", "\n", " def scale_df(df, scaler):\n", " return pd.DataFrame(scaler.transform(df), columns=df.columns, index=df.index)\n", "\n", " X_train[numurical_cols] = X_train[numurical_cols].pipe(scale_df, scaler)\n", " X_val[numurical_cols] = X_val[numurical_cols].pipe(scale_df, scaler)\n", " X_test[numurical_cols] = X_test[numurical_cols].pipe(scale_df, scaler)\n", "\n", "\n", "train_data = PandasDataSet(X_train, y_train, s_train)\n", "val_data = PandasDataSet(X_val, y_val, s_val)\n", "test_data = PandasDataSet(X_test, y_test, s_test)\n", "\n", "\n", "train_infinite_loader = InfiniteDataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)\n", "train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)\n", "val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)\n", "test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MLP(\n", " (network): ModuleList(\n", " (0): Linear(in_features=101, out_features=256, bias=True)\n", " (1): Linear(in_features=256, out_features=256, bias=True)\n", " )\n", " (head): Linear(in_features=256, out_features=1, bias=True)\n", ")\n" ] } ], "source": [ "mlp_layers = [int(x) for x in mlp_layers.split(\",\")]\n", "net = MLP(n_features=n_features, num_classes=1, mlp_layers=mlp_layers ).to(device)\n", "clf_criterion = nn.BCELoss()\n", "fair_criterion = DiffDP()\n", "optimizer = optim.Adam(net.parameters(), lr=lr)\n", "scheduler = StepLR(optimizer, step_size=50, gamma=0.1)\n", "\n", "print(net)\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+-------------------+-------------------+-------------------+\n", "| Step(Tr|Val|Te) | acc | dp |\n", "+===================+===================+===================+\n", "| 0 | 75.22|75.22|75.21 | 0.00|0.00|0.00 |\n", "+-------------------+-------------------+-------------------+\n", "| 20 | 83.24|82.52|82.69 | 11.83|11.69|11.01 |\n", "+-------------------+-------------------+-------------------+\n", "| 40 | 84.54|83.93|84.06 | 16.19|16.87|15.69 |\n", "+-------------------+-------------------+-------------------+\n", "| 60 | 85.06|84.25|84.42 | 17.94|18.54|17.53 |\n", "+-------------------+-------------------+-------------------+\n", "| 80 | 85.16|84.40|84.48 | 17.77|18.20|17.31 |\n", "+-------------------+-------------------+-------------------+\n", "| 100 | 85.19|84.49|84.55 | 18.61|18.96|18.14 |\n", "+-------------------+-------------------+-------------------+\n", "| 120 | 85.18|84.49|84.58 | 18.56|18.97|18.12 |\n", "+-------------------+-------------------+-------------------+\n", "| 140 | 85.17|84.48|84.59 | 18.39|18.80|17.93 |\n", "+-------------------+-------------------+-------------------+\n", "| 160 | 85.19|84.48|84.61 | 18.05|18.63|17.70 |\n", "+-------------------+-------------------+-------------------+\n", "| 180 | 85.19|84.47|84.59 | 18.06|18.64|17.70 |\n", "+-------------------+-------------------+-------------------+\n" ] } ], "source": [ "lam = 0\n", "logs = []\n", "headers = [\"Step(Tr|Val|Te)\"] + evaluation_metrics.split(\",\")\n", "\n", "# evaluation_metrics = \"ap,dp,prule\"\n", "\n", "\n", "for step, (X, y, s) in enumerate(train_infinite_loader):\n", " if step >= num_training_steps:\n", " break\n", "\n", " X, y, s = X.to(device), y.to(device), s.to(device)\n", " net, loss, clf_loss, fair_loss = train_step(model=net, data=X, target=y, sensitive=s, optimizer=optimizer, scheduler=scheduler, clf_criterion=clf_criterion, fair_criterion=fair_criterion, lam=lam, device=device)\n", "\n", "\n", " if step % 20 == 0:\n", " train_metrics = test(model=net, test_loader=train_loader, clf_criterion=clf_criterion, fair_criterion=fair_criterion, lam=lam, device=device, prefix=\"train\")\n", " val_metrics = test(model=net, test_loader=val_loader, clf_criterion=clf_criterion, fair_criterion=fair_criterion, lam=lam, device=device, prefix=\"val\")\n", " test_metrics = test(model=net, test_loader=test_loader, clf_criterion=clf_criterion, fair_criterion=fair_criterion, lam=lam, device=device, prefix=\"test\")\n", " res_dict = {}\n", " res_dict.update(train_metrics)\n", " res_dict.update(val_metrics)\n", " res_dict.update(test_metrics)\n", " res = print_metrics(res_dict, evaluation_metrics, train=True)\n", " logs.append( [ step, *res] )\n", " \n", "table = tabulate(logs, headers=headers, tablefmt=\"grid\", floatfmt=\"02.2f\")\n", "print(table)\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "y_hat = train_metrics[\"train/y_hat\"].reshape(-1)\n", "sensitive_attr = train_metrics[\"train/sensitive\"].reshape(-1)\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "\n", "# Separate y_hat based on sensitive_attr\n", "y_hat_0 = y_hat[sensitive_attr == 0]\n", "y_hat_1 = y_hat[sensitive_attr == 1]\n", "\n", "sns.kdeplot(y_hat_0, bw_adjust=.5, label='sensitive_attr=0', shade=True)\n", "sns.kdeplot(y_hat_1, bw_adjust=.5, label='sensitive_attr=1', shade=True)\n", "\n", "# Plot formatting\n", "plt.legend(prop={'size': 12})\n", "plt.title('Distribution of y_hat respect to sensitive_attr')\n", "plt.xlabel('y_hat')\n", "plt.ylabel('Density')\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MLP(\n", " (network): ModuleList(\n", " (0): Linear(in_features=101, out_features=256, bias=True)\n", " (1): Linear(in_features=256, out_features=256, bias=True)\n", " )\n", " (head): Linear(in_features=256, out_features=1, bias=True)\n", ")\n", "+-------------------+-------------------+----------------+\n", "| Step(Tr|Val|Te) | acc | dp |\n", "+===================+===================+================+\n", "| 0 | 75.67|75.50|75.51 | 0.67|0.37|0.30 |\n", "+-------------------+-------------------+----------------+\n", "| 20 | 82.57|81.91|82.31 | 9.42|9.39|9.02 |\n", "+-------------------+-------------------+----------------+\n", "| 40 | 83.57|82.73|83.33 | 7.39|7.39|6.95 |\n", "+-------------------+-------------------+----------------+\n", "| 60 | 81.66|81.16|81.59 | 1.27|1.97|1.09 |\n", "+-------------------+-------------------+----------------+\n", "| 80 | 83.18|82.58|82.88 | 4.56|4.13|4.09 |\n", "+-------------------+-------------------+----------------+\n", "| 100 | 82.92|82.27|82.60 | 3.47|3.52|3.16 |\n", "+-------------------+-------------------+----------------+\n", "| 120 | 82.97|82.36|82.64 | 3.54|3.59|3.15 |\n", "+-------------------+-------------------+----------------+\n", "| 140 | 83.06|82.41|82.74 | 3.80|3.81|3.41 |\n", "+-------------------+-------------------+----------------+\n", "| 160 | 83.09|82.41|82.79 | 3.79|3.89|3.41 |\n", "+-------------------+-------------------+----------------+\n", "| 180 | 83.09|82.42|82.80 | 3.80|3.90|3.43 |\n", "+-------------------+-------------------+----------------+\n" ] } ], "source": [ "net = MLP(n_features=n_features, num_classes=1, mlp_layers=mlp_layers ).to(device)\n", "clf_criterion = nn.BCELoss()\n", "fair_criterion = DiffDP()\n", "optimizer = optim.Adam(net.parameters(), lr=lr)\n", "scheduler = StepLR(optimizer, step_size=50, gamma=0.1)\n", "print(net)\n", "lam = 0.88\n", "logs = []\n", "headers = [\"Step(Tr|Val|Te)\"] + evaluation_metrics.split(\",\")\n", "\n", "# evaluation_metrics = \"ap,dp,prule\"\n", "\n", "\n", "for step, (X, y, s) in enumerate(train_infinite_loader):\n", " if step >= num_training_steps:\n", " break\n", "\n", " X, y, s = X.to(device), y.to(device), s.to(device)\n", " net, loss, clf_loss, fair_loss = train_step(model=net, data=X, target=y, sensitive=s, optimizer=optimizer, scheduler=scheduler, clf_criterion=clf_criterion, fair_criterion=fair_criterion, lam=lam, device=device)\n", "\n", "\n", " if step % 20 == 0:\n", " train_metrics = test(model=net, test_loader=train_loader, clf_criterion=clf_criterion, fair_criterion=fair_criterion, lam=lam, device=device, prefix=\"train\")\n", " val_metrics = test(model=net, test_loader=val_loader, clf_criterion=clf_criterion, fair_criterion=fair_criterion, lam=lam, device=device, prefix=\"val\")\n", " test_metrics = test(model=net, test_loader=test_loader, clf_criterion=clf_criterion, fair_criterion=fair_criterion, lam=lam, device=device, prefix=\"test\")\n", " res_dict = {}\n", " res_dict.update(train_metrics)\n", " res_dict.update(val_metrics)\n", " res_dict.update(test_metrics)\n", " res = print_metrics(res_dict, evaluation_metrics, train=True)\n", " logs.append( [ step, *res] )\n", " \n", "table = tabulate(logs, headers=headers, tablefmt=\"grid\", floatfmt=\"02.2f\")\n", "print(table)\n", "\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "y_hat = train_metrics[\"train/y_hat\"].reshape(-1)\n", "sensitive_attr = train_metrics[\"train/sensitive\"].reshape(-1)\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "\n", "# Separate y_hat based on sensitive_attr\n", "y_hat_0 = y_hat[sensitive_attr == 0]\n", "y_hat_1 = y_hat[sensitive_attr == 1]\n", "\n", "# Create the distribution plots\n", "sns.kdeplot(y_hat_0, label='sensitive_attr=0', shade=True)\n", "sns.kdeplot(y_hat_1, label='sensitive_attr=1', shade=True)\n", "\n", "# Plot formatting\n", "plt.legend(prop={'size': 12})\n", "plt.title('Distribution of y_hat respect to sensitive_attr')\n", "plt.xlabel('y_hat')\n", "plt.ylabel('Density')\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.16" } }, "nbformat": 4, "nbformat_minor": 4 }