Capping to zero

This commit is contained in:
reedts 2018-08-13 14:31:39 +02:00
parent 0e228772dc
commit 841690f98b
5 changed files with 9 additions and 8 deletions

View File

@ -1,5 +1,5 @@
from pywatts import db from pywatts import db
from pywatts import fetchdata from pywatts import fetchdata
from pywatts import neural from pywatts import neural
from pywatts import main from pywatts import routines
from pywatts import kcross from pywatts import kcross

View File

@ -75,7 +75,8 @@ def predict24h(nn, X_pred):
for i in range(24): for i in range(24):
pred = nn.predict1h(pandas.DataFrame.from_dict(input)) pred = nn.predict1h(pandas.DataFrame.from_dict(input))
predictions.extend(list([p['predictions'][0] for p in pred])) # Cap prediction to 0
predictions.extend(list([max(p['predictions'][0], 0) for p in pred]))
# Remove first value and append predicted value # Remove first value and append predicted value
del input['dc'][0] del input['dc'][0]
input['dc'].append(predictions[-1]) input['dc'].append(predictions[-1])

View File

@ -1,11 +1,11 @@
import tensorflow as tf import tensorflow as tf
import pywatts.db import pywatts.db
from pywatts.main import * from pywatts.routines import *
PREDICT_QUERY = "query-sample_1hour.json" PREDICT_QUERY = "query-sample_1hour.json"
PREDICT_RESULT = PREDICT_QUERY.replace("query", "result") PREDICT_RESULT = PREDICT_QUERY.replace("query", "result")
QUERY_ID = 1 QUERY_ID = 0
pred_query = input_query("../sample_data/" + PREDICT_QUERY, QUERY_ID) pred_query = input_query("../sample_data/" + PREDICT_QUERY, QUERY_ID)
@ -21,4 +21,4 @@ prediction = predict(n, pred_query)
print(prediction) print(prediction)
print(pred_result) print(pred_result)
pywatts.main.eval_prediction(prediction, pred_result) pywatts.routines.eval_prediction(prediction, pred_result)

View File

@ -1,6 +1,6 @@
import tensorflow as tf import tensorflow as tf
import pywatts.db import pywatts.db
from pywatts.main import * from pywatts.routines import *
import matplotlib.pyplot as pp import matplotlib.pyplot as pp

View File

@ -1,7 +1,7 @@
import peewee import peewee
import tensorflow as tf import tensorflow as tf
import pywatts.db import pywatts.db
from pywatts.main import * from pywatts.routines import *
NUM_STATIONS_FROM_DB = 75 NUM_STATIONS_FROM_DB = 75
NUM_TRAIN_STATIONS = 400 NUM_TRAIN_STATIONS = 400
@ -43,7 +43,7 @@ if TRAIN:
if PLOT: if PLOT:
# Plot training success rate (with 'average loss') # Plot training success rate (with 'average loss')
pywatts.main.plot_training(train_eval) pywatts.routines.plot_training(train_eval)
exit() exit()