Uploaded image for project: 'Spark'
  1. Spark
  2. SPARK-6571

MatrixFactorizationModel created by load fails on predictAll

    XMLWordPrintableJSON

Details

    • Bug
    • Status: Resolved
    • Major
    • Resolution: Fixed
    • 1.3.0
    • 1.3.1, 1.4.0
    • MLlib, PySpark
    • None

    Description

      This code, adapted from the documentation, fails when using a loaded model.
      from pyspark.mllib.recommendation import ALS, Rating, MatrixFactorizationModel

      r1 = (1, 1, 1.0)
      r2 = (1, 2, 2.0)
      r3 = (2, 1, 2.0)
      ratings = sc.parallelize([r1, r2, r3])
      model = ALS.trainImplicit(ratings, 1, seed=10)
      print '(2, 2)', model.predict(2, 2)

      1. 0.43...
        testset = sc.parallelize([(1, 2), (1, 1)])
        print 'all', model.predictAll(testset).collect()
      2. [Rating(user=1, product=1, rating=1.0...), Rating(user=1, product=2, rating=1.9...)]
        import os, tempfile
        path = tempfile.mkdtemp()
        model.save(sc, path)
        sameModel = MatrixFactorizationModel.load(sc, path)
        print '(2, 2)', sameModel.predict(2,2)
        sameModel.predictAll(testset).collect()

      This gives
      (2, 2) 0.443547642944
      all [Rating(user=1, product=1, rating=1.1538351103381217), Rating(user=1, product=2, rating=0.7153473708381739)]
      (2, 2) 0.443547642944
      ---------------------------------------------------------------------------
      Py4JError Traceback (most recent call last)
      <ipython-input-18-af6612bed9d0> in <module>()
      19 sameModel = MatrixFactorizationModel.load(sc, path)
      20 print '(2, 2)', sameModel.predict(2,2)
      ---> 21 sameModel.predictAll(testset).collect()
      22

      /home/ubuntu/spark/python/pyspark/mllib/recommendation.pyc in predictAll(self, user_product)
      104 assert len(first) == 2, "user_product should be RDD of (user, product)"
      105 user_product = user_product.map(lambda (u, p): (int(u), int(p)))
      --> 106 return self.call("predict", user_product)
      107
      108 def userFeatures(self):

      /home/ubuntu/spark/python/pyspark/mllib/common.pyc in call(self, name, *a)
      134 def call(self, name, *a):
      135 """Call method of java_model"""
      --> 136 return callJavaFunc(self._sc, getattr(self._java_model, name), *a)
      137
      138

      /home/ubuntu/spark/python/pyspark/mllib/common.pyc in callJavaFunc(sc, func, *args)
      111 """ Call Java Function """
      112 args = [_py2java(sc, a) for a in args]
      --> 113 return _java2py(sc, func(*args))
      114
      115

      /home/ubuntu/spark/python/lib/py4j-0.8.2.1-src.zip/py4j/java_gateway.py in _call_(self, *args)
      536 answer = self.gateway_client.send_command(command)
      537 return_value = get_return_value(answer, self.gateway_client,
      --> 538 self.target_id, self.name)
      539
      540 for temp_arg in temp_args:

      /home/ubuntu/spark/python/lib/py4j-0.8.2.1-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
      302 raise Py4JError(
      303 'An error occurred while calling

      {0} {1} {2}

      . Trace:\n

      {3}

      \n'.
      --> 304 format(target_id, '.', name, value))
      305 else:
      306 raise Py4JError(

      Py4JError: An error occurred while calling o450.predict. Trace:
      py4j.Py4JException: Method predict([class org.apache.spark.api.java.JavaRDD]) does not exist
      at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:333)
      at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:342)
      at py4j.Gateway.invoke(Gateway.java:252)
      at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:133)
      at py4j.commands.CallCommand.execute(CallCommand.java:79)
      at py4j.GatewayConnection.run(GatewayConnection.java:207)
      at java.lang.Thread.run(Thread.java:744)

      Attachments

        Activity

          People

            mengxr Xiangrui Meng
            cchayden Charles Hayden
            Votes:
            0 Vote for this issue
            Watchers:
            3 Start watching this issue

            Dates

              Created:
              Updated:
              Resolved: