Moved input_fn outside of class
This commit is contained in:
parent
5521e7de80
commit
3f4602c306
1 changed files with 8 additions and 8 deletions
|
@ -1,5 +1,12 @@
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
def pywatts_input_fn(X, y=None, num_epochs=None, shuffle=True, batch_size=400):
|
||||||
|
return tf.estimator.inputs.pandas_input_fn(x=X,
|
||||||
|
y=y,
|
||||||
|
num_epochs=num_epochs,
|
||||||
|
shuffle=shuffle,
|
||||||
|
batch_size=batch_size)
|
||||||
|
|
||||||
|
|
||||||
class Net:
|
class Net:
|
||||||
__regressor = None
|
__regressor = None
|
||||||
|
@ -11,15 +18,8 @@ class Net:
|
||||||
hidden_units=[50, 50],
|
hidden_units=[50, 50],
|
||||||
model_dir='tf_pywatts_model')
|
model_dir='tf_pywatts_model')
|
||||||
|
|
||||||
def pywatts_input_fn(X, y=None, num_epochs=None, shuffle=True, batch_size=400):
|
|
||||||
return tf.estimator.inputs.pandas_input_fn(x=X,
|
|
||||||
y=y,
|
|
||||||
num_epochs=num_epochs,
|
|
||||||
shuffle=shuffle,
|
|
||||||
batch_size=batch_size)
|
|
||||||
|
|
||||||
def train(self, training_data, steps):
|
def train(self, training_data, steps):
|
||||||
self.__regressor.train(input_fn=self.pywatts_input_fn(training_data, num_epochs=None, shuffle=True), steps=steps)
|
self.__regressor.train(input_fn=pywatts_input_fn(training_data, num_epochs=None, shuffle=True), steps=steps)
|
||||||
|
|
||||||
def evaluate(self, eval_data):
|
def evaluate(self, eval_data):
|
||||||
self.__regressor.evaluate(input_fn=self.pywatts_input_fn(eval_data, num_epochs=1, shuffle=False), steps=1)
|
self.__regressor.evaluate(input_fn=self.pywatts_input_fn(eval_data, num_epochs=1, shuffle=False), steps=1)
|
||||||
|
|
Loading…
Reference in a new issue