You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

68 lines
1.8 KiB
Python

import keras
from keras.engine.data_adapter import KerasSequenceAdapter
from normal_use import *
from nn_use import FNN_Net, NN_Net
Regressors = [NN_Net]
# Params = ['','']
Regressor = Union[type(Regressors)]
train_test_data = None
def train_model(id, regType: Regressor):
X, ys = train_test_data['X_train'], train_test_data['y_train']
check_X_y(X, ys, multi_output=True)
models = {}
for target_col in ys.columns:
y = ys[target_col]
reg = regType()
reg.fit(X, y)
models[target_col] = reg
print(regType.__name__, target_col)
joblib.dump(models, f"nn_models/{regType.__name__}.model")
# keras.models.save_model(models, f"nn_models/{regType.__name__}.model")
def eval_model(regType: Regressor):
models = joblib.load(f"nn_models/{regType.__name__}.model")
X, ys = train_test_data['X_test'], train_test_data['y_test']
evals = []
for target_col, reg in models.items():
y_hat = reg.predict(X) # fake
y = ys[target_col] # real
rmse = metrics.mean_squared_error(y, y_hat, squared=False)
r2 = metrics.r2_score(y, y_hat)
eval_dict = {'Error': target_col, 'RMSE': rmse, 'R^2': r2}
evals.append(eval_dict)
print(regType.__name__)
print(pd.DataFrame(evals))
print("Average R2: ", average_R2(evals))
def train_one_models(trainsets):
"""
Description
-----------
call this to start trainning each regressors.
Parameters
----------
trainset : dict
use joblib to extract target dataset(create_datas) and put it in here.
Returns
-------
NO returns, but models in folder "models" and print R2 on screen
"""
global train_test_data
train_test_data = trainsets
for i, reg in enumerate(Regressors):
train_model(i, reg)
for reg in Regressors:
eval_model(reg)