使用 Lance 进行深度学习工件管理¶
除了数据集,Lance 文件格式还可以用于保存和版本化深度学习模型权重。事实上,使用 Lance 文件格式来管理 PyTorch 模型权重可以使深度学习工件管理更加精简(与普通的权重保存方法相比)。
在本例中,我们将演示如何使用 Lance 保存、版本化和加载 PyTorch 模型的权重。更具体地说,我们将加载一个预训练的 ResNet 模型,以 Lance 文件格式保存它,然后将其重新加载到 PyTorch 中,并验证权重是否确实相同。我们还将演示如何在一个 Lance 数据集中版本化模型权重,这得益于我们的零拷贝、自动版本化功能。
关键思想:当你在 PyTorch 中保存模型的权重(即:状态字典)时,权重以键值对的形式存储在 OrderedDict
中,其中键表示权重的名称,值表示相应的权重张量。为了尽可能地模拟这一点,我们将把权重保存在三列中。第一列将包含权重的名称,第二列将包含权重本身(但被展平为列表),第三列将包含权重的原始形状,以便在加载到模型中时可以重新构建它们。
导入和设置¶
我们将从导入和加载所有必要的模块开始。
import os
import shutil
import lance
import pyarrow as pa
import torch
from collections import OrderedDict
我们还将定义一个 GLOBAL_SCHEMA
,它将决定权重表的结构。
GLOBAL_SCHEMA = pa.schema(
[
pa.field("name", pa.string()),
pa.field("value", pa.list_(pa.float64(), -1)),
pa.field("shape", pa.list_(pa.int64(), -1)), # Is a list with variable shape because weights can have any number of dims
]
)
正如我们前面提到的,权重表将有三列——一列用于存储权重名称,一列用于存储展平的权重值,一列用于存储原始权重形状以便重新加载。
保存和版本化模型¶
首先我们将关注模型保存部分。让我们从编写一个实用函数开始,该函数将接受模型的 state dict,遍历每个权重,将其展平,然后以 pyarrow RecordBatch
的形式返回权重名称、展平的权重和权重的原始形状。
def _save_model_writer(state_dict):
"""Yields a RecordBatch for each parameter in the model state dict"""
for param_name, param in state_dict.items():
param_shape = list(param.size())
param_value = param.flatten().tolist()
yield pa.RecordBatch.from_arrays(
[
pa.array(
[param_name],
pa.string(),
),
pa.array(
[param_value],
pa.list_(pa.float64(), -1),
),
pa.array(
[param_shape],
pa.list_(pa.int64(), -1),
),
],
["name", "value", "shape"],
)
现在谈谈版本化:假设你用一些新数据训练了你的模型,但又不想覆盖旧的检查点,你现在可以将这些新训练的模型权重作为 Lance 权重数据集中的一个版本保存。这将允许你从一个 Lance 权重数据集中加载特定版本的权重,而不必为每个模型检查点创建单独的文件夹。
让我们编写一个函数来处理模型的保存工作,无论是否带有版本。
def save_model(state_dict: OrderedDict, file_name: str, version=False):
"""Saves a PyTorch model in lance file format
Args:
state_dict (OrderedDict): Model state dict
file_name (str): Lance model name
version (bool): Whether to save as a new version or overwrite the existing versions,
if the lance file already exists
"""
# Create a reader
reader = pa.RecordBatchReader.from_batches(
GLOBAL_SCHEMA, _save_model_writer(state_dict)
)
if os.path.exists(file_name):
if version:
# If we want versioning, we use the overwrite mode to create a new version
lance.write_dataset(
reader, file_name, schema=GLOBAL_SCHEMA, mode="overwrite"
)
else:
# If we don't want versioning, we delete the existing file and write a new one
shutil.rmtree(file_name)
lance.write_dataset(reader, file_name, schema=GLOBAL_SCHEMA)
else:
# If the file doesn't exist, we write a new one
lance.write_dataset(reader, file_name, schema=GLOBAL_SCHEMA)
上述函数将接收模型状态字典、Lance 保存的文件名和权重版本。该函数将首先使用全局 schema 和我们上面编写的实用函数创建一个 RecordBatchReader
。如果权重 Lance 数据集已存在于目录中,我们将只将其保存为新版本(如果启用了版本控制)或删除旧文件并保存为新权重。否则,权重将正常保存。
加载模型¶
将权重从 Lance 权重数据集加载到模型中与保存它们的过程正好相反。关键部分是将展平的权重恢复到其原始形状,由于你保存了与权重对应的形状,这变得更容易。为了更好的可读性,我们将此过程分为三个函数。
第一个函数是 _load_weight
函数,它将从 Lance 权重数据集中检索到的“权重”作为 torch 张量以其原始形状返回。我们从 Lance 权重数据集中检索到的“权重”将是一个字典,其中值对应于每列,以键的形式表示。
def _load_weight(weight: dict) -> torch.Tensor:
"""Converts a weight dict to a torch tensor"""
return torch.tensor(weight["value"], dtype=torch.float64).reshape(weight["shape"])
(可选地,您还可以添加一个选项来指定权重的datatype。)
下一个函数将是将所有权重从 lance 权重数据集加载到状态字典中,这是 PyTorch 在我们将权重加载到模型中时所期望的。
def _load_state_dict(file_name: str, version: int = 1, map_location=None) -> OrderedDict:
"""Reads the model weights from lance file and returns a model state dict
If the model weights are too large, this function will fail with a memory error.
Args:
file_name (str): Lance model name
version (int): Version of the model to load
map_location (str): Device to load the model on
Returns:
OrderedDict: Model state dict
"""
ds = lance.dataset(file_name, version=version)
weights = ds.take([x for x in range(ds.count_rows())]).to_pylist()
state_dict = OrderedDict()
for weight in weights:
state_dict[weight["name"]] = _load_weight(weight).to(map_location)
return state_dict
load_state_dict
函数将期望一个 lance 权重数据集文件名、一个版本以及权重将加载到的设备。我们本质上将所有权重从 lance 权重数据集加载到内存中,并使用我们之前编写的实用函数迭代地将它们转换为权重,然后将它们放置在设备上。
这里需要注意的一点是,如果保存的权重大于内存,此函数将失败。为了简化起见,我们假设要加载的权重可以放入内存中,并且我们不必处理任何分片。
最后,我们将编写一个更高级别的函数,它是我们唯一会调用的用于加载权重的函数。
def load_model(
model: torch.nn.Module, file_name: str, version: int = 1, map_location=None
):
"""Loads the model weights from lance file and sets them to the model
Args:
model (torch.nn.Module): PyTorch model
file_name (str): Lance model name
version (int): Version of the model to load
map_location (str): Device to load the model on
"""
state_dict = _load_state_dict(file_name, version=version, map_location=map_location)
model.load_state_dict(state_dict)
load_model
函数将需要模型、lance 权重数据集名称、要加载的权重版本和映射位置。这只会调用 _load_state_dict
实用程序来获取状态字典,然后将该状态字典加载到模型中。
总结¶
总而言之,你只需调用 save_model
和 load_model
两个函数即可分别保存和加载模型,只要权重可以放入内存并使用 PyTorch,就应该没问题。
尽管仍处于实验阶段,但这种方法定义了一种进行深度学习工件管理的新方式。