This commit is contained in:
Morvan Zhou
2017-06-18 00:18:42 +10:00
committed by Morvan Zhou
parent dcc4e7cb84
commit 51f1c938f3
7 changed files with 32 additions and 79 deletions

View File

@ -85,26 +85,16 @@ loss_func = nn.CrossEntropyLoss() # the target label is no
# following function (plot_with_labels) is for visualization, can be ignored if not interested
from matplotlib import cm
try:
from sklearn.manifold import TSNE
HAS_SK = True
except:
HAS_SK = False
print('Please install sklearn for layer visualization')
try: from sklearn.manifold import TSNE; HAS_SK = True
except: HAS_SK = False; print('Please install sklearn for layer visualization')
def plot_with_labels(lowDWeights, labels):
plt.cla()
X, Y = lowDWeights[:, 0], lowDWeights[:, 1]
for x, y, s in zip(X, Y, labels):
c = cm.rainbow(int(255 * s / 9))
plt.text(x, y, s, backgroundcolor=c, fontsize=9)
plt.xlim(X.min(), X.max())
plt.ylim(Y.min(), Y.max())
plt.title('Visualize last layer')
plt.show()
plt.pause(0.01)
c = cm.rainbow(int(255 * s / 9)); plt.text(x, y, s, backgroundcolor=c, fontsize=9)
plt.xlim(X.min(), X.max()); plt.ylim(Y.min(), Y.max()); plt.title('Visualize last layer'); plt.show(); plt.pause(0.01)
plt.ion()
# training and testing
for epoch in range(EPOCH):
for step, (x, y) in enumerate(train_loader): # gives batch data, normalize x when iterate train_loader