You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
42 lines
1.4 KiB
42 lines
1.4 KiB
import numpy as np |
|
from sklearn import metrics |
|
from sklearn.metrics import average_precision_score |
|
from sklearn.metrics import roc_auc_score, auc |
|
|
|
|
|
class Evaluator(object): |
|
def __init__(self, edges_pos, edges_neg): |
|
self.edges_pos = edges_pos |
|
self.edges_neg = edges_neg |
|
|
|
def get_roc_score(self, emb, feas): |
|
# if emb is None: |
|
# feed_dict.update({placeholders['dropout']: 0}) |
|
# emb = sess.run(model.z_mean, feed_dict=feed_dict) |
|
|
|
def sigmoid(x): |
|
return 1 / (1 + np.exp(-x)) |
|
|
|
# Predict on test set of edges |
|
adj_rec = np.dot(emb, emb.T) |
|
preds = [] |
|
pos = [] |
|
for e in self.edges_pos: |
|
preds.append(sigmoid(adj_rec[e[0], e[1]])) |
|
pos.append(feas['adj_orig'][e[0], e[1]]) |
|
|
|
preds_neg = [] |
|
neg = [] |
|
for e in self.edges_neg: |
|
preds_neg.append(sigmoid(adj_rec[e[0], e[1]])) |
|
neg.append(feas['adj_orig'][e[0], e[1]]) |
|
|
|
preds_all = np.hstack([preds, preds_neg]) |
|
labels_all = np.hstack([np.ones(len(preds)), np.zeros(len(preds))]) |
|
roc_score = roc_auc_score(labels_all, preds_all) |
|
ap_score = average_precision_score(labels_all, preds_all) |
|
|
|
precision, recall, _thresholds = metrics.precision_recall_curve(labels_all, preds_all) |
|
aupr_score = auc(recall, precision) |
|
|
|
return roc_score, ap_score, emb, aupr_score
|
|
|