dask_ml.xgboost.train
dask_ml.xgboost.train¶
- dask_ml.xgboost.train(client, params, data, labels, dmatrix_kwargs={}, evals_result=None, sample_weight=None, **kwargs)¶
在 Dask 集群上训练 XGBoost 模型
这会在所有 Dask worker 上启动 XGBoost,将输入数据移动到这些 worker,然后在输入数据上调用
xgboost.train
。- 参数
- client: dask.distributed.Client
- params: dict
传递给 XGBoost 的参数 (参见 xgb.Booster.train)
- data: Dask 数组或 Dask 数据帧
- labels: Dask 数组或 Dask 数据帧
- dmatrix_kwargs: 传递给 Xgboost DMatrix 的关键字参数
- evals_result: 字典,可选
通过原地修改 evals_result,存储 eval_set 中所有项的评估结果历史。
- sample_weight类似数组,可选
实例权重
- **kwargs: 传递给 XGBoost train 的关键字参数
另请参阅
示例
>>> client = Client('scheduler-address:8786') >>> data = dd.read_csv('s3://...') >>> labels = data['outcome'] >>> del data['outcome'] >>> train(client, params, data, labels, **normal_kwargs) <xgboost.core.Booster object at ...>