Details
-
New Feature
-
Status: To Do
-
Minor
-
Resolution: Unresolved
-
None
Description
This is a suggestion for a small usability improvement in the gluon API. The proposal is to include a member `F` with each tensor variable that points to mxnet `ndarray` or `symbol` module, depending on whether the variable is an NDArray or a Symbol. So for instance
x = mx.nd.zeros(10)
x.F # this is mx.ndarray
While this is only a small change, it simplifies the API for functions and classes that do not contain learnable parameters. Such functions then don't need an explicit F parameter in the API and don't need to be wrapped in HybridBlocks.
As an example one could implement a Gaussian distribution like this:
class Gaussian: def __init__(self, mu, sigma): self.mu = mu self.sigma = sigma self.F = mu.F def sample(self): return self.F.sample_normal(mu=self.mu, sigma=self.sigma) def log_prob(self, x): return -0.5 * F.square(((x - self.mu) / self.sigma)) - F.log(self.sigma) - 0.5 * math.log(2 * math.pi)
which gives a clean API that works for both symbols and ndarrays
distr = Gaussian(mx.nd.zeros(10), mx.nd.ones(10)) s = distr.sample() lp = distr.log_prob(s)
While one can currently shim this by using a function
def get_F(var): if isinstance(var, mx.nd.NDArray): return mx.nd else: return mx.sym
it makes sense to have a standard mechanism for this.