Learnax - JAX-based Training Framework

Learnax is a powerful training framework built on JAX for accelerated machine learning research. It provides a clean, modular approach to training neural networks with built-in support for distributed computing, checkpointing, and experiment tracking.

Key Features

  • JAX-based acceleration: Leverage the power of JAX’s transformations and GPU/TPU acceleration

  • Distributed training: Automatic parallel training across multiple devices using jax.pmap

  • Robust checkpointing: Flexible checkpoint saving and resuming with automatic management

  • Experiment tracking: Seamless integration with Weights & Biases for experiment logging

  • Modular architecture: Clean abstractions for models, losses, training, and evaluation

Core Components

Trainer

The Trainer class is the central component of Learnax. It orchestrates the training process, handling:

  • Model initialization

  • Training loop execution with gradient updates

  • Metrics tracking

  • Checkpointing

  • Distributed training across multiple devices

Example usage:

trainer = Trainer(
    model=my_model,
    learning_rate=1e-4,
    losses=my_loss_pipe,
    seed=42,
    train_dataset=train_data,
    num_epochs=100,
    batch_size=32,
    num_workers=4,
    run=wandb_run
)

trainer.init()
trainer.train()

Checkpointer

The Checkpointer class manages model checkpoints, providing:

  • Automatic saving of periodic checkpoints

  • Retaining the best checkpoints based on metrics

  • Limiting the number of checkpoints to save disk space

  • Simple APIs for loading and resuming training

Example usage:

checkpointer = Checkpointer(
    registry_path="./experiments",
    run_id="experiment_1",
    save_every=100,
    checkpoint_every=1000,
    keep_best=True,
    metric_name="val/loss"
)

# Save checkpoint
checkpointer.maybe_save_checkpoints(train_state, step=1000, metrics={"val/loss": 0.1})

# Load checkpoint
checkpoint_data = checkpointer.load_latest()

Loss System

Learnax provides a modular loss system with:

  • LossFunction: Base class for individual loss components

  • LossPipe: Combines multiple loss functions with weights and scheduling

Example:

loss_pipe = LossPipe([
    MSELoss(weight=1.0),
    RegularizationLoss(weight=0.01, scheduler=lambda step: min(1.0, step/1000))
])

Experiment Registry

The Registry and Run classes help manage experiments:

  • Track configurations, metrics, and results

  • Restore previous runs for analysis or continued training

  • Integrate with Weights & Biases for enhanced visualization

Installation

Install Learnax via pip:

pip install learnax

Requirements:

  • JAX >= 0.3.0

  • Flax >= 0.5.0

  • Optax >= 0.1.0

  • Weights & Biases (optional)

Getting Started

For a quick start, see the following examples:

  • Basic model training

  • Distributed training

  • Custom loss functions

  • Checkpoint management

  • Experiment tracking

API Reference

For detailed API documentation, see the following sections: