Fixed shuffling
This commit is contained in:
parent
d4da4ca121
commit
e019f1bee7
1 changed files with 2 additions and 4 deletions
|
@ -22,9 +22,7 @@ def pywatts_input_fn(X, y=None, num_epochs=None, shuffle=True, batch_size=1):
|
||||||
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
|
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
|
||||||
|
|
||||||
if shuffle:
|
if shuffle:
|
||||||
dataset.shuffle(len(features['0']*len(features)*4))
|
return dataset.shuffle(len(features['0']*batch_size*4)).repeat().batch(batch_size)
|
||||||
|
|
||||||
return dataset.repeat().batch(batch_size)
|
|
||||||
else:
|
else:
|
||||||
return dataset.batch(batch_size)
|
return dataset.batch(batch_size)
|
||||||
|
|
||||||
|
@ -35,7 +33,7 @@ class Net:
|
||||||
|
|
||||||
def __init__(self, feature_cols=__feature_cols):
|
def __init__(self, feature_cols=__feature_cols):
|
||||||
self.__regressor = tf.estimator.DNNRegressor(feature_columns=feature_cols,
|
self.__regressor = tf.estimator.DNNRegressor(feature_columns=feature_cols,
|
||||||
hidden_units=[75, 75],
|
hidden_units=[64, 128, 64],
|
||||||
model_dir='tf_pywatts_model')
|
model_dir='tf_pywatts_model')
|
||||||
|
|
||||||
def train(self, training_data, training_results, batch_size, steps):
|
def train(self, training_data, training_results, batch_size, steps):
|
||||||
|
|
Loading…
Reference in a new issue