本文介绍一下最近比较火的Late Chunking(延迟分块) ,它用于改进文本嵌入的质量。该方法利用长上下文嵌入模型,在变换器模型之后和均值池化之前进行分块,从而捕获完整的上下文信息。与传统的朴素分块方法相比,延迟分块在各种检索任务中表现更优,且无需额外训练。此外,本文还提出了长延迟分块方法,用于处理超过模型上下文长度的长文档,并介绍了一种基于跨度池化的训练方法,以进一步提高检索准确性。
延迟分块(Late Chunking)
延迟分块是一种利用最近嵌入模型的长上下文输入窗口与大多数应用程序的最佳文本块相对较小的大小之间的差异的策略。这些模型支持更长的输入文本,例如jina-embeddings-v2-small支持8192个标记——大约十页标准文本——而最佳块大小通常要小得多,例如段落的大小。
算法流程如下:
输入
-
文本 :待处理的文本。 -
分块策略 :用于确定分块边界的策略(如固定大小边界、句子边界或语义句子边界)。
输出
-
块嵌入 :由每个块的嵌入向量组成的序列 。
具体步骤如下:
-
分块:使用分块策略 将文本 分块,得到分块序列 。
-
标记化:
-
将文本 标记化为一系列标记(tokens),得到标记ID序列 和每个标记的字符长度序列 。
-
-
嵌入生成:
-
使用嵌入模型对标记ID序列 进行编码,生成每个标记的嵌入向量序列 。
-
-
确定分块边界:
-
计算分块 的起始和结束字符位置 和 。 -
初始化 和 为0。 -
遍历标记序列 : -
将 添加到 列表中。 -
如果当前标记的字符长度 加上之前的字符长度总和等于 ,则设置 为 。 -
如果当前标记的字符长度 加上之前的字符长度总和等于 ,则设置 为 。 -
初始化一个空列表 用于存储分块边界。 -
对于每个分块 :
-
-
均值池化:
-
对标记嵌入序列 进行均值池化,生成该块的嵌入向量 。 -
将 添加到 列表中。 -
初始化一个空列表 用于存储块嵌入。 -
对于每个分块边界 :
-
-
返回块嵌入:
-
返回块嵌入序列 。
-
长延迟分块(Long Late Chunking)
对于超过模型上下文长度的长文档,论文中提出了长延迟分块方法。
目标:
输入:
-
文本 :待处理的文本。 -
分块策略 :用于确定分块边界的策略(如固定大小边界、句子边界或语义句子边界)。 -
最大标记长度 :每个宏块的最大标记数量。 -
重叠长度 :宏块之间的重叠标记数量。
输出:
-
块嵌入 :由每个块的嵌入向量组成的序列 。
具体详细步骤:
-
分块:使用分块策略 将文本 分块,得到分块序列 。
-
标记化:将文本 标记化为一系列标记(tokens),得到标记ID序列 和每个标记的字符长度序列 。
-
检查标记数量:如果标记数量 小于等于最大标记长度 ,则直接使用延迟分块方法处理。
-
初始化变量:初始化变量 为 , 为 1,以及一个空列表 用于存储块嵌入。
-
处理宏块:
-
使用嵌入模型对标记ID序列 进行编码,生成每个标记的嵌入向量序列 。 -
如果 为 1,则将所有标记嵌入 添加到 列表中。 -
否则,将重叠部分的标记嵌入 添加到 列表中。 -
更新标记位置: 更新为 , 更新为 。 -
当 小于标记总数 时,执行以下步骤:
-
-
返回块嵌入:
-
返回块嵌入序列 。
-
训练方法(Training Method)
论文还提出的训练方法旨在进一步提高延迟分块(Late Chunking)的检索准确性。该方法基于跨度池化(Span Pooling)技术,通过训练模型将标注文本跨度中的相关信息编码到其标记嵌入中。以下是对训练方法的详细介绍,包括训练数据、训练过程和相关细节。
1. 训练数据准备
训练数据由查询、相关文档和文档中相关跨度的标注组成。具体步骤如下:
-
数据收集:从数据集中收集查询、相关文档和文档中相关跨度的标注。 -
数据格式化:将数据格式化为元组,其中是查询,是相关文档,是文档中相关跨度的标注。
数据集
论文中使用了两个数据集进行训练:
-
FEVER(Thorne et al., 2018):该数据集包含从维基百科中提取的文档,并标注了文档中支持或反驳查询声明的相关句子。 -
TriviaQA(Joshi et al., 2017):该数据集包含从维基百科和网页中提取的文档,并标注了文档中包含答案的相关短语。
2. 训练过程
训练过程包括以下几个关键步骤:
-
标记化:对文档进行标记化,生成标记序列。 -
嵌入生成:使用长上下文嵌入模型对文档的所有标记进行编码,生成每个标记的嵌入向量。 -
跨度池化:对标注跨度内的标记嵌入进行均值池化,生成文档嵌入。 -
InfoNCE损失:使用InfoNCE损失进行训练,确保查询嵌入与相关文档嵌入的相似度高于与其他文档嵌入的相似度。
2.1 标记化
-
输入:文档 。 -
输出:标记ID序列 和每个标记的字符长度序列 。
2.2 嵌入生成
-
输入:标记ID序列 。 -
输出:标记嵌入向量序列 。
2.3 跨度池化
-
输入:标记嵌入向量序列 和标注跨度 。 -
输出:文档嵌入 。
2.4 InfoNCE损失
-
输入:查询嵌入 和文档嵌入 。 -
输出:InfoNCE损失 。
训练超参数
-
模型:jina-embeddings-v3和jina-embeddings-v2-small-en。 -
训练数据:FEVER和TriviaQA。 -
批量大小:512。 -
训练步数:500步。 -
损失函数:InfoNCE损失。
微调过程本身遵循Gunther et al.(2023)中描述的成对训练阶段,其中模型在文本对上使用InfoNCE(van den Oord et al., 2018)损失进行训练,该损失定义在一批的对上,并使用余弦相似度函数:
在这里,查询向量通过通常的方式将嵌入模型应用于查询文本获得。对于文档嵌入,通过将模型应用于文档获得标记嵌入集,并对跨度内的标记嵌入执行均值池化操作,因此称为“跨度池化”。
文中使用双向版本的损失,其中通过交换对的顺序从获得:


