tensorclouds.transport.diffusion¶
Classes¶
Base class for all neural network modules. |
Functions¶
|
|
|
|
|
Module Contents¶
- tensorclouds.transport.diffusion.linear_beta_schedule(timesteps)¶
- tensorclouds.transport.diffusion.sigmoid_beta_schedule(timesteps, start=0, end=3, tau=0.3, clamp_min=1e-05)¶
- tensorclouds.transport.diffusion.compute_constants(timesteps, start_at=1.0, scheduler=linear_beta_schedule)¶
- class tensorclouds.transport.diffusion.ModelPrediction¶
- prediction: tensorclouds.tensorcloud.TensorCloud¶
- target: dict¶
- reweight: float¶
- class tensorclouds.transport.diffusion.TensorCloudDiffuser¶
Bases:
flax.linen.Module
Base class for all neural network modules.
Layers and models should subclass this class.
All Flax Modules are Python 3.7 dataclasses. Since dataclasses take over
__init__
, you should instead overridesetup()
, which is automatically called to initialize the module.Modules can contain submodules, and in this way can be nested in a tree structure. Submodels can be assigned as regular attributes inside the
setup()
method.You can define arbitrary “forward pass” methods on your Module subclass. While no methods are special-cased,
__call__
is a popular choice because it allows you to use module instances as if they are functions:>>> from flax import linen as nn >>> from typing import Tuple >>> class Module(nn.Module): ... features: Tuple[int, ...] = (16, 4) ... def setup(self): ... self.dense1 = nn.Dense(self.features[0]) ... self.dense2 = nn.Dense(self.features[1]) ... def __call__(self, x): ... return self.dense2(nn.relu(self.dense1(x)))
Optionally, for more concise module implementations where submodules definitions are co-located with their usage, you can use the
compact()
wrapper.- network: flax.linen.Module¶
- irreps: e3nn_jax.Irreps¶
- var_features: float¶
- var_coords: float¶
- timesteps: int = 1000¶
- leading_shape: Tuple = (1,)¶
- setup()¶
Initializes a Module lazily (similar to a lazy
__init__
).setup
is called once lazily on a module instance when a module is bound, immediately before any other methods like__call__
are invoked, or before asetup
-defined attribute onself
is accessed.This can happen in three cases:
Immediately when invoking
apply()
,init()
orinit_and_output()
.Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setup
method (see__setattr__()
):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact()
, immediately before another method is called orsetup
defined attribute is accessed.
- sample(cond: e3nn_jax.IrrepsArray = None, mask_coord=None, mask_features=None)¶
- q_sample(x0, t: int)¶