Source code for hessQuik.layers.resnet_layer

import torch
from hessQuik.layers import hessQuikLayer
import hessQuik.activations as act
from hessQuik.layers import singleLayer
from typing import Union, Tuple


[docs]class resnetLayer(hessQuikLayer): r""" Evaluate and compute derivatives of a residual layer. Examples:: >>> import hessQuik.layers as lay >>> f = lay.resnetLayer(4, h=0.25) >>> x = torch.randn(10, 4) >>> fx, dfdx, d2fd2x = f(x, do_gradient=True, do_Hessian=True) >>> print(fx.shape, dfdx.shape, d2fd2x.shape) torch.Size([10, 4]) torch.Size([10, 4, 4]) torch.Size([10, 4, 4, 4]) """
[docs] def __init__(self, width: int, h: float = 1.0, act: act.hessQuikActivationFunction = act.identityActivation(), bias: bool = True, device=None, dtype=None) -> None: r""" :param width: number of input and output features, :math:`w` :type width: int :param h: step size, :math:`h > 0` :type h: float :param act: activation function :type act: hessQuikActivationFunction :param bias: additive bias flag :type bias: bool :var layer: singleLayer with :math:`w` input features and :math:`w` output features """ factory_kwargs = {'device': device, 'dtype': dtype} super(resnetLayer, self).__init__() self.width = width self.h = h self.layer = singleLayer(width, width, act=act, bias=bias, **factory_kwargs)
[docs] def dim_input(self) -> int: r""" width """ return self.width
[docs] def dim_output(self) -> int: r""" width """ return self.width
[docs] def forward(self, u, do_gradient=False, do_Hessian=False, do_Laplacian=False, forward_mode=True, dudx=None, d2ud2x=None, v=None): r""" Forward propagation through resnet layer of the form .. math:: f(x) = u(x) + h \cdot singleLayer(u(x)) Here, :math:`u(x)` is the input into the layer of size :math:`(n_s, w)` which is a function of the input of the network, :math:`x`. The output features, :math:`f(x)`, are of size :math:`(n_s, w)`. As an example, for one sample, :math:`n_s = 1`, the gradient with respect to :math:`x` is of the form .. math:: \nabla_x f = I + h \nabla_x singleLayer(u(x)) where :math:`I` denotes the :math:`w \times w` identity matrix. """ if do_Laplacian and not do_Hessian: forward_mode = True (dfdx, d2fd2x) = (None, None) fi, dfi, d2fi = self.layer(u, do_gradient=do_gradient, do_Hessian=do_Hessian, do_Laplacian=do_Laplacian, dudx=dudx, d2ud2x=d2ud2x, forward_mode=True if forward_mode is True else None, v=v) # skip connection f = u + self.h * fi if do_gradient and forward_mode is True: if dudx is None: if v is None: v = torch.eye(self.width, dtype=dfi.dtype, device=dfi.device) dfdx = v + self.h * dfi else: dfdx = dudx + self.h * dfi if (do_Hessian or do_Laplacian) and forward_mode is True: d2fd2x = self.h * d2fi if d2ud2x is not None: d2fd2x += d2ud2x if (do_gradient or do_Hessian) and forward_mode is False: dfdx, d2fd2x = self.backward(do_Hessian=do_Hessian, v=v) return f, dfdx, d2fd2x
[docs] def backward(self, do_Hessian=False, dgdf=None, d2gd2f=None, v=None): r""" Backward propagation through single layer of the form .. math:: f(u) = u + h \cdot singleLayer(u) Here, the network is :math:`g` is a function of :math:`f(u)`. As an example, for one sample, :math:`n_s = 1`, the gradient of the network with respect to :math:`u` is of the form .. math:: \nabla_u g = \nabla_f g + h \cdot \nabla_u singleLayer(u) where :math:`\odot` denotes the pointwise product. """ d2gd2u = None if not do_Hessian: dgdu = self.layer.backward(do_Hessian=False, dgdf=dgdf, d2gd2f=None, v=v)[0] if dgdf is None: if v is None: v = torch.eye(self.width, dtype=dgdu.dtype, device=dgdu.device) dgdu = v + self.h * dgdu else: dgdu = dgdf + self.h * dgdu else: dfdx, d2fd2x = self.layer.backward(do_Hessian=do_Hessian, dgdf=None, d2gd2f=None, v=v)[:2] if v is None: v = torch.eye(self.width, dtype=dfdx.dtype, device=dfdx.device) dgdu = v + self.h * dfdx if dgdf is not None: dgdu = dgdu @ dgdf # d2gd2u = self.h * d2fd2x if d2gd2f is None: d2gd2u = self.h * d2fd2x else: # TODO: compare timings for h_dfdx on CPU and GPU h_dfdx = torch.eye(self.width, dtype=dfdx.dtype, device=dfdx.device) + self.h * dfdx # Gauss-Newton approximation h1 = (h_dfdx.unsqueeze(1) @ d2gd2f.permute(0, 3, 1, 2) @ h_dfdx.permute(0, 2, 1).unsqueeze(1)) h1 = h1.permute(0, 2, 3, 1) # extra term to compute full Hessian h2 = d2fd2x @ dgdf.unsqueeze(1) # combine d2gd2u = h1 + self.h * h2 return dgdu, d2gd2u
def extra_repr(self) -> str: r""" :meta private: """ return 'width={}, h={}'.format(self.width, self.h)
if __name__ == '__main__': from hessQuik.utils import input_derivative_check, directional_derivative_check, \ directional_derivative_laplacian_check, input_derivative_check_finite_difference_laplacian torch.set_default_dtype(torch.float64) nex = 11 # no. of examples width = 4 # no. of input features h = 0.25 x = torch.randn(nex, width) f = resnetLayer(width, h=h, act=act.softplusActivation()) print('\n======= FORWARD =======') input_derivative_check(f, x, do_Hessian=True, verbose=True, forward_mode=True) print('\n======= BACKWARD =======') input_derivative_check(f, x, do_Hessian=True, verbose=True, forward_mode=False) print('\n======= LAPLACIAN =======') input_derivative_check_finite_difference_laplacian(f, x, do_Laplacian=True, verbose=True) print('\n======= DIRECTIONAL =======') directional_derivative_check(f, x, verbose=True) directional_derivative_laplacian_check(f, x, verbose=True)