crackcell's dustbin home projects
首页 > Spark代码笔记:mllib/feature/OneHotEncoder > 正文

Spark代码笔记:mllib/feature/OneHotEncoder

1 简介

对输入列进行one-hot编码,以向量形式输出。

2 代码笔记

  • OneHotEncoder是一个Transformer,对应一个输入列,一个输出列
  • 有一个略怪异的droplast特性,默认是true,会不对最后一个出现的向量进行编码,而产生一个空向量。个人习惯关掉这个开关,用setDropLast设置

2.1 transform

@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {

首先进行ML属性的组装过程:

// schema transformation
val inputColName = $(inputCol)
val outputColName = $(outputCol)
val shouldDropLast = $(dropLast)
var outputAttrGroup = AttributeGroup.fromStructField(
  transformSchema(dataset.schema)(outputColName))
  • transformSchema组装输出列的schema:
    1. 校验输入列的类型为NumericType
    2. 校验输出列不存在
    3. 从输入列的ML属性中提取属性名
      • 对于NominalAttribute:提取每个value作为名字,否则以取值下标作为名字
      • 对于BinaryAttribute:同nominal
      • 对于NumericAttribute:抛出异常
    4. 封装输出列属性BinaryAttribute
if (outputAttrGroup.size < 0) {
  // If the number of attributes is unknown, we check the values from the input column.
  val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).rdd.map(_.getDouble(0))
    .aggregate(0.0)(
      (m, x) => {
        assert(x <= Int.MaxValue,
          s"OneHotEncoder only supports up to ${Int.MaxValue} indices, but got $x")
        assert(x >= 0.0 && x == x.toInt,
          s"Values from column $inputColName must be indices, but got $x.")
        math.max(m, x)
      },
      (m0, m1) => {
        math.max(m0, m1)
      }
    ).toInt + 1
  val outputAttrNames = Array.tabulate(numAttrs)(_.toString)
  val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames
  val outputAttrs: Array[Attribute] =
    filtered.map(name => BinaryAttribute.defaultAttr.withName(name))
  outputAttrGroup = new AttributeGroup(outputColName, outputAttrs)
}

如果从transformSchema里没有获取成功属性个数,遍历一遍整个数据集,获取所有取值可能,用索引构建属性名,如果掉到这个分支,效率会大幅降低。

  val metadata = outputAttrGroup.toMetadata()
  // data transformation
  val size = outputAttrGroup.size
  val oneValue = Array(1.0)
  val emptyValues = Array.empty[Double]
  val emptyIndices = Array.empty[Int]
  val encode = udf { label: Double =>
    if (label < size) {
      Vectors.sparse(size, Array(label.toInt), oneValue)
    } else {
      Vectors.sparse(size, emptyIndices, emptyValues)
    }
  }

  dataset.select(col("*"), encode(col(inputColName).cast(DoubleType)).as(outputColName, metadata))
}

剩下的就是实际遍历数据,执行编码的过程了。简单理解,就是把输入列的Double类型值强制转换成Int,作为稀疏向量的下标。

2.2 transformSchema

@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
  val inputColName = $(inputCol)
  val outputColName = $(outputCol)

  require(schema(inputColName).dataType.isInstanceOf[NumericType],
    s"Input column must be of type NumericType but got ${schema(inputColName).dataType}")
  val inputFields = schema.fields
  require(!inputFields.exists(_.name == outputColName),
    s"Output column $outputColName already exists.")

  val inputAttr = Attribute.fromStructField(schema(inputColName))
  val outputAttrNames: Option[Array[String]] = inputAttr match {
    case nominal: NominalAttribute =>
      if (nominal.values.isDefined) {
        nominal.values
      } else if (nominal.numValues.isDefined) {
        nominal.numValues.map(n => Array.tabulate(n)(_.toString))
      } else {
        None
      }
    case binary: BinaryAttribute =>
      if (binary.values.isDefined) {
        binary.values
      } else {
        Some(Array.tabulate(2)(_.toString))
      }
    case _: NumericAttribute =>
      throw new RuntimeException(
        s"The input column $inputColName cannot be numeric.")
    case _ =>
      None // optimistic about unknown attributes
  }

  val filteredOutputAttrNames = outputAttrNames.map { names =>
    if ($(dropLast)) {
      require(names.length > 1,
        s"The input column $inputColName should have at least two distinct values.")
      names.dropRight(1)
    } else {
      names
    }
  }

  val outputAttrGroup = if (filteredOutputAttrNames.isDefined) {
    val attrs: Array[Attribute] = filteredOutputAttrNames.get.map { name =>
      BinaryAttribute.defaultAttr.withName(name)
    }
    new AttributeGroup($(outputCol), attrs)
  } else {
    new AttributeGroup($(outputCol))
  }

  val outputFields = inputFields :+ outputAttrGroup.toStructField()
  StructType(outputFields)
}

Date: Sat Jan 28 20:14:56 2017

Author: Menglong TAN

Org version 7.9.3f with Emacs version 24

Validate XHTML 1.0

Modified theme and code from Tom Preston-Werner.