You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

95 lines
3.2 KiB

2 years ago
import os
import random
import numpy as np
import scipy.sparse as sp
from src import config
def load_luo_data(dataset):
dp = np.loadtxt('../../data/RawData/luo/mat_drug_protein.txt'.format(dataset), dtype=int)
dd = np.loadtxt('../../data/RawData/luo/mat_drug_drug.txt'.format(dataset), dtype=int)
pp = np.loadtxt('../../data/RawData/luo/mat_protein_protein.txt'.format(dataset), dtype=int)
adj = np.vstack((np.hstack((dd, dp)), np.hstack((dp.T, pp))))
return sp.csr_matrix(adj + sp.eye(adj.shape[0])), dd.shape[0]
def load_yam_data(dataset):
dp = np.loadtxt('../../data/RawData/Yamanishi/{}_admat_dgc.txt'.format(dataset), dtype=str, delimiter='\t')[1:, 1:].astype(np.int).T
dd = np.loadtxt('../../data/RawData/Yamanishi/{}_simmat_dc.txt'.format(dataset), dtype=str, delimiter='\t')[1:, 1:].astype(np.float)
pp = np.loadtxt('../../data/RawData/Yamanishi/{}_simmat_dg.txt'.format(dataset), dtype=str, delimiter='\t')[1:, 1:].astype(np.float)
dd = np.where(dd < 0.5, 0, 1)
pp = np.where(pp < 0.5, 0, 1)
adj = np.vstack((np.hstack((dd, dp)), np.hstack((dp.T, pp))))
return sp.csr_matrix(adj), dd.shape[0]
def is_symmetry(adj):
for i in range(adj.shape[0]):
for j in range(adj.shape[1]):
if adj[i][j] != adj[j][i]:
return False
return True
def is_1_diag(adj):
if sum(np.diagonal(adj)) != adj.shape[0]:
return False
return True
def change_unbalanced(adj, percent, dp_line, dataset):
"""
note: percent控制屏蔽掉的节点所占的百分比
:param adj:
:param percent:
:return: 返回去除部分已知关联的邻接矩阵
"""
# 判断是否对称
# assert is_symmetry(adj.A)
adj = adj - sp.dia_matrix((adj.diagonal()[np.newaxis, :], [0]), shape=adj.shape) + sp.eye(adj.shape[0])
# 判断对角线是否全为1
assert is_1_diag(adj.A)
adj = (sp.triu(adj) + sp.triu(adj).T - sp.eye(adj.shape[0])).A
row = list(range(0, dp_line))
col = list(range(dp_line, adj.shape[0]))
idx = []
for i in row:
for j in col:
if i != j and adj[i][j] == 1:
idx.append((i, j))
num = int(np.floor(percent * len(idx)))
count = 0
# random.seed(config.seed)
while count < num:
row, col = random.choice(idx)
idx.remove((row, col))
adj[row][col] = 0
adj[col][row] = 0
count += 1
# idx = []
# for i in range(adj.shape[0]):
# for j in range(i + 1, adj.shape[0]):
# if adj[i][j] == 1:
# idx.append((i, j))
# num = int(np.floor(percent * len(idx)))
# count = 0
# # random.seed(config.seed)
# while count < num:
# row, col = random.choice(idx)
# idx.remove((row, col))
# adj[row][col] = 0
# adj[col][row] = 0
# count += 1
# 保存改变不平衡性后新的dp
new_dp = adj[0:dp_line, dp_line:]
# if not os.path.exists('../../data/partitioned_data/{0}/feature'.format(dataset)):
# os.mkdir('../../data/partitioned_data/{0}/feature'.format(dataset))
# np.savetxt('../../data/partitioned_data/{0}/feature/{0}_new_admat_dgc.txt'.format(dataset), new_dp, fmt='%d', delimiter='\t')
return sp.csr_matrix(adj.astype(np.int))