Moved input_fn outside of class

This commit is contained in:
reedts 2018-05-29 15:57:50 +02:00
parent 5521e7de80
commit 3f4602c306

View file

@ -1,5 +1,12 @@
import tensorflow as tf import tensorflow as tf
def pywatts_input_fn(X, y=None, num_epochs=None, shuffle=True, batch_size=400):
return tf.estimator.inputs.pandas_input_fn(x=X,
y=y,
num_epochs=num_epochs,
shuffle=shuffle,
batch_size=batch_size)
class Net: class Net:
__regressor = None __regressor = None
@ -11,15 +18,8 @@ class Net:
hidden_units=[50, 50], hidden_units=[50, 50],
model_dir='tf_pywatts_model') model_dir='tf_pywatts_model')
def pywatts_input_fn(X, y=None, num_epochs=None, shuffle=True, batch_size=400):
return tf.estimator.inputs.pandas_input_fn(x=X,
y=y,
num_epochs=num_epochs,
shuffle=shuffle,
batch_size=batch_size)
def train(self, training_data, steps): def train(self, training_data, steps):
self.__regressor.train(input_fn=self.pywatts_input_fn(training_data, num_epochs=None, shuffle=True), steps=steps) self.__regressor.train(input_fn=pywatts_input_fn(training_data, num_epochs=None, shuffle=True), steps=steps)
def evaluate(self, eval_data): def evaluate(self, eval_data):
self.__regressor.evaluate(input_fn=self.pywatts_input_fn(eval_data, num_epochs=1, shuffle=False), steps=1) self.__regressor.evaluate(input_fn=self.pywatts_input_fn(eval_data, num_epochs=1, shuffle=False), steps=1)