Skip to content

Training

rydberggpt.training

callbacks

module_info_callback

ModelInfoCallback

Bases: Callback

A custom PyTorch Lightning callback that logs model information at the start of training.

This callback extracts and logs information about the model's structure, total parameters, and total trainable parameters at the beginning of the training process. The information is saved as a YAML file in the logger's log directory.

Source code in src/rydberggpt/training/callbacks/module_info_callback.py
class ModelInfoCallback(Callback):
    """
    A custom PyTorch Lightning callback that logs model information at the start of training.

    This callback extracts and logs information about the model's structure, total parameters, and
    total trainable parameters at the beginning of the training process. The information is saved
    as a YAML file in the logger's log directory.
    """

    def on_train_start(self, trainer, pl_module) -> None:
        """
        Run the callback at the beginning of training.

        Args:
            trainer (pytorch_lightning.Trainer): The PyTorch Lightning trainer instance.
            pl_module (pytorch_lightning.LightningModule): The PyTorch Lightning module instance.
        """
        # This will run at the beginning of training
        log_path = trainer.logger.log_dir

        summary = ModelSummary(pl_module, max_depth=1)
        total_parameters = summary.total_parameters
        total_trainable_parameters = summary.trainable_parameters

        summary_dict = extract_model_info(pl_module.model)
        summary_dict["total_parameters"] = total_parameters
        summary_dict["total_trainable_parameters"] = total_trainable_parameters

        # Save the summary dictionary to a YAML file
        with open(f"{log_path}/model_info.yaml", "w") as file:
            yaml.dump(summary_dict, file)
on_train_start(trainer, pl_module) -> None

Run the callback at the beginning of training.

Parameters:

Name Type Description Default
trainer Trainer

The PyTorch Lightning trainer instance.

required
pl_module LightningModule

The PyTorch Lightning module instance.

required
Source code in src/rydberggpt/training/callbacks/module_info_callback.py
def on_train_start(self, trainer, pl_module) -> None:
    """
    Run the callback at the beginning of training.

    Args:
        trainer (pytorch_lightning.Trainer): The PyTorch Lightning trainer instance.
        pl_module (pytorch_lightning.LightningModule): The PyTorch Lightning module instance.
    """
    # This will run at the beginning of training
    log_path = trainer.logger.log_dir

    summary = ModelSummary(pl_module, max_depth=1)
    total_parameters = summary.total_parameters
    total_trainable_parameters = summary.trainable_parameters

    summary_dict = extract_model_info(pl_module.model)
    summary_dict["total_parameters"] = total_parameters
    summary_dict["total_trainable_parameters"] = total_trainable_parameters

    # Save the summary dictionary to a YAML file
    with open(f"{log_path}/model_info.yaml", "w") as file:
        yaml.dump(summary_dict, file)

logger

setup_logger(log_path)

Set up the logger to write logs to a file and the console.

Source code in src/rydberggpt/training/logger.py
def setup_logger(log_path):
    """
    Set up the logger to write logs to a file and the console.
    """
    # Ensure the log_path exists
    if not os.path.exists(log_path):
        os.makedirs(log_path)

    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # Console Handler
    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)
    formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
    ch.setFormatter(formatter)
    logger.addHandler(ch)

    # File Handler
    fh = logging.FileHandler(os.path.join(log_path, "training.log"))
    fh.setLevel(logging.INFO)
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    return logger

loss

NLLLoss

Bases: LightningModule

This class implements the Negative Log Likelihood (NLL) loss function as a PyTorch Lightning module.

The NLL loss measures the performance of a classification model where the prediction input is a probability distribution over classes. It is useful in training models for multi-class classification problems.

The loss is calculated by taking the negative log of the probabilities predicted by the model for the true class labels.

Methods:

Name Description
forward

Computes the NLL loss based on the conditional log probabilities and the target values.

Examples:

>>> nll_loss = NLLLoss()
>>> loss = nll_loss(cond_log_probs, tgt)
Source code in src/rydberggpt/training/loss.py
class NLLLoss(pl.LightningModule):
    """
    This class implements the Negative Log Likelihood (NLL) loss function as a PyTorch Lightning module.

    The NLL loss measures the performance of a classification model where the prediction input is a probability
    distribution over classes. It is useful in training models for multi-class classification problems.

    The loss is calculated by taking the negative log of the probabilities predicted by the model for the true class labels.

    Methods:
        forward:
            Computes the NLL loss based on the conditional log probabilities and the target values.

    Examples:
        >>> nll_loss = NLLLoss()
        >>> loss = nll_loss(cond_log_probs, tgt)
    """

    def __init__(self):
        super(NLLLoss, self).__init__()

    def forward(self, cond_log_probs: Tensor, tgt: Tensor) -> Tensor:
        """
        Computes the NLL loss based on the conditional log probabilities and the target values.

        Args:
            cond_log_probs (Tensor): The conditional log probabilities predicted by the model.
            tgt (Tensor): The target values.

        Returns:
            (Tensor): The computed NLL loss.
        """
        num_atoms = tgt.shape[-2] - (tgt == 0.0).all(-1).sum(-1)
        log_probs = (cond_log_probs * tgt).sum(dim=(-2, -1))
        loss = -torch.mean(log_probs / num_atoms)
        return loss
forward(cond_log_probs: Tensor, tgt: Tensor) -> Tensor

Computes the NLL loss based on the conditional log probabilities and the target values.

Parameters:

Name Type Description Default
cond_log_probs Tensor

The conditional log probabilities predicted by the model.

required
tgt Tensor

The target values.

required

Returns:

Type Description
Tensor

The computed NLL loss.

Source code in src/rydberggpt/training/loss.py
def forward(self, cond_log_probs: Tensor, tgt: Tensor) -> Tensor:
    """
    Computes the NLL loss based on the conditional log probabilities and the target values.

    Args:
        cond_log_probs (Tensor): The conditional log probabilities predicted by the model.
        tgt (Tensor): The target values.

    Returns:
        (Tensor): The computed NLL loss.
    """
    num_atoms = tgt.shape[-2] - (tgt == 0.0).all(-1).sum(-1)
    log_probs = (cond_log_probs * tgt).sum(dim=(-2, -1))
    loss = -torch.mean(log_probs / num_atoms)
    return loss

train

trainer

RydbergGPTTrainer

Bases: LightningModule

A custom PyTorch Lightning module for training a Rydberg GPT model.

Parameters:

Name Type Description Default
model Module

The model to be trained.

required
config dataclass

A dataclass containing the model's configuration parameters.

required
logger TensorBoardLogger

A TensorBoard logger instance for logging training progress.

None
example_input_array tensor

An example input tensor used for generating the model summary.

None
Source code in src/rydberggpt/training/trainer.py
class RydbergGPTTrainer(pl.LightningModule):
    """
    A custom PyTorch Lightning module for training a Rydberg GPT model.

    Args:
        model (nn.Module): The model to be trained.
        config (dataclass): A dataclass containing the model's configuration parameters.
        logger (TensorBoardLogger): A TensorBoard logger instance for logging training progress.
        example_input_array (torch.tensor, optional): An example input tensor used for
            generating the model summary.
    """

    def __init__(
        self,
        model: nn.Module,
        config: dataclass,
        logger: TensorBoardLogger = None,
        example_input_array: torch.tensor = None,
    ) -> None:
        super().__init__()
        self.config = config
        self.save_hyperparameters(asdict(config))
        self.model = model
        self.criterion = getattr(loss, self.config.criterion)()
        self.example_input_array = example_input_array

    def forward(self, m_onehot: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
        """
        Perform a forward pass through the model.

        Args:
            m_onehot (torch.Tensor): One-hot encoded measurements tensor.
            cond (torch.Tensor): Conditioning tensor. # TODO prompt

        Returns:
            (torch.Tensor): Conditional log probabilities tensor.
        """
        out = self.model.forward(m_onehot, cond)
        cond_log_probs = self.model.generator(out)
        return cond_log_probs

    def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
        """
        Perform a single training step.

        Args:
            batch (pl.Batch): A batch of data during training.
            batch_idx (int): The index of the current batch.

        Returns:
            (torch.Tensor): The training loss for the current batch.
        """
        m_shifted_onehot = shift_inputs(batch.m_onehot)

        cond_log_probs = self.forward(m_shifted_onehot, batch.graph)
        loss = self.criterion(cond_log_probs, batch.m_onehot)
        self.log("train_loss", loss, sync_dist=True)
        return loss

    def configure_optimizers(self) -> Dict[str, Union[optim.Optimizer, Dict]]:
        """
        Configures the optimizer and learning rate scheduler for the RydbergGPTTrainer.

        Returns:
            (Dict[str, Union[optim.Optimizer, Dict]]): A dictionary containing the optimizer and lr_scheduler configurations.
        """
        optimizer_class = getattr(optim, self.config.optimizer)
        optimizer = optimizer_class(
            self.model.parameters(), lr=self.config.learning_rate
        )

        # Add learning rate scheduler
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer,
            T_0=self.config.t_initial,  # initial number of epochs in a period
            T_mult=self.config.t_mult,  # factor to increase the period length after each restart
            eta_min=self.config.eta_min,  # minimum learning rate
        )

        # Return both the optimizer and the scheduler
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",
                "monitor": "train_loss",
            },
        }
configure_optimizers() -> Dict[str, Union[optim.Optimizer, Dict]]

Configures the optimizer and learning rate scheduler for the RydbergGPTTrainer.

Returns:

Type Description
Dict[str, Union[Optimizer, Dict]]

A dictionary containing the optimizer and lr_scheduler configurations.

Source code in src/rydberggpt/training/trainer.py
def configure_optimizers(self) -> Dict[str, Union[optim.Optimizer, Dict]]:
    """
    Configures the optimizer and learning rate scheduler for the RydbergGPTTrainer.

    Returns:
        (Dict[str, Union[optim.Optimizer, Dict]]): A dictionary containing the optimizer and lr_scheduler configurations.
    """
    optimizer_class = getattr(optim, self.config.optimizer)
    optimizer = optimizer_class(
        self.model.parameters(), lr=self.config.learning_rate
    )

    # Add learning rate scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer,
        T_0=self.config.t_initial,  # initial number of epochs in a period
        T_mult=self.config.t_mult,  # factor to increase the period length after each restart
        eta_min=self.config.eta_min,  # minimum learning rate
    )

    # Return both the optimizer and the scheduler
    return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": scheduler,
            "interval": "epoch",
            "monitor": "train_loss",
        },
    }
forward(m_onehot: torch.Tensor, cond: torch.Tensor) -> torch.Tensor

Perform a forward pass through the model.

Parameters:

Name Type Description Default
m_onehot Tensor

One-hot encoded measurements tensor.

required
cond Tensor

Conditioning tensor. # TODO prompt

required

Returns:

Type Description
Tensor

Conditional log probabilities tensor.

Source code in src/rydberggpt/training/trainer.py
def forward(self, m_onehot: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
    """
    Perform a forward pass through the model.

    Args:
        m_onehot (torch.Tensor): One-hot encoded measurements tensor.
        cond (torch.Tensor): Conditioning tensor. # TODO prompt

    Returns:
        (torch.Tensor): Conditional log probabilities tensor.
    """
    out = self.model.forward(m_onehot, cond)
    cond_log_probs = self.model.generator(out)
    return cond_log_probs
training_step(batch: torch.Tensor, batch_idx: int) -> torch.Tensor

Perform a single training step.

Parameters:

Name Type Description Default
batch Batch

A batch of data during training.

required
batch_idx int

The index of the current batch.

required

Returns:

Type Description
Tensor

The training loss for the current batch.

Source code in src/rydberggpt/training/trainer.py
def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
    """
    Perform a single training step.

    Args:
        batch (pl.Batch): A batch of data during training.
        batch_idx (int): The index of the current batch.

    Returns:
        (torch.Tensor): The training loss for the current batch.
    """
    m_shifted_onehot = shift_inputs(batch.m_onehot)

    cond_log_probs = self.forward(m_shifted_onehot, batch.graph)
    loss = self.criterion(cond_log_probs, batch.m_onehot)
    self.log("train_loss", loss, sync_dist=True)
    return loss

utils

set_example_input_array(train_loader: DataLoader) -> Tuple[Any, Any]

Get an example input array from the train loader.

Parameters:

Name Type Description Default
train_loader DataLoader

The DataLoader instance for the training data.

required

Returns:

Type Description
Tuple[Any, Any]

A tuple containing m_onehot and graph from the example batch.

Source code in src/rydberggpt/training/utils.py
def set_example_input_array(train_loader: DataLoader) -> Tuple[Any, Any]:
    """
    Get an example input array from the train loader.

    Args:
        train_loader (DataLoader): The DataLoader instance for the training data.

    Returns:
        (Tuple[Any, Any]): A tuple containing m_onehot and graph from the example batch.
    """
    logging.info("Setting example input array...")
    example_batch = next(iter(train_loader))
    return example_batch.m_onehot, example_batch.graph