交叉验证

交叉验证

关于交叉验证的更全面讨论,请参阅 scikit-learn 交叉验证文档。本文档仅描述了为支持 Dask 数组而进行的扩展。

分割一个或多个 Dask 数组的最简单方法是使用 dask_ml.model_selection.train_test_split()

In [1]: import dask.array as da

In [2]: from dask_ml.datasets import make_regression

In [3]: from dask_ml.model_selection import train_test_split

In [4]: X, y = make_regression(n_samples=125, n_features=4, random_state=0, chunks=50)

In [5]: X
Out[5]: dask.array<normal, shape=(125, 4), dtype=float64, chunksize=(50, 4), chunktype=numpy.ndarray>

分割 Dask 数组的接口与 scikit-learn 的版本相同。

In [6]: X_train, X_test, y_train, y_test = train_test_split(X, y)

In [7]: X_train  # A dask Array
Out[7]: dask.array<concatenate, shape=(112, 4), dtype=float64, chunksize=(45, 4), chunktype=numpy.ndarray>

In [8]: X_train.compute()[:3]
Out[8]: 
array([[ 0.72278447,  1.44812909, -0.48500419, -0.13090068],
       [-0.10656729, -0.19912227, -1.29342938,  0.67371525],
       [-1.05879148, -0.89722898,  0.96089406, -0.97585992]])

虽然可以将 dask 数组传递给 sklearn.model_selection.train_test_split(),但出于性能考虑,我们建议使用 Dask 版本:Dask 版本更快,原因有二

首先,Dask 版本按块进行混洗。在分布式环境中,块之间的混洗可能需要在机器之间传输大量数据,这可能会很慢。但是,如果您的数据中存在强烈的模式,您将需要执行完全混洗。

其次,Dask 版本避免分配大的中间 NumPy 数组来存储切片索引。对于非常大的数据集,创建和传输 np.arange(n_samples) 可能非常耗时。