You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
152 lines
5.9 KiB
Python
152 lines
5.9 KiB
Python
# Copyright 2020 IBM Corporation
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# Standard libraries
|
|
import os
|
|
import logging
|
|
import itertools
|
|
import datetime as dt
|
|
# Data processing libraries
|
|
import numpy as np
|
|
import pandas as pd
|
|
from tqdm import tqdm
|
|
from joblib import Parallel, delayed
|
|
# Physics model
|
|
from orbit_prediction import get_state_vect_cols
|
|
from orbit_prediction.physics_model import PhysicsModel
|
|
|
|
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def predict_orbit(window):
|
|
"""Predict the state vectors of each future timestep in the given `window`
|
|
using a physics astrodynamics model.
|
|
|
|
:param window: The window of timesteps to predict the orbit of the ASO for
|
|
:type window: pandas.DataFrame
|
|
|
|
:return: The original timestep rows with the predicted state vectors added
|
|
:rtype: pandas.DataFrame
|
|
"""
|
|
# The `window` DataFrame is reverse sorted by time so the starting position
|
|
# is the last row
|
|
start_row = window.iloc[-1]
|
|
start_epoch = start_row.name
|
|
# Get the column names of the state vector components
|
|
state_vect_comps = get_state_vect_cols()
|
|
# Extract the position and velocity vectors as a numpy array
|
|
start_state_vect = start_row[state_vect_comps].to_numpy()
|
|
start_state = np.concatenate((np.array([start_epoch]),
|
|
start_state_vect))
|
|
# Build an orbit model
|
|
orbit_model = PhysicsModel()
|
|
orbit_model.fit([start_state])
|
|
future_rows = window.iloc[:-1].reset_index()
|
|
# We add the epoch and the state vector components of the starting row
|
|
# to the rows we will use the physics model to make predictions for
|
|
future_rows['start_epoch'] = start_epoch
|
|
for svc in state_vect_comps:
|
|
future_rows[f'start_{svc}'] = start_row[svc]
|
|
# Calculate the elapsed time from the starting epoch to the
|
|
# the epoch of all the rows to make predictions for
|
|
time_deltas = future_rows.epoch - future_rows.start_epoch
|
|
elapsed_seconds = time_deltas.dt.total_seconds()
|
|
future_rows['elapsed_seconds'] = elapsed_seconds
|
|
physics_cols = [f'physics_pred_{svc}' for svc in state_vect_comps]
|
|
# Predict the state vectors for each of the rows in the "future"
|
|
predicted_orbits = orbit_model.predict([elapsed_seconds.to_numpy()])
|
|
try:
|
|
future_rows[physics_cols] = predicted_orbits[0]
|
|
except Exception as ex:
|
|
print("asshole",f"{ex}\n",window,future_rows)
|
|
return future_rows
|
|
|
|
|
|
def predict_orbits(df, last_n_days, n_pred_days):
|
|
"""Use a physics astrodynamics model to predict the orbits of the ASOs
|
|
in the provided DataFrame.
|
|
|
|
:param df: The DataFrame containing the observed orbital state vectors
|
|
to use to make predictions from
|
|
:type df: pandas.DataFrame
|
|
|
|
:param last_n_days: Filter the DataFrame to use rows from only the last
|
|
`n` days. Use all the rows if `None` is passed, but this may take a
|
|
very long time to run
|
|
:type last_n_days: int
|
|
|
|
:param n_pred_days: The number of days in the rolling prediction window
|
|
:type n_pred_days: int
|
|
"""
|
|
if last_n_days:
|
|
time_cutoff = df.epoch.max() - dt.timedelta(days=last_n_days)
|
|
df = df[df.epoch >= time_cutoff]
|
|
epoch_df = df.sort_values('epoch', ascending=False).set_index('epoch')
|
|
pred_window_length = f'{n_pred_days}d'
|
|
# For each row in `df` we create a window of all of the observations for
|
|
# that ASO that are within `n_pred_days` of the given row
|
|
window_cols = ['aso_id', pd.Grouper(freq=pred_window_length)]
|
|
windows = [w[1] for w in epoch_df.groupby(window_cols)]
|
|
# Predict the orbits in each window in parallel
|
|
window_dfs = Parallel(n_jobs=-1)(delayed(predict_orbit)(w)
|
|
for w in tqdm(windows))
|
|
# Join all of the window prediction DataFrames into a single DataFrame
|
|
physics_pred_df = pd.concat(window_dfs).reset_index(drop=True)
|
|
return physics_pred_df
|
|
|
|
|
|
def calc_physics_error(df):
|
|
"""Calculates the error in the state vector components between the ground truth
|
|
observations and the physics model predictions.
|
|
|
|
:param df: The DataFrame containing the ground truth observations and the
|
|
physics model predictions
|
|
:type df: pandas.DataFrame
|
|
|
|
:return: The input DataFrame with the physical model error column added
|
|
:rtype: pandas.DataFrame
|
|
"""
|
|
comps = ['x', 'y', 'z']
|
|
vects = ['r', 'v']
|
|
for vect, comp in itertools.product(vects, comps):
|
|
comp_col = f'{vect}_{comp}'
|
|
err_col = f'physics_err_{comp_col}'
|
|
err_val = df[f'physics_pred_{comp_col}'] - df[comp_col]
|
|
df[err_col] = err_val
|
|
return df
|
|
|
|
|
|
def run():
|
|
"""Builds a training data set of physics model errors based on the
|
|
parameters supplied by the CLI.
|
|
|
|
:param args: The command line arguments
|
|
:type args: argparse.Namespace
|
|
"""
|
|
logger.info('Loading input DataFrame...')
|
|
input_df = pd.read_parquet("/home/lj020/Downloads/data6.parquet")
|
|
logger.info('Predicting orbits...')
|
|
physics_pred_df = predict_orbits(input_df,
|
|
last_n_days=None,
|
|
n_pred_days=5)
|
|
logger.info('Calculating physical model error...')
|
|
physics_pred_df = calc_physics_error(physics_pred_df)
|
|
logger.info('Serializing results...')
|
|
physics_pred_df.to_parquet("/home/lj020/Downloads/train_result.parquet")
|
|
|
|
with open("/home/lj020/Downloads/train_result.txt","w") as f:
|
|
f.write(physics_pred_df.to_string())
|
|
|
|
run() |