일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | 6 | 7 |
8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | 17 | 18 | 19 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 30 |
- Prompt Tuning for Graph Neural Networks
- pandas row 제거
- pandas
- layer 일부 freeze
- sktime 튜토리얼
- EDA 추천 파이썬
- python 경우의 수
- EDA in python
- Graph Theory
- pandas 특정 조건 열 제거
- sktime
- 판다스 조건
- molecular representation
- 선형함수 딥러닝
- 비선형함수 딥러닝
- pytorch 데이터셋 나누기
- sktime tutorial
- 일부 레이어 고정
- pandas 조건
- pretraining
- 경우의 수 파이썬
- pytorch dataset split
- Does GNN Pretraining Help Molecular Representation?
- weight 일부 고정
- 모델 freeze
- Skip connection
- 비선형함수
- 시계열 라이브러리
- sktime 예제
- pandas 행 제거
- Today
- Total
MoonNote
Schnet: A continuous-filter convolutional neural network for modeling quantum interactions 본문
Schnet: A continuous-filter convolutional neural network for modeling quantum interactions
Kisung Moon 2021. 12. 21. 15:26SchNet은 다음과 같은 3가지 단계를 거친다.
- K: the number of hidden layers
- continuous-filter convolution layer로 원자의 연속적인 postion을 모델링 할 수 있음
SchNet 실행 코드
def forward(self, batch_data):
z, pos, batch = batch_data.z, batch_data.pos, batch_data.batch
edge_index = radius_graph(pos, r=self.cutoff, batch=batch)
row, col = edge_index
dist = (pos[row] - pos[col]).norm(dim=-1)
dist_emb = self.dist_emb(dist)
v = self.init_v(z)
for update_e, update_v in zip(self.update_es, self.update_vs):
e = update_e(v, dist, dist_emb, edge_index)
v = update_v(v,e, edge_index)
u = self.update_u(v, batch)
return u
- z: atom들의 원자번호 (SchNet은 atom feature로 atom number만 사용)
- pos: atom들의 3차원 coordinate
- batch: atom들의 batch index
- radius_graph: torch geometric에서 제공하는 함수로, cutoff 이내의 edge index를 리턴함
- dist: radius graph로 생성된 edge들의 거리를 계산함
- dist_emb: emb라는 함수에 dist를 input으로 사용
- emb
1. offset: start에서 stop (cutoff)까지 num_gaussian의 갯수만큼의 간격 생성
2. coeff: -0.5 / offset의 간격 제곱 값. cutoff와 offset 간격이 비례하기 때문에 cutoff가 커질수록 coeff가 작아짐. 반대로, num gaussian과 offset 간격이 반비례하기 때문에 num gaussian이 커질수록 coeff 값이 커짐. 모두 음의 크기 (cutoff=5, num_gaussians=50으로 설정하면 약 -48)
3. register_buffer 에 offset을 등록하여 모델의 파라미터로 사용하지 않음
4. dist[n, 1] - offset[1, num_gaussians] 으로 dist를 줄이면서 num_gaussisan 크기만큼 늘림 [e, num_gaussians]
5. dist의 제곱에 coeff를 곱한 후 exponential function을 취해준다. => 많은 dist가 emb 후에 0이 된다.
6. Distance를 num_gaussian 차원만큼 Embedding 해주는 효과
class emb(torch.nn.Module):
def __init__(self, start=0.0, stop=5.0, num_gaussians=50):
super(emb, self).__init__()
offset = torch.linspace(start, stop, num_gaussians)
self.coeff = -0.5 / (offset[1] - offset[0]).item()**2
self.register_buffer('offset', offset)
def forward(self, dist):
dist = dist.view(-1, 1) - self.offset.view(1, -1)
return torch.exp(self.coeff * torch.pow(dist, 2))
class SchNet(torch.nn.Module):
def __init__(self, cutoff, num_layers, hidden_channels, out_channels, num_filters, num_gaussians, dropout_rate):
super(SchNet, self).__init__()
self.cutoff = cutoff
self.num_layers = num_layers
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.num_filters = num_filters
self.num_gaussians = num_gaussians
self.dropout_rate = dropout_rate
self.init_v = Embedding(100, hidden_channels)
self.dist_emb = emb(0.0, cutoff, num_gaussians)
self.update_vs = torch.nn.ModuleList([update_v(hidden_channels, num_filters, dropout_rate) for _ in range(num_layers)])
self.update_es = torch.nn.ModuleList([
update_e(hidden_channels, num_filters, num_gaussians, cutoff, dropout_rate) for _ in range(num_layers)])
self.update_u = update_u(hidden_channels, out_channels, dropout_rate)
self.reset_parameters()
- init_v: atom을 embedding하는 과정
1. Embedding 함수를 사용하여 z를 연속적인 값을 가지는 벡터로 변환 [n, hidden_channels]
- update_e (edge를 구성하는 node를 distance를 반영하여 update)
1. j, _: edge_index 중에 $node_{j}$의 index 선택 [e]
2. C: distance 변환 과정 [e]
- dist 값을 변환해준 후 cosine을 취함 -> distance가 크더라도 distance가 작은 edge보다 더 작아질 수 있음
- 1을 더한 후 0.5를 곱해서 0과 1 사이의 값으로 변환
class update_e(torch.nn.Module):
def __init__(self, hidden_channels, num_filters, num_gaussians, cutoff, dropout_rate):
super(update_e, self).__init__()
self.cutoff = cutoff
self.lin = Linear(hidden_channels, num_filters, bias=False)
self.mlp = Sequential(
Linear(num_gaussians, num_filters),
ShiftedSoftplus(),
Linear(num_filters, num_filters),
)
self.reset_parameters()
self.dropout = nn.Dropout(dropout_rate)
def reset_parameters(self):
torch.nn.init.xavier_uniform_(self.lin.weight)
torch.nn.init.xavier_uniform_(self.mlp[0].weight)
self.mlp[0].bias.data.fill_(0)
torch.nn.init.xavier_uniform_(self.mlp[2].weight)
self.mlp[0].bias.data.fill_(0)
def forward(self, v, dist, dist_emb, edge_index):
j, _ = edge_index
C = 0.5 * (torch.cos(dist * PI / self.cutoff) + 1.0)
W = self.mlp(dist_emb) * C.view(-1, 1)
W = self.dropout(W)
v = self.lin(v)
e = v[j] * W
return e
3. self.mlp(dist_emb): MLP를 태워서 [e, num_gaussian] 에서 [e, num_filters]로 변환
4. C.view(-1, 1): [e]을 [e, 1]로 변환
5. W: self.mlp(dist_emb)에 C를 곱해줘서(broad casting) [e, num_filters] * [e, 1] = [e, num_filters]
6. self.lin(v): MLP에 태워서 v [n, hidden_channels]를 [n, num_filters]로 변환
7. v[j]: [n, num_filters]에서 j[e] 에 해당하는 index만 가져와서 edge를 구성하는 node만 취함 [e, num_filters]
8. e: v[j] * W = edge를 구성하는 node들의 embedding [e, num_filters]
- update_v (edge를 구성하는 node를 distance를 반영하여 update + 원래의 node embedding)
1. input으로 v, update_e와 edge_index를 받음
2. _, i = edge_index: $node_{i}$의 index
3. global_add_pool: $node_{i}$ 기준으로 edge($node_{j}$들의 embedding을 sum -> [n, num_filters]
4. out: 선형변환 -> [n, hidden_channels]
5. v + out: [n, hidden_channels] + [n, hidden_channels] = [n, hidden_channels]
class update_v(torch.nn.Module):
def __init__(self, hidden_channels, num_filters, dropout_rate):
super(update_v, self).__init__()
self.act = ShiftedSoftplus()
self.lin1 = Linear(num_filters, hidden_channels)
self.lin2 = Linear(hidden_channels, hidden_channels)
self.reset_parameters()
self.dropout = nn.Dropout(dropout_rate)
def reset_parameters(self):
torch.nn.init.xavier_uniform_(self.lin1.weight)
self.lin1.bias.data.fill_(0)
torch.nn.init.xavier_uniform_(self.lin2.weight)
self.lin2.bias.data.fill_(0)
def forward(self, v, e, edge_index):
_, i = edge_index
#out = scatter(e, i, dim=0)
out = global_add_pool(e, i)
out = self.lin1(out)
out = self.act(out)
out = self.dropout(out)
out = self.lin2(out)
return v + out
class SchNet(torch.nn.Module):
def __init__(self, cutoff, num_layers, hidden_channels, out_channels, num_filters, num_gaussians, dropout_rate):
super(SchNet, self).__init__()
self.cutoff = cutoff
self.num_layers = num_layers
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.num_filters = num_filters
self.num_gaussians = num_gaussians
self.dropout_rate = dropout_rate
self.init_v = Embedding(100, hidden_channels)
self.dist_emb = emb(0.0, cutoff, num_gaussians)
self.update_vs = torch.nn.ModuleList([update_v(hidden_channels, num_filters, dropout_rate) for _ in range(num_layers)])
self.update_es = torch.nn.ModuleList([
update_e(hidden_channels, num_filters, num_gaussians, cutoff, dropout_rate) for _ in range(num_layers)])
self.update_u = update_u(hidden_channels, out_channels, dropout_rate)
self.reset_parameters()
def reset_parameters(self):
self.init_v.reset_parameters()
for update_e in self.update_es:
update_e.reset_parameters()
for update_v in self.update_vs:
update_v.reset_parameters()
self.update_u.reset_parameters()
def forward(self, batch_data):
z, pos, batch = batch_data.z, batch_data.pos, batch_data.batch
edge_index = radius_graph(pos, r=self.cutoff, batch=batch, max_num_neighbors=100) # return edge_index based on pos within cutoff
row, col = edge_index
dist = (pos[row] - pos[col]).norm(dim=-1) # return distance between edge
dist_emb = self.dist_emb(dist)
v = self.init_v(z)
for update_e, update_v in zip(self.update_es, self.update_vs):
e = update_e(v, dist, dist_emb, edge_index)
v = update_v(v, e, edge_index)
u = self.update_u(v, batch)
return u
- update_u
1. MLP
2. Readout: global sum pooling
class update_u(torch.nn.Module):
def __init__(self, hidden_channels, out_channels, dropout_rate):
super(update_u, self).__init__()
self.lin1 = Linear(hidden_channels, hidden_channels // 2)
self.act = ShiftedSoftplus()
self.lin2 = Linear(hidden_channels // 2, out_channels)
self.reset_parameters()
self.dropout = nn.Dropout(dropout_rate)
def reset_parameters(self):
torch.nn.init.xavier_uniform_(self.lin1.weight)
self.lin1.bias.data.fill_(0)
torch.nn.init.xavier_uniform_(self.lin2.weight)
self.lin2.bias.data.fill_(0)
def forward(self, v, batch):
v = self.lin1(v)
v = self.act(v)
v = self.dropout(v)
v = self.lin2(v)
#u = scatter(v, batch, dim=0)
u = global_add_pool(v, batch)
return u