跳到内容

使用 Lance 进行深度学习工件管理

除了数据集,Lance 文件格式还可以用于保存和版本化深度学习模型权重。事实上,使用 Lance 文件格式来管理 PyTorch 模型权重可以使深度学习工件管理更加精简(与普通的权重保存方法相比)。

在本例中,我们将演示如何使用 Lance 保存、版本化和加载 PyTorch 模型的权重。更具体地说,我们将加载一个预训练的 ResNet 模型,以 Lance 文件格式保存它,然后将其重新加载到 PyTorch 中,并验证权重是否确实相同。我们还将演示如何在一个 Lance 数据集中版本化模型权重,这得益于我们的零拷贝、自动版本化功能。

关键思想:当你在 PyTorch 中保存模型的权重(即:状态字典)时,权重以键值对的形式存储在 OrderedDict 中,其中键表示权重的名称,值表示相应的权重张量。为了尽可能地模拟这一点,我们将把权重保存在三列中。第一列将包含权重的名称,第二列将包含权重本身(但被展平为列表),第三列将包含权重的原始形状,以便在加载到模型中时可以重新构建它们。

导入和设置

我们将从导入和加载所有必要的模块开始。

import os
import shutil
import lance
import pyarrow as pa
import torch
from collections import OrderedDict

我们还将定义一个 GLOBAL_SCHEMA,它将决定权重表的结构。

GLOBAL_SCHEMA = pa.schema(
    [
        pa.field("name", pa.string()),
        pa.field("value", pa.list_(pa.float64(), -1)),
        pa.field("shape", pa.list_(pa.int64(), -1)), # Is a list with variable shape because weights can have any number of dims
    ]
)

正如我们前面提到的,权重表将有三列——一列用于存储权重名称,一列用于存储展平的权重值,一列用于存储原始权重形状以便重新加载。

保存和版本化模型

首先我们将关注模型保存部分。让我们从编写一个实用函数开始,该函数将接受模型的 state dict,遍历每个权重,将其展平,然后以 pyarrow RecordBatch 的形式返回权重名称、展平的权重和权重的原始形状。

def _save_model_writer(state_dict):
    """Yields a RecordBatch for each parameter in the model state dict"""
    for param_name, param in state_dict.items():
        param_shape = list(param.size())
        param_value = param.flatten().tolist()
        yield pa.RecordBatch.from_arrays(
            [
                pa.array(
                    [param_name],
                    pa.string(),
                ),
                pa.array(
                    [param_value],
                    pa.list_(pa.float64(), -1),
                ),
                pa.array(
                    [param_shape],
                    pa.list_(pa.int64(), -1),
                ),
            ],
            ["name", "value", "shape"],
        )

现在谈谈版本化:假设你用一些新数据训练了你的模型,但又不想覆盖旧的检查点,你现在可以将这些新训练的模型权重作为 Lance 权重数据集中的一个版本保存。这将允许你从一个 Lance 权重数据集中加载特定版本的权重,而不必为每个模型检查点创建单独的文件夹。

让我们编写一个函数来处理模型的保存工作,无论是否带有版本。

def save_model(state_dict: OrderedDict, file_name: str, version=False):
    """Saves a PyTorch model in lance file format

    Args:
        state_dict (OrderedDict): Model state dict
        file_name (str): Lance model name
        version (bool): Whether to save as a new version or overwrite the existing versions,
            if the lance file already exists
    """
    # Create a reader
    reader = pa.RecordBatchReader.from_batches(
        GLOBAL_SCHEMA, _save_model_writer(state_dict)
    )

    if os.path.exists(file_name):
        if version:
            # If we want versioning, we use the overwrite mode to create a new version
            lance.write_dataset(
                reader, file_name, schema=GLOBAL_SCHEMA, mode="overwrite"
            )
        else:
            # If we don't want versioning, we delete the existing file and write a new one
            shutil.rmtree(file_name)
            lance.write_dataset(reader, file_name, schema=GLOBAL_SCHEMA)
    else:
        # If the file doesn't exist, we write a new one
        lance.write_dataset(reader, file_name, schema=GLOBAL_SCHEMA)

上述函数将接收模型状态字典、Lance 保存的文件名和权重版本。该函数将首先使用全局 schema 和我们上面编写的实用函数创建一个 RecordBatchReader。如果权重 Lance 数据集已存在于目录中,我们将只将其保存为新版本(如果启用了版本控制)或删除旧文件并保存为新权重。否则,权重将正常保存。

加载模型

将权重从 Lance 权重数据集加载到模型中与保存它们的过程正好相反。关键部分是将展平的权重恢复到其原始形状,由于你保存了与权重对应的形状,这变得更容易。为了更好的可读性,我们将此过程分为三个函数。

第一个函数是 _load_weight 函数,它将从 Lance 权重数据集中检索到的“权重”作为 torch 张量以其原始形状返回。我们从 Lance 权重数据集中检索到的“权重”将是一个字典,其中值对应于每列,以键的形式表示。

def _load_weight(weight: dict) -> torch.Tensor:
    """Converts a weight dict to a torch tensor"""
    return torch.tensor(weight["value"], dtype=torch.float64).reshape(weight["shape"])

(可选地,您还可以添加一个选项来指定权重的datatype。)

下一个函数将是将所有权重从 lance 权重数据集加载到状态字典中,这是 PyTorch 在我们将权重加载到模型中时所期望的。

def _load_state_dict(file_name: str, version: int = 1, map_location=None) -> OrderedDict:
    """Reads the model weights from lance file and returns a model state dict
    If the model weights are too large, this function will fail with a memory error.

    Args:
        file_name (str): Lance model name
        version (int): Version of the model to load
        map_location (str): Device to load the model on

    Returns:
        OrderedDict: Model state dict
    """
    ds = lance.dataset(file_name, version=version)
    weights = ds.take([x for x in range(ds.count_rows())]).to_pylist()
    state_dict = OrderedDict()

    for weight in weights:
        state_dict[weight["name"]] = _load_weight(weight).to(map_location)

    return state_dict

load_state_dict 函数将期望一个 lance 权重数据集文件名、一个版本以及权重将加载到的设备。我们本质上将所有权重从 lance 权重数据集加载到内存中,并使用我们之前编写的实用函数迭代地将它们转换为权重,然后将它们放置在设备上。

这里需要注意的一点是,如果保存的权重大于内存,此函数将失败。为了简化起见,我们假设要加载的权重可以放入内存中,并且我们不必处理任何分片。

最后,我们将编写一个更高级别的函数,它是我们唯一会调用的用于加载权重的函数。

def load_model(
    model: torch.nn.Module, file_name: str, version: int = 1, map_location=None
):
    """Loads the model weights from lance file and sets them to the model

    Args:
        model (torch.nn.Module): PyTorch model
        file_name (str): Lance model name
        version (int): Version of the model to load
        map_location (str): Device to load the model on
    """
    state_dict = _load_state_dict(file_name, version=version, map_location=map_location)
    model.load_state_dict(state_dict)

load_model 函数将需要模型、lance 权重数据集名称、要加载的权重版本和映射位置。这只会调用 _load_state_dict 实用程序来获取状态字典,然后将该状态字典加载到模型中。

总结

总而言之,你只需调用 save_modelload_model 两个函数即可分别保存和加载模型,只要权重可以放入内存并使用 PyTorch,就应该没问题。

尽管仍处于实验阶段,但这种方法定义了一种进行深度学习工件管理的新方式。