Keras 和 Tensorflow

Keras 和 Tensorflow

SciKeras 软件包为 Keras 带来了 Scikit-learn API。这使得 Dask-ML 可以与 Keras 模型无缝结合使用。

安装

按照 Tensorflow 安装说明SciKeras 安装指南,需要安装以下软件包:

$ pip install tensorflow>=2.3.0
$ pip install scikeras>=0.1.8

这些是 Dask-ML 使用 Tensorflow/Keras 所需的最低版本。

用法

首先,我们先定义一个常规函数来创建模型。这是创建 Keras Sequential 模型 的常规方法

import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential

def build_model(lr=0.01, momentum=0.9):
    layers = [Dense(512, input_shape=(784,), activation="relu"),
              Dense(10, input_shape=(512,), activation="softmax")]
    model = Sequential(layers)

    opt = tf.keras.optimizers.SGD(
        learning_rate=lr, momentum=momentum, nesterov=True,
    )
    model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
    return model

现在,我们可以使用 SciKeras 来创建一个与 Scikit-learn 兼容的模型

from scikeras.wrappers import KerasClassifier
niceties = dict(verbose=False)
model = KerasClassifier(build_fn=build_model, lr=0.1, momentum=0.9, **niceties)

这个模型将与 Dask-ML 的所有功能配合使用:它可以将 NumPy 数组作为输入,并遵循 Scikit-learn API。例如,可以使用 Dask-ML 来执行以下操作:

如果我们想调整 lrmomentum,SciKeras 要求我们在初始化时传入 lrmomentum

model = KerasClassifier(build_fn=build_model, lr=None, momentum=None, **niceties)

SciKeras 支持更多模型创建方法,包括一些与 Tensorflow 向后兼容的方法。详细信息请参考其文档。

示例:超参数优化

如果需要,我们可以使用上面提到的模型与 HyperbandSearchCV 结合使用。让我们在 MNIST 数据集上调整这个模型

from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
import numpy as np
from typing import Tuple

def get_mnist() -> Tuple[np.ndarray, np.ndarray]:
    (X_train, y_train), _ = mnist.load_data()
    X_train = X_train.reshape(X_train.shape[0], 784)
    X_train = X_train.astype("float32")
    X_train /= 255
    return X_train, y_train

然后,我们执行调整 SGD 实现的基本任务

from scipy.stats import loguniform, uniform
params = {"lr": loguniform(1e-3, 1e-1), "momentum": uniform(0, 1)}
X, y = get_mnist()

现在,可以运行搜索了

from dask.distributed import Client
client = Client()

from dask_ml.model_selection import HyperbandSearchCV
search = HyperbandSearchCV(model, params, max_iter=27)
search.fit(X, y)