Unverified Commit 6c5848bf authored by NingMa's avatar NingMa Committed by GitHub

Add files via upload

parent 3cc9f440
{
"model_type": "sage",
"nhid": 128,
"pooling_ratio":0.5,
"dropout_ratio":0.3,
"double":true,
"device":"cuda:0",
"dataset":"TRIANGLES",
"outer_lr": 0.001,
"inner_lr": 0.01,
"stop_lr":0.0001,
"weight_decay": 1e-5,
"max_step":15,
"min_step":5,
"step_test":15,
"flexible_step": true,
"step_penalty":0.001,
"use_score": true,
"use_grad":false,
"use_loss":true,
"train_shot": 10,
"val_shot": 10,
"train_query": 15,
"val_query": 15,
"train_way": 3,
"test_way": 3,
"val_episode": 200,
"train_episode": 200,
"batch_per_episodes":5,
"epochs":500,
"patience":35,
"grad_clip":5,
"save": true
}
\ No newline at end of file
import torch
from models.graph_convfw import GraphConv
from torch_geometric.nn.pool.topk_pool import topk, filter_adj
from torch_geometric.utils import softmax
class SAGPooling(torch.nn.Module):
r"""The self-attention pooling operator from the `"Self-Attention Graph
Pooling" <https://arxiv.org/abs/1904.08082>`_ and `"Understanding
Attention and Generalization in Graph Neural Networks"
<https://arxiv.org/abs/1905.02850>`_ papers
if :obj:`min_score` :math:`\tilde{\alpha}` is :obj:`None`:
.. math::
\mathbf{y} &= \textrm{GNN}(\mathbf{X}, \mathbf{A})
\mathbf{i} &= \mathrm{top}_k(\mathbf{y})
\mathbf{X}^{\prime} &= (\mathbf{X} \odot
\mathrm{tanh}(\mathbf{y}))_{\mathbf{i}}
\mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}}
if :obj:`min_score` :math:`\tilde{\alpha}` is a value in [0, 1]:
.. math::
\mathbf{y} &= \mathrm{softmax}(\textrm{GNN}(\mathbf{X},\mathbf{A}))
\mathbf{i} &= \mathbf{y}_i > \tilde{\alpha}
\mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathbf{y})_{\mathbf{i}}
\mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}},
where nodes are dropped based on a learnable projection score
:math:`\mathbf{p}`.
Projections scores are learned based on a graph neural network layer.
Args:
in_channels (int): Size of each input sample.
ratio (float): Graph pooling ratio, which is used to compute
:math:`k = \lceil \mathrm{ratio} \cdot N \rceil`.
This value is ignored if min_score is not None.
(default: :obj:`0.5`)
GNN (torch.nn.Module, optional): A graph neural network layer for
calculating projection scores (one of
:class:`torch_geometric.nn.conv.GraphConv`,
:class:`torch_geometric.nn.conv.GCNConv`,
:class:`torch_geometric.nn.conv.GATConv` or
:class:`torch_geometric.nn.conv.SAGEConv`). (default:
:class:`torch_geometric.nn.conv.GraphConv`)
min_score (float, optional): Minimal node score :math:`\tilde{\alpha}`
which is used to compute indices of pooled nodes
:math:`\mathbf{i} = \mathbf{y}_i > \tilde{\alpha}`.
When this value is not :obj:`None`, the :obj:`ratio` argument is
ignored. (default: :obj:`None`)
multiplier (float, optional): Coefficient by which features gets
multiplied after pooling. This can be useful for large graphs and
when :obj:`min_score` is used. (default: :obj:`1`)
nonlinearity (torch.nn.functional, optional): The nonlinearity to use.
(default: :obj:`torch.tanh`)
**kwargs (optional): Additional parameters for initializing the graph
neural network layer.
"""
def __init__(self, in_channels, ratio=0.5, GNN=GraphConv, min_score=None,
multiplier=1, nonlinearity=torch.tanh, **kwargs):
super(SAGPooling, self).__init__()
self.in_channels = in_channels
self.ratio = ratio
self.gnn = GNN(in_channels, 1, **kwargs)
self.min_score = min_score
self.multiplier = multiplier
self.nonlinearity = nonlinearity
self.reset_parameters()
def reset_parameters(self):
self.gnn.reset_parameters()
def forward(self, x, edge_index, edge_attr=None, batch=None, attn=None):
""""""
if batch is None:
batch = edge_index.new_zeros(x.size(0))
attn = x if attn is None else attn
attn = attn.unsqueeze(-1) if attn.dim() == 1 else attn
score = self.gnn(attn, edge_index).view(-1)
if self.min_score is None:
score = self.nonlinearity(score)
else:
score = softmax(score, batch)
perm = topk(score, self.ratio, batch, self.min_score)
x = x[perm] * score[perm].view(-1, 1)
x = self.multiplier * x if self.multiplier != 1 else x
batch = batch[perm]
edge_index, edge_attr = filter_adj(edge_index, edge_attr, perm,
num_nodes=score.size(0))
return x, edge_index, edge_attr, batch, perm, score[perm]
def __repr__(self):
return '{}({}, {}, {}={}, multiplier={})'.format(
self.__class__.__name__, self.gnn.__class__.__name__,
self.in_channels,
'ratio' if self.min_score is None else 'min_score',
self.ratio if self.min_score is None else self.min_score,
self.multiplier)
import torch
import torch.nn.functional as F
from torch.nn import Parameter
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import add_remaining_self_loops
from utils import uniform
class SAGEConv(MessagePassing):
r"""The GraphSAGE operator from the `"Inductive Representation Learning on
Large Graphs" <https://arxiv.org/abs/1706.02216>`_ paper
.. math::
\mathbf{\hat{x}}_i &= \mathbf{\Theta} \cdot
\mathrm{mean}_{j \in \mathcal{N(i) \cup \{ i \}}}(\mathbf{x}_j)
\mathbf{x}^{\prime}_i &= \frac{\mathbf{\hat{x}}_i}
{\| \mathbf{\hat{x}}_i \|_2}.
Args:
in_channels (int): Size of each input sample.
out_channels (int): Size of each output sample.
normalize (bool, optional): If set to :obj:`True`, output features
will be :math:`\ell_2`-normalized. (default: :obj:`False`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
def __init__(self, in_channels, out_channels, normalize=False, bias=True,
**kwargs):
super(SAGEConv, self).__init__(aggr='mean', **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.normalize = normalize
self.weight = Parameter(torch.Tensor(self.in_channels, out_channels))
self.weight.fast = None
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
self.bias.fast = None
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
uniform(self.in_channels, self.weight)
uniform(self.in_channels, self.bias)
def forward(self, x, edge_index, edge_weight=None, size=None):
""""""
if size is None and torch.is_tensor(x):
edge_index, edge_weight = add_remaining_self_loops(
edge_index, edge_weight, 1, x.size(0))
if self.weight.fast is not None:
weight = self.weight.fast
else:
weight = self.weight
if torch.is_tensor(x):
x = torch.matmul(x, weight)
else:
x = (None if x[0] is None else torch.matmul(x[0], weight),
None if x[1] is None else torch.matmul(x[1], weight))
return self.propagate(edge_index, size=size, x=x,
edge_weight=edge_weight)
def message(self, x_j, edge_weight):
return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j
def update(self, aggr_out):
if self.bias is not None:
if self.bias.fast is not None:
aggr_out = aggr_out + self.bias.fast
else:
aggr_out = aggr_out + self.bias
if self.normalize:
aggr_out = F.normalize(aggr_out, p=2, dim=-1)
return aggr_out
def __repr__(self):
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
self.out_channels)
"""
An original implementation of sparsemax (Martins & Astudillo, 2016) is available at
https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/modules/sparse_activations.py.
See `From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification, ICML 2016`
for detailed description.
We make some modifications to make it work at scatter operation scenarios, e.g., calculate softmax according to batch
indicators.
Usage:
>> x = torch.tensor([ 1.7301, 0.6792, -1.0565, 1.6614, -0.3196, -0.7790, -0.3877, -0.4943,
0.1831, -0.0061])
>> batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1])
>> sparse_attention = Sparsemax()
>> res = sparse_attention(x, batch)
>> print(res)
tensor([0.5343, 0.0000, 0.0000, 0.4657, 0.0612, 0.0000, 0.0000, 0.0000, 0.5640,
0.3748])
"""
import torch
import torch.nn as nn
from torch.autograd import Function
from torch_scatter import scatter_add, scatter_max
def scatter_sort(x, batch, fill_value=-1e16):
num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item()
cum_num_nodes = torch.cat([num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0)
index = torch.arange(batch.size(0), dtype=torch.long, device=x.device)
index = (index - cum_num_nodes[batch]) + (batch * max_num_nodes)
dense_x = x.new_full((batch_size * max_num_nodes,), fill_value)
dense_x[index] = x
dense_x = dense_x.view(batch_size, max_num_nodes)
sorted_x, _ = dense_x.sort(dim=-1, descending=True)
cumsum_sorted_x = sorted_x.cumsum(dim=-1)
cumsum_sorted_x = cumsum_sorted_x.view(-1)
sorted_x = sorted_x.view(-1)
filled_index = sorted_x != fill_value
sorted_x = sorted_x[filled_index]
cumsum_sorted_x = cumsum_sorted_x[filled_index]
return sorted_x, cumsum_sorted_x
def _make_ix_like(batch):
num_nodes = scatter_add(batch.new_ones(batch.size(0)), batch, dim=0)
idx = [torch.arange(1, i + 1, dtype=torch.long, device=batch.device) for i in num_nodes]
idx = torch.cat(idx, dim=0)
return idx
def _threshold_and_support(x, batch):
"""Sparsemax building block: compute the threshold
Args:
x: input tensor to apply the sparsemax
batch: group indicators
Returns:
the threshold value
"""
num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
cum_num_nodes = torch.cat([num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0)
sorted_input, input_cumsum = scatter_sort(x, batch)
input_cumsum = input_cumsum - 1.0
rhos = _make_ix_like(batch).to(x.dtype)
support = rhos * sorted_input > input_cumsum
support_size = scatter_add(support.to(batch.dtype), batch)
# mask invalid index, for example, if batch is not start from 0 or not continuous, it may result in negative index
idx = support_size + cum_num_nodes - 1
mask = idx < 0
idx[mask] = 0
tau = input_cumsum.gather(0, idx)
tau /= support_size.to(x.dtype)
return tau, support_size
class SparsemaxFunction(Function):
@staticmethod
def forward(ctx, x, batch):
"""sparsemax: normalizing sparse transform
Parameters:
ctx: context object
x (Tensor): shape (N, )
batch: group indicator
Returns:
output (Tensor): same shape as input
"""
max_val, _ = scatter_max(x, batch)
x -= max_val[batch]
tau, supp_size = _threshold_and_support(x, batch)
output = torch.clamp(x - tau[batch], min=0)
ctx.save_for_backward(supp_size, output, batch)
return output
@staticmethod
def backward(ctx, grad_output):
supp_size, output, batch = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[output == 0] = 0
v_hat = scatter_add(grad_input, batch) / supp_size.to(output.dtype)
grad_input = torch.where(output != 0, grad_input - v_hat[batch], grad_input)
return grad_input, None
sparsemax = SparsemaxFunction.apply
class Sparsemax(nn.Module):
def __init__(self):
super(Sparsemax, self).__init__()
def forward(self, x, batch):
return sparsemax(x, batch)
if __name__ == '__main__':
sparse_attention = Sparsemax()
input_x = torch.tensor([1.7301, 0.6792, -1.0565, 1.6614, -0.3196, -0.7790, -0.3877, -0.4943, 0.1831, -0.0061])
input_batch = torch.cat([torch.zeros(4, dtype=torch.long), torch.ones(6, dtype=torch.long)], dim=0)
res = sparse_attention(input_x, input_batch)
print(res)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment