Skip to content

graphchem.data.MoleculeGraph

Bases: torch_geometric.data.Data

Source code in graphchem/data/structs.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class MoleculeGraph(torch_geometric.data.Data):

    def __init__(self, atom_attr: 'torch.tensor', bond_attr: 'torch.tensor',
                 connectivity: 'torch.tensor', target: 'torch.tensor' = None):
        """ MoleculeGraph object, extends torch_geometric.data.Data object; a
        singular molecule graph/data point

        Args:
            atom_attr (torch.tensor): atom features, shape (n_atoms,
                n_atom_features); dtype assumed torch.float32
            bond_attr (torch.tensor): bond features, shape (n_bonds,
                n_bond_features); dtype assumed torch.float32
            connectivity (torch.tensor): COO graph connectivity index, size
                (2, n_bonds); dtype assumed torch.long
            target (torch.tensor, default=None): target value(s), shape
                (1, n_targets); if not supplied (None), set to [0.0]; dtype
                assumed torch.float32
        """

        if target is None:
            target = torch.tensor([0.0]).type(torch.float32).reshape(1, 1)

        super(MoleculeGraph, self).__init__(
            x=atom_attr,
            edge_index=connectivity,
            edge_attr=bond_attr,
            y=torch.tensor(target).type(torch.float).reshape(1, len(target))
        )

__init__(atom_attr, bond_attr, connectivity, target=None)

MoleculeGraph object, extends torch_geometric.data.Data object; a singular molecule graph/data point

Parameters:

Name Type Description Default
atom_attr torch.tensor

atom features, shape (n_atoms, n_atom_features); dtype assumed torch.float32

required
bond_attr torch.tensor

bond features, shape (n_bonds, n_bond_features); dtype assumed torch.float32

required
connectivity torch.tensor

COO graph connectivity index, size (2, n_bonds); dtype assumed torch.long

required
target torch.tensor, default=None

target value(s), shape (1, n_targets); if not supplied (None), set to [0.0]; dtype assumed torch.float32

None
Source code in graphchem/data/structs.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def __init__(self, atom_attr: 'torch.tensor', bond_attr: 'torch.tensor',
             connectivity: 'torch.tensor', target: 'torch.tensor' = None):
    """ MoleculeGraph object, extends torch_geometric.data.Data object; a
    singular molecule graph/data point

    Args:
        atom_attr (torch.tensor): atom features, shape (n_atoms,
            n_atom_features); dtype assumed torch.float32
        bond_attr (torch.tensor): bond features, shape (n_bonds,
            n_bond_features); dtype assumed torch.float32
        connectivity (torch.tensor): COO graph connectivity index, size
            (2, n_bonds); dtype assumed torch.long
        target (torch.tensor, default=None): target value(s), shape
            (1, n_targets); if not supplied (None), set to [0.0]; dtype
            assumed torch.float32
    """

    if target is None:
        target = torch.tensor([0.0]).type(torch.float32).reshape(1, 1)

    super(MoleculeGraph, self).__init__(
        x=atom_attr,
        edge_index=connectivity,
        edge_attr=bond_attr,
        y=torch.tensor(target).type(torch.float).reshape(1, len(target))
    )