Commit b94d27db authored by MaNingChina's avatar MaNingChina

submit files

parent 6c5848bf
import os
import os.path as osp
import shutil
import pickle
import torch.utils.data as data
import numpy as np
import torch
import random
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchnet as tnt
import gl #globle variables
class GraphDataSet(data.Dataset):
def __init__(self,phase="train",base_folder="/home/data/yangjieyu/graph_maml/data/exp_data/TRIANGLES",dataset_name="TRIANGLES"):
super(GraphDataSet,self).__init__()
self.base_folder=base_folder
self.datasetName=dataset_name
self.phase=phase
self.base_folder=os.path.join("./data",dataset_name)
if dataset_name in "COIL-RAG":
self.num_features=64
elif dataset_name in "COIL-DEL":
self.num_features=2
elif dataset_name in '20ng':
self.num_features = 1
elif dataset_name in 'REDDIT-MULTI-12K':
self.num_features = 1
elif dataset_name in 'ohsumed':
self.num_features = 1
elif dataset_name in 'R52':
self.num_features = 1
elif dataset_name in 'Letter_high':
self.num_features = 1
elif dataset_name in "TRIANGLES":
self.num_features = 1
elif dataset_name in "Reddit":
self.num_features = 1
elif dataset_name in "ENZYMES":
self.num_features = 1
node_attribures_path=os.path.join(self.base_folder,dataset_name+"_node_attributes.pickle")
if gl.global_attributes is None:
attrs = self.load_pickle(node_attribures_path)
# print(attrs.shape)
attrs = list(map(float, attrs))
gl.global_attributes = attrs
self.node_attribures=gl.global_attributes
self.saved_set = self.load_pickle(
os.path.join(self.base_folder, dataset_name + "_{}_set.pickle".format(phase)))
self.graph_indicator = {}
graph2nodes = self.saved_set["graph2nodes"]
for graph_index, node_list in graph2nodes.items():
for node in node_list:
self.graph_indicator[node] = graph_index
self.num_graph=len(graph2nodes)
def load_pickle(self,file_name):
with open(file_name, 'rb') as f:
data = pickle.load(f)
return data
def __getitem__(self, index):
return self.graph_indicator[index]
def __len__(self):
return len(self.saved_set["label2graphs"])
class FewShotDataloader():
def __init__(self,
dataset,
n_way=5, # number of novel categories.
n_shot=5, # number of training examples per novel category.
n_query=5, # number of test examples for all the novel categories.
batch_size=1, # number of training episodes per batch.
num_workers=4,
epoch_size=2000, # number of batches per epoch.
):
self.label2graphs,self.graph2nodes,self.graph2edges=dataset.saved_set["label2graphs"],dataset.saved_set["graph2nodes"],dataset.saved_set["graph2edges"]
self.dataset = dataset
self.phase = self.dataset.phase
# max_possible_nKnovel = (self.dataset.num_cats_base if self.phase=='train'
# else self.dataset.num_cats_novel)
self.n_way=n_way
self.n_shot=n_shot
self.n_query=n_query
self.batch_size = batch_size
self.epoch_size = epoch_size
self.num_workers = num_workers
self.is_eval_mode=(self.phase=='test') or (self.phase=='val')
def sample_graph_id(self,sampled_classes):
"""
:param sampled_class: class id selected,==n_way
:return: support_graphs,graph ids of support,shape:n_way*n_shot
query_graph=graph ids of query ,shape:n_way*n_query
"""
support_graphs=[]
query_graphs = []
support_labels=[]
query_labels=[]
for index,label in enumerate(sampled_classes):
graphs=self.label2graphs[label]
assert (len(graphs)) >= self.n_shot + self.n_query
selected_graphs=random.sample(graphs,self.n_shot+self.n_query)
support_graphs.extend(selected_graphs[0:self.n_shot])
query_graphs.extend(selected_graphs[self.n_shot:])
support_labels.extend([index]*self.n_shot)
query_labels.extend([index]*self.n_query)
sindex=list(range(len(support_graphs)))
random.shuffle(sindex)
support_graphs=np.array(support_graphs)[sindex]
support_labels=np.array(support_labels)[sindex]
qindex=list(range(len(query_graphs)))
random.shuffle(qindex)
query_graphs=np.array(query_graphs)[qindex]
query_labels=np.array(query_labels)[qindex]
return np.array(support_graphs),np.array(query_graphs),np.array(support_labels),np.array(query_labels)
def sample_graph_data(self,graph_ids):
"""
:param graph_ids: a numpy shape n_way*n_shot/query
:return:
"""
edge_index=[]
graph_indicator=[]
node_attr=[]
node_number=0
for index,gid in enumerate(graph_ids):
nodes=self.graph2nodes[gid]
new_nodes=list(range(node_number,node_number+len(nodes)))
node_number=node_number+len(nodes)
node2new_number=dict(zip(nodes,new_nodes))
node_attr.append(np.array([self.dataset.node_attribures[node] for node in nodes]).reshape(len(nodes),-1))
edge_index.extend([[node2new_number[edge[0]],node2new_number[edge[1]]]for edge in self.graph2edges[gid]])
graph_indicator.extend([index]*len(nodes))
node_attr = np.vstack(node_attr)
return [torch.from_numpy(node_attr).float(), \
torch.from_numpy(np.array(edge_index)).long(), \
torch.from_numpy(np.array(graph_indicator)).long()]
def sample_episode(self):
"""Samples a training episode."""
classes= random.sample(self.label2graphs.keys(),self.n_way)
# print(classes)
support_graphs,query_graphs,support_labels,query_labels=self.sample_graph_id(classes)
support_data=self.sample_graph_data(support_graphs)
support_labels=torch.from_numpy(support_labels).long()
support_data.append(support_labels)
query_data = self.sample_graph_data(query_graphs)
query_labels=torch.from_numpy(query_labels).long()
query_data.append(query_labels)
return support_data,query_data
def get_iterator(self, epoch=0):
rand_seed = epoch
random.seed(rand_seed)
np.random.seed(rand_seed)
def load_function(iter_idx):
support_data,query_data =self.sample_episode()
return support_data,query_data
tnt_dataset = tnt.dataset.ListDataset(
elem_list=range(self.epoch_size), load=load_function)
data_loader = tnt_dataset.parallel(
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=(False if self.is_eval_mode else True)
# shuffle=True
)
return data_loader
def __call__(self, epoch=0):
return self.get_iterator(epoch)
def __len__(self):
return int(self.epoch_size / self.batch_size)
if __name__ == '__main__':
GraphDataSet()
#for load speed of data
global_edges=None
global_labels=None
global_indicator=None
# global_edges=None
global_attributes=None
import argparse
from data.dataset1 import GraphDataSet,FewShotDataloader
from models.meta_ada import Meta
from tqdm import tqdm
from tensorboardX import SummaryWriter
from utils import *
setup_seed()
def get_dataset(dataset):
val_data = GraphDataSet(phase="val", dataset_name=dataset)
train_data=GraphDataSet(phase="train",dataset_name=dataset)
test_data=GraphDataSet(phase="test",dataset_name=dataset)
return train_data,val_data,test_data
def get_model(config):
model=None
model_type=config["model_type"]
print("encoder:",model_type)
if model_type in "gcn":
from models.GCN4maml import Model
model=Model(config)
elif model_type in "sage":
from models.sage4maml_model import Model
model = Model(config)
return model
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='./config/adamaml_tri.json')
parser.add_argument('--root_directory', type=str, default='./experiments')
args = parser.parse_args()
# print(args)
config = unserialize(args.config)
print("dataset", config["dataset"], "{}way{}shot".format(config["test_way"], config["val_shot"]),"gpu",config["device"])
training_set, validation_set, test_set = get_dataset(config["dataset"])
config["num_features"] = training_set.num_features
train_loader = FewShotDataloader(training_set,
n_way=config["train_way"],
n_shot=config["train_shot"],
n_query=config["train_query"],
batch_size=1, # number of training episodes per batch.
num_workers=4,
epoch_size=config["train_episode"], # number of batches per epoch.
)
val_loader=None
if validation_set is not None:
val_loader = FewShotDataloader(validation_set,
n_way=config["test_way"],
n_shot=config["val_shot"],
n_query=config["val_query"],
batch_size=1,
num_workers=4,
epoch_size=config["val_episode"],
)
test_loader = FewShotDataloader(test_set,
n_way=config["test_way"], # number of novel categories.
n_shot=config["val_shot"], # number of training examples per novel category.
n_query=config["val_query"], # number of test examples for all the novel categories.
batch_size=1, # number of training episodes per batch.
num_workers=4,
epoch_size=config["val_episode"], # number of batches per epoch.
)
model = get_model(config).to(config["device"])
meta_model=Meta(model,config).to(config["device"])
if config["double"]==True:
model=model.double()
meta_model=meta_model.double()
pa=get_para_num(meta_model)
print(pa)
writer=None
root_directory = args.root_directory
if config['save']:
project_name = config["dataset"]
data_directory = os.path.join(root_directory, "log", "-".join((project_name, time.strftime("%m-%d"))))
check_dir(data_directory)
log_file = os.path.join(data_directory, "_".join(("log", time.strftime("%m-%d-%H-%M"))))
check_dir(log_file)
writer = SummaryWriter(log_file, comment='Normal')
# print("log_dir:",log_file)
serialize(config, os.path.join(log_file, "config.json"), in_json=True)
else:
data_directory, writer = None, None
log_file = os.path.join(data_directory, "_".join(("log", time.strftime("%m-%d-%H-%M"))))
config["save_path"] = log_file
config["log_file"] = os.path.join(log_file, "print.txt")
log(config["log_file"], str(vars(args)))
np.set_printoptions(precision=3)
def run():
write_count=0
val_count=0
device=config["device"]
t = time.time()
max_val_acc=0
max_score_val_acc=0
min_step=config["min_step"]
test_step=config["step_test"]
for epoch in range(config["epochs"]):
loss_train = 0.0
correct = 0
meta_model.train()
train_accs, train_final_losses,train_total_losses, val_accs, val_losses = [], [], [], [],[]
score_val_acc=[]
for i, data in enumerate(tqdm(train_loader(epoch)), 1):
support_data, query_data=data
support_data=[item.to(device) for item in support_data]
if config["double"] == True:
support_data[0]=support_data[0].double()
query_data[0] = query_data[0].double()
query_data=[item.to(device) for item in query_data]
accs,step,final_loss,total_loss,stop_gates,scores,train_losses,train_accs_support=meta_model(support_data, query_data)
train_accs.append(accs[step])
train_final_losses.append(final_loss)
train_total_losses.append(total_loss)
#
if (i+1)%100==0:
if np.sum(stop_gates) > 0:
print("\nstep",len(stop_gates),np.array(stop_gates))
print("accs{:.6f},final_loss{:.6f},total_loss{:.6f}".format(np.mean(train_accs),np.mean(train_final_losses),
np.mean(train_total_losses)))
# validation_stage
meta_model.eval()
for i, data in enumerate(tqdm(val_loader(epoch)), 1):
support_data, query_data = data
if config["double"]==True:
support_data[0] = support_data[0].double()
query_data[0] = query_data[0].double()
support_data = [item.to(device) for item in support_data]
query_data = [item.to(device) for item in query_data]
accs, step, stop_gates, scores, query_losses = meta_model.finetunning(support_data, query_data)
acc=get_max_acc(accs,step,scores,min_step,test_step)
val_accs.append(accs[step])
# train_losses.append(loss)
if (i+1) % 200 == 0:
print("\n{}th test".format(i))
if np.sum(stop_gates)>0:
print("stop_prob", len(stop_gates), np.array(stop_gates))
print("scores", len(scores), np.array(scores))
print("query_losses", len(query_losses), np.array(query_losses))
print("accs", step, np.array([accs[i] for i in range(0, step + 1)]))
val_acc_avg=np.mean(val_accs)
train_acc_avg=np.mean(train_accs)
train_loss_avg =np.mean(train_final_losses)
val_acc_ci95 = 1.96 * np.std(np.array(val_accs)) / np.sqrt(config["val_episode"])
if val_acc_avg > max_val_acc:
max_val_acc = val_acc_avg
log(config["log_file"],'\nEpoch(***Best***): {:04d},loss_train: {:.6f},acc_train: {:.6f},'
'acc_val:{:.2f} ±{:.2f},meta_lr: {:.6f},time: {:.2f}s,best {:.2f}'
.format(epoch,train_loss_avg,train_acc_avg,val_acc_avg,val_acc_ci95,
meta_model.get_meta_learning_rate(),time.time() - t,max_val_acc))
torch.save({'epoch': epoch, 'embedding':meta_model.state_dict(),
# 'optimizer': optimizer.state_dict()
}, os.path.join(config["save_path"], 'best_model.pth'))
else :
log(config["log_file"], '\nEpoch: {:04d},loss_train: {:.6f},acc_train: {:.6f},'
'acc_val:{:.2f} ±{:.2f},meta_lr: {:.6f},time: {:.2f}s,best {:.2f}'
.format(epoch, train_loss_avg, train_acc_avg, val_acc_avg, val_acc_ci95,
meta_model.get_meta_learning_rate(), time.time() - t, max_val_acc))
meta_model.adapt_meta_learning_rate(train_loss_avg)
print('Optimization Finished! Total time elapsed: {:.6f}'.format(time.time() - t))
def get_max_acc(accs,step,scores,min_step,test_step):
step=np.argmax(scores[min_step-1:test_step])+min_step-1
return accs[step]
if __name__ == '__main__':
# Model training
best_model = run()
import torch
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
from models.TopKPoolfw import TopKPooling
from models.GcnConv import GCNConvFw
from models.layersFw import LinearFw
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import add_self_loops, remove_self_loops
from torch_scatter import scatter_add
class NodeInformationScore(MessagePassing):
def __init__(self, improved=False, cached=False, **kwargs):
super(NodeInformationScore, self).__init__(aggr='add', **kwargs)
self.improved = improved
self.cached = cached
self.cached_result = None
self.cached_num_edges = None
@staticmethod
def norm(edge_index, num_nodes, edge_weight, dtype=None):
edge_index, _ = remove_self_loops(edge_index)
if edge_weight is None:
edge_weight = torch.ones((edge_index.size(1),), dtype=dtype, device=edge_index.device)
row, col = edge_index
deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
edge_index, edge_weight = add_self_loops(edge_index, edge_weight, 0, num_nodes)
row, col = edge_index
expand_deg = torch.zeros((edge_weight.size(0),), dtype=dtype, device=edge_index.device)
expand_deg[-num_nodes:] = torch.ones((num_nodes,), dtype=dtype, device=edge_index.device)
return edge_index, expand_deg - deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
def forward(self, x, edge_index, edge_weight=None):
if self.cached and self.cached_result is not None:
if edge_index.size(1) != self.cached_num_edges:
raise RuntimeError(
'Cached {} number of edges, but found {}'.format(self.cached_num_edges, edge_index.size(1)))
if not self.cached or self.cached_result is None:
self.cached_num_edges = edge_index.size(1)
edge_index, norm = self.norm(edge_index, x.size(0), edge_weight, x.dtype)
self.cached_result = edge_index, norm
edge_index, norm = self.cached_result
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
return norm.view(-1, 1) * x_j
def update(self, aggr_out):
return aggr_out
class Model(torch.nn.Module):
def __init__(self, config):
super(Model, self).__init__()
self.config = config
self.num_features = config["num_features"]
self.nhid = config["nhid"]
self.num_classes = config["train_way"]
# if self.num_classes==2:
# self.num_classes=1
self.pooling_ratio=config["pooling_ratio"]
self.conv1 = GCNConvFw(self.num_features, self.nhid)
self.conv2 = GCNConvFw(self.nhid, self.nhid)
self.conv3 = GCNConvFw(self.nhid, self.nhid)
self.calc_information_score = NodeInformationScore()
# self.pool1 = HGPSLPoolFw(self.nhid, self.pooling_ratio, self.sample, self.sparse, self.sl, self.lamb)
self.pool1 = TopKPooling(self.nhid, self.pooling_ratio)
self.pool2 = TopKPooling(self.nhid, self.pooling_ratio)
self.pool3 = TopKPooling(self.nhid, self.pooling_ratio)
self.lin1 = LinearFw(self.nhid * 2, self.nhid)
self.lin2 = LinearFw(self.nhid, self.nhid // 2)
self.lin3 = LinearFw(self.nhid // 2, self.num_classes)
# self.bn1 = torch.nn.BatchNorm1d(self.nhid,affine=False)
# self.bn2 = torch.nn.BatchNorm1d(self.nhid // 2,affine=False)
# self.bn3 = torch.nn.BatchNorm1d(self.num_classes,affine=False)
self.relu=F.leaky_relu
def forward(self, data):
x, edge_index, batch = data
edge_attr = None
edge_index=edge_index.transpose(0,1)
x = self.relu(self.conv1(x, edge_index, edge_attr),negative_slope=0.1)
x, edge_index, edge_attr, batch, _, _ = self.pool1(x, edge_index, None, batch)
# x, edge_index, edge_attr, batch, _ = self.pool1(x, edge_index, edge_attr, batch)
x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
x =self.relu(self.conv2(x, edge_index, edge_attr),negative_slope=0.1)
x, edge_index, edge_attr, batch, _, _ = self.pool2(x, edge_index, None, batch)
# x, edge_index, edge_attr, batch, _ = self.pool2(x, edge_index, edge_attr, batch)
x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
#
x = self.relu(self.conv3(x, edge_index, edge_attr),negative_slope=0.1)
x, edge_index, edge_attr, batch, _, _ = self.pool3(x, edge_index, None, batch)
# x, edge_index, edge_attr, batch, _ = self.pool3(x, edge_index, edge_attr, batch)
x_information_score = self.calc_information_score(x, edge_index)
score = torch.sum(torch.abs(x_information_score), dim=1)
x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
x = self.relu(x1,negative_slope=0.1) + self.relu(x2,negative_slope=0.1) + self.relu(x3,negative_slope=0.1)
# x = F.relu(x1)
x = self.relu(self.lin1(x),negative_slope=0.1)
# x=self.bn1(x)
# x = F.dropout(x, p=self.dropout_ratio, training=self.training)
x = self.relu(self.lin2(x),negative_slope=0.1)
# x=self.bn2(x)
# x = F.dropout(x, p=self.dropout_ratio, training=self.training)
x=self.lin3(x)
# x = F.log_softmax(x, dim=-1)
return x,score.mean()
import torch
from torch.nn import Parameter
from torch_scatter import scatter_add
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import add_remaining_self_loops
import math
# from ..inits import glorot, zeros
class GCNConvFw(MessagePassing):
r"""The graph convolutional operator from the `"Semi-supervised
Classification with Graph Convolutional Networks"
<https://arxiv.org/abs/1609.02907>`_ paper
.. math::
\mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}
\mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta},
where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the
adjacency matrix with inserted self-loops and
:math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix.
Args:
in_channels (int): Size of each input sample.
out_channels (int): Size of each output sample.
improved (bool, optional): If set to :obj:`True`, the layer computes
:math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + 2\mathbf{I}`.
(default: :obj:`False`)
cached (bool, optional): If set to :obj:`True`, the layer will cache
the computation of :math:`\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}
\mathbf{\hat{D}}^{-1/2}` on first execution, and will use the
cached version for further executions.
This parameter should only be set to :obj:`True` in transductive
learning scenarios. (default: :obj:`False`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
def __init__(self, in_channels, out_channels, improved=False, cached=False,
bias=True, **kwargs):
super(GCNConvFw, self).__init__(aggr='add', **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.improved = improved
self.cached = cached
self.weight = Parameter(torch.Tensor(in_channels, out_channels))
self.weight.fast = None
# self.fast = None
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
self.bias.fast=None
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
glorot(self.weight)
zeros(self.bias)
self.cached_result = None
self.cached_num_edges = None
@staticmethod
def norm(edge_index, num_nodes, edge_weight=None, improved=False,
dtype=None):
if edge_weight is None:
edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
device=edge_index.device)
fill_value = 1 if not improved else 2
edge_index, edge_weight = add_remaining_self_loops(
edge_index, edge_weight, fill_value, num_nodes)
row, col = edge_index
deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
def forward(self, x, edge_index, edge_weight=None):
""""""
if self.weight.fast is not None:
x = torch.matmul(x, self.weight.fast)
else:
x = torch.matmul(x, self.weight)
if self.cached and self.cached_result is not None:
if edge_index.size(1) != self.cached_num_edges:
raise RuntimeError(
'Cached {} number of edges, but found {}. Please '
'disable the caching behavior of this layer by removing '
'the `cached=True` argument in its constructor.'.format(
self.cached_num_edges, edge_index.size(1)))
if not self.cached or self.cached_result is None:
self.cached_num_edges = edge_index.size(1)
edge_index, norm = self.norm(edge_index, x.size(self.node_dim),
edge_weight, self.improved, x.dtype)
self.cached_result = edge_index, norm
edge_index, norm = self.cached_result
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
return norm.view(-1, 1) * x_j
def update(self, aggr_out):
if self.bias is not None:
if self.bias.fast is not None:
aggr_out=aggr_out + self.bias.fast
else:
aggr_out = aggr_out + self.bias
return aggr_out
def __repr__(self):
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
self.out_channels)
def glorot(tensor):
if tensor is not None:
stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
tensor.data.uniform_(-stdv, stdv)
def zeros(tensor):
if tensor is not None:
tensor.data.fill_(0)
import torch
from torch.nn import Parameter
from torch_scatter import scatter_add, scatter_max
from torch_geometric.utils import softmax
import math
# from ..inits import uniform
# from ...utils.num_nodes import maybe_num_nodes
def maybe_num_nodes(index, num_nodes=None):
return index.max().item() + 1 if num_nodes is None else num_nodes
def uniform(size, tensor):
bound = 1.0 / math.sqrt(size)
if tensor is not None:
tensor.data.uniform_(-bound, bound)
def topk(x, ratio, batch, min_score=None, tol=1e-7):
if min_score is not None:
# Make sure that we do not drop all nodes in a graph.
scores_max = scatter_max(x, batch)[0][batch] - tol
scores_min = scores_max.clamp(max=min_score)
perm = torch.nonzero(x > scores_min).view(-1)
else:
num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item()
cum_num_nodes = torch.cat(
[num_nodes.new_zeros(1),
num_nodes.cumsum(dim=0)[:-1]], dim=0)
index = torch.arange(batch.size(0), dtype=torch.long, device=x.device)
index = (index - cum_num_nodes[batch]) + (batch * max_num_nodes)
dense_x = x.new_full((batch_size * max_num_nodes, ), -2)
dense_x[index] = x
dense_x = dense_x.view(batch_size, max_num_nodes)
_, perm = dense_x.sort(dim=-1, descending=True)
perm = perm + cum_num_nodes.view(-1, 1)
perm = perm.view(-1)
k = (ratio * num_nodes.to(torch.float)).ceil().to(torch.long)
mask = [
torch.arange(k[i], dtype=torch.long, device=x.device) +
i * max_num_nodes for i in range(batch_size)
]
mask = torch.cat(mask, dim=0)
perm = perm[mask]
return perm
def filter_adj(edge_index, edge_attr, perm, num_nodes=None):
num_nodes = maybe_num_nodes(edge_index, num_nodes)
mask = perm.new_full((num_nodes, ), -1)
i = torch.arange(perm.size(0), dtype=torch.long, device=perm.device)
mask[perm] = i
row, col = edge_index
row, col = mask[row], mask[col]
mask = (row >= 0) & (col >= 0)
row, col = row[mask], col[mask]
if edge_attr is not None:
edge_attr = edge_attr[mask]
return torch.stack([row, col], dim=0), edge_attr
class TopKPooling(torch.nn.Module):
r""":math:`\mathrm{top}_k` pooling operator from the `"Graph U-Nets"
<https://arxiv.org/abs/1905.05178>`_, `"Towards Sparse
Hierarchical Graph Classifiers" <https://arxiv.org/abs/1811.01287>`_
and `"Understanding Attention and Generalization in Graph Neural
Networks" <https://arxiv.org/abs/1905.02850>`_ papers
if min_score :math:`\tilde{\alpha}` is None:
.. math::
\mathbf{y} &= \frac{\mathbf{X}\mathbf{p}}{\| \mathbf{p} \|}
\mathbf{i} &= \mathrm{top}_k(\mathbf{y})
\mathbf{X}^{\prime} &= (\mathbf{X} \odot
\mathrm{tanh}(\mathbf{y}))_{\mathbf{i}}
\mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}}
if min_score :math:`\tilde{\alpha}` is a value in [0, 1]:
.. math::
\mathbf{y} &= \mathrm{softmax}(\mathbf{X}\mathbf{p})
\mathbf{i} &= \mathbf{y}_i > \tilde{\alpha}
\mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathbf{y})_{\mathbf{i}}
\mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}},
where nodes are dropped based on a learnable projection score
:math:`\mathbf{p}`.
Args:
in_channels (int): Size of each input sample.
ratio (float): Graph pooling ratio, which is used to compute
:math:`k = \lceil \mathrm{ratio} \cdot N \rceil`.
This value is ignored if min_score is not None.
(default: :obj:`0.5`)
min_score (float, optional): Minimal node score :math:`\tilde{\alpha}`
which is used to compute indices of pooled nodes
:math:`\mathbf{i} = \mathbf{y}_i > \tilde{\alpha}`.
When this value is not :obj:`None`, the :obj:`ratio` argument is
ignored. (default: :obj:`None`)
multiplier (float, optional): Coefficient by which features gets
multiplied after pooling. This can be useful for large graphs and
when :obj:`min_score` is used. (default: :obj:`1`)
nonlinearity (torch.nn.functional, optional): The nonlinearity to use.
(default: :obj:`torch.tanh`)
"""
def __init__(self, in_channels, ratio=0.5, min_score=None, multiplier=1,
nonlinearity=torch.tanh):
super(TopKPooling, self).__init__()
self.in_channels = in_channels
self.ratio = ratio
self.min_score = min_score
self.multiplier = multiplier
self.nonlinearity = nonlinearity
self.weight = Parameter(torch.Tensor(1, in_channels))
self.weight.fast=None
self.reset_parameters()
def reset_parameters(self):
size = self.in_channels
uniform(size, self.weight)
def forward(self, x, edge_index, edge_attr=None, batch=None, attn=None):
""""""
if batch is None:
batch = edge_index.new_zeros(x.size(0))
attn = x if attn is None else attn
attn = attn.unsqueeze(-1) if attn.dim() == 1 else attn
if self.weight.fast is not None:
score = (attn * self.weight.fast).sum(dim=-1)
else:
score = (attn * self.weight).sum(dim=-1)
if self.min_score is None:
if self.weight.fast is not None:
score = self.nonlinearity(score / self.weight.fast.norm(p=2, dim=-1))
else:
score = self.nonlinearity(score / self.weight.norm(p=2, dim=-1))
else:
score = softmax(score, batch)
perm = topk(score, self.ratio, batch, self.min_score)
x = x[perm] * score[perm].view(-1, 1)
x = self.multiplier * x if self.multiplier != 1 else x
batch = batch[perm]
edge_index, edge_attr = filter_adj(edge_index, edge_attr, perm,
num_nodes=score.size(0))
return x, edge_index, edge_attr, batch, perm, score[perm]
def __repr__(self):
return '{}({}, {}={}, multiplier={})'.format(
self.__class__.__name__, self.in_channels,
'ratio' if self.min_score is None else 'min_score',
self.ratio if self.min_score is None else self.min_score,
self.multiplier)
import torch
from torch.nn import Parameter
from torch_geometric.nn.conv import MessagePassing
from models.layersFw import LinearFw
from utils import uniform
class GraphConv(MessagePassing):
r"""The graph neural network operator from the `"Weisfeiler and Leman Go
Neural: Higher-order Graph Neural Networks"
<https://arxiv.org/abs/1810.02244>`_ paper
.. math::
\mathbf{x}^{\prime}_i = \mathbf{\Theta}_1 \mathbf{x}_i +
\sum_{j \in \mathcal{N}(i)} \mathbf{\Theta}_2 \mathbf{x}_j.
Args:
in_channels (int): Size of each input sample.
out_channels (int): Size of each output sample.
aggr (string, optional): The aggregation scheme to use
(:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`).
(default: :obj:`"add"`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
def __init__(self, in_channels, out_channels, aggr='add', bias=True,
**kwargs):
super(GraphConv, self).__init__(aggr=aggr, **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.weight = Parameter(torch.Tensor(in_channels, out_channels))
self.lin = LinearFw(in_channels, out_channels)
self.weight.fast=None
self.lin.weight.fast=None
self.lin.bias.fast=None
self.reset_parameters()
def reset_parameters(self):
uniform(self.in_channels, self.weight)
self.lin.reset_parameters()
def forward(self, x, edge_index, edge_weight=None, size=None):
""""""
if self.weight.fast is not None:
h = torch.matmul(x, self.weight.fast)
else:
h = torch.matmul(x, self.weight)
return self.propagate(edge_index, size=size, x=x, h=h,
edge_weight=edge_weight)
def message(self, h_j, edge_weight):
return h_j if edge_weight is None else edge_weight.view(-1, 1) * h_j
def update(self, aggr_out, x):
# tmp=None
# if self.lin.weight.fast is not None:
# tmp=torch.matmul(x,self.lin.weight.fast)+self.lin.bias.fast
# else:
# tmp=self.lin(x)
return aggr_out + self.lin(x)
def __repr__(self):
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
self.out_channels)
This diff is collapsed.
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader
from torch import optim
import numpy as np
from copy import deepcopy
class Meta(nn.Module):
"""
Meta Learner
"""
def __init__(self,model, config):
"""
:param args:
"""
super(Meta, self).__init__()
self.update_lr = config["inner_lr"]
self.meta_lr = config["lr"]
self.n_way = config["train_way"]
self.k_spt = config["train_shot"]
self.k_qry = config["train_query"]
# self.task_num = args.task_num
self.update_step = config["step"]
self.clip=config["grad_clip"]
self.update_step_test = config["step_test"]
self.net = model
self.meta_optim = optim.Adam(self.net.parameters(), lr=self.meta_lr,weight_decay=config["weight_decay"])
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.meta_optim, mode='min',
factor=0.2, patience=config["patience"],
verbose=True, min_lr=1e-6)
self.task_index=1
self.update_flag=config["batch_per_episodes"]
def forward(self, support_data, query_data):
"""
:param x_spt: [b, setsz, c_, h, w]
:param y_spt: [b, setsz]
:param x_qry: [b, querysz, c_, h, w]
:param y_qry: [b, querysz]
:return:
"""
(support_nodes, support_edge_index, support_graph_indicator, support_label) = support_data
(query_nodes, query_edge_index, query_graph_indicator, query_label) = query_data
task_num = support_nodes.size()[0]
querysz = query_label.size()[1]
losses_q = [0 for _ in range(self.update_step)] # losses_q[i] is the loss on step i
corrects = [0 for _ in range(self.update_step)]
for i in range(task_num):
fast_parameters = list(self.parameters()) # the first gradient calcuated in line 45 is based on original weight
for weight in self.parameters():
weight.fast = None
# self.zero_grad()
for k in range(0, self.update_step):
# 1. run the i-th task and compute loss for k=1~K-1
logits, _ = self.net([support_nodes[i], support_edge_index[i], support_graph_indicator[i]])
loss = F.nll_loss(logits, support_label[i])
# buiuld graph supld fport gradient of gradient
grad = torch.autograd.grad(loss, fast_parameters,create_graph=True)
fast_parameters = []
for index, weight in enumerate(self.parameters()):
# for usage of weight.fast, please see Linear_fw, Conv_fw in backbone.py
# if grad[k] is None:
# fast_parameters.append(weight.fast)
# continue
if weight.fast is None:
weight.fast = weight - self.update_lr * grad[index] # create weight.fast
else:
# create an updated weight.fast,
# note the '-' is not merely minus value, but to create a new weight.fast
weight.fast = weight.fast - self.update_lr * grad[index]
# gradients calculated in line 45 are based on newest fast weight, but the graph will retain the link to old weight.fasts
fast_parameters.append(weight.fast)
logits_q, _ = self.net([query_nodes[i], query_edge_index[i], query_graph_indicator[i]])
# loss_q will be overwritten and just keep the loss_q on last update step.
loss_q = F.nll_loss(logits_q, query_label[i])
losses_q[k] += loss_q
with torch.no_grad():
pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
correct = torch.eq(pred_q, query_label[i]).sum().item() # convert to numpy
corrects[k] = corrects[k] + correct
# end of all tasks
# sum over all losses on query set across all tasks
loss_q = losses_q[-1] / task_num
# print("loss",loss_q.item())
# optimize theta parameters
loss_q.backward()
if self.task_index==self.update_flag:
if self.clip > 0.1: # 0.1 threshold wether to do clip
torch.nn.utils.clip_grad_norm_(self.parameters(), self.clip)
self.meta_optim.step()
self.meta_optim.zero_grad()
self.task_index=1
else:
self.task_index=self.task_index+1
# for p in self.net.parameters():
# print(torch.norm(p.grad).item())
accs = 100*np.array(corrects) / (querysz * task_num)
return accs,loss_q.item()
def finetunning(self, support_data, query_data):
"""
:param x_spt: [setsz, c_, h, w]
:param y_spt: [setsz]
:param x_qry: [querysz, c_, h, w]
:param y_qry: [querysz]
:return:
"""
(support_nodes, support_edge_index, support_graph_indicator, support_label) = support_data
(query_nodes, query_edge_index, query_graph_indicator, query_label) = query_data
task_num = support_nodes.size()[0]
querysz = query_label.size()[1]
# losses_q = [0 for _ in range(self.update_step_test)] # losses_q[i] is the loss on step i
corrects = [0 for _ in range(self.update_step_test)]
for i in range(task_num):
fast_parameters = list(
self.parameters()) # the first gradient calcuated in line 45 is based on original weight
for weight in self.parameters():
weight.fast = None
self.zero_grad()
for k in range(0, self.update_step_test):
# 1. run the i-th task and compute loss for k=1~K-1
logits, _ = self.net([support_nodes[i], support_edge_index[i], support_graph_indicator[i]])
loss = F.nll_loss(logits, support_label[i])
# buiuld graph supld fport gradient of gradient
grad = torch.autograd.grad(loss, fast_parameters, create_graph=True)
fast_parameters = []
for index, weight in enumerate(self.parameters()):
# for usage of weight.fast, please see Linear_fw, Conv_fw in backbone.py
if weight.fast is None:
weight.fast = weight - self.update_lr * grad[index] # create weight.fast
else:
# create an updated weight.fast,
# note the '-' is not merely minus value, but to create a new weight.fast
weight.fast = weight.fast - self.update_lr * grad[index]
# gradients calculated in line 45 are based on newest fast weight, but the graph will retain the link to old weight.fasts
fast_parameters.append(weight.fast)
# print('add')
logits_q, _= self.net([query_nodes[i], query_edge_index[i], query_graph_indicator[i]])
# # loss_q will be overwritten and just keep the loss_q on last update step.
# loss_q = F.nll_loss(logits_q, query_label[i])
#
# losses_q[k] += loss_q
with torch.no_grad():
pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
correct = torch.eq(pred_q, query_label[i]).sum().item() # convert to numpy
corrects[k] = corrects[k] + correct
accs = 100*np.array(corrects) / querysz*task_num
return accs
def adapt_meta_learning_rate(self,loss):
self.scheduler.step(loss)
def get_meta_learning_rate(self):
epoch_learning_rate=[]
for param_group in self.meta_optim.param_groups:
epoch_learning_rate.append(param_group['lr'])
return epoch_learning_rate[0]
if __name__ == '__main__':
pass
This diff is collapsed.
......@@ -3,7 +3,6 @@ from models.graph_convfw import GraphConv
from torch_geometric.nn.pool.topk_pool import topk, filter_adj
from torch_geometric.utils import softmax
class SAGPooling(torch.nn.Module):
r"""The self-attention pooling operator from the `"Self-Attention Graph
Pooling" <https://arxiv.org/abs/1904.08082>`_ and `"Understanding
......
import torch
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
# from models.TopKPoolfw import TopKPooling
# from torch_geometric.nn import GCNConv
from models.sage_conv_fw import SAGEConv
from models.sag_poolfw import SAGPooling
from models.layersFw import LinearFw
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import add_self_loops, remove_self_loops
from torch_scatter import scatter_add
class NodeInformationScore(MessagePassing):
def __init__(self, improved=False, cached=False, **kwargs):
super(NodeInformationScore, self).__init__(aggr='add', **kwargs)
self.improved = improved
self.cached = cached
self.cached_result = None
self.cached_num_edges = None
@staticmethod
def norm(edge_index, num_nodes, edge_weight, dtype=None):
edge_index, _ = remove_self_loops(edge_index)
if edge_weight is None:
edge_weight = torch.ones((edge_index.size(1),), dtype=dtype, device=edge_index.device)
row, col = edge_index
deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
edge_index, edge_weight = add_self_loops(edge_index, edge_weight, 0, num_nodes)
row, col = edge_index
expand_deg = torch.zeros((edge_weight.size(0),), dtype=dtype, device=edge_index.device)
expand_deg[-num_nodes:] = torch.ones((num_nodes,), dtype=dtype, device=edge_index.device)
return edge_index, expand_deg - deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
def forward(self, x, edge_index, edge_weight=None):
if self.cached and self.cached_result is not None:
if edge_index.size(1) != self.cached_num_edges:
raise RuntimeError(
'Cached {} number of edges, but found {}'.format(self.cached_num_edges, edge_index.size(1)))
if not self.cached or self.cached_result is None:
self.cached_num_edges = edge_index.size(1)
edge_index, norm = self.norm(edge_index, x.size(0), edge_weight, x.dtype)
self.cached_result = edge_index, norm
edge_index, norm = self.cached_result
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
return norm.view(-1, 1) * x_j
def update(self, aggr_out):
return aggr_out
class Model(torch.nn.Module):
def __init__(self, config):
super(Model, self).__init__()
self.config = config
self.num_features = config["num_features"]
self.nhid = config["nhid"]
self.num_classes = config["train_way"]
# if self.num_classes==2:
# self.num_classes=1
self.pooling_ratio = config["pooling_ratio"]
self.dropout_ratio = config["dropout_ratio"]
self.conv1 = SAGEConv(self.num_features, self.nhid)
self.conv2 = SAGEConv(self.nhid, self.nhid)
self.conv3 = SAGEConv(self.nhid, self.nhid)
# self.conv4 = SAGEConv(self.nhid, self.nhid)
# self.conv5 = SAGEConv(self.nhid, self.nhid)
self.calc_information_score = NodeInformationScore()
# self.pool1 = HGPSLPoolFw(self.nhid, self.pooling_ratio, self.sample, self.sparse, self.sl, self.lamb)
self.pool1 = SAGPooling(self.nhid, self.pooling_ratio)
self.pool2 = SAGPooling(self.nhid, self.pooling_ratio)
self.pool3 = SAGPooling(self.nhid, self.pooling_ratio)
# self.pool4 = SAGPooling(self.nhid, self.pooling_ratio)
# self.pool5 = SAGPooling(self.nhid, self.pooling_ratio)
self.lin1 = LinearFw(self.nhid*2 , self.nhid)
self.lin2 = LinearFw(self.nhid, self.nhid // 2)
self.lin3 = LinearFw(self.nhid // 2, self.num_classes)
# self.bn1 = torch.nn.BatchNorm1d(self.nhid,affine=False)
# self.bn2 = torch.nn.BatchNorm1d(self.nhid // 2,affine=False)
# self.bn3 = torch.nn.BatchNorm1d(self.num_classes,affine=False)
self.relu=F.leaky_relu
def forward(self, data):
x, edge_index, batch = data
edge_attr = None
edge_index=edge_index.transpose(0,1)
x = self.relu(self.conv1(x, edge_index, edge_attr),negative_slope=0.1)
x, edge_index, edge_attr, batch, _, _ = self.pool1(x, edge_index, None, batch)
# x, edge_index, edge_attr, batch, _ = self.pool1(x, edge_index, edge_attr, batch)
x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
x =self.relu(self.conv2(x, edge_index, edge_attr),negative_slope=0.1)
x, edge_index, edge_attr, batch, _, _ = self.pool2(x, edge_index, None, batch)
# x, edge_index, edge_attr, batch, _ = self.pool2(x, edge_index, edge_attr, batch)
x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
#
x = self.relu(self.conv3(x, edge_index, edge_attr), negative_slope=0.1)
x, edge_index, edge_attr, batch, _, _ = self.pool3(x, edge_index, None, batch)
# x, edge_index, edge_attr, batch, _ = self.pool2(x, edge_index, edge_attr, batch)
x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
# x = self.relu(self.conv4(x, edge_index, edge_attr), negative_slope=0.1)
# x, edge_index, edge_attr, batch, _, _ = self.pool4(x, edge_index, None, batch)
# x, edge_index, edge_attr, batch, _ = self.pool2(x, edge_index, edge_attr, batch)
# x4 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
# x = self.relu(self.conv5(x, edge_index, edge_attr),negative_slope=0.1)
# x, edge_index, edge_attr, batch, _, _ = self.pool5(x, edge_index, None, batch)
# x, edge_index, edge_attr, batch, _ = self.pool3(x, edge_index, edge_attr, batch)
x_information_score = self.calc_information_score(x, edge_index)
score = torch.sum(torch.abs(x_information_score), dim=1)
# x5 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
x = self.relu(x1,negative_slope=0.1) + self.relu(x2,negative_slope=0.1)+self.relu(x3,negative_slope=0.1)
# x = self.relu(x1,negative_slope=0.1) + self.relu(x2,negative_slope=0.1)
# + self.relu(x3,negative_slope=0.1)+self.relu(x4,negative_slope=0.1)+self.relu(x5,negative_slope=0.1)
# x = F.relu(x1)
graph_emb=x
# x = self.lin1(x)
x = self.relu(self.lin1(x),negative_slope=0.1)
# x=self.bn1(x)
# x = F.dropout(x, p=self.dropout_ratio, training=self.training)
x = self.relu(self.lin2(x),negative_slope=0.1)
# x=self.bn2(x)
# x = F.dropout(x, p=self.dropout_ratio, training=self.training)
x=self.lin3(x)
# x = F.log_softmax(x, dim=-1)
return x,score.mean(),graph_emb
import argparse
import time
import torch
from models.sage4maml_model import Model
import ssl
import os
from data.dataset1 import GraphDataSet,FewShotDataloader
from models.meta_ada import Meta
from tqdm import tqdm
import numpy as np
from utils import *
def get_dataset(dataset):
train_data=None
val_data=None
test_data=None
val_data = GraphDataSet(phase="val", dataset_name=dataset)
train_data=GraphDataSet(phase="train",dataset_name=dataset)
test_data=GraphDataSet(phase="test",dataset_name=dataset)
return train_data,val_data,test_data
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', type=str, default='./experiments/log/COIL-DEL-06-23/log_06-23-10-15')
parser.add_argument('--gpu', type=str, default='cuda:0')
parser.add_argument('--root_directory', type=str, default='./experiments')
args = parser.parse_args()
print(args)
config = unserialize(os.path.join(args.model_dir,'config.json'))
config["device"]=args.gpu
print("dataset", config["dataset"], "{}way{}shot".format(config["test_way"], config["val_shot"]),"gpu",config["device"])
_,_, test_set = get_dataset(config["dataset"])
config["num_features"] = test_set.num_features
config["val_episode"]=400
test_loader = FewShotDataloader(test_set,
n_way=config["test_way"], # number of novel categories.
n_shot=config["val_shot"], # number of training examples per novel category.
n_query=config["val_query"], # number of test examples for all the novel categories.
batch_size=1, # number of training episodes per batch.
num_workers=4,
epoch_size=config["val_episode"], # number of batches per epoch.
)
model = Model(config)
meta_model=Meta(model,config)
saved_models = torch.load(os.path.join(args.model_dir, 'best_model.pth'))
meta_model.load_state_dict(saved_models['embedding'])
model=meta_model.net
if config["double"]==True:
model=model.double()
meta_model=meta_model.double()
model=model.to(config["device"])
meta_model=meta_model.to(config["device"])
pa=get_para_num(meta_model)
print(pa)
def run():
device=config["device"]
t = time.time()
max_val_acc=0
val_accs=[]
# validation_stage
meta_model.eval()
for i, data in enumerate(tqdm(test_loader(1)), 1):
support_data, query_data = data
if config["double"]==True:
support_data[0] = support_data[0].double()
query_data[0] = query_data[0].double()
support_data = [item.to(device) for item in support_data]
query_data = [item.to(device) for item in query_data]
accs,step,stop_gates,scores,query_losses= meta_model.finetunning(support_data, query_data)
val_accs.append(accs[step])
if i % 100 == 0:
print("\n{}th test".format(i))
print("stop_prob", len(stop_gates), [stop_gate for stop_gate in stop_gates])
print("scores", len(scores), [score for score in scores])
print("stop_prob", len(query_losses), [query_loss for query_loss in query_losses])
print("accs", len(accs), [accs[i] for i in range(0,step+1)])
print("query_accs{:.2f}".format(np.mean(val_accs)))
val_acc_avg=np.mean(val_accs)
val_acc_ci95 = 1.96 * np.std(np.array(val_accs)) / np.sqrt(config["val_episode"])
print('\nacc_val:{:.2f} ±{:.2f},time: {:.2f}s'.format(val_acc_avg,val_acc_ci95,time.time() - t))
return None
if __name__ == '__main__':
# Model training
best_model = run()
import os
import time
import torch
import numpy as np
import json
import _pickle
import math
def get_para_num(net):
total_num = sum(p.numel() for p in net.parameters())
trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
return {'Total': total_num, 'Trainable': trainable_num}
def setup_seed(seed=0):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
# np.seed(seed)
torch.backends.cudnn.deterministic = True
def serialize(obj, path, in_json=False):
if isinstance(obj, np.ndarray):
np.save(path, obj)
elif in_json:
with open(path, "w") as file:
json.dump(obj, file, indent=2)
else:
with open(path, 'wb') as file:
_pickle.dump(obj, file)
def unserialize(path):
suffix = os.path.basename(path).split(".")[-1]
if suffix == "npy":
return np.load(path)
elif suffix == "json":
with open(path, "r") as file:
return json.load(file)
else:
with open(path, 'rb') as file:
return _pickle.load(file)
def set_gpu(x):
os.environ['CUDA_VISIBLE_DEVICES'] = x
print('using gpu:', x)
def check_dir(path):
'''
Create directory if it does not exist.
path: Path of directory.
'''
if not os.path.exists(path):
os.mkdir(path)
def uniform(size, tensor):
bound = 1.0 / math.sqrt(size)
if tensor is not None:
tensor.data.uniform_(-bound, bound)
def count_accuracy(logits, label):
pred = torch.argmax(logits, dim=1).view(-1)
label = label.view(-1)
accuracy = 100 * pred.eq(label).float().mean().item()
return accuracy
class Timer():
def __init__(self):
self.o = time.time()
def measure(self, p=1):
x = (time.time() - self.o) / float(p)
x = int(x)
if x >= 3600:
return '{:.1f}h'.format(x / 3600)
if x >= 60:
return '{}m'.format(round(x / 60))
return '{}s'.format(x)
import datetime
def log(log_file_path, string):
'''
Write one line of log into screen and file.
log_file_path: Path of log file.
string: String to write in log file.
'''
time=datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
with open(log_file_path, 'a+') as f:
f.write(string+" "+time + '\n')
f.flush()
print(string)
def store(st,writer,epoch=None):
update_step=len(st["loss"])
for step in range(update_step):
writer.add_scalars("l_s_s",{"loss":st["loss"][step],
"stop_gate":st["stop_gates"][step],
"scores":st["scores"][step]
},step)
for item in ["grads","input_gates","forget_gates"]:
for step in range(update_step):
d={}
for index,v in enumerate(st[item][step]):
d["layer"+str(index)]=v
writer.add_scalars(item, d, step)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment