目录

PyTorch

目录

PyTorch

SkorchPyTorch_ 带来了 Scikit-learn API。Skorch 允许将 PyTorch 模型封装在与 Scikit-learn 兼容的估计器中。这意味着用 Skorch 封装的 PyTorch 模型可以与 Dask-ML API 的其余部分一起使用。例如,将用 Skorch 封装的 PyTorch 模型与 Dask-ML 的 HyperbandSearchCVIncremental 一起使用是可能的。

我们建议查阅 Skorch 文档以获取完整详细信息。

使用示例

首先,我们创建一个普通的 PyTorch 模型

import torch.nn as nn
import torch.nn.functional as F

class ShallowNet(nn.Module):
    def __init__(self, n_features=5):
        super().__init__()
        self.layer1 = nn.Linear(n_features, 1)

    def forward(self, x):
        return F.relu(self.layer1(x))

有了它,使用 Skorch 变得很容易

from skorch import NeuralNetRegressor
import torch.optim as optim

niceties = {
    "callbacks": False,
    "warm_start": False,
    "train_split": None,
    "max_epochs": 1,
}

model = NeuralNetRegressor(
    module=ShallowNet,
    module__n_features=5,
    criterion=nn.MSELoss,
    optimizer=optim.SGD,
    optimizer__lr=0.1,
    optimizer__momentum=0.9,
    batch_size=64,
    **niceties,
)

PyTorch nn.Module 的每个参数都以 module__ 为前缀,优化器(optim.SGD 接受 lrmomentum 参数)也一样。niceties 确保 Skorch 使用所有数据进行训练,并且不会打印过多的日志。

现在,这个模型可以与 Dask-ML 一起使用。例如,可以执行以下操作