Add evaluation method
This commit is contained in:
parent
ba0b67565b
commit
0f8ebcff23
2 changed files with 19 additions and 3 deletions
|
@ -1,6 +1,8 @@
|
||||||
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import matplotlib.pyplot as pp
|
import matplotlib.pyplot as pp
|
||||||
import pywatts.neural
|
import pywatts.neural
|
||||||
|
from sklearn.metrics import explained_variance_score, mean_absolute_error, median_absolute_error
|
||||||
|
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
|
@ -33,3 +35,18 @@ def plot_training(evaluation):
|
||||||
loss.append(e['loss'])
|
loss.append(e['loss'])
|
||||||
pp.plot(loss)
|
pp.plot(loss)
|
||||||
|
|
||||||
|
|
||||||
|
def predict(X_pred):
|
||||||
|
pred = n.predict1h(X_pred)
|
||||||
|
predictions = np.array([p['predictions'][0] for p in pred])
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
|
||||||
|
def eval_prediction(prediction):
|
||||||
|
print("The Explained Variance: %.2f" % explained_variance_score(
|
||||||
|
y_test, prediction))
|
||||||
|
print("The Mean Absolute Error: %.2f volt dc" % mean_absolute_error(
|
||||||
|
y_test, prediction))
|
||||||
|
print("The Median Absolute Error: %.2f volt dc" % median_absolute_error(
|
||||||
|
y_test, prediction))
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,5 @@ class Net:
|
||||||
def evaluate(self, eval_data, eval_results):
|
def evaluate(self, eval_data, eval_results):
|
||||||
return self.__regressor.evaluate(input_fn=pywatts_input_fn(eval_data, y=eval_results, num_epochs=1, shuffle=False), steps=1)
|
return self.__regressor.evaluate(input_fn=pywatts_input_fn(eval_data, y=eval_results, num_epochs=1, shuffle=False), steps=1)
|
||||||
|
|
||||||
def predict1h(self, df):
|
def predict1h(self, predict_data):
|
||||||
df = df.drop(['month', 'day', 'hour'])
|
return self.__regressor.predict(input_fn=pywatts_input_fn(predict_data, num_epochs=1, shuffle=False))
|
||||||
return self.__regressor.predict(input_fn=pywatts_input_fn(df, num_epochs=1, shuffle=False))
|
|
||||||
|
|
Loading…
Reference in a new issue