Source code for hessQuik.activations.hessQuik_activation_function

import torch.nn as nn
from torch import Tensor
from typing import Union, Tuple


[docs]class hessQuikActivationFunction(nn.Module): r""" Base class for all hessQuik activation functions. """ def __init__(self) -> None: super(hessQuikActivationFunction, self).__init__() self.ctx = None # context variable
[docs] def forward(self, x: Tensor, do_gradient: bool = False, do_Hessian: bool = False, forward_mode: bool = True) \ -> Tuple[Tensor, Union[Tensor, None], Union[Tensor, None]]: r""" Applies a pointwise activation function to the incoming data. :param x: input into the activation function. :math:`(*)` where :math:`*` means any shape. :type x: torch.Tensor :param do_gradient: If set to ``True``, the gradient will be computed during the forward call. Default: ``False`` :type do_gradient: bool, optional :param do_Hessian: If set to ``True``, the Hessian will be computed during the forward call. Default: ``False`` :type do_Hessian: bool, optional :param forward_mode: If set to ``False``, the derivatives will be computed in backward mode. Default: ``True`` :type forward_mode: bool, optional :return: - **sigma** (*torch.Tensor*) - value of activation function at input x, same size as x - **dsigma** (*torch.Tensor* or ``None``) - first derivative of activation function at input x, same size as x - **d2sigma** (*torch.Tensor* or ``None``) - second derivative of activation function at input x, same size as x """ raise NotImplementedError
[docs] def backward(self, do_Hessian: bool = False) -> Tuple[Tensor, Union[Tensor, None]]: r""" Computes derivatives of activation function evaluated at x in backward mode. Calls self.compute_derivatives without inputs, stores necessary variables in self.ctx. Inherited by all subclasses. """ dsigma, d2sigma = self.compute_derivatives(*self.ctx, do_Hessian=do_Hessian) return dsigma, d2sigma
[docs] def compute_derivatives(self, *args, do_Hessian: bool = False) -> Tuple[Tensor, Union[Tensor, None]]: r""" :param args: variables needed to compute derivatives :type args: torch.Tensor :param do_Hessian: If set to ``True``, the Hessian will be computed during the forward call. Default: ``False`` :type do_Hessian: bool, optional :return: - **dsigma** (*torch.Tensor* or ``None``) - first derivative of activation function at input x, same size as x - **d2sigma** (*torch.Tensor* or ``None``) - second derivative of activation function at input x, same size as x """ raise NotImplementedError