赞
踩
在DGL框架中,当我们明白了边、顶点、图以及边和顶点的属性后,接下来需要了解的概念就是 三座大山了:
dgl的节点有一个概念是mailbox, 用于暂存消息函数发送过来的数据。 消息-> 邮箱, 这样的概念是比较容易理解的。
默认的消息函数是 ϕ,它接受的参数是edges ,类型是dgl.EdgeBatch. edges有src,dst和data三个属性,分别是源顶点、目标顶点和边,可以用这三个属性访问各自的特征。
表示: node + node -> mailbox 或 Node + edge -> mailbox
内置消息函数:
一元: copy 函数
二元: add, sub, mul, div, dot 函数
约定: 名字上的u表示源节点,v表示目标节点,e表示边。
参数: 字符串参数, 表示相应节点的输入和输出特征名字段名
例: 对源节点的hu特征和目标节点的hv特征求和,然后将结果保存在边的he特征上
dgl.function.u_add_v('hu','hv','he')
如果要自定义此消息函数,等价于以下代表,注意返回的是dict格式的数据。
- def message_func(edges):
- return {'he': edges.src['hu'] + edges.dst['hv']}
默认的聚合函数是ρ , 接受的参数类型是nodes ,也就是顶点集合,类型为 dgl.NodeBatch, nodes有成员属性mailbox, 用来访问节点收到的消息。 mailbox可以理解为一块临时存贮区,在消息函数运行后用来暂存数据。 此时并不会更新目标节点数据。
内置的聚合函数:
sum, max,min,mean 操作
参数:这些函数通常都是两个参数,类型为字符串
一个用于指定mailbox中的字段名
一个用于指定目标节点特征的字段名
如dgl.function.sum('m','h')等价于如下所示的自定义函数。 注意,聚合只是聚合,并不更新任何值 ,只执行聚合的任务,说白了就是把消息函数中传来的数据进行处理但不更新,切记。
- import torch
- def reduce_func(nodes):
- return {'h': torch.sum(nodes.mailbox['m'], dim=1)}
前两步接收和聚合后的数据,需要更新目标节点的特征,参数为nodes, 类型为dgl.NodeBatch. 此函数对聚合函数的聚合结果进行操作,在消息传递的最后一步将其与其它节点的特征组合后,作为节点的新特征。
前面讲了消息和mailbox的概念,更新函数的作用就是按需将mailbox中的数据搬回家(与节点的数据合并)
在不涉及消息传递时,可以调用apply_edges()函数进行逐边计算
参数为一个消息函数,默认为更新所有的边
例子:
- import dgl.function as fn
- graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
该接口合并了消息生成,消息聚合,节点特征更新,好处是可以给这三步操作作一个整体优化,用更底层的高效算法,比如直接调用cuda函数进行操作,从而提高运行效率。
参数为: 一个消息函数,一个聚合函数,一个更新函数。 官方文档不建议在这儿使用更新函数,可以自己在随后进行操作。
示例:
- def updata_all_example(graph):
- # store the result in graph.ndata['ft']
- graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
- fn.sum('m', 'ft'))
- # Call update function outside of update_all
- final_ft = graph.ndata['ft'] * 2
- return final_ft
这段代码 将节点 特征字段 ft与边特征字段a相乘后生成消息m (存放于暂存位置mailbox中),然后对所有的消息求和来更新节点特征ft, 再将ft乘以2得到最终结果 final_ft, 调用后mailbox中的中间结果m会被清除。公式表示为:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。