Description
Sidenote: make formatting updates described in https://github.com/apache/spark/pull/39188
Currently, `Distributor().run(...)` takes only files as input. Now we will add in additional functionality to take in functions as well. This will require us to go through the following process on each task in the executor nodes:
1. take the input function and args and pickle them
2. Create a temp train.py file that looks like
import cloudpickle import os if _name_ == "_main_": train, args = cloudpickle.load(f"{tempdir}/train_input.pkl") output = train(*args) if output and os.environ.get("RANK", "") == "0": # this is for partitionId == 0 cloudpickle.dump(f"{tempdir}/train_output.pkl")
3. Run that train.py file with `torchrun`
4. Check if `train_output.pkl` has been created on process on partitionId == 0, if it has, then deserialize it and return that output through `.collect()`
Attachments
Issue Links
- links to