Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
ryan112358
left a comment
There was a problem hiding this comment.
Thanks for the hard work on this, looks like a great start! Will provide another round of feedback after my initial comments are resolved
src/mbi/relaxed_projection.py
Outdated
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
| from jax import random, jit, value_and_grad, vmap, lax |
There was a problem hiding this comment.
Let's stick to the convention of just "import jax" and "jax.numpy as jnp, and not import these additional members
| @@ -0,0 +1,593 @@ | |||
| import numpy as np | |||
There was a problem hiding this comment.
can you add a unit test to test_estimation.py to ensure this works as expected when no noise is added?
There was a problem hiding this comment.
examples won't necessarily be run by pytest or github actions, can you port over a simplified version of the example to test_estimation.py?
ryan112358
left a comment
There was a problem hiding this comment.
Appreciate all the work you've put into this so far!
src/mbi/relaxed_projection.py
Outdated
| D_start = _initialize_synthetic_dataset(key, num_generated_points=D_size, data_dimension=np.sum(domain.shape)) | ||
|
|
||
| stat_dim = _obtain_dim(measurements = measurements) | ||
| statistics = [MarginalStatistics(domain, dim) for dim in stat_dim] |
There was a problem hiding this comment.
I think the MarginalStatistics class should not be necessary if you have a MarginalLossFn. The pattern could be:
marginals = { cl : D.project(cl) for cl in cliques }
loss = loss_fn(marginals)
And this whole computation should be differentiable by jax
src/mbi/relaxed_projection.py
Outdated
| from .domain import Domain | ||
| from .marginal_loss import LinearMeasurement | ||
| from .clique_vector import CliqueVector | ||
| from .estimation import mirror_descent |
There was a problem hiding this comment.
You shouldn't need this dependency here
There was a problem hiding this comment.
You mean mirror_descent or all of these four dependencies? Domain and LinearMeasurement are imported for signatures of function, while others, except mirror_descent, are used in the function.
|
I pushed the most recent code to this branch. I went through previous comments and addressed most of them. One exception is the two-case loss function in RP, for which I left some replies. |
Support relaxed projection and neural network generators.
Two simple examples are provided in
mechanisms/NN.pyandmechanisms/RP.py