Optimizations
This commit is contained in:
parent
60137462ed
commit
c2a489ce71
3 changed files with 12 additions and 16 deletions
|
@ -2,14 +2,6 @@ import pandas
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
# def pywatts_input_fn(X, y=None, num_epochs=None, shuffle=True, batch_size=1):
|
|
||||||
#
|
|
||||||
# return tf.estimator.inputs.pandas_input_fn(x=X,
|
|
||||||
# y=y,
|
|
||||||
# num_epochs=num_epochs,
|
|
||||||
# shuffle=shuffle,
|
|
||||||
# batch_size=batch_size)
|
|
||||||
|
|
||||||
def pywatts_input_fn(X, y=None, num_epochs=None, shuffle=True, batch_size=1):
|
def pywatts_input_fn(X, y=None, num_epochs=None, shuffle=True, batch_size=1):
|
||||||
# Create dictionary for features in hour 0 ... 335
|
# Create dictionary for features in hour 0 ... 335
|
||||||
features = {str(idx): [] for idx in range(336)}
|
features = {str(idx): [] for idx in range(336)}
|
||||||
|
@ -28,6 +20,9 @@ def pywatts_input_fn(X, y=None, num_epochs=None, shuffle=True, batch_size=1):
|
||||||
else:
|
else:
|
||||||
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
|
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
|
||||||
|
|
||||||
|
if shuffle:
|
||||||
|
dataset.shuffle(len(features['0']))
|
||||||
|
|
||||||
return dataset.batch(batch_size)
|
return dataset.batch(batch_size)
|
||||||
|
|
||||||
|
|
||||||
|
@ -37,11 +32,11 @@ class Net:
|
||||||
|
|
||||||
def __init__(self, feature_cols=__feature_cols):
|
def __init__(self, feature_cols=__feature_cols):
|
||||||
self.__regressor = tf.estimator.DNNRegressor(feature_columns=feature_cols,
|
self.__regressor = tf.estimator.DNNRegressor(feature_columns=feature_cols,
|
||||||
hidden_units=[2],
|
hidden_units=[75, 75],
|
||||||
model_dir='tf_pywatts_model')
|
model_dir='tf_pywatts_model')
|
||||||
|
|
||||||
def train(self, training_data, training_results, steps):
|
def train(self, training_data, training_results, batch_size, steps):
|
||||||
self.__regressor.train(input_fn=lambda: pywatts_input_fn(training_data, y=training_results, num_epochs=None, shuffle=True, batch_size=1), steps=steps)
|
self.__regressor.train(input_fn=lambda: pywatts_input_fn(training_data, y=training_results, num_epochs=None, shuffle=True, batch_size=batch_size), steps=steps)
|
||||||
|
|
||||||
def evaluate(self, eval_data, eval_results):
|
def evaluate(self, eval_data, eval_results):
|
||||||
return self.__regressor.evaluate(input_fn=lambda: pywatts_input_fn(eval_data, y=eval_results, num_epochs=1, shuffle=False), steps=1)
|
return self.__regressor.evaluate(input_fn=lambda: pywatts_input_fn(eval_data, y=eval_results, num_epochs=1, shuffle=False), steps=1)
|
||||||
|
|
|
@ -5,7 +5,7 @@ from pywatts.main import *
|
||||||
|
|
||||||
PREDICT_QUERY = "query-sample_1hour.json"
|
PREDICT_QUERY = "query-sample_1hour.json"
|
||||||
PREDICT_RESULT = PREDICT_QUERY.replace("query", "result")
|
PREDICT_RESULT = PREDICT_QUERY.replace("query", "result")
|
||||||
QUERY_ID = 0
|
QUERY_ID = 1
|
||||||
|
|
||||||
|
|
||||||
pred_query = input_query("../sample_data/" + PREDICT_QUERY, QUERY_ID)
|
pred_query = input_query("../sample_data/" + PREDICT_QUERY, QUERY_ID)
|
||||||
|
@ -18,5 +18,6 @@ n = pywatts.neural.Net(feature_cols=feature_col)
|
||||||
|
|
||||||
prediction = predict(n, pred_query)
|
prediction = predict(n, pred_query)
|
||||||
|
|
||||||
|
print(prediction)
|
||||||
|
|
||||||
pywatts.main.eval_prediction(prediction, pred_result)
|
pywatts.main.eval_prediction(prediction, pred_result)
|
||||||
|
|
|
@ -3,12 +3,12 @@ import tensorflow as tf
|
||||||
import pywatts.db
|
import pywatts.db
|
||||||
from pywatts.main import *
|
from pywatts.main import *
|
||||||
|
|
||||||
NUM_STATIONS_FROM_DB = 50
|
NUM_STATIONS_FROM_DB = 75
|
||||||
NUM_TRAIN_STATIONS = 1
|
NUM_TRAIN_STATIONS = 60
|
||||||
NUM_EVAL_STATIONS = 1
|
NUM_EVAL_STATIONS = 15
|
||||||
TRAIN = True
|
TRAIN = True
|
||||||
PLOT = True
|
PLOT = True
|
||||||
TRAIN_STEPS = 1
|
TRAIN_STEPS = 10
|
||||||
|
|
||||||
|
|
||||||
df = pywatts.db.rows_to_df(list(range(1, NUM_STATIONS_FROM_DB)))
|
df = pywatts.db.rows_to_df(list(range(1, NUM_STATIONS_FROM_DB)))
|
||||||
|
|
Loading…
Reference in a new issue