Bases: Data
A custom graph class representing a molecular structure.
This class extends the Data class from PyTorch Geometric to represent
molecules with node attributes (atoms), edge attributes (bonds), and
connectivity information. It also includes an optional target value.
Attributes
x : torch.Tensor
The node features (atom attributes).
edge_index : torch.Tensor
A 2D tensor describing the connectivity between atoms.
edge_attr : torch.Tensor
Edge features (bond attributes).
y : torch.Tensor
Target value(s) of the molecule.
Source code in graphchem/data/structs.py
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66 | class MoleculeGraph(Data):
"""
A custom graph class representing a molecular structure.
This class extends the `Data` class from PyTorch Geometric to represent
molecules with node attributes (atoms), edge attributes (bonds), and
connectivity information. It also includes an optional target value.
Attributes
----------
x : torch.Tensor
The node features (atom attributes).
edge_index : torch.Tensor
A 2D tensor describing the connectivity between atoms.
edge_attr : torch.Tensor
Edge features (bond attributes).
y : torch.Tensor
Target value(s) of the molecule.
"""
def __init__(self, atom_attr: torch.Tensor,
bond_attr: torch.Tensor,
connectivity: torch.Tensor,
target: Optional[torch.Tensor] = None):
"""
Initialize the MoleculeGraph object.
Parameters
----------
atom_attr : torch.Tensor
A 2D tensor of shape (num_atoms, num_atom_features) representing
the attributes of each atom in the molecule.
bond_attr : torch.Tensor
A 2D tensor of shape (num_bonds, num_bond_features) representing
the attributes of each bond in the molecule.
connectivity : torch.Tensor
A 2D tensor of shape (2, num_bonds) where each column represents an
edge (bond) between two atoms. The first row contains the source
atom indices and the second row contains the target atom indices.
target : Optional[torch.Tensor]
An optional 1D or 2D tensor representing the target value(s) of the
molecule. If not provided, it defaults to a tensor with a single
element set to 0.0.
"""
if target is None:
# Set default target to a tensor with shape (1, 1) and value 0.0
target = torch.tensor([0.0]).type(torch.float32).reshape(1, 1)
elif len(target.shape) == 1:
# Reshape target if it's a 1D tensor to (1, target.shape[0])
target = target.reshape(1, -1)
if target.shape[0] != 1:
raise ValueError("Target tensor must have shape (1, num_targets)")
super().__init__(
x=atom_attr,
edge_index=connectivity,
edge_attr=bond_attr,
y=target
)
|
__init__(atom_attr, bond_attr, connectivity, target=None)
Initialize the MoleculeGraph object.
Parameters
atom_attr : torch.Tensor
A 2D tensor of shape (num_atoms, num_atom_features) representing
the attributes of each atom in the molecule.
bond_attr : torch.Tensor
A 2D tensor of shape (num_bonds, num_bond_features) representing
the attributes of each bond in the molecule.
connectivity : torch.Tensor
A 2D tensor of shape (2, num_bonds) where each column represents an
edge (bond) between two atoms. The first row contains the source
atom indices and the second row contains the target atom indices.
target : Optional[torch.Tensor]
An optional 1D or 2D tensor representing the target value(s) of the
molecule. If not provided, it defaults to a tensor with a single
element set to 0.0.
Source code in graphchem/data/structs.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66 | def __init__(self, atom_attr: torch.Tensor,
bond_attr: torch.Tensor,
connectivity: torch.Tensor,
target: Optional[torch.Tensor] = None):
"""
Initialize the MoleculeGraph object.
Parameters
----------
atom_attr : torch.Tensor
A 2D tensor of shape (num_atoms, num_atom_features) representing
the attributes of each atom in the molecule.
bond_attr : torch.Tensor
A 2D tensor of shape (num_bonds, num_bond_features) representing
the attributes of each bond in the molecule.
connectivity : torch.Tensor
A 2D tensor of shape (2, num_bonds) where each column represents an
edge (bond) between two atoms. The first row contains the source
atom indices and the second row contains the target atom indices.
target : Optional[torch.Tensor]
An optional 1D or 2D tensor representing the target value(s) of the
molecule. If not provided, it defaults to a tensor with a single
element set to 0.0.
"""
if target is None:
# Set default target to a tensor with shape (1, 1) and value 0.0
target = torch.tensor([0.0]).type(torch.float32).reshape(1, 1)
elif len(target.shape) == 1:
# Reshape target if it's a 1D tensor to (1, target.shape[0])
target = target.reshape(1, -1)
if target.shape[0] != 1:
raise ValueError("Target tensor must have shape (1, num_targets)")
super().__init__(
x=atom_attr,
edge_index=connectivity,
edge_attr=bond_attr,
y=target
)
|