pm21-dragon/lectures/lecture-13/Kalman Filter.ipynb

272 lines
7.4 KiB
Plaintext
Raw Normal View History

2025-01-24 03:28:43 -05:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Kalman filter\n",
"\n",
"We will go through this notebook after a Power Point lecture.\n",
"\n",
"In the Power Point lecture, we showed -- amongst other imates -- figures from [this PDF](https://synapticlab.co.kr/attachment/cfile1.uf@2737C54B590907BA0D46CE.pdf) ([doi:10.1109/MSP.2012.2203621](https://doi.org/10.1109/MSP.2012.2203621)).\n",
"\n",
"As further reading, I recommend this webpage: [How a Kalman filter works, in pictures](https://www.bzarg.com/p/how-a-kalman-filter-works-in-pictures/).\n",
"\n",
"For a wonderful, if hardcore, Python-based Kalman filter library and documentation, please see [FilterPy](https://filterpy.readthedocs.io/en/latest/)."
]
},
{
"cell_type": "code",
2025-01-31 03:53:06 -05:00
"execution_count": null,
2025-01-24 03:28:43 -05:00
"metadata": {},
2025-01-31 03:53:06 -05:00
"outputs": [],
2025-01-24 03:28:43 -05:00
"source": [
"!pip install adskalman"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Kalman filter example in Python"
]
},
{
"cell_type": "code",
2025-01-31 03:53:06 -05:00
"execution_count": null,
2025-01-24 03:28:43 -05:00
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import adskalman.adskalman as adskalman\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
2025-01-31 03:53:06 -05:00
"execution_count": null,
2025-01-24 03:28:43 -05:00
"metadata": {},
"outputs": [],
"source": [
"def column(arr):\n",
" \"\"\"convert 1D array-like to a 2D vertical array\n",
"\n",
" >>> column((1,2,3))\n",
"\n",
" array([[1],\n",
" [2],\n",
" [3]])\n",
" \"\"\"\n",
" arr = np.array(arr)\n",
" assert arr.ndim == 1\n",
" a2 = arr[:, np.newaxis]\n",
" return a2"
]
},
{
"cell_type": "code",
2025-01-31 03:53:06 -05:00
"execution_count": null,
2025-01-24 03:28:43 -05:00
"metadata": {},
"outputs": [],
"source": [
"# Create a 4-dimensional state space model:\n",
"# (x, y, xvel, yvel).\n",
"dt = 0.01\n",
"true_initial_state = column([0.0, 0.0, 10.0, -5.0])\n",
"# This is F in wikipedia language.\n",
"motion_model = np.array([[1.0, 0.0, dt, 0.0],\n",
" [0.0, 1.0, 0.0, dt],\n",
" [0.0, 0.0, 1.0, 0.0],\n",
" [0.0, 0.0, 0.0, 1.0]])\n",
"\n",
"# This is Q in wikipedia language. For a constant velocity form, it must take this specific form to be correct.\n",
"T3 = dt**3/3\n",
"T2 = dt**2/2\n",
"motion_noise_covariance = 1000.0*np.array([[T3, 0.0, T2, 0.0],\n",
" [0.0, T3, 0.0, T2],\n",
" [T2, 0.0, dt, 0.0],\n",
" [0.0, T2, 0.0, dt]])"
]
},
{
"cell_type": "code",
2025-01-31 03:53:06 -05:00
"execution_count": null,
2025-01-24 03:28:43 -05:00
"metadata": {},
"outputs": [],
"source": [
"duration = 0.5\n",
"t = np.arange(0.0, duration, dt)"
]
},
{
"cell_type": "code",
2025-01-31 03:53:06 -05:00
"execution_count": null,
2025-01-24 03:28:43 -05:00
"metadata": {},
"outputs": [],
"source": [
"# Create some fake data with our model.\n",
"current_state = true_initial_state\n",
"state = []\n",
"for _ in t:\n",
" state.append(current_state[:, 0])\n",
" noise_sample = adskalman.rand_mvn(np.zeros(4), motion_noise_covariance, 1).T\n",
" current_state = np.dot(motion_model, current_state) + noise_sample\n",
"state = np.array(state)"
]
},
{
"cell_type": "code",
2025-01-31 03:53:06 -05:00
"execution_count": null,
2025-01-24 03:28:43 -05:00
"metadata": {},
2025-01-31 03:53:06 -05:00
"outputs": [],
2025-01-24 03:28:43 -05:00
"source": [
"plt.plot(state[:, 0], state[:, 1], '.-')\n",
"plt.xlabel('x')\n",
"_ = plt.ylabel('y')"
]
},
{
"cell_type": "code",
2025-01-31 03:53:06 -05:00
"execution_count": null,
2025-01-24 03:28:43 -05:00
"metadata": {},
"outputs": [],
"source": [
"# Create observation model. We only observe the position.\n",
"observation_model = np.array([[1.0, 0.0, 0.0, 0.0],\n",
" [0.0, 1.0, 0.0, 0.0]])\n",
"observation_noise_covariance = np.array([[0.01, 0.0],\n",
" [0.0, 0.01]])"
]
},
{
"cell_type": "code",
2025-01-31 03:53:06 -05:00
"execution_count": null,
2025-01-24 03:28:43 -05:00
"metadata": {},
"outputs": [],
"source": [
"# Create noisy observations.\n",
"observation = []\n",
"for current_state in state:\n",
" noise_sample = adskalman.rand_mvn(np.zeros(2), observation_noise_covariance, 1).T\n",
" current_observation = np.dot(observation_model, column(current_state)) + noise_sample\n",
" observation.append(current_observation[:, 0])\n",
"observation = np.array(observation)"
]
},
{
"cell_type": "code",
2025-01-31 03:53:06 -05:00
"execution_count": null,
2025-01-24 03:28:43 -05:00
"metadata": {},
2025-01-31 03:53:06 -05:00
"outputs": [],
2025-01-24 03:28:43 -05:00
"source": [
"plt.plot(observation[:, 0], observation[:, 1], '.-')\n",
"plt.xlabel('x')\n",
"_ = plt.ylabel('y')"
]
},
{
"cell_type": "code",
2025-01-31 03:53:06 -05:00
"execution_count": null,
2025-01-24 03:28:43 -05:00
"metadata": {},
"outputs": [],
"source": [
"# Run kalman filter on the noisy observations.\n",
"y = observation\n",
"F = motion_model\n",
"H = observation_model\n",
"Q = motion_noise_covariance\n",
"R = observation_noise_covariance\n",
"initx = true_initial_state[:, 0]\n",
"initV = 0.1*np.eye(4)"
]
},
{
"cell_type": "code",
2025-01-31 03:53:06 -05:00
"execution_count": null,
2025-01-24 03:28:43 -05:00
"metadata": {},
"outputs": [],
"source": [
"kfilt = adskalman.KalmanFilter(F, H, Q, R, initx, initV)\n",
"xfilt = []\n",
"Vfilt = []\n",
"for i, y_i in enumerate(y):\n",
" is_initial = i == 0\n",
" xfilt_i, Vfilt_i = kfilt.step(y=y_i, isinitial=is_initial)\n",
" xfilt.append(xfilt_i)\n",
" Vfilt.append(Vfilt_i)\n",
"xfilt = np.array(xfilt)\n",
"Vfilt = np.array(Vfilt)"
]
},
{
"cell_type": "code",
2025-01-31 03:53:06 -05:00
"execution_count": null,
2025-01-24 03:28:43 -05:00
"metadata": {},
2025-01-31 03:53:06 -05:00
"outputs": [],
2025-01-24 03:28:43 -05:00
"source": [
"plt.plot(state[:, 0], state[:, 1], '.-', label='true')\n",
"plt.plot(observation[:, 0], observation[:, 1], '.-', label='observed')\n",
"plt.plot(xfilt[:, 0], xfilt[:, 1], '.-', label='kalman filtered')\n",
"plt.xlabel('x')\n",
"plt.ylabel('y')\n",
"_ = plt.legend()"
]
},
{
"cell_type": "code",
2025-01-31 03:53:06 -05:00
"execution_count": null,
2025-01-24 03:28:43 -05:00
"metadata": {},
"outputs": [],
"source": [
"# Now run again with missing data\n",
"y[20:30, :] = np.nan\n",
"kfilt = adskalman.KalmanFilter(F, H, Q, R, initx, initV)\n",
"xfilt = []\n",
"Vfilt = []\n",
"for i, y_i in enumerate(y):\n",
" is_initial = i == 0\n",
" xfilt_i, Vfilt_i = kfilt.step(y=y_i, isinitial=is_initial)\n",
" xfilt.append(xfilt_i)\n",
" Vfilt.append(Vfilt_i)\n",
"xfilt = np.array(xfilt)\n",
"Vfilt = np.array(Vfilt)"
]
},
{
"cell_type": "code",
2025-01-31 03:53:06 -05:00
"execution_count": null,
2025-01-24 03:28:43 -05:00
"metadata": {},
2025-01-31 03:53:06 -05:00
"outputs": [],
2025-01-24 03:28:43 -05:00
"source": [
"plt.plot(state[:, 0], state[:, 1], '.-', label='true')\n",
"plt.plot(observation[:, 0], observation[:, 1], '.-', label='observed')\n",
"plt.plot(xfilt[:, 0], xfilt[:, 1], '.-', label='kalman filtered')\n",
"plt.xlabel('x')\n",
"plt.ylabel('y')\n",
"_ = plt.legend()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
"nbformat_minor": 4
}