Source code for hessQuik.utils.training

import torch
from typing import Union


[docs]def train_one_epoch(f: torch.nn.Module, x: torch.Tensor, y: torch.Tensor, optimizer: torch.optim.Optimizer, batch_size: int = 5, do_gradient: bool = False, do_Hessian: bool = False, loss_weights: Union[tuple, list] = (1.0, 1.0, 1.0)): r""" Training mean-square loss for one epoch where the loss function is .. math:: L(\theta) = \frac{1}{2N}\left(w_0\|f_{\theta}(x)\|^2 + w_1\|\nabla f_{\theta}(x)\|^2 + w_2\|\nabla^2 f_{\theta}(x)\|^2\right) where :math:`f_{theta}` is the network :math:`\theta` are the network weights, and :math:`N` is the number of training samples. The loss corresponding to the function value, gradient, and Hessian each can have different weights, :math:`w_0`, :math:`w_1`, and :math:`w_2`, respectively. :param f: hessQuik neural network to train :type f: torch.nn.Module :param x: training data of shape :math:`(N, *)` where :math:`*` can be any shapw :type x: torch.Tensor :param y: target data of shape :math:`(N, *)` :type y: torch.Tensor :param optimizer: method for updating the network weights :type optimizer: torch.optim.Optimizer :param batch_size: size of mini-batches for stochatistic training. Default: 5 :type batch_size: int :param do_gradient: If set to ``True``, the gradient will be computed during the forward call. Default: ``False`` :type do_gradient: bool, optional :param do_Hessian: If set to ``True``, the Hessian will be computed during the forward call. Default: ``False`` :type do_Hessian: bool, optional :param loss_weights: weight for each term in the loss function :type loss_weights: tuple or list :return: tuple containing the overall running loss and the running loss for each term in the loss function """ f.train() n = x.shape[0] b = batch_size n_batch = n // b (loss_f, loss_df, loss_d2f) = (torch.zeros(1), torch.zeros(1), torch.zeros(1)) running_loss, running_loss_f, running_loss_df, running_loss_d2f = (0.0, 0.0, 0.0 if do_gradient else None, 0.0 if do_Hessian else None) # shuffle idx = torch.randperm(n) for i in range(n_batch): idxb = idx[i * b:(i + 1) * b] xb, yb = x[idxb], y[idxb] optimizer.zero_grad() fb, dfb, d2fb = f(xb, do_gradient=do_gradient, do_Hessian=do_Hessian) loss_f = (0.5 / b) * torch.norm(fb - yb[:, 0].view_as(fb)) ** 2 loss = loss_weights[0] * loss_f running_loss_f += b * loss_f.item() if do_gradient: loss_df = (0.5 / b) * torch.norm(dfb - yb[:, 1:xb.shape[1]+1].view_as(dfb)) ** 2 loss = loss + loss_weights[1] * loss_df running_loss_df += b * loss_df.item() if do_Hessian: loss_d2f = (0.5 / b) * torch.norm(d2fb - yb[:, xb.shape[1]+1:].view_as(d2fb)) ** 2 loss = loss + loss_weights[2] * loss_d2f running_loss_d2f += b * loss_d2f.item() running_loss += b * loss.item() # update network weights loss.backward() optimizer.step() output = (running_loss / n, running_loss_f / n) if do_gradient: output += (running_loss_df / n,) if do_Hessian: output += (running_loss_d2f / n,) return output
[docs]def test(f: torch.nn.Module, x: torch.Tensor, y: torch.Tensor, do_gradient: bool = False, do_Hessian: bool = False, loss_weights: Union[tuple, list] = (1.0, 1.0, 1.0)): r""" Evaluate mean-squared loss function without training See :py:func:`hessQuik.utils.training.train_one_epoch` for details. """ f.eval() (loss_f, loss_df, loss_d2f) = (torch.zeros(1), torch.zeros(1), torch.zeros(1)) with torch.no_grad(): n = x.shape[0] f0, df0, d2f0 = f(x, do_gradient=do_gradient, do_Hessian=do_Hessian) loss_f = (0.5 / n) * torch.norm(f0 - y[:, 0].view_as(f0)) ** 2 loss = loss_weights[0] * loss_f if do_gradient: loss_df = (0.5 / n) * torch.norm(df0 - y[:, 1:x.shape[1]+1].view_as(df0)) ** 2 loss += loss_weights[1] * loss_df if do_Hessian: loss_d2f = (0.5 / n) * torch.norm(d2f0 - y[:, x.shape[1]+1:].view_as(d2f0)) ** 2 loss += loss_weights[2] * loss_d2f output = (loss.item(), loss_f.item()) if do_gradient: output += (loss_df.item(),) if do_Hessian: output += (loss_d2f.item(),) return output