关键词:EEG;自监督学习;对比学习
有监督学习方法受到数据收集与标记成本的限制,而自监督学习方法作为一个pre-training或者feature learning的方法,在计算机视觉和时间序列分析等领域的前景十分的广阔。本文中,我们提出的自监督模型可用于学习信息性表示(informative representations)。一种成功的方法依赖于预测是否从相同的时间上下文中采样了时间窗口。正如临床相关任务(睡眠评分)和两个脑电图数据集所证明的那样,我们的方法在低数据情况下优于纯监督方法,同时无需获取标签即可捕获重要的生理信息。
借助SSL,数据的结构可用于将无监督的学习问题转变为有监督的学习问题,称为“前置任务” [3]。
在自我监督的前置任务上学习的表示然后可以在监督的下游任务上重用,从而可能大大减少所需的带标签示例的数量。
Hyvérinen从非线性独立组件分析的角度正式对SSL提出了一种general且有理论依据的方法。
该方法中提出SSL任务是通过使用辅助变量u(time index、segment index、history of the data)来构建,以训练对比分类器。此分类器学习预测样本与辅助变量是否配对。
本文提出的自监督方法是为了从无标签的EEG信号中学习端到端的特征。引入了两个时间对比学习任务:relative positioning 和 temporal shuffling。实验中显示这些基于预测时间窗口是否在时间上接近的对比学习任务可用于学习 EEG 功能,这些特征可以捕获数据背后的结构的多个组件。文中证明了这些功能在下游任务中重复使用时的效果优于无监督模型与传统的监督模型。
section2:the SSL tasks and learning problems
$S\in R^{M*C}$:输入数据 $M$ :时间样本数 $C$:通道数
$y\in -1 ,1$ :二标签
$T$:每个时间窗口的采样点数量
$\tau_{pos}$:positive context的持续时间
$\tau_{neg}$:对应于每个窗口周围的negative context的范围
2.1 pretext—-relative position
假设:相邻的时间窗口对应的标签相同
对时间窗口进行取样,$x_{t}$是锚窗口anchor window
将自定义的$N$个标签对定义为:
2.2 pretext—-temporal shuffling
从positive context 中取样第三个样本$x{t^{‘’}}$,并用它提供额外的参考点与$x{t^{‘}}$对比。此时标签给予以下定义:
$y_{i}={\begin{array}{rcl}1 & \mbox if & t<t^{‘}
2.3 feature extractor
为了了解端到端如何根据时间窗口的相对位置或顺序来区分时间窗口,我们引入了一个特征提取器:
$h:R^{T*C}\to R^{D}$ 参数$\Theta$ 将窗口$x$映射到其特征空间上。然后使用对比模块聚合每个窗口的特征表示。
对于RP pretext:$g_{RP}:R^{D}*R^{D}\to R^{D}$
例如通过计算an elementwise absolute difference
对于TS pretext:$g_{TS}:R^{D}R^{D}R^{D}\to R^{2D}$ 通过合并绝对差异
$g_{TS}(h(x),h(x^{‘},h(x’’))=(abs(h(x)-h(x^{‘})),abs((h(x^{‘})-h(x^{‘’})))\in R^{2D}$
2.4 predict
一个带有参数$\omega \in R^{D}$ or $\in R^{2D}$和偏差$\omega_{0}$的线性上下文判别模型用于预测相关目标$y$
联合损失函数写为:
按照$y$使用的惯例,预测目标是的符号
RP 和 TS 模型都可分别被视为具有两个或三个子网络的siamese神经网络。
section 3:在睡眠数据上的应用
3.1 两个睡眠数据集
提取30 s的非重叠窗口,在Sleep EDF上生成T = 2000和C = 2的窗口,在MASS上生成C = 3的T = 3840。
对窗口进行归一化,以便通道的均值为0,标准差为1。在每个recording中,总共对2000个锚点窗口进行了均匀采样。对于每个锚窗口,采样了三个正样本和三个负样本。
训练集测试集和验证集划分:
Sleep EDF数据集上:验证集——受试者0-19;测试集——受试者20-39;训练集——受试者40-82
==分别生成了训练集:512622;验证集:267,630;测试集:342,300对==
MASS 数据集:训练集——1-41;验证集——42-52;测试集——52-62
==分别生成了训练集:237,882;验证集:52,152;测试集:73,650对==
3.2 模型算法
对于特征提取器h,采用的是之前提出的架构【S. Chambon, M. N. Galtier, P. J. Arnal, G. Wainrib, and
A. Gramfort, “A deep learning architecture for temporal sleep stage classification using multivariate and multimodal time series,” IEEE Trans Neur Syst Rehab Eng, vol. 26, no. 4, pp. 758–769, 2018. 】
输入:$(C,T,1)$
CNN结构:
SleepEDF数据集:k=50,m=13,D=100 可训练参数个数:55545
MASS数据:k=64,m=16,D=100 可训练参数个数:67173
3.3 模型比较
将经过SSL任务训练的模型与三个神经网络的baseline进行比较:
1)随机初始化
2)卷积自动编码器。自编码器使用特征提取器h作为编码器,四层卷积作为解码器,均方误差作为重建损失。
3)纯监督学习。在特征提取器h中增加softmax层
人工提取了脑电图特征:均值方差、偏度峰度、标准差、(0.5,4,8,13,30,49)Hz之间的频率对数功率带及其所有可能的比率、峰峰值、Hurst指数、近似熵和Hjorth复杂度。致使每个EEG通道有34个特征,将这些特征串联至单个向量组成特征向量。
为解决class之间的不平衡问题,我们使用平衡acc(bal acc)【定义为每个class的平均召回率】来评估下游任务的模型性能。另外,训练时,加权损失也可解决class的不平衡。
3.4 实验
实验一:分析了不同SSL超参数值下CNN的性能及其对CNN的影响。
实验二:带有有限标签的SSL任务对提高预测性能的作用。
实验三:探究SSL学习到的特征,研究他们的生理相关性。
实验一
首先评估CNN架构学习SSL任务的能力。使用具有三组超参数$\tau{pos}$和$\tau{neg}$的RP和TS任务,在整个训练集上训练特征提取器h。训练了h后,我们将标记的样本投影到网络各自的特征空间中,然后在每组特征上训练多项式线性logistic回归模型以预测睡眠阶段。
【MASS数据集上,超参数$\tau{pos}$和$\tau{neg}$(in minutes)不同值时,SSL任务和下游分类任务的平衡准确度】
从结果中可以得到以下:
1)在MASS数据集上,前两组实验($\tau{pos}$=2,$\tau{neg}$=2和$\tau{pos}$=4,$\tau{neg}$=15)中,SSL任务和下游分类任务的平衡准确度近似。
2)增大超参数$\tau_{pos}$到120后,任务更加困难,SSL任务和分类任务的性能也降低。
最终,决定选择$\tau{pos}$=4,$\tau{neg}$=15,因为相较于第一组可以增加从正context中抽取到的窗口数量
实验二
用不同的方法对特征提取器h进行训练(AE、RP、TS在无标签数据上及全监督模型在标签数据上)然后提取特征。同时使用随机初始化权重的模型(未经过培训的模型)提取特征。下游分类任务采用逻辑回归。
在MASS数据集上,SSL的性能要优于纯监督模型,RP要略高于TS,AE和随机初始化模型都较低。SleepEDF数据集上也表现出类似的结果:
使用自动编码器预制的模型获得了非常低的性能,因为使用均方误差损失的重建任务鼓励模型专注于输入信号的低频率。事实上,这些频率比像EEG这样的生物信号中的高频具有更高的功率。
实验三
为了进一步探索使用 SSL 学到的功能,我们使用 UMAP [20] 将标记的 Sleep EDF 数据集上获得的 100 维嵌入投影到两个维度。
上图中可以看出,从使用标签的彩色编码样本来看,groups不仅对应着睡眠阶段,同时按照顺序排列:从图形的右侧开始,向左移动,我们可以绘制一个连续穿过 W、N1、N2 和 N3 的轨迹。R阶段与W和N1重叠。
此外,在图2-B中,人们可以观察到嵌入编码与年龄相关的信息。年轻受试者的样本占据点云的左外部分,而来自较老受试者的样本则位于U形结构的内侧。这种现象在N1、N2和N3阶段可见,但在W和R阶段不可见,那里看不到明显的老化结构。这可以解释为睡眠主轴的流行,主要特征用于识别随着年龄的增长而变化的N2和N3。
Code
feature extractor
1 | class EEG_FeatureExtractor(nn.Module): |
ReflectionPad2d
对输入数据进行扩边,扩充方法采用镜像填充。ReflectionPad2d(n)相比原来数组每行每列都要增加n
dropout
防止过拟合的方法之一,对于神经网络单元,按照一定的概率将其暂时从网络中丢弃。
dropout=0.5时效果最好