CS224W – Colab 3 MessagePassing实现GraphSAGE

Implement the GraphSAGE layer directly

1.GraphSage

对于一个具有编码

h

v

l

1

h_v^{l-1}

hvl1的中心节点

v

v

v,进行下一步状态更新的规则为:

h

v

(

l

)

=

W

l

h

v

(

l

1

)

+

W

r

A

G

G

(

{

h

u

(

l

1

)

,

u

N

(

v

)

}

)

h_v^{(l)} = W_lcdot h_v^{(l-1)} + W_r cdot AGG({h_u^{(l-1)}, forall u in N(v) })

hv(l)=Wlhv(l1)+WrAGG({hu(l1),uN(v)})

W

l

W_l

Wl

W

r

W_r

Wr为可学习的权重,

N

(

v

)

N(v)

N(v) 代表

v

v

v的邻接节点。

A

G

G

(

)

AGG(·)

AGG() 为消息聚合函数,当采用 mean aggregation时,有

A

G

G

(

{

h

u

(

l

1

)

,

u

N

(

v

)

}

)

=

1

N

(

v

)

u

N

(

v

)

h

u

(

l

1

)

AGG({h_u^{(l-1)}, forall u in N(v) }) = frac{1}{|N(v)|} sum_{uin N(v)} h_u^{(l-1)}

AGG({hu(l1),uN(v)})=N(v)1uN(v)hu(l1)

2.Implement

(1)实现方法

实现分三步,分别为

1)每一个邻居

u

u

u节点传递当前状态

u

l

1

u^{l-1}

ul1

2)中心节点

v

v

v 使用聚合函数聚合收到的消息,在GraphSage中为简单求平均;

3)中心节点使用聚合消息更新自己的状态,在GraphSage中为残差。

(2)实现步骤

pytorch提供了MessagePassing父类,我们借此可以简洁实现消息传递。

class GraphSage(MessagePassing):
    
    def __init__(self, in_channels, out_channels, normalize = True,
                 bias = False, **kwargs):  
        super(GraphSage, self).__init__(**kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize

        self.lin_l=nn.Linear(in_features=in_channels, out_features=out_channels)
        self.lin_r=nn.Linear(in_features=in_channels, out_features=out_channels)
    
    def message(self, x_j):
        out = None
        out = self.lin_r(x_j)
        return out
   
    def aggregate(self, inputs, index, dim_size = None):
        out = None
        node_dim = self.node_dim
        out=torch_scatter.scatter(inputs, index, dim=node_dim,reduce='mean')
        return out
 
    def forward(self, x, edge_index, size = None):
        out=self.propagate(edge_index,x=(x,x))
        out=self.lin_l(x)+out
        if self.normalize:
            out=F.normalize(out)
        return out

message函数定义全局消息传递的内容。参数x_j描述所有消息传递关系中源节点的特征,形状为

[

E

,

d

]

[|mathcal{E}|, d]

[E,d]

(

i

,

j

)

E

(i, j) in mathcal{E}

(i,j)E.

aggregate函数定义了中心节点接收和聚合消息的方法。参数inputsmessage函数的返回值,index描述了每个中心节点

v

v

v接收来自邻居节点

u

u

u的消息在inputs的哪一行行。scatter函数声明为

torch_scatter.scatter(input: Tensor, index: Tensor, dim: int = -1, out: Optional[Tensor] = None, dim_size: Optional[int] = None, reduce: str = 'sum')→ Tensor[source]

函数功能为用indexdim指定的维度索引张量input,再根据reduce规则计算返回值。

在这里插入图片描述

如图所示,中心节点0的邻居节点在input的第0、1、3个索引。

propagate函数定义在MessagePassing父类。用于启动一次消息传递过程。edge_index为整张图的边索引信息,形状是

[

2

,

E

]

[2,mathcal{E}]

[2,E]。参数x存放邻居节点和中心节点的特征。因为每个节点既是中心节点又是邻居节点,且采用一样的特征描述,所以元组的两个元素是一样的。propagate函数会自动调用messageaggregate完成消息传递和消息聚合。

④当GraphSage对象被调用时,默认调用forward来启动消息传递。forward函数返回更新后的节点特征张量,形状为

[

N

,

d

]

[|N|, d]

[N,d].

N

N

N是所有节点的集合。

3. Train and Test

使用CORA dataset数据集进行节点分类任务。训练过程如下

在这里插入图片描述

本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
THE END
分享
二维码
< <上一篇
下一篇>>