Uploaded image for project: 'Spark'
  1. Spark
  2. SPARK-16872

Impl Gaussian Naive Bayes Classifier

    XMLWordPrintableJSON

Details

    • New Feature
    • Status: Resolved
    • Major
    • Resolution: Fixed
    • None
    • 3.0.0
    • ML, PySpark
    • None

    Description

      I implemented Gaussian NB according to scikit-learn's GaussianNB.
      In GaussianNB model, the theta matrix is used to store means and there is a extra sigma matrix storing the variance of each feature.

      GaussianNB in spark

      scala> import org.apache.spark.ml.classification.GaussianNaiveBayes
      import org.apache.spark.ml.classification.GaussianNaiveBayes
      
      scala> val path = "/Users/zrf/.dev/spark-2.1.0-bin-hadoop2.7/data/mllib/sample_multiclass_classification_data.txt"
      path: String = /Users/zrf/.dev/spark-2.1.0-bin-hadoop2.7/data/mllib/sample_multiclass_classification_data.txt
      
      scala> val data = spark.read.format("libsvm").load(path).persist()
      data: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [label: double, features: vector]
      
      scala> val gnb = new GaussianNaiveBayes()
      gnb: org.apache.spark.ml.classification.GaussianNaiveBayes = gnb_54c50467306c
      
      scala> val model = gnb.fit(data)
      17/01/03 14:25:48 INFO Instrumentation: GaussianNaiveBayes-gnb_54c50467306c-720112035-1: training: numPartitions=1 storageLevel=StorageLevel(1 replicas)
      17/01/03 14:25:48 INFO Instrumentation: GaussianNaiveBayes-gnb_54c50467306c-720112035-1: {}
      17/01/03 14:25:49 INFO Instrumentation: GaussianNaiveBayes-gnb_54c50467306c-720112035-1: {"numFeatures":4}
      17/01/03 14:25:49 INFO Instrumentation: GaussianNaiveBayes-gnb_54c50467306c-720112035-1: {"numClasses":3}
      17/01/03 14:25:49 INFO Instrumentation: GaussianNaiveBayes-gnb_54c50467306c-720112035-1: training finished
      model: org.apache.spark.ml.classification.GaussianNaiveBayesModel = GaussianNaiveBayesModel (uid=gnb_54c50467306c) with 3 classes
      
      scala> model.pi
      res0: org.apache.spark.ml.linalg.Vector = [-1.0986122886681098,-1.0986122886681098,-1.0986122886681098]
      
      scala> model.pi.toArray.map(math.exp)
      res1: Array[Double] = Array(0.3333333333333333, 0.3333333333333333, 0.3333333333333333)
      
      scala> model.theta
      res2: org.apache.spark.ml.linalg.Matrix =
      0.2711110067018001   -0.18833335400000006  0.5430507200000001   0.605000046
      -0.6077777799999998  0.181666672           -0.8427117400000006  -0.8800001399999998
      -0.0911111425964     -0.3583333580000001   0.105084738          0.021666701507102017
      
      scala> model.sigma
      res3: org.apache.spark.ml.linalg.Matrix =
      0.1223012510889361   0.07078051983960698  0.03430000595243976   0.051336071297393815
      0.03758145300924998  0.09880280046403413  0.003390296940069426  0.007822241779598893
      0.08058763609659315  0.06701386661293329  0.024866409227781675  0.02661391644759426
      
      
      scala> model.transform(data).select("probability").take(10)
      [rdd_68_0]
      res4: Array[org.apache.spark.sql.Row] = Array([[1.0627410543476422E-21,0.9999999999999938,6.2765233965353945E-15]], [[7.254521422345374E-26,1.0,1.3849442153180895E-18]], [[1.9629244119173135E-24,0.9999999999999998,1.9424765181237926E-16]], [[6.061218297948492E-22,0.9999999999999902,9.853216073401884E-15]], [[0.9972225671942837,8.844241161578932E-165,0.002777432805716399]], [[5.361683970373604E-26,1.0,2.3004604508982183E-18]], [[0.01062850630038623,3.3102617689978775E-100,0.9893714936996136]], [[1.9297314618271785E-4,2.124922209137708E-71,0.9998070268538172]], [[3.118816393732361E-27,1.0,6.5310299615983584E-21]], [[0.9999926009854522,8.734773657627494E-206,7.399014547943611E-6]])
      
      scala> model.transform(data).select("prediction").take(10)
      [rdd_68_0]
      res5: Array[org.apache.spark.sql.Row] = Array([1.0], [1.0], [1.0], [1.0], [0.0], [1.0], [2.0], [2.0], [1.0], [0.0])
      

      GaussianNB in scikit-learn

      import numpy as np
      from sklearn.naive_bayes import GaussianNB
      from sklearn.datasets import load_svmlight_file
      
      path = '/Users/zrf/.dev/spark-2.1.0-bin-hadoop2.7/data/mllib/sample_multiclass_classification_data.txt'
      X, y = load_svmlight_file(path)
      X = X.toarray()
      
      clf = GaussianNB()
      
      clf.fit(X, y)
      
      >>> clf.class_prior_
      array([ 0.33333333,  0.33333333,  0.33333333])
      
      >>> clf.theta_
      array([[ 0.27111101, -0.18833335,  0.54305072,  0.60500005],
             [-0.60777778,  0.18166667, -0.84271174, -0.88000014],
             [-0.09111114, -0.35833336,  0.10508474,  0.0216667 ]])
             
      >>> clf.sigma_
      array([[ 0.12230125,  0.07078052,  0.03430001,  0.05133607],
             [ 0.03758145,  0.0988028 ,  0.0033903 ,  0.00782224],
             [ 0.08058764,  0.06701387,  0.02486641,  0.02661392]])
             
      >>> clf.predict_proba(X)[:10]
      array([[  1.06274105e-021,   1.00000000e+000,   6.27652340e-015],
             [  7.25452142e-026,   1.00000000e+000,   1.38494422e-018],
             [  1.96292441e-024,   1.00000000e+000,   1.94247652e-016],
             [  6.06121830e-022,   1.00000000e+000,   9.85321607e-015],
             [  9.97222567e-001,   8.84424116e-165,   2.77743281e-003],
             [  5.36168397e-026,   1.00000000e+000,   2.30046045e-018],
             [  1.06285063e-002,   3.31026177e-100,   9.89371494e-001],
             [  1.92973146e-004,   2.12492221e-071,   9.99807027e-001],
             [  3.11881639e-027,   1.00000000e+000,   6.53102996e-021],
             [  9.99992601e-001,   8.73477366e-206,   7.39901455e-006]])
             
      >>> clf.predict(X)[:10]
      array([ 1.,  1.,  1.,  1.,  0.,  1.,  2.,  2.,  1.,  0.])
      

      Attachments

        Issue Links

          Activity

            People

              podongfeng Ruifeng Zheng
              podongfeng Ruifeng Zheng
              Yanbo Liang Yanbo Liang
              Votes:
              0 Vote for this issue
              Watchers:
              5 Start watching this issue

              Dates

                Created:
                Updated:
                Resolved: