update
This commit is contained in:
@ -85,26 +85,16 @@ loss_func = nn.CrossEntropyLoss() # the target label is no
|
||||
|
||||
# following function (plot_with_labels) is for visualization, can be ignored if not interested
|
||||
from matplotlib import cm
|
||||
try:
|
||||
from sklearn.manifold import TSNE
|
||||
HAS_SK = True
|
||||
except:
|
||||
HAS_SK = False
|
||||
print('Please install sklearn for layer visualization')
|
||||
try: from sklearn.manifold import TSNE; HAS_SK = True
|
||||
except: HAS_SK = False; print('Please install sklearn for layer visualization')
|
||||
def plot_with_labels(lowDWeights, labels):
|
||||
plt.cla()
|
||||
X, Y = lowDWeights[:, 0], lowDWeights[:, 1]
|
||||
for x, y, s in zip(X, Y, labels):
|
||||
c = cm.rainbow(int(255 * s / 9))
|
||||
plt.text(x, y, s, backgroundcolor=c, fontsize=9)
|
||||
plt.xlim(X.min(), X.max())
|
||||
plt.ylim(Y.min(), Y.max())
|
||||
plt.title('Visualize last layer')
|
||||
plt.show()
|
||||
plt.pause(0.01)
|
||||
c = cm.rainbow(int(255 * s / 9)); plt.text(x, y, s, backgroundcolor=c, fontsize=9)
|
||||
plt.xlim(X.min(), X.max()); plt.ylim(Y.min(), Y.max()); plt.title('Visualize last layer'); plt.show(); plt.pause(0.01)
|
||||
|
||||
plt.ion()
|
||||
|
||||
# training and testing
|
||||
for epoch in range(EPOCH):
|
||||
for step, (x, y) in enumerate(train_loader): # gives batch data, normalize x when iterate train_loader
|
||||
|
||||
Reference in New Issue
Block a user