Merge
This commit is contained in:
commit
65756a18a4
3 changed files with 6 additions and 8 deletions
|
@ -16,7 +16,7 @@ def split(data, k):
|
|||
data_list = data['dc'].tolist()
|
||||
|
||||
# Each sample has 337 elements
|
||||
samples = [data_list[i:i+337] for i in range(0, len(data_list) - 337)]
|
||||
samples = [data_list[i:i+337] for i in range(0, len(data_list) - 337, 20)]
|
||||
# Randomly shuffle samples
|
||||
random.shuffle(samples)
|
||||
|
||||
|
@ -47,7 +47,7 @@ def train(nn, X_train, y_train, X_eval, y_eval, steps=10):
|
|||
evaluation = []
|
||||
for count, train_data in enumerate(X_train):
|
||||
for i in range(steps):
|
||||
nn.train(train_data, y_train[count], batch_size=int(len(train_data['dc'])/336), steps=1)
|
||||
nn.train(train_data, y_train[count], batch_size=30, steps=100) #batch_size=int(len(train_data['dc'])/336), steps=1)
|
||||
evaluation.append(nn.evaluate(X_eval[count], y_eval[count], batch_size=int(len(X_eval[count]['dc'])/336)))
|
||||
print("Training %s: %s/%s" % (count, (i+1), steps))
|
||||
|
||||
|
|
|
@ -20,9 +20,7 @@ def pywatts_input_fn(X, y=None, num_epochs=None, shuffle=True, batch_size=1):
|
|||
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
|
||||
|
||||
if shuffle:
|
||||
dataset.shuffle(len(features['0']*len(features)*4))
|
||||
|
||||
return dataset.repeat().batch(batch_size)
|
||||
return dataset.shuffle(len(features['0']*len(features)*4)).repeat().batch(batch_size)
|
||||
else:
|
||||
return dataset.batch(batch_size)
|
||||
|
||||
|
@ -33,7 +31,7 @@ class Net:
|
|||
|
||||
def __init__(self, feature_cols=__feature_cols):
|
||||
self.__regressor = tf.estimator.DNNRegressor(feature_columns=feature_cols,
|
||||
hidden_units=[75, 75],
|
||||
hidden_units=[64, 128, 64],
|
||||
model_dir='tf_pywatts_model')
|
||||
|
||||
def train(self, training_data, training_results, batch_size, steps):
|
||||
|
|
|
@ -4,11 +4,11 @@ import pywatts.db
|
|||
from pywatts import kcross
|
||||
|
||||
NUM_STATIONS_FROM_DB = 75
|
||||
K = 4
|
||||
K = 10
|
||||
NUM_EVAL_STATIONS = 40
|
||||
TRAIN = True
|
||||
PLOT = True
|
||||
TRAIN_STEPS = 4
|
||||
TRAIN_STEPS = 20
|
||||
|
||||
|
||||
df = pywatts.db.rows_to_df(list(range(1, NUM_STATIONS_FROM_DB)))
|
||||
|
|
Loading…
Reference in a new issue