From e019f1bee71ac36571e70dc16d705c72b5f8502d Mon Sep 17 00:00:00 2001 From: reedts Date: Tue, 14 Aug 2018 15:21:39 +0200 Subject: [PATCH] Fixed shuffling --- pywatts/neural.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pywatts/neural.py b/pywatts/neural.py index 4180366..aa377db 100644 --- a/pywatts/neural.py +++ b/pywatts/neural.py @@ -22,9 +22,7 @@ def pywatts_input_fn(X, y=None, num_epochs=None, shuffle=True, batch_size=1): dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels)) if shuffle: - dataset.shuffle(len(features['0']*len(features)*4)) - - return dataset.repeat().batch(batch_size) + return dataset.shuffle(len(features['0']*batch_size*4)).repeat().batch(batch_size) else: return dataset.batch(batch_size) @@ -35,7 +33,7 @@ class Net: def __init__(self, feature_cols=__feature_cols): self.__regressor = tf.estimator.DNNRegressor(feature_columns=feature_cols, - hidden_units=[75, 75], + hidden_units=[64, 128, 64], model_dir='tf_pywatts_model') def train(self, training_data, training_results, batch_size, steps):