pywatts/pywatts/test_kcross_train.py

42 lines
802 B
Python
Raw Normal View History

2018-08-06 13:28:27 +02:00
import peewee
import tensorflow as tf
import pywatts.db
from pywatts import kcross
NUM_STATIONS_FROM_DB = 75
2018-08-14 22:20:40 +02:00
K = 10
2018-08-06 13:28:27 +02:00
NUM_EVAL_STATIONS = 40
TRAIN = True
PLOT = True
2018-08-14 22:20:40 +02:00
TRAIN_STEPS = 20
2018-08-06 13:28:27 +02:00
df = pywatts.db.rows_to_df(list(range(1, NUM_STATIONS_FROM_DB)))
X = df
y = df['dc']
# Define feature columns and initialize Regressor
feature_col = [tf.feature_column.numeric_column(str(idx)) for idx in range(336)]
n = pywatts.neural.Net(feature_cols=feature_col)
# Training data
(X_train, y_train, X_eval, y_eval) = kcross.split(df, K)
2018-08-07 17:54:05 +02:00
#train_eval = {}
2018-08-06 13:28:27 +02:00
if TRAIN:
# Train the model with the steps given
train_eval = kcross.train(n, X_train, y_train, X_eval, y_eval, TRAIN_STEPS)
if PLOT:
# Plot training success rate (with 'average loss')
2018-08-13 17:17:37 +02:00
pywatts.routines.plot_training(train_eval)
2018-08-06 13:28:27 +02:00
exit()