crackcell's dustbin home projects
首页 > Spark代码笔记:mllib/feature/VectorAssember > 正文

Spark代码笔记:mllib/feature/VectorAssember

1 简介

将多个输入列整合成一个向量,例如:

val assembler = new VectorAssembler()
  .setInputCols(Array("fea1", "fea2"))
  .setOutputCol("features")
val newdata = assembler.transform(dataset)

2 代码笔记

@Since("1.4.0")
class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
  extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable {
  • VectorAssembler是一个Transformer
  • 接受多个输入列:HasInputCols
  • 输出一列

2.1 transform

@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
  transformSchema(dataset.schema, logging = true)

首先检查字段合法性:

  • 输入列的类型必须是NumericType、BooleanType或者VectorUDT
  • 输出列必须不存在
// Schema transformation.
val schema = dataset.schema
lazy val first = dataset.toDF.first()

下面开始遍历每个输入列,组装输出列的ML属性信息:

  • 如果是DoubleType:若输入列有属性信息直接使用。若没有,用列的名字创建属性
  • 如果是NumericType或者BooleanType:直接用输入列的名字创建属性
  • 如果是VectorUDT,稍微复杂一点:
    1. 先尝试从输入列中提取属性信息
    2. 如果输入列整个向量的属性信息有效,按照规则创建新属性名字,并用名字创建属性:
      • 如果向量该分量的属性存在:
        输入列名_分量属性名
        
      • 如果分量属性不存在:
        输入列名_分量的小标
        
val attrs = $(inputCols).flatMap { c =>
  val field = schema(c)
  val index = schema.fieldIndex(c)
  field.dataType match {
    case DoubleType =>
      val attr = Attribute.fromStructField(field)
      // If the input column doesn't have ML attribute, assume numeric.
      if (attr == UnresolvedAttribute) {
        Some(NumericAttribute.defaultAttr.withName(c))
      } else {
        Some(attr.withName(c))
      }
    case _: NumericType | BooleanType =>
      // If the input column type is a compatible scalar type, assume numeric.
      Some(NumericAttribute.defaultAttr.withName(c))
    case _: VectorUDT =>
      val group = AttributeGroup.fromStructField(field)
      if (group.attributes.isDefined) {
        // If attributes are defined, copy them with updated names.
        group.attributes.get.zipWithIndex.map { case (attr, i) =>
          if (attr.name.isDefined) {
            // TODO: Define a rigorous naming scheme.
            attr.withName(c + "_" + attr.name.get)
          } else {
            attr.withName(c + "_" + i)
          }
        }
      } else {
        // Otherwise, treat all attributes as numeric. If we cannot get the number of attributes
        // from metadata, check the first row.
        val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size)
        Array.tabulate(numAttrs)(i => NumericAttribute.defaultAttr.withName(c + "_" + i))
      }
    case otherType =>
      throw new SparkException(s"VectorAssembler does not support the $otherType type")
  }
}
val metadata = new AttributeGroup($(outputCol), attrs).toMetadata()

提取完ML属性之后,开始进行真正的数据处理:

  • 定义对每行执行操作的udf。这里有个技巧,如果需要对整行的数据进行处理,该怎么写:
    • args是一个Array[Column],表示需要传入那些列
    • 传参的时候用struct(args: _*)转换成Row的格式
    • 在定义接受Row的udf
  // Data transformation.
  val assembleFunc = udf { r: Row =>
    VectorAssembler.assemble(r.toSeq: _*)
  }
  val args = $(inputCols).map { c =>
    schema(c).dataType match {
      case DoubleType => dataset(c)
      case _: VectorUDT => dataset(c)
      case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid")
    }
  }

  dataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata))
}

Date: Sat Jan 28 19:24:56 2017

Author: Menglong TAN

Org version 7.9.3f with Emacs version 24

Validate XHTML 1.0

Modified theme and code from Tom Preston-Werner.