Added new test configuration

This commit is contained in:
reedts 2018-08-14 22:20:40 +02:00
parent dfddb8799e
commit 51d0e9cea8
3 changed files with 5 additions and 5 deletions

View File

@ -18,7 +18,7 @@ def split(data, k):
data_list = data['dc'].tolist()
# Each sample has 337 elements
samples = [data_list[i:i+337] for i in range(0, len(data_list) - 337, 337)]
samples = [data_list[i:i+337] for i in range(0, len(data_list) - 337, 20)]
# Randomly shuffle samples
random.shuffle(samples)
@ -49,7 +49,7 @@ def train(nn, X_train, y_train, X_eval, y_eval, steps=10):
evaluation = []
for count, train_data in enumerate(X_train):
for i in range(steps):
nn.train(train_data, y_train[count], batch_size=int(len(train_data['dc'])/336), steps=1)
nn.train(train_data, y_train[count], batch_size=30, steps=100) #batch_size=int(len(train_data['dc'])/336), steps=1)
evaluation.append(nn.evaluate(X_eval[count], y_eval[count], batch_size=int(len(X_eval[count]['dc'])/336)))
print("Training %s: %s/%s" % (count, (i+1), steps))

View File

@ -22,7 +22,7 @@ def pywatts_input_fn(X, y=None, num_epochs=None, shuffle=True, batch_size=1):
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
if shuffle:
return dataset.shuffle(len(features['0']*batch_size*4)).repeat().batch(batch_size)
return dataset.shuffle(len(features['0']*len(features)*4)).repeat().batch(batch_size)
else:
return dataset.batch(batch_size)

View File

@ -4,11 +4,11 @@ import pywatts.db
from pywatts import kcross
NUM_STATIONS_FROM_DB = 75
K = 4
K = 10
NUM_EVAL_STATIONS = 40
TRAIN = True
PLOT = True
TRAIN_STEPS = 4
TRAIN_STEPS = 20
df = pywatts.db.rows_to_df(list(range(1, NUM_STATIONS_FROM_DB)))