Uploaded image for project: 'SystemDS'
  1. SystemDS
  2. SYSTEMDS-1389

Update API: Pass in all outputs from `forward` to `backward` for performance

    XMLWordPrintableJSON

Details

    • Improvement
    • Status: Open
    • Major
    • Resolution: Unresolved
    • None
    • None
    • None
    • None

    Description

      Currently, we do not pass the outputs of the forward functions to the backward functions in the nn library. This aims to update the backward API to include (1) all relevant gradients from upstream, (2) all outputs from forward, and (3) all inputs given to forward. Effectively, this would be equivalent to having a class object that maintains all configuration and input + output tensors. If we had an object oriented design, most of the parameters to `forward` and `backward` would just be instance variables and thus would be accessible by `forward` or `backward` as needed. Given that we don't have that design, this API update mimics that by allowing `backward` to have access to any inputs available to `forward`, as well as any outputs produced by `forward`. This provides two benefits: first, many layers can benefit from a performance perspective from having access to the outputs of the forward function within the backward function, and second, this makes the API much simpler and less error prone by allowing for simple copy-and-paste of forward inputs and outputs as arguments to backward and by removing ambiguity related to the parameters. A downside is that often times, not every single parameter is needed by backward.

      Attachments

        Activity

          People

            dusenberrymw Mike Dusenberry
            dusenberrymw Mike Dusenberry
            Votes:
            0 Vote for this issue
            Watchers:
            2 Start watching this issue

            Dates

              Created:
              Updated: