Interfacing torch to heyoka.py#
Note
For an introduction on feed forward neural networks in heyoka.py, check out the example: Feed-Forward Neural Networks.
Warning
This tutorial assumes torch is installed
heyoka.py is not a library meant for machine learning, nor it aspires to be one. However, given its support for feed-forward neural networks and their potential use in numerical integration, it is useful to connect the heyoka.py ffnn()
factory to a torch model. This tutorial tackles this!
import heyoka as hk
import numpy as np
# We will need torch for this notebook. The CPU version is enough though.
import torch
from torch import nn
We start defining a ffnn
model in torch. We use as a test-case, a feed-forward neural network with two hidden layers having 32 neurons each and use tanh
as nonlinearities and a sigmoid
for the output layer. We define the model as to map it easily to the heyoka ffnn
factory, but other styles are also possible.
This will typically look something like:
#Let us use float64 (this is not necessary as heyoka has also the float32 type, but we go for high precision here!)
torch.set_default_dtype(torch.float64)
class torch_net(nn.Module):
def __init__(
self,
n_inputs=4,
nn_hidden=[32, 32],
n_out=1,
activations=[nn.Tanh(), nn.Tanh(), nn.Sigmoid()]
):
super(torch_net, self).__init__()
# We treat all layers equally.
dims = [n_inputs] + nn_hidden + [n_out]
# Linear function.
self.fcs = nn.ModuleList([nn.Linear(dims[i], dims[i + 1]) for i in range(len(dims) - 1)])
# Non-linearities.
self.acts = nn.ModuleList(activations)
def forward(self, x):
for fc, act in zip(self.fcs, self.acts):
x = act(fc(x))
return x
Weights and biases are stored by torch in the model as arrays, while heyoka flattens everything into a one-dimensional sequence containing all weights first, then all biases.
The following function takes care of converting the torch representation to heyoka’s:
def weights_and_biases_heyoka(model):
weights = {}
biases = {}
for name, param in model.named_parameters():
if 'weight' in name:
weights[name] = param.data.clone()
elif 'bias' in name:
biases[name] = param.data.clone()
biases_torch=[]
weights_torch=[]
for idx in range(len(weights)):
weights_torch.append(weights[list(weights.keys())[idx]].numpy())
biases_torch.append(biases[list(biases.keys())[idx]].numpy())
w_flat=[]
b_flat=[]
for i in range(len(weights_torch)):
w_flat+=list(weights_torch[i].flatten())
b_flat+=list(biases_torch[i].flatten())
w_flat=np.array(w_flat)
b_flat=np.array(b_flat)
print(w_flat.shape)
return np.concatenate((w_flat, b_flat))
We are now ready to instantiate the model and extract its weights and biases ready for constructing an heyoka.ffnn
object:
model = torch_net(n_inputs=4,
nn_hidden=[32, 32],
n_out=1,
activations=[nn.Tanh(), nn.Tanh(), nn.Sigmoid()])
# Here one would likely perform some training ... at the end, we extract the model parameters:
flattened_weights = weights_and_biases_heyoka(model)
(1184,)
Let us instantiate the feed forward neural network in heyoka.py using those parameters:
inp_1, inp_2, inp_3, inp_4=hk.make_vars("inp_1","inp_2","inp_3","inp_4")
model_heyoka=hk.model.ffnn(inputs=[inp_1, inp_2, inp_3, inp_4],
nn_hidden=[32,32],
n_out=1,
activations=[hk.tanh,hk.tanh,hk.sigmoid],
nn_wb=flattened_weights)
Good! Just to be sure, we now verify the output is the same at inference? Let’s first compile the network so that we can run inference:
model_heyoka_compiled=hk.cfunc(model_heyoka, [inp_1, inp_2, inp_3, inp_4])
… and create some random inputs
random_input=torch.rand((4,1000000))
random_input_torch=random_input.t()
random_input_numpy=random_input.numpy()
out_array=np.zeros((1,1000000))
Now, let’s compare the output of heyoka.ffnn
and torch
to see if they are identical
hey = model_heyoka_compiled(random_input_numpy,outputs=out_array)
torch = model(random_input_torch).detach().numpy().reshape(1,-1)
print("Maximum difference in the inference is: ", (hey-torch).max())
Maximum difference in the inference is: 2.220446049250313e-16
In this way we have managed to port the torch model in heyoka.py, reproducing the same results…