API Reference¶
multigrad¶
- class multigrad.OnePointModel(aux_data: Any = None, comm: Any = None, loss_func_has_aux: bool = False, sumstats_func_has_aux: bool = False)[source]¶
Allows differentiable one-point calculations to be performed on separate MPI ranks, and automatically sums over each rank controlled by the comm. This is an abstract base class only. The user must personally define the calc_partial_sumstats_from_params and calc_loss_from_sumstats methods
- Parameters:
aux_data (Any (default=None)) – Any auxiliary data for easy access within sumstats or loss functions
comm (Comm (default=COMM_WORLD)) – MPI communicator
loss_func_has_aux (bool (default=False)) – If true, calc_partial_sumstats_from_params(x) -> (y, aux) and calc_loss_from_sumstats(y, aux) -> … signatures will be assumed
sumstats_func_has_aux (bool (default=False)) – If true, calc_loss_from_sumstats(…) -> (loss, aux) signature will be assumed
- calc_dloss_dparams(params, randkey=None)[source]¶
Calculate the gradient of the loss w.r.t. model parameters given
- Parameters:
params (array-like) – Model parameters
randkey (PRNGKey | int (default=None)) – If set to a value other than None, the “randkey” kwarg will be passed to user-defined methods
- Returns:
Gradient of the loss with respect to each parameter
- Return type:
array
- calc_loss_and_grad_from_params(params, randkey=None)[source]¶
Calculate the loss and its gradient.
This function returns the equivalent of (calc_loss_from_params(x), calc_dloss_dparams(x)) but it is significantly cheaper than calling them separately
- Parameters:
params (array-like) – Model parameters
randkey (PRNGKey | int (default=None)) – If set to a value other than None, the “randkey” kwarg will be passed to user-defined methods
- Returns:
float – The loss evaluated at the parameters given
array – Gradient of the loss with respect to each parameter
- calc_loss_from_params(params, randkey=None)[source]¶
Calculate the loss evaluated at a given set of parameters
- Parameters:
params (array-like) – Model parameters
randkey (PRNGKey | int (default=None)) – If set to a value other than None, the “randkey” kwarg will be passed to user-defined methods
- Returns:
The loss evaluated at the parameters given
- Return type:
float
- calc_loss_from_sumstats(sumstats, sumstats_aux=None, randkey=None)[source]¶
Custom method to map summary statistics to loss
- calc_partial_sumstats_from_params(params, randkey=None)[source]¶
Custom method to map parameters to summary statistics
- calc_sumstats_from_params(params, total=True, randkey=None)[source]¶
Compute summary statistics at given parameters
- Parameters:
params (array-like) – Model parameters
total (bool (default=True)) – If true (default), sumstats will be summed over all MPI ranks
randkey (PRNGKey | int (default=None)) – If set to a value other than None, the “randkey” kwarg will be passed to user-defined methods
- Returns:
Summary statistics evaluated at given parameters
- Return type:
array
- run_adam(guess, nsteps=100, param_bounds=None, learning_rate=0.01, randkey=None, const_randkey=False, comm=None)[source]¶
Run adam to descend the gradient and optimize the model parameters, given an initial guess. Stochasticity is allowed if randkey is passed.
- Parameters:
guess (array-like) – The starting parameters.
nsteps (int (default=100)) – The number of steps to take.
param_bounds (Sequence, optional) – Lower and upper bounds of each parameter of “shape” (ndim, 2). Pass None as the bound for each unbounded parameter, by default None
learning_rate (float (default=0.001)) – The adam learning rate.
randkey (int | PRNG Key (default=None)) – If given, a new PRNG Key will be generated at each iteration and be passed to calc_loss_and_grad_from_params() as the “randkey” kwarg
const_randkey (bool (default=False)) – By default, randkey is regenerated at each gradient descent iteration. Remove this behavior by setting const_randkey=True
- Returns:
The optimal parameters.
- Return type:
array-like
- run_bfgs(guess, maxsteps=100, param_bounds=None, randkey=None, comm=None)[source]¶
Run BFGS to descend the gradient and optimize the model parameters, given an initial guess. Stochasticity must be held fixed via a random key
- Parameters:
guess (array-like) – The starting parameters.
maxsteps (int (default=100)) – The number of steps to take.
param_bounds (Sequence, optional) – Lower and upper bounds of each parameter of “shape” (ndim, 2). Pass None as the bound for each unbounded parameter, by default None
randkey (int | PRNG Key (default=None)) – Since BFGS requires a deterministic function, this key will be passed to calc_loss_and_grad_from_params() as the “randkey” kwarg as a constant at every iteration
- Returns:
- messagestr
describes reason of termination
- successboolean
True if converged
- funfloat
minimum loss found
- xarray
parameters at minimum loss found
- jacarray
gradient of loss at minimum loss found
- nfevint
number of function evaluations
- nitint
number of gradient descent iterations
- Return type:
OptimizeResult (contains the following attributes)
- run_lhs_param_scan(xmins, xmaxs, n_dim, num_evaluations, seed=None, randkey=None)[source]¶
Compute sumstat and loss values over a Latin Hypercube sample
- Parameters:
xmins (float | array-like) – Lower bound on each parameter
xmaxs (float | array-like) – Upper bound on each parameter
n_dim (int) – Number of parameters
num_evaluations (int) – Number of Latin Hypercube samples to draw and evaluate
seed (int (default=None)) – Seed to make LHD draws reproducible, randomized by default
randkey (PRNGKey | int (default=None)) – Random key passed to each sumstat and loss evaluation
- Returns:
params (array-like) – Parameters (drawn in Latin Hypercube shape)
sumstats (array-like) – Sumstats evaluated at each draw of parameters
losses (array-like) – Loss evaluated at each draw of parameters
- run_simple_grad_descent(guess, nsteps=100, learning_rate=0.01)[source]¶
Descend the gradient with a fixed learning rate to optimize parameters, given an initial guess. Stochasticity not allowed.
- Parameters:
guess (array-like) – The starting parameters.
nsteps (int (default=100)) – The number of steps to take.
learning_rate (float (default=0.001)) – The fixed learning rate.
- Returns:
loss : array of loss values returned at each iteration params : array of trial parameters at each iteration aux : array of aux values returned at each iteration
- Return type:
GradientDescentResult (contains the following attributes)
- class multigrad.OnePointGroup(models: tuple[OnePointModel, ...] | OnePointModel, main_comm: Any = None)[source]¶
Allows different OnePointModels to simultaneously perform their calc_loss_and_grad_from_params method. The results are summed.
- Parameters:
models (tuple[OnePointModel]) – Sequence of models, each providing a loss component to be summed.
main_comm (Comm (default=COMM_WORLD)) – MPI communicator for the entire group (each model should be assigned its own sub-communicator)
- multigrad.split_subcomms(num_groups=None, ranks_per_group=None, comm=None)[source]¶
Split comm into sub-comms (not grouped by nodes)
- Parameters:
num_groups (int, optional) – Specify the number of evenly divided groups of subcomms
ranks_per_group (list[int], optional) – Specify the number of ranks given to each sub-comm
comm (MPI.Comm, optional) – Specify a sub-communicator to split into sub-sub-communicators
- Returns:
subcomm (MPI.Comm) – The sub-comm that now controls this process
num_groups (int) – The number of groups of subcomms (same as input if not None)
group_rank (int) – The rank of this group (0 <= subcomm_rank < num_subcomms)
- multigrad.split_subcomms_by_node(comm=None)[source]¶
Split comm into sub-comms (grouped by nodes)
- Parameters:
comm (MPI.Comm, optional) – Specify a sub-communicator to split into sub-sub-communicators
- Returns:
subcomm (MPI.Comm) – The sub-comm that now controls this process
num_groups (int) – The number of groups of subcomms (= number of nodes)
group_rank (int) – The rank of this group (0 <= subcomm_rank < num_subcomms)
- multigrad.reduce_sum(value, root=None, comm=None)[source]¶
Returns the sum of value across all MPI processes
- Parameters:
value (np.ndarray | float | int) – value input by each MPI process to be summed
root (int, optional) – rank of the process to receive and sum the values, by default None (broadcast result to all ranks)
comm (MPI.Intracomm (default = MPI.COMM_WORLD)) – option to pass a sub-communicator in case the operation is not performed by all MPI ranks
- Returns:
Sum of values given by each rank of the communicator
- Return type:
np.ndarray | float