Skip to content

Utilities

rydberggpt.utils

create_config_from_yaml(yaml_content: Dict) -> dataclass

Create a dataclass config object from the given YAML content.

Parameters:

Name Type Description Default
yaml_content Dict

A dictionary containing the YAML content.

required

Returns:

Type Description
dataclass

A dataclass object representing the config.

Source code in src/rydberggpt/utils.py
def create_config_from_yaml(yaml_content: Dict) -> dataclass:
    """
    Create a dataclass config object from the given YAML content.

    Args:
        yaml_content (Dict): A dictionary containing the YAML content.

    Returns:
        (dataclass): A dataclass object representing the config.
    """
    flattened_config = flatten_yaml(yaml_content)
    Config = create_dataclass_from_dict("Config", flattened_config)
    return Config(**flattened_config)

create_dataclass_from_dict(name: str, data: Dict[str, Any]) -> Type

Create a dataclass from a dictionary.

Parameters:

Name Type Description Default
name str

The name of the dataclass.

required
data Dict[str, Any]

A dictionary containing the dataclass fields and their values.

required

Returns:

Type Description
Type

A new dataclass with the specified name and fields.

Source code in src/rydberggpt/utils.py
def create_dataclass_from_dict(name: str, data: Dict[str, Any]) -> Type:
    """
    Create a dataclass from a dictionary.

    Args:
        name (str): The name of the dataclass.
        data (Dict[str, Any]): A dictionary containing the dataclass fields and their values.

    Returns:
        (Type): A new dataclass with the specified name and fields.
    """
    fields = [(key, type(value)) for key, value in data.items()]
    return make_dataclass(name, fields)

flatten_yaml(yaml_config: Dict[str, Dict[str, Any]]) -> Dict[str, Any]

Flatten a nested YAML configuration dictionary.

Parameters:

Name Type Description Default
yaml_config Dict[str, Dict[str, Any]]

A nested dictionary representing the YAML configuration.

required

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: A flattened dictionary with the nested structure removed.

Source code in src/rydberggpt/utils.py
def flatten_yaml(yaml_config: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
    """
    Flatten a nested YAML configuration dictionary.

    Args:
        yaml_config (Dict[str, Dict[str, Any]]): A nested dictionary representing the YAML configuration.

    Returns:
        Dict[str, Any]: A flattened dictionary with the nested structure removed.
    """
    flattened_config = {}
    for section, section_values in yaml_config.items():
        if isinstance(section_values, dict):
            for key, value in section_values.items():
                flattened_config[f"{key}"] = value
        else:
            flattened_config[section] = section_values
    return flattened_config

load_config_file(checkpoint_path: str, config_file: str = 'hparams.yaml') -> str

Load the configuration file associated with a given checkpoint.

Parameters:

Name Type Description Default
checkpoint_path str

The path to the checkpoint file.

required
config_file str

The name of the configuration file, defaults to "hparams.yaml".

'hparams.yaml'

Returns:

Type Description
str

The path to the configuration file.

Raises:

Type Description
FileNotFoundError

If the configuration file is not found in the specified directory.

Source code in src/rydberggpt/utils.py
def load_config_file(checkpoint_path: str, config_file: str = "hparams.yaml") -> str:
    """
    Load the configuration file associated with a given checkpoint.

    Args:
        checkpoint_path (str): The path to the checkpoint file.
        config_file (str, optional): The name of the configuration file, defaults to "hparams.yaml".

    Returns:
        (str): The path to the configuration file.

    Raises:
        FileNotFoundError: If the configuration file is not found in the specified directory.
    """
    config_dir = os.path.dirname(os.path.dirname(checkpoint_path))

    if not os.path.exists(os.path.join(config_dir, config_file)):
        raise FileNotFoundError(f"No config file found in {config_dir}")

    return os.path.join(config_dir, config_file)

load_yaml_file(path: str, yaml_file_name: str) -> Dict[str, Any]

Load the content of a YAML file given its path and file name.

Parameters:

Name Type Description Default
path str

The path to the directory containing the YAML file.

required
yaml_file_name str

The name of the YAML file.

required

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: A dictionary containing the YAML content.

Source code in src/rydberggpt/utils.py
def load_yaml_file(path: str, yaml_file_name: str) -> Dict[str, Any]:
    """
    Load the content of a YAML file given its path and file name.

    Args:
        path (str): The path to the directory containing the YAML file.
        yaml_file_name (str): The name of the YAML file.

    Returns:
        Dict[str, Any]: A dictionary containing the YAML content.
    """
    if not yaml_file_name.endswith(".yaml"):
        yaml_file_name += ".yaml"

    yaml_path = os.path.join(path, yaml_file_name)
    with open(yaml_path, "r") as file:
        yaml_content = yaml.safe_load(file)
    return yaml_content

save_to_yaml(data: Dict[str, Any], filename: str) -> None

Save a dictionary to a file in YAML format.

Parameters:

Name Type Description Default
data Dict[str, Any]

The dictionary to be saved.

required
filename str

The path to the file where the dictionary will be saved.

required
Source code in src/rydberggpt/utils.py
def save_to_yaml(data: Dict[str, Any], filename: str) -> None:
    """
    Save a dictionary to a file in YAML format.

    Args:
        data (Dict[str, Any]): The dictionary to be saved.
        filename (str): The path to the file where the dictionary will be saved.
    """
    with open(filename, "w") as file:
        yaml.dump(data, file)

shift_inputs(tensor: torch.Tensor) -> torch.Tensor

Shifts the second dimension (S) of the input tensor by one position to the right and pads the beginning with zeros.

Parameters:

Name Type Description Default
tensor Tensor

The input tensor of shape [B, S, D].

required

Returns:

Type Description
Tensor

The resulting tensor after the shift and pad operation.

Source code in src/rydberggpt/utils.py
def shift_inputs(tensor: torch.Tensor) -> torch.Tensor:
    """
    Shifts the second dimension (S) of the input tensor by one position to the right
    and pads the beginning with zeros.

    Args:
        tensor (torch.Tensor): The input tensor of shape [B, S, D].

    Returns:
        (torch.Tensor): The resulting tensor after the shift and pad operation.
    """
    B, _, D = tensor.size()
    zero_padding = torch.zeros((B, 1, D), device=tensor.device, dtype=tensor.dtype)
    shifted_tensor = torch.cat((zero_padding, tensor[:, :-1, :]), dim=1)
    return shifted_tensor

time_and_log(fn: Callable[..., Any]) -> Callable[..., Any]

Decorator function to measure the time taken by a function to execute and log it.

Parameters:

Name Type Description Default
fn Callable[..., Any]

The function to be wrapped.

required

Returns:

Type Description
Callable[..., Any]

Callable[..., Any]: The wrapped function.

Usage
@time_and_log
def my_function(arg1, arg2):
    # function logic here
Source code in src/rydberggpt/utils.py
def time_and_log(fn: Callable[..., Any]) -> Callable[..., Any]:
    """
    Decorator function to measure the time taken by a function to execute and log it.

    Args:
        fn (Callable[..., Any]): The function to be wrapped.

    Returns:
        Callable[..., Any]: The wrapped function.

    Usage:
        ```py
        @time_and_log
        def my_function(arg1, arg2):
            # function logic here
        ```
    """

    def wrapped(*args: Any, **kwargs: Any) -> Any:
        start_time = time.time()
        result = fn(*args, **kwargs)
        elapsed_time = time.time() - start_time

        # Convert elapsed time to HH:MM:SS format
        formatted_time = str(timedelta(seconds=elapsed_time))

        logging.info(f"{fn.__name__} took {formatted_time} to run.")
        return result

    return wrapped

to_one_hot(data: Union[torch.Tensor, List[int], Tuple[int]], num_classes: int) -> torch.Tensor

Converts the input data into one-hot representation.

Parameters:

Name Type Description Default
data Union[Tensor, List[int], Tuple[int]]

Input data to be converted into one-hot. It can be a 1D tensor, list or tuple of integers.

required
num_classes int

Number of classes in the one-hot representation.

required

Returns:

Name Type Description
data Tensor

The one-hot representation of the input data.

Source code in src/rydberggpt/utils.py
def to_one_hot(
    data: Union[torch.Tensor, List[int], Tuple[int]], num_classes: int
) -> torch.Tensor:
    """
    Converts the input data into one-hot representation.

    Args:
        data: Input data to be converted into one-hot. It can be a 1D tensor, list or tuple of integers.
        num_classes: Number of classes in the one-hot representation.

    Returns:
        data (torch.Tensor): The one-hot representation of the input data.
    """

    if isinstance(data, (list, tuple)):
        data = torch.tensor(data, dtype=torch.int64)
    elif not isinstance(data, torch.Tensor):
        raise TypeError("Input data must be a tensor, list or tuple of integers.")

    data = nn.functional.one_hot(data.long(), num_classes)

    return data.to(torch.float)

rydberggpt.utils_ckpt

find_best_ckpt(log_dir: str) -> Optional[str]

Find the best checkpoint file (with the lowest training loss) in the specified log directory.

Parameters:

Name Type Description Default
log_dir str

The path to the log directory containing the checkpoint files.

required

Returns:

Type Description
str

The path to the checkpoint file with the lowest training loss.

Source code in src/rydberggpt/utils_ckpt.py
def find_best_ckpt(log_dir: str) -> Optional[str]:
    """
    Find the best checkpoint file (with the lowest training loss) in the specified log directory.

    Args:
        log_dir (str): The path to the log directory containing the checkpoint files.

    Returns:
        (str): The path to the checkpoint file with the lowest training loss.
    """
    log_dir = os.path.join(log_dir, "checkpoints")
    ckpt_files = [file for file in os.listdir(log_dir) if file.endswith(".ckpt")]

    if not ckpt_files:
        raise FileNotFoundError(f"No checkpoint files found in {log_dir}")

    # Extract the training loss from the ckpt filenames
    ckpt_losses = []
    for file in ckpt_files:
        match = re.search(r"train_loss=(\d+\.\d+)", file)
        if match:
            ckpt_losses.append(float(match.group(1)))
        else:
            ckpt_losses.append(float("inf"))

    # Find the index of the ckpt with the lowest training loss
    best_ckpt_index = ckpt_losses.index(min(ckpt_losses))
    best_ckpt = ckpt_files[best_ckpt_index]

    return os.path.join(log_dir, best_ckpt)

find_latest_ckpt(log_dir: str)

Find the latest checkpoint file (based on modification time) in the specified log directory.

Parameters:

Name Type Description Default
log_dir str

The path to the log directory containing the checkpoint files.

required

Returns:

Type Description
str

The path to the latest checkpoint file.

Source code in src/rydberggpt/utils_ckpt.py
def find_latest_ckpt(log_dir: str):
    """
    Find the latest checkpoint file (based on modification time) in the specified log directory.

    Args:
        log_dir (str): The path to the log directory containing the checkpoint files.

    Returns:
        (str): The path to the latest checkpoint file.
    """
    log_dir = os.path.join(log_dir, "checkpoints")
    ckpt_files = [file for file in os.listdir(log_dir) if file.endswith(".ckpt")]

    if not ckpt_files:
        raise FileNotFoundError(f"No checkpoint files found in {log_dir}")

    ckpt_files.sort(key=lambda x: os.path.getmtime(os.path.join(log_dir, x)))
    latest_ckpt = ckpt_files[-1]
    return os.path.join(log_dir, latest_ckpt)

get_ckpt_path(from_ckpt: int, log_dir: str = 'logs/lightning_logs') -> str

Get the checkpoint path from a specified checkpoint version number.

Parameters:

Name Type Description Default
from_ckpt int

The version number of the checkpoint.

required
log_dir str

The root directory where checkpoints are stored. Defaults to "logs/lightning_logs".

'logs/lightning_logs'

Returns:

Type Description
str

The path to the specified checkpoint version directory.

Raises:

Type Description
FileNotFoundError

If no checkpoint is found in the specified directory.

Source code in src/rydberggpt/utils_ckpt.py
def get_ckpt_path(from_ckpt: int, log_dir: str = "logs/lightning_logs") -> str:
    """
    Get the checkpoint path from a specified checkpoint version number.

    Args:
        from_ckpt (int): The version number of the checkpoint.
        log_dir (str, optional): The root directory where checkpoints are stored.
                                Defaults to "logs/lightning_logs".

    Returns:
        (str): The path to the specified checkpoint version directory.

    Raises:
        FileNotFoundError: If no checkpoint is found in the specified directory.
    """
    log_dir = os.path.join(log_dir, f"version_{from_ckpt}")

    if log_dir is None:
        raise FileNotFoundError(f"No checkpoint found in {log_dir}")

    return log_dir

get_model_from_ckpt(log_path: str, model: nn.Module, ckpt: str = 'best', trainer: pl.LightningModule = RydbergGPTTrainer) -> nn.Module

Load a model from a specified checkpoint file in the log directory.

Parameters:

Name Type Description Default
log_path str

The path to the log directory containing the checkpoint files.

required
model Module

The model class to load.

required
ckpt str

The checkpoint to load. Must be either "best" or "latest". Defaults to "best".

'best'
trainer LightningModule

The trainer class to use for loading the model. Defaults to RydbergGPTTrainer.

RydbergGPTTrainer

Returns:

Type Description
Module

The loaded model.

Raises:

Type Description
ValueError

If the value of ckpt is not "best" or "latest".

Source code in src/rydberggpt/utils_ckpt.py
def get_model_from_ckpt(
    log_path: str,
    model: nn.Module,
    ckpt: str = "best",
    trainer: pl.LightningModule = RydbergGPTTrainer,
) -> nn.Module:
    """
    Load a model from a specified checkpoint file in the log directory.

    Args:
        log_path (str): The path to the log directory containing the checkpoint files.
        model (nn.Module): The model class to load.
        ckpt (str, optional): The checkpoint to load. Must be either "best" or "latest". Defaults to "best".
        trainer (pl.LightningModule, optional): The trainer class to use for loading the model. Defaults to RydbergGPTTrainer.

    Returns:
        (nn.Module): The loaded model.

    Raises:
        ValueError: If the value of ckpt is not "best" or "latest".
    """
    if ckpt == "best":
        ckpt_path = find_best_ckpt(log_path)
    elif ckpt == "last":
        ckpt_path = find_latest_ckpt(log_path)
    else:
        raise ValueError(f"ckpt must be 'best' or 'latest', not {ckpt}")

    yaml_dict = load_yaml_file(log_path, "hparams.yaml")
    config = create_config_from_yaml(yaml_dict)

    rydberg_gpt_trainer = trainer.load_from_checkpoint(
        ckpt_path,
        model=model,
        config=config,
        logger=None,
        example_input_array=None,
    )
    return rydberg_gpt_trainer.model