diff --git a/pywatts/neural.py b/pywatts/neural.py index 4180366..aa377db 100644 --- a/pywatts/neural.py +++ b/pywatts/neural.py @@ -22,9 +22,7 @@ def pywatts_input_fn(X, y=None, num_epochs=None, shuffle=True, batch_size=1): dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels)) if shuffle: - dataset.shuffle(len(features['0']*len(features)*4)) - - return dataset.repeat().batch(batch_size) + return dataset.shuffle(len(features['0']*batch_size*4)).repeat().batch(batch_size) else: return dataset.batch(batch_size) @@ -35,7 +33,7 @@ class Net: def __init__(self, feature_cols=__feature_cols): self.__regressor = tf.estimator.DNNRegressor(feature_columns=feature_cols, - hidden_units=[75, 75], + hidden_units=[64, 128, 64], model_dir='tf_pywatts_model') def train(self, training_data, training_results, batch_size, steps):