本研究由德国弗莱堡大学的Noah Hollmann、Samuel Müller、Katharina Eggensperger和Frank Hutter(同时任职于Bosch人工智能中心)共同完成,以会议论文形式发表于ICLR 2023(International Conference on Learning Representations)。
领域定位:该研究属于机器学习与自动机器学习(AutoML)交叉领域,聚焦于表格数据分类这一长期被深度学习忽视但实际应用最广泛的数据类型。传统上,梯度提升决策树(GBDT)因其训练速度快、鲁棒性强而主导该领域,而深度学习模型因计算成本高、调参复杂未能突破。
核心问题:现有AutoML系统(如Auto-Sklearn、AutoGluon)虽能实现高性能,但需要数十分钟至数小时的超参数优化和模型训练时间,难以满足实时决策需求。
创新目标:研究团队提出TabPFN(Tabular Prior-Data Fitted Network),一种基于Transformer架构的预训练模型,可在1秒内完成小型表格数据(≤1000样本,≤100数值型特征,≤10类别)的分类任务,无需超参数调优,性能媲美主流AutoML系统。其核心突破在于将传统“逐数据集训练”模式转变为基于先验的贝叶斯推理近似,通过单次前向传播实现预测。
理论基础:PFN是一种通过合成数据预训练来近似贝叶斯后验预测分布(PPD)的模型。其关键思想是:
- 离线训练阶段:从预设的先验分布中生成大量合成数据集,训练Transformer学习如何根据任意训练集预测测试样本标签。
- 在线推理阶段:对真实数据仅需单次前向传播即可输出PPD,无需梯度更新。
TabPFN的改进:
- 架构优化:采用12层Transformer,支持可变长度输入(通过零填充处理不同特征数),修改注意力掩码以加速推理。
- 混合先验设计:结合结构因果模型(SCM)和贝叶斯神经网络(BNN)先验,强化对表格数据因果关系的建模能力。
SCM先验:
- 生成机制:从随机生成的有向无环图(DAG)中采样因果模型,节点表示特征或目标变量,边表示因果关系(如图2所示)。
- 奥卡姆剃刀原则:偏好结构简单的SCM(如节点数少的模型概率更高)。
- 数据多样性:通过不同激活函数和噪声分布生成复杂的数据分布(如图3a展示的合成数据与真实数据对比)。
BNN先验:
- 从随机采样的神经网络架构(层数、宽度、激活函数)生成数据,作为SCM先验的补充。
分类任务转换:将SCM/BNN生成的连续标签通过区间划分映射为离散类别,支持多类不平衡数据。
训练阶段:
- 数据生成:18,000批次×512个合成数据集,总计20小时(8×NVIDIA RTX 2080 Ti)。
- 损失函数:交叉熵损失,优化Transformer对合成数据集中隐藏样本的预测能力。
评估基准:
- 数据集:OpenML-CC18中18个纯数值型无缺失值数据集,以及额外67个验证数据集。
- 对比方法:包括XGBoost、LightGBM、CatBoost及Auto-Sklearn 2.0、AutoGluon等AutoML系统。
- 指标:ROC AUC(多类采用OvO)、准确率、交叉熵。
作者提出17项改进计划,包括扩展至回归任务、大规模数据、对抗鲁棒性等,尤其强调因果干预推理和公平性验证的潜在突破。
(注:本文图表引用均来自原论文,实验细节可参考附录B-F及开源代码。)