分享
[算法学习] Test-Time Training的工作原理
输入“/”快速插入内容
[算法学习] Test-Time Training的工作原理
用户8537
用户5027
2025年12月10日修改
6025
6496
👨💻
作者:吵爷
基础概念
Test-Time Training(TTT)算法是一种
机器学习方法
,旨在通过在测试阶段使用额外的训练步骤来提高模型的性能。它是一种将
训练过程与测试过程动态结合
的技术,主要应用于处理测试分布与训练分布存在偏差的场景,例如
领域自适应
或应对
分布漂移
的问题。
核心思想
TTT算法在模型测试时,不仅直接用模型进行预测,还会引入一个
额外的目标函数(auxiliary loss)
,用测试样本来微调或更新模型的某些部分(通常是低阶参数,如特征提取器)。这个目标函数的设计可以帮助模型在新的分布下重新学习表征,从而提高对当前测试样本的适应性。
TTT的特点
动态自适应
:模型在测试时可以动态调整,以适应测试数据分布。
低计算开销
:相比完全重新训练模型,TTT仅在小规模样本上进行少量训练,效率较高。
无监督适应
:通常只利用测试样本的无监督信号(如自监督任务的损失),不需要额外标注。
举个例子,假设在训练阶段,一个分类模型还同时被训练了一个辅助任务,如预测输入样本的旋转角度。在测试阶段,模型会通过优化旋转预测任务的损失函数,更新特征提取层的参数,使其更好地适应当前测试数据分布,然后再进行分类。
和标准Fine-tuning的区别
计算实例
假设我们有一个训练好的二维平面线性分类器的模型,主要任务是对红点和蓝点进行分类(随便写一个),测试时通过辅助任务微调特征提取器参数。
原始分类模型:
当z 大于等于0,则判定为蓝点,如果z小于0则判定为红点
然后我现在增加了一个没有训练过的辅助任务,目标是让模型预测测试点到原点的距离平方。在训练分类任务时,参数 w1,w2,b仅被调整以优化分类损失。现在,辅助任务的损失引入了一个新的梯度方向,模型的参数将同时适应两种任务(分类和距离预测):
初始设置
1.
数据设置:
测试点:x = (2,4)
辅助任务:预测点到原点的距离
2.
模型设置:特征提取器公式
这里的θ代表数据点两个维度的对应权重
设置初始参数:
初始学习率
3.
辅助任务目标:方便一点直接用欧几里得距离平方,不开根号计算实际值了
步骤1:辅助任务微调