1.spark.ml.feature.VectorAssembler

    xiaoxiao2021-04-16  55

    package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute} import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ /** * A feature transformer that merges multiple columns into a vector column. * 这是一个类,用于合并多个列到一个vector列 */ @Since("1.4.0") class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("vecAssembler")) /** @group setParam */ @Since("1.4.0") def setInputCols(value: Array[String]): this.type = set(inputCols, value) /** @group setParam */ @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) // Schema transformation. val schema = dataset.schema lazy val first = dataset.toDF.first() val attrs = $(inputCols).flatMap { c => val field = schema(c) val index = schema.fieldIndex(c) field.dataType match { case DoubleType => val attr = Attribute.fromStructField(field) // If the input column doesn't have ML attribute, assume numeric. if (attr == UnresolvedAttribute) { Some(NumericAttribute.defaultAttr.withName(c)) } else { Some(attr.withName(c)) } case _: NumericType | BooleanType => // If the input column type is a compatible scalar type, assume numeric. Some(NumericAttribute.defaultAttr.withName(c)) case _: VectorUDT => val group = AttributeGroup.fromStructField(field) if (group.attributes.isDefined) { // If attributes are defined, copy them with updated names. group.attributes.get.zipWithIndex.map { case (attr, i) => if (attr.name.isDefined) { // TODO: Define a rigorous naming scheme. attr.withName(c + "_" + attr.name.get) } else { attr.withName(c + "_" + i) } } } else { // Otherwise, treat all attributes as numeric. If we cannot get the number of attributes // from metadata, check the first row. val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size) Array.tabulate(numAttrs)(i => NumericAttribute.defaultAttr.withName(c + "_" + i)) } case otherType => throw new SparkException(s"VectorAssembler does not support the $otherType type") } } val metadata = new AttributeGroup($(outputCol), attrs).toMetadata() // Data transformation. val assembleFunc = udf { r: Row => VectorAssembler.assemble(r.toSeq: _*) } val args = $(inputCols).map { c => schema(c).dataType match { case DoubleType => dataset(c) case _: VectorUDT => dataset(c) case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid") } } dataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata)) } @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { val inputColNames = $(inputCols) val outputColName = $(outputCol) val inputDataTypes = inputColNames.map(name => schema(name).dataType) inputDataTypes.foreach { case _: NumericType | BooleanType => case t if t.isInstanceOf[VectorUDT] => case other => throw new IllegalArgumentException(s"Data type $other is not supported.") } if (schema.fieldNames.contains(outputColName)) { throw new IllegalArgumentException(s"Output column $outputColName already exists.") } StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, true)) } @Since("1.4.1") override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra) } @Since("1.6.0") object VectorAssembler extends DefaultParamsReadable[VectorAssembler] { @Since("1.6.0") override def load(path: String): VectorAssembler = super.load(path) private[feature] def assemble(vv: Any*): Vector = { val indices = ArrayBuilder.make[Int] val values = ArrayBuilder.make[Double] var cur = 0 vv.foreach { case v: Double => if (v != 0.0) { indices += cur values += v } cur += 1 case vec: Vector => vec.foreachActive { case (i, v) => if (v != 0.0) { indices += cur + i values += v } } cur += vec.size case null => // TODO: output Double.NaN? throw new SparkException("Values to assemble cannot be null.") case o => throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") } Vectors.sparse(cur, indices.result(), values.result()).compressed } }
    转载请注明原文地址: https://ju.6miu.com/read-673051.html

    最新回复(0)