Utilities
rydberggpt.utils
¶
create_config_from_yaml(yaml_content: Dict) -> dataclass
¶
Create a dataclass config object from the given YAML content.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
yaml_content |
Dict
|
A dictionary containing the YAML content. |
required |
Returns:
Type | Description |
---|---|
dataclass
|
A dataclass object representing the config. |
Source code in src/rydberggpt/utils.py
create_dataclass_from_dict(name: str, data: Dict[str, Any]) -> Type
¶
Create a dataclass from a dictionary.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str
|
The name of the dataclass. |
required |
data |
Dict[str, Any]
|
A dictionary containing the dataclass fields and their values. |
required |
Returns:
Type | Description |
---|---|
Type
|
A new dataclass with the specified name and fields. |
Source code in src/rydberggpt/utils.py
flatten_yaml(yaml_config: Dict[str, Dict[str, Any]]) -> Dict[str, Any]
¶
Flatten a nested YAML configuration dictionary.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
yaml_config |
Dict[str, Dict[str, Any]]
|
A nested dictionary representing the YAML configuration. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any]
|
Dict[str, Any]: A flattened dictionary with the nested structure removed. |
Source code in src/rydberggpt/utils.py
load_config_file(checkpoint_path: str, config_file: str = 'hparams.yaml') -> str
¶
Load the configuration file associated with a given checkpoint.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
checkpoint_path |
str
|
The path to the checkpoint file. |
required |
config_file |
str
|
The name of the configuration file, defaults to "hparams.yaml". |
'hparams.yaml'
|
Returns:
Type | Description |
---|---|
str
|
The path to the configuration file. |
Raises:
Type | Description |
---|---|
FileNotFoundError
|
If the configuration file is not found in the specified directory. |
Source code in src/rydberggpt/utils.py
load_yaml_file(path: str, yaml_file_name: str) -> Dict[str, Any]
¶
Load the content of a YAML file given its path and file name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
str
|
The path to the directory containing the YAML file. |
required |
yaml_file_name |
str
|
The name of the YAML file. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any]
|
Dict[str, Any]: A dictionary containing the YAML content. |
Source code in src/rydberggpt/utils.py
save_to_yaml(data: Dict[str, Any], filename: str) -> None
¶
Save a dictionary to a file in YAML format.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data |
Dict[str, Any]
|
The dictionary to be saved. |
required |
filename |
str
|
The path to the file where the dictionary will be saved. |
required |
Source code in src/rydberggpt/utils.py
shift_inputs(tensor: torch.Tensor) -> torch.Tensor
¶
Shifts the second dimension (S) of the input tensor by one position to the right and pads the beginning with zeros.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tensor |
Tensor
|
The input tensor of shape [B, S, D]. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
The resulting tensor after the shift and pad operation. |
Source code in src/rydberggpt/utils.py
time_and_log(fn: Callable[..., Any]) -> Callable[..., Any]
¶
Decorator function to measure the time taken by a function to execute and log it.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
fn |
Callable[..., Any]
|
The function to be wrapped. |
required |
Returns:
Type | Description |
---|---|
Callable[..., Any]
|
Callable[..., Any]: The wrapped function. |
Source code in src/rydberggpt/utils.py
to_one_hot(data: Union[torch.Tensor, List[int], Tuple[int]], num_classes: int) -> torch.Tensor
¶
Converts the input data into one-hot representation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data |
Union[Tensor, List[int], Tuple[int]]
|
Input data to be converted into one-hot. It can be a 1D tensor, list or tuple of integers. |
required |
num_classes |
int
|
Number of classes in the one-hot representation. |
required |
Returns:
Name | Type | Description |
---|---|---|
data |
Tensor
|
The one-hot representation of the input data. |
Source code in src/rydberggpt/utils.py
rydberggpt.utils_ckpt
¶
find_best_ckpt(log_dir: str) -> Optional[str]
¶
Find the best checkpoint file (with the lowest training loss) in the specified log directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
log_dir |
str
|
The path to the log directory containing the checkpoint files. |
required |
Returns:
Type | Description |
---|---|
str
|
The path to the checkpoint file with the lowest training loss. |
Source code in src/rydberggpt/utils_ckpt.py
find_latest_ckpt(log_dir: str)
¶
Find the latest checkpoint file (based on modification time) in the specified log directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
log_dir |
str
|
The path to the log directory containing the checkpoint files. |
required |
Returns:
Type | Description |
---|---|
str
|
The path to the latest checkpoint file. |
Source code in src/rydberggpt/utils_ckpt.py
get_ckpt_path(from_ckpt: int, log_dir: str = 'logs/lightning_logs') -> str
¶
Get the checkpoint path from a specified checkpoint version number.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
from_ckpt |
int
|
The version number of the checkpoint. |
required |
log_dir |
str
|
The root directory where checkpoints are stored. Defaults to "logs/lightning_logs". |
'logs/lightning_logs'
|
Returns:
Type | Description |
---|---|
str
|
The path to the specified checkpoint version directory. |
Raises:
Type | Description |
---|---|
FileNotFoundError
|
If no checkpoint is found in the specified directory. |
Source code in src/rydberggpt/utils_ckpt.py
get_model_from_ckpt(log_path: str, model: nn.Module, ckpt: str = 'best', trainer: pl.LightningModule = RydbergGPTTrainer) -> nn.Module
¶
Load a model from a specified checkpoint file in the log directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
log_path |
str
|
The path to the log directory containing the checkpoint files. |
required |
model |
Module
|
The model class to load. |
required |
ckpt |
str
|
The checkpoint to load. Must be either "best" or "latest". Defaults to "best". |
'best'
|
trainer |
LightningModule
|
The trainer class to use for loading the model. Defaults to RydbergGPTTrainer. |
RydbergGPTTrainer
|
Returns:
Type | Description |
---|---|
Module
|
The loaded model. |
Raises:
Type | Description |
---|---|
ValueError
|
If the value of ckpt is not "best" or "latest". |