Skip to content

graphchem.nn.MoleculeGCN

Bases: nn.Module

Source code in graphchem/nn/gcn.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
class MoleculeGCN(nn.Module):

    def __init__(self, atom_vocab_size: int, bond_vocab_size: int,
                 output_dim: int, embedding_dim: int = 64, n_messages: int = 2,
                 n_readout: int = 2, readout_dim: int = 64,
                 dropout: float = 0.0):
        """ MoleculeGCN, extends torch.nn.Module; combination of GeneralConv
        and EdgeConv modules and feed-forward readout layer(s) for regressing
        on target variables using molecular structure

        Molecule graphs are first embedded (torch.nn.Embedding), then each
        message passing operation consists of:

        bond_embedding > EdgeConv > updated bond_embedding
        atom_embedding + bond_embedding > GeneralConv > updated
            atom_embedding

        The sum of all atom states is then passed through a series of fully-
        connected readout layers to regress on a variable:

        atom_embedding > fully-connected readout layers > target variable

        Args:
            atom_vocab_size (int): num features (MoleculeEncoder.vocab_sizes)
            bond_vocab_size (int): num features (MoleculeEncoder.vocab_sizes)
            output_dim (int): number of target values per compound
            embedding_dim (int): number of embedded features for atoms and
                bonds
            n_messages (int): number of message passes between atoms
            n_readout (int): number of feed-forward post-readout
                layers (think standard NN/MLP)
            readout_dim (int): number of neurons in readout layers
            dropout (float): random neuron dropout during training
        """

        super(MoleculeGCN, self).__init__()
        self._dropout = dropout
        self._n_messages = n_messages

        self.emb_atom = nn.Embedding(atom_vocab_size, embedding_dim)
        self.emb_bond = nn.Embedding(bond_vocab_size, embedding_dim)

        self.atom_conv = gnn.GeneralConv(embedding_dim, embedding_dim,
                                         embedding_dim, aggr='add')
        self.bond_conv = gnn.EdgeConv(nn.Sequential(
            nn.Linear(2 * embedding_dim, embedding_dim)
        ))

        self.readout = nn.ModuleList()
        self.readout.append(nn.Sequential(
            nn.Linear(embedding_dim, readout_dim)
        ))
        if n_readout > 1:
            for _ in range(n_readout - 1):
                self.readout.append(nn.Sequential(
                    nn.Linear(readout_dim, readout_dim)
                ))
        self.readout.append(nn.Sequential(
            nn.Linear(readout_dim, output_dim)
        ))

    def forward(self,
                data: 'torch_geometric.data.Data') -> Tuple['torch.tensor']:
        """ forward operation for PyTorch module; given a sample of
        torch_geometric.data.Data, with atom/bond attributes and connectivity,
        perform message passing operations and readout

        Args:
            data ('torch_geometric.data.Data'): torch_geometric data object or
                inheritee

        Returns:
            Tuple['torch.tensor']: (readout output (target prediction), atom
                embeddings, bond embeddings); embeddings represent pre-sum/
                readout values present at each atom/bond, useful for
                determining which atoms/bonds contribute to target value
        """

        x, edge_attr, edge_index, batch = data.x, data.edge_attr,\
            data.edge_index, data.batch
        if data.num_node_features == 0:
            x = torch.ones(data.num_nodes, 1)

        out_atom = self.emb_atom(x)
        out_atom = F.softplus(out_atom)

        out_bond = self.emb_bond(edge_attr)
        out_bond = F.softplus(out_bond)

        for _ in range(self._n_messages):

            out_bond = self.bond_conv(out_bond, edge_index)
            out_bond = F.softplus(out_bond)
            out_bond = F.dropout(out_bond, p=self._dropout,
                                 training=self.training)

            out_atom = self.atom_conv(out_atom, edge_index, out_bond)
            out_atom = F.softplus(out_atom)
            out_atom = F.dropout(out_atom, p=self._dropout,
                                 training=self.training)

        out = gnn.global_add_pool(out_atom, batch)

        for layer in self.readout[:-1]:
            out = layer(out)
            out = F.softplus(out)
            out = F.dropout(out, p=self._dropout, training=self.training)
        out = self.readout[-1](out)

        return (out, out_atom, out_bond)

__init__(atom_vocab_size, bond_vocab_size, output_dim, embedding_dim=64, n_messages=2, n_readout=2, readout_dim=64, dropout=0.0)

MoleculeGCN, extends torch.nn.Module; combination of GeneralConv and EdgeConv modules and feed-forward readout layer(s) for regressing on target variables using molecular structure

Molecule graphs are first embedded (torch.nn.Embedding), then each message passing operation consists of:

bond_embedding > EdgeConv > updated bond_embedding atom_embedding + bond_embedding > GeneralConv > updated atom_embedding

The sum of all atom states is then passed through a series of fully- connected readout layers to regress on a variable:

atom_embedding > fully-connected readout layers > target variable

Parameters:

Name Type Description Default
atom_vocab_size int

num features (MoleculeEncoder.vocab_sizes)

required
bond_vocab_size int

num features (MoleculeEncoder.vocab_sizes)

required
output_dim int

number of target values per compound

required
embedding_dim int

number of embedded features for atoms and bonds

64
n_messages int

number of message passes between atoms

2
n_readout int

number of feed-forward post-readout layers (think standard NN/MLP)

2
readout_dim int

number of neurons in readout layers

64
dropout float

random neuron dropout during training

0.0
Source code in graphchem/nn/gcn.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def __init__(self, atom_vocab_size: int, bond_vocab_size: int,
             output_dim: int, embedding_dim: int = 64, n_messages: int = 2,
             n_readout: int = 2, readout_dim: int = 64,
             dropout: float = 0.0):
    """ MoleculeGCN, extends torch.nn.Module; combination of GeneralConv
    and EdgeConv modules and feed-forward readout layer(s) for regressing
    on target variables using molecular structure

    Molecule graphs are first embedded (torch.nn.Embedding), then each
    message passing operation consists of:

    bond_embedding > EdgeConv > updated bond_embedding
    atom_embedding + bond_embedding > GeneralConv > updated
        atom_embedding

    The sum of all atom states is then passed through a series of fully-
    connected readout layers to regress on a variable:

    atom_embedding > fully-connected readout layers > target variable

    Args:
        atom_vocab_size (int): num features (MoleculeEncoder.vocab_sizes)
        bond_vocab_size (int): num features (MoleculeEncoder.vocab_sizes)
        output_dim (int): number of target values per compound
        embedding_dim (int): number of embedded features for atoms and
            bonds
        n_messages (int): number of message passes between atoms
        n_readout (int): number of feed-forward post-readout
            layers (think standard NN/MLP)
        readout_dim (int): number of neurons in readout layers
        dropout (float): random neuron dropout during training
    """

    super(MoleculeGCN, self).__init__()
    self._dropout = dropout
    self._n_messages = n_messages

    self.emb_atom = nn.Embedding(atom_vocab_size, embedding_dim)
    self.emb_bond = nn.Embedding(bond_vocab_size, embedding_dim)

    self.atom_conv = gnn.GeneralConv(embedding_dim, embedding_dim,
                                     embedding_dim, aggr='add')
    self.bond_conv = gnn.EdgeConv(nn.Sequential(
        nn.Linear(2 * embedding_dim, embedding_dim)
    ))

    self.readout = nn.ModuleList()
    self.readout.append(nn.Sequential(
        nn.Linear(embedding_dim, readout_dim)
    ))
    if n_readout > 1:
        for _ in range(n_readout - 1):
            self.readout.append(nn.Sequential(
                nn.Linear(readout_dim, readout_dim)
            ))
    self.readout.append(nn.Sequential(
        nn.Linear(readout_dim, output_dim)
    ))

forward(data)

forward operation for PyTorch module; given a sample of torch_geometric.data.Data, with atom/bond attributes and connectivity, perform message passing operations and readout

Parameters:

Name Type Description Default
data torch_geometric.data.Data

torch_geometric data object or inheritee

required

Returns:

Type Description
Tuple[torch.tensor]

Tuple['torch.tensor']: (readout output (target prediction), atom embeddings, bond embeddings); embeddings represent pre-sum/ readout values present at each atom/bond, useful for determining which atoms/bonds contribute to target value

Source code in graphchem/nn/gcn.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def forward(self,
            data: 'torch_geometric.data.Data') -> Tuple['torch.tensor']:
    """ forward operation for PyTorch module; given a sample of
    torch_geometric.data.Data, with atom/bond attributes and connectivity,
    perform message passing operations and readout

    Args:
        data ('torch_geometric.data.Data'): torch_geometric data object or
            inheritee

    Returns:
        Tuple['torch.tensor']: (readout output (target prediction), atom
            embeddings, bond embeddings); embeddings represent pre-sum/
            readout values present at each atom/bond, useful for
            determining which atoms/bonds contribute to target value
    """

    x, edge_attr, edge_index, batch = data.x, data.edge_attr,\
        data.edge_index, data.batch
    if data.num_node_features == 0:
        x = torch.ones(data.num_nodes, 1)

    out_atom = self.emb_atom(x)
    out_atom = F.softplus(out_atom)

    out_bond = self.emb_bond(edge_attr)
    out_bond = F.softplus(out_bond)

    for _ in range(self._n_messages):

        out_bond = self.bond_conv(out_bond, edge_index)
        out_bond = F.softplus(out_bond)
        out_bond = F.dropout(out_bond, p=self._dropout,
                             training=self.training)

        out_atom = self.atom_conv(out_atom, edge_index, out_bond)
        out_atom = F.softplus(out_atom)
        out_atom = F.dropout(out_atom, p=self._dropout,
                             training=self.training)

    out = gnn.global_add_pool(out_atom, batch)

    for layer in self.readout[:-1]:
        out = layer(out)
        out = F.softplus(out)
        out = F.dropout(out, p=self._dropout, training=self.training)
    out = self.readout[-1](out)

    return (out, out_atom, out_bond)