tensorclouds.transport.diffusion

Classes

ModelPrediction

TensorCloudDiffuser

Base class for all neural network modules.

Functions

linear_beta_schedule(timesteps)

sigmoid_beta_schedule(timesteps[, start, end, tau, ...])

compute_constants(timesteps[, start_at, scheduler])

Module Contents

tensorclouds.transport.diffusion.linear_beta_schedule(timesteps)
tensorclouds.transport.diffusion.sigmoid_beta_schedule(timesteps, start=0, end=3, tau=0.3, clamp_min=1e-05)
tensorclouds.transport.diffusion.compute_constants(timesteps, start_at=1.0, scheduler=linear_beta_schedule)
class tensorclouds.transport.diffusion.ModelPrediction
prediction: tensorclouds.tensorcloud.TensorCloud
target: dict
reweight: float
class tensorclouds.transport.diffusion.TensorCloudDiffuser

Bases: flax.linen.Module

Base class for all neural network modules.

Layers and models should subclass this class.

All Flax Modules are Python 3.7 dataclasses. Since dataclasses take over __init__, you should instead override setup(), which is automatically called to initialize the module.

Modules can contain submodules, and in this way can be nested in a tree structure. Submodels can be assigned as regular attributes inside the setup() method.

You can define arbitrary “forward pass” methods on your Module subclass. While no methods are special-cased, __call__ is a popular choice because it allows you to use module instances as if they are functions:

>>> from flax import linen as nn
>>> from typing import Tuple

>>> class Module(nn.Module):
...   features: Tuple[int, ...] = (16, 4)

...   def setup(self):
...     self.dense1 = nn.Dense(self.features[0])
...     self.dense2 = nn.Dense(self.features[1])

...   def __call__(self, x):
...     return self.dense2(nn.relu(self.dense1(x)))

Optionally, for more concise module implementations where submodules definitions are co-located with their usage, you can use the compact() wrapper.

network: flax.linen.Module
irreps: e3nn_jax.Irreps
var_features: float
var_coords: float
timesteps: int = 1000
leading_shape: Tuple = (1,)
setup()

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

sample(cond: e3nn_jax.IrrepsArray = None, mask_coord=None, mask_features=None)
q_sample(x0, t: int)