Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 190 additions & 0 deletions crates/RustQuant_math/src/interpolation/lagrange_interpolator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// RustQuant: A Rust library for quantitative finance tools.
// Copyright (C) 2023 https://github.com/avhz
// Dual licensed under Apache 2.0 and MIT.
// See:
// - LICENSE-APACHE.md
// - LICENSE-MIT.md
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

//! Module containing functionality for interpolation.

use crate::interpolation::{InterpolationIndex, InterpolationValue, Interpolator};
use RustQuant_error::RustQuantError;

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// STRUCTS & ENUMS
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

/// Linear Interpolator.
pub struct LagrangeInterpolator<IndexType, ValueType>
where
IndexType: InterpolationIndex<DeltaDiv = ValueType>,
ValueType: InterpolationValue,
{
/// X-axis values for the interpolator.
pub xs: Vec<IndexType>,

/// Y-axis values for the interpolator.
pub ys: Vec<ValueType>,

/// Whether the interpolator has been fitted.
pub fitted: bool,
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// IMPLEMENTATIONS, FUNCTIONS, AND MACROS
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

impl<IndexType, ValueType> LagrangeInterpolator<IndexType, ValueType>
where
IndexType: InterpolationIndex<DeltaDiv = ValueType>,
ValueType: InterpolationValue,
{
/// Create a new LagrangeInterpolator.
///
/// # Errors
/// - `RustQuantError::UnequalLength` if ```xs.length() != ys.length()```.
///
/// # Panics
/// Panics if NaN is in the index.
pub fn new(
xs: Vec<IndexType>,
ys: Vec<ValueType>,
) -> Result<LagrangeInterpolator<IndexType, ValueType>, RustQuantError> {
if xs.len() != ys.len() {
return Err(RustQuantError::UnequalLength);
}

let mut tmp: Vec<_> = xs.into_iter().zip(ys).collect();

tmp.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());

let (xs, ys): (Vec<IndexType>, Vec<ValueType>) = tmp.into_iter().unzip();

Ok(Self {
xs,
ys,
fitted: false,
})
}

fn lagrange_basis(&self, point: IndexType, node: IndexType, index: usize) -> ValueType {
let mut basis: ValueType = ValueType::one();
for (i, x) in self.xs.iter().enumerate() {
if i != index {
basis *= (point - *x) / (node - *x);
}
}
basis
}

fn lagrange_polynomial(&self, point: IndexType) -> ValueType {
let mut polynomial: ValueType = ValueType::zero();
for (i, (x, y)) in self.xs.iter().zip(&self.ys).enumerate() {
polynomial += *y * self.lagrange_basis(point, *x, i);

}
polynomial
}
}

impl<IndexType, ValueType> Interpolator<IndexType, ValueType>
for LagrangeInterpolator<IndexType, ValueType>
where
IndexType: InterpolationIndex<DeltaDiv = ValueType>,
ValueType: InterpolationValue,
{
fn fit(&mut self) -> Result<(), RustQuantError> {
self.fitted = true;
Ok(())
}

fn range(&self) -> (IndexType, IndexType) {
(*self.xs.first().unwrap(), *self.xs.last().unwrap())
}

fn add_point(&mut self, point: (IndexType, ValueType)) {
let idx = self.xs.partition_point(|&x| x < point.0);
self.xs.insert(idx, point.0);
self.ys.insert(idx, point.1);
}

fn interpolate(&self, point: IndexType) -> Result<ValueType, RustQuantError> {
let range = self.range();
if point.partial_cmp(&range.0).unwrap() == std::cmp::Ordering::Less
|| point.partial_cmp(&range.1).unwrap() == std::cmp::Ordering::Greater
{
return Err(RustQuantError::OutsideOfRange);
}
if let Ok(idx) = self
.xs
.binary_search_by(|p| p.partial_cmp(&point).expect("Cannot compare values."))
{
return Ok(self.ys[idx]);
}

Ok(self.lagrange_polynomial(point))
}
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Unit tests
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

#[cfg(test)]
mod tests_lagrange_interpolation {
use super::*;
use RustQuant_utils::{assert_approx_equal, RUSTQUANT_EPSILON};

#[test]
fn test_lagrange_interpolation() {
let xs: Vec<f64> = vec![0., 1., 2., 3., 4.];
let ys: Vec<f64> = vec![1., 2., 4., 8., 16.];

let mut interpolator = LagrangeInterpolator::new(xs, ys).unwrap();
let _ = interpolator.fit();

assert_approx_equal!(
5.6484375,
interpolator.interpolate(2.5).unwrap(),
RUSTQUANT_EPSILON
);
}

#[test]
fn test_lagrange_interpolation_dates() {
let now: time::OffsetDateTime = time::OffsetDateTime::now_utc();

let xs: Vec<time::OffsetDateTime> = vec![
now,
now + time::Duration::days(1),
now + time::Duration::days(2),
now + time::Duration::days(3),
now + time::Duration::days(4),
];
let ys: Vec<f64> = vec![1., 2., 4., 8., 16.];

let mut interpolator: LagrangeInterpolator<time::OffsetDateTime, f64> = LagrangeInterpolator::new(xs.clone(), ys).unwrap();
let _ = interpolator.fit();

assert_approx_equal!(
5.6484375,
interpolator
.interpolate(xs[2] + time::Duration::hours(12))
.unwrap(),
RUSTQUANT_EPSILON
);
}

#[test]
fn test_linear_interpolation_out_of_range() {
let xs: Vec<f64> = vec![1., 2., 3., 4., 5.];
let ys: Vec<f64> = vec![1., 2., 3., 4., 5.];

let mut interpolator: LagrangeInterpolator<f64, f64> = LagrangeInterpolator::new(xs, ys).unwrap();
let _ = interpolator.fit();

assert!(interpolator.interpolate(6.).is_err());
}
}
3 changes: 3 additions & 0 deletions crates/RustQuant_math/src/interpolation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ pub use exponential_interpolator::*;
pub mod b_splines;
pub use b_splines::*;

pub mod lagrange_interpolator;
pub use lagrange_interpolator::*;

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
Loading