Commit 4204663c authored by Gihan Jayatilaka's avatar Gihan Jayatilaka

Visualizer done

parent 13baece6
File added
This diff is collapsed.
......@@ -294,7 +294,8 @@ xTrain, xTest, yTrain, yTest= train_test_split(X, Y, test_size=0.5)
print("Test set evaluation for Movie video (Without training for movie video): ",model.metrics_names,' ', model.evaluate(xTest,yTest))
model.fit(xTrain,yTrain,validation_split=0.1,epochs=EPOCHS,verbose=1,batch_size=BATCH_SIZE,shuffle=True)
hist= model.fit(xTrain,yTrain,validation_split=0.1,epochs=EPOCHS,verbose=1,batch_size=BATCH_SIZE,shuffle=True)
print("Training history for multiple videos\n",hist.history)
print("Test set evaluation for Movie video (After training for movie video):",model.metrics_names,' ', model.evaluate(xTest,yTest))
......
......@@ -19,7 +19,7 @@
"source": [
"'''\n",
"gihanchanaka@gmail.com\n",
"11-03-2019\n",
"13-03-2019\n",
" 1)This is to learn from BW and predict random\n",
" 2) learn from random and predict movie\n",
"'''"
......@@ -194,7 +194,7 @@
" cost=tf.reduce_mean(sqError,1)\n",
" cost=tf.reduce_mean(cost,0)\n",
" if DEBUG: print(\"DEBUG: cost shape= {}\".format(cost.shape))\n",
" optimizer=tf.train.AdamOptimizer(learning_rate=0.01).minimize(cost)\n",
" optimizer=tf.train.AdamOptimizer(learning_rate=0.001).minimize(cost)\n",
" yPred=tf.round(outputTensor)\n",
" if DEBUG: print(\"DEBUG: yPred shape= {}\".format(yPred.shape))\n",
" correctPrediction=tf.equal(yPred,targetTensor)\n",
......@@ -309,7 +309,7 @@
" batches+=1\n",
"\n",
" except:\n",
" if DEBUG: print(\"\\t No batches of data \".format(batches))\n",
" if DEBUG: print(\"\\t {} batches of data \".format(batches))\n",
" if countSum==0: break\n",
" costVal=costCountProductSum/countSum\n",
" accVal=trainOrTest,accCountProductSum/countSum\n",
......@@ -333,7 +333,7 @@
"OUTPUT_DIM=CELLS_PER_FRAME\n",
"\n",
"\n",
"EPOCHS=100\n",
"EPOCHS=20\n",
"BATCH_SIZE=64\n",
"CUDA1=0\n",
"CUDA2=1\n",
......@@ -349,6 +349,8 @@
"outputs": [],
"source": [
"def trainAndTestForVideo(fileName,noFrames,framesToSkip=0,videoFileFormat='.avi',testSplit=0.1):\n",
" print(\">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>Starting new video file\\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>{}\".format(fileName))\n",
" \n",
" FILE_NAME=fileName\n",
" FILE_NAME_VIDEO=FILE_NAME+videoFileFormat\n",
" FILE_NAME_CSV=FILE_NAME+'.csv'\n",
......@@ -365,15 +367,27 @@
" if DEBUG: print(\"X type: {}, Y type: {}.\".format(dataX.dtype,dataY.dtype))\n",
" xTrain, xTest, yTrain, yTest= train_test_split(dataX, dataY, test_size=testSplit)\n",
" print(\"SIZES: xTrain {}, yTrain {}, xTest {}, yTest {}\".format(xTrain.shape,yTrain.shape,xTest.shape,yTest.shape))\n",
"\n",
" \n",
" hist={\"trainAcc\":[],\"trainCost\":[],\"testAcc\":[],\"testCost\":[]}\n",
" \n",
" \n",
" for i in range(EPOCHS):\n",
" trainAcc,trainCost=printAccuracy(iterr,xTrain,yTrain,acc,cst,\"Training\")\n",
" testAcc,testCost=printAccuracy(iterr,xTest,yTest,acc,cst,\"Testing\")\n",
" \n",
" hist[\"trainAcc\"].append(trainAcc)\n",
" hist[\"trainCost\"].append(trainCost)\n",
" hist[\"testAcc\"].append(testAcc)\n",
" hist[\"testCost\"].append(testCost)\n",
" \n",
" \n",
" if min(trainCost,testCost) < TARGET_COST:\n",
" print(\"Converged!\")\n",
" break\n",
"\n",
" train(iterr,opt,1,xTrain,yTrain)\n",
" \n",
" return hist\n",
" "
]
},
......@@ -469,7 +483,7 @@
}
],
"source": [
"trainAndTestForVideo('./video/bw',10000)"
"histBw=trainAndTestForVideo('./video/bw',10000)"
]
},
{
......@@ -499,7 +513,7 @@
}
],
"source": [
"trainAndTestForVideo('./video/ran',10000)"
"histRan=trainAndTestForVideo('./video/ran',100000)"
]
},
{
......@@ -587,7 +601,30 @@
}
],
"source": [
"trainAndTestForVideo('./video/multipleVideos',10000,videoFileFormat='.mp4',testSplit=0.5)"
"histMultipleVideo=trainAndTestForVideo('./video/multipleVideos',10000,videoFileFormat='.mp4',testSplit=0.5)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'histBw' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-25-61876b85bfed>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhistBw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhistRan\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhistMultipleVideo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mNameError\u001b[0m: name 'histBw' is not defined"
]
}
],
"source": [
"print(histBw)\n",
"print(histRan)\n",
"print(histMultipleVideo)"
]
},
{
......@@ -595,7 +632,26 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"'''\n",
"{'trainCost': [0.487190250597947, 6.275400916450763e-06], \n",
"'testAcc': [('Testing', 0.49980468675494194), ('Testing', 1.0)], \n",
"'testCost': [0.48725608363747597, 1.0347399794592357e-07], \n",
"'trainAcc': [('Training', 0.4994202127270665), ('Training', 0.9999911345488636)]}\n",
"\n",
"\n",
"{'trainCost': [0.20937026709529524, 0.061433447273910466, 0.026764632209290005, 0.015015229714889052, 0.009539862714892796, 0.006624541298241903, 0.004842561076837757, 0.003697494075448642, 0.0029006636097155354, 0.0023166576849547684, 0.0019074803687725381, 0.0015896577150676505, 0.0013644386110030584, 0.001186302326459554, 0.0010262928389373825, 0.0008936834280917434, 0.0007987172024358903, 0.0007147448452139304, 0.0006410042675296034, 0.0005829130136686566], \n",
"'testAcc': [('Testing', 0.7569765672087669), ('Testing', 0.9179531298577785), ('Testing', 0.9588046856224537), ('Testing', 0.9735781326889992), ('Testing', 0.9810937456786633), ('Testing', 0.9852421917021275), ('Testing', 0.9873281307518482), ('Testing', 0.9882656261324883), ('Testing', 0.9894609339535236), ('Testing', 0.9904375039041042), ('Testing', 0.9911250062286854), ('Testing', 0.9915624968707561), ('Testing', 0.9914609342813492), ('Testing', 0.9920546822249889), ('Testing', 0.9925390593707561), ('Testing', 0.9926249943673611), ('Testing', 0.9929453171789646), ('Testing', 0.9930312484502792), ('Testing', 0.9930156245827675), ('Testing', 0.9933359436690807)], \n",
"'testCost': [0.20510075148195028, 0.06685484061017632, 0.03253534168470651, 0.020639818103518337, 0.01502221537521109, 0.011874609510414302, 0.010066515096696094, 0.009086779988138005, 0.008293081540614367, 0.007626592996530235, 0.007099631344317459, 0.006742824451066554, 0.006500241201138124, 0.006219661634531803, 0.005867137078894302, 0.005834336916450411, 0.005574006267124787, 0.00550500194367487, 0.0053980302036507055, 0.0052129948599031195], \n",
"'trainAcc': [('Training', 0.7515354613040356), ('Training', 0.9251923764005621), ('Training', 0.9677101048171943), ('Training', 0.9822925510981404), ('Training', 0.9889725218427942), ('Training', 0.9924822703320929), ('Training', 0.9946498211393965), ('Training', 0.9959246419000287), ('Training', 0.9968235797070443), ('Training', 0.9975150671411068), ('Training', 0.9979884721708636), ('Training', 0.9983262360518705), ('Training', 0.9985983929735549), ('Training', 0.9987508764503695), ('Training', 0.9989361602363857), ('Training', 0.9990868716375202), ('Training', 0.9991790673411485), ('Training', 0.9992748122688726), ('Training', 0.999352827985236), ('Training', 0.9994015862755742)]}\n",
"\n",
"\n",
"{'trainCost': [0.5029900013645993, 0.4776768835285042, 0.4591610099695906, 0.44652620636964147, 0.42667750467227983, 0.4050721575187731, 0.38265376641780513, 0.3606759505935862, 0.3475723809833768, 0.32737633812276623, 0.3119316508498373, 0.2990805570837818, 0.2864646523059169, 0.2791970398607133, 0.2705821962673453, 0.26238762823086753, 0.2565659776895861, 0.24874630155442637, 0.2458962330335303, 0.2412038983046254], \n",
"'testAcc': [('Testing', 0.4862816446944128), ('Testing', 0.5078481017034265), ('Testing', 0.5216139245636856), ('Testing', 0.5325632868688318), ('Testing', 0.5485759513287605), ('Testing', 0.5647943042501619), ('Testing', 0.5853639260123048), ('Testing', 0.6034256348127052), ('Testing', 0.6154746858379508), ('Testing', 0.6347151873986933), ('Testing', 0.6476028493688076), ('Testing', 0.6608702540397644), ('Testing', 0.67348101320146), ('Testing', 0.680324366575555), ('Testing', 0.6890348075311395), ('Testing', 0.6955775286577925), ('Testing', 0.70223892565015), ('Testing', 0.7094066444831558), ('Testing', 0.7109651889982103), ('Testing', 0.7164477856853341)], \n",
"'testCost': [0.5040931848785545, 0.4840255138240283, 0.4693682189983658, 0.4575264374666576, 0.44012309667430344, 0.4228188406817521, 0.4004590971560418, 0.382104583556139, 0.3697694204276121, 0.35061238234556175, 0.3376556023766723, 0.3246032493778422, 0.3124394220641897, 0.3059287022186231, 0.2977016475758975, 0.2912550908100756, 0.2847993453092213, 0.2783400008950052, 0.2768173849658121, 0.2716038244057305], \n",
"'trainAcc': [('Training', 0.48648733919179893), ('Training', 0.5144857590711569), ('Training', 0.532064874715443), ('Training', 0.5438528490971916), ('Training', 0.5622626588314394), ('Training', 0.5832911382747602), ('Training', 0.6044620275497437), ('Training', 0.626368674296367), ('Training', 0.6394382926482188), ('Training', 0.6598180434371852), ('Training', 0.6762341781507565), ('Training', 0.6898813262770448), ('Training', 0.7035522151596939), ('Training', 0.7105221499370623), ('Training', 0.7195411384860172), ('Training', 0.7288370238074774), ('Training', 0.7352531650398351), ('Training', 0.7440980997266649), ('Training', 0.7474762646457817), ('Training', 0.7525158207627791)]}\n",
"'''\n"
]
}
],
"metadata": {
......
......@@ -147,7 +147,7 @@ def buildNetwork(outputTensor,targetTensor):
cost=tf.reduce_mean(sqError,1)
cost=tf.reduce_mean(cost,0)
if DEBUG: print("DEBUG: cost shape= {}".format(cost.shape))
optimizer=tf.train.AdamOptimizer(learning_rate=0.01).minimize(cost)
optimizer=tf.train.AdamOptimizer(learning_rate=0.001).minimize(cost)
yPred=tf.round(outputTensor)
if DEBUG: print("DEBUG: yPred shape= {}".format(yPred.shape))
correctPrediction=tf.equal(yPred,targetTensor)
......@@ -268,7 +268,7 @@ INPUT_DIM=(FRAME_HEIGHT,FRAME_WIDTH)
OUTPUT_DIM=CELLS_PER_FRAME
EPOCHS=100
EPOCHS=20
BATCH_SIZE=64
CUDA1=0
CUDA2=1
......@@ -281,6 +281,8 @@ sess = tf.Session()
def trainAndTestForVideo(fileName,noFrames,framesToSkip=0,videoFileFormat='.avi',testSplit=0.1):
print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>Starting new video file\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>{}".format(fileName))
FILE_NAME=fileName
FILE_NAME_VIDEO=FILE_NAME+videoFileFormat
FILE_NAME_CSV=FILE_NAME+'.csv'
......@@ -297,16 +299,28 @@ def trainAndTestForVideo(fileName,noFrames,framesToSkip=0,videoFileFormat='.avi'
if DEBUG: print("X type: {}, Y type: {}.".format(dataX.dtype,dataY.dtype))
xTrain, xTest, yTrain, yTest= train_test_split(dataX, dataY, test_size=testSplit)
print("SIZES: xTrain {}, yTrain {}, xTest {}, yTest {}".format(xTrain.shape,yTrain.shape,xTest.shape,yTest.shape))
hist={"trainAcc":[],"trainCost":[],"testAcc":[],"testCost":[]}
for i in range(EPOCHS):
trainAcc,trainCost=printAccuracy(iterr,xTrain,yTrain,acc,cst,"Training")
testAcc,testCost=printAccuracy(iterr,xTest,yTest,acc,cst,"Testing")
hist["trainAcc"].append(trainAcc)
hist["trainCost"].append(trainCost)
hist["testAcc"].append(testAcc)
hist["testCost"].append(testCost)
if min(trainCost,testCost) < TARGET_COST:
print("Converged!")
break
train(iterr,opt,1,xTrain,yTrain)
return hist
# In[20]:
......@@ -352,23 +366,48 @@ if DEBUG: print("layer shapes 1:{} 2:{} 3:{} ".format(l1.shape,l2.shape,l3.shape
# In[22]:
trainAndTestForVideo('./video/bw',10000)
histBw=trainAndTestForVideo('./video/bw',10000)
# In[23]:
trainAndTestForVideo('./video/ran',10000)
histRan=trainAndTestForVideo('./video/ran',100000)
# In[24]:
trainAndTestForVideo('./video/multipleVideos',10000,videoFileFormat='.mp4',testSplit=0.5)
histMultipleVideo=trainAndTestForVideo('./video/multipleVideos',10000,videoFileFormat='.mp4',testSplit=0.5)
# In[25]:
print(histBw)
print(histRan)
print(histMultipleVideo)
# In[ ]:
'''
{'trainCost': [0.487190250597947, 6.275400916450763e-06],
'testAcc': [('Testing', 0.49980468675494194), ('Testing', 1.0)],
'testCost': [0.48725608363747597, 1.0347399794592357e-07],
'trainAcc': [('Training', 0.4994202127270665), ('Training', 0.9999911345488636)]}
{'trainCost': [0.20937026709529524, 0.061433447273910466, 0.026764632209290005, 0.015015229714889052, 0.009539862714892796, 0.006624541298241903, 0.004842561076837757, 0.003697494075448642, 0.0029006636097155354, 0.0023166576849547684, 0.0019074803687725381, 0.0015896577150676505, 0.0013644386110030584, 0.001186302326459554, 0.0010262928389373825, 0.0008936834280917434, 0.0007987172024358903, 0.0007147448452139304, 0.0006410042675296034, 0.0005829130136686566],
'testAcc': [('Testing', 0.7569765672087669), ('Testing', 0.9179531298577785), ('Testing', 0.9588046856224537), ('Testing', 0.9735781326889992), ('Testing', 0.9810937456786633), ('Testing', 0.9852421917021275), ('Testing', 0.9873281307518482), ('Testing', 0.9882656261324883), ('Testing', 0.9894609339535236), ('Testing', 0.9904375039041042), ('Testing', 0.9911250062286854), ('Testing', 0.9915624968707561), ('Testing', 0.9914609342813492), ('Testing', 0.9920546822249889), ('Testing', 0.9925390593707561), ('Testing', 0.9926249943673611), ('Testing', 0.9929453171789646), ('Testing', 0.9930312484502792), ('Testing', 0.9930156245827675), ('Testing', 0.9933359436690807)],
'testCost': [0.20510075148195028, 0.06685484061017632, 0.03253534168470651, 0.020639818103518337, 0.01502221537521109, 0.011874609510414302, 0.010066515096696094, 0.009086779988138005, 0.008293081540614367, 0.007626592996530235, 0.007099631344317459, 0.006742824451066554, 0.006500241201138124, 0.006219661634531803, 0.005867137078894302, 0.005834336916450411, 0.005574006267124787, 0.00550500194367487, 0.0053980302036507055, 0.0052129948599031195],
'trainAcc': [('Training', 0.7515354613040356), ('Training', 0.9251923764005621), ('Training', 0.9677101048171943), ('Training', 0.9822925510981404), ('Training', 0.9889725218427942), ('Training', 0.9924822703320929), ('Training', 0.9946498211393965), ('Training', 0.9959246419000287), ('Training', 0.9968235797070443), ('Training', 0.9975150671411068), ('Training', 0.9979884721708636), ('Training', 0.9983262360518705), ('Training', 0.9985983929735549), ('Training', 0.9987508764503695), ('Training', 0.9989361602363857), ('Training', 0.9990868716375202), ('Training', 0.9991790673411485), ('Training', 0.9992748122688726), ('Training', 0.999352827985236), ('Training', 0.9994015862755742)]}
{'trainCost': [0.5029900013645993, 0.4776768835285042, 0.4591610099695906, 0.44652620636964147, 0.42667750467227983, 0.4050721575187731, 0.38265376641780513, 0.3606759505935862, 0.3475723809833768, 0.32737633812276623, 0.3119316508498373, 0.2990805570837818, 0.2864646523059169, 0.2791970398607133, 0.2705821962673453, 0.26238762823086753, 0.2565659776895861, 0.24874630155442637, 0.2458962330335303, 0.2412038983046254],
'testAcc': [('Testing', 0.4862816446944128), ('Testing', 0.5078481017034265), ('Testing', 0.5216139245636856), ('Testing', 0.5325632868688318), ('Testing', 0.5485759513287605), ('Testing', 0.5647943042501619), ('Testing', 0.5853639260123048), ('Testing', 0.6034256348127052), ('Testing', 0.6154746858379508), ('Testing', 0.6347151873986933), ('Testing', 0.6476028493688076), ('Testing', 0.6608702540397644), ('Testing', 0.67348101320146), ('Testing', 0.680324366575555), ('Testing', 0.6890348075311395), ('Testing', 0.6955775286577925), ('Testing', 0.70223892565015), ('Testing', 0.7094066444831558), ('Testing', 0.7109651889982103), ('Testing', 0.7164477856853341)],
'testCost': [0.5040931848785545, 0.4840255138240283, 0.4693682189983658, 0.4575264374666576, 0.44012309667430344, 0.4228188406817521, 0.4004590971560418, 0.382104583556139, 0.3697694204276121, 0.35061238234556175, 0.3376556023766723, 0.3246032493778422, 0.3124394220641897, 0.3059287022186231, 0.2977016475758975, 0.2912550908100756, 0.2847993453092213, 0.2783400008950052, 0.2768173849658121, 0.2716038244057305],
'trainAcc': [('Training', 0.48648733919179893), ('Training', 0.5144857590711569), ('Training', 0.532064874715443), ('Training', 0.5438528490971916), ('Training', 0.5622626588314394), ('Training', 0.5832911382747602), ('Training', 0.6044620275497437), ('Training', 0.626368674296367), ('Training', 0.6394382926482188), ('Training', 0.6598180434371852), ('Training', 0.6762341781507565), ('Training', 0.6898813262770448), ('Training', 0.7035522151596939), ('Training', 0.7105221499370623), ('Training', 0.7195411384860172), ('Training', 0.7288370238074774), ('Training', 0.7352531650398351), ('Training', 0.7440980997266649), ('Training', 0.7474762646457817), ('Training', 0.7525158207627791)]}
'''
This diff is collapsed.
#!/usr/bin/env python
# coding: utf-8
# In[19]:
DEBUG=False
FRAME_WIDTH=200
FRAME_HEIGHT=100
# In[35]:
def displayWeights(W,outputNode):
fig = plt.figure()
ax = fig.gca(projection='3d')
X=np.arange(FRAME_WIDTH)
Y=np.arange(FRAME_HEIGHT)
X,Y=np.meshgrid(X, Y)
Z=W
if DEBUG: print(Z.shape)
Z=np.reshape(Z[:,outputNode],(FRAME_HEIGHT,FRAME_WIDTH))
if DEBUG: print(X.shape,Y.shape,Z.shape)
surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm,
linewidth=0, antialiased=True)
ax.set_zlim(np.min(Z), np.max(Z))
ax.zaxis.set_major_locator(LinearLocator(10))
ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f'))
plt.title("NN weigts deciding the cell no. {}".format(outputNode))
plt.xlabel('Screen width')
plt.ylabel('Screen height')
ax.set_zlabel('\n' + 'NN weights\n')
# Add a color bar which maps values to colors.
fig.colorbar(surf)#, shrink=0.5, aspect=5)
plt.show()
# In[36]:
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
data=np.load('./weights/cnnModel07-weights.npz')
w3Ran=data['w3Ran']
data=np.load('./weights/cnnModel07-weights-bwoverfit.npz')
w3Bw=data['w3Bw']
data=np.load('./weights/cnnModel07-weights-multiplevideos-20.npz')
w3MultipleVideo20=data['w3MultipleVideo']
data=np.load('./weights/cnnModel07-weights-multiplevideos-50.npz')
w3MultipleVideo50=data['w3MultipleVideo']
for o in range(25):
displayWeights(w3MultipleVideo50,o)
# In[ ]:
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment