|
|
|
@ -11,6 +11,7 @@ from evaluation import Evaluator |
|
|
|
|
from model import BGAN |
|
|
|
|
from optimizer import Optimizer, update |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 解析参数 |
|
|
|
|
def parse_args(): |
|
|
|
|
parser = argparse.ArgumentParser(description='BGANDTI') |
|
|
|
@ -19,7 +20,7 @@ def parse_args(): |
|
|
|
|
parser.add_argument('--hidden3', type=int, default=64, help='隐藏层3神经元数量.') |
|
|
|
|
parser.add_argument('--learning_rate', type=float, default=.6 * 0.001, help='学习率') |
|
|
|
|
parser.add_argument('--discriminator_learning_rate', type=float, default=0.001, help='判别器学习率') |
|
|
|
|
parser.add_argument('--epoch', type=int, default=20, help='迭代次数') |
|
|
|
|
parser.add_argument('--epoch', type=int, default=250, help='迭代次数') |
|
|
|
|
parser.add_argument('--seed', type=int, default=50, help='用来打乱数据集') |
|
|
|
|
parser.add_argument('--features', type=int, default=1, help='是(1)否(0)使用特征') |
|
|
|
|
parser.add_argument('--dropout', type=float, default=0., help='Dropout rate (1 - keep probability).') |
|
|
|
@ -56,11 +57,9 @@ if __name__ == "__main__": |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
# 构造模型 |
|
|
|
|
# d_real, discriminator, ae_model, model_z2g, D_Graph, GD_real = DBGAN(placeholders, feas['num_features'], feas['num_nodes'], feas['features_nonzero'], settings) |
|
|
|
|
model = BGAN(placeholders, feas['num_features'], feas['num_nodes'], feas['features_nonzero'], settings) |
|
|
|
|
|
|
|
|
|
# 定义优化器 |
|
|
|
|
# opt = Optimizer(ae_model, model_z2g, D_Graph, discriminator, placeholders, feas['pos_weight'], feas['norm'], d_real, feas['num_nodes'], GD_real) |
|
|
|
|
optimizer = Optimizer(model.ae_model, model.model_z2g, model.D_Graph, model.discriminator, placeholders, feas['pos_weight'], feas['norm'], model.d_real, feas['num_nodes'], model.GD_real, |
|
|
|
|
settings) |
|
|
|
|
|
|
|
|
@ -96,10 +95,8 @@ if __name__ == "__main__": |
|
|
|
|
record.append([roc_score, aupr_score, ap_score]) |
|
|
|
|
record_emb.append(emb) |
|
|
|
|
rec = np.array(record) |
|
|
|
|
# index = rec[:, 0].tolist().index(max(rec[:, 0].tolist())) |
|
|
|
|
# index_pr = rec[:, 1].tolist().index(max(rec[:, 1].tolist())) |
|
|
|
|
emb = record_emb[rec[:, 0].tolist().index(max(rec[:, 0].tolist()))] |
|
|
|
|
ana = record[rec[:, 0].tolist().index(max(rec[:, 0].tolist()))] |
|
|
|
|
ana_pr = record[rec[:, 1].tolist().index(max(rec[:, 1].tolist()))] |
|
|
|
|
# ana_pr = record[rec[:, 1].tolist().index(max(rec[:, 1].tolist()))] |
|
|
|
|
print('The peak [auc] test_roc={:.7f}, aupr={:.7f}, ap={:.7f}'.format(ana[0], ana[1], ana[2])) |
|
|
|
|
print('The peak [aupr] test_roc={:.7f}, aupr={:.7f}, ap={:.7f}'.format(ana_pr[0], ana_pr[1], ana_pr[2])) |
|
|
|
|
# print('The peak [aupr] test_roc={:.7f}, aupr={:.7f}, ap={:.7f}'.format(ana_pr[0], ana_pr[1], ana_pr[2])) |
|
|
|
|