Source code for lbfgsb.cauchy

# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2025 Antoine COLLET

"""
Implement a function to compute the generalized Cauchy point (GCP) for the L-BFGS-B
algorithm, mainly for internal use.

The target of the GCP procedure is to find a step size t such that
x(t) = x0 - t * g is a local minimum of the quadratic function m(x),
where m(x) is a local approximation to the objective function.

First determine a sequence of break points t0=0, t1, t2, ..., tn.
On each interval [t[i-1], t[i]], x is changing linearly.
After passing a break point, one or more coordinates of x will be fixed at the bounds.
We search the first local minimum of m(x) by examining the intervals [t[i-1], t[i]]
sequentially.

Functions
^^^^^^^^^

.. autosummary::
   :toctree: _autosummary

    get_cauchy_point

Reference:
[1] R. H. Byrd, P. Lu, and J. Nocedal (1995). A limited memory algorithm for bound
constrained optimization.
"""

import logging
from typing import Optional, Tuple

import numpy as np

from lbfgsb._numba_helpers import njit
from lbfgsb.bfgsmats import LBFGSB_MATRICES, bmv, bmv_numba
from lbfgsb.types import NDArrayFloat, NDArrayInt

_FLOAT_MAX = np.finfo(np.float64).max


def display_start_point(
    nseg: int,
    f_prime: float,
    f_second: float,
    delta_t: Optional[float],
    delta_t_min: float,
    iprint: int,
    logger: Optional[logging.Logger],
) -> None:
    """
    Display the start point status.

    Parameters
    ----------
    nseg : int
        Number of explored segment.
    f_prime : float
        First derivative.
    f_second : float
        Second derivative.
    delta_t : float
        See Algorithm CP: Computation of the generalized Cauchy point in [1].
    delta_t_min : float
        See Algorithm CP: Computation of the generalized Cauchy point in [1].
    iprint : int, optional
        Controls the frequency of output. ``iprint < 0`` means no output;
        ``iprint = 0``    print only one line at the last iteration;
        ``0 < iprint < 99`` print also f and ``|proj g|`` every iprint iterations;
        ``iprint >= 99``   print details of every iteration except n-vectors;
    logger: Optional[Logger], optional
        :class:`logging.Logger` instance. If None, nothing is displayed, no matter the
        value of `iprint`, by default None.

    """
    if iprint < 100:
        return
    if logger is None:
        return
    logger.info(f"Piece    , {nseg},  --f1, f2 at start point , {f_prime} , {f_second}")
    if delta_t is not None:
        logger.info(f"Distance to the next break point =  {delta_t}")
    logger.info(f"Distance to the stationary point =  {delta_t_min}")


def _get_cauchy_point_numpy(
    x: NDArrayFloat,
    grad: NDArrayFloat,
    lb: NDArrayFloat,
    ub: NDArrayFloat,
    W: NDArrayFloat,
    theta: float,
    invMfactors: Tuple[NDArrayFloat, NDArrayFloat],
    use_factor: bool,
    iprint: int,
    logger: Optional[logging.Logger] = None,
):
    eps_f_sec = 1e-30
    x_cp: NDArrayFloat = x.copy()

    # To define the breakpoints in each coordinate direction, we compute
    t = np.empty_like(grad)
    t.fill(np.inf)
    # masks
    neg = grad < 0
    pos = grad > 0
    # update breakpoints
    t[neg] = (x[neg] - ub[neg]) / grad[neg]
    t[pos] = (x[pos] - lb[pos]) / grad[pos]

    # used to store the Cauchy direction `P(x-tg)-x`.
    d = np.where(t == 0, 0.0, -grad)

    # In the end, F is the list of ordered breakpoint indices
    # sort {t;,i = 1,. ..,n} in increasing order to obtain the ordered
    # set {tj :tj <= tj+1 ,j = 1, ...,n}.
    # Keep only the indices where t > 0
    # Note: sorts only positive breakpoints to reduces sort cost from O(n log n)
    # to O(k log k) where k ≪ n in practice
    pos_idx = np.flatnonzero(t > 0)
    sorted_t_idx: NDArrayInt = pos_idx[np.argsort(t[pos_idx])]

    # Initialization
    p = W.T @ d  # 2mn operations

    # Initialize c = W'(xcp - x) = 0.
    c: NDArrayFloat = np.zeros(p.size)

    # Initialize f1
    f_prime: float = -d.dot(d)  # n operations

    # Initialize derivative f2.
    f_second: float = -theta * f_prime
    f2_org: float = f_second + 0.0  # make a copy

    # Update f2 with - d^{T} @ W @ M @ W^{T} @ d = - p^{T} @ M @ p
    # old way: f2 = f2 - p.dot(M.dot(p))  # O(m^{2}) operations
    # new_way: not at first iteration -> invMfactors and M are worse zero.
    # And cho_solve produces nan so we use bmv
    if use_factor:
        f_second = f_second - p.dot(bmv(invMfactors, p))  # O(m^{2}) operations

    # dtm in the fortran code
    delta_t_min: float = -f_prime / f_second

    # Number of breakpoints
    nbreak = len(sorted_t_idx)
    # Handler the case where there are no breakpoints
    if nbreak == 0:
        # is a zero vector, return with the initial xcp as GCP.
        return x_cp, c

    # iter in the fortran code and b in [1]
    _i = 0
    # break point index (b in section 4 [1])
    ibp: int = sorted_t_idx[_i].item()
    # value of the smallest breakpoint, t in section 4 [1]
    t_cur: float = t[ibp]
    # previous breakpoint value
    t_old = 0.0

    delta_t: float = t_cur - 0.0

    # Number of the breakpoint segment -> Nseg in Fortran
    nseg: int = 1

    if iprint >= 99 and logger is not None:
        logger.info(f"There are {nbreak} breakpoints ")

    # flag
    is_gpc_found = False

    nbreak = len(sorted_t_idx)
    while _i < nbreak:
        display_start_point(
            nseg, f_prime, f_second, delta_t, delta_t_min, iprint, logger
        )

        if delta_t_min < delta_t:
            is_gpc_found = True
            break

        # Fix one variable and reset the corresponding component of d to zero.
        if d[ibp] > 0:
            x_cp[ibp] = ub[ibp]
        elif d[ibp] < 0:
            x_cp[ibp] = lb[ibp]
        zb = x_cp[ibp] - x[ibp]

        if iprint >= 100 and logger is not None:
            # ibp +1 to match the Fortran code (because index starts at 1)
            logger.info(f"Variable  {ibp + 1} is fixed.")

        c += delta_t * p
        W_b = W[ibp, :]
        g_b = grad[ibp]

        # Update the derivative information
        # 1) Old way
        # f1 += delta_t * f2 + g_b * (g_b + theta * zb - W_b.dot(M.dot(c)))
        # f2 -= g_b * (g_b * theta + W_b.dot(M.dot(2 * p + g_b * W_b)))
        # 2) New way with the cholesky factorization
        f_prime += delta_t * f_second + g_b * (g_b + theta * zb)
        f_second -= g_b * g_b * theta

        # First iteration -> invMfactors and M are worse zero.
        # And cho_solve produces nan
        if use_factor:
            invMWb = bmv(invMfactors, W_b)
            f_prime -= g_b * invMWb.dot(c)
            f_second -= g_b * (2.0 * invMWb.dot(p) + g_b * invMWb.dot(W_b))

        # this is a trick of the original FORTRAN code that prevents very low
        # values of f2
        f_second = max(f_second, eps_f_sec * f2_org)

        # Fix one variable and reset the corresponding component of d to zero.
        p += g_b * W_b
        d[ibp] = 0
        delta_t_min = -f_prime / f_second
        t_old = t_cur + 0.0  # copy

        _i += 1
        if _i + 1 < nbreak:
            ibp = sorted_t_idx[_i + 1].item()
            t_cur = t[ibp]
        else:
            # to ensure that delta_t > delta_t_min and break the while
            t_cur = np.inf

        delta_t = t_cur - t_old
        nseg += 1

    if iprint >= 99 and logger is not None:
        if is_gpc_found:
            logger.info("GCP found in this segment")
            display_start_point(
                nseg, f_prime, f_second, None, delta_t_min, iprint, logger
            )

    delta_t_min = 0 if delta_t_min < 0 else delta_t_min
    t_old += delta_t_min

    mask = t >= t_cur
    x_cp[mask] = x[mask] + t_old * d[mask]

    return x_cp, c + delta_t_min * p


@njit(cache=True)
def _all_finite_numba(x: NDArrayFloat) -> bool:
    for i in range(x.size):
        if not np.isfinite(x[i]):
            return False
    return True


@njit(cache=True)
def _safe_sumsq_numba(x: NDArrayFloat) -> float:
    """
    Return sum(x**2) without emitting overflow warnings.

    If the squared norm cannot be represented in float64, return np.inf.
    """
    max_abs = 0.0

    for i in range(x.size):
        ax = abs(x[i])
        if ax > max_abs:
            max_abs = ax

    if max_abs == 0.0:
        return 0.0

    if not np.isfinite(max_abs):
        return np.inf

    scaled_sum = 0.0
    for i in range(x.size):
        y = x[i] / max_abs
        scaled_sum += y * y

    if scaled_sum == 0.0:
        return 0.0

    # Need max_abs**2 * scaled_sum <= FLOAT_MAX.
    # Avoid computing max_abs**2 directly before checking.
    limit = np.sqrt(_FLOAT_MAX / scaled_sum)

    if max_abs > limit:
        return np.inf

    return max_abs * max_abs * scaled_sum


@njit(cache=True)
def _safe_delta_t_min_numba(f_prime: float, f_second: float) -> float:
    """
    Safely compute -f_prime / f_second.

    Returns np.inf when the local quadratic model is unusable.
    """
    if not np.isfinite(f_prime):
        return np.inf

    if not np.isfinite(f_second):
        return np.inf

    if f_second <= 0.0:
        return np.inf

    return -f_prime / f_second


@njit(cache=True)
def _safe_curvature_floor_numba(
    f_second: float,
    f2_org: float,
    eps_f_sec: float,
) -> float:
    """Apply the curvature floor only when meaningful."""
    if not np.isfinite(f_second):
        return np.inf

    if not np.isfinite(f2_org):
        return f_second

    if f2_org <= 0.0:
        return f_second

    floor = eps_f_sec * f2_org

    if f_second < floor:
        return floor

    return f_second


@njit(cache=True)
def _get_cauchy_point_numba(
    x: NDArrayFloat,
    grad: NDArrayFloat,
    lb: NDArrayFloat,
    ub: NDArrayFloat,
    W: NDArrayFloat,
    theta: float,
    invMfactors: Tuple[NDArrayFloat, NDArrayFloat],
    use_factor: bool,
):
    n = x.size
    m = W.shape[1]
    eps_f_sec = 1e-30

    x_cp = x.copy()
    t = np.empty(n)
    d = np.empty(n)

    # Breakpoints
    for i in range(n):
        gi = grad[i]

        if gi < 0.0:
            t[i] = (x[i] - ub[i]) / gi
        elif gi > 0.0:
            t[i] = (x[i] - lb[i]) / gi
        else:
            t[i] = np.inf

        d[i] = -gi if t[i] != 0.0 else 0.0

    # If the initial search direction is already non-finite, stop early.
    if not _all_finite_numba(d):
        return x_cp, np.zeros(m)

    # Positive breakpoints
    pos_idx = np.empty(n, dtype=np.int64)
    k = 0

    for i in range(n):
        if t[i] > 0.0:
            pos_idx[k] = i
            k += 1

    if k == 0:
        return x_cp, np.zeros(m)

    pos_idx = pos_idx[:k]
    pos_idx = pos_idx[np.argsort(t[pos_idx])]

    # Initialization
    p = W.T @ d
    c = np.zeros(m)

    if not _all_finite_numba(p):
        return x_cp, c

    dd = _safe_sumsq_numba(d)

    # If ||d||^2 overflows, the local quadratic model is unusable.
    # Return the current finite Cauchy state instead of propagating inf/nan.
    if not np.isfinite(dd):
        return x_cp, c

    f_prime = -dd
    f_second = theta * dd
    f2_org = f_second

    if not np.isfinite(f_second) or f_second <= 0.0:
        return x_cp, c

    if use_factor:
        tmp = bmv_numba(*invMfactors, p)

        if not _all_finite_numba(tmp):
            return x_cp, c

        correction = np.dot(p, tmp)

        if not np.isfinite(correction):
            return x_cp, c

        f_second -= correction

    f_second = _safe_curvature_floor_numba(f_second, f2_org, eps_f_sec)

    delta_t_min = _safe_delta_t_min_numba(f_prime, f_second)

    if not np.isfinite(delta_t_min):
        return x_cp, c

    t_old = 0.0
    ibp = pos_idx[0]
    t_cur = t[ibp]
    delta_t = t_cur
    i = 0

    while i < k:
        if not np.isfinite(delta_t):
            return x_cp, c

        if delta_t_min < delta_t:
            break

        zb = ub[ibp] - x[ibp] if d[ibp] > 0.0 else lb[ibp] - x[ibp]

        if not np.isfinite(zb):
            return x_cp, c

        x_cp[ibp] = x[ibp] + zb

        # c += delta_t * p
        if not _all_finite_numba(p):
            return x_cp, c

        c += delta_t * p

        if not _all_finite_numba(c):
            return x_cp, np.zeros(m)

        Wb = W[ibp]
        gb = grad[ibp]

        if not np.isfinite(gb):
            return x_cp, c

        # Update f_prime and f_second carefully.
        f_prime += delta_t * f_second + gb * (gb + theta * zb)
        f_second -= gb * gb * theta

        if not np.isfinite(f_prime) or not np.isfinite(f_second):
            return x_cp, c

        if use_factor:
            invMWb = bmv_numba(*invMfactors, Wb)

            if not _all_finite_numba(invMWb):
                return x_cp, c

            dot_invMWb_c = np.dot(invMWb, c)
            dot_invMWb_p = np.dot(invMWb, p)
            dot_invMWb_Wb = np.dot(invMWb, Wb)

            if (
                not np.isfinite(dot_invMWb_c)
                or not np.isfinite(dot_invMWb_p)
                or not np.isfinite(dot_invMWb_Wb)
            ):
                return x_cp, c

            f_prime -= gb * dot_invMWb_c
            f_second -= gb * (2.0 * dot_invMWb_p + gb * dot_invMWb_Wb)

            if not np.isfinite(f_prime) or not np.isfinite(f_second):
                return x_cp, c

        f_second = _safe_curvature_floor_numba(f_second, f2_org, eps_f_sec)

        if not np.isfinite(f_second) or f_second <= 0.0:
            return x_cp, c

        p += gb * Wb

        if not _all_finite_numba(p):
            return x_cp, c

        d[ibp] = 0.0

        delta_t_min = _safe_delta_t_min_numba(f_prime, f_second)

        if not np.isfinite(delta_t_min):
            return x_cp, c

        t_old = t_cur

        i += 1

        if i < k:
            ibp = pos_idx[i]
            t_cur = t[ibp]
            delta_t = t_cur - t_old
        else:
            t_cur = np.inf
            delta_t = t_cur

    if delta_t_min < 0.0:
        delta_t_min = 0.0

    t_old += delta_t_min

    if not np.isfinite(t_old):
        return x_cp, c

    for i in range(n):
        if t[i] >= t_cur:
            xi = x[i] + t_old * d[i]

            if np.isfinite(xi):
                x_cp[i] = xi
            else:
                return x_cp, c

    c += delta_t_min * p

    if not _all_finite_numba(c):
        return x_cp, np.zeros(m)

    return x_cp, c


[docs] def get_cauchy_point( x: NDArrayFloat, grad: NDArrayFloat, lb: NDArrayFloat, ub: NDArrayFloat, mats: LBFGSB_MATRICES, iprint: int, logger: Optional[logging.Logger] = None, is_use_numba_jit: bool = False, ) -> Tuple[NDArrayFloat, NDArrayFloat]: r""" Computes the generalized Cauchy point (GCP). This is the Generalized Cauchy point procedure in section 4 of [1]. It is defined as the first local minimizer of the quadratic .. math:: \[\langle grad,s\rangle + \frac{1}{2} \langle s, (\theta I + WMW^\intercal)s\rangle\] along the projected gradient direction .. math::`P_[l,u](x-\theta grad).` Parameters ---------- x : NDArrayFloat Starting point for the GCP computation. grad : NDArrayFloat Gradient of fun with respect to x. lb : NDArrayFloat Lower bound vector. ub : NDArrayFloat Upper bound vector. mats: LBFGSB_MATRICES Wrapper for L-BFGS-B matrices. iprint : int, optional Controls the frequency of output. ``iprint < 0`` means no output; ``iprint = 0`` print only one line at the last iteration; ``0 < iprint < 99`` print also f and ``|proj g|`` every iprint iterations; ``iprint >= 99`` print details of every iteration except n-vectors; logger: Optional[Logger], optional :class:`logging.Logger` instance. If None, nothing is displayed, no matter the value of `iprint`, by default None. is_use_numba_jit: bool Whether to use `numba` just-in-time compilation to speed-up the computation intensive part of the algorithm. The default is False. .. versionadded:: 1.0 Returns ------- Tuple[NDArrayFloat, NDArrayFloat] The array of Cauchy points and c = W @ (Zc - Zk). References ---------- * R. H. Byrd, P. Lu and J. Nocedal. A Limited Memory Algorithm for Bound Constrained Optimization, (1995), SIAM Journal on Scientific and Statistical Computing, 16, 5, pp. 1190-1208. * C. Zhu, R. H. Byrd and J. Nocedal. L-BFGS-B: Algorithm 778: L-BFGS-B, FORTRAN routines for large scale bound constrained optimization (1997), ACM Transactions on Mathematical Software, 23, 4, pp. 550 - 560. * J.L. Morales and J. Nocedal. L-BFGS-B: Remark on Algorithm 778: L-BFGS-B, FORTRAN routines for large scale bound constrained optimization (2011), ACM Transactions on Mathematical Software, 38, 1. """ # Note: the variable names follow the FORTRAN original implementation if iprint >= 99 and logger is not None: logger.info("---------------- CAUCHY entered-------------------") if is_use_numba_jit: x_cp, c = _get_cauchy_point_numba( x, grad, lb, ub, mats.W, mats.theta, mats.invMfactors, mats.use_factor ) else: x_cp, c = _get_cauchy_point_numpy( x, grad, lb, ub, mats.W, mats.theta, mats.invMfactors, mats.use_factor, iprint, logger, ) if logger is not None: if iprint > 100: logger.info(f"Cauchy X = {x_cp}") if iprint >= 99: logger.info("---------------- exit CAUCHY----------------------") return x_cp, c