Description
As we started the discussion on slack and since the design doc isn't up to date on this I thought it would be good to discuss it here for future reference.
Currently in the design doc:
Input Types
What input types should we support? In Probabilistic Torch we have so far required that parameters are either Variables or Numbers. Should we additionally support Tensor parameter arguments? If so, then how do we deal with gradients in reparameterized distributions when Tensor arguments are supplied?
So it feels like its not settled or I'm just not keeping up
Problems
- We worry about keeping combatibility with Numbers creating need to
validate_log_prob_arg
- What should be acceptable query-types for cdf, inv_cdf, pdf, pmf? Shouldn't this be validated as with log_prob? It would be nice if scalar queries are accepted.
broadcast_call
is potentially expensive Implement broadcast_all() in C/C++ #80. See code- Despite upcasting we still do things conditionally on type such as
math.log(self.scale) if isinstance(self.scale, Number) else self.scale.log()
- We need specialized Softmax for both Tensors & Variables
- tests are largely concerned with the problem of scalar, tensor and variable mixing.
- Slightly related, we infer type, precision and device placement in rsample. But how do will utilize precision-dependent epsilons in a clean way?
- Type confusion seems to block cool ideas
Possible reasons
- Upcoming Tensor/Variable merge
- Being able to construct and query distributions using scalars is nice. But this niceness comes at a cost and sometimes without the niceness (like
log_prob
only accepts tensor xorVariable
)
Possible solutions
I'm thinking from the assumption that we'll soon have deprecated tensors in favor of variables and 0d Variables. Since mixing is a herdle, can't we just decide something like:
- All queries on a Distribution returns
Variable
? - All distribution parameters upcasted to
Variable
? - All distribution query args upcasted to Variable?
This is mostly the case currently anyway.
We have good entry point for casting as parameters will be upcasted and reshaped using broadcast_all
and arguments to (at least) log prob will have to pass validate_log_prob_arg
.
With a strong promise on input and return type this could be made fast (as proposed in #80). If this is made fast before we've figured this out we might be stuck with it.
I also like how the nn
module is working and I expect to be able to do things like calling .double()
, .float()
on a distribution. This is easily implemented in Distribution
if we assume access to all the parameters of a distribution programmatically as a tuple or dict as discussed in #130. By clearing up types now I imagine that the transition will require less ifs & elses.