-
Notifications
You must be signed in to change notification settings - Fork 1
/
tsne_plot.py
63 lines (47 loc) · 1.38 KB
/
tsne_plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import numpy as np
import matplotlib.pyplot as plt, mpld3
import nltk
from nltk.corpus import stopwords
import sys
from sklearn.manifold import TSNE
import pdb
STOPWORDS = stopwords.words('english') #+ ['tries', 'north']
#read in tsne data
f = file(sys.argv[1], 'r').read().split('\n')
# plt.subplots_adjust(bottom = 0.1)
datax = []
datay = []
labels = []
fig, ax = plt.subplots(subplot_kw=dict(axisbg='#EEEEEE'))
tags = {}
color_tags = []
for line in f[:-1]:
label, x, y = line.split()
if label in STOPWORDS: continue
pos = nltk.pos_tag([label])
if pos[0][1][0] is not 'N': continue
x = float(x)
y = float(y)
datax.append(x)
datay.append(y)
labels.append(str(pos))
if pos[0][1][0] not in tags:
tags[pos[0][1][0]] = len(tags)
color_tags.append(tags[pos[0][1][0]])
# plt.annotate(label, xy = (x, y), xytext = (0, 0), textcoords = 'offset points')
# plt.scatter(datax, datay, color='white')
N = len(datax)
print N
scatter = ax.scatter(datax,
datay,
# c=np.random.random(size=N),
c=color_tags,
s=100,
alpha=0.3,
cmap=plt.cm.jet)
ax.grid(color='white', linestyle='solid')
tooltip = mpld3.plugins.PointLabelTooltip(scatter, labels=labels)
mpld3.plugins.connect(fig, tooltip)
mpld3.show()
plt.savefig('tsne.pdf')
# plt.show()