Source code for hessQuik.utils.input_derivative_check

import torch
from math import log2, floor, ceil
import hessQuik
from hessQuik.utils import convert_to_base
from typing import Callable, Tuple, Optional, Union
import math

[docs]def input_derivative_check(f: Union[torch.nn.Module, Callable], x: torch.Tensor, do_Hessian: bool = False, forward_mode: bool = True, num_test: int = 15, base: float = 2.0, tol: float = 0.1, verbose: float = False) -> Tuple[Optional[bool], Optional[bool]]: r""" Taylor approximation test to verify derivatives. Form the approximation by perturbing the input :math:`x` in the direction :math:`p` with step size :math:`h > 0` via .. math:: f(x + h p) \approx f(x) + h\nabla f(x)^\top p + \frac{1}{2}p^\top \nabla^2f(x) p As :math:`h \downarrow 0^+`, the error between the approximation and the true value will decrease. The rate of decrease indicates the accuracy of the derivative computation. For details, see Chapter 5 of `Computational Methods for Electromagnetics`_ by Eldad Haber. .. _Computational Methods for Electromagnetics: https://epubs.siam.org/doi/book/10.1137/1.9781611973808 Examples:: >>> from hessQuik.layers import singleLayer >>> torch.set_default_dtype(torch.float64) # use double precision to check implementations >>> x = torch.randn(10, 4) >>> f = singleLayer(4, 7, act=act.softplusActivation()) >>> input_derivative_check(f, x, do_Hessian=True, verbose=True, forward_mode=True) h E0 E1 E2 1.00 x 2^(00) 1.62 x 2^(-02) 1.70 x 2^(-07) 1.02 x 2^(-12) 1.00 x 2^(-01) 1.63 x 2^(-03) 1.70 x 2^(-09) 1.06 x 2^(-15) 1.00 x 2^(-02) 1.63 x 2^(-04) 1.69 x 2^(-11) 1.08 x 2^(-18) 1.00 x 2^(-03) 1.63 x 2^(-05) 1.69 x 2^(-13) 1.09 x 2^(-21) 1.00 x 2^(-04) 1.63 x 2^(-06) 1.69 x 2^(-15) 1.09 x 2^(-24) 1.00 x 2^(-05) 1.63 x 2^(-07) 1.69 x 2^(-17) 1.10 x 2^(-27) 1.00 x 2^(-06) 1.63 x 2^(-08) 1.69 x 2^(-19) 1.10 x 2^(-30) 1.00 x 2^(-07) 1.63 x 2^(-09) 1.69 x 2^(-21) 1.10 x 2^(-33) 1.00 x 2^(-08) 1.63 x 2^(-10) 1.69 x 2^(-23) 1.10 x 2^(-36) 1.00 x 2^(-09) 1.63 x 2^(-11) 1.69 x 2^(-25) 1.10 x 2^(-39) 1.00 x 2^(-10) 1.63 x 2^(-12) 1.69 x 2^(-27) 1.10 x 2^(-42) 1.00 x 2^(-11) 1.63 x 2^(-13) 1.69 x 2^(-29) 1.10 x 2^(-45) 1.00 x 2^(-12) 1.63 x 2^(-14) 1.69 x 2^(-31) 1.15 x 2^(-48) 1.00 x 2^(-13) 1.63 x 2^(-15) 1.69 x 2^(-33) 1.33 x 2^(-50) 1.00 x 2^(-14) 1.63 x 2^(-16) 1.69 x 2^(-35) 1.70 x 2^(-51) Gradient PASSED! Hessian PASSED! :param f: callable function that returns value, gradient, and Hessian :type f: torch.nn.Module or Callable :param x: input data :type x: torch.Tensor :param do_Hessian: If set to ``True``, the Hessian will be computed during the forward call. Default: ``False`` :type do_Hessian: bool, optional :param forward_mode: If set to ``False``, the derivatives will be computed in backward mode. Default: ``True`` :type forward_mode: bool, optional :param num_test: number of perturbations :type num_test: int :param base: step size :math:`h = base^k` :type base: float :param tol: small tolerance to account for numerical errors when computing the order of approximation :type tol: float :param verbose: printout flag :type verbose: bool :return: - **grad_check** (*bool*) - if ``True``, gradient check passes - **hess_check** (*bool, optional*) - if ``True``, Hessian check passes """ # initial evaluation f0, df0, d2f0 = f(x, do_gradient=True, do_Hessian=do_Hessian, forward_mode=forward_mode) # ---------------------------------------------------------------------------------------------------------------- # # directional derivatives dx = torch.randn_like(x) dx = dx / torch.norm(x) curvx = None if isinstance(f, hessQuik.activations.hessQuikActivationFunction): dfdx = df0 * dx if d2f0 is not None: curvx = torch.sum(dx.unsqueeze(0) * d2f0 * dx.unsqueeze(0), dim=0) else: dfdx = torch.matmul(df0.transpose(1, 2), dx.unsqueeze(2)).squeeze(2) if d2f0 is not None: curvx = torch.sum(dx.unsqueeze(2).unsqueeze(3) * d2f0 * dx.unsqueeze(1).unsqueeze(3), dim=(1, 2)) # ---------------------------------------------------------------------------------------------------------------- # # derivative check grad_check, hess_check = None, None E0, E1, E2 = [], [], [] if verbose: headers = ('h', 'E0', 'E1') if do_Hessian: headers += ('E2',) print(('{:<20s}' * len(headers)).format(*headers)) for k in range(num_test): h = base ** (-k) ft, *_ = f(x + h * dx, do_gradient=False, do_Hessian=False) E0.append(torch.norm(f0 - ft).item()) E1.append(torch.norm(f0 + h * dfdx - ft).item()) printouts = convert_to_base((E0[-1], E1[-1])) if curvx is not None: E2.append(torch.norm(f0 + h * dfdx + 0.5 * (h ** 2) * curvx - ft).item()) printouts += convert_to_base((E2[-1],)) if verbose: print(((1 + len(printouts) // 2) * '%0.2f x 2^(%0.2d)\t\t') % ((1, -k) + printouts)) E0, E1, E2 = torch.tensor(E0), torch.tensor(E1), torch.tensor(E2) # ---------------------------------------------------------------------------------------------------------------- # # check if order is 2 enough of the time eps = torch.finfo(x.dtype).eps grad_check = (sum((torch.log2(E1[:-1] / E1[1:]) / log2(base)) > (2 - tol)) > num_test // 3) grad_check = (grad_check or (torch.kthvalue(E1, num_test // 4)[0] < (100 * eps))) if curvx is not None: hess_check = (sum((torch.log2(E2[:-1] / E2[1:]) / log2(base)) > (3 - tol)) > num_test // 3) hess_check = (hess_check or (torch.kthvalue(E2, num_test // 4)[0] < (100 * eps))) if verbose: if grad_check: print('Gradient PASSED!') else: print('Gradient FAILED.') if curvx is not None: if hess_check: print('Hessian PASSED!') else: print('Hessian FAILED.') return grad_check, hess_check
[docs]def input_derivative_check_finite_difference(f: Callable, x: torch.Tensor, do_Hessian: bool = False, forward_mode: bool = True, eps: float = 1e-4, atol: float = 1e-5, rtol: float = 1e-3, verbose: bool = False) -> Tuple[Optional[bool], Optional[bool]]: r""" Finite difference test to verify derivatives. Form the approximation by perturbing each entry in the input in the unit direction with step size :math:`\varepsilon > 0`: .. math:: \widetilde{\nabla f}_{i} = \frac{f(x_i + \varepsilon) - f(x_i - \varepsilon)}{2\varepsilon} where :math:`x_i \pm \varepsilon` means add or subtract :math:`\varepsilon` from the i-th entry of the input :math:`x`, but leave the other entries unchanged. The notation :math:`\widetilde{(\cdot)}` indicates the finite difference approximation. Examples:: >>> from hessQuik.layers import singleLayer >>> torch.set_default_dtype(torch.float64) # use double precision to check implementations >>> x = torch.randn(10, 4) >>> f = singleLayer(4, 7, act=act.tanhActivation()) >>> input_derivative_check_finite_difference(f, x, do_Hessian=True, verbose=True, forward_mode=True) Gradient Finite Difference: Error = 8.1720e-10, Relative Error = 2.5602e-10 Gradient PASSED! Hessian Finite Difference: Error = 4.5324e-08, Relative Error = 4.4598e-08 Hessian PASSED! :param f: callable function that returns value, gradient, and Hessian :type f: Callable :param x: input data :type x: torch.Tensor :param do_Hessian: If set to ``True``, the Hessian will be computed during the forward call. Default: ``False`` :type do_Hessian: bool, optional :param forward_mode: If set to ``False``, the derivatives will be computed in backward mode. Default: ``True`` :type forward_mode: bool, optional :param eps: step size. Default: 1e-4 :type eps: float :param atol: absolute tolerance, e.g., :math:`\|\nabla f - \widetilde{\nabla f}\| < atol`. Default: 1e-5 :type atol: float :param rtol: relative tolerance, e.g., :math:`\|\nabla f - \widetilde{\nabla f}\|/\|\nabla f\| < rtol`. Default: 1e-3 :type rtol: float :param verbose: printout flag :type verbose: bool :return: - **grad_check** (*bool*) - if ``True``, gradient check passes - **hess_check** (*bool, optional*) - if ``True``, Hessian check passes """ # compute initial gradient f0, df0, d2f0 = f(x, do_gradient=True, do_Hessian=do_Hessian, forward_mode=forward_mode) d = x.shape[1] # ---------------------------------------------------------------------------------------------------------------- # # test gradient df0_approx = torch.zeros_like(df0) for i in range(d): # perturbation in standard directions ei = torch.zeros(d, device=f0.device, dtype=f0.dtype) ei[i] = eps f_pos, *_ = f(x + ei.unsqueeze(0)) f_neg, *_ = f(x - ei.unsqueeze(0)) df0_approx[:, i] = (f_pos - f_neg) / (2 * eps) err = torch.norm(df0 - df0_approx).item() rel_err = err / torch.norm(df0).item() grad_check = (err < atol and rel_err < rtol) if verbose: print('Gradient Finite Difference: Error = %0.4e, Relative Error = %0.4e' % (err, rel_err)) if grad_check: print('Gradient PASSED!') else: print('Gradient FAILED.') # ---------------------------------------------------------------------------------------------------------------- # # test Hessian # https://v8doc.sas.com/sashtml/ormp/chap5/sect28.htm hess_check = None if do_Hessian: d2f0_approx = torch.zeros_like(d2f0) for i in range(d): ei = torch.zeros(d, device=f0.device, dtype=f0.dtype) ei[i] = eps for j in range(d): ej = torch.zeros(d, device=f0.device, dtype=f0.dtype) ej[j] = eps f1, *_ = f(x + (ei + ej).unsqueeze(0)) f2, *_ = f(x + (ei - ej).unsqueeze(0)) f3, *_ = f(x + (-ei + ej).unsqueeze(0)) f4, *_ = f(x - (ei + ej).unsqueeze(0)) d2f0_approx[:, i, j] = (f1 - f2 - f3 + f4) / (4 * (eps ** 2)) err = torch.norm(d2f0 - d2f0_approx).item() rel_err = err / max(torch.norm(d2f0).item(), eps) hess_check = (err < atol and rel_err < rtol) if verbose: print('Hessian Finite Difference: Error = %0.4e, Relative Error = %0.4e' % (err, rel_err)) if hess_check: print('Hessian PASSED!') else: print('Hessian FAILED.') return grad_check, hess_check
if __name__ == '__main__': from hessQuik.networks import NN from hessQuik.layers import singleLayer import hessQuik.activations as act torch.set_default_dtype(torch.float64) nex = 11 # no. of examples d = 4 # no. of input features x = torch.randn(nex, d) dx = torch.randn_like(x) f = NN(singleLayer(d, 7, act=act.identityActivation()), singleLayer(7, 5, act=act.identityActivation())) input_derivative_check(f, x, do_Hessian=True, verbose=True) input_derivative_check_finite_difference(f, x, do_Hessian=True, verbose=True, forward_mode=False)