diff --git a/pywatts/neural.py b/pywatts/neural.py index 2c09e79..10b75bc 100644 --- a/pywatts/neural.py +++ b/pywatts/neural.py @@ -1,5 +1,12 @@ import tensorflow as tf +def pywatts_input_fn(X, y=None, num_epochs=None, shuffle=True, batch_size=400): + return tf.estimator.inputs.pandas_input_fn(x=X, + y=y, + num_epochs=num_epochs, + shuffle=shuffle, + batch_size=batch_size) + class Net: __regressor = None @@ -11,15 +18,8 @@ class Net: hidden_units=[50, 50], model_dir='tf_pywatts_model') - def pywatts_input_fn(X, y=None, num_epochs=None, shuffle=True, batch_size=400): - return tf.estimator.inputs.pandas_input_fn(x=X, - y=y, - num_epochs=num_epochs, - shuffle=shuffle, - batch_size=batch_size) - def train(self, training_data, steps): - self.__regressor.train(input_fn=self.pywatts_input_fn(training_data, num_epochs=None, shuffle=True), steps=steps) + self.__regressor.train(input_fn=pywatts_input_fn(training_data, num_epochs=None, shuffle=True), steps=steps) def evaluate(self, eval_data): self.__regressor.evaluate(input_fn=self.pywatts_input_fn(eval_data, num_epochs=1, shuffle=False), steps=1)