From 3f4602c306185aac3b79a711521408ac12f93ff9 Mon Sep 17 00:00:00 2001 From: reedts Date: Tue, 29 May 2018 15:57:50 +0200 Subject: [PATCH] Moved input_fn outside of class --- pywatts/neural.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pywatts/neural.py b/pywatts/neural.py index 2c09e79..10b75bc 100644 --- a/pywatts/neural.py +++ b/pywatts/neural.py @@ -1,5 +1,12 @@ import tensorflow as tf +def pywatts_input_fn(X, y=None, num_epochs=None, shuffle=True, batch_size=400): + return tf.estimator.inputs.pandas_input_fn(x=X, + y=y, + num_epochs=num_epochs, + shuffle=shuffle, + batch_size=batch_size) + class Net: __regressor = None @@ -11,15 +18,8 @@ class Net: hidden_units=[50, 50], model_dir='tf_pywatts_model') - def pywatts_input_fn(X, y=None, num_epochs=None, shuffle=True, batch_size=400): - return tf.estimator.inputs.pandas_input_fn(x=X, - y=y, - num_epochs=num_epochs, - shuffle=shuffle, - batch_size=batch_size) - def train(self, training_data, steps): - self.__regressor.train(input_fn=self.pywatts_input_fn(training_data, num_epochs=None, shuffle=True), steps=steps) + self.__regressor.train(input_fn=pywatts_input_fn(training_data, num_epochs=None, shuffle=True), steps=steps) def evaluate(self, eval_data): self.__regressor.evaluate(input_fn=self.pywatts_input_fn(eval_data, num_epochs=1, shuffle=False), steps=1)