You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
68 lines
1.8 KiB
Python
68 lines
1.8 KiB
Python
import keras
|
|
from keras.engine.data_adapter import KerasSequenceAdapter
|
|
from normal_use import *
|
|
from nn_use import FNN_Net, NN_Net
|
|
|
|
|
|
Regressors = [NN_Net]
|
|
# Params = ['','']
|
|
Regressor = Union[type(Regressors)]
|
|
train_test_data = None
|
|
|
|
|
|
def train_model(id, regType: Regressor):
|
|
X, ys = train_test_data['X_train'], train_test_data['y_train']
|
|
check_X_y(X, ys, multi_output=True)
|
|
models = {}
|
|
for target_col in ys.columns:
|
|
y = ys[target_col]
|
|
reg = regType()
|
|
reg.fit(X, y)
|
|
models[target_col] = reg
|
|
print(regType.__name__, target_col)
|
|
joblib.dump(models, f"nn_models/{regType.__name__}.model")
|
|
# keras.models.save_model(models, f"nn_models/{regType.__name__}.model")
|
|
|
|
|
|
def eval_model(regType: Regressor):
|
|
models = joblib.load(f"nn_models/{regType.__name__}.model")
|
|
X, ys = train_test_data['X_test'], train_test_data['y_test']
|
|
evals = []
|
|
for target_col, reg in models.items():
|
|
y_hat = reg.predict(X) # fake
|
|
y = ys[target_col] # real
|
|
rmse = metrics.mean_squared_error(y, y_hat, squared=False)
|
|
r2 = metrics.r2_score(y, y_hat)
|
|
eval_dict = {'Error': target_col, 'RMSE': rmse, 'R^2': r2}
|
|
evals.append(eval_dict)
|
|
print(regType.__name__)
|
|
print(pd.DataFrame(evals))
|
|
print("Average R2: ", average_R2(evals))
|
|
|
|
|
|
def train_one_models(trainsets):
|
|
"""
|
|
Description
|
|
-----------
|
|
call this to start trainning each regressors.
|
|
|
|
Parameters
|
|
----------
|
|
trainset : dict
|
|
use joblib to extract target dataset(create_datas) and put it in here.
|
|
|
|
Returns
|
|
-------
|
|
NO returns, but models in folder "models" and print R2 on screen
|
|
|
|
"""
|
|
global train_test_data
|
|
train_test_data = trainsets
|
|
|
|
for i, reg in enumerate(Regressors):
|
|
train_model(i, reg)
|
|
|
|
for reg in Regressors:
|
|
eval_model(reg)
|
|
|
|
|