Optimizations

This commit is contained in:
reedts 2018-06-23 15:40:45 +02:00
parent 60137462ed
commit c2a489ce71
3 changed files with 12 additions and 16 deletions

View file

@ -2,14 +2,6 @@ import pandas
import tensorflow as tf
# def pywatts_input_fn(X, y=None, num_epochs=None, shuffle=True, batch_size=1):
#
# return tf.estimator.inputs.pandas_input_fn(x=X,
# y=y,
# num_epochs=num_epochs,
# shuffle=shuffle,
# batch_size=batch_size)
def pywatts_input_fn(X, y=None, num_epochs=None, shuffle=True, batch_size=1):
# Create dictionary for features in hour 0 ... 335
features = {str(idx): [] for idx in range(336)}
@ -28,6 +20,9 @@ def pywatts_input_fn(X, y=None, num_epochs=None, shuffle=True, batch_size=1):
else:
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
if shuffle:
dataset.shuffle(len(features['0']))
return dataset.batch(batch_size)
@ -37,11 +32,11 @@ class Net:
def __init__(self, feature_cols=__feature_cols):
self.__regressor = tf.estimator.DNNRegressor(feature_columns=feature_cols,
hidden_units=[2],
hidden_units=[75, 75],
model_dir='tf_pywatts_model')
def train(self, training_data, training_results, steps):
self.__regressor.train(input_fn=lambda: pywatts_input_fn(training_data, y=training_results, num_epochs=None, shuffle=True, batch_size=1), steps=steps)
def train(self, training_data, training_results, batch_size, steps):
self.__regressor.train(input_fn=lambda: pywatts_input_fn(training_data, y=training_results, num_epochs=None, shuffle=True, batch_size=batch_size), steps=steps)
def evaluate(self, eval_data, eval_results):
return self.__regressor.evaluate(input_fn=lambda: pywatts_input_fn(eval_data, y=eval_results, num_epochs=1, shuffle=False), steps=1)

View file

@ -5,7 +5,7 @@ from pywatts.main import *
PREDICT_QUERY = "query-sample_1hour.json"
PREDICT_RESULT = PREDICT_QUERY.replace("query", "result")
QUERY_ID = 0
QUERY_ID = 1
pred_query = input_query("../sample_data/" + PREDICT_QUERY, QUERY_ID)
@ -18,5 +18,6 @@ n = pywatts.neural.Net(feature_cols=feature_col)
prediction = predict(n, pred_query)
print(prediction)
pywatts.main.eval_prediction(prediction, pred_result)

View file

@ -3,12 +3,12 @@ import tensorflow as tf
import pywatts.db
from pywatts.main import *
NUM_STATIONS_FROM_DB = 50
NUM_TRAIN_STATIONS = 1
NUM_EVAL_STATIONS = 1
NUM_STATIONS_FROM_DB = 75
NUM_TRAIN_STATIONS = 60
NUM_EVAL_STATIONS = 15
TRAIN = True
PLOT = True
TRAIN_STEPS = 1
TRAIN_STEPS = 10
df = pywatts.db.rows_to_df(list(range(1, NUM_STATIONS_FROM_DB)))