Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
goodman1204 committed Aug 28, 2020
0 parents commit 2a3a6ee
Show file tree
Hide file tree
Showing 36 changed files with 3,029,865 additions and 0 deletions.
108 changes: 108 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
.idea/
logs/
pretrain_model.pk
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
result/*.npy

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2018 zfjsail

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
63 changes: 63 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# CAN-pytorch
This is a PyTorch implementation of the CAN model described in the paper:
PyTorch version for CAN: Co-embedding Attributed Networks based on code <https://github.com/zfjsail/gae-pytorch> and <https://github.com/mengzaiqiao/CAN>
>Zaiqiao Meng, Shangsong Liang, Hongyan Bao, Xiangliang Zhang. Co-embedding Attributed Networks. (WSDM2019)

### Requirements
- Python 3.7.4
- PyTorch 1.5.0
- install requirements via

```
pip install -r requirements.txt
```

### How to run

```
python train.py
```


#### Facebook dataset with the default parameter settings

```
Epoch: 0194 train_loss= 0.75375 log_lik= 0.69382 KL= 0.05993 train_acc= 0.73382 val_edge_roc= 0.98516 val_edge_ap= 0.98455 val_attr_roc= 0.95473 val_attr_ap= 0.95721 time= 1.71142
Epoch: 0195 train_loss= 0.75215 log_lik= 0.69217 KL= 0.05998 train_acc= 0.73484 val_edge_roc= 0.98577 val_edge_ap= 0.98492 val_attr_roc= 0.95465 val_attr_ap= 0.95746 time= 1.64731
Epoch: 0196 train_loss= 0.75135 log_lik= 0.69133 KL= 0.06002 train_acc= 0.73486 val_edge_roc= 0.98588 val_edge_ap= 0.98486 val_attr_roc= 0.95322 val_attr_ap= 0.95755 time= 1.64199
Epoch: 0197 train_loss= 0.75140 log_lik= 0.69134 KL= 0.06006 train_acc= 0.73556 val_edge_roc= 0.98545 val_edge_ap= 0.98477 val_attr_roc= 0.95652 val_attr_ap= 0.95914 time= 1.63010
Epoch: 0198 train_loss= 0.75157 log_lik= 0.69146 KL= 0.06010 train_acc= 0.73477 val_edge_roc= 0.98573 val_edge_ap= 0.98490 val_attr_roc= 0.95497 val_attr_ap= 0.95753 time= 1.65039
Epoch: 0199 train_loss= 0.75122 log_lik= 0.69107 KL= 0.06015 train_acc= 0.73400 val_edge_roc= 0.98620 val_edge_ap= 0.98523 val_attr_roc= 0.95420 val_attr_ap= 0.95829 time= 1.66717
Epoch: 0200 train_loss= 0.74931 log_lik= 0.68914 KL= 0.06017 train_acc= 0.73667 val_edge_roc= 0.98601 val_edge_ap= 0.98515 val_attr_roc= 0.95426 val_attr_ap= 0.95744 time= 1.65484
Optimization Finished!
Test edge ROC score: 0.9853779088016957
Test edge AP score: 0.9836879718079673
Test attr ROC score: 0.9578314765862058
Test attr AP score: 0.9577498373032282
```

#### CiteSeer dataset with the default parameter settings
```
Epoch: 0198 train_loss= 0.81845 log_lik= 0.76834 KL= 0.05011 train_acc= 0.66264 val_edge_roc= 0.94756 val_edge_ap= 0.95467 val_attr_roc= 0.92974 val_attr_ap= 0.92059 time= 1.70837
/Users/storen/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py:1558: UserWarning: nn.functional.tanh is deprecated. Use torch.tanh instead.
/Users/storen/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py:1558: UserWarning: nn.functional.tanh is deprecated. Use torch.tanh instead.
Optimization Finished!
Test edge ROC score: 0.9490130318561492
Test edge AP score: 0.95856792990438
Test attr ROC score: 0.9239066109625775
Test attr AP score: 0.9121636661521142
```

#### Cora dataset with the default parameter settings
```
Epoch: 0197 train_loss= 0.92730 log_lik= 0.85651 KL= 0.07079 train_acc= 0.64933 val_edge_roc= 0.98261 val_edge_ap= 0.97795 val_attr_roc= 0.89457 val_attr_ap= 0.88565 time= 0.85166
Epoch: 0198 train_loss= 0.92594 log_lik= 0.85511 KL= 0.07083 train_acc= 0.65040 val_edge_roc= 0.98230 val_edge_ap= 0.97761 val_attr_roc= 0.89448 val_attr_ap= 0.88571 time= 0.82273
Epoch: 0199 train_loss= 0.92517 log_lik= 0.85432 KL= 0.07085 train_acc= 0.65058 val_edge_roc= 0.98256 val_edge_ap= 0.97801 val_attr_roc= 0.89523 val_attr_ap= 0.88651 time= 0.81977
Epoch: 0200 train_loss= 0.92596 log_lik= 0.85508 KL= 0.07088 train_acc= 0.64968 val_edge_roc= 0.98289 val_edge_ap= 0.97837 val_attr_roc= 0.89585 val_attr_ap= 0.88720 time= 0.90153
Optimization Finished!
Test edge ROC score: 0.983134251252218
Test edge AP score: 0.9817151099782778
Test attr ROC score: 0.895140476178776
Test attr AP score: 0.8847338611264453
```
3 changes: 3 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@



61 changes: 61 additions & 0 deletions classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from __future__ import division
from __future__ import print_function
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.svm import LinearSVC,SVC
from sklearn.metrics import *


def multiclass_node_classification_eval(X, y, ratio=0.5, rnd=2018):

X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=ratio, random_state=rnd)
clf = SVC()
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)

macro_f1 = f1_score(y_test, y_pred, average="macro")
micro_f1 = f1_score(y_test, y_pred, average="micro")

return macro_f1, micro_f1


def node_classification_F1(Embeddings, y, ratio):
macro_f1_avg = 0
micro_f1_avg = 0
for i in range(10):
rnd = np.random.randint(2018)
macro_f1, micro_f1 = multiclass_node_classification_eval(
Embeddings, y, ratio, rnd)
macro_f1_avg += macro_f1
micro_f1_avg += micro_f1
macro_f1_avg /= 10
micro_f1_avg /= 10
print ("Macro_f1: " + str(macro_f1_avg))
print ("Micro_f1: " + str(micro_f1_avg))


def read_label(inputFileName):
f = open(inputFileName, "r")
lines = f.readlines()
f.close()
N = len(lines)
y = np.zeros(N, dtype=int)
i = 0
for line in lines:
l = line.strip("\n\r")
y[i] = int(l)
i += 1
return y


datasets = ['cora' ]#'cora', 'citeseer', 'pubmed', 'pubmed','BlogCatalog']
for datasetname in datasets:
for ratio in [0.2]:
print('dataset:', datasetname, ',ratio:', ratio)
embedding_node_result_file = "result/AGAE_{}_n_mu.emb.npy".format(datasetname)
label_file = "data/" + datasetname + ".label"
y = read_label(label_file)
Embeddings = np.load(embedding_node_result_file)
node_classification_F1(Embeddings, y, ratio)
Loading

0 comments on commit 2a3a6ee

Please sign in to comment.