跳到内容

使用 Lance 创建多模态数据集

得益于 Lance 文件格式能够存储不同模态的数据,Lance 在存储多模态数据集方面表现出色。在这个简短的示例中,我们将介绍如何获取多模态数据集并将其存储为 Lance 文件格式。

此处选择的数据集是 Flickr8k 数据集。Flickr8k 是一个用于基于句子的图像描述和搜索的基准集合,包含 8,000 张图像,每张图像都配有五个不同的标题,这些标题清晰地描述了显著的实体和事件。图像选自六个不同的 Flickr 群组,通常不包含任何知名人物或地点,而是经过手动选择以描绘各种场景和情况。

我们将使用上述 Flickr8k 数据集创建用于多模态模型训练的图像-标题对数据集,并将其保存为 Lance 数据集形式,其中包含图像文件名、每张图像的所有标题(顺序保留)以及图像本身(二进制格式)。

导入和设置

我们假设您已下载数据集,更具体地说,是“Flickr8k.token.txt”文件和“Flicker8k_Dataset/”文件夹,并且两者都存在于当前目录中。可以从此处下载(下载数据集和文本 zip 文件)。

我们还假设您已安装 pyarrow 和 pylance,以及 opencv(用于读取图像)和 tqdm(用于美观的进度条)。

现在让我们从导入开始,并定义标题文件和图像数据集文件夹。

import os
import cv2
import random

import lance
import pyarrow as pa

import matplotlib.pyplot as plt

from tqdm.auto import tqdm

captions = "Flickr8k.token.txt"
image_folder = "Flicker8k_Dataset/"

加载和处理

在 flickr8k 数据集中,每张图像都有多个按顺序对应的标题。我们将把所有这些标题放在一个列表中,对应于每张图像,列表中它们的位置代表了它们最初出现的顺序。让我们将注解(图像路径和相应的标题)加载到一个列表中,列表的每个元素都是一个元组,由图像名称、标题编号和标题本身组成。

with open(captions, "r") as fl:
    annotations = fl.readlines()

# Converts the annotations where each element of this list is a tuple consisting of image file name, caption number and caption itself
annotations = list(map(lambda x: tuple([*x.split('\t')[0].split('#'), x.split('\t')[1]]), annotations))

现在,对于同一图像的所有标题,我们将它们放入一个按顺序排序的列表中。

captions = []
image_ids = set(ann[0] for ann in annotations)
for img_id in tqdm(image_ids):
    current_img_captions = []
    for ann_img_id, num, caption in annotations:
        if img_id == ann_img_id:
            current_img_captions.append((num, caption))

    # Sort by the annotation number
    current_img_captions.sort(key=lambda x: x[0])
    captions.append((img_id, tuple([x[1] for x in current_img_captions])))

转换为 Lance 数据集

现在我们的标题列表已采用正确的格式,我们将编写一个 `process()` 函数,该函数将以所述标题作为参数,并生成一个 Pyarrow 记录批次,其中包含 `image_id`、`image` 和 `captions`。此记录批次中的图像将是二进制格式,并且图像的所有标题都将放在一个列表中,并保留其顺序。

def process(captions):
    for img_id, img_captions in tqdm(captions):
        try:
            with open(os.path.join(image_folder, img_id), 'rb') as im:
                binary_im = im.read()

        except FileNotFoundError:
            print(f"img_id '{img_id}' not found in the folder, skipping.")
            continue

        img_id = pa.array([img_id], type=pa.string())
        img = pa.array([binary_im], type=pa.binary())
        capt = pa.array([img_captions], pa.list_(pa.string(), -1))

        yield pa.RecordBatch.from_arrays(
            [img_id, img, capt], 
            ["image_id", "image", "captions"]
        )

我们还要定义相同的模式,以告诉 Pyarrow 表中应该期望的数据类型。

schema = pa.schema([
    pa.field("image_id", pa.string()),
    pa.field("image", pa.binary()),
    pa.field("captions", pa.list_(pa.string(), -1)),
])

我们包括 `image_id`(即原始图像名称),以便将来更容易引用和调试。

最后,我们定义一个读取器来迭代地读取这些记录批次,然后将它们写入磁盘上的 lance 数据集。

reader = pa.RecordBatchReader.from_batches(schema, process(captions))
lance.write_dataset(reader, "flickr8k.lance", schema)

基本上就是这样!如果您想以 Notebook 形式执行此操作,可以在我们的 deeplearning-recipes 存储库此处查看此示例。

有关使用 Lance 数据集的更多深度学习相关示例,请务必查看 lance-deeplearning-recipes 存储库!