Skip to content

graphchem.data.MoleculeGraph

Bases: Data

A custom graph class representing a molecular structure.

This class extends the Data class from PyTorch Geometric to represent molecules with node attributes (atoms), edge attributes (bonds), and connectivity information. It also includes an optional target value.

Attributes

x : torch.Tensor The node features (atom attributes). edge_index : torch.Tensor A 2D tensor describing the connectivity between atoms. edge_attr : torch.Tensor Edge features (bond attributes). y : torch.Tensor Target value(s) of the molecule.

Source code in graphchem/data/structs.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
class MoleculeGraph(Data):
    """
    A custom graph class representing a molecular structure.

    This class extends the `Data` class from PyTorch Geometric to represent
    molecules with node attributes (atoms), edge attributes (bonds), and
    connectivity information. It also includes an optional target value.

    Attributes
    ----------
    x : torch.Tensor
        The node features (atom attributes).
    edge_index : torch.Tensor
        A 2D tensor describing the connectivity between atoms.
    edge_attr : torch.Tensor
        Edge features (bond attributes).
    y : torch.Tensor
        Target value(s) of the molecule.
    """

    def __init__(self, atom_attr: torch.Tensor,
                 bond_attr: torch.Tensor,
                 connectivity: torch.Tensor,
                 target: Optional[torch.Tensor] = None):
        """
        Initialize the MoleculeGraph object.

        Parameters
        ----------
        atom_attr : torch.Tensor
            A 2D tensor of shape (num_atoms, num_atom_features) representing
            the attributes of each atom in the molecule.
        bond_attr : torch.Tensor
            A 2D tensor of shape (num_bonds, num_bond_features) representing
            the attributes of each bond in the molecule.
        connectivity : torch.Tensor
            A 2D tensor of shape (2, num_bonds) where each column represents an
            edge (bond) between two atoms. The first row contains the source
            atom indices and the second row contains the target atom indices.
        target : Optional[torch.Tensor]
            An optional 1D or 2D tensor representing the target value(s) of the
            molecule. If not provided, it defaults to a tensor with a single
            element set to 0.0.
        """

        if target is None:
            # Set default target to a tensor with shape (1, 1) and value 0.0
            target = torch.tensor([0.0]).type(torch.float32).reshape(1, 1)
        elif len(target.shape) == 1:
            # Reshape target if it's a 1D tensor to (1, target.shape[0])
            target = target.reshape(1, -1)
        if target.shape[0] != 1:
            raise ValueError("Target tensor must have shape (1, num_targets)")

        super().__init__(
            x=atom_attr,
            edge_index=connectivity,
            edge_attr=bond_attr,
            y=target
        )

__init__(atom_attr, bond_attr, connectivity, target=None)

Initialize the MoleculeGraph object.

Parameters

atom_attr : torch.Tensor A 2D tensor of shape (num_atoms, num_atom_features) representing the attributes of each atom in the molecule. bond_attr : torch.Tensor A 2D tensor of shape (num_bonds, num_bond_features) representing the attributes of each bond in the molecule. connectivity : torch.Tensor A 2D tensor of shape (2, num_bonds) where each column represents an edge (bond) between two atoms. The first row contains the source atom indices and the second row contains the target atom indices. target : Optional[torch.Tensor] An optional 1D or 2D tensor representing the target value(s) of the molecule. If not provided, it defaults to a tensor with a single element set to 0.0.

Source code in graphchem/data/structs.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def __init__(self, atom_attr: torch.Tensor,
             bond_attr: torch.Tensor,
             connectivity: torch.Tensor,
             target: Optional[torch.Tensor] = None):
    """
    Initialize the MoleculeGraph object.

    Parameters
    ----------
    atom_attr : torch.Tensor
        A 2D tensor of shape (num_atoms, num_atom_features) representing
        the attributes of each atom in the molecule.
    bond_attr : torch.Tensor
        A 2D tensor of shape (num_bonds, num_bond_features) representing
        the attributes of each bond in the molecule.
    connectivity : torch.Tensor
        A 2D tensor of shape (2, num_bonds) where each column represents an
        edge (bond) between two atoms. The first row contains the source
        atom indices and the second row contains the target atom indices.
    target : Optional[torch.Tensor]
        An optional 1D or 2D tensor representing the target value(s) of the
        molecule. If not provided, it defaults to a tensor with a single
        element set to 0.0.
    """

    if target is None:
        # Set default target to a tensor with shape (1, 1) and value 0.0
        target = torch.tensor([0.0]).type(torch.float32).reshape(1, 1)
    elif len(target.shape) == 1:
        # Reshape target if it's a 1D tensor to (1, target.shape[0])
        target = target.reshape(1, -1)
    if target.shape[0] != 1:
        raise ValueError("Target tensor must have shape (1, num_targets)")

    super().__init__(
        x=atom_attr,
        edge_index=connectivity,
        edge_attr=bond_attr,
        y=target
    )