Uploaded image for project: 'Spark'
  1. Spark
  2. SPARK-41589 PyTorch Distributor
  3. SPARK-41775

Implement training functions as input

    XMLWordPrintableJSON

Details

    • Sub-task
    • Status: Resolved
    • Major
    • Resolution: Fixed
    • 3.4.0
    • 3.4.0
    • ML, PySpark
    • None

    Description

      Sidenote: make formatting updates described in https://github.com/apache/spark/pull/39188

       

      Currently, `Distributor().run(...)` takes only files as input. Now we will add in additional functionality to take in functions as well. This will require us to go through the following process on each task in the executor nodes:
      1. take the input function and args and pickle them
      2. Create a temp train.py file that looks like

      import cloudpickle
      import os
      if _name_ == "_main_":
          train, args = cloudpickle.load(f"{tempdir}/train_input.pkl")
          output = train(*args)
          if output and os.environ.get("RANK", "") == "0": # this is for partitionId == 0
              cloudpickle.dump(f"{tempdir}/train_output.pkl") 

      3. Run that train.py file with `torchrun`

      4. Check if `train_output.pkl` has been created on process on partitionId == 0, if it has, then deserialize it and return that output through `.collect()`

      Attachments

        Activity

          People

            erithwik Rithwik Ediga Lakhamsani
            erithwik Rithwik Ediga Lakhamsani
            Votes:
            0 Vote for this issue
            Watchers:
            4 Start watching this issue

            Dates

              Created:
              Updated:
              Resolved: