Source code for spectral_connectivity.minimum_phase_decomposition

"""A spectral density matrix can be decomposed into minimum phase functions. This is used in computing the
pairwise spectral granger prediction."""

import os
from logging import getLogger

import numpy as np

if os.environ.get("SPECTRAL_CONNECTIVITY_ENABLE_GPU") == "true":
    try:
        import cupy as xp
        from cupyx.scipy.fft import fft, ifft
    except ImportError:
        import numpy as xp
        from scipy.fft import fft, ifft
else:
    import numpy as xp
    from scipy.fft import fft, ifft


logger = getLogger(__name__)


def _conjugate_transpose(x):
    """Conjugate transpose of the last two dimensions of array x"""
    return x.swapaxes(-1, -2).conjugate()


def _get_intial_conditions(cross_spectral_matrix):
    """Returns a guess for the minimum phase factor using the Cholesky
    factorization.

    Parameters
    ----------
    cross_spectral_matrix : array, shape (n_time_samples, ...,
                                          n_fft_samples, n_signals,
                                          n_signals)

    Returns
    -------
    minimum_phase_factor : array, shape (n_time_samples, ..., 1, n_signals,
                                         n_signals)
    """
    try:
        return xp.linalg.cholesky(
            ifft(cross_spectral_matrix, axis=-3)[..., 0:1, :, :].real
        ).swapaxes(-1, -2)
    except xp.linalg.linalg.LinAlgError:
        logger.warning(
            "Computing the initial conditions using the Cholesky failed. "
            "Using a random initial condition."
        )

        new_shape = list(cross_spectral_matrix.shape)
        N_RAND = 1000
        new_shape[-3] = N_RAND
        random_start = xp.random.standard_normal(size=new_shape)

        random_start = xp.matmul(random_start, _conjugate_transpose(random_start)).mean(
            axis=-3, keepdims=True
        )

        return xp.linalg.cholesky(random_start)


def _get_causal_signal(linear_predictor):
    """Takes half the roots on the unit circle (zero lag) and all the roots
    inside the unit circle (positive lags).

    Gives you A_(t+1)(Z) / A_(t)(Z)
    This is the plus operator in [1]

    Parameters
    ----------
    linear_predictor : array, shape (..., n_fft_samples, n_signals,
                                     n_signals)

    Returns
    -------
    causal_part_of_linear_predictor : array, shape (..., n_fft_samples,
                                                    n_signals, n_signals)

    """
    n_signals = linear_predictor.shape[-1]
    n_fft_samples = linear_predictor.shape[-3]
    linear_predictor_coefficients = ifft(linear_predictor, axis=-3)

    # Take half of the roots on the unit circle
    linear_predictor_coefficients[..., 0, :, :] *= 0.5

    # Make the unit circle roots upper triangular
    lower_triangular_ind = np.tril_indices(n_signals, k=-1)
    linear_predictor_coefficients[
        ..., 0, lower_triangular_ind[0], lower_triangular_ind[1]
    ] = 0

    # Take only the roots inside the unit circle (positive lags)
    linear_predictor_coefficients[..., (n_fft_samples + 1) // 2 :, :, :] = 0
    return fft(linear_predictor_coefficients, axis=-3)


def _check_convergence(current, old, tolerance=1e-8):
    """Check convergence of Wilson algorithm at each time point.

    Parameters
    ----------
    current : array, shape (n_time_points, ...)
        Current guess.
    old : array, shape (n_time_points, ...)
        Previous guess.
    tolerance : float
        Largest difference between guesses for the matrix to be judged as
        similar.

    Returns
    -------
    is_converged : array, shape (n_time_points,)
        Boolean array that indicates whether the array has converged for
        each time point.
    """
    n_time_points = current.shape[0]
    error = xp.linalg.norm(
        xp.reshape(current - old, (n_time_points, -1)), ord=xp.inf, axis=1
    )
    return error < tolerance


def _get_linear_predictor(minimum_phase_factor, cross_spectral_matrix, I):
    """Measure how close the minimum phase factor is to the original
    cross spectral matrix.

    Parameters
    ----------
    minimum_phase_factor : array, shape (n_time_samples, ...,
                                         n_fft_samples, n_signals,
                                         n_signals)
        The current minimum phase square root guess.
    cross_spectral_matrix : array, shape (n_time_samples, ...,
                                          n_fft_samples, n_signals,
                                          n_signals)
        The matrix to be factored.
    I : array, shape (n_signals, n_signals)
        Identity matrix.

    Returns
    -------
    linear_predictor : array, shape (n_time_samples, ..., n_fft_samples,
                                     n_signals, n_signals)
        How much to adjust for the next guess for minimum phase factor.

    """
    covariance_sandwich_estimator = xp.linalg.solve(
        minimum_phase_factor, cross_spectral_matrix
    )
    covariance_sandwich_estimator = xp.linalg.solve(
        minimum_phase_factor, _conjugate_transpose(covariance_sandwich_estimator)
    )
    return covariance_sandwich_estimator + I


[docs] def minimum_phase_decomposition( cross_spectral_matrix, tolerance=1e-8, max_iterations=60 ): """Find a minimum phase matrix square root of the cross spectral density using the Wilson algorithm. Parameters ---------- cross_spectral_matrix : array, shape (n_time_samples, ..., n_fft_samples, n_signals, n_signals) tolerance : float The maximum difference between guesses. max_iterations : int The maximum number of iterations for the algorithm to converge. Returns ------- minimum_phase_factor : array, shape (n_time_samples, ..., n_fft_samples, n_signals, n_signals) The square root of the `cross_spectral_matrix` where all the poles are inside the unit circle (minimum phase). """ n_time_points = cross_spectral_matrix.shape[0] n_signals = cross_spectral_matrix.shape[-1] I = xp.eye(n_signals) is_converged = xp.zeros(n_time_points, dtype=bool) minimum_phase_factor = xp.zeros(cross_spectral_matrix.shape) minimum_phase_factor[..., :, :, :] = _get_intial_conditions(cross_spectral_matrix) for iteration in range(max_iterations): logger.debug( "iteration: {0}, {1} of {2} converged".format( iteration, is_converged.sum(), len(is_converged) ) ) old_minimum_phase_factor = minimum_phase_factor.copy() linear_predictor = _get_linear_predictor( minimum_phase_factor, cross_spectral_matrix, I ) minimum_phase_factor = xp.matmul( minimum_phase_factor, _get_causal_signal(linear_predictor) ) # If already converged at a time point, don't change. minimum_phase_factor[is_converged, ...] = old_minimum_phase_factor[ is_converged, ... ] is_converged = _check_convergence( minimum_phase_factor, old_minimum_phase_factor, tolerance ) if xp.all(is_converged): return minimum_phase_factor else: logger.warning( "Maximum iterations reached. {} of {} converged".format( is_converged.sum(), len(is_converged) ) ) return minimum_phase_factor