Moved input_fn outside of class
This commit is contained in:
parent
5521e7de80
commit
3f4602c306
1 changed files with 8 additions and 8 deletions
|
@ -1,5 +1,12 @@
|
|||
import tensorflow as tf
|
||||
|
||||
def pywatts_input_fn(X, y=None, num_epochs=None, shuffle=True, batch_size=400):
|
||||
return tf.estimator.inputs.pandas_input_fn(x=X,
|
||||
y=y,
|
||||
num_epochs=num_epochs,
|
||||
shuffle=shuffle,
|
||||
batch_size=batch_size)
|
||||
|
||||
|
||||
class Net:
|
||||
__regressor = None
|
||||
|
@ -11,15 +18,8 @@ class Net:
|
|||
hidden_units=[50, 50],
|
||||
model_dir='tf_pywatts_model')
|
||||
|
||||
def pywatts_input_fn(X, y=None, num_epochs=None, shuffle=True, batch_size=400):
|
||||
return tf.estimator.inputs.pandas_input_fn(x=X,
|
||||
y=y,
|
||||
num_epochs=num_epochs,
|
||||
shuffle=shuffle,
|
||||
batch_size=batch_size)
|
||||
|
||||
def train(self, training_data, steps):
|
||||
self.__regressor.train(input_fn=self.pywatts_input_fn(training_data, num_epochs=None, shuffle=True), steps=steps)
|
||||
self.__regressor.train(input_fn=pywatts_input_fn(training_data, num_epochs=None, shuffle=True), steps=steps)
|
||||
|
||||
def evaluate(self, eval_data):
|
||||
self.__regressor.evaluate(input_fn=self.pywatts_input_fn(eval_data, num_epochs=1, shuffle=False), steps=1)
|
||||
|
|
Loading…
Reference in a new issue