graphchem.nn.MoleculeGCN
Bases: nn.Module
Source code in graphchem/nn/gcn.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
|
__init__(atom_vocab_size, bond_vocab_size, output_dim, embedding_dim=64, n_messages=2, n_readout=2, readout_dim=64, dropout=0.0)
MoleculeGCN, extends torch.nn.Module; combination of GeneralConv and EdgeConv modules and feed-forward readout layer(s) for regressing on target variables using molecular structure
Molecule graphs are first embedded (torch.nn.Embedding), then each message passing operation consists of:
bond_embedding > EdgeConv > updated bond_embedding atom_embedding + bond_embedding > GeneralConv > updated atom_embedding
The sum of all atom states is then passed through a series of fully- connected readout layers to regress on a variable:
atom_embedding > fully-connected readout layers > target variable
Parameters:
Name | Type | Description | Default |
---|---|---|---|
atom_vocab_size |
int
|
num features (MoleculeEncoder.vocab_sizes) |
required |
bond_vocab_size |
int
|
num features (MoleculeEncoder.vocab_sizes) |
required |
output_dim |
int
|
number of target values per compound |
required |
embedding_dim |
int
|
number of embedded features for atoms and bonds |
64
|
n_messages |
int
|
number of message passes between atoms |
2
|
n_readout |
int
|
number of feed-forward post-readout layers (think standard NN/MLP) |
2
|
readout_dim |
int
|
number of neurons in readout layers |
64
|
dropout |
float
|
random neuron dropout during training |
0.0
|
Source code in graphchem/nn/gcn.py
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
|
forward(data)
forward operation for PyTorch module; given a sample of torch_geometric.data.Data, with atom/bond attributes and connectivity, perform message passing operations and readout
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data |
torch_geometric.data.Data
|
torch_geometric data object or inheritee |
required |
Returns:
Type | Description |
---|---|
Tuple[torch.tensor]
|
Tuple['torch.tensor']: (readout output (target prediction), atom embeddings, bond embeddings); embeddings represent pre-sum/ readout values present at each atom/bond, useful for determining which atoms/bonds contribute to target value |
Source code in graphchem/nn/gcn.py
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
|