Description
Add torch distributor data loader that loads data from spark partition data.
We can add 2 APIs like:
Adds a `TorchDistributor` method API :
def train_on_dataframe(self, train_function, spark_dataframe, *args, **kwargs): """ Runs distributed training using provided spark DataFrame as input data. You should ensure the input spark DataFrame have evenly divided partitions, and this method starts a barrier spark job that each spark task in the job process one partition of the input spark DataFrame. Parameters ---------- train_function : Either a PyTorch function, PyTorch Lightning function that launches distributed training. Note that inside the function, you can call `pyspark.ml.torch.distributor.get_spark_partition_data_loader` API to get a torch data loader, the data loader loads data from the corresponding partition of the input spark DataFrame. spark_dataframe : An input spark DataFrame that can be used in PyTorch `train_function` function. See `train_function` argument doc for details. args : `args` need to be the input parameters to `train_function` function. It would look like >>> model = distributor.run(train, 1e-3, 64) where train is a function and 1e-3 and 64 are regular numeric inputs to the function. kwargs : `kwargs` need to be the key-work input parameters to `train_function` function. It would look like >>> model = distributor.run(train, tol=1e-3, max_iter=64) where train is a function that has 2 arguments `tol` and `max_iter`. Returns ------- Returns the output of `train_function` called with args inside spark rank 0 task. """
Adds an loader API:
def get_spark_partition_data_loader(num_samples, batch_size, prefetch=2): """ This function must be called inside the `train_function` where `train_function` is the input argument of `TorchDistributor.train_on_dataframe`. The function returns a pytorch data loader that loads data from the corresponding spark partition data. Parameters ---------- num_samples : Number of samples to generate per epoch. If `num_samples` is less than the number of rows in the spark partition, it generate the first `num_samples` rows of the spark partition, if `num_samples` is greater than the number of rows in the spark partition, then after the iterator loaded all rows from the partition, it wraps round back to the first row. batch_size: How many samples per batch to load. prefetch: Number of batches loaded in advance. """