Training
rydberggpt.training
¶
callbacks
¶
module_info_callback
¶
ModelInfoCallback
¶
Bases: Callback
A custom PyTorch Lightning callback that logs model information at the start of training.
This callback extracts and logs information about the model's structure, total parameters, and total trainable parameters at the beginning of the training process. The information is saved as a YAML file in the logger's log directory.
Source code in src/rydberggpt/training/callbacks/module_info_callback.py
on_train_start(trainer, pl_module) -> None
¶
Run the callback at the beginning of training.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
trainer |
Trainer
|
The PyTorch Lightning trainer instance. |
required |
pl_module |
LightningModule
|
The PyTorch Lightning module instance. |
required |
Source code in src/rydberggpt/training/callbacks/module_info_callback.py
logger
¶
setup_logger(log_path)
¶
Set up the logger to write logs to a file and the console.
Source code in src/rydberggpt/training/logger.py
loss
¶
NLLLoss
¶
Bases: LightningModule
This class implements the Negative Log Likelihood (NLL) loss function as a PyTorch Lightning module.
The NLL loss measures the performance of a classification model where the prediction input is a probability distribution over classes. It is useful in training models for multi-class classification problems.
The loss is calculated by taking the negative log of the probabilities predicted by the model for the true class labels.
Methods:
Name | Description |
---|---|
forward |
Computes the NLL loss based on the conditional log probabilities and the target values. |
Examples:
Source code in src/rydberggpt/training/loss.py
forward(cond_log_probs: Tensor, tgt: Tensor) -> Tensor
¶
Computes the NLL loss based on the conditional log probabilities and the target values.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cond_log_probs |
Tensor
|
The conditional log probabilities predicted by the model. |
required |
tgt |
Tensor
|
The target values. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
The computed NLL loss. |
Source code in src/rydberggpt/training/loss.py
train
¶
trainer
¶
RydbergGPTTrainer
¶
Bases: LightningModule
A custom PyTorch Lightning module for training a Rydberg GPT model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Module
|
The model to be trained. |
required |
config |
dataclass
|
A dataclass containing the model's configuration parameters. |
required |
logger |
TensorBoardLogger
|
A TensorBoard logger instance for logging training progress. |
None
|
example_input_array |
tensor
|
An example input tensor used for generating the model summary. |
None
|
Source code in src/rydberggpt/training/trainer.py
configure_optimizers() -> Dict[str, Union[optim.Optimizer, Dict]]
¶
Configures the optimizer and learning rate scheduler for the RydbergGPTTrainer.
Returns:
Type | Description |
---|---|
Dict[str, Union[Optimizer, Dict]]
|
A dictionary containing the optimizer and lr_scheduler configurations. |
Source code in src/rydberggpt/training/trainer.py
forward(m_onehot: torch.Tensor, cond: torch.Tensor) -> torch.Tensor
¶
Perform a forward pass through the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
m_onehot |
Tensor
|
One-hot encoded measurements tensor. |
required |
cond |
Tensor
|
Conditioning tensor. # TODO prompt |
required |
Returns:
Type | Description |
---|---|
Tensor
|
Conditional log probabilities tensor. |
Source code in src/rydberggpt/training/trainer.py
training_step(batch: torch.Tensor, batch_idx: int) -> torch.Tensor
¶
Perform a single training step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch |
Batch
|
A batch of data during training. |
required |
batch_idx |
int
|
The index of the current batch. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
The training loss for the current batch. |
Source code in src/rydberggpt/training/trainer.py
utils
¶
set_example_input_array(train_loader: DataLoader) -> Tuple[Any, Any]
¶
Get an example input array from the train loader.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
train_loader |
DataLoader
|
The DataLoader instance for the training data. |
required |
Returns:
Type | Description |
---|---|
Tuple[Any, Any]
|
A tuple containing m_onehot and graph from the example batch. |