Skip to content

Commit 31340d5

Browse files
committed
Refactoring
1 parent d8312d2 commit 31340d5

File tree

17 files changed

+249
-233
lines changed

17 files changed

+249
-233
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
__pycache__/
2+
runs/
3+
.ipynb_checkpoints/

1d_demo.ipynb

Lines changed: 56 additions & 59 deletions
Large diffs are not rendered by default.

2d_demo.ipynb

Lines changed: 9 additions & 9 deletions
Large diffs are not rendered by default.

convcnp.py

Lines changed: 0 additions & 150 deletions
This file was deleted.

convcnp/dataset/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .dataset import Synthetic1D

dataset.py renamed to convcnp/dataset/dataset.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from torch.utils import data as tdata
33
from gpytorch.utils.cholesky import psd_safe_cholesky
44

5+
from .kernels import eq_kernel, matern_kernel, periodic_kernel
6+
57

68
class Synthetic1D(tdata.Dataset):
79
def __init__(self,
@@ -17,7 +19,15 @@ def __init__(self,
1719
self.x_dim = 1
1820
self.y_dim = 1
1921

20-
self.kernel = kernel
22+
if kernel == 'eq':
23+
self.kernel = eq_kernel
24+
elif kernel == 'matern':
25+
self.kernel = matern_kernel
26+
elif kernel == 'periodic':
27+
self.kernel = periodic_kernel
28+
else:
29+
raise NotImplementedError('{} kernel is not implemented'.format(kernel))
30+
2131
self.length_scale = length_scale
2232
self.output_scale = output_scale
2333

File renamed without changes.
File renamed without changes.

convcnp/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .convcnp1d import ConvCNP1d
2+
from .convcnp2d import ConvCNP2d

convcnp/models/convcnp1d.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.distributions import MultivariateNormal
4+
5+
from gpytorch.kernels import RBFKernel, ScaleKernel
6+
7+
from ..modules import PowerFunction
8+
9+
10+
class ConvCNP1d(nn.Module):
11+
def __init__(self, density=16):
12+
super().__init__()
13+
14+
self.density = density
15+
16+
self.psi = ScaleKernel(RBFKernel())
17+
self.phi = PowerFunction(K=1)
18+
19+
self.cnn = nn.Sequential(
20+
nn.Conv1d(3, 16, 5, 1, 2),
21+
nn.ReLU(),
22+
nn.Conv1d(16, 32, 5, 1, 2),
23+
nn.ReLU(),
24+
nn.Conv1d(32, 16, 5, 1, 2),
25+
nn.ReLU(),
26+
nn.Conv1d(16, 2, 5, 1, 2)
27+
)
28+
29+
def weights_init(m):
30+
if isinstance(m, nn.Conv1d):
31+
torch.nn.init.xavier_uniform_(m.weight)
32+
torch.nn.init.zeros_(m.bias)
33+
self.cnn.apply(weights_init)
34+
35+
self.pos = nn.Softplus()
36+
self.psi_rho = ScaleKernel(RBFKernel())
37+
38+
def forward(self, xc: torch.Tensor, yc: torch.Tensor, xt: torch.Tensor):
39+
tmp = torch.cat([xc, xt], 1)
40+
lower, upper = tmp.min(), tmp.max()
41+
num_t = int((self.density * (upper - lower)).item())
42+
t = torch.linspace(start=lower, end=upper, steps=num_t).reshape(1, -1, 1).repeat(xc.size(0), 1, 1).to(xc.device)
43+
44+
h = self.psi(t, xc).matmul(self.phi(yc))
45+
h0, h1 = h.split(1, -1)
46+
h1 = h1.div(h0 + 1e-8)
47+
h = torch.cat([h0, h1], -1)
48+
49+
rep = torch.cat([t, h], -1).transpose(-1, -2)
50+
f = self.cnn(rep).transpose(-1, -2)
51+
f_mu, f_sigma = f.split(1, -1)
52+
53+
mu = self.psi_rho(xt, t).matmul(f_mu)
54+
55+
sigma = self.psi_rho(xt, t).matmul(self.pos(f_sigma))
56+
return MultivariateNormal(mu, scale_tril=sigma.diag_embed())

0 commit comments

Comments
 (0)