MoonNote

Schnet: A continuous-filter convolutional neural network for modeling quantum interactions 본문

Study/Paper Review

Schnet: A continuous-filter convolutional neural network for modeling quantum interactions

Kisung Moon 2021. 12. 21. 15:26

SchNet은 다음과 같은 3가지 단계를 거친다.

- K: the number of hidden layers

- continuous-filter convolution layer로 원자의 연속적인 postion을 모델링 할 수 있음

 

SchNet 실행 코드

def forward(self, batch_data):
    z, pos, batch = batch_data.z, batch_data.pos, batch_data.batch

    edge_index = radius_graph(pos, r=self.cutoff, batch=batch)
    row, col = edge_index
    dist = (pos[row] - pos[col]).norm(dim=-1)
    dist_emb = self.dist_emb(dist)

    v = self.init_v(z)

    for update_e, update_v in zip(self.update_es, self.update_vs):
        e = update_e(v, dist, dist_emb, edge_index)
        v = update_v(v,e, edge_index)
    u = self.update_u(v, batch)
    return u

- z: atom들의 원자번호 (SchNet은 atom feature로 atom number만 사용)

- pos: atom들의 3차원 coordinate

- batch: atom들의 batch index

 

- radius_graph: torch geometric에서 제공하는 함수로, cutoff 이내의 edge index를 리턴함

- dist: radius graph로 생성된 edge들의 거리를 계산함

- dist_emb: emb라는 함수에 dist를 input으로 사용

 

- emb

  1. offset: start에서 stop (cutoff)까지 num_gaussian의 갯수만큼의 간격 생성 

  2. coeff: -0.5 / offset의 간격 제곱 값. cutoff와 offset 간격이 비례하기 때문에 cutoff가 커질수록 coeff가 작아짐. 반대로, num gaussian과 offset 간격이 반비례하기 때문에 num gaussian이 커질수록 coeff 값이 커짐. 모두 음의 크기 (cutoff=5, num_gaussians=50으로 설정하면 약 -48)

  3. register_buffer 에 offset을 등록하여 모델의 파라미터로 사용하지 않음

  4. dist[n, 1] - offset[1, num_gaussians] 으로 dist를 줄이면서 num_gaussisan 크기만큼 늘림 [e, num_gaussians]

  5. dist의 제곱에 coeff를 곱한 후 exponential function을 취해준다. => 많은 dist가 emb 후에 0이 된다.

  6. Distance를 num_gaussian 차원만큼 Embedding 해주는 효과  

class emb(torch.nn.Module):
    def __init__(self, start=0.0, stop=5.0, num_gaussians=50):
        super(emb, self).__init__()
        offset = torch.linspace(start, stop, num_gaussians)
        self.coeff = -0.5 / (offset[1] - offset[0]).item()**2
        self.register_buffer('offset', offset)

    def forward(self, dist):
        dist = dist.view(-1, 1) - self.offset.view(1, -1)
        return torch.exp(self.coeff * torch.pow(dist, 2))

class SchNet(torch.nn.Module):
    def __init__(self, cutoff, num_layers, hidden_channels, out_channels, num_filters, num_gaussians, dropout_rate):
        super(SchNet, self).__init__()

        self.cutoff = cutoff
        self.num_layers = num_layers
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.num_filters = num_filters
        self.num_gaussians = num_gaussians
        self.dropout_rate = dropout_rate

        self.init_v = Embedding(100, hidden_channels)
        self.dist_emb = emb(0.0, cutoff, num_gaussians)

        self.update_vs = torch.nn.ModuleList([update_v(hidden_channels, num_filters, dropout_rate) for _ in range(num_layers)])

        self.update_es = torch.nn.ModuleList([
            update_e(hidden_channels, num_filters, num_gaussians, cutoff, dropout_rate) for _ in range(num_layers)])
        
        self.update_u = update_u(hidden_channels, out_channels, dropout_rate)

        self.reset_parameters()

 

- init_v: atom을 embedding하는 과정

  1. Embedding 함수를 사용하여 z를 연속적인 값을 가지는 벡터로 변환 [n, hidden_channels]

  

- update_e (edge를 구성하는 node를 distance를 반영하여 update)

  1. j, _: edge_index 중에 $node_{j}$의 index 선택 [e]

  2. C: distance 변환 과정 [e]

  - dist 값을 변환해준 후 cosine을 취함 -> distance가 크더라도 distance가 작은 edge보다 더 작아질 수 있음

  - 1을 더한 후 0.5를 곱해서 0과 1 사이의 값으로 변환

class update_e(torch.nn.Module):
    def __init__(self, hidden_channels, num_filters, num_gaussians, cutoff, dropout_rate):
        super(update_e, self).__init__()
        self.cutoff = cutoff
        self.lin = Linear(hidden_channels, num_filters, bias=False)
        self.mlp = Sequential(
            Linear(num_gaussians, num_filters),
            ShiftedSoftplus(),
            Linear(num_filters, num_filters),
        )

        self.reset_parameters()
        self.dropout = nn.Dropout(dropout_rate)

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.lin.weight)
        torch.nn.init.xavier_uniform_(self.mlp[0].weight)
        self.mlp[0].bias.data.fill_(0)
        torch.nn.init.xavier_uniform_(self.mlp[2].weight)
        self.mlp[0].bias.data.fill_(0)

    def forward(self, v, dist, dist_emb, edge_index):
        j, _ = edge_index
        C = 0.5 * (torch.cos(dist * PI / self.cutoff) + 1.0)
        W = self.mlp(dist_emb) * C.view(-1, 1)
        W = self.dropout(W)
        v = self.lin(v)
        e = v[j] * W
        return e

  3. self.mlp(dist_emb): MLP를 태워서 [e, num_gaussian] 에서 [e, num_filters]로 변환

  4. C.view(-1, 1): [e]을 [e, 1]로 변환

  5. W: self.mlp(dist_emb)에 C를 곱해줘서(broad casting) [e, num_filters] * [e, 1] = [e, num_filters]

  6. self.lin(v): MLP에 태워서 v [n, hidden_channels]를 [n, num_filters]로 변환

  7. v[j]: [n, num_filters]에서 j[e] 에 해당하는 index만 가져와서 edge를 구성하는 node만 취함 [e, num_filters]

  8. e: v[j] * W = edge를 구성하는 node들의 embedding [e, num_filters]

 

- update_v (edge를 구성하는 node를 distance를 반영하여 update + 원래의 node embedding)

  1. input으로 v, update_e와 edge_index를 받음

  2. _, i = edge_index: $node_{i}$의 index

  3. global_add_pool: $node_{i}$ 기준으로 edge($node_{j}$들의 embedding을 sum -> [n, num_filters]

  4. out: 선형변환 -> [n, hidden_channels]

  5. v + out: [n, hidden_channels] + [n, hidden_channels] = [n, hidden_channels]

 

class update_v(torch.nn.Module):
    def __init__(self, hidden_channels, num_filters, dropout_rate):
        super(update_v, self).__init__()
        self.act = ShiftedSoftplus()
        self.lin1 = Linear(num_filters, hidden_channels)
        self.lin2 = Linear(hidden_channels, hidden_channels)

        self.reset_parameters()
        self.dropout = nn.Dropout(dropout_rate)

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.lin1.weight)
        self.lin1.bias.data.fill_(0)
        torch.nn.init.xavier_uniform_(self.lin2.weight)
        self.lin2.bias.data.fill_(0)

    def forward(self, v, e, edge_index):
        _, i = edge_index
        #out = scatter(e, i, dim=0)
        out = global_add_pool(e, i)
        out = self.lin1(out)
        out = self.act(out)
        out = self.dropout(out)
        out = self.lin2(out)
        return v + out
        
        
class SchNet(torch.nn.Module):
    def __init__(self, cutoff, num_layers, hidden_channels, out_channels, num_filters, num_gaussians, dropout_rate):
        super(SchNet, self).__init__()

        self.cutoff = cutoff
        self.num_layers = num_layers
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.num_filters = num_filters
        self.num_gaussians = num_gaussians
        self.dropout_rate = dropout_rate

        self.init_v = Embedding(100, hidden_channels)
        self.dist_emb = emb(0.0, cutoff, num_gaussians)

        self.update_vs = torch.nn.ModuleList([update_v(hidden_channels, num_filters, dropout_rate) for _ in range(num_layers)])

        self.update_es = torch.nn.ModuleList([
            update_e(hidden_channels, num_filters, num_gaussians, cutoff, dropout_rate) for _ in range(num_layers)])
        
        self.update_u = update_u(hidden_channels, out_channels, dropout_rate)

        self.reset_parameters()

    def reset_parameters(self):
        self.init_v.reset_parameters()
        for update_e in self.update_es:
            update_e.reset_parameters()
        for update_v in self.update_vs:
            update_v.reset_parameters()
        self.update_u.reset_parameters()

    def forward(self, batch_data):
        z, pos, batch = batch_data.z, batch_data.pos, batch_data.batch       

        edge_index = radius_graph(pos, r=self.cutoff, batch=batch, max_num_neighbors=100)  # return edge_index based on pos within cutoff
        row, col = edge_index
        dist = (pos[row] - pos[col]).norm(dim=-1) # return distance between edge
        dist_emb = self.dist_emb(dist)
        
        v = self.init_v(z)
        
        for update_e, update_v in zip(self.update_es, self.update_vs):
            e = update_e(v, dist, dist_emb, edge_index)
            v = update_v(v, e, edge_index)
        u = self.update_u(v, batch)
        
        return u

 

- update_u 

  1. MLP

  2. Readout: global sum pooling

class update_u(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, dropout_rate):
        super(update_u, self).__init__()
        self.lin1 = Linear(hidden_channels, hidden_channels // 2)
        self.act = ShiftedSoftplus()
        self.lin2 = Linear(hidden_channels // 2, out_channels)

        self.reset_parameters()
        self.dropout = nn.Dropout(dropout_rate)

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.lin1.weight)
        self.lin1.bias.data.fill_(0)
        torch.nn.init.xavier_uniform_(self.lin2.weight)
        self.lin2.bias.data.fill_(0)

    def forward(self, v, batch):
        v = self.lin1(v)
        v = self.act(v)
        v = self.dropout(v)
        v = self.lin2(v)
        #u = scatter(v, batch, dim=0)
        u = global_add_pool(v, batch)
        return u
Comments