API Reference

multigrad

class multigrad.OnePointModel(aux_data: Any = None, comm: Any = None, loss_func_has_aux: bool = False, sumstats_func_has_aux: bool = False)[source]

Allows differentiable one-point calculations to be performed on separate MPI ranks, and automatically sums over each rank controlled by the comm. This is an abstract base class only. The user must personally define the calc_partial_sumstats_from_params and calc_loss_from_sumstats methods

Parameters:
  • aux_data (Any (default=None)) – Any auxiliary data for easy access within sumstats or loss functions

  • comm (Comm (default=COMM_WORLD)) – MPI communicator

  • loss_func_has_aux (bool (default=False)) – If true, calc_partial_sumstats_from_params(x) -> (y, aux) and calc_loss_from_sumstats(y, aux) -> … signatures will be assumed

  • sumstats_func_has_aux (bool (default=False)) – If true, calc_loss_from_sumstats(…) -> (loss, aux) signature will be assumed

calc_dloss_dparams(params, randkey=None)[source]

Calculate the gradient of the loss w.r.t. model parameters given

Parameters:
  • params (array-like) – Model parameters

  • randkey (PRNGKey | int (default=None)) – If set to a value other than None, the “randkey” kwarg will be passed to user-defined methods

Returns:

Gradient of the loss with respect to each parameter

Return type:

array

calc_loss_and_grad_from_params(params, randkey=None)[source]

Calculate the loss and its gradient.

This function returns the equivalent of (calc_loss_from_params(x), calc_dloss_dparams(x)) but it is significantly cheaper than calling them separately

Parameters:
  • params (array-like) – Model parameters

  • randkey (PRNGKey | int (default=None)) – If set to a value other than None, the “randkey” kwarg will be passed to user-defined methods

Returns:

  • float – The loss evaluated at the parameters given

  • array – Gradient of the loss with respect to each parameter

calc_loss_from_params(params, randkey=None)[source]

Calculate the loss evaluated at a given set of parameters

Parameters:
  • params (array-like) – Model parameters

  • randkey (PRNGKey | int (default=None)) – If set to a value other than None, the “randkey” kwarg will be passed to user-defined methods

Returns:

The loss evaluated at the parameters given

Return type:

float

calc_loss_from_sumstats(sumstats, sumstats_aux=None, randkey=None)[source]

Custom method to map summary statistics to loss

calc_partial_sumstats_from_params(params, randkey=None)[source]

Custom method to map parameters to summary statistics

calc_sumstats_from_params(params, total=True, randkey=None)[source]

Compute summary statistics at given parameters

Parameters:
  • params (array-like) – Model parameters

  • total (bool (default=True)) – If true (default), sumstats will be summed over all MPI ranks

  • randkey (PRNGKey | int (default=None)) – If set to a value other than None, the “randkey” kwarg will be passed to user-defined methods

Returns:

Summary statistics evaluated at given parameters

Return type:

array

run_adam(guess, nsteps=100, param_bounds=None, learning_rate=0.01, randkey=None, const_randkey=False, comm=None)[source]

Run adam to descend the gradient and optimize the model parameters, given an initial guess. Stochasticity is allowed if randkey is passed.

Parameters:
  • guess (array-like) – The starting parameters.

  • nsteps (int (default=100)) – The number of steps to take.

  • param_bounds (Sequence, optional) – Lower and upper bounds of each parameter of “shape” (ndim, 2). Pass None as the bound for each unbounded parameter, by default None

  • learning_rate (float (default=0.001)) – The adam learning rate.

  • randkey (int | PRNG Key (default=None)) – If given, a new PRNG Key will be generated at each iteration and be passed to calc_loss_and_grad_from_params() as the “randkey” kwarg

  • const_randkey (bool (default=False)) – By default, randkey is regenerated at each gradient descent iteration. Remove this behavior by setting const_randkey=True

Returns:

The optimal parameters.

Return type:

array-like

run_bfgs(guess, maxsteps=100, param_bounds=None, randkey=None, comm=None)[source]

Run BFGS to descend the gradient and optimize the model parameters, given an initial guess. Stochasticity must be held fixed via a random key

Parameters:
  • guess (array-like) – The starting parameters.

  • maxsteps (int (default=100)) – The number of steps to take.

  • param_bounds (Sequence, optional) – Lower and upper bounds of each parameter of “shape” (ndim, 2). Pass None as the bound for each unbounded parameter, by default None

  • randkey (int | PRNG Key (default=None)) – Since BFGS requires a deterministic function, this key will be passed to calc_loss_and_grad_from_params() as the “randkey” kwarg as a constant at every iteration

Returns:

messagestr

describes reason of termination

successboolean

True if converged

funfloat

minimum loss found

xarray

parameters at minimum loss found

jacarray

gradient of loss at minimum loss found

nfevint

number of function evaluations

nitint

number of gradient descent iterations

Return type:

OptimizeResult (contains the following attributes)

run_lhs_param_scan(xmins, xmaxs, n_dim, num_evaluations, seed=None, randkey=None)[source]

Compute sumstat and loss values over a Latin Hypercube sample

Parameters:
  • xmins (float | array-like) – Lower bound on each parameter

  • xmaxs (float | array-like) – Upper bound on each parameter

  • n_dim (int) – Number of parameters

  • num_evaluations (int) – Number of Latin Hypercube samples to draw and evaluate

  • seed (int (default=None)) – Seed to make LHD draws reproducible, randomized by default

  • randkey (PRNGKey | int (default=None)) – Random key passed to each sumstat and loss evaluation

Returns:

  • params (array-like) – Parameters (drawn in Latin Hypercube shape)

  • sumstats (array-like) – Sumstats evaluated at each draw of parameters

  • losses (array-like) – Loss evaluated at each draw of parameters

run_simple_grad_descent(guess, nsteps=100, learning_rate=0.01)[source]

Descend the gradient with a fixed learning rate to optimize parameters, given an initial guess. Stochasticity not allowed.

Parameters:
  • guess (array-like) – The starting parameters.

  • nsteps (int (default=100)) – The number of steps to take.

  • learning_rate (float (default=0.001)) – The fixed learning rate.

Returns:

loss : array of loss values returned at each iteration params : array of trial parameters at each iteration aux : array of aux values returned at each iteration

Return type:

GradientDescentResult (contains the following attributes)

class multigrad.OnePointGroup(models: tuple[OnePointModel, ...] | OnePointModel, main_comm: Any = None)[source]

Allows different OnePointModels to simultaneously perform their calc_loss_and_grad_from_params method. The results are summed.

Parameters:
  • models (tuple[OnePointModel]) – Sequence of models, each providing a loss component to be summed.

  • main_comm (Comm (default=COMM_WORLD)) – MPI communicator for the entire group (each model should be assigned its own sub-communicator)

multigrad.split_subcomms(num_groups=None, ranks_per_group=None, comm=None)[source]

Split comm into sub-comms (not grouped by nodes)

Parameters:
  • num_groups (int, optional) – Specify the number of evenly divided groups of subcomms

  • ranks_per_group (list[int], optional) – Specify the number of ranks given to each sub-comm

  • comm (MPI.Comm, optional) – Specify a sub-communicator to split into sub-sub-communicators

Returns:

  • subcomm (MPI.Comm) – The sub-comm that now controls this process

  • num_groups (int) – The number of groups of subcomms (same as input if not None)

  • group_rank (int) – The rank of this group (0 <= subcomm_rank < num_subcomms)

multigrad.split_subcomms_by_node(comm=None)[source]

Split comm into sub-comms (grouped by nodes)

Parameters:

comm (MPI.Comm, optional) – Specify a sub-communicator to split into sub-sub-communicators

Returns:

  • subcomm (MPI.Comm) – The sub-comm that now controls this process

  • num_groups (int) – The number of groups of subcomms (= number of nodes)

  • group_rank (int) – The rank of this group (0 <= subcomm_rank < num_subcomms)

multigrad.reduce_sum(value, root=None, comm=None)[source]

Returns the sum of value across all MPI processes

Parameters:
  • value (np.ndarray | float | int) – value input by each MPI process to be summed

  • root (int, optional) – rank of the process to receive and sum the values, by default None (broadcast result to all ranks)

  • comm (MPI.Intracomm (default = MPI.COMM_WORLD)) – option to pass a sub-communicator in case the operation is not performed by all MPI ranks

Returns:

Sum of values given by each rank of the communicator

Return type:

np.ndarray | float