本文主要介绍了一篇名为《Attention-based Deep Multiple Instance Learning》的研究论文,由Maximilian Ilse、Jakub M. Tomczak和Max Welling合作完成,他们都来自荷兰的阿姆斯特丹大学。该论文发表于2018年的第35届国际机器学习会议(International Conference on Machine Learning, ICML)论文集,收录于PMLR 80卷。
这项研究属于机器学习和计算机视觉交叉领域,特别是聚焦于多示例学习(Multiple Instance Learning, MIL)这一方向。多示例学习是一种弱监督学习范式,其特点是训练数据由一组组“包”组成,每个“包”包含多个“示例”,但只有一个包级别的标签。例如,在医学图像分析中,一张被标注为“恶性肿瘤”的整张病理切片(包)由成千上万个图像块(示例)构成,但只有部分图像块实际包含癌细胞。传统的MIL方法在实现高精度包分类和定位关键示例(即触发包标签的示例,如癌细胞区域)之间存在权衡,往往难以兼顾。本研究旨在解决这一矛盾,提出一种既灵活又可解释的深度MIL模型。其学术背景是,随着深度学习的发展,将神经网络应用于MIL问题成为趋势,但如何设计一个既能有效聚合包内信息、又能保持对关键实例识别能力的池化(pooling)操作是关键挑战。本研究的目标是:1)提出一个基于神经网络的、完全可微的MIL框架;2)引入一种新颖的、基于注意力机制的可训练MIL池化算子,以替代传统的非自适应算子(如最大值max或均值mean);3)在保证分类性能的同时,提供对决策的解释,即通过注意力权重定位关键实例;4)在多个数据集上验证方法的有效性和优越性。
该研究的详细工作流程遵循了一个清晰的三步范式,这源于对称函数基本定理的启示。整个模型流程可以概括为:首先,使用一个神经网络(记作函数 f)将每个输入示例(如图像块)映射到一个低维嵌入表示(embedding)。其次,使用一个置换不变(即顺序无关)的池化函数(记作σ)将所有示例的嵌入表示聚合成一个固定长度的包表示。最后,使用另一个神经网络(记作函数 g)对这个包表示进行处理,输出包标签的概率。
研究的关键创新和核心流程在于第二步——池化操作的设计。作者没有使用预定义的非自适应池化算子(如max或mean),而是提出了一种基于注意力机制的可训练池化层。其具体操作如下:对于一个包含K个嵌入向量 {h₁, h₂, …, h_K} 的包,其聚合表示z通过加权平均计算:z = Σ (a_k * h_k),其中k从1到K。权重a_k(即注意力权重)的计算方式是核心:a_k = exp{ w^T * tanh(V * h_k^T) } / Σ_j exp{ w^T * tanh(V * h_j^T) }。这里,w和V是可学习的参数矩阵,tanh是双曲正切非线性激活函数。这个设计确保所有权重之和为1,且计算过程与示例顺序无关。此外,作者还提出了一个“门控注意力”变体,在计算中引入了额外的门控机制:a_k = exp{ w^T * ( tanh(V * h_k^T) ⊙ sigmoid(U * h_k^T) ) } / Σ_j exp{ w^T * ( tanh(V * h_j^T) ⊙ sigmoid(U * h_j^T) ) },其中U是另一组参数,⊙是逐元素乘法,sigmoid是sigmoid函数。门控机制旨在增强模型的表达能力,克服tanh函数在某些区间的近似线性问题。
在实验部分,研究包含了多个步骤,覆盖了不同类型的MIL数据集,以全面验证模型。实验对象和样本量如下: 1. 经典MIL基准数据集:包括MUSK1, MUSK2, Fox, Tiger, Elephant五个数据集。这些数据集的特征是示例已预先提取好特征,包和示例数量相对较少。研究使用10折交叉验证,重复5次,报告平均分类准确率。 2. MNIST-Bags图像数据集:这是一个基于MNIST手写数字集构造的合成数据集。通过从MNIST训练集和测试集中随机采样图像构建包,包的大小呈高斯分布。包被标记为正例的条件是其中至少包含一个数字“9”。研究系统性地考察了不同平均包大小(10, 50, 100)和不同训练包数量(从50到500)下模型的性能,使用固定1000个测试包进行评估,主要评价指标是曲线下面积(AUC)。 3. 真实世界组织病理学数据集:包括两个数据集。 * 乳腺癌数据集(Breast Cancer):包含58张H&E染色图像(包),标注为恶性或良性。每张图像被分割成32x32的图像块(示例),总计每包约672个块,丢弃空白区域过多的块。 * 结肠癌数据集(Colon Cancer):包含100张H&E染色图像。任务是判断一个包中是否包含至少一个上皮细胞(epithelial cell)核,因为结肠癌源于上皮细胞。每个包由以标记细胞核为中心的27x27图像块构成。 对于这两个医学数据集,研究使用10折交叉验证,重复5次,报告准确率、精确率、召回率、F1分数和AUC等多个指标。
在数据处理和分析流程上,对于图像数据集(MNIST-Bags和两个病理数据集),函数f采用成熟的卷积神经网络架构:MNIST-Bags使用LeNet5,病理数据集采用一篇先前工作中提出的CNN模型。所有模型均使用Adam优化器进行端到端训练。为了防止医学数据集上的过拟合,采用了数据增强技术。整个工作流的数据分析完全基于模型在测试集或交叉验证折上预测得到的性能指标和注意力权重图。
研究的主要结果在各个实验环节均得到了体现: 在经典MIL基准数据集上,论文中的表1显示,提出的“注意力”和“门控注意力”方法在五个数据集上的分类准确率与当时性能最佳的传统MIL方法(如MI-Graph, MI-VLAD, MI-FV)以及深度MIL方法(如MI-Net with DS/RC)相当或接近,所有结果均在平均值的标准误差范围内。这表明即使在非图像、小样本的传统MIL问题上,基于注意力的深度模型也能保持竞争力。
在MNIST-Bags图像数据集上,结果更具洞察力。如图1、2、3所示,当训练样本量较小(即训练包数量少或平均包大小较小时),提出的注意力方法(无论是实例级还是嵌入级)的AUC显著高于其他所有对比方法,包括基于SVM的MI-SVM、使用max或mean池化的深度模型。这证明了注意力机制在“小样本制度”下的优越性。随着训练数据量的增加,所有方法的性能趋于接近,但注意力方法始终处于领先或持平状态。研究还发现,嵌入级方法普遍优于实例级方法;均值池化(mean)的表现显著差于最大值池化(max)。此外,图4展示了一个关键的可视化结果:对于一个包含数字9的正包,模型预测正确,并且其计算出的注意力权重高度集中在几个数字“9”的图像上(权重值如0.226, 0.248, 0.280, 0.246),而其他数字的权重近乎为零。这直观地证明了注意力机制能够自动定位并高亮关键实例,实现了可解释性。
在组织病理学数据集上的结果进一步强化了上述结论。如表2和表3所示,在乳腺癌和结肠癌数据集上,基于注意力(尤其是门控注意力)的方法在大多数评估指标上(特别是召回率和AUC)都取得了最佳或接近最佳的性能。值得注意的是,在乳腺癌数据集上,嵌入级+max池化的方法几乎失败,而注意力方法表现稳健,凸显了自适应池化的重要性。更高的召回率在医疗领域尤为重要,意味着更少的漏诊。更重要的结果是图5展示的可解释性应用:对于一张结肠癌病理图像,模型仅使用图像级别的标签进行训练,但其生成的注意力权重热图与医生标注的上皮细胞核区域(地面真相)高度吻合。这证明了模型不仅能够做出准确诊断,还能可靠地定位出疑似病变区域(ROI),为医生提供了有价值的参考。相比之下,实例级分类器往往只选择一小部分阳性块,可靠性较低。
基于以上结果,本研究得出结论:提出了一种灵活且可解释的、完全由神经网络参数化的多示例学习方法。该方法通过引入基于(门控)注意力机制的可训练MIL池化层,成功地统一了分类性能与模型可解释性。实验表明,该方法在经典数据集上与最优方法持平,在更具挑战性的图像数据集(尤其是小样本情况和真实的医疗图像数据集)上超越了现有方法,并能通过注意力权重有效定位关键实例,生成与专家标注一致的热力图。
本研究的科学价值和应用价值显著。在科学上,它提供了一个将对称函数定理、深度学习和注意力机制优雅结合的MIL通用框架,为后续研究指明了方向。在应用上,特别是在计算病理学等医疗影像分析领域,该方法能够利用易获取的弱标注数据(仅需整张图像的诊断标签),同时实现高精度诊断和自动病灶区域提示,有望大幅减轻病理医生的工作负担,具有重要的临床转化潜力。论文也指出了未来的研究方向,如扩展到多类别MIL问题、考虑示例间的依赖关系等。
本研究的亮点在于:第一,方法论的创新性:首次将注意力机制作为一种完全可训练、自适应的置换不变池化算子系统性地引入MIL框架,并给出了理论依据(对称函数定理)。第二,性能的优越性:尤其在数据有限或任务复杂的真实场景(如医学图像)中,表现出超越传统固定池化方法的分类能力。第三,可解释性的实现:模型内置的注意力权重自然提供了对决策的解释,能够直接可视化关键实例,弥合了实例级与嵌入级方法之间的鸿沟,满足了如欧盟《通用数据保护条例》等对算法决策可解释性的法规要求。第四,广泛的验证:通过在合成数据、经典基准和真实医疗数据上的全面实验,扎实地证明了方法的有效性、鲁棒性和实用价值。