Spark PreprocessingFE practice

编程入门 行业动态 更新时间:2024-10-16 00:19:00

<a href=https://www.elefans.com/category/jswz/34/1769717.html style=Spark PreprocessingFE practice"/>

Spark PreprocessingFE practice

最近因为要做推荐系统 ,为了熟悉 pyspark 的操作,并且熟悉一下处理日志数据 , 故尝试处理此数据集


数据集介绍
Ali_Display_Ad_Click是阿里巴巴提供的一个淘宝展示广告点击率预估数据集。 下载地址 =56 ,

数据名称说明属性
raw_sample原始的样本骨架用户ID,广告ID,时间,资源位,是否点击
ad_feature广告的基本信息广告ID,广告计划ID,类目ID,品牌ID
user_profile用户的基本信息用户ID,年龄层,性别等
raw_behavior_log用户的行为日志用户ID,行为类型,时间,商品类目ID,品牌ID

原始样本骨架raw_sample
我们从淘宝网站中随机抽样了114万用户8天内的广告展示/点击日志(2600万条记录),构成原始的样本骨架。
字段说明如下:
(1) user_id:脱敏过的用户ID;
(2) adgroup_id:脱敏过的广告单元ID;
(3) time_stamp:时间戳;
(4) pid:资源位;
(5) noclk:为1代表没有点击;为0代表点击;
(6) clk:为0代表没有点击;为1代表点击;
我们用前面7天的做训练样本(20170506-20170512),用第8天的做测试样本(20170513)。

广告基本信息表ad_feature
本数据集涵盖了raw_sample中全部广告的基本信息。字段说明如下:
(1) adgroup_id:脱敏过的广告ID;
(2) cate_id:脱敏过的商品类目ID;
(3) campaign_id:脱敏过的广告计划ID;
(4) customer_id:脱敏过的广告主ID;
(5) brand:脱敏过的品牌ID;
(6) price: 宝贝的价格
其中一个广告ID对应一个商品(宝贝),一个宝贝属于一个类目,一个宝贝属于一个品牌。

用户基本信息表user_profile
本数据集涵盖了raw_sample中全部用户的基本信息。字段说明如下:
(1) userid:脱敏过的用户ID;
(2) cms_segid:微群ID;
(3) cms_group_id:cms_group_id;
(4) final_gender_code:性别 1:男,2:女;
(5) age_level:年龄层次;
(6) pvalue_level:消费档次,1:低档,2:中档,3:高档;
(7) shopping_level:购物深度,1:浅层用户,2:中度用户,3:深度用户
(8) occupation:是否大学生 ,1:是,0:否
(9) new_user_class_level:城市层级

用户的行为日志behavior_log
本数据集涵盖了raw_sample中全部用户22天内的购物行为(共七亿条记录)。字段说明如下:
(1) user:脱敏过的用户ID;
(2) time_stamp:时间戳;
(3) btag:行为类型, 包括以下四种:

(4) cate:脱敏过的商品类目;
(5) brand: 脱敏过的品牌词;
这里以user + time_stamp为key,会有很多重复的记录;这是因为我们的不同的类型的行为数据是不同部门记录的,在打包到一起的时候,实际上会有小的偏差(即两个一样的time_stamp实际上是差异比较小的两个时间)。

Preprocessing & Feature Engineering

from pyspark.sql import SparkSessionspark = SparkSession.builder.appName('raw_sample').getOrCreate()
df = spark.read.csv(r'D:\阿里ctr预估数据集\raw_sample.csv', header=True)
df.show()
+------+----------+----------+-----------+------+---+
|  user|time_stamp|adgroup_id|        pid|nonclk|clk|
+------+----------+----------+-----------+------+---+
|581738|1494137644|         1|430548_1007|     1|  0|
|449818|1494638778|         3|430548_1007|     1|  0|
|914836|1494650879|         4|430548_1007|     1|  0|
|914836|1494651029|         5|430548_1007|     1|  0|
|399907|1494302958|         8|430548_1007|     1|  0|
|628137|1494524935|         9|430548_1007|     1|  0|
|298139|1494462593|         9|430539_1007|     1|  0|
|775475|1494561036|         9|430548_1007|     1|  0|
|555266|1494307136|        11|430539_1007|     1|  0|
|117840|1494036743|        11|430548_1007|     1|  0|
|739815|1494115387|        11|430539_1007|     1|  0|
|623911|1494625301|        11|430548_1007|     1|  0|
|623911|1494451608|        11|430548_1007|     1|  0|
|421590|1494034144|        11|430548_1007|     1|  0|
|976358|1494156949|        13|430548_1007|     1|  0|
|286630|1494218579|        13|430539_1007|     1|  0|
|286630|1494289247|        13|430539_1007|     1|  0|
|771431|1494153867|        13|430548_1007|     1|  0|
|707120|1494220810|        13|430548_1007|     1|  0|
|530454|1494293746|        13|430548_1007|     1|  0|
+------+----------+----------+-----------+------+---+
only showing top 20 rows

# 数据情况
print('样本数', df.count())
print('空值', df.count() - df.dropna().count())
样本数 26557961
空值 0
row1, row2 = df.groupBy("clk").count().collect()
r = row2.asDict()['count'] / row1.asDict()['count']
print('点击了的广告占比', r)
点击了的广告占比 0.05422599045209166
df.printSchema()
root|-- user: string (nullable = true)|-- time_stamp: string (nullable = true)|-- adgroup_id: string (nullable = true)|-- pid: string (nullable = true)|-- nonclk: string (nullable = true)|-- clk: string (nullable = true)

from pyspark.sql.types import StructField, StructType, IntegerType, LongType# 更改列的数据类型
raw_sample_df = df.\withColumn('user', df.user.cast(IntegerType())).\withColumn('time_stamp', df.time_stamp.cast(LongType())).\withColumn('nonclk', df.nonclk.cast(IntegerType())).\withColumn('clk', df.clk.cast(IntegerType()))raw_sample_df.printSchema()
raw_sample_df.show()
root|-- user: integer (nullable = true)|-- time_stamp: long (nullable = true)|-- adgroup_id: string (nullable = true)|-- pid: string (nullable = true)|-- nonclk: integer (nullable = true)|-- clk: integer (nullable = true)+------+----------+----------+-----------+------+---+
|  user|time_stamp|adgroup_id|        pid|nonclk|clk|
+------+----------+----------+-----------+------+---+
|581738|1494137644|         1|430548_1007|     1|  0|
|449818|1494638778|         3|430548_1007|     1|  0|
|914836|1494650879|         4|430548_1007|     1|  0|
|914836|1494651029|         5|430548_1007|     1|  0|
|399907|1494302958|         8|430548_1007|     1|  0|
|628137|1494524935|         9|430548_1007|     1|  0|
|298139|1494462593|         9|430539_1007|     1|  0|
|775475|1494561036|         9|430548_1007|     1|  0|
|555266|1494307136|        11|430539_1007|     1|  0|
|117840|1494036743|        11|430548_1007|     1|  0|
|739815|1494115387|        11|430539_1007|     1|  0|
|623911|1494625301|        11|430548_1007|     1|  0|
|623911|1494451608|        11|430548_1007|     1|  0|
|421590|1494034144|        11|430548_1007|     1|  0|
|976358|1494156949|        13|430548_1007|     1|  0|
|286630|1494218579|        13|430539_1007|     1|  0|
|286630|1494289247|        13|430539_1007|     1|  0|
|771431|1494153867|        13|430548_1007|     1|  0|
|707120|1494220810|        13|430548_1007|     1|  0|
|530454|1494293746|        13|430548_1007|     1|  0|
+------+----------+----------+-----------+------+---+
only showing top 20 rows

特征工程

from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline# StringIndexer 指定某一个类型是字符串的列,进行编码 如该列有 'a','b', 'c' -> 0, 1, 2
stringindexer = StringIndexer(inputCol='pid', outputCol='pid_feature')
# 独热编码
encoder = OneHotEncoder(dropLast=False, inputCol='pid_feature', outputCol='pid_value')
# 用管道对编码步骤进行封装
pipeline = Pipeline(stages=[stringindexer, encoder])
pipeline = pipeline.fit(raw_sample_df)
df = pipeline.transform(raw_sample_df)
df.show()
+------+----------+----------+-----------+------+---+-----------+-------------+
|  user|time_stamp|adgroup_id|        pid|nonclk|clk|pid_feature|    pid_value|
+------+----------+----------+-----------+------+---+-----------+-------------+
|581738|1494137644|         1|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|449818|1494638778|         3|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|914836|1494650879|         4|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|914836|1494651029|         5|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|399907|1494302958|         8|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|628137|1494524935|         9|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|298139|1494462593|         9|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|775475|1494561036|         9|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|555266|1494307136|        11|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|117840|1494036743|        11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|739815|1494115387|        11|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|623911|1494625301|        11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|623911|1494451608|        11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|421590|1494034144|        11|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|976358|1494156949|        13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|286630|1494218579|        13|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|286630|1494289247|        13|430539_1007|     1|  0|        1.0|(2,[1],[1.0])|
|771431|1494153867|        13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|707120|1494220810|        13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
|530454|1494293746|        13|430548_1007|     1|  0|        0.0|(2,[0],[1.0])|
+------+----------+----------+-----------+------+---+-----------+-------------+
only showing top 20 rows

# pyspark.ml.feature.OneHotEncoder 返回的新的一列的 数据类型是
# 稀疏向量 pyspark.ml.linalg.SparseVector
# 向量 (1.0, 0.0, 1.0, 3.0) 的稠密向量表示是 [1.0, 0.0, 1.0, 3.0]
# 稀疏格式表示是(4, [0, 2, 3], [1.0, 1.0, 3.0]) => (向量长度, 元素索引, 值)
from pyspark.ml.linalg import SparseVectorprint(SparseVector(4, [1, 3], [3.0, 4.0]))
print(SparseVector(4, [1, 3], [3.0, 4.0]).toArray()) # 转换为 numpy.ndarray
print(df.select("pid_value").first())
print(df.select("pid_value").first().pid_value.toArray())
(4,[1,3],[3.0,4.0])
[0. 3. 0. 4.]
Row(pid_value=SparseVector(2, {0: 1.0}))
[1. 0.]
df.describe('time_stamp').show()
+-------+--------------------+
|summary|          time_stamp|
+-------+--------------------+
|  count|            26557961|
|   mean|1.4943547981415155E9|
| stddev|  198755.26175048228|
|    min|          1494000000|
|    max|          1494691186|
+-------+--------------------+

# 时间间隔, 时间戳以秒为单位
time_temp = 1494691186 - 1494000000
# 最大时间戳和最小时间戳间隔 8天
time_temp / (24*60*60)
7.999837962962963
# 一共是 8 天的数据,前 7 天作为训练集,最后一天作为测试集train_data = raw_sample_df.filter(raw_sample_df['time_stamp'] <= \(1494691186 - (24*60*60)))
test_data = raw_sample_df.filter(raw_sample_df['time_stamp'] > \(1494691186 - (24*60*60)))num1, num2 = train_data.count(), test_data.count()
num1, num2, num1 / (num1+num2)
(23249291, 3308670, 0.8754170171422422)
# 处理 ad_feature 数据集
df = spark.read.csv('D:/阿里ctr预估数据集/ad_feature.csv', header=True)
df.show()
+----------+-------+-----------+--------+------+-----+
|adgroup_id|cate_id|campaign_id|customer| brand|price|
+----------+-------+-----------+--------+------+-----+
|     63133|   6406|      83237|       1| 95471|170.0|
|    313401|   6406|      83237|       1| 87331|199.0|
|    248909|    392|      83237|       1| 32233| 38.0|
|    208458|    392|      83237|       1|174374|139.0|
|    110847|   7211|     135256|       2|145952|32.99|
|    607788|   6261|     387991|       6|207800|199.0|
|    375706|   4520|     387991|       6|  NULL| 99.0|
|     11115|   7213|     139747|       9|186847| 33.0|
|     24484|   7207|     139744|       9|186847| 19.0|
|     28589|   5953|     395195|      13|  NULL|428.0|
|     23236|   5953|     395195|      13|  NULL|368.0|
|    300556|   5953|     395195|      13|  NULL|639.0|
|     92560|   5953|     395195|      13|  NULL|368.0|
|    590965|   4284|      28145|      14|454237|249.0|
|    529913|   4284|      70206|      14|  NULL|249.0|
|    546930|   4284|      28145|      14|  NULL|249.0|
|    639794|   6261|      70206|      14| 37004| 89.9|
|    335413|   4284|      28145|      14|  NULL|249.0|
|    794890|   4284|      70206|      14|454237|249.0|
|    684020|   6261|      70206|      14| 37004| 99.0|
+----------+-------+-----------+--------+------+-----+
only showing top 20 rows

# 替换空值为 -1,在做处理
df = df.replace(to_replace='NULL', value='-1')
df.printSchema()
root|-- adgroup_id: string (nullable = true)|-- cate_id: string (nullable = true)|-- campaign_id: string (nullable = true)|-- customer: string (nullable = true)|-- brand: string (nullable = true)|-- price: string (nullable = true)

from pyspark.sql.types import FloatTypedf = df.\withColumn("adgroup_id", df.adgroup_id.cast(IntegerType())).\withColumn("cate_id", df.cate_id.cast(IntegerType())).\withColumn("campaign_id", df.campaign_id.cast(IntegerType())).\withColumn("customer", df.customer.cast(IntegerType())).\withColumn("brand", df.brand.cast(IntegerType())).\withColumn("price", df.price.cast(FloatType()))df.printSchema()
df.show()
root|-- adgroup_id: integer (nullable = true)|-- cate_id: integer (nullable = true)|-- campaign_id: integer (nullable = true)|-- customer: integer (nullable = true)|-- brand: integer (nullable = true)|-- price: float (nullable = true)+----------+-------+-----------+--------+------+-----+
|adgroup_id|cate_id|campaign_id|customer| brand|price|
+----------+-------+-----------+--------+------+-----+
|     63133|   6406|      83237|       1| 95471|170.0|
|    313401|   6406|      83237|       1| 87331|199.0|
|    248909|    392|      83237|       1| 32233| 38.0|
|    208458|    392|      83237|       1|174374|139.0|
|    110847|   7211|     135256|       2|145952|32.99|
|    607788|   6261|     387991|       6|207800|199.0|
|    375706|   4520|     387991|       6|    -1| 99.0|
|     11115|   7213|     139747|       9|186847| 33.0|
|     24484|   7207|     139744|       9|186847| 19.0|
|     28589|   5953|     395195|      13|    -1|428.0|
|     23236|   5953|     395195|      13|    -1|368.0|
|    300556|   5953|     395195|      13|    -1|639.0|
|     92560|   5953|     395195|      13|    -1|368.0|
|    590965|   4284|      28145|      14|454237|249.0|
|    529913|   4284|      70206|      14|    -1|249.0|
|    546930|   4284|      28145|      14|    -1|249.0|
|    639794|   6261|      70206|      14| 37004| 89.9|
|    335413|   4284|      28145|      14|    -1|249.0|
|    794890|   4284|      70206|      14|454237|249.0|
|    684020|   6261|      70206|      14| 37004| 99.0|
+----------+-------+-----------+--------+------+-----+
only showing top 20 rows

df.describe().show()
+-------+-----------------+-----------------+------------------+------------------+------------------+------------------+
|summary|       adgroup_id|          cate_id|       campaign_id|          customer|             brand|             price|
+-------+-----------------+-----------------+------------------+------------------+------------------+------------------+
|  count|           846811|           846811|            846811|            846811|            846811|            846811|
|   mean|         423406.0|5868.593464185043|206552.60428005777|113180.40600559038|162566.00186464275|1838.8671081309947|
| stddev|244453.4237388931|2705.171203318181|125192.34090758237| 73435.83494972257|152482.73866344756| 310887.7001702612|
|    min|                1|                1|                 1|                 1|                -1|              0.01|
|    max|           846811|            12960|            423436|            255875|            461497|             1.0E8|
+-------+-----------------+-----------------+------------------+------------------+------------------+------------------+

# 除了广告显示的价格 price,其他都是离散型特征,查看唯一值的数量
for col in df.columns[:-1]:print(col, df.groupBy(col).count().count())
adgroup_id 846811
cate_id 6769
campaign_id 423436
customer 255875
brand 99815

这些特征属于高数量类别特征,不适合用独热编码进行处理,可以用 smoothing,咱这不是在搞竞赛,就算了哈

.html

价格的话,可以很好的反应广告的属性,也不需要进行标准化和归一化了

# user_profile 数据集
df = spark.read.csv('D:/阿里ctr预估数据集/user_profile.csv', header=True)
df.show()
df.count()
+------+---------+------------+-----------------+---------+------------+--------------+----------+---------------------+
|userid|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level |
+------+---------+------------+-----------------+---------+------------+--------------+----------+---------------------+
|   234|        0|           5|                2|        5|        null|             3|         0|                    3|
|   523|        5|           2|                2|        2|           1|             3|         1|                    2|
|   612|        0|           8|                1|        2|           2|             3|         0|                 null|
|  1670|        0|           4|                2|        4|        null|             1|         0|                 null|
|  2545|        0|          10|                1|        4|        null|             3|         0|                 null|
|  3644|       49|           6|                2|        6|           2|             3|         0|                    2|
|  5777|       44|           5|                2|        5|           2|             3|         0|                    2|
|  6211|        0|           9|                1|        3|        null|             3|         0|                    2|
|  6355|        2|           1|                2|        1|           1|             3|         0|                    4|
|  6823|       43|           5|                2|        5|           2|             3|         0|                    1|
|  6972|        5|           2|                2|        2|           2|             3|         1|                    2|
|  9293|        0|           5|                2|        5|        null|             3|         0|                    4|
|  9510|       55|           8|                1|        2|           2|             2|         0|                    2|
| 10122|       33|           4|                2|        4|           2|             3|         0|                    2|
| 10549|        0|           4|                2|        4|           2|             3|         0|                 null|
| 10812|        0|           4|                2|        4|        null|             2|         0|                 null|
| 10912|        0|           4|                2|        4|           2|             3|         0|                 null|
| 10996|        0|           5|                2|        5|        null|             3|         0|                    4|
| 11256|        8|           2|                2|        2|           1|             3|         0|                    3|
| 11310|       31|           4|                2|        4|           1|             3|         0|                    4|
+------+---------+------------+-----------------+---------+------------+--------------+----------+---------------------+
only showing top 20 rows

1061768
df.printSchema()
root|-- userid: string (nullable = true)|-- cms_segid: string (nullable = true)|-- cms_group_id: string (nullable = true)|-- final_gender_code: string (nullable = true)|-- age_level: string (nullable = true)|-- pvalue_level: string (nullable = true)|-- shopping_level: string (nullable = true)|-- occupation: string (nullable = true)|-- new_user_class_level : string (nullable = true)

# 这里的 null 表示空值,前面的 NULL 是字符串
df.printSchema()
df = df.\withColumn('userid', df.userid.cast(IntegerType())).\withColumn('cms_segid', df.cms_segid.cast(IntegerType())).\withColumn('cms_group_id', df.cms_group_id.cast(IntegerType())).\withColumn('final_gender_code', df.final_gender_code.cast(IntegerType())).\withColumn('age_level', df.age_level.cast(IntegerType())).\withColumn('pvalue_level', df.pvalue_level.cast(IntegerType())).\withColumn('shopping_level', df.shopping_level.cast(IntegerType())).\withColumn('occupation', df.occupation.cast(IntegerType())).\withColumn('new_user_class_level ', df['new_user_class_level '].cast(IntegerType())) # 这里后面多了一个空格 'new_user_class_level 'df.printSchema()
df.show()
root|-- userid: string (nullable = true)|-- cms_segid: string (nullable = true)|-- cms_group_id: string (nullable = true)|-- final_gender_code: string (nullable = true)|-- age_level: string (nullable = true)|-- pvalue_level: string (nullable = true)|-- shopping_level: string (nullable = true)|-- occupation: string (nullable = true)|-- new_user_class_level : string (nullable = true)root|-- userid: integer (nullable = true)|-- cms_segid: integer (nullable = true)|-- cms_group_id: integer (nullable = true)|-- final_gender_code: integer (nullable = true)|-- age_level: integer (nullable = true)|-- pvalue_level: integer (nullable = true)|-- shopping_level: integer (nullable = true)|-- occupation: integer (nullable = true)|-- new_user_class_level : integer (nullable = true)+------+---------+------------+-----------------+---------+------------+--------------+----------+---------------------+
|userid|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level |
+------+---------+------------+-----------------+---------+------------+--------------+----------+---------------------+
|   234|        0|           5|                2|        5|        null|             3|         0|                    3|
|   523|        5|           2|                2|        2|           1|             3|         1|                    2|
|   612|        0|           8|                1|        2|           2|             3|         0|                 null|
|  1670|        0|           4|                2|        4|        null|             1|         0|                 null|
|  2545|        0|          10|                1|        4|        null|             3|         0|                 null|
|  3644|       49|           6|                2|        6|           2|             3|         0|                    2|
|  5777|       44|           5|                2|        5|           2|             3|         0|                    2|
|  6211|        0|           9|                1|        3|        null|             3|         0|                    2|
|  6355|        2|           1|                2|        1|           1|             3|         0|                    4|
|  6823|       43|           5|                2|        5|           2|             3|         0|                    1|
|  6972|        5|           2|                2|        2|           2|             3|         1|                    2|
|  9293|        0|           5|                2|        5|        null|             3|         0|                    4|
|  9510|       55|           8|                1|        2|           2|             2|         0|                    2|
| 10122|       33|           4|                2|        4|           2|             3|         0|                    2|
| 10549|        0|           4|                2|        4|           2|             3|         0|                 null|
| 10812|        0|           4|                2|        4|        null|             2|         0|                 null|
| 10912|        0|           4|                2|        4|           2|             3|         0|                 null|
| 10996|        0|           5|                2|        5|        null|             3|         0|                    4|
| 11256|        8|           2|                2|        2|           1|             3|         0|                    3|
| 11310|       31|           4|                2|        4|           1|             3|         0|                    4|
+------+---------+------------+-----------------+---------+------------+--------------+----------+---------------------+
only showing top 20 rows

# 这个用户信息表的特征也全是离散值。。
na_col = ['pvalue_level', 'new_user_class_level ']
for col in df.columns:if col not in na_col:print(col, df.groupBy(col).count().count())
userid 1061768
cms_segid 97
cms_group_id 13
final_gender_code 2
age_level 7
shopping_level 3
occupation 2
# 查看 'pvalue_level', 'new_user_class_level ' 缺失值比例
for col in na_col:print(col, df.groupBy(col).count().show())
+------------+------+
|pvalue_level| count|
+------------+------+
|        null|575917|
|           1|154436|
|           3| 37759|
|           2|293656|
+------------+------+pvalue_level None
+---------------------+------+
|new_user_class_level | count|
+---------------------+------+
|                 null|344920|
|                    1| 80548|
|                    3|173047|
|                    4|138833|
|                    2|324420|
+---------------------+------+new_user_class_level  None
_sum = df.count()
# 空值比例
print('pvalue_level', 1 - df.dropna(subset=['pvalue_level']).count() / _sum)
print('new_user_class_level ', 1 - df.dropna(subset=['new_user_class_level ']).count() / _sum)
pvalue_level 0.5424132202138321
new_user_class_level  0.32485439380354275
df
DataFrame[userid: int, cms_segid: int, cms_group_id: int, final_gender_code: int, age_level: int, pvalue_level: int, shopping_level: int, occupation: int, new_user_class_level : int]
# 用随机森林填补 pvalue_level 列的缺失值
# 把 pvalue_level作为标签,其他特征作为特征向量进行训练
# 把用预测值填充缺失值,不为空的值作为 训练集的标签
from pyspark.mllib.regression import LabeledPointtrain = df.dropna(subset=['pvalue_level']).rdd.map(# LabeledPoint 得到 (标签, 特征向量) 的元组# 离散值编码是从0 开始,pvalue_level最小值是1, 1,2,3 lambda r:LabeledPoint(r.pvalue_level-1, [r.cms_segid, r.cms_group_id, \r.final_gender_code, r.age_level, r.shopping_level, r.occupation])
)

官方文档

classmethod trainClassifier(data, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy=‘auto’, impurity=‘gini’, maxDepth=4, maxBins=32, seed=None)[source]

Train a random forest model for binary or multiclass classification.

  • Parameters

    • data – Training dataset: RDD of LabeledPoint. Labels should take values {0, 1, …, numClasses-1}.

    • numClasses – Number of classes for classification.

    • categoricalFeaturesInfo – Map storing arity of categorical features. An entry (n -> k) indicates that feature n is categorical with k categories indexed from 0: {0, 1, …, k-1}.

    • numTrees – Number of trees in the random forest.

    • featureSubsetStrategy – Number of features to consider for splits at each node. Supported values: “auto”, “all”, “sqrt”, “log2”, “onethird”. If “auto” is set, this parameter is set based on numTrees: if numTrees == 1, set to “all”; if numTrees > 1 (forest) set to “sqrt”. (default: “auto”)

    • impurity – Criterion used for information gain calculation. Supported values: “gini” or “entropy”. (default: “gini”)

    • maxDepth – Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 means 1 internal node + 2 leaf nodes). (default: 4)

    • maxBins – Maximum number of bins used for splitting features. (default: 32)

    • seed – Random seed for bootstrapping and choosing feature subsets. Set as None to generate seed based on system time. (default: None)

    Returns

    RandomForestModel that can be used for prediction.

%%time
from pyspark.mllib.tree import RandomForestrfc = RandomForest.trainClassifier(data=train, numClasses=3, \categoricalFeaturesInfo={},numTrees=10)
Wall time: 12.7 s
# 筛选出 'pvalue_level' 存在缺失值的行并填充
pvalue_level_na_df = df.na.fill(-1).where('pvalue_level=-1')
pvalue_level_na_df.show()
+------+---------+------------+-----------------+---------+------------+--------------+----------+---------------------+
|userid|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level |
+------+---------+------------+-----------------+---------+------------+--------------+----------+---------------------+
|   234|        0|           5|                2|        5|          -1|             3|         0|                    3|
|  1670|        0|           4|                2|        4|          -1|             1|         0|                   -1|
|  2545|        0|          10|                1|        4|          -1|             3|         0|                   -1|
|  6211|        0|           9|                1|        3|          -1|             3|         0|                    2|
|  9293|        0|           5|                2|        5|          -1|             3|         0|                    4|
| 10812|        0|           4|                2|        4|          -1|             2|         0|                   -1|
| 10996|        0|           5|                2|        5|          -1|             3|         0|                    4|
| 11602|        0|           5|                2|        5|          -1|             3|         0|                    2|
| 11727|        0|           3|                2|        3|          -1|             3|         0|                    1|
| 12195|        0|          10|                1|        4|          -1|             3|         0|                    2|
| 12620|        0|           4|                2|        4|          -1|             2|         0|                   -1|
| 12873|        0|           5|                2|        5|          -1|             3|         0|                    2|
| 14027|        0|          10|                1|        4|          -1|             3|         0|                    3|
| 14437|        0|           5|                2|        5|          -1|             3|         0|                   -1|
| 14574|        0|           1|                2|        1|          -1|             2|         0|                   -1|
| 14985|        0|          11|                1|        5|          -1|             2|         0|                   -1|
| 15525|        0|           3|                2|        3|          -1|             3|         0|                    1|
| 17025|        0|           5|                2|        5|          -1|             3|         0|                   -1|
| 17097|        0|           4|                2|        4|          -1|             2|         0|                   -1|
| 18799|        0|           5|                2|        5|          -1|             3|         0|                    4|
+------+---------+------------+-----------------+---------+------------+--------------+----------+---------------------+
only showing top 20 rows

def feature_row(r):'''筛选出作为特征向量的列'''return r.cms_segid, r.cms_group_id, r.final_gender_code, r.age_level, r.shopping_level, r.occupation# 筛选出要进行预测的特征向量
rdd = pvalue_level_na_df.rdd.map(feature_row)
pred = rfc.predict(rdd)
pred
MapPartitionsRDD[373] at mapPartitions at PythonMLLibAPI.scala:1336
# 对 标签/预测值 进行 +1 处理
pred_df = pred.map(lambda value:value + 1).collect()
pred_df
[2.0,2.0,2.0,...]
type(pred_df)
list
# 转为 pd.dataframe 操作,spark 的 dataframe 合并两个 df 很麻烦,还得同一个df的另一部分才能合并
p_obj = pvalue_level_na_df.toPandas()
p_obj['pvalue_level'] = pred_df
pdf = spark.createDataFrame(p_obj)
pdf.printSchema()
root|-- userid: long (nullable = true)|-- cms_segid: long (nullable = true)|-- cms_group_id: long (nullable = true)|-- final_gender_code: long (nullable = true)|-- age_level: long (nullable = true)|-- pvalue_level: double (nullable = true)|-- shopping_level: long (nullable = true)|-- occupation: long (nullable = true)|-- new_user_class_level : long (nullable = true)

pdf = pdf.\withColumn('userid', pdf.userid.cast(IntegerType())).\withColumn('cms_segid', pdf.cms_segid.cast(IntegerType())).\withColumn('cms_group_id', pdf.cms_group_id.cast(IntegerType())).\withColumn('final_gender_code', pdf.final_gender_code.cast(IntegerType())).\withColumn('age_level', pdf.age_level.cast(IntegerType())).\withColumn('pvalue_level', pdf.pvalue_level.cast(IntegerType())).\withColumn('shopping_level', pdf.shopping_level.cast(IntegerType())).\withColumn('occupation', pdf.occupation.cast(IntegerType())).\withColumn('new_user_class_level ', pdf['new_user_class_level '].cast(IntegerType())) # 这里后面多了一个空格 'new_user_class_level 'pdf.printSchema()
root|-- userid: integer (nullable = true)|-- cms_segid: integer (nullable = true)|-- cms_group_id: integer (nullable = true)|-- final_gender_code: integer (nullable = true)|-- age_level: integer (nullable = true)|-- pvalue_level: integer (nullable = true)|-- shopping_level: integer (nullable = true)|-- occupation: integer (nullable = true)|-- new_user_class_level : integer (nullable = true)

new_df = df.dropna(subset=['pvalue_level']).unionAll(pdf)
new_df.show()
+------+---------+------------+-----------------+---------+------------+--------------+----------+---------------------+
|userid|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level |
+------+---------+------------+-----------------+---------+------------+--------------+----------+---------------------+
|   523|        5|           2|                2|        2|           1|             3|         1|                    2|
|   612|        0|           8|                1|        2|           2|             3|         0|                 null|
|  3644|       49|           6|                2|        6|           2|             3|         0|                    2|
|  5777|       44|           5|                2|        5|           2|             3|         0|                    2|
|  6355|        2|           1|                2|        1|           1|             3|         0|                    4|
|  6823|       43|           5|                2|        5|           2|             3|         0|                    1|
|  6972|        5|           2|                2|        2|           2|             3|         1|                    2|
|  9510|       55|           8|                1|        2|           2|             2|         0|                    2|
| 10122|       33|           4|                2|        4|           2|             3|         0|                    2|
| 10549|        0|           4|                2|        4|           2|             3|         0|                 null|
| 10912|        0|           4|                2|        4|           2|             3|         0|                 null|
| 11256|        8|           2|                2|        2|           1|             3|         0|                    3|
| 11310|       31|           4|                2|        4|           1|             3|         0|                    4|
| 11739|       20|           3|                2|        3|           2|             3|         0|                    4|
| 12549|       33|           4|                2|        4|           2|             3|         0|                    2|
| 15155|       36|           5|                2|        5|           2|             1|         0|                 null|
| 15347|       20|           3|                2|        3|           2|             3|         0|                    3|
| 15455|        8|           2|                2|        2|           2|             3|         0|                    3|
| 15783|        0|           4|                2|        4|           2|             3|         0|                 null|
| 16749|        5|           2|                2|        2|           1|             3|         1|                    4|
+------+---------+------------+-----------------+---------+------------+--------------+----------+---------------------+
only showing top 20 rows

%%time
# 独热编码
df = df.withColumnRenamed('new_user_class_level ', 'new_user_class_level') # 我忍这个空格很久了,现在去掉
df = df.na.fill(-1)
df.show()
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
|userid|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
|   234|        0|           5|                2|        5|          -1|             3|         0|                   3|
|   523|        5|           2|                2|        2|           1|             3|         1|                   2|
|   612|        0|           8|                1|        2|           2|             3|         0|                  -1|
|  1670|        0|           4|                2|        4|          -1|             1|         0|                  -1|
|  2545|        0|          10|                1|        4|          -1|             3|         0|                  -1|
|  3644|       49|           6|                2|        6|           2|             3|         0|                   2|
|  5777|       44|           5|                2|        5|           2|             3|         0|                   2|
|  6211|        0|           9|                1|        3|          -1|             3|         0|                   2|
|  6355|        2|           1|                2|        1|           1|             3|         0|                   4|
|  6823|       43|           5|                2|        5|           2|             3|         0|                   1|
|  6972|        5|           2|                2|        2|           2|             3|         1|                   2|
|  9293|        0|           5|                2|        5|          -1|             3|         0|                   4|
|  9510|       55|           8|                1|        2|           2|             2|         0|                   2|
| 10122|       33|           4|                2|        4|           2|             3|         0|                   2|
| 10549|        0|           4|                2|        4|           2|             3|         0|                  -1|
| 10812|        0|           4|                2|        4|          -1|             2|         0|                  -1|
| 10912|        0|           4|                2|        4|           2|             3|         0|                  -1|
| 10996|        0|           5|                2|        5|          -1|             3|         0|                   4|
| 11256|        8|           2|                2|        2|           1|             3|         0|                   3|
| 11310|       31|           4|                2|        4|           1|             3|         0|                   4|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+
only showing top 20 rowsWall time: 492 ms
# 要进行独热编码必须先把该列值转为字符串类型
from pyspark.sql.types import StringTypedf = df.withColumn('new_user_class_level', df['new_user_class_level'].cast(StringType()))stringindexer = StringIndexer(inputCol='new_user_class_level',outputCol='nucl_onehot_feature')
encoder = OneHotEncoder(dropLast=False, inputCol='nucl_onehot_feature', outputCol='nucl_onehot_value')
pipeline = Pipeline(stages=[stringindexer, encoder])
pipeline_fit = pipeline.fit(df)df = pipeline_fit.transform(df)
df.show()
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-------------------+-----------------+
|userid|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level|nucl_onehot_feature|nucl_onehot_value|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-------------------+-----------------+
|   234|        0|           5|                2|        5|          -1|             3|         0|                   3|                2.0|    (5,[2],[1.0])|
|   523|        5|           2|                2|        2|           1|             3|         1|                   2|                1.0|    (5,[1],[1.0])|
|   612|        0|           8|                1|        2|           2|             3|         0|                  -1|                0.0|    (5,[0],[1.0])|
|  1670|        0|           4|                2|        4|          -1|             1|         0|                  -1|                0.0|    (5,[0],[1.0])|
|  2545|        0|          10|                1|        4|          -1|             3|         0|                  -1|                0.0|    (5,[0],[1.0])|
|  3644|       49|           6|                2|        6|           2|             3|         0|                   2|                1.0|    (5,[1],[1.0])|
|  5777|       44|           5|                2|        5|           2|             3|         0|                   2|                1.0|    (5,[1],[1.0])|
|  6211|        0|           9|                1|        3|          -1|             3|         0|                   2|                1.0|    (5,[1],[1.0])|
|  6355|        2|           1|                2|        1|           1|             3|         0|                   4|                3.0|    (5,[3],[1.0])|
|  6823|       43|           5|                2|        5|           2|             3|         0|                   1|                4.0|    (5,[4],[1.0])|
|  6972|        5|           2|                2|        2|           2|             3|         1|                   2|                1.0|    (5,[1],[1.0])|
|  9293|        0|           5|                2|        5|          -1|             3|         0|                   4|                3.0|    (5,[3],[1.0])|
|  9510|       55|           8|                1|        2|           2|             2|         0|                   2|                1.0|    (5,[1],[1.0])|
| 10122|       33|           4|                2|        4|           2|             3|         0|                   2|                1.0|    (5,[1],[1.0])|
| 10549|        0|           4|                2|        4|           2|             3|         0|                  -1|                0.0|    (5,[0],[1.0])|
| 10812|        0|           4|                2|        4|          -1|             2|         0|                  -1|                0.0|    (5,[0],[1.0])|
| 10912|        0|           4|                2|        4|           2|             3|         0|                  -1|                0.0|    (5,[0],[1.0])|
| 10996|        0|           5|                2|        5|          -1|             3|         0|                   4|                3.0|    (5,[3],[1.0])|
| 11256|        8|           2|                2|        2|           1|             3|         0|                   3|                2.0|    (5,[2],[1.0])|
| 11310|       31|           4|                2|        4|           1|             3|         0|                   4|                3.0|    (5,[3],[1.0])|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-------------------+-----------------+
only showing top 20 rows

df = df.withColumn('pvalue_level', df['pvalue_level'].cast(StringType()))stringindexer = StringIndexer(inputCol='pvalue_level',outputCol='pvalue_level_onehot_feature')
encoder = OneHotEncoder(dropLast=False, inputCol='pvalue_level_onehot_feature', outputCol='pl_onehot_value')
pipeline = Pipeline(stages=[stringindexer, encoder])
pipeline_fit = pipeline.fit(df)df = pipeline_fit.transform(df)
df.printSchema()
df.columns
root|-- userid: integer (nullable = true)|-- cms_segid: integer (nullable = true)|-- cms_group_id: integer (nullable = true)|-- final_gender_code: integer (nullable = true)|-- age_level: integer (nullable = true)|-- pvalue_level: string (nullable = true)|-- shopping_level: integer (nullable = true)|-- occupation: integer (nullable = true)|-- new_user_class_level: string (nullable = true)|-- nucl_onehot_feature: double (nullable = false)|-- nucl_onehot_value: vector (nullable = true)|-- pvalue_level_onehot_feature: double (nullable = false)|-- pl_onehot_value: vector (nullable = true)

['userid','cms_segid','cms_group_id','final_gender_code','age_level','pvalue_level','shopping_level','occupation','new_user_class_level','nucl_onehot_feature','nucl_onehot_value','pvalue_level_onehot_feature','pl_onehot_value']
# 特征合并
from pyspark.ml.feature import VectorAssemblerfeature_df = VectorAssembler().setInputCols(['age_level', 'pl_onehot_value', 'nucl_onehot_value']).\setOutputCol('features').transform(df)
feature_df.show()
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-------------------+-----------------+---------------------------+---------------+--------------------+
|userid|cms_segid|cms_group_id|final_gender_code|age_level|pvalue_level|shopping_level|occupation|new_user_class_level|nucl_onehot_feature|nucl_onehot_value|pvalue_level_onehot_feature|pl_onehot_value|            features|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-------------------+-----------------+---------------------------+---------------+--------------------+
|   234|        0|           5|                2|        5|          -1|             3|         0|                   3|                2.0|    (5,[2],[1.0])|                        0.0|  (4,[0],[1.0])|(10,[0,1,7],[5.0,...|
|   523|        5|           2|                2|        2|           1|             3|         1|                   2|                1.0|    (5,[1],[1.0])|                        2.0|  (4,[2],[1.0])|(10,[0,3,6],[2.0,...|
|   612|        0|           8|                1|        2|           2|             3|         0|                  -1|                0.0|    (5,[0],[1.0])|                        1.0|  (4,[1],[1.0])|(10,[0,2,5],[2.0,...|
|  1670|        0|           4|                2|        4|          -1|             1|         0|                  -1|                0.0|    (5,[0],[1.0])|                        0.0|  (4,[0],[1.0])|(10,[0,1,5],[4.0,...|
|  2545|        0|          10|                1|        4|          -1|             3|         0|                  -1|                0.0|    (5,[0],[1.0])|                        0.0|  (4,[0],[1.0])|(10,[0,1,5],[4.0,...|
|  3644|       49|           6|                2|        6|           2|             3|         0|                   2|                1.0|    (5,[1],[1.0])|                        1.0|  (4,[1],[1.0])|(10,[0,2,6],[6.0,...|
|  5777|       44|           5|                2|        5|           2|             3|         0|                   2|                1.0|    (5,[1],[1.0])|                        1.0|  (4,[1],[1.0])|(10,[0,2,6],[5.0,...|
|  6211|        0|           9|                1|        3|          -1|             3|         0|                   2|                1.0|    (5,[1],[1.0])|                        0.0|  (4,[0],[1.0])|(10,[0,1,6],[3.0,...|
|  6355|        2|           1|                2|        1|           1|             3|         0|                   4|                3.0|    (5,[3],[1.0])|                        2.0|  (4,[2],[1.0])|(10,[0,3,8],[1.0,...|
|  6823|       43|           5|                2|        5|           2|             3|         0|                   1|                4.0|    (5,[4],[1.0])|                        1.0|  (4,[1],[1.0])|(10,[0,2,9],[5.0,...|
|  6972|        5|           2|                2|        2|           2|             3|         1|                   2|                1.0|    (5,[1],[1.0])|                        1.0|  (4,[1],[1.0])|(10,[0,2,6],[2.0,...|
|  9293|        0|           5|                2|        5|          -1|             3|         0|                   4|                3.0|    (5,[3],[1.0])|                        0.0|  (4,[0],[1.0])|(10,[0,1,8],[5.0,...|
|  9510|       55|           8|                1|        2|           2|             2|         0|                   2|                1.0|    (5,[1],[1.0])|                        1.0|  (4,[1],[1.0])|(10,[0,2,6],[2.0,...|
| 10122|       33|           4|                2|        4|           2|             3|         0|                   2|                1.0|    (5,[1],[1.0])|                        1.0|  (4,[1],[1.0])|(10,[0,2,6],[4.0,...|
| 10549|        0|           4|                2|        4|           2|             3|         0|                  -1|                0.0|    (5,[0],[1.0])|                        1.0|  (4,[1],[1.0])|(10,[0,2,5],[4.0,...|
| 10812|        0|           4|                2|        4|          -1|             2|         0|                  -1|                0.0|    (5,[0],[1.0])|                        0.0|  (4,[0],[1.0])|(10,[0,1,5],[4.0,...|
| 10912|        0|           4|                2|        4|           2|             3|         0|                  -1|                0.0|    (5,[0],[1.0])|                        1.0|  (4,[1],[1.0])|(10,[0,2,5],[4.0,...|
| 10996|        0|           5|                2|        5|          -1|             3|         0|                   4|                3.0|    (5,[3],[1.0])|                        0.0|  (4,[0],[1.0])|(10,[0,1,8],[5.0,...|
| 11256|        8|           2|                2|        2|           1|             3|         0|                   3|                2.0|    (5,[2],[1.0])|                        2.0|  (4,[2],[1.0])|(10,[0,3,7],[2.0,...|
| 11310|       31|           4|                2|        4|           1|             3|         0|                   4|                3.0|    (5,[3],[1.0])|                        2.0|  (4,[2],[1.0])|(10,[0,3,8],[4.0,...|
+------+---------+------------+-----------------+---------+------------+--------------+----------+--------------------+-------------------+-----------------+---------------------------+---------------+--------------------+
only showing top 20 rows

更多推荐

Spark PreprocessingFE practice

本文发布于:2024-02-12 12:19:29,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1687772.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:Spark   PreprocessingFE   practice

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!