Roc curve for multi class classification

from sklearn.model_selection import train_test_split

import numpy as np

import pandas as pd

from sklearn.preprocessing import StandardScaler

from sklearn.datasets import load_iris

from sklearn.utils import extmath

from sklearn.svm import SVC

from sklearn.preprocessing import OneHotEncoder

from sklearn.metrics import roc_curve

from sklearn.metrics import roc_auc_score

from matplotlib import pyplot

from itertools import cycle

from sklearn.preprocessing import label_binarize

from sklearn.metrics import roc_curve, auc

iris = load_iris()

# Load iris into a dataframe and set the field names

df = pd.DataFrame(iris['data'], columns=iris['feature_names'])

df.head()

y = iris.target

X = df.iloc[:, 0:4]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.5, random_state=0)

clf = SVC(kernel='linear')

# fitting x samples and y classes

clf.fit(X_train, y_train)

ypred = clf.predict(X_test)

y_test = pd.DataFrame(y_test)

ypred=pd.DataFrame(ypred)

y_test=y_test.values.reshape(-1,1)

ypred=ypred.values.reshape(-1,1)

onehotencoder = OneHotEncoder()

y_test= onehotencoder.fit_transform(y_test).toarray()

ypred = onehotencoder.fit_transform(ypred).toarray()

n_classes = ypred.shape[1]

# Plotting and estimation of FPR, TPR

fpr = dict()

tpr = dict()

roc_auc = dict()

for i in range(n_classes):

fpr[i], tpr[i], _ = roc_curve(y_test[:, i], ypred[:, i])

roc_auc[i] = auc(fpr[i], tpr[i])

colors = cycle(['blue', 'green', 'red','darkorange','olive','purple','navy'])

for i, color in zip(range(n_classes), colors):

pyplot.plot(fpr[i], tpr[i], color=color, lw=1.5, label='ROC curve of class {0}(area = {1:0.2f})' ''.format(i+1, roc_auc[i]))

pyplot.plot([0, 1], [0, 1], 'k--', lw=1.5)

pyplot.xlim([-0.05, 1.0])

pyplot.ylim([0.0, 1.05])

pyplot.xlabel('False Positive Rate',fontsize=12, fontweight='bold')

pyplot.ylabel('True Positive Rate',fontsize=12, fontweight='bold')

pyplot.tick_params(labelsize=12)

pyplot.legend(loc="lower right")

ax = pyplot.axes()

pyplot.show()