Skip to content

Data

rydberggpt.data

dataclasses

BaseGraph dataclass

Bases: ABC

A base dataclass representing a graph configuration.

Source code in src/rydberggpt/data/dataclasses.py
@dataclass
class BaseGraph(ABC):
    """A base dataclass representing a graph configuration."""

    num_atoms: int
    graph_name: str
    Rb: float
    delta: float
    omega: float
    beta: float

Batch dataclass

A dataclass representing a batch of graphs

Source code in src/rydberggpt/data/dataclasses.py
@dataclass
class Batch:
    """A dataclass representing a batch of graphs"""

    graph: Data
    m_onehot: torch.Tensor

GridGraph dataclass

Bases: BaseGraph

A dataclass representing the configuration of a grid graph

Source code in src/rydberggpt/data/dataclasses.py
@dataclass
class GridGraph(BaseGraph):
    """A dataclass representing the configuration of a grid graph"""

    n_rows: int
    n_cols: int

custom_collate(batch: List[Batch]) -> Batch

Custom collate function to handle Batch objects when creating a DataLoader.

Parameters:

Name Type Description Default
batch List[Batch]

A list of Batch objects to be collated.

required

Returns:

Type Description
Batch

A single Batch object containing the collated data.

Source code in src/rydberggpt/data/dataclasses.py
def custom_collate(batch: List[Batch]) -> Batch:
    """
    Custom collate function to handle Batch objects when creating a DataLoader.

    Args:
        batch (List[Batch]): A list of Batch objects to be collated.

    Returns:
        (Batch): A single Batch object containing the collated data.
    """

    graph_batch = PyGBatch.from_data_list([b.graph for b in batch])

    # NOTE: The graphs, and measurement data are not of the same size. To ensure
    # a padded tensor suitable for the neural network, we use the to_dense_batch function. This ensures that our
    # data is padded with zeros.
    # see: https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/utils/to_dense_batch.html

    m_onehot = to_dense_batch(
        torch.cat([b.m_onehot for b in batch], axis=-2),
        batch=graph_batch.batch,
    )[0].to(torch.float32)

    return Batch(graph=graph_batch, m_onehot=m_onehot)

graph_structures

generate_grid_graph(n_rows: int, n_cols: int) -> nx.Graph

Generates a fully connected grid graph with weights based on the reciprocal of Euclidean distance. Coordinates is in units of lattice constant a.

Parameters:

Name Type Description Default
n_rows int

The number of rows in the grid.

required
n_cols int

The number of columns in the grid.

required

Returns:

Type Description
Graph

The generated grid graph with node positions and edge weights.

Source code in src/rydberggpt/data/graph_structures.py
def generate_grid_graph(n_rows: int, n_cols: int) -> nx.Graph:
    """
    Generates a fully connected grid graph with weights based on the reciprocal of Euclidean distance. Coordinates is in units of lattice constant a.

    Args:
        n_rows (int): The number of rows in the grid.
        n_cols (int): The number of columns in the grid.

    Returns:
        (nx.Graph): The generated grid graph with node positions and edge weights.
    """

    # Create an empty graph
    graph = nx.Graph()

    # Add nodes with positions as attributes
    for i in range(n_rows):
        for j in range(n_cols):
            node_id = i * n_cols + j
            graph.add_node(node_id, pos=(i, j))

    # Add fully connected edges with weights as the reciprocal of Euclidean distance
    for node1 in graph.nodes:
        pos1 = np.array(graph.nodes[node1]["pos"])
        for node2 in graph.nodes:
            if node1 != node2:
                pos2 = np.array(graph.nodes[node2]["pos"])
                interaction_strength = np.linalg.norm(pos1 - pos2) ** (-6)
                graph.add_edge(node1, node2, weight=interaction_strength)

    return graph

get_graph(config: BaseGraph) -> nx.Graph

Generates a graph based on the given configuration.

Parameters:

Name Type Description Default
config BaseGraph

The graph configuration, an instance of a subclass of the BaseGraph dataclass.

required

Returns:

Type Description
Graph

The generated graph based on the configuration.

Raises:

Type Description
NotImplementedError

If the graph name provided in the configuration is not implemented.

Source code in src/rydberggpt/data/graph_structures.py
def get_graph(config: BaseGraph) -> nx.Graph:
    """
    Generates a graph based on the given configuration.

    Args:
        config (BaseGraph): The graph configuration, an instance of a subclass of the BaseGraph dataclass.

    Returns:
        (nx.Graph): The generated graph based on the configuration.

    Raises:
        NotImplementedError: If the graph name provided in the configuration is not implemented.
    """
    if config.graph_name == "grid_graph":
        graph = generate_grid_graph(config.n_rows, config.n_cols)

    else:
        raise NotImplementedError(f"Graph name {config.graph_name} not implemented.")

    return graph

rydberg_dataset

build_datapipes(root_dir: str, batch_size: int, buffer_size: int)

Builds a data pipeline for processing files from a specified directory.

This function initializes a FileLister to list files from the specified directory and its subdirectories. It then demultiplexes the files into three separate data pipes for processing configuration, dataset, and graph files respectively. The configuration and graph files are opened, parsed as JSON, and processed using a custom selection function. The data pipes are then zipped together, shuffled, filtered, and buffered into batches using a custom collate function.

Parameters:

Name Type Description Default
root_dir str

The root directory from which to list files.

required
batch_size int

The number of samples per batch.

required
buffer_size int

The buffer size to use when buffering data into batches.

required

Returns:

Type Description
IterDataPipe

The final data pipe containing batches of processed data.

Source code in src/rydberggpt/data/rydberg_dataset.py
def build_datapipes(root_dir: str, batch_size: int, buffer_size: int):
    """
    Builds a data pipeline for processing files from a specified directory.

    This function initializes a FileLister to list files from the specified
    directory and its subdirectories. It then demultiplexes the files into
    three separate data pipes for processing configuration, dataset, and
    graph files respectively. The configuration and graph files are opened,
    parsed as JSON, and processed using a custom selection function.
    The data pipes are then zipped together, shuffled, filtered, and buffered
    into batches using a custom collate function.

    Args:
        root_dir (str): The root directory from which to list files.
        batch_size (int): The number of samples per batch.
        buffer_size (int): The buffer size to use when buffering data into batches.

    Returns:
        (IterDataPipe): The final data pipe containing batches of processed data.
    """
    file_lister = FileLister([root_dir], recursive=True)
    config_dp, dataset_dp, graph_dp = file_lister.demux(
        3,
        classify_file_fn,
        drop_none=True,
        buffer_size=-1,
    )
    config_dp = config_dp.open_files().parse_json_files()
    graph_dp = graph_dp.open_files().parse_json_files()
    datapipe = config_dp.zip(dataset_dp).zip(graph_dp).map(map_fn)
    datapipe = datapipe.shuffle()
    datapipe = Buffer(source_datapipe=datapipe, buffer_size=buffer_size)
    datapipe = datapipe.batch(batch_size).collate(custom_collate).sharding_filter()

    return datapipe

utils_graph

batch_pyg_data(data_list: List[Data]) -> Data

Batch a list of PyTorch Geometric Data objects into a single Data object.

Parameters:

Name Type Description Default
data_list List[Data]

List of PyTorch Geometric Data objects.

required

Returns:

Type Description
Data

A single batched Data object containing all input Data objects.

Source code in src/rydberggpt/data/utils_graph.py
def batch_pyg_data(data_list: List[Data]) -> Data:
    """
    Batch a list of PyTorch Geometric Data objects into a single Data object.

    Args:
        data_list: List of PyTorch Geometric Data objects.

    Returns:
        (Data): A single batched Data object containing all input Data objects.
    """
    batched_data = PyGBatch.from_data_list(data_list)
    return batched_data

dict_to_graph(graph_dict: Dict) -> nx.Graph

Create a NetworkX graph from a dictionary.

Parameters:

Name Type Description Default
graph_dict Dict

Dictionary representing a NetworkX graph.

required

Returns:

Type Description
Graph

NetworkX graph object.

Source code in src/rydberggpt/data/utils_graph.py
def dict_to_graph(graph_dict: Dict) -> nx.Graph:
    """
    Create a NetworkX graph from a dictionary.

    Args:
        graph_dict: Dictionary representing a NetworkX graph.

    Returns:
        (nx.Graph): NetworkX graph object.
    """
    graph = nx.node_link_graph(graph_dict)
    return graph

graph_to_dict(graph: nx.Graph) -> Dict

Convert a NetworkX graph to a dictionary.

Parameters:

Name Type Description Default
graph Graph

NetworkX graph object.

required

Returns:

Type Description
Dict

A dictionary representing the NetworkX graph.

Source code in src/rydberggpt/data/utils_graph.py
def graph_to_dict(graph: nx.Graph) -> Dict:
    """
    Convert a NetworkX graph to a dictionary.

    Args:
        graph: NetworkX graph object.

    Returns:
        (Dict): A dictionary representing the NetworkX graph.
    """
    graph_dict = nx.node_link_data(graph)
    return graph_dict

networkx_to_pyg_data(graph: nx.Graph, node_features: torch.Tensor) -> Data

Convert a NetworkX graph to a PyTorch Geometric Data object.

Parameters:

Name Type Description Default
graph Graph

NetworkX graph object.

required

Returns:

Type Description
Data

A PyTorch Geometric Data object representing the input graph.

Source code in src/rydberggpt/data/utils_graph.py
def networkx_to_pyg_data(graph: nx.Graph, node_features: torch.Tensor) -> Data:
    """
    Convert a NetworkX graph to a PyTorch Geometric Data object.

    Args:
        graph: NetworkX graph object.

    Returns:
        (Data): A PyTorch Geometric Data object representing the input graph.
    """

    x = node_features.repeat(len(graph.nodes()), 1)

    # Convert the edge list to a PyTorch Geometric edge_index tensor
    edge_index = torch.tensor(list(graph.edges), dtype=torch.long).t().contiguous()

    # Get edge weights from the graph
    edge_weight = torch.tensor(
        list(nx.get_edge_attributes(graph, "weight").values()), dtype=torch.float
    )

    # Create a Data object
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_weight)

    return data

pyg_graph_data(config, graph_data)

Convert a graph in node-link format to a PyG Data object.

Parameters:

Name Type Description Default
graph_data Dict

The graph in node-link format.

required
config_data Dict

The configuration data for the graph.

required

Returns:

Type Description
Data

The graph as a PyG Data object.

Source code in src/rydberggpt/data/utils_graph.py
def pyg_graph_data(config, graph_data):
    """
    Convert a graph in node-link format to a PyG Data object.

    Args:
        graph_data (Dict): The graph in node-link format.
        config_data (Dict): The configuration data for the graph.

    Returns:
        (Data): The graph as a PyG Data object.

    """
    node_features = torch.tensor(
        [
            config["delta"],
            config["omega"],
            config["beta"],
            config["Rb"],
        ],
        dtype=torch.float32,
    )
    graph_nx = nx.node_link_graph(graph_data)
    pyg_graph = networkx_to_pyg_data(graph_nx, node_features)
    return pyg_graph

read_graph_from_json(file_path: str) -> Dict

Read a JSON file and convert it to a dictionary representing a NetworkX graph.

Parameters:

Name Type Description Default
file_path str

Path to the JSON file to read.

required

Returns:

Type Description
Dict

A dictionary representing a NetworkX graph.

Source code in src/rydberggpt/data/utils_graph.py
def read_graph_from_json(file_path: str) -> Dict:
    """
    Read a JSON file and convert it to a dictionary representing a NetworkX graph.

    Args:
        file_path: Path to the JSON file to read.

    Returns:
        (Dict): A dictionary representing a NetworkX graph.
    """
    with open(file_path, "r") as f:
        graph_dict = json.load(f)
    return graph_dict

save_graph_to_json(graph_dict: Dict, file_path: str) -> None

Save a dictionary representing a NetworkX graph to a JSON file.

Parameters:

Name Type Description Default
graph_dict Dict

Dictionary representing a NetworkX graph.

required
file_path str

Path to the JSON file to save.

required
Source code in src/rydberggpt/data/utils_graph.py
def save_graph_to_json(graph_dict: Dict, file_path: str) -> None:
    """
    Save a dictionary representing a NetworkX graph to a JSON file.

    Args:
        graph_dict: Dictionary representing a NetworkX graph.
        file_path: Path to the JSON file to save.
    """
    with open(file_path, "w") as f:
        json.dump(graph_dict, f)