0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

ICLR 2019论文解读:胶囊图神经网络的PyTorch实现

DPVg_AI_era 来源:lp 2019-03-29 10:11 次阅读

胶囊图神经网络(CapsGNN)是在GNN启发下诞生了基于图片分类的新框架。CapsGNN在10个数据集中的6个的表现排名位居前两名。与所有其他端到端架构相比,CapsGNN在所有社交数据集中均名列首位。

本日Reddit上热议的一个话题是名为“胶囊图神经网络”(CapsGNN)的新框架。从名字不难看出,它是受图神经网络(GNN)的启发,在其基础上改进而来的成果。

CapsGNN框架的作者为新加坡南洋理工大学电气与电子工程学院的Zhang Xinyi和Lihui Chen,该研究的论文将在ICLR 2019上发表。

目前,从图神经网络(GNN)中学到的高质量节点嵌入已经应用于各种基于节点的应用程序中,其中一些程序已经实现了最先进的性能。不过,当应用程序用GNN学习的节点嵌入来生成图形嵌入时,标量节点表示可能不足以有效地保留节点或图形的完整属性,从而导致图形嵌入的性能达不到最优。

胶囊图神经网络(CapsGNN)受到了胶囊神经网络的启发,利用胶囊的概念来解决现有基于GNN的图嵌入算法的缺点。CapsGNN以胶囊形式对节点特征进行提取,利用路由机制来捕获图形级别的重要信息。因此,模型会为每个图生成多个嵌入,从多个不同方面捕获图的属性。

CapsGNN中包含的注意力模块可用于处理各种尺寸的图,让模型能够专注处理图的关键部分。通过对10个图结构数据集的广泛评估表明,CapsGNN具有强大的机制,可通过数据驱动捕获整个图的宏观属性。在几个图分类任务上的性能优于其他SOTA技术。

胶囊图神经网络基本架构

上图所示为CapsGNN的简化版本。它由三个关键模块组成:1)基本节点胶囊提取模块:GNN用于提取具有不同感受野的局部顶点特征,然后在该模块中构建主节点胶囊。 2)高级图胶囊提取模块:融合了注意力模块和动态路由,以生成多个图胶囊。 3)图分类模块:再次利用动态路由,生成用于图分类的类胶囊。

注意力模块

在CapsGNN中,基于每个节点提取主胶囊,即主胶囊的数量取决于输入图的大小。在这种情况下,如果直接应用路由机制,则生成的高级别的胶囊的值将高度依赖于主胶囊的数量(图大小),这种情况并不理想。因此,实验引入一个注意力模块来解决这个问题。

注意力模块架构。首先压平主胶囊,利用两层全连接神经网络产生每个胶囊的注意力值。利用基于节点的归一化(对每行进行归一化)来生成最终注意力值。 将标准化值与主胶囊相乘来计算标度胶囊。

实验设置与结果

我们验证了从CapsGNN中提取的图嵌入与大量SOTA方法的性能,与一些经典方法的最优性能做了对比。此外还进行了实验研究,评估胶囊对图编码特征效率的影响。我们对生成的图/类胶囊进行了简要分析。实验结果和分析如下所示。

表1为生物数据集的实验结果,表2为社会数据集的实验结果。对于每个数据集,以粗体突出显示前2个准确度。

与所有其他算法相比,CapsGNN在10个数据集中的6个的表现排名位居前两名,并且在其他数据集上也实现了基本相当的结果。与所有其他端到端架构相比,CapsGNN在所有社交数据集中均名列首位。

表1:生物数据集的实验结果

表2:社交数据集的实验结果

胶囊的效率

在胶囊的效率测试实验中,GNN的层数设置为L = 3,每层的通道数都设置为Cl = 2。通过调整节点的维度(dn)、图(dg)、胶囊和图形、胶囊的数量(P)来构造不同的CapsGNN。

表3:胶囊效率评估实验中经过测试的体系结构详细信息

图3:特征表示效率的比较。横轴表示测试架构的设置,纵轴表示NCI1的分类精度。

图胶囊的可视化

分类胶囊的可视化

胶囊图网络:基于GNN的高效快捷的新框架

CapsGNN是一个新框架,将胶囊理论融合到GNN中,来实现更高效的图表示学习。该框架受CapsNet的启发,在原体系结构中引入了胶囊的概念,在从GNN提取的节点特征的基础上,以向量的形式提取特征。

利用CapsGNN,一个图可以表示为多个嵌入,每个嵌入都可以捕获不同方面的图属性。生成的图形和类封装不仅可以保留与分类相关的信息,还可以保留关于图属性的其他信息,这些信息可能在后续流程中用到。CapsGNN是一种新颖、高效且强大的数据驱动方法,可以表示图形等高维数据。

与其他SOTA算法相比,CapsGNN模型在10个图表分类任务中有6个成功实现了更好或相当的性能,在社交数据集上的表现尤其显眼。与其他类似的基于标量的体系结构相比,CapsGNN在编码特征方面更有效,这对于处理大型数据集非常有用。

关于开源代码和模型的一些补充信息

运行环境

代码库在Python 3.5.2中实现。用于开发的软件包版本如下:

networkx 1.11tqdm 4.28.1numpy 1.15.4pandas 0.23.4texttable 1.5.0scipy 1.1.0argparse 1.1.0torch 0.4.1torch-scatter 1.1.2torch-sparse 0.2.2torch-cluster 1.2.4torch-geometric 1.0.3torchvision 0.2.1

数据集

代码会从input文件夹中获取训练图,图存储形式为JSON。用于测试的图也存储为JSON文件。每个节点id和节点标签必须从0开始索引。字典的键是存储的字符串,以使JSON能够序列化排布。

每个JSON文件都具有以下的键值结构:

{"edges": [[0, 1],[1, 2],[2, 3],[3, 4]], "labels": {"0": "A", "1": "B", "2": "C", "3": "A", "4": "B"}, "target": 1}

边缘键(edgeskey)具有边缘列表值,用于描述连接结构。标签键具有每个节点的标签,这些标签存储为字典- 在此嵌套字典中,标签是值,节点标识符是键。目标键具有整数值,该值代表了类成员资格。

输出

预测结果保存在output目录中。每个嵌入都有一个标题和一个带有图标识符的列。最后,预测会按标识符列排序。

训练CapsGNN模型由src /main.py脚本处理,该脚本提供以下命令行参数

输入和输出选项

--training-graphs STR Training graphs folder. Default is `dataset/train/`. --testing-graphs STR Testing graphs folder. Default is `dataset/test/`. --prediction-path STR Output predictions file. Default is `output/watts_predictions.csv`.

模型选项

--epochs INT Number of epochs. Default is 10. --batch-size INT Number fo graphs per batch. Default is 32. --gcn-filters INT Number of filters in GCNs. Default is 2. --gcn-layers INT Number of GCNs chained together. Default is 5. --inner-attention-dimension INT Number of neurons in attention. Default is 20. --capsule-dimensions INT Number of capsule neurons. Default is 8. --number-of-capsules INT Number of capsules in layer. Default is 8. --weight-decay FLOAT Weight decay of Adam. Defatuls is 10^-6. --lambd FLOAT Regularization parameter. Default is 1.0. --learning-rate FLOAT Adam learning rate. Default is 0.01.

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 神经网络
    +关注

    关注

    42

    文章

    4771

    浏览量

    100752
  • 数据集
    +关注

    关注

    4

    文章

    1208

    浏览量

    24697
  • pytorch
    +关注

    关注

    2

    文章

    808

    浏览量

    13221
  • GNN
    GNN
    +关注

    关注

    1

    文章

    31

    浏览量

    6336

原文标题:基于GNN,强于GNN:胶囊图神经网络的PyTorch实现 | ICLR 2019

文章出处:【微信号:AI_era,微信公众号:新智元】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    labview BP神经网络实现

    请问:我在用labview做BP神经网络实现故障诊断,在NI官网找到了机器学习工具包(MLT),但是里面没有关于这部分VI的帮助文档,对于”BP神经网络分类“这个范例有很多不懂的地方,比如
    发表于 02-22 16:08

    人工神经网络实现方法有哪些?

    人工神经网络(Artificial Neural Network,ANN)是一种类似生物神经网络的信息处理结构,它的提出是为了解决一些非线性,非平稳,复杂的实际问题。那有哪些办法能实现人工神经
    发表于 08-01 08:06

    matlab实现神经网络 精选资料分享

    神经神经网络,对于神经网络实现是如何一直没有具体实现一下:现看到一个简单的神经网络模型用于训
    发表于 08-18 07:25

    轻量化神经网络的相关资料下载

    视觉任务中,并取得了巨大成功。然而,由于存储空间和功耗的限制,神经网络模型在嵌入式设备上的存储与计算仍然是一个巨大的挑战。前面几篇介绍了如何在嵌入式AI芯片上部署神经网络:【嵌入式AI开发】篇五|实战篇一:STM32cubeIDE上部署
    发表于 12-14 07:35

    一种新型神经网络结构:胶囊网络

    胶囊网络是 Geoffrey Hinton 提出的一种新型神经网络结构,为了解决卷积神经网络(ConvNets)的一些缺点,提出了胶囊
    的头像 发表于 02-02 09:25 5871次阅读

    基于PyTorch的深度学习入门教程之使用PyTorch构建一个神经网络

    PyTorch的自动梯度计算 Part3:使用PyTorch构建一个神经网络 Part4:训练一个神经网络分类器 Part5:数据并行化 本文是关于Part3的内容。 Part3:使
    的头像 发表于 02-15 09:40 2100次阅读

    PyTorch教程8.1之深度卷积神经网络(AlexNet)

    电子发烧友网站提供《PyTorch教程8.1之深度卷积神经网络(AlexNet).pdf》资料免费下载
    发表于 06-05 10:09 0次下载
    <b class='flag-5'>PyTorch</b>教程8.1之深度卷积<b class='flag-5'>神经网络</b>(AlexNet)

    PyTorch教程之循环神经网络

    电子发烧友网站提供《PyTorch教程之循环神经网络.pdf》资料免费下载
    发表于 06-05 09:52 0次下载
    <b class='flag-5'>PyTorch</b>教程之循环<b class='flag-5'>神经网络</b>

    PyTorch教程之从零开始的递归神经网络实现

    电子发烧友网站提供《PyTorch教程之从零开始的递归神经网络实现.pdf》资料免费下载
    发表于 06-05 09:55 0次下载
    <b class='flag-5'>PyTorch</b>教程之从零开始的递归<b class='flag-5'>神经网络</b><b class='flag-5'>实现</b>

    PyTorch教程9.6之递归神经网络的简洁实现

    电子发烧友网站提供《PyTorch教程9.6之递归神经网络的简洁实现.pdf》资料免费下载
    发表于 06-05 09:56 0次下载
    <b class='flag-5'>PyTorch</b>教程9.6之递归<b class='flag-5'>神经网络</b>的简洁<b class='flag-5'>实现</b>

    PyTorch教程10.3之深度递归神经网络

    电子发烧友网站提供《PyTorch教程10.3之深度递归神经网络.pdf》资料免费下载
    发表于 06-05 15:12 0次下载
    <b class='flag-5'>PyTorch</b>教程10.3之深度递归<b class='flag-5'>神经网络</b>

    使用PyTorch构建神经网络

    PyTorch是一个流行的深度学习框架,它以其简洁的API和强大的灵活性在学术界和工业界得到了广泛应用。在本文中,我们将深入探讨如何使用PyTorch构建神经网络,包括从基础概念到高级特性的全面解析。本文旨在为读者提供一个完整的
    的头像 发表于 07-02 11:31 708次阅读

    PyTorch神经网络模型构建过程

    PyTorch,作为一个广泛使用的开源深度学习库,提供了丰富的工具和模块,帮助开发者构建、训练和部署神经网络模型。在神经网络模型中,输出层是尤为关键的部分,它负责将模型的预测结果以合适的形式输出。以下将详细解析
    的头像 发表于 07-10 14:57 500次阅读

    pytorch中有神经网络模型吗

    当然,PyTorch是一个广泛使用的深度学习框架,它提供了许多预训练的神经网络模型。 PyTorch中的神经网络模型 1. 引言 深度学习是一种基于人工
    的头像 发表于 07-11 09:59 699次阅读

    PyTorch如何实现多层全连接神经网络

    PyTorch实现多层全连接神经网络(也称为密集连接神经网络或DNN)是一个相对直接的过程,涉及定义网络结构、初始化参数、前向传播、损失
    的头像 发表于 07-11 16:07 1180次阅读