Commit 8fac9517 authored by Gihan Jayatilaka's avatar Gihan Jayatilaka

vis2 completed

parent 172e0b44
......@@ -17,7 +17,7 @@ SQRT_CELLS=0
def plot_confusion_matrix(y_true, y_pred, classes,
normalize=False,
title=None,
cmap=plt.cm.Blues):
cmap=plt.cm.Blues, param1=0,param2=0):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
......@@ -36,6 +36,20 @@ def plot_confusion_matrix(y_true, y_pred, classes,
# Only use the labels that appear in the data
#classes = classes[unique_labels(y_true, y_pred)]
classesBin=[]
for x in range(len(classes)):
xx=bin(classes[x])[2:]
while len(xx)<(param1+param2):
xx='0'+xx
for j in range(1,param2):
xx=xx[:(j*param1)]+'\n'+xx[(j*param1):]
classesBin.append(xx)
classes=classesBin
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
......@@ -57,7 +71,7 @@ def plot_confusion_matrix(y_true, y_pred, classes,
xlabel='Predicted label')
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
plt.setp(ax.get_xticklabels(), rotation=0, ha="right",
rotation_mode="anchor")
# Loop over data dimensions and create text annotations.
......@@ -132,7 +146,7 @@ def countMisses(yTrue,yPred,windowH,windowW):
if DEBUG:
print("convTrue shape={}, convPred shape={}".format(convTrue.shape,convPred.shape))
plot_confusion_matrix(convTrue,convPred,np.arange(2**(windowH*windowW)))
plot_confusion_matrix(convTrue,convPred,np.arange(2**(windowH*windowW)),param1=windowW,param2=windowH,normalize=True)
'''for y in range(convTrue.shape[0]):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment