PyTorch
目录
PyTorch¶
Skorch 为 PyTorch_ 带来了 Scikit-learn API。Skorch 允许将 PyTorch 模型封装在与 Scikit-learn 兼容的估计器中。这意味着用 Skorch 封装的 PyTorch 模型可以与 Dask-ML API 的其余部分一起使用。例如,将用 Skorch 封装的 PyTorch 模型与 Dask-ML 的 HyperbandSearchCV
或 Incremental
一起使用是可能的。
我们建议查阅 Skorch 文档以获取完整详细信息。
使用示例¶
首先,我们创建一个普通的 PyTorch 模型
import torch.nn as nn
import torch.nn.functional as F
class ShallowNet(nn.Module):
def __init__(self, n_features=5):
super().__init__()
self.layer1 = nn.Linear(n_features, 1)
def forward(self, x):
return F.relu(self.layer1(x))
有了它,使用 Skorch 变得很容易
from skorch import NeuralNetRegressor
import torch.optim as optim
niceties = {
"callbacks": False,
"warm_start": False,
"train_split": None,
"max_epochs": 1,
}
model = NeuralNetRegressor(
module=ShallowNet,
module__n_features=5,
criterion=nn.MSELoss,
optimizer=optim.SGD,
optimizer__lr=0.1,
optimizer__momentum=0.9,
batch_size=64,
**niceties,
)
PyTorch nn.Module
的每个参数都以 module__
为前缀,优化器(optim.SGD
接受 lr
和 momentum
参数)也一样。niceties
确保 Skorch 使用所有数据进行训练,并且不会打印过多的日志。
现在,这个模型可以与 Dask-ML 一起使用。例如,可以执行以下操作
将 PyTorch 与 Dask-ML 的模型选择一起使用,包括
HyperbandSearchCV
。将 PyTorch 与 Dask-ML 的
Incremental
一起使用。