Keras 和 Tensorflow
目录
Keras 和 Tensorflow¶
SciKeras 软件包为 Keras 带来了 Scikit-learn API。这使得 Dask-ML 可以与 Keras 模型无缝结合使用。
安装¶
按照 Tensorflow 安装说明 和 SciKeras 安装指南,需要安装以下软件包:
$ pip install tensorflow>=2.3.0
$ pip install scikeras>=0.1.8
这些是 Dask-ML 使用 Tensorflow/Keras 所需的最低版本。
用法¶
首先,我们先定义一个常规函数来创建模型。这是创建 Keras Sequential 模型 的常规方法
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential
def build_model(lr=0.01, momentum=0.9):
layers = [Dense(512, input_shape=(784,), activation="relu"),
Dense(10, input_shape=(512,), activation="softmax")]
model = Sequential(layers)
opt = tf.keras.optimizers.SGD(
learning_rate=lr, momentum=momentum, nesterov=True,
)
model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
return model
现在,我们可以使用 SciKeras 来创建一个与 Scikit-learn 兼容的模型
from scikeras.wrappers import KerasClassifier
niceties = dict(verbose=False)
model = KerasClassifier(build_fn=build_model, lr=0.1, momentum=0.9, **niceties)
这个模型将与 Dask-ML 的所有功能配合使用:它可以将 NumPy 数组作为输入,并遵循 Scikit-learn API。例如,可以使用 Dask-ML 来执行以下操作:
将 Keras 与 Dask-ML 的模型选择结合使用,包括
HyperbandSearchCV
。将 Keras 与 Dask-ML 的
Incremental
结合使用。
如果我们想调整 lr
和 momentum
,SciKeras 要求我们在初始化时传入 lr
和 momentum
model = KerasClassifier(build_fn=build_model, lr=None, momentum=None, **niceties)
SciKeras 支持更多模型创建方法,包括一些与 Tensorflow 向后兼容的方法。详细信息请参考其文档。
示例:超参数优化¶
如果需要,我们可以使用上面提到的模型与 HyperbandSearchCV
结合使用。让我们在 MNIST 数据集上调整这个模型
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
import numpy as np
from typing import Tuple
def get_mnist() -> Tuple[np.ndarray, np.ndarray]:
(X_train, y_train), _ = mnist.load_data()
X_train = X_train.reshape(X_train.shape[0], 784)
X_train = X_train.astype("float32")
X_train /= 255
return X_train, y_train
然后,我们执行调整 SGD 实现的基本任务
from scipy.stats import loguniform, uniform
params = {"lr": loguniform(1e-3, 1e-1), "momentum": uniform(0, 1)}
X, y = get_mnist()
现在,可以运行搜索了
from dask.distributed import Client
client = Client()
from dask_ml.model_selection import HyperbandSearchCV
search = HyperbandSearchCV(model, params, max_iter=27)
search.fit(X, y)