From 51d0e9cea847610c5214dabdcbc8a6a23be10968 Mon Sep 17 00:00:00 2001 From: reedts Date: Tue, 14 Aug 2018 22:20:40 +0200 Subject: [PATCH] Added new test configuration --- pywatts/kcross.py | 4 ++-- pywatts/neural.py | 2 +- pywatts/test_kcross_train.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pywatts/kcross.py b/pywatts/kcross.py index 9c3151d..29a7bde 100644 --- a/pywatts/kcross.py +++ b/pywatts/kcross.py @@ -18,7 +18,7 @@ def split(data, k): data_list = data['dc'].tolist() # Each sample has 337 elements - samples = [data_list[i:i+337] for i in range(0, len(data_list) - 337, 337)] + samples = [data_list[i:i+337] for i in range(0, len(data_list) - 337, 20)] # Randomly shuffle samples random.shuffle(samples) @@ -49,7 +49,7 @@ def train(nn, X_train, y_train, X_eval, y_eval, steps=10): evaluation = [] for count, train_data in enumerate(X_train): for i in range(steps): - nn.train(train_data, y_train[count], batch_size=int(len(train_data['dc'])/336), steps=1) + nn.train(train_data, y_train[count], batch_size=30, steps=100) #batch_size=int(len(train_data['dc'])/336), steps=1) evaluation.append(nn.evaluate(X_eval[count], y_eval[count], batch_size=int(len(X_eval[count]['dc'])/336))) print("Training %s: %s/%s" % (count, (i+1), steps)) diff --git a/pywatts/neural.py b/pywatts/neural.py index aa377db..66dc16c 100644 --- a/pywatts/neural.py +++ b/pywatts/neural.py @@ -22,7 +22,7 @@ def pywatts_input_fn(X, y=None, num_epochs=None, shuffle=True, batch_size=1): dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels)) if shuffle: - return dataset.shuffle(len(features['0']*batch_size*4)).repeat().batch(batch_size) + return dataset.shuffle(len(features['0']*len(features)*4)).repeat().batch(batch_size) else: return dataset.batch(batch_size) diff --git a/pywatts/test_kcross_train.py b/pywatts/test_kcross_train.py index 4e29d0a..14550b1 100644 --- a/pywatts/test_kcross_train.py +++ b/pywatts/test_kcross_train.py @@ -4,11 +4,11 @@ import pywatts.db from pywatts import kcross NUM_STATIONS_FROM_DB = 75 -K = 4 +K = 10 NUM_EVAL_STATIONS = 40 TRAIN = True PLOT = True -TRAIN_STEPS = 4 +TRAIN_STEPS = 20 df = pywatts.db.rows_to_df(list(range(1, NUM_STATIONS_FROM_DB)))