跳到内容

Tensorflow 集成

Lance 可以在 Tensorflow 中用作常规的 tf.data.Dataset

警告

此功能是实验性的,API 未来可能会更改。

从 Lance 读取

使用 lance.tf.data.from_lance,您可以轻松创建 tf.data.Dataset

import tensorflow as tf
import lance

# Create tf dataset
ds = lance.tf.data.from_lance("s3://my-bucket/my-dataset")

# Chain tf dataset with other tf primitives

for batch in ds.shuffling(32).map(lambda x: tf.io.decode_png(x["image"])):
    print(batch)

在 Lance 列式格式的支持下,使用 lance.tf.data.from_lance 支持高效的列选择、过滤等功能。

ds = lance.tf.data.from_lance(
    "s3://my-bucket/my-dataset",
    columns=["image", "label"],
    filter="split = 'train' AND collected_time > timestamp '2020-01-01'",
    batch_size=256)

默认情况下,Lance 将从投影列中推断 Tensor 规范。您也可以手动指定 tf.TensorSpec

batch_size = 256
ds = lance.tf.data.from_lance(
    "s3://my-bucket/my-dataset",
    columns=["image", "labels"],
    batch_size=batch_size,
    output_signature={
        "image": tf.TensorSpec(shape=(), dtype=tf.string),
        "labels": tf.RaggedTensorSpec(
            dtype=tf.int32, shape=(batch_size, None), ragged_rank=1),
    },

分布式训练和混洗

由于 Lance 数据集是一组片段,我们可以将片段分发和混洗到不同的 worker。

import tensorflow as tf
from lance.tf.data import from_lance, lance_fragments

world_size = 32
rank = 10
seed = 123  #
epoch = 100

dataset_uri = "s3://my-bucket/my-dataset"

# Shuffle fragments distributedly.
fragments =
    lance_fragments("s3://my-bucket/my-dataset")
    .shuffling(32, seed=seed)
    .repeat(epoch)
    .enumerate()
    .filter(lambda i, _: i % world_size == rank)
    .map(lambda _, fid: fid)

ds = from_lance(
    uri,
    columns=["image", "label"],
    fragments=fragments,
    batch_size=32
    )
for batch in ds:
    print(batch)

警告

对于多进程,您可能不应该使用 fork,因为 Lance 内部是多线程的,而 fork 和多线程不能很好地协同工作。请参阅此讨论