首先,我们导入 Spark XGBoost 的 GPU 版本和 CPU 版本所需的软件包:
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql._
import org.apache.spark.ml._
import org.apache.spark.ml.feature._
import org.apache.spark.ml.evaluation._
import org.apache.spark.sql.types._
import ml.dmlc.xgboost4j.scala.spark.{XGBoostRegressor, XGBoostRegressionModel}
使用 Spark XGBoost 的 GPU 版本时,您还需要进行以下导入操作:
import ml.dmlc.xgboost4j.scala.spark.rapids.{GpuDataReader, GpuDataset}
我们使用 Spark StructType 指定模式。
lazy val schema =
StructType(Array(
StructField("vendor_id", DoubleType),
StructField("passenger_count", DoubleType),
StructField("trip_distance", DoubleType),
StructField("pickup_longitude", DoubleType),
StructField("pickup_latitude", DoubleType),
StructField("rate_code", DoubleType),
StructField("store_and_fwd", DoubleType),
StructField("dropoff_longitude", DoubleType),
StructField("dropoff_latitude", DoubleType),
StructField(labelName, DoubleType),
StructField("hour", DoubleType),
StructField("year", IntegerType),
StructField("month", IntegerType),
StructField("day", DoubleType),
StructField("day_of_week", DoubleType),
StructField("is_weekend", DoubleType)
))
在以下代码中,我们创建 Spark 会话并设置训练和评估数据文件路径。(请注意:如果您使用 notebook,则不必创建 SparkSession。)
val trainPath = "/FileStore/tables/taxi_tsmall.csv"
val evalPath = "/FileStore/tables/taxi_esmall.csv"
val spark = SparkSession.builder().appName("Taxi-GPU").getOrCreate
我们将 CSV 文件的数据加载到 Spark DataFrame 中,并指定要加载到 DataFrame 中的数据源和模式,具体如下所示。
val tdf = spark.read.option("inferSchema",
"false").option("header", true).schema(schema).csv(trainPath)
val edf = spark.read.option("inferSchema", "false").option("header",
true).schema(schema).csv(evalPath)
DataFrame show(5) 显示前 5 行:
tdf.select("trip_distance", "rate_code","fare_amount").show(5)
result:
+------------------+-------------+-----------+
| trip_distance| rate_code|fare_amount|
+------------------+-------------+-----------+
| 2.72|-6.77418915E8| 11.5|
| 0.94|-6.77418915E8| 5.5|
| 3.63|-6.77418915E8| 13.0|
| 11.86|-6.77418915E8| 33.5|
| 3.03|-6.77418915E8| 11.0|
+------------------+-------------+-----------+
函数 Describe 返回一个 DataFrame,其中包含描述性汇总统计信息,例如计数、均值、标准差以及每个数字列的最小值和最大值。
tdf.select("trip_distance", "rate_code","fare_amount").describe().show
+-------+------------------+--------------------+------------------+
|summary| trip_distance| rate_code| fare_amount|
+-------+------------------+--------------------+------------------+
| count| 7999| 7999| 7999|
| mean| 3.278923615451919|-6.569284350812602E8|12.348543567945994|
| stddev|3.6320775770793547|1.6677419425906155E8|10.221929466939088|
| min| 0.0| -6.77418915E8| 2.5|
| max|35.970000000000006| 1.957796822E9| 107.5|
+-------+------------------+--------------------+------------------+
以下散点图用于探讨车费与行程距离之间的相关性。
%sql
select trip_distance, fare_amount
from taxi