tensorclouds.nn.utils¶
Classes¶
dict() -> new empty dictionary |
|
Functions¶
|
Compute the l2 norm of a pytree of arrays. Useful for weight decay. |
|
Clip gradients stored as a pytree of arrays to maximum norm max_norm. |
|
|
|
|
|
|
|
|
|
next multiple of factor |
|
output size of a convolutional layer |
|
output size of a convolutional layer |
|
safe_norm(x) = norm(x) if norm(x) != 0 else 1.0 |
|
Module Contents¶
- class tensorclouds.nn.utils.dotdict¶
Bases:
dict
dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object’s
(key, value) pairs
- dict(iterable) -> new dictionary initialized as if via:
d = {} for k, v in iterable:
d[k] = v
- dict(**kwargs) -> new dictionary initialized with the name=value pairs
in the keyword argument list. For example: dict(one=1, two=2)
- tensorclouds.nn.utils.l2_norm(tree)¶
Compute the l2 norm of a pytree of arrays. Useful for weight decay.
- tensorclouds.nn.utils.clip_grads(grad_tree, max_norm)¶
Clip gradients stored as a pytree of arrays to maximum norm max_norm.
- tensorclouds.nn.utils.inner_stack(pytrees)¶
- tensorclouds.nn.utils.inner_split(pytree)¶
- class tensorclouds.nn.utils.ModelOutput¶
- datum: moleculib.protein.datum.ProteinDatum¶
- encoder_internals: List[tensorclouds.tensorcloud.TensorCloud]¶
- decoder_internals: List[tensorclouds.tensorcloud.TensorCloud]¶
- atom_perm_loss: jax.Array¶
- diff_loss: List[jax.Array]¶
- tensorclouds.nn.utils.rescale_irreps(irreps: e3nn_jax.Irreps, rescale: float, chunk_factor: int = 0)¶
- tensorclouds.nn.utils.multiscale_irreps(irreps: e3nn_jax.Irreps, depth: int, rescale: float, chunk_factor: int = 0) List[e3nn_jax.Irreps] ¶
- tensorclouds.nn.utils.next_multiple(x: int, factor: int) int ¶
next multiple of factor
- tensorclouds.nn.utils.up_conv_seq_len(size: int, kernel: int, stride: int, mode: str) int ¶
output size of a convolutional layer
- tensorclouds.nn.utils.down_conv_seq_len(size: int, kernel: int, stride: int, mode: str) int ¶
output size of a convolutional layer
- tensorclouds.nn.utils.safe_norm(vector: jax.Array, axis: int = -1) jax.Array ¶
safe_norm(x) = norm(x) if norm(x) != 0 else 1.0
- tensorclouds.nn.utils.safe_normalize(vector: jax.Array) jax.Array ¶