分享自:

小数据准确预测:基于表格基础模型的研究

期刊:natureDOI:10.1038/s41586-024-08328-6

《Nature》期刊于2025年1月9日发表了题为《Accurate predictions on small data with a tabular foundation model》的研究论文,由来自德国弗莱堡大学机器学习实验室(Machine Learning Lab, University of Freiburg)、柏林健康研究所计算医学中心(Computational Medicine, Berlin Institute of Health at Charité)及Prior Labs的Noah Hollmann、Samuel Müller等学者共同完成。该研究提出了一种名为TabPFN(Tabular Prior-Data Fitted Network)的表格数据基础模型,通过上下文学习(In-Context Learning, ICL)机制,在小规模表格数据(≤10,000样本)的预测任务中显著超越了当前最优方法(如梯度提升决策树),且训练时间极短(仅需2.8秒)。以下从学术背景、方法流程、结果与结论等方面展开详述。


一、学术背景

表格数据(tabular data)广泛存在于生物医学、经济学、气候科学等领域,其核心预测任务是基于特征列填充标签列的缺失值。过去20年,梯度提升决策树(Gradient-Boosted Decision Trees, GBDT)(如XGBoost、CatBoost)主导了这一领域,而深度学习因表格数据的异质性(如特征类型多样、缺失值、类别不平衡等)表现不佳。传统方法的局限包括:
1. 分布外预测能力差
2. 跨数据集知识迁移困难
3. 无法与神经网络结合(缺乏梯度传播)

为此,研究者提出TabPFN,旨在通过合成数据预训练的Transformer模型,实现端到端的表格学习算法自动化设计,突破传统方法的瓶颈。


二、研究方法与流程

1. 数据生成:基于结构因果模型的合成数据集

  • 生成流程
    • 步骤1:采样高层超参数(如数据集大小、特征数、难度级别);
    • 步骤2:构建有向无环图(DAG),模拟特征与目标的因果关系;
    • 步骤3:通过随机噪声传播生成样本,应用多样化的计算映射(如小规模神经网络、分类特征离散化、决策树结构);
    • 步骤4:后处理(如Kumaraswamy分布扭曲、量化离散化、随机缺失值注入)。
  • 规模:预训练阶段生成1亿个合成数据集,覆盖多样化的数据挑战(如非平滑函数、多峰分布)。

2. 模型架构:二维Transformer设计

  • 核心改进
    • 双向注意力机制:每个单元格(cell)分别对同行(样本内)和同列(特征间)做注意力计算,实现样本与特征顺序的不变性;
    • 缓存优化:分离训练集与测试集的推理状态,CPU上实现300倍加速;
    • 回归输出:采用分段常数分布(piece-wise constant distribution)预测多模态目标分布。
  • 计算效率:单H100 GPU可处理5,000万单元格(如500万行×10特征)。

3. 预训练与推理流程

  • 预训练目标:最小化合成数据集中掩码目标的交叉熵损失,近似贝叶斯后验预测分布。
  • 推理阶段
    • 输入包含带标签的训练集和无标签测试集,通过单次前向传播完成预测;
    • 支持微调、数据生成、密度估计和嵌入学习等基础模型功能。

三、主要实验结果

1. 性能对比:全面超越基线方法

  • 基准数据集:AutoML Benchmark(29分类+28回归数据集,≤10,000样本)。
  • 关键结果
    • 分类任务:TabPFN默认配置(2.8秒)的ROC AUC达0.939,显著优于CatBoost(默认0.752;调优4小时后0.822);
    • 回归任务:归一化RMSE为0.923,优于CatBoost的0.872(调优后差距扩大至0.093);
    • 效率优势:速度比调优4小时的基线快5,140倍(分类)和3,000倍(回归)。

2. 鲁棒性验证

  • 抗干扰能力:对无信息特征(随机添加)、异常值(随机放大至12倍标准差)的敏感度低于传统神经网络(如MLP);
  • 数据缩减实验:仅用50%样本时,性能仍与CatBoost(全量数据)相当。

3. 基础模型能力展示

  • 密度估计:成功建模双缝实验(double-slit experiment)的光强多峰分布;
  • 数据生成:合成数据与真实数据(如German Credit数据集)分布一致;
  • 嵌入学习:在MFAT-Factors手写数字数据集上,特征嵌入呈现清晰的类别聚类。

四、研究结论与价值

  1. 科学价值
    • 首次将上下文学习(ICL)应用于表格数据,证明了通过合成数据预训练可自动发现高效算法;
    • 为小规模表格数据建立了新的性能基准,推动端到端算法设计范式转变。
  2. 应用价值
    • 加速生物医学风险模型、药物发现等领域的决策流程;
    • 支持快速迭代的数据科学工作流(如特征工程、超参数调优)。

五、研究亮点

  1. 方法创新
    • 二维Transformer架构:突破序列化处理的局限,实现表格结构的显式建模;
    • 合成数据驱动:避免真实数据隐私问题,覆盖广泛的数据分布。
  2. 性能突破
    • 在Kaggle竞赛子集(≤10,000样本)中,默认TabPFN全面优于CatBoost;
    • 支持概率预测(如回归任务的多峰分布建模)。

六、其他价值内容

  • 开源与可访问性:代码公开(Zenodo存档),提供API接口,支持消费级GPU运行;
  • 扩展应用:已探索时序数据(如ECG、神经影像)和遗传数据的适配潜力,为后续研究指明方向。

该研究通过“定义行为示例而非显式指令”的算法设计理念,为表格数据建模开辟了新路径,其方法框架或可推广至其他结构化数据领域。

上述解读依据用户上传的学术文献,如有不准确或可能侵权之处请联系本站站长:admin@fmread.com