Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring moscot to include moscot.neural #763

Open
3 of 8 tasks
MUCDK opened this issue Nov 7, 2024 · 5 comments
Open
3 of 8 tasks

Refactoring moscot to include moscot.neural #763

MUCDK opened this issue Nov 7, 2024 · 5 comments
Assignees

Comments

@MUCDK
Copy link
Collaborator

MUCDK commented Nov 7, 2024

We want to have a moscot.neural module which implements the discrete problems as equivalent neural problems. Therefore, we consecutively do the following steps:

moscot.neural.solvers in future times should include

  • GenotLin,
  • GenotQuad,
  • and OTCFM,

start with GenotLin, because it's implemented already, and extend it to GenotQuad. OTCFM will be added later.

  • Port the dataloader (or roughly the idea of the dataloader) from CellFlow, and extend it to include quadratic terms.
  • Let's also distinguish between the data representation which we compute the coupling in, and the one which are input/output to the flow
@selmanozleyen
Copy link
Collaborator

selmanozleyen commented Jan 8, 2025

hi @MUCDK #778 should fix some of the items.

I think the best order to do the rest are:

  • Porting dataloader so that it would work with all 3 solvers you mentioned (my thesis code should be enough for this). In my thesis code I yield a dictionary of src_flow, tgt_flow, (src,tgt)_(lin,quad) all of these are optional. But for this code I think best would be to yield aug_flow, src_flow, tgt_flow, (src,tgt)_(lin,quad). This way for OTCFM we use src_flow, tgt_flow and for GENOT: aug_flow, tgt_flow. And this is very easy to implement. ps: I'd prefer calling it aug_flow since it was very confusing to call it cond since time can also be a condition.

  • This will also distinguish between the data representation (the matching function and flow part).

From what I see I'd need to update NeuralOTProblem and then a bit GenotLinSolver and Problem with the for the new dataloader. But the rest should come more easily. If I can get help with writing tests then this can speed up the progress.

But I also have some questions:

  • DistributionCollection is a dictionary of dict[K, DistributionContainer] but the thing is DistributionContainer is always a pair of distribution right? unlike the name suggests.
  • I also don't understand why NeuralOTProblem has multiple distributions self._distributions: Optional[DistributionCollection[K]] = None. So are these distribution pairs all trained on the same neural network?

@MUCDK
Copy link
Collaborator Author

MUCDK commented Jan 8, 2025

Great, thanks, this makes sense.

Does aug stand for augmented?

The dataloader in the MSc is similar to CellFlow right? Then that makes sense!

Yes, once you start , @LeonStadelmann can help with tests!

I think we can get rid of DistributionCollection and follow the CellFlow data handling. and yes, DistributionCollection contains one distribution.

Second q: yes, all distributions are trained with the same neural network.

@selmanozleyen
Copy link
Collaborator

selmanozleyen commented Jan 8, 2025

Does aug stand for augmented?

yes

The dataloader in the MSc is similar to CellFlow right? Then that makes sense!

yes

I think we can get rid of DistributionCollection and follow the CellFlow data handling. and yes, DistributionCollection
contains one distribution.

Just so we are on the same page currently DistributionContainer has a pair of distributions and DistributionCollection is a dictionary of pairs.

@selmanozleyen
Copy link
Collaborator

selmanozleyen commented Jan 8, 2025

@MUCDK I have another question. Currently the prepare for NeuralOTProblem is like this

    @wrap_prepare
    def prepare(
        self,
        policy_key: str,
        policy: Policy_t,
        xy: Mapping[str, Any],
        xx: Mapping[str, Any],
        conditions: Mapping[str, Any],
        a: Optional[str] = None,
        b: Optional[str] = None,
        subset: Optional[Sequence[Tuple[K, K]]] = None,
        reference: K = None,
        **kwargs: Any,
    ) -> "NeuralOTProblem":

I will update this. I currently don't know what xx is or if its something different that x_attrs. I am planning to make it very similar to OTProblem.prepare. What do you think?

def prepare(
        self,
        xy: Mapping[str, Any],
        x: Mapping[str, Any],
        x_flow: Mapping[str, Any],
        x_aug: Mapping[str, Any],
        y: Mapping[str, Any],
        y_flow: Mapping[str, Any],
        a: Optional[Union[bool, str, ArrayLike]] = None,
        b: Optional[Union[bool, str, ArrayLike]] = None,
        marginal_kwargs: Dict[str, Any] = types.MappingProxyType({}),
    ) -> "NeuralOTProblem":```

I also think it is better to rewrite GENOTLinSolver as a generalized GENOTSolver (it should work with OTCFM as well but since its lower priority we can check if we can make it NeuralOTSolver later.)

@MUCDK
Copy link
Collaborator Author

MUCDK commented Jan 8, 2025

xx would be quadratic source term. Regarding your suggestion, I would have xy, x, y. By default, x_flow=None, and then we set x_flow=xy for OTFM, and x_flow=gaussian anyways for GENOT. Also for GENOT, I would have by default x_aug=concat(x_lin, x_quad), and analogously for y. The rest makes sense!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants