当前位置:   article > 正文

内容增强网络表示学习的一般框架(A General Framework for Content-enhanced Network Representation Learning)(2)

a general framework for content-enhanced network representation learning

在上篇内容增强网络表示学习的一般框架(A General Framework for Content-enhanced Network Representation Learning)(1)中,我们可以了解到这个学习框架的基本构成,这篇文章就来简单的实现下。

前提假设

在这里插入图片描述
即:
S P SP SP 是通过随机游走所生成的路径上的邻接节点对;
S N SN SN 是所有的负采样的集合;
节点 e u e_u eu的数值化表示为节点的嵌入表示,类似与DeepWalknode2vec

任务

1)整合node-node连接、node-content连接为一个完整的图表示;
2)图中任意节点 u u u的嵌入表示 e u e_u eu
3)从下图可以看见一个sent2vec
在这里插入图片描述
sent2vec原文:Skip-Thought Vector
代码地址:here

数据集

使用DBLP引文网络V1数据集,按照文中所述:
1)每篇论文视为一个节点,节点之间的边表示引用关系;
2)文中所述只有16.7%的节点是有contents存在。
3)观察上图Figure 2,可以发现并不是每个节点都存在contents,但是在该数据集中不存在多个连接指向一个文档内容的这种情况。

1. 数据集处理

将存在contents的论文的contents抽取为一个节点;

数据集格式:
#* — paperTitle
#@ — Authors
#t ---- Year
#c — publication venue
#index 00---- index id of this paper
#% ---- the id of references of this paper (there are multiple lines, with each indicating a reference)
#! — Abstract

不妨进行按行读取文本文件内容:

lines = []
with open("outputacm.txt", "r") as f:
    lines = f.readlines()
lines = lines[1:] # 数据集中第一行是个数统计,不需要
  • 1
  • 2
  • 3
  • 4

然后,我们按照每行的前置判断类型,找到对应的节点indexreferences,以及存储部分节点的文本对应关系,如下:

Nodes = [] # 存储所有的节点列表
Cites = [] # 存储节点之间的引用关系元组
Contents = [] # 临时存储节点和内容,用于后面的文本内容编号,即虚拟节点生成
for line in lines:
    if line.startswith("#index"): # 节点下标
        index = line[6: -1]
        Nodes.append(index)
    if line.startswith("#!"): # 节点的内容
        content = line[2: -1]
        Contents.append((Nodes[-1], content))
    if line.startswith("#%"): # 引文网络的引用文章ID
        node = line[2: -1]
        Cites.append((Nodes[-1], node))
        pass

# 将内容文本进行编号,然后添加关系到Cites中,虚拟节点加入到Nodes中
virtual_node_map = {}
for item in Contents:
    virtual_node_index = str(int(Nodes[-1]) + 1)
    virtual_node_map[virtual_node_index] = item[1]
    Nodes.append(virtual_node_index)
    Cites.append((item[0], virtual_node_index))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

然后,不妨存储这个边集到本地的CSV文件中:

# 写生成的边集Cites到文本'edges.csv'中
import csv
with open('edges.csv', 'w', newline='') as csvfile:
    writer  = csv.writer(csvfile)
    writer.writerow(("node1", "node2"))
    for row in Cites:
        writer.writerow(row)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

然后,使用我们前面文章中定义的读取CSV文件到networkx格式的图变量中:

import networkx as nx

def loadGraph(filename):
    # 定义一个Graph来存储节点边等信息,默认初始权重为1
    G = nx.Graph()
    with open(filename, mode="r", encoding="utf-8") as f:
        # 第一行存储的是顶点的数目和边的数目
        n, m = f.readline().split(",")
        for line in f:
            u, v = map(int, line.split(","))
            try:
                G[u][v]['weight'] += 1
            except:
                G.add_edge(u, v, weight=1)
    return G
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

然后,根据节点和度的关系来进行统计,以绘制幂律分布的图形:

G = loadGraph("edges.csv")
# 遍历所有的节点,然后统计每个节点的度数,然后绘制幂律分布图像
count = {} # 统计图中的度-频率的字典
for node in G.nodes():
    degree = G.degree(node)
    if degree not in count.keys():
        count[degree] = 1
    else:
        count[degree] += 1
# 得到绘制的度-频率的值字典count
# 由于个别节点的度数目过于多,这里去除度大于10000的度的值,实际中貌似也不需要去除
# import copy
# keys = copy.copy(list(count.keys()))
# for key in keys:
#     if count[key] > 10000:
#         del count[key]
# count
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

在绘制之前,不妨先简单绘制下原始的度-频率的分布图像:

# 绘制pow law
import matplotlib.pyplot as plt
import numpy as np


x = np.array(list(count.keys()))
y = np.array(list(count.values()))

plt.scatter(x, y, s=20, c='r', alpha=0.5)

plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

即:
在这里插入图片描述
上图显然不是幂律分布图,那么不妨直接取log()来试试:

import math
x_ln = [math.log(i) for i in x]
y_ln = [math.log(i) for i in y]
plt.scatter(x_ln, y_ln, s=20, c='r', alpha=0.5)
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5

即:
在这里插入图片描述
那么,可以说明我们将节点的文本整理成虚拟节点的网络图是符合无标度特性的。
这里整理下代码,给出完整代码:

import networkx as nx
import math
import matplotlib.pyplot as plt
import numpy as np

# 对引文网络数据集outputacm.txt的数据进行处理
def readFile():
    lines = []
    with open("outputacm.txt", "r") as f:
        lines = f.readlines()

    return lines[1:] # 数据集中第一行是个数统计,不需要

# 写生成的边集Cites到文本'edges.csv'中
def save2csv(Cites):
    import csv
    with open('edges.csv', 'w', newline='') as csvfile:
        writer  = csv.writer(csvfile)
        writer.writerow(("node1", "node2"))
        for row in Cites:
            writer.writerow(row)

# 统计数据
def statistics(lines):
    Nodes = [] # 存储所有的节点列表
    Cites = [] # 存储节点之间的引用关系元组
    Contents = [] # 临时存储节点和内容,用于后面的文本内容编号,即虚拟节点生成
    for line in lines:
        if line.startswith("#index"): # 节点下标
            index = line[6: -1]
            Nodes.append(index)
        if line.startswith("#!"): # 节点的内容
            content = line[2: -1]
            Contents.append((Nodes[-1], content))
        if line.startswith("#%"): # 引文网络的引用文章ID
            node = line[2: -1]
            Cites.append((Nodes[-1], node))
            pass

    # 将内容文本进行编号,然后添加关系到Cites中,虚拟节点加入到Nodes中
    virtual_node_map = {}
    for item in Contents:
        virtual_node_index = str(int(Nodes[-1]) + 1)
        virtual_node_map[virtual_node_index] = item[1]
        Nodes.append(virtual_node_index)
        Cites.append((item[0], virtual_node_index))
    
    # 存储边的关系元组到本地
    save2csv(Cites)
    
    return Nodes,Contents

# 读取csv文件到networkx的图G中
def loadGraph(filename):
    # 定义一个Graph来存储节点边等信息,默认初始权重为1
    G = nx.Graph()
    with open(filename, mode="r", encoding="utf-8") as f:
        # 第一行存储的是顶点的数目和边的数目
        n, m = f.readline().split(",")
        for line in f:
            u, v = map(int, line.split(","))
            try:
                G[u][v]['weight'] += 1
            except:
                G.add_edge(u, v, weight=1)
    return G

# 统计Degree-Frequency
# 遍历所有的节点,然后统计每个节点的度数,然后绘制幂律分布图像
def countDF():
    lines = readFile()
    Nodes,Contents = statistics(lines)
    G = loadGraph('edges.csv')
    count = {} # 统计图中的度-频率的字典
    for node in G.nodes():
        degree = G.degree(node)
        if degree not in count.keys():
            count[degree] = 1
        else:
            count[degree] += 1
    # 得到绘制的度-频率的值字典count
    # 由于个别节点的度数目过于多,这里去除度大于10000的度的值
    # import copy
    # keys = copy.copy(list(count.keys()))
    # for key in keys:
    #     if count[key] > 10000:
    #         del count[key]
    # count
    x = np.array(list(count.keys()))
    y = np.array(list(count.values()))
    return x, y

def plot(x, y):
    x_ln = [math.log(i) for i in x]
    y_ln = [math.log(i) for i in y]
    plt.scatter(x_ln, y_ln, s=20, c='r', alpha=0.5)
    plt.show()
    

# 函数入口
if __name__=="__main__":
    x, y = countDF()
    plot(x, y)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103

接下来,就对这个网络进行嵌入表示,得出每个节点的向量表示,

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/354950
推荐阅读
相关标签
  

闽ICP备14008679号