import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from hessQuik.layers import hessQuikLayer
from typing import Union, Tuple
[docs]class quadraticICNNLayer(hessQuikLayer):
r"""
Evaluate and compute derivatives of a ICNN quadratic layer.
Examples::
>>> import hessQuik.layers as lay
>>> f = lay.quadraticICNNLayer(4, None, 2)
>>> x = torch.randn(10, 4)
>>> fx, dfdx, d2fd2x = f(x, do_gradient=True, do_Hessian=True)
>>> print(fx.shape, dfdx.shape, d2fd2x.shape)
torch.Size([10, 1]) torch.Size([10, 4, 1]) torch.Size([10, 4, 4, 1])
"""
[docs] def __init__(self, input_dim: int, in_features: Union[int, None], rank: int, device=None, dtype=None) -> None:
r"""
:param input_dim: dimension of network inputs
:type input_dim: int
:param in_features: number of input features, :math:`n_{in}`. For only ICNN quadratic layer, set ``in_features = None``
:type in_features: int or ``None``
:param rank: number of columns of quadratic matrix, :math:`r`. In practice, :math:`r < n_{in}`
:type rank: int
:var v: weight vector for network inputs of size :math:`(d,)`
:var w: weight vector for input features of size :math:`(n_{in},)`
:var A: weight matrix for quadratic term of size :math:`(d, r)`
:var mu: additive scalar bias
:var nonneg: pointwise function to force :math:`l` to have nonnegative weights. Default ``torch.nn.functional.softplus``
"""
factory_kwargs = {'device': device, 'dtype': dtype}
super(quadraticICNNLayer, self).__init__()
self.input_dim = input_dim
self.in_features = in_features
self.rank = rank
self.ctx = None
self.nonneg = F.softplus
# create final layer
if in_features is not None:
self.w = nn.Parameter(torch.empty(in_features, **factory_kwargs))
else:
self.register_parameter('w', None)
self.v = nn.Parameter(torch.empty(input_dim, **factory_kwargs))
self.mu = nn.Parameter(torch.empty(1, **factory_kwargs))
self.A = nn.Parameter(torch.empty(rank, input_dim, **factory_kwargs))
self.reset_parameters()
def reset_parameters(self) -> None:
if self.in_features is not None:
bound = 1 / math.sqrt(self.in_features)
nn.init.uniform_(self.w, a=-bound, b=bound)
else:
bound = 1 / math.sqrt(self.input_dim)
nn.init.uniform_(self.v, a=-bound, b=bound)
nn.init.uniform_(self.mu)
bound = 1 / math.sqrt(self.input_dim)
nn.init.uniform_(self.A, a=-bound, b=bound)
[docs] def dim_output(self) -> int:
r"""
scalar
"""
return 1
[docs] def forward(self, ux, do_gradient=False, do_Hessian=False, do_Laplacian=False, forward_mode=True,
dudx=None, d2ud2x=None, v=None):
r"""
Forward propagation through ICNN layer of the form, for one sample :math:`n_s = 1`,
.. math::
f(x) =
\left[\begin{array}{c}u(x) & x\end{array}\right]
\left[\begin{array}{c}w^+ \\ v\end{array}\right] + \frac{1}{2} x A A^\top x^\top + \mu
Here, :math:`u(x)` is the input into the layer of size :math:`(n_s, n_{in})` which is
a function of the input of the network, :math:`x` of size :math:`(n_s, d)`.
The output features, :math:`f(x)`, are of size :math:`(n_s, 1)`.
The notation :math:`(\cdot)^+` is a function that makes the weights of a matrix nonnegative.
As an example, for one sample, :math:`n_s = 1`, the gradient with respect to :math:`x` is of the form
.. math::
\nabla_x f = \left[\begin{array}{c}(w^+)^\top & v^\top\end{array}\right]
\left[\begin{array}{c} \nabla_x u \\ I\end{array}\right] + x A A^\top
where :math:`I` is the :math:`d \times d` identity matrix.
"""
(df, d2f) = (None, None)
AtA = self.A.t() @ self.A
if self.w is None:
w = torch.empty(0, dtype=self.v.dtype, device=self.v.device)
else:
w = self.nonneg(self.w)
wv = torch.cat((w, self.v), dim=0)
x = ux[:, -self.input_dim:]
# forward propagate
f = ux @ wv + 0.5 * torch.sum((x @ AtA) * x, dim=1) + self.mu
# ------------------------------------------------------------------------------------------------------------ #
if (do_gradient or do_Hessian) and forward_mode is True:
if self.in_features is None:
z = torch.empty(ux.shape[0], 0)
else:
z = torch.zeros(ux.shape[0], self.in_features)
df = wv.unsqueeze(0) + torch.cat((z, x @ AtA), dim=1)
# -------------------------------------------------------------------------------------------------------- #
if do_Hessian:
d2f = AtA
if d2ud2x is not None:
d2f = dudx[:, :, -self.input_dim:] @ d2f @ dudx[:, :, -self.input_dim:].permute(0, 2, 1)
z = torch.zeros(x.shape[0], self.in_features)
d2f += (d2ud2x @ (torch.cat((w, self.v), dim=0).unsqueeze(0)
+ torch.cat((z, x @ AtA), dim=1)).unsqueeze(1).unsqueeze(-1)).squeeze()
d2f = d2f.unsqueeze(-1)
if d2f.ndim < 4:
e = torch.ones(x.shape[0], device=x.device, dtype=x.dtype).view(-1, 1, 1, 1)
d2f = e * d2f.unsqueeze(0)
# -------------------------------------------------------------------------------------------------------- #
# finish computing gradient
if dudx is not None:
df = (dudx @ df.unsqueeze(-1)).squeeze()
df = df.unsqueeze(-1)
if (do_gradient or do_Hessian) and forward_mode is not True:
self.ctx = (ux,)
if forward_mode is False:
df, d2f = self.backward(do_Hessian=do_Hessian)
return f.unsqueeze(-1), df, d2f
[docs] def backward(self, do_Hessian=False, dgdf=None, d2gd2f=None, v=None):
r"""
Backward propagation through quadratic ICNN layer of the form, for one sample :math:`n_s = 1`,
.. math::
f\left(\begin{bmatrix} u & x \end{bmatrix}\right) =\left[\begin{array}{c}u & x\end{array}\right]
\left[\begin{array}{c}w^+ \\ v\end{array}\right] + \frac{1}{2} x A A^\top x^\top + \mu
Here, the network is :math:`g` is a function of :math:`f(u)`.
The gradient of the layer with respect to :math:`\begin{bmatrix} u & x \end{bmatrix}` is of the form
.. math::
\nabla_{[u,x]} f = \begin{bmatrix}(w^+)^\top & v^\top + x A A^\top\end{bmatrix}.
"""
d2f = None
ux = self.ctx[0]
x = ux[:, -self.input_dim:]
AtA = self.A.t() @ self.A # TODO: recompute this or store it?
wv = self.v
if self.w is not None:
wv = torch.cat((self.nonneg(self.w), wv), dim=0)
z = torch.empty(ux.shape[0], 0)
if self.in_features is not None:
z = torch.zeros(ux.shape[0], self.in_features)
df = wv.unsqueeze(0) + torch.cat((z, x @ AtA), dim=1)
if do_Hessian:
e = torch.ones(x.shape[0], 1, 1, dtype=AtA.dtype, device=AtA.device)
d2f = torch.zeros(x.shape[0], ux.shape[1], ux.shape[1])
d2f[:, -self.input_dim:, -self.input_dim:] = e * AtA
d2f = d2f.unsqueeze(-1)
return df.unsqueeze(-1), d2f
if __name__ == '__main__':
from hessQuik.utils import input_derivative_check
torch.set_default_dtype(torch.float64)
nex = 11 # no. of examples
d = 4 # no. of input features
in_feat = 5
m = 13 # rank
x = torch.randn(nex, d)
f = quadraticICNNLayer(d, None, m)
print('======= FORWARD =======')
input_derivative_check(f, x, do_Hessian=True, verbose=True, forward_mode=True)
print('======= BACKWARD =======')
input_derivative_check(f, x, do_Hessian=True, verbose=True, forward_mode=False)