tensorclouds.nn.utils

Classes

dotdict

dict() -> new empty dictionary

ModelOutput

EmbeddingsDataset

Functions

l2_norm(tree)

Compute the l2 norm of a pytree of arrays. Useful for weight decay.

clip_grads(grad_tree, max_norm)

Clip gradients stored as a pytree of arrays to maximum norm max_norm.

inner_stack(pytrees)

inner_split(pytree)

rescale_irreps(irreps, rescale[, chunk_factor])

multiscale_irreps(→ List[e3nn_jax.Irreps])

next_multiple(→ int)

next multiple of factor

up_conv_seq_len(→ int)

output size of a convolutional layer

down_conv_seq_len(→ int)

output size of a convolutional layer

safe_norm(→ jax.Array)

safe_norm(x) = norm(x) if norm(x) != 0 else 1.0

safe_normalize(→ jax.Array)

Module Contents

class tensorclouds.nn.utils.dotdict

Bases: dict

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object’s

(key, value) pairs

dict(iterable) -> new dictionary initialized as if via:

d = {} for k, v in iterable:

d[k] = v

dict(**kwargs) -> new dictionary initialized with the name=value pairs

in the keyword argument list. For example: dict(one=1, two=2)

tensorclouds.nn.utils.l2_norm(tree)

Compute the l2 norm of a pytree of arrays. Useful for weight decay.

tensorclouds.nn.utils.clip_grads(grad_tree, max_norm)

Clip gradients stored as a pytree of arrays to maximum norm max_norm.

tensorclouds.nn.utils.inner_stack(pytrees)
tensorclouds.nn.utils.inner_split(pytree)
class tensorclouds.nn.utils.ModelOutput
datum: moleculib.protein.datum.ProteinDatum
encoder_internals: List[tensorclouds.tensorcloud.TensorCloud]
decoder_internals: List[tensorclouds.tensorcloud.TensorCloud]
atom_perm_loss: jax.Array
diff_loss: List[jax.Array]
tensorclouds.nn.utils.rescale_irreps(irreps: e3nn_jax.Irreps, rescale: float, chunk_factor: int = 0)
tensorclouds.nn.utils.multiscale_irreps(irreps: e3nn_jax.Irreps, depth: int, rescale: float, chunk_factor: int = 0) List[e3nn_jax.Irreps]
tensorclouds.nn.utils.next_multiple(x: int, factor: int) int

next multiple of factor

tensorclouds.nn.utils.up_conv_seq_len(size: int, kernel: int, stride: int, mode: str) int

output size of a convolutional layer

tensorclouds.nn.utils.down_conv_seq_len(size: int, kernel: int, stride: int, mode: str) int

output size of a convolutional layer

tensorclouds.nn.utils.safe_norm(vector: jax.Array, axis: int = -1) jax.Array

safe_norm(x) = norm(x) if norm(x) != 0 else 1.0

tensorclouds.nn.utils.safe_normalize(vector: jax.Array) jax.Array
class tensorclouds.nn.utils.EmbeddingsDataset(path, transform=[])

Bases: moleculib.abstract.dataset.PreProcessedDataset

path