卡尔曼滤波.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. #开发时间:2024/6/12 14:29
  2. import numpy as np
  3. import pandas as pd
  4. from filterpy.kalman import KalmanFilter
  5. import matplotlib.pyplot as plt
  6. from matplotlib import rcParams
  7. # 卡尔曼滤波去噪
  8. def kalman_denoise(signal):
  9. kf = KalmanFilter(dim_x=2, dim_z=1)
  10. kf.x = np.array([0., 0.])
  11. kf.F = np.array([[1., 1.],
  12. [0., 1.]])
  13. kf.H = np.array([[1., 0.]])
  14. kf.P *= 1000.
  15. kf.R = 5
  16. kf.Q = np.array([[1e-5, 0.],
  17. [0., 1e-5]])
  18. filtered_signal = []
  19. for z in signal:
  20. kf.predict()
  21. kf.update(z)
  22. filtered_signal.append(kf.x[0])
  23. return np.array(filtered_signal)
  24. plt.rcParams['xtick.direction'] = 'in'
  25. plt.rcParams['ytick.direction'] = 'in'
  26. plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
  27. plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
  28. config = {
  29. "font.family": 'serif',
  30. "font.size": 20,
  31. "mathtext.fontset": 'stix',
  32. "font.serif": ['Times New Roman'], # 宋体
  33. 'axes.unicode_minus': False # 处理负号
  34. }
  35. rcParams.update(config)
  36. # 读取信号
  37. fs = 1000
  38. noisy_signal = pd.read_csv('noisy_signals_time.csv')
  39. # 检查是否包含时间列
  40. if 'time' in noisy_signal.columns:
  41. time_column = noisy_signal['time'].values.reshape(-1, 1)
  42. other_columns = noisy_signal.drop(columns='time')
  43. else:
  44. time_column = None
  45. other_columns = noisy_signal
  46. # 应用去噪方法
  47. denoised_signals = pd.DataFrame()
  48. for column in other_columns.columns:
  49. denoised_signals[column] = kalman_denoise(other_columns[column].values)
  50. # 如果有时间列,将其添加回去
  51. if time_column is not None:
  52. denoised_signals.insert(0, 'time', time_column)
  53. # 保存去噪后的信号到CSV文件
  54. denoised_signals.to_csv('denoise_kalman.csv', index=False)
  55. # 绘图
  56. plt.figure(figsize=(12, 10))
  57. # 绘制原始信号和去噪后信号(这里只绘制第一列作为示例)
  58. plt.subplot(2, 1, 1)
  59. plt.plot(noisy_signal.iloc[:, 1], label='Noisy Signal')
  60. plt.legend()
  61. plt.subplot(2, 1, 2)
  62. plt.plot(denoised_signals.iloc[:, 1], label='Kalman Filtered')
  63. plt.legend()
  64. plt.tight_layout()
  65. plt.show()