Tensorflow 集成¶
Lance 可以在 Tensorflow 中用作常规的 tf.data.Dataset。
警告
此功能是实验性的,API 未来可能会更改。
从 Lance 读取¶
使用 lance.tf.data.from_lance
,您可以轻松创建 tf.data.Dataset
。
import tensorflow as tf
import lance
# Create tf dataset
ds = lance.tf.data.from_lance("s3://my-bucket/my-dataset")
# Chain tf dataset with other tf primitives
for batch in ds.shuffling(32).map(lambda x: tf.io.decode_png(x["image"])):
print(batch)
在 Lance 列式格式的支持下,使用 lance.tf.data.from_lance
支持高效的列选择、过滤等功能。
ds = lance.tf.data.from_lance(
"s3://my-bucket/my-dataset",
columns=["image", "label"],
filter="split = 'train' AND collected_time > timestamp '2020-01-01'",
batch_size=256)
默认情况下,Lance 将从投影列中推断 Tensor 规范。您也可以手动指定 tf.TensorSpec
。
batch_size = 256
ds = lance.tf.data.from_lance(
"s3://my-bucket/my-dataset",
columns=["image", "labels"],
batch_size=batch_size,
output_signature={
"image": tf.TensorSpec(shape=(), dtype=tf.string),
"labels": tf.RaggedTensorSpec(
dtype=tf.int32, shape=(batch_size, None), ragged_rank=1),
},
分布式训练和混洗¶
由于 Lance 数据集是一组片段,我们可以将片段分发和混洗到不同的 worker。
import tensorflow as tf
from lance.tf.data import from_lance, lance_fragments
world_size = 32
rank = 10
seed = 123 #
epoch = 100
dataset_uri = "s3://my-bucket/my-dataset"
# Shuffle fragments distributedly.
fragments =
lance_fragments("s3://my-bucket/my-dataset")
.shuffling(32, seed=seed)
.repeat(epoch)
.enumerate()
.filter(lambda i, _: i % world_size == rank)
.map(lambda _, fid: fid)
ds = from_lance(
uri,
columns=["image", "label"],
fragments=fragments,
batch_size=32
)
for batch in ds:
print(batch)
警告
对于多进程,您可能不应该使用 fork,因为 Lance 内部是多线程的,而 fork 和多线程不能很好地协同工作。请参阅此讨论。