From 60137462ed5bc8bac18dc6f108ecfdf6c829ade1 Mon Sep 17 00:00:00 2001 From: reedts Date: Sat, 23 Jun 2018 15:40:23 +0200 Subject: [PATCH] Avoid having the same test sample twice --- pywatts/main.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/pywatts/main.py b/pywatts/main.py index cba3cb8..414f87d 100644 --- a/pywatts/main.py +++ b/pywatts/main.py @@ -8,17 +8,22 @@ from random import randint def train_split(data, size): + used_idxs = [] X_values = {'dc': [], 'temp': [], 'wind': []} y_values = [] for i in range(size): rnd_idx = randint(0, data.size / data.shape[1] - 337) + if rnd_idx in used_idxs: + continue + else: + used_idxs.append(rnd_idx) + X_values['dc'].extend(data['dc'][rnd_idx:rnd_idx + 336].tolist()) X_values['temp'].extend(data['temp'][rnd_idx:rnd_idx + 336].tolist()) X_values['wind'].extend(data['wind'][rnd_idx:rnd_idx + 336].tolist()) y_values.append(data['dc'][rnd_idx + 337].tolist()) - return pandas.DataFrame.from_dict(X_values), pandas.DataFrame.from_dict({'dc': y_values}) @@ -31,6 +36,7 @@ def input_query(json_str, idx=0): 'wind': tmp_df['wind'][idx]} ) + def input_result(json_str, idx=0): tmp_df = pandas.read_json(json_str) @@ -40,7 +46,7 @@ def input_result(json_str, idx=0): def train(nn, X_train, y_train, X_val, y_val, steps=100): evaluation = [] for i in range(steps): - nn.train(X_train, y_train, steps=100) + nn.train(X_train, y_train, batch_size=int(len(X_train['dc'].tolist())/336), steps=100) evaluation.append(nn.evaluate(X_val, y_val)) print("Training %s of %s" % ((i+1), steps)) return evaluation