GithubHelp home page GithubHelp logo

gbdt_leaf_spark's Introduction

gbdt_leaf_spark(gbdt+lr)

object GbdtLeafTest {
  val modelPath = ""
  val mleapModelPath = ""

  def loadModel(spark: SparkSession, df: DataFrame): PipelineModel = {

    val sameModel = PipelineModel.load(modelPath)

    sameModel.transform(df).show(false)
    sameModel
  }

  def saveModel(spark: SparkSession, df: DataFrame): Unit = {
    var model_stages = new collection.mutable.ListBuffer[PipelineStage]()

    val transformer = new StringIndexer()
      .setInputCol("ad__cr_id")
      .setOutputCol("ad__cr_id_1")
      .setHandleInvalid("keep")

    model_stages += transformer

    val transformer2 = new StringIndexer()
      .setInputCol("gender")
      .setOutputCol("gender_1")
      .setHandleInvalid("keep")

    model_stages += transformer2

    val inter =
      new VectorAssembler().setInputCols(Array("ad__cr_id_1", "gender_1")).setOutputCol("gbdt_features")
    model_stages += inter
    val samplePipline = new Pipeline()
      .setStages(model_stages.toArray)
      .fit(df)
    val trainData = samplePipline.transform(df)

    val gbt = new GBTLeafClassifier()
      .setLabelCol("label")
      .setFeaturesCol("gbdt_features")
      .setOutputCol("gbdtGenFeatures1")
      .setMaxIter(1)
      .setPredictionCol("gbdt_prediction")
      .setRawPredictionCol("gbdt_predictiton_raw")
      .setProbabilityCol("gbdt_prob")
      .setMaxBins(1000)
      .setMaxDepth(3)
      .setMinInfoGain(0.01)
    model_stages += gbt
    //    gbt.fit(trainDataForGbdt)

    val onlinePipeline =
      new Pipeline()
        .setStages(model_stages.toArray)
        .fit(df)
    val trainDataForGbdt = onlinePipeline.transform(df)
    trainDataForGbdt.show(false)

    //    val f = gbt.fit(trainData)
    //    onlinePipeline.transform(trainData).show(false)
    //    val g = new GBTClassificationLeafModel(new GBTClassificationLeafModel)
    onlinePipeline.write.overwrite().save(modelPath)
    //    f.write.overwrite().save(modelPath)
  }

  def getTrainData(spark: SparkSession, df: DataFrame): DataFrame = {
    var model_stages = new collection.mutable.ListBuffer[PipelineStage]()

    val transformer = new StringIndexer()
      .setInputCol("ad__cr_id")
      .setOutputCol("ad__cr_id_1")
      .setHandleInvalid("keep")

    model_stages += transformer

    val transformer2 = new StringIndexer()
      .setInputCol("gender")
      .setOutputCol("gender_1")
      .setHandleInvalid("keep")

    model_stages += transformer2

    val inter =
      new Interaction().setInputCols(Array("ad__cr_id_1", "gender_1")).setOutputCol("features")
    model_stages += inter
    val samplePipline = new Pipeline()
      .setStages(model_stages.toArray)
      .fit(df)
    val trainData = samplePipline.transform(df)
    trainData
  }

  def testGbdtLR(spark: SparkSession, df: DataFrame): Unit = {
    var model_stages = new collection.mutable.ListBuffer[PipelineStage]()


    // 特征工程
    // 1.load gbdt
    val gbdtPipline = loadModel(spark, df)

    gbdtPipline.transform(df).show(false)
    println(s"gbdt pipline ${gbdtPipline.stages}")
    model_stages ++= gbdtPipline.stages
    println(s"after add gbdt length ${model_stages.length}")
    new Pipeline()
      .setStages(model_stages.toArray)
      .fit(df).transform(df).show(false)

    val transformer3 = new VectorAssembler().setInputCols(Array("ad__cr_id_1", "gender_1", "gbdtGenFeatures1")).setOutputCol("features")
    model_stages += transformer3


    val lr = new LogisticRegression()
      .setMaxIter(2)
      .setTol(0.1)
      .setElasticNetParam(0.01)
      .setFeaturesCol("gbdtGenFeatures1")
      .setLabelCol("label")

    model_stages += lr
    val onlinePipline = new Pipeline()
      .setStages(model_stages.toArray)
      .fit(df)

    onlinePipline.transform(df).show(false)
  }

  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setMaster("local").setAppName("Test Application")
    val spark =
      SparkSession.builder().config(conf).getOrCreate()
    val df = spark.read
      .option("header", "true")
      .option("inferSchema", "true")
      .option("delimiter", "\t")
      .csv("src/main/scala/org.apache.spark.ml.mleap.gbdt/data.csv") //.select("ad__cr_id", "label", "gender")
    spark.sparkContext.setLogLevel("WARN")
    saveModel(spark, df)
    loadModel(spark, df)
    testGbdtLR(spark, df)
  }

}

gbdt_leaf_spark's People

Contributors

fangwendong avatar

Stargazers

 avatar  avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.