Details
Description
LoR and AFT applied internal status to optimize prediction/transform, but the status is not correctly updated in some case:
from pyspark.sql import Row from pyspark.ml.classification import * from pyspark.ml.linalg import Vectors df = spark.createDataFrame( [ (1.0, 1.0, Vectors.dense(0.0, 5.0)), (0.0, 2.0, Vectors.dense(1.0, 2.0)), (1.0, 3.0, Vectors.dense(2.0, 1.0)), (0.0, 4.0, Vectors.dense(3.0, 3.0)), ], ["label", "weight", "features"], ) lor = LogisticRegression(weightCol="weight") model = lor.fit(df) # status changes 1 for t in [0.0, 0.1, 0.2, 0.5, 1.0]: model.setThreshold(t).transform(df) # status changes 2 [model.setThreshold(t).predict(Vectors.dense(0.0, 5.0)) for t in [0.0, 0.1, 0.2, 0.5, 1.0]] for t in [0.0, 0.1, 0.2, 0.5, 1.0]: print(t) model.setThreshold(t).transform(df).show() # <- error results
results:
0.0 +-----+------+---------+--------------------+--------------------+----------+ |label|weight| features| rawPrediction| probability|prediction| +-----+------+---------+--------------------+--------------------+----------+ | 1.0| 1.0|[0.0,5.0]|[0.10932013376341...|[0.52730284774069...| 0.0| | 0.0| 2.0|[1.0,2.0]|[-0.8619624039359...|[0.29692950635762...| 0.0| | 1.0| 3.0|[2.0,1.0]|[-0.3634508721860...|[0.41012446452385...| 0.0| | 0.0| 4.0|[3.0,3.0]|[2.33975176373760...|[0.91211618852612...| 0.0| +-----+------+---------+--------------------+--------------------+----------+ 0.1 +-----+------+---------+--------------------+--------------------+----------+ |label|weight| features| rawPrediction| probability|prediction| +-----+------+---------+--------------------+--------------------+----------+ | 1.0| 1.0|[0.0,5.0]|[0.10932013376341...|[0.52730284774069...| 0.0| | 0.0| 2.0|[1.0,2.0]|[-0.8619624039359...|[0.29692950635762...| 0.0| | 1.0| 3.0|[2.0,1.0]|[-0.3634508721860...|[0.41012446452385...| 0.0| | 0.0| 4.0|[3.0,3.0]|[2.33975176373760...|[0.91211618852612...| 0.0| +-----+------+---------+--------------------+--------------------+----------+ 0.2 +-----+------+---------+--------------------+--------------------+----------+ |label|weight| features| rawPrediction| probability|prediction| +-----+------+---------+--------------------+--------------------+----------+ | 1.0| 1.0|[0.0,5.0]|[0.10932013376341...|[0.52730284774069...| 0.0| | 0.0| 2.0|[1.0,2.0]|[-0.8619624039359...|[0.29692950635762...| 0.0| | 1.0| 3.0|[2.0,1.0]|[-0.3634508721860...|[0.41012446452385...| 0.0| | 0.0| 4.0|[3.0,3.0]|[2.33975176373760...|[0.91211618852612...| 0.0| +-----+------+---------+--------------------+--------------------+----------+ 0.5 +-----+------+---------+--------------------+--------------------+----------+ |label|weight| features| rawPrediction| probability|prediction| +-----+------+---------+--------------------+--------------------+----------+ | 1.0| 1.0|[0.0,5.0]|[0.10932013376341...|[0.52730284774069...| 0.0| | 0.0| 2.0|[1.0,2.0]|[-0.8619624039359...|[0.29692950635762...| 0.0| | 1.0| 3.0|[2.0,1.0]|[-0.3634508721860...|[0.41012446452385...| 0.0| | 0.0| 4.0|[3.0,3.0]|[2.33975176373760...|[0.91211618852612...| 0.0| +-----+------+---------+--------------------+--------------------+----------+ 1.0 +-----+------+---------+--------------------+--------------------+----------+ |label|weight| features| rawPrediction| probability|prediction| +-----+------+---------+--------------------+--------------------+----------+ | 1.0| 1.0|[0.0,5.0]|[0.10932013376341...|[0.52730284774069...| 0.0| | 0.0| 2.0|[1.0,2.0]|[-0.8619624039359...|[0.29692950635762...| 0.0| | 1.0| 3.0|[2.0,1.0]|[-0.3634508721860...|[0.41012446452385...| 0.0| | 0.0| 4.0|[3.0,3.0]|[2.33975176373760...|[0.91211618852612...| 0.0| +-----+------+---------+--------------------+--------------------+----------+