You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ssa/build_training_data.py

75 lines
2.5 KiB
Python

from os import replace
from astropy.time.core import Time, TimeDelta
import pandas as pd
import random
from tqdm import tqdm
from poliastro.bodies import Earth
from poliastro.twobody import Orbit
import astropy.units as U
from poliastro.twobody.propagation import cowell
from joblib import Parallel, delayed
def rv2list(rv):
r, v = rv
r = r.to(U.m).to_value()
v = v.to(U.m / U.s).to_value()
return [*r, *v]
def build_training_data(df: pd.DataFrame):
rnd_starts = random.choices(df.index, k=50)
result = []
for idx in tqdm(rnd_starts, position=0):
src = df.iloc[idx].to_numpy()
targets = df.iloc[idx:idx + 24 * 60].sample(n=100,
replace=True).to_numpy()
def calc_diff(dst):
s = Orbit.from_vectors(Earth, src[2:5] * U.m, src[5:8] * U.m / U.s,
Time(src[1]))
o = Orbit.from_vectors(Earth, dst[2:5] * U.m, dst[5:8] * U.m / U.s,
Time(dst[1]))
dt = dst[1] - src[1]
op = s.propagate(o.epoch, method=cowell)
rv = rv2list(op.rv())
rv_real = rv2list(o.rv())
rv_diff = [x - y for x, y in zip(rv_real, rv)]
return dt.seconds, rv, rv_diff
calc_diff(targets[0])
diffs = Parallel(n_jobs=4)(
delayed(calc_diff)(x)
for x in tqdm(targets, position=1, leave=False))
result += [[*src, dt, *pred, *err] for dt, pred, err in diffs]
return pd.DataFrame(result,
columns=[
"id",
"start_epoch",
"start_r_x",
"start_r_y",
"start_r_z",
"start_v_x",
"start_v_y",
"start_v_z",
"elapsed_seconds",
"pred_r_x",
"pred_r_y",
"pred_r_z",
"pred_v_x",
"pred_v_y",
"pred_v_z",
"err_r_x",
"err_r_y",
"err_r_z",
"err_v_x",
"err_v_y",
"err_v_z",
])
if __name__ == "__main__":
df = pd.read_parquet("1.pq")
td = build_training_data(df)
td.to_parquet("data.pq")