Source code for pccg.splines

from scipy.interpolate import BSpline
from scipy.integrate import quad
import numpy as np
import math
import matplotlib.pyplot as plt
import torch


[docs] def bs(x, knots, boundary_knots, degree=3, intercept=False): """Generate the B-spline basis matrix for a polynomial spline. This function mimick the function bs in R package splines Args: x (Tensor): Values at which basis functions are evaludated. knots (Tensor): Internal breakpoints that define the spline. boundary_knots (Tensor): Boundary points degree (int, optional): The degree of the piecewise polynomial. The default is 3 for cubic splines. intercept (bool, optional): If True, an intercept is included in the basis. Default is False. Returns: design_matrix (Tensor): A tensor of dimension (len(x), df), where df = len(knots) + degree if intercept = False, df = len(knots) + degree + 1 if intercept = True. """ knots = knots.numpy() boundary_knots = boundary_knots.numpy() x = x.numpy() knots = np.concatenate([knots, boundary_knots]) knots.sort() augmented_knots = np.concatenate( [ np.array([boundary_knots[0] for i in range(degree + 1)]), knots, np.array([boundary_knots[1] for i in range(degree + 1)]), ] ) num_of_basis = len(augmented_knots) - 2 * (degree + 1) + degree + 1 spl_list = [] for i in range(num_of_basis): coeff = np.zeros(num_of_basis) coeff[i] = 1.0 spl = BSpline(augmented_knots, coeff, degree, extrapolate=False) spl_list.append(spl) design_matrix = np.array([spl(x) for spl in spl_list]).T ## if the intercept is Fales, drop the first basis term, which is often ## referred as the "intercept". Note that np.sum(design_matrix, -1) = 1. ## see https://cran.r-project.org/web/packages/crs/vignettes/spline_primer.pdf if intercept is False: design_matrix = design_matrix[:, 1:] design_matrix = torch.from_numpy(design_matrix) return design_matrix
[docs] def pbs( x, knots, boundary_knots=torch.tensor([-math.pi, math.pi]), degree=3, intercept=False, ): """Compute the design matrix of a periodic B-spline. This function mimick the pbs function in R package pbs. Args: x (Tensor): Values at which basis functions are evaludated. knots (Tensor): Internal breakpoints that define the spline. boundary_knots (Tensor): Boundary points degree (int, optional): The degree of the piecewise polynomial. The default is 3 for cubic splines. intercept (bool, optional): If True, an intercept is included in the basis. Default is False. Returns: design_matrix (Tensor): A tensor of dimension (len(x), df), where df = len(knots) if intercept = False, df = len(knots) + 1 if intercept = True. """ knots = knots.numpy() boundary_knots = boundary_knots.numpy() x = x.numpy() knots = np.concatenate([knots, boundary_knots]) knots.sort() augmented_knots = np.copy(knots) for i in range(degree): augmented_knots = np.append( augmented_knots, knots[-1] + knots[i + 1] - knots[0] ) for i in range(degree): augmented_knots = np.insert( augmented_knots, 0, knots[0] - (knots[-1] - knots[-1 - (i + 1)]) ) num_of_basis = len(augmented_knots) - 2 * (degree + 1) + degree + 1 spl_list = [] for i in range(num_of_basis): coeff = np.zeros(num_of_basis) coeff[i] = 1.0 spl = BSpline(augmented_knots, coeff, degree, extrapolate=False) spl_list.append(spl) design_matrix = np.array([spl(x) for spl in spl_list]).T design_matrix_left = design_matrix[:, 0:degree] design_matrix_right = design_matrix[:, -degree:] design_matrix_middle = design_matrix[:, degree:-degree] design_matrix = np.concatenate( [design_matrix_middle, design_matrix_left + design_matrix_right], axis=-1 ) ## if the intercept is Fales, drop the first basis term, which is often ## referred as the "intercept". ## see https://cran.r-project.org/web/packages/crs/vignettes/spline_primer.pdf if intercept is False: design_matrix = design_matrix[:, 1:] design_matrix = torch.from_numpy(design_matrix) return design_matrix
[docs] def bs_lj(r, r_min, r_max, num_of_basis, omega=False): """Compute the design matrix of a custimized B-spline for Lennard-Jones type interaction. Args: r (Tensor): Distances at which basis functions are evaluated. r_min (float): A cutoff distance. When r < r_min, the interaction becomes repulsive and the basis function will provide a postive value. r_max (float): A cutoff distance. When r > r_max, all basis functions are zeros. num_of_basis (int): The number of basis. omega (bool): Integral of secondary derivatives. If True, the function will also return a matrix omega. Omega[i,j] = \\int_{r_min}^{r_max} basis_i.derivative(2)*basis_j.derivative(2) dr. This matrix is useful when fitting a smoothing splines by addding a penaly term controling the secondary derivative of splines. Returns: design_matrix (Tensor): A matrix of dimension (len(r), num_of_basis). omega (Tensor): A matrix containing the integral of the splines' second derivatives """ r = r.numpy() ## degree of spline degree = 3 ## knots of cubic spline t = np.linspace(r_min, r_max + (r_max - r_min), num_of_basis * 2 + 3) ## number of basis n = len(t) - 2 + degree + 1 ## preappend and append knots t = np.concatenate( ( np.array([r_min for i in range(degree)]), t, np.array([r_max + (r_max - r_min) for i in range(degree)]), ) ) spl_list = [] for i in range(n): c = np.zeros(n) c[i] = 1.0 spl_list.append(BSpline(t, c, degree, extrapolate=True)) spl_list = spl_list[: -(n // 2 + 2)] spl_list = [spl_list[i] for i in range(len(spl_list)) if i != 1] design_matrix = [] for i in range(len(spl_list)): u = spl_list[i](r) if i != 0: u[r <= r_min] = 0.0 design_matrix.append(u) design_matrix = np.array(design_matrix).T if omega: omega = np.zeros((len(spl_list), len(spl_list))) for i in range(len(spl_list)): for j in range(i, len(spl_list)): spl_i = spl_list[i].derivative(2) spl_j = spl_list[j].derivative(2) omega[i, j] = quad( lambda x: spl_i(x) * spl_j(x), r_min, r_max, limit=10_000 )[0] omega[j, i] = omega[i, j] omega[0, :] = 0.0 omega[:, 0] = 0.0 return torch.from_numpy(design_matrix), torch.from_numpy(omega) else: return torch.from_numpy(design_matrix)
[docs] def bs_rmsd(r, r_max, num_of_basis): """Compute the design matrix of a custimized B-spline for a biasing potential on RMSD. Args: r (Tensor): Distances at which basis functions are evaluated. r_max (float): A cutoff distance. When r > r_max, all basis functions are zeros. num_of_basis (int): The number of basis. Returns: design_matrix (Tensor): A matrix of dimension (len(r), num_of_basis). """ r = r.numpy() ## degree of spline degree = 3 ## knots of cubic spline r_min = 0.0 t = np.linspace(r_min, r_max + (r_max - r_min), num_of_basis * 2 + 2) ## number of basis n = len(t) - 2 + degree + 1 ## preappend and append knots t = np.concatenate( ( np.array([r_min for i in range(degree)]), t, np.array([r_max + (r_max - r_min) for i in range(degree)]), ) ) spl_list = [] for i in range(n): c = np.zeros(n) c[i] = 1.0 spl_list.append(BSpline(t, c, degree, extrapolate=True)) spl_list = spl_list[: -(n // 2 + 2)] design_matrix = [] for i in range(len(spl_list)): u = spl_list[i](r) design_matrix.append(u) design_matrix = np.array(design_matrix).T return torch.from_numpy(design_matrix)
if __name__ == "__main__": ## testing functions bs and pbs knots = torch.linspace(start=-math.pi, end=math.pi, steps=10) knots = knots[1:-1] boundary_knots = torch.tensor([-math.pi, math.pi]) x = torch.linspace(start=-math.pi, end=math.pi, steps=200) degree = 3 design_matrix_bs = bs(x, knots, boundary_knots, degree).numpy() design_matrix_pbs = pbs(x, knots, boundary_knots, degree).numpy() fig, axes = plt.subplots() for j in range(design_matrix_bs.shape[-1]): plt.plot(x, design_matrix_bs[:, j], label=f"{j}", linewidth=3) # plt.legend() plt.tight_layout() fig.savefig("./output/design_matrix_bs.pdf") plt.close() fig, axes = plt.subplots() for j in range(design_matrix_pbs.shape[-1]): plt.plot(x, design_matrix_pbs[:, j], label=f"{j}") plt.legend() plt.tight_layout() fig.savefig("./output/design_matrix_pbs.pdf") plt.close() ## test the function bs_lj r_min, r_max = 0.3, 2.0 num_of_basis = 12 r = torch.linspace(r_min - 0.05, r_max + 1.0, 1000) design_matrix, omega = bs_lj(r, r_min, r_max, num_of_basis, True) fig, axes = plt.subplots() for j in range(design_matrix.shape[-1]): plt.plot(r, design_matrix[:, j], label=f"{j}") t = np.linspace(r_min, r_max + (r_max - r_min), num_of_basis * 2 + 3) for i in range(len(t)): if t[i] <= r_max: plt.axvline(t[i], linestyle="--") plt.legend() plt.tight_layout() fig.savefig("./output/design_matrix_bs_lj.pdf") plt.close() ## test the function bs_rmsd r_min, r_max = 0.0, 2.0 num_of_basis = 12 r = torch.linspace(0.0, r_max + 1.0, 1000) design_matrix = bs_rmsd(r, r_max, num_of_basis) fig, axes = plt.subplots() for j in range(design_matrix.shape[-1]): plt.plot(r, design_matrix[:, j], label=f"{j}") t = np.linspace(r_min, r_max + (r_max - r_min), num_of_basis * 2 + 3) for i in range(len(t)): if t[i] <= r_max: plt.axvline(t[i], linestyle="--") plt.legend() plt.tight_layout() fig.savefig("./output/design_matrix_bs_rmsd.pdf") plt.close()