This research aims to develop an artificial intelligence-based system that identifies patients who are more likely to develop heart disease based on their medical history. The heart disease dataset from the UCI Machine Learning Repository was used for training and validation.
In [1]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import warnings
warnings.filterwarnings("ignore")
from sklearn.preprocessing import MinMaxScaler
from sklearn import metrics
import random
Install Required Libraries¶
In [2]:
%%capture
!pip install -q hvplot
!pip install pytorch-tabnet
Import Libraries¶
In [3]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import scikitplot as skplt
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from xgboost import XGBClassifier
from sklearn.ensemble import RandomForestClassifier,StackingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import AdaBoostClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from scipy import stats
from numpy import isnan
from sklearn.impute import KNNImputer
from sklearn.model_selection import GridSearchCV, cross_val_score, StratifiedKFold, learning_curve
import pytorch_tabnet
from pytorch_tabnet.tab_model import TabNetClassifier
import torch
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import roc_auc_score, accuracy_score
from sklearn.model_selection import KFold
from matplotlib.pyplot import figure
Get Cleveland Data from UCI Repository¶
In [4]:
data=pd.read_csv("https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/processed.cleveland.data",header=None)
data = data.replace("?",np.nan)
data = data.dropna().reset_index(drop=True)
data.columns = ['age', 'sex', 'chest pain type', 'resting bp s', 'cholesterol',
'fasting blood sugar', 'resting ecg', 'max heart rate',
'exercise angina', 'oldpeak', 'ST slope','ca', 'thal', 'target']
k=['age', 'sex', 'chest pain type', 'resting bp s', 'cholesterol',
'fasting blood sugar', 'resting ecg', 'max heart rate',
'exercise angina', 'ST slope','ca', 'thal', 'target']
for j in k:
data[j] = data[j].astype('float').astype('int')
data['oldpeak'] = data['oldpeak'].astype('float')
data['target'] = np.where(data.target>0,1,0)
dataTab = data.copy()
data.head()
Out[4]:
age | sex | chest pain type | resting bp s | cholesterol | fasting blood sugar | resting ecg | max heart rate | exercise angina | oldpeak | ST slope | ca | thal | target | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 63 | 1 | 1 | 145 | 233 | 1 | 2 | 150 | 0 | 2.3 | 3 | 0 | 6 | 0 |
1 | 67 | 1 | 4 | 160 | 286 | 0 | 2 | 108 | 1 | 1.5 | 2 | 3 | 3 | 1 |
2 | 67 | 1 | 4 | 120 | 229 | 0 | 2 | 129 | 1 | 2.6 | 2 | 2 | 7 | 1 |
3 | 37 | 1 | 3 | 130 | 250 | 0 | 0 | 187 | 0 | 3.5 | 3 | 0 | 3 | 0 |
4 | 41 | 0 | 2 | 130 | 204 | 0 | 2 | 172 | 0 | 1.4 | 1 | 0 | 3 | 0 |
Convert Nominal Variables - chest pain type, resting ecg and thal¶
In [5]:
CP_Dict = {1:'typical angina',2:'atypical angina',3:'non-anginal',4:'asymptomatic'}
ECG_Dict = {0:'normal',1:'ST-T wave abnormality',2:'left ventricular hypertrophy'}
thal_Dict = {3:'normal',6:'fixed defect',7:'reversable defect'}
data.replace({"chest pain type": CP_Dict},inplace=True)
data.replace({"resting ecg": ECG_Dict},inplace=True)
data.replace({"thal": thal_Dict},inplace=True)
data.head()
Out[5]:
age | sex | chest pain type | resting bp s | cholesterol | fasting blood sugar | resting ecg | max heart rate | exercise angina | oldpeak | ST slope | ca | thal | target | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 63 | 1 | typical angina | 145 | 233 | 1 | left ventricular hypertrophy | 150 | 0 | 2.3 | 3 | 0 | fixed defect | 0 |
1 | 67 | 1 | asymptomatic | 160 | 286 | 0 | left ventricular hypertrophy | 108 | 1 | 1.5 | 2 | 3 | normal | 1 |
2 | 67 | 1 | asymptomatic | 120 | 229 | 0 | left ventricular hypertrophy | 129 | 1 | 2.6 | 2 | 2 | reversable defect | 1 |
3 | 37 | 1 | non-anginal | 130 | 250 | 0 | normal | 187 | 0 | 3.5 | 3 | 0 | normal | 0 |
4 | 41 | 0 | atypical angina | 130 | 204 | 0 | left ventricular hypertrophy | 172 | 0 | 1.4 | 1 | 0 | normal | 0 |
In [6]:
Sex_Dict = {1:'male',0:'female'}
FS_Dict = {0:'under 120mgdl',1:'over 120mgdl'}
exang_Dict = {0:'not induced',1:'induced'}
slope_Dict = {1:'upsloping',2:'flat',3:'downsloping'}
data.replace({"sex": Sex_Dict},inplace=True)
data.replace({"fasting blood sugar": FS_Dict},inplace=True)
data.replace({"exercise angina": exang_Dict},inplace=True)
data.replace({"ST slope": slope_Dict},inplace=True)
dataset = data.copy()
data.head()
Out[6]:
age | sex | chest pain type | resting bp s | cholesterol | fasting blood sugar | resting ecg | max heart rate | exercise angina | oldpeak | ST slope | ca | thal | target | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 63 | male | typical angina | 145 | 233 | over 120mgdl | left ventricular hypertrophy | 150 | not induced | 2.3 | downsloping | 0 | fixed defect | 0 |
1 | 67 | male | asymptomatic | 160 | 286 | under 120mgdl | left ventricular hypertrophy | 108 | induced | 1.5 | flat | 3 | normal | 1 |
2 | 67 | male | asymptomatic | 120 | 229 | under 120mgdl | left ventricular hypertrophy | 129 | induced | 2.6 | flat | 2 | reversable defect | 1 |
3 | 37 | male | non-anginal | 130 | 250 | under 120mgdl | normal | 187 | not induced | 3.5 | downsloping | 0 | normal | 0 |
4 | 41 | female | atypical angina | 130 | 204 | under 120mgdl | left ventricular hypertrophy | 172 | not induced | 1.4 | upsloping | 0 | normal | 0 |
In [7]:
data['target'].value_counts(dropna=False)
Out[7]:
0 160 1 137 Name: target, dtype: int64
Exploratory Data Analysis¶
Target Variables¶
In [8]:
f, axes = plt.subplots(1, 1, figsize=(4, 6))
sns.countplot(ax=axes,x='target', data=data, palette=['green','orange'])
axes.set_title("Target Distribution", fontsize=20)
Out[8]:
Text(0.5, 1.0, 'Target Distribution')
Categorical Binaries - Sex, FBS and Exang¶
In [9]:
f, axes = plt.subplots(1, 3, figsize=(15, 5))
sns.countplot(ax=axes[0],x='sex', data=data, palette=['green','orange'],hue="target")
axes[0].set_title("sex", fontsize=20)
sns.countplot(ax=axes[1],x='fasting blood sugar', data=data, palette=['green','orange'],hue="target")
axes[1].set_title("fasting blood sugar", fontsize=20)
sns.countplot(ax=axes[2],x='exercise angina', data=data, palette=['green','orange'],hue="target")
plt.title("exercise angina", fontsize=20)
Out[9]:
Text(0.5, 1.0, 'exercise angina')
Categorical Variables - CP, ECG and thal¶
In [10]:
plt.figure(figsize=(12,5))
sns.countplot(x='chest pain type', data=data, palette=['green','orange'],hue="target")
plt.title("chest pain type", fontsize=20)
Out[10]:
Text(0.5, 1.0, 'chest pain type')
In [11]:
plt.figure(figsize=(12,5))
sns.countplot(x='ST slope', data=data, palette=['green','orange'],hue="target")
plt.title("ST slope", fontsize=20)
Out[11]:
Text(0.5, 1.0, 'ST slope')
In [12]:
plt.figure(figsize=(12,5))
ax=sns.countplot(x='resting ecg', data=data, palette=['green','orange'],hue="target")
plt.title("resting ecg", fontsize=20)
Out[12]:
Text(0.5, 1.0, 'resting ecg')
In [13]:
plt.figure(figsize=(12,5))
sns.countplot(x='ca', data=data, palette=['green','orange'],hue="target")
plt.title("ca", fontsize=20)
Out[13]:
Text(0.5, 1.0, 'ca')
In [14]:
plt.figure(figsize=(12,5))
sns.countplot(x='thal', data=data, palette=['green','orange'],hue="target")
plt.title("thal", fontsize=20)
Out[14]:
Text(0.5, 1.0, 'thal')
Numeric Variables - Age, Cholestrol,Resting BP and Max heart rate¶
In [15]:
data_disease = data[data["target"] == 1]
data_normal = data[data["target"] == 0]
Age¶
In [16]:
plt.figure(figsize=(8,5))
sns.distplot(data_normal["age"], bins=24, color='g')
sns.distplot(data_disease["age"], bins=24, color='r')
plt.title("Distribuition and density by Age",fontsize=20)
plt.xlabel("Age",fontsize=15)
plt.show()
Cholestrol¶
In [17]:
#figure size
plt.figure(figsize=(8,5))
sns.distplot(data_normal["cholesterol"], bins=24, color='g')
sns.distplot(data_disease["cholesterol"], bins=24, color='r')
plt.title("Distribuition and density by cholesterol",fontsize=20)
plt.xlabel("cholesterol",fontsize=15)
plt.show()
Resting BP¶
In [18]:
plt.figure(figsize=(8,5))
sns.distplot(data_normal["resting bp s"], bins=24, color='g')
sns.distplot(data_disease["resting bp s"], bins=24, color='r')
plt.title("Distribuition and density by resting bp",fontsize=20)
plt.xlabel("resting bp",fontsize=15)
plt.show()
Max Heart Rate¶
In [19]:
plt.figure(figsize=(8,5))
sns.distplot(data_normal["max heart rate"], bins=24, color='g')
sns.distplot(data_disease["max heart rate"], bins=24, color='r')
plt.title("Distribuition and density by max heart rate",fontsize=20)
plt.xlabel("max heart rate",fontsize=15)
plt.show()
Oldpeak¶
In [20]:
plt.figure(figsize=(8,5))
sns.distplot(data_normal["oldpeak"], bins=24, color='g')
sns.distplot(data_disease["oldpeak"], bins=24, color='r')
plt.title("Distribuition and density by old peak",fontsize=20)
plt.xlabel("oldpeak",fontsize=15)
plt.show()
Age vs Max HR¶
In [21]:
plt.figure(figsize=(9, 7))
plt.scatter(data_disease["age"],
data_disease["max heart rate"],
c="salmon")
plt.scatter(data_normal["age"],
data_normal["max heart rate"],
c="lightblue")
plt.title("Heart Disease in function of Age and Max Heart Rate")
plt.xlabel("Age")
plt.ylabel("Max Heart Rate")
plt.legend(["Disease", "No Disease"]);
Correlations¶
In [22]:
import hvplot.pandas
data.drop('target', axis=1).corrwith(data.target).hvplot.barh(
width=600, height=400,
title="Correlation between Heart Disease and Numeric Features",
ylabel='Correlation', xlabel='Numerical Features'
)
Out[22]:
In [23]:
sns.set(rc = {'figure.figsize':(12,12)})
sns.heatmap(data.corr(), annot = True, fmt='.2g',cmap= 'coolwarm')
Out[23]:
<AxesSubplot:>
Test Train Split¶
In [24]:
from sklearn.preprocessing import LabelEncoder
X = dataTab.drop('target', axis = 1)
feats = X.columns
categorical_columns = ['sex','resting ecg', 'chest pain type','fasting blood sugar' ,
'exercise angina','ST slope','thal']
categorical_dims = {}
for col in categorical_columns:
print(col, X[col].nunique())
l_enc = LabelEncoder()
X[col] = l_enc.fit_transform(X[col].values)
categorical_dims[col] = len(l_enc.classes_)
cat_idxs = [ i for i, f in enumerate(feats) if f in categorical_columns]
cat_dims = [ categorical_dims[f] for i, f in enumerate(feats) if f in categorical_columns]
X = np.array(X)
target = np.array(dataTab['target'])
x_train, x_val, y_train, y_val = train_test_split(X, target, test_size=0.3, random_state=42)
x_val, x_test, y_val, y_test = train_test_split(x_val, y_val, test_size=0.4, random_state=42)
sex 2 resting ecg 3 chest pain type 4 fasting blood sugar 2 exercise angina 2 ST slope 3 thal 3
Modeling¶
Base Models¶
- LogisticRegression
- RandomForestClassifier
- XGBClassifier
- GradientBoostingClassifier
In [25]:
LR = LogisticRegression(solver='liblinear',random_state=42)
RF = RandomForestClassifier(criterion = 'entropy',random_state=42)
XGB = XGBClassifier(max_depth=5,random_state=42)
GBM = GradientBoostingClassifier(learning_rate = 0.01,random_state=42)
In [26]:
models = [LR,RF,XGB,GBM]
metric_list = []
for m in models:
mName = type(m).__name__
m.fit(x_train,y_train)
y_pred = m.predict(x_test)
auroc = np.round(roc_auc_score(y_test, y_pred),4)
accuracy = np.round(accuracy_score(y_test, y_pred),4)
precision = np.round(precision_score(y_test, y_pred),4)
recall = np.round(recall_score(y_test, y_pred),4)
f1 = np.round(f1_score(y_test, y_pred),4)
globals()[f"y_pred_{mName}"] = y_pred
l = [mName,auroc,accuracy,precision,recall,f1]
metric_list.append(l)
print(mName,":",auroc,accuracy,precision,recall,f1)
LogisticRegression : 0.9 0.9167 1.0 0.8 0.8889 RandomForestClassifier : 0.8762 0.8889 0.9231 0.8 0.8571 XGBClassifier : 0.8429 0.8611 0.9167 0.7333 0.8148 GradientBoostingClassifier : 0.7857 0.8056 0.8333 0.6667 0.7407
In [27]:
df_metric_list = pd.DataFrame(metric_list)
df_metric_list.columns = ['modelName','auroc','accuracy','precision','recall','f1_score']
df_metric_list = df_metric_list.sort_values(["accuracy","auroc"],ascending=False).reset_index(drop=True)
df_metric_list
Out[27]:
modelName | auroc | accuracy | precision | recall | f1_score | |
---|---|---|---|---|---|---|
0 | LogisticRegression | 0.9000 | 0.9167 | 1.0000 | 0.8000 | 0.8889 |
1 | RandomForestClassifier | 0.8762 | 0.8889 | 0.9231 | 0.8000 | 0.8571 |
2 | XGBClassifier | 0.8429 | 0.8611 | 0.9167 | 0.7333 | 0.8148 |
3 | GradientBoostingClassifier | 0.7857 | 0.8056 | 0.8333 | 0.6667 | 0.7407 |
Confusion Matrix for Base Models¶
In [28]:
font = {'family' : 'normal',
'weight' : 'bold',
'size' : 18}
plt.rc('font', **font)
f, axes = plt.subplots(2, 2, figsize=(8,9))
i=0
axes = axes.ravel()
f.suptitle("Confusion Matrix Base Models", fontsize=20, fontweight='bold')
for m in models:
mName = type(m).__name__
disp = ConfusionMatrixDisplay(confusion_matrix(y_test,globals()[f"y_pred_{mName}"]))
disp.plot(ax=axes[i], values_format='.20g')
axes[i].grid(False)
disp.ax_.set_title(mName,fontweight='bold',fontsize=12)
disp.im_.colorbar.remove()
i=i+1
plt.show()
Tabular Net (TABNET)¶
In [29]:
def set_seed(seed: int = 42):
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ["PYTHONHASHSEED"] = str(seed)
set_seed()
In [30]:
tabnet = TabNetClassifier(optimizer_fn=torch.optim.Adam,
optimizer_params=dict(lr=0.01),
scheduler_params={"step_size":100,
"gamma":0.95},
scheduler_fn=torch.optim.lr_scheduler.StepLR,
verbose=0,
mask_type='sparsemax',
cat_dims=cat_dims,cat_idxs=cat_idxs
)
tabnet.fit(
x_train,y_train,
eval_set=[(x_train, y_train), (x_val, y_val)],
eval_name=['train', 'valid'],
eval_metric=['auc','accuracy'],
max_epochs=1000 , patience=50,
batch_size=16,virtual_batch_size=16,
num_workers=0,
drop_last=False
)
y_pred = tabnet.predict(x_test)
test_acc = accuracy_score(y_pred, y_test)
preds_valid = tabnet.predict(x_val)
valid_acc = accuracy_score(preds_valid, y_val)
print("valid_acc:",valid_acc,", test_acc:",test_acc)
Early stopping occurred at epoch 126 with best_epoch = 76 and best_valid_accuracy = 0.85185 valid_acc: 0.8518518518518519 , test_acc: 0.9444444444444444
Tabular Nets Plots¶
In [31]:
fig = plt.figure(figsize=(10, 5))
fig.suptitle("Tabnet Training Loss",fontsize=20)
plt.plot(tabnet.history['loss'])
Out[31]:
[<matplotlib.lines.Line2D at 0x7f56006777d0>]
In [32]:
fig = plt.figure(figsize=(10, 5))
fig.suptitle("Tabnet Train and Valid Accuracy",fontsize=20)
plt.plot(tabnet.history['train_accuracy'])
plt.plot(tabnet.history['valid_accuracy'])
Out[32]:
[<matplotlib.lines.Line2D at 0x7f56005f5190>]
Tabnet Metrics¶
In [33]:
metric_list=[]
auroc = np.round(roc_auc_score(y_test, y_pred),4)
accuracy = np.round(accuracy_score(y_test, y_pred),4)
precision = np.round(precision_score(y_test, y_pred),4)
recall = np.round(recall_score(y_test, y_pred),4)
f1 = np.round(f1_score(y_test, y_pred),4)
l = ['Tabular Net',auroc,accuracy,precision,recall,f1]
metric_list.append(l)
df_metric_list = pd.DataFrame(metric_list)
df_metric_list.columns = ['modelName','auroc','accuracy','precision','recall','f1_score']
df_metric_list = df_metric_list.sort_values(["accuracy","auroc"],ascending=False).reset_index(drop=True)
df_metric_list
Out[33]:
modelName | auroc | accuracy | precision | recall | f1_score | |
---|---|---|---|---|---|---|
0 | Tabular Net | 0.9429 | 0.9444 | 0.9333 | 0.9333 | 0.9333 |
Confusion Matrix for Tabular Net¶
In [34]:
font = {'family' : 'normal',
'weight' : 'bold',
'size' : 18}
plt.rc('font', **font)
f, axes = plt.subplots(1, 1, figsize=(5, 5))
i=0
f.suptitle("Confusion Matrix For Tab Net", fontsize=20, fontweight='bold')
for m in models:
disp = ConfusionMatrixDisplay(confusion_matrix(y_test,y_pred))
disp.plot(ax=axes, values_format='.20g')
axes.grid(False)
disp.im_.colorbar.remove()
i=i+1
plt.show()
Feature Importances¶
In [35]:
feat_importances = pd.Series(tabnet.feature_importances_, index=feats)
feat_importances.nlargest(15).plot(kind='barh')
Out[35]:
<AxesSubplot:>
In [36]:
fig = plt.figure(figsize=(5, 7))
importances = tabnet.feature_importances_
indices = np.argsort(importances)[4:]
plt.title('Feature Importances',fontweight='bold',fontsize=22)
plt.barh(range(len(indices)), importances[indices], color='g', align='center')
plt.yticks(range(len(indices)), [feats[i] for i in indices],
fontweight='bold',fontsize=12)
plt.xlabel('Relative Importance')
plt.show()
Feature Explainability¶
By using the masks, we can understand which features are being used at a prediction level
In [37]:
font = {'family' : 'normal',
'weight' : 'bold',
'size' : 18}
plt.rc('font', **font)
explain_matrix, masks = tabnet.explain(x_test)
f, axes = plt.subplots(1, 3, figsize=(12,6))
axes = axes.ravel()
f.suptitle("Masks", fontsize=20, fontweight='bold')
for i in range(3):
axes[i].imshow(masks[i])
axes[i].set_title(f"mask {i}")
plt.show()
In [38]:
masksum = masks[0]+masks[1]+masks[2]
masksumdf = pd.DataFrame(masksum)
masksumdf.columns = feats
masksumdf
Out[38]:
age | sex | chest pain type | resting bp s | cholesterol | fasting blood sugar | resting ecg | max heart rate | exercise angina | oldpeak | ST slope | ca | thal | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.421180 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.239193 | 0.254920 | 0.084707 | 0.000000 | 1.000000 | 0.000000 | 1.000000 | 0.000000 |
1 | 0.000000 | 0.032162 | 0.000000 | 1.000000 | 0.0 | 0.000000 | 0.000000 | 0.410161 | 0.000000 | 0.557677 | 0.000000 | 1.000000 | 0.000000 |
2 | 0.000000 | 0.000000 | 0.000000 | 0.977114 | 0.0 | 0.000000 | 0.000000 | 0.468335 | 0.000000 | 0.531665 | 0.000000 | 0.022886 | 1.000000 |
3 | 0.272394 | 0.031191 | 0.000000 | 0.000000 | 0.0 | 0.288377 | 0.000000 | 0.054172 | 0.000000 | 1.020705 | 0.000000 | 0.893932 | 0.439229 |
4 | 0.000000 | 0.866500 | 0.000000 | 0.000000 | 0.0 | 0.000000 | 0.564887 | 0.000000 | 0.042170 | 0.000000 | 0.000000 | 1.000000 | 0.526442 |
5 | 0.000000 | 0.000059 | 0.000000 | 0.999941 | 0.0 | 0.000000 | 0.000000 | 0.430287 | 0.000000 | 0.569713 | 0.000000 | 0.000000 | 1.000000 |
6 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.925404 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 0.000000 | 1.000000 | 0.074596 |
7 | 0.669275 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.176432 | 0.063401 | 0.090891 | 0.000000 | 0.815062 | 0.000000 | 1.184938 | 0.000000 |
8 | 0.000000 | 0.277727 | 0.000000 | 0.467442 | 0.0 | 0.254831 | 0.000000 | 0.477747 | 0.000000 | 0.522253 | 0.000000 | 1.000000 | 0.000000 |
9 | 0.091034 | 0.893204 | 0.000000 | 0.000000 | 0.0 | 0.000000 | 0.486976 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 0.528785 |
10 | 0.000000 | 0.052177 | 0.000000 | 0.656777 | 0.0 | 0.000000 | 0.000000 | 0.540997 | 0.000000 | 0.459003 | 0.000000 | 0.291047 | 1.000000 |
11 | 0.000000 | 0.000000 | 0.001158 | 0.000000 | 0.0 | 0.884686 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 1.114155 |
12 | 0.180344 | 0.469384 | 0.000000 | 0.000000 | 0.0 | 0.255227 | 0.244604 | 0.155433 | 0.000000 | 0.000000 | 0.000000 | 1.419348 | 0.275661 |
13 | 0.274975 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.405568 | 0.030806 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 1.016979 | 1.271672 |
14 | 0.181538 | 0.301859 | 0.000000 | 0.000000 | 0.0 | 0.109740 | 0.505199 | 0.147025 | 0.000000 | 0.610345 | 0.000000 | 1.000000 | 0.144294 |
15 | 0.251564 | 0.867052 | 0.000000 | 0.000000 | 0.0 | 0.029788 | 0.326736 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 0.524859 |
16 | 0.000000 | 0.515716 | 0.000000 | 0.541488 | 0.0 | 0.000000 | 0.000000 | 0.282020 | 0.000000 | 0.613046 | 0.000000 | 1.047730 | 0.000000 |
17 | 0.225985 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.547413 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 1.226603 |
18 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.316066 | 0.287177 | 0.304699 | 0.084341 | 0.000000 | 0.000000 | 1.000000 | 1.007718 |
19 | 0.000000 | 0.707483 | 0.000000 | 0.000000 | 0.0 | 0.463950 | 0.000000 | 0.000000 | 0.000000 | 0.059420 | 0.000000 | 1.233096 | 0.536050 |
20 | 0.515521 | 0.028114 | 0.000000 | 0.000000 | 0.0 | 0.279003 | 0.131515 | 0.000000 | 0.000000 | 0.420060 | 0.000000 | 0.697103 | 0.928685 |
21 | 0.000000 | 0.753129 | 0.000000 | 0.136883 | 0.0 | 0.000000 | 0.229975 | 0.000000 | 0.000000 | 0.374388 | 0.000000 | 1.453410 | 0.052215 |
22 | 0.329486 | 0.006196 | 0.000000 | 0.000000 | 0.0 | 0.000000 | 0.114499 | 0.010178 | 0.000000 | 0.914524 | 0.000000 | 0.606188 | 1.018930 |
23 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 1.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 1.000000 |
24 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.874469 | 0.000000 | 0.125531 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 1.000000 |
25 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.700572 | 0.011820 | 0.287608 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 1.000000 |
26 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 0.0 | 0.000000 | 0.000000 | 0.505667 | 0.000000 | 0.494333 | 0.000000 | 1.000000 | 0.000000 |
27 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.637512 | 0.000000 | 0.362488 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 1.000000 |
28 | 0.138939 | 0.597025 | 0.000000 | 0.000000 | 0.0 | 0.070083 | 0.552135 | 0.000000 | 0.000000 | 0.402975 | 0.000000 | 1.000000 | 0.238843 |
29 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.955776 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.044224 | 1.000000 | 1.000000 |
30 | 0.000000 | 0.624539 | 0.000000 | 0.100942 | 0.0 | 0.000000 | 0.000000 | 0.000000 | 0.098922 | 0.000000 | 0.000000 | 1.632383 | 0.543213 |
31 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 0.0 | 0.000000 | 0.000000 | 0.669707 | 0.000000 | 0.330293 | 0.000000 | 0.000000 | 1.000000 |
32 | 0.000000 | 0.000000 | 0.013528 | 0.000000 | 0.0 | 0.986472 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 0.000000 | 1.000000 | 0.000000 |
33 | 0.000000 | 0.385101 | 0.000000 | 0.645728 | 0.0 | 0.000000 | 0.000000 | 0.416254 | 0.000000 | 0.480495 | 0.000000 | 0.072423 | 1.000000 |
34 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 0.0 | 0.000000 | 0.000000 | 0.669707 | 0.000000 | 0.330293 | 0.000000 | 0.000000 | 1.000000 |
35 | 0.261785 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.385700 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 1.062956 | 1.289559 |
In [39]:
font = {'family' : 'normal',
'weight' : 'bold',
'size' : 18}
plt.figure(figsize=(8,6))
plt.rc('font', **font)
sns.heatmap(masksumdf,cbar=True)
plt.rc('font', **font)
plt.title('Features used for prediction',fontweight='bold',fontsize=16)
Out[39]:
Text(0.5, 1.0, 'Features used for prediction')
No comments:
Post a Comment