FrEIA
Subclasses of torch.nn.Module, that are reversible and can be used in the nodes of the GraphINN class. The only additional things that are needed compared to the base class is an @staticmethod otuput_dims, and the ‘rev’-argument of the forward-method.
Abstract template:
InvertibleModule
Coupling blocks:
AllInOneBlock
NICECouplingBlock
RNVPCouplingBlock
GLOWCouplingBlock
GINCouplingBlock
AffineCouplingOneSided
ConditionalAffineTransform
Other learned transforms:
ActNorm
IResNetLayer
InvAutoAct
InvAutoActFixed
InvAutoActTwoSided
InvAutoConv2D
InvAutoFC
LearnedElementwiseScaling
OrthogonalTransform
HouseholderPerm
Fixed (non-learned) transforms:
PermuteRandom
FixedLinearTransform
Fixed1x1Conv
Graph topology:
SplitChannel
ConcatChannel
Split1D
Concat1d
Reshaping:
IRevNetDownsampling
IRevNetUpsampling
HaarDownsampling
HaarUpsampling’,
Flatten
Reshape
FrEIA.modules.
__init__
Parameters: dims_in: list of tuples specifying the shape of the inputs to this
operator: dims_in = [shape_x_0, shape_x_1, …]
this operator.
__module__
forward
Perform a forward (default, rev=False) or backward pass (rev=True) through this module/operator.
Note to implementers: - Subclasses MUST return a Jacobian when jac=True, but CAN return a
valid Jacobian when jac=False (not punished). The latter is only recommended if the computation of the Jacobian is trivial.
Subclasses MUST follow the convention that the returned Jacobian be consistent with the evaluation direction. Let’s make this more precise: Let $f$ be the function that the subclass represents. Then: $$$ J = log det
-J = log det
$$$ Any subclass MUST return $J$ for forward evaluation (rev=False), and $-J$ for backward evaluation (rev=True).
x_or_z – input data (array-like of one or more tensors)
c – conditioning data (array-like of none or more tensors)
rev – perform backward pass
jac – return Jacobian associated to the direction
initialize_with_data
output_dims
training
Half of a coupling block following the GLOWCouplingBlock design. This means only one affine transformation on half the inputs. In the case where random permutations or orthogonal transforms are used after every block, this is not a restriction and simplifies the design.
Additional args in docstring of base class. :param subnet_constructor: function or class, with signature
constructor(dims_in, dims_out). The result should be a torch nn.Module, that takes dims_in input channels, and dims_out output channels. See tutorial for examples. One subnetwork will be initialized in the block.
clamp – Soft clamping for the multiplicative component. The amplification or attenuation of each input dimension can be at most exp(±clamp).
clamp_activation – Function to perform the clamping. String values “ATAN”, “TANH”, and “SIGMOID” are recognized, or a function of object can be passed. TANH behaves like the original realNVP paper. A custom function should take tensors and map -inf to -1 and +inf to +1.
See base class docstring
Module combining the most common operations in a normalizing flow or similar model.
It combines affine coupling, permutation, and global affine transformation (‘ActNorm’). It can also be used as GIN coupling block, perform learned householder permutations, and use an inverted pre-permutation (see constructor docstring for details).
subnet_constructor – class or callable f, called as f(channels_in, channels_out) and should return a torch.nn.Module
affine_clamping – clamp the output of the multiplicative coefficients (before exponentiation) to +/- affine_clamping.
gin_block – Turn the block into a GIN block from Sorrenson et al, 2019
global_affine_init – Initial value for the global affine scaling beta
global_affine_init – ‘SIGMOID’, ‘SOFTPLUS’, or ‘EXP’. Defines the activation to be used on the beta for the global affine scaling.
permute_soft – bool, whether to sample the permutation matrices from SO(N), or to use hard permutations in stead. Note, permute_soft=True is very slow when working with >512 dimensions.
learned_householder_permutation – Int, if >0, use that many learned householder reflections. Slow if large number. Dubious whether it actually helps.
reverse_permutation – Reverse the permutation before the block, as introduced by Putzky et al, 2019.
Concat
Invertible merge operation.
Concatenates a list of incoming tensors along a given dimension and passes on the result. Inverse is the corresponding split operation.
dims_in
A list of tuples containing the non-batch dimensionality of all incoming tensors. Handled automatically during compute graph setup. Dimensionality of incoming tensors must be identical, except in the merge dimension dim. Concat only makes sense with multiple input tensors.
dim
Index of the dimension along which to concatenate, not counting the batch dimension. Defaults to 0, i.e. the channel dimension in structured data.
Inits the Concat module with the attributes described above and checks that all dimensions are compatible.
See super class InvertibleModule. Jacobian log-det of concatenation is always zero.
See super class InvertibleModule.
alias of FrEIA.modules.graph_topology._deprecated_by.<locals>.deprecated_class
FrEIA.modules.graph_topology._deprecated_by.<locals>.deprecated_class
Similar to the conditioning layers from SPADE (Park et al, 2019): Perform an affine transformation on the whole input, where the affine coefficients are predicted from only the condition.
Fixed 1x1 conv transformation with matrix M.
Fixed transformation according to y = Mx + b, with invertible matrix M.
Flattens N-D tensors into 1-D tensors.
See docstring of base class (FrEIA.modules.InvertibleModule).
Coupling Block following the GIN design. The difference from GLOWCouplingBlock (and other affine coupling blocks) is that the Jacobian determinant is constrained to be 1. This constrains the block to be volume-preserving. Volume preservation is achieved by subtracting the mean of the output of the s subnetwork from itself. While volume preserving, GIN is still more powerful than NICE, as GIN is not volume preserving within each dimension. Note: this implementation differs slightly from the originally published implementation, which scales the final component of the s subnetwork so the sum of the outputs of s is zero. There was no difference found between the implementations in practice, but subtracting the mean guarantees that all outputs of s are at most ±exp(clamp), which might be more stable in certain cases.
constructor(dims_in, dims_out). The result should be a torch nn.Module, that takes dims_in input channels, and dims_out output channels. See tutorial for examples. Two of these subnetworks will be initialized in the block.
Coupling Block following the GLOW design. Note, this is only the coupling part itself, and does not include ActNorm, invertible 1x1 convolutions, etc. See AllInOneBlock for a block combining these functions at once. The only difference to the RNVPCouplingBlock coupling blocks is that it uses a single subnetwork to jointly predict [s_i, t_i], instead of two separate subnetworks. This reduces computational cost and speeds up learning.
GaussianMixtureModel
An invertible Gaussian mixture model. The weights, means, covariance parameterization and component index must be supplied as conditional inputs to the module and can come from an external feed-forward network, which may be trained by backpropagating through the GMM. Weights should first be normalized via GaussianMixtureModel.normalize_weights(w) and component indices can be sampled via GaussianMixtureModel.pick_mixture_component(w). If component indices are specified, the model reduces to that Gaussian mixture component and maps between data x and standard normal latent variable z. Components can also be chosen consistently at random, by supplying an integer random seed instead of indices. If a None value is supplied instead of indices, the model maps between K data points x and K latent codes z simultaneously, where K is the number of mixture components. Mathematical derivations are found in the technical report “Training Mixture Density Networks with full covariance matrices” on arXiv.
Map between data distribution and standard normal latent distribution of mixture components or entire mixture, in an invertible way.
must be [batch_size, n_dims] if component indices i are specified and should be [batch_size, n_components, n_dims] if not.
The conditional input c must be a list [w, mu, U, i] of parameters for the Gaussian mixture model with the following properties:
and have size [batch_size, n_components].
n_components, n_dims].
precision matrices of the mixture components. These are needed to parameterize the covariance of the mixture components and must have size [batch_size, n_components, n_dims * (n_dims + 1) / 2].
to be used as random number generator seed for component selection, or None to indicate that all mixture components are modelled.
nll_loss
Negative log-likelihood loss for training a Mixture Density Network.
one. Tensor must be of size [batch_size, n_components].
of size [batch, n_components, n_dims].
Tensor size must be [batch_size, n_components].
nll_upper_bound
Numerically more stable upper bound of the negative log-likelihood loss for training a Mixture Density Network.
normalize_weights
Apply softmax to ensure component weights are positive and sum to one. Works on batches of component weights.
size [batch_size, n_components]
pick_mixture_component
Randomly choose mixture component indices with probability given by the component weights w. Works on batches of component weights.
w: Weights of the mixture components, must be positive and sum to one seed: Optional RNG seed for consistent decisions
Uses Haar wavelets to split each channel into 4 channels, with half the width and height dimensions.
See docstring of base class (FrEIA.modules.InvertibleModule) for more. :param order_by_wavelet: Whether to group the output by original channels or
by wavelet. E.g. if the average, vertical, horizontal and diagonal wavelets for channel 1 are a1, v1, h1, d1, those for channel 2 are a2, v2, h2, d2, etc, then the output channels will be structured as follows: set to True: a1, a2, …, v1, v2, …, h1, h2, …, d1, d2, … set to False: a1, v1, h1, d1, a2, v2, h2, d2, … The True option is slightly slower to compute than the False option. The option is useful if e.g. the average channels should be split off by a FrEIA.modules.Split. Then, setting order_by_wavelet=True allows to split off the first quarter of channels to isolate the average wavelets only.
rebalance – Must !=0. There exist different conventions how to define the Haar wavelets. The wavelet components in the forward direction are multiplied with this factor, and those in the inverse direction are adjusted accordingly, so that the module as a whole is invertible. Stability of the network may be increased for rebalance < 1 (e.g. 0.5).
HaarUpsampling
The inverted operation of HaarDownsampling (see that docstring for details).
Implementation of the i-ResNet architecture as proposed in https://arxiv.org/pdf/1811.00995.pdf
lipschitz_correction
The invertible spatial downsampling used in i-RevNet. Each group of four neighboring pixels is reordered into one pixel with four times the channels in a checkerboard-like pattern. See i-RevNet, Jacobsen 2018 et al.
See docstring of base class (FrEIA.modules.InvertibleModule) for more. :param legacy_backend: If True, uses the splitting and concatenating method,
adapted from github.com/jhjacobsen/pytorch-i-revnet/blob/master/models/model_utils.py for the use in FrEIA. Is usually slower on GPU. If False, uses a 2d strided convolution with a kernel representing the downsampling. Note that the ordering of the output channels will be different. If pixels in each patch in channel 1 are a1, b1, …, and in channel 2 are a2, b2, … Then the output channels will be the following: legacy_backend=True: a1, a2, …, b1, b2, …, c1, c2, … legacy_backend=False: a1, b1, …, a2, b2, …, a3, b3, … (see also order_by_wavelet in module HaarDownsampling) Usually this difference is completely irrelevant.
The inverted operation of IRevNetDownsampling (see that docstring for details).
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
Module
e
log_e
log of the nonlinear function e
Base class for all invertible modules in FrEIA. Given module, an instance of some InvertibleOperator. This module shall be invertible in its input dimensions, so that the input can be recovered by applying the module in backwards mode (rev=True), not to be confused with pytorch.backward() which computes the gradient of an operation. ``` x = torch.randn(BATCH_SIZE, DIM_COUNT) c = torch.randn(BATCH_SIZE, CONDITION_DIM) # Forward mode z, jac = module([x], [c], jac=True) # Backward mode x_rev, jac_rev = module(z, [c], rev=True, jac=True) ``` The module returns $log det J = log det
Base class for all invertible modules in FrEIA.
Given module, an instance of some InvertibleOperator. This module shall be invertible in its input dimensions, so that the input can be recovered by applying the module in backwards mode (rev=True), not to be confused with pytorch.backward() which computes the gradient of an operation. ``` x = torch.randn(BATCH_SIZE, DIM_COUNT) c = torch.randn(BATCH_SIZE, CONDITION_DIM)
# Forward mode z, jac = module([x], [c], jac=True)
# Backward mode x_rev, jac_rev = module(z, [c], rev=True, jac=True) ``` The module returns $log det J = log det
of the operation in forward mode, and $-log det J = log det
rac{partial f^{-1}}{partial z} = -log det rac{partial f}{partial x}$
in backward mode (rev=True). Then, torch.allclose(x, x_rev[0]) == True and jac == -jac_rev.
in backward mode (rev=True).
Then, torch.allclose(x, x_rev[0]) == True and jac == -jac_rev.
dims_in – list of tuples specifying the shape of the inputs to this operator: dims_in = [shape_x_0, shape_x_1, …]
dims_c – list of tuples specifying the shape of the conditions to this operator.
Perform a forward (default, rev=False) or backward pass (rev=True) through this module/operator. Note to implementers: - Subclasses MUST return a Jacobian when jac=True, but CAN return a valid Jacobian when jac=False (not punished). The latter is only recommended if the computation of the Jacobian is trivial. Subclasses MUST follow the convention that the returned Jacobian be consistent with the evaluation direction. Let’s make this more precise: Let $f$ be the function that the subclass represents. Then: $$$ J = log det
x_or_z: input data (array-like of one or more tensors) c: conditioning data (array-like of none or more tensors) rev: perform backward pass jac: return Jacobian associated to the direction
jacobian
Coupling Block following the NICE (Dinh et al, 2015) design. The inputs are split in two halves. For 2D, 3D, 4D inputs, the split is performed along the channel dimension. Then, residual coefficients are predicted by two subnetworks that are added to each half in turn.
Additional args in docstring of base class. :param subnet_constructor: Callable function, class, or factory object, with signature
constructor(dims_in, dims_out). The result should be a torch nn.Module, that takes dims_in input channels, and dims_out output channels. See tutorial for examples. Two of these subnetworks will be initialized inside the block.
permutes input vector in a random but fixed way
Coupling Block following the RealNVP design (Dinh et al, 2017) with some minor differences. The inputs are split in two halves. For 2D, 3D, 4D inputs, the split is performed along the channel dimension. For checkerboard-splitting, prepend an i_RevNet_downsampling module. Two affine coupling operations are performed in turn on both halves of the input.
constructor(dims_in, dims_out). The result should be a torch nn.Module, that takes dims_in input channels, and dims_out output channels. See tutorial for examples. Four of these subnetworks will be initialized in the block.
Reshapes N-D tensors into target dim tensors. Note that the reshape resulting from e.g. (3, 32, 32) -> (12, 16, 16) will not necessarily be spatially sensible. See IRevNetDownsampling, IRevNetUpsampling, HaarDownsampling, HaarUpsampling for spatially meaningful reshaping operations.
See docstring of base class (FrEIA.modules.InvertibleModule) for more. :param output_dims: The shape the reshaped output is supposed to have (not
including batch dimension)
target_dim – Deprecated name for output_dims
Split
Invertible split operation.
Splits the incoming tensor along the given dimension, and returns a list of separate output tensors. The inverse is the corresponding merge operation.
A list of tuples containing the non-batch dimensionality of all incoming tensors. Handled automatically during compute graph setup. Split only takes one input tensor.
section_sizes
If set, takes precedence over ‘n_sections’ and behaves like the argument in torch.split(), except when a list of section sizes is given that doesn’t add up to the size of ‘dim’, an additional split section is created to take the slack. Defaults to None.
n_sections
If ‘section_sizes’ is None, the tensor is split into ‘n_sections’ parts of equal size or close to it. This mode behaves like numpy.array_split(). Defaults to 2, i.e. splitting the data into two equal halves.
Index of the dimension along which to split, not counting the batch dimension. Defaults to 0, i.e. the channel dimension in structured data.
Inits the Split module with the attributes described above and checks that split sizes and dimensionality are compatible.
See super class InvertibleModule. Jacobian log-det of splitting is always zero.
__annotations__