dask_ml.xgboost.train

dask_ml.xgboost.train

dask_ml.xgboost.train(client, params, data, labels, dmatrix_kwargs={}, evals_result=None, sample_weight=None, **kwargs)

在 Dask 集群上训练 XGBoost 模型

这会在所有 Dask worker 上启动 XGBoost,将输入数据移动到这些 worker,然后在输入数据上调用 xgboost.train

参数
client: dask.distributed.Client
params: dict

传递给 XGBoost 的参数 (参见 xgb.Booster.train)

data: Dask 数组或 Dask 数据帧
labels: Dask 数组或 Dask 数据帧
dmatrix_kwargs: 传递给 Xgboost DMatrix 的关键字参数
evals_result: 字典,可选

通过原地修改 evals_result,存储 eval_set 中所有项的评估结果历史。

sample_weight类似数组,可选

实例权重

**kwargs: 传递给 XGBoost train 的关键字参数

另请参阅

predict

示例

>>> client = Client('scheduler-address:8786')  
>>> data = dd.read_csv('s3://...')  
>>> labels = data['outcome']  
>>> del data['outcome']  
>>> train(client, params, data, labels, **normal_kwargs)  
<xgboost.core.Booster object at ...>