Press "Enter" to skip to content

ECCV 2020 | 华为诺亚提出基于元强化学习的跨任务可迁移网络架构搜索方案

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

 

近年来,神经网络架构搜索(NAS)取得了许多突破,但许多算法仍受限于特定的搜索空间及视觉任务,比如大多数算法都是面向一个固定的数据集,无法在面对多项任务时,对跨任务的知识进行重复利用,因此无法实现对搜索策略进行跨任务的高效迁移学习。

 

本文提出的 CATCH 是一种基于元强化学习(Meta-RL)的网络架构搜索方案,它通过构建多个小型任务对搜索策略进行预训练,从而在迁移到目标任务上时能取得更快更好的搜索效果。元学习(Meta-learning)和强化学习(RL)的结合使得 CATCH 能有效地适应新任务,且由于各类搜索空间都能被建构为序列决策(sequential decision-making)类问题,因此这类方法可以适用于多种搜索空间。近年来的 NAS 研究进展迅速,但大多关注单一任务,本篇论文也是目前少数率先研究多领域(分类、检测、分割)架构搜索的算法之一。

 

 

论文地址:https://www.catch-nas.com/

 

背景和动机

 

当前的 NAS 方法已经在多个领域产生了超过了人工设计的神经网络,不过这些方法在可迁移性和搜索效率上还有待提高。NAS 的技术在未来有很大的应用潜能,但这些美好愿景的实现很大程度要求搜索算法具备一些能力,如: (1) 有效处理大量任务;(2)广泛适用于不同的搜索空间(search space);(3)保持其在各种任务下的搜索表现。这些特征是当前的许多算法所忽略的,主要体现在:

 

1. 搜索策略缺乏在多个任务之间的迁移能力。许多算法只能在遇到新任务时从头开始重复而低效地进行搜索。

 

2. 对源任务搜索结果的直接部署无法保证最优表现。例如当前通常的做法是将 CIFAR-10 分类任务的搜索结果直接部署到 ImageNet 分类任务,这样的做法无法保证直接部署的网络结构是最优的,且当任务的性质差距较大时(如 MNIST 与 ImageNet),直接部署也并不合理。

 

3. 一些算法的搜索空间比较受限。 例如当前一些可微分算法(DARTS)只能应用于微观结构(cell-structure)的搜索,尚未适用于更广泛结构的搜索,缺乏通用性。

 

为了解决这些问题和增强 NAS 算法的可迁移性,作者提出了 CATCH,一个基于元强化学习的跨任务可迁移网络架构搜索方案。如图 1 所示,CATCH 框架中的搜索代理(即 CATCHer)充当决策者。作者首先通过构建资源耗费低的元训练任务,对 CATCHer 进行元训练,然后将其部署到目标任务以快速适应。

 

 

CATCH 方法概述

 

作者的灵感来自元强化学习(meta-RL)。在 meta-RL 中,相同动作空间(action space)但奖励函数(reward function)不同的问题可以被看作为不同的任务(task)。在 NAS 中,大部分问题的搜索空间是不变的,即动作空间一致;但数据集(例如 CIFAR-10 与 ImageNet)或者视觉领域(例如分类、检测、分割)的更改都会改变奖励函数,因此根据定义可以将这些数据集(无论是否在同一领域)视为不同的任务。从这个视角来看,NAS 的跨领域、跨数据集高效搜索问题,就可以被转化为元强化学习中对多个任务的快速迁移与适应的一个问题。

 

元强化学习在这里之所以有效,很大程度也依赖于许多数据集的网络设计中存在着一些共性,这些共性在过去成就了迁移学习(如将一个优秀的 ImageNet 分类网络迁移到 COCO 检测任务上),也为元强化学习创造了快速适应(fast adaptation)的空间。

 

 

搜索过程

 

如上图所示,在每一次任务的搜索过程中,CATCH 会首先选取任意一个网络进行训练、评估,并获得(模型 – 奖励)的一个元组(m,r)用于初始化搜索历史。

 

接下来 CATCHer 的三个核心组件会分别发挥作用:

 

1. 任务信息编码器(Context Encoder):任务信息编码器通过变分推断(Variational Inference)的方式将搜索历史编码为任务表征 z,指导 RL 控制器和网络评估器的表现。 作者将任务表征 z 建模为具有对角协方差矩阵的多元高斯分布,在编码的过程中,编码器旨在估计后验分布 p(z|c_{1:N})。由于 c_{1:N} 只与任务有关,因此可以被分解为高斯因子的乘积,

 

 

2. 其中 f_\phi 用于预测 p(z|c_i)任务信息均值和标准差的神经网络。任务表征 z 从分布中随机采样得到,这样的设定有利于将对任务的不确定性加入模型,并缓解稀疏奖励(sparse reward)的问题实现有效探索。

 

3.RL 控制器(RL Controller):RL 控制器进行序列决策,生成候选网络(Candidate Networks)。网络的生成可以被视为决策问题,其中 RL 控制器的每个动作都决定了最终架构的一个属性。该属性可以是在微观结构搜索中形成特定的操作类型(例如,跳跃连接、卷积操作等),也可以是在宏观结构(macro skeleton)搜索中形成网络的形状(例如,宽度、深度等)。

 

m 可以表示成序列决策列表 [a_1, a_2, …, a_L]。在每个时间步长,RL 控制器输入已经完成的决策和任务表征 z,并输出选择接下来动作的分布,从中采样相应的动作。控制器随机采样 M 个网络作为候选网络。

 

4. 网络评估器(Network Evaluator)。网络评估器用于预测候选网络的性能,并确定选取预测值最高的网络进行实际训练。

 

优化过程

 

所有这三个组件都可以进行端到端优化。RL 控制器(以 \ theta_c 为参数)使用 RL 算法近端策略优化(PPO)进行训练:

 

 

网络评估器通过优化 Huber Loss 进行训练,其中采用了优先经验回放(Prioritized Experience Replay)的技巧,从而提高采样效率。

 

 

任务信息编码器的优化过程将上述两个优化目标作为其优化目标的一部分。每个任务的最终变分下界(Variational Lower Bound)为

 

 

公式中 KL 散度近似于约束 z 和 c 之间的互信息的变分信息瓶颈(Variational Information Bottleneck),此信息瓶颈可作为正则化,以避免过拟合训练任务。p(z)是单位高斯先验。由于 (1) 任务表征 z 用作控制器和评估器的输入,并且 (2) p(z)和 q_\phi(z|c)是高斯分布,且其中 KL 散度使用其均值和标准差进行计算,因此可以使用重参数化技巧(Reparameterization Trick)将梯度端对端反向传播到任务信息编码器进行更新。

 

 

CATCH 框架包括两个阶段:如算法 1 所示的元训练阶段(Meta-training Phase)和适应阶段(Adaptation Phase)。在元训练阶段,我们在一组资源消耗低的元训练任务中训练 CATCHer。此阶段的主要目标是为任务信息编码器提供多样化的任务,使其产生有意义的表征。在适应阶段,经过元训练的 CATCHer 在任务信息编码的指导下来有效找到目标任务的优秀网络。

 

实验和结果

 

作者在两个不同的搜索空间上证明了 CATCH 的有效性和通用性,分别是微观结构搜索空间和基于残差模块(Residual Block)的宏观结构搜索空间。同时也探索了从分类任务迁移到不同领域任务(目标检测、语义分割)的可能性。

 

微观结构搜索空间

 

作者从 ImageNet16(图片缩小为 16×16)数据集中任意抽取 10/20/30 类,构建了 25 个元训练图像分类任务。CATCH 在元训练过后分别迁移到 CIFAR-10、CIFAR-100、ImageNet16-120(图片缩小为 16×16,抽取前 120 类)分类任务上进行搜索。并基于 NAS-Bench-201基准数据集进行快速的算法评估,与其他 NAS 算法进行了比较。实验证明和 sample-based 和 one-shot 的方法进行比较都能够快速适应目标任务找到最优网络。

 

 

 

 

基于残差模块的宏观结构搜索空间

 

作者将 ImageNet 数据集中的图片任意缩小为 16×16、32×32、224×224 并任意抽取 10/20/30 类,同样构建了 25 个元训练图像分类任务。在元训练过后分别迁移到 ImageNet图像分类、COCO目标检测以及 CityScapes语义分割任务上进行网络搜索。得到了具有竞争力的搜索结果。

 

 

 

 

总结

 

本文是跨任务可迁移 NAS 的早期工作之一,提出了基于元强化学习的跨任务可迁移网络架构搜索方案 CATCH。 CATCH 主要通过在大量元训练任务上对搜索策略进行预训练,同时获得提取任务表征的能力来实现在目标任务上的快速适应。在两个搜索空间上的实验显示了将 CATCH 扩展到大型数据集和各种视觉领域方面的潜力。

Be First to Comment

发表评论

电子邮件地址不会被公开。 必填项已用*标注