You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

42 lines
1.4 KiB

import numpy as np
from sklearn import metrics
from sklearn.metrics import average_precision_score
from sklearn.metrics import roc_auc_score, auc
class Evaluator(object):
def __init__(self, edges_pos, edges_neg):
self.edges_pos = edges_pos
self.edges_neg = edges_neg
def get_roc_score(self, emb, feas):
# if emb is None:
# feed_dict.update({placeholders['dropout']: 0})
# emb = sess.run(model.z_mean, feed_dict=feed_dict)
def sigmoid(x):
return 1 / (1 + np.exp(-x))
# Predict on test set of edges
adj_rec = np.dot(emb, emb.T)
preds = []
pos = []
for e in self.edges_pos:
preds.append(sigmoid(adj_rec[e[0], e[1]]))
pos.append(feas['adj_orig'][e[0], e[1]])
preds_neg = []
neg = []
for e in self.edges_neg:
preds_neg.append(sigmoid(adj_rec[e[0], e[1]]))
neg.append(feas['adj_orig'][e[0], e[1]])
preds_all = np.hstack([preds, preds_neg])
labels_all = np.hstack([np.ones(len(preds)), np.zeros(len(preds))])
roc_score = roc_auc_score(labels_all, preds_all)
ap_score = average_precision_score(labels_all, preds_all)
precision, recall, _thresholds = metrics.precision_recall_curve(labels_all, preds_all)
aupr_score = auc(recall, precision)
return roc_score, ap_score, emb, aupr_score