add sample data and main functions
This commit is contained in:
parent
c3c782bf02
commit
ba0b67565b
6 changed files with 54 additions and 4 deletions
|
@ -1,4 +1,5 @@
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
import matplotlib.pyplot as pp
|
||||||
import pywatts.neural
|
import pywatts.neural
|
||||||
|
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
|
@ -9,5 +10,26 @@ y = df['dc']
|
||||||
|
|
||||||
X_train, X_tmp, y_train, y_tmp = train_test_split(X, y, test_size=0.2, random_state=23)
|
X_train, X_tmp, y_train, y_tmp = train_test_split(X, y, test_size=0.2, random_state=23)
|
||||||
|
|
||||||
|
X_test, X_val, y_test, y_val = train_test_split(X_tmp, y_tmp, test_size=0.5, random_state=23)
|
||||||
|
|
||||||
|
X_train.shape, X_test.shape, X_val.shape
|
||||||
|
|
||||||
feature_cols = [tf.feature_column.numeric_column(col) for col in X.columns]
|
feature_cols = [tf.feature_column.numeric_column(col) for col in X.columns]
|
||||||
n = pywatts.neural.Net(feature_cols=feature_cols)
|
n = pywatts.neural.Net(feature_cols=feature_cols)
|
||||||
|
|
||||||
|
|
||||||
|
def train(steps=100):
|
||||||
|
evaluation = []
|
||||||
|
for i in range(steps):
|
||||||
|
n.train(X_train, y_train, steps=400)
|
||||||
|
evaluation.append(n.evaluate(X_val, y_val))
|
||||||
|
print("Training %s of %s" % (i, steps))
|
||||||
|
return evaluation
|
||||||
|
|
||||||
|
|
||||||
|
def plot_training(evaluation):
|
||||||
|
loss = []
|
||||||
|
for e in evaluation:
|
||||||
|
loss.append(e['loss'])
|
||||||
|
pp.plot(loss)
|
||||||
|
|
||||||
|
|
|
@ -21,9 +21,9 @@ class Net:
|
||||||
def train(self, training_data, training_results, steps):
|
def train(self, training_data, training_results, steps):
|
||||||
self.__regressor.train(input_fn=pywatts_input_fn(training_data, y=training_results, num_epochs=None, shuffle=True), steps=steps)
|
self.__regressor.train(input_fn=pywatts_input_fn(training_data, y=training_results, num_epochs=None, shuffle=True), steps=steps)
|
||||||
|
|
||||||
def evaluate(self, eval_data):
|
def evaluate(self, eval_data, eval_results):
|
||||||
self.__regressor.evaluate(input_fn=self.pywatts_input_fn(eval_data, num_epochs=1, shuffle=False), steps=1)
|
return self.__regressor.evaluate(input_fn=pywatts_input_fn(eval_data, y=eval_results, num_epochs=1, shuffle=False), steps=1)
|
||||||
|
|
||||||
def predict1h(self, df):
|
def predict1h(self, df):
|
||||||
df = df.drop(['month', 'day', 'hour'])
|
df = df.drop(['month', 'day', 'hour'])
|
||||||
return self.__regressor.predict(input_fn=self.pywatts_input_fn(df, num_epochs=1, shuffle=False))
|
return self.__regressor.predict(input_fn=pywatts_input_fn(df, num_epochs=1, shuffle=False))
|
||||||
|
|
7
sample_data/query-sample_1hour.json
Normal file
7
sample_data/query-sample_1hour.json
Normal file
File diff suppressed because one or more lines are too long
7
sample_data/query-sample_24hour.json
Normal file
7
sample_data/query-sample_24hour.json
Normal file
File diff suppressed because one or more lines are too long
7
sample_data/result-sample_1hour.json
Normal file
7
sample_data/result-sample_1hour.json
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
[
|
||||||
|
[3403.8909999999996],
|
||||||
|
[0.0],
|
||||||
|
[312.218],
|
||||||
|
[0.0],
|
||||||
|
[2609.3089999999997]
|
||||||
|
]
|
7
sample_data/result-sample_24hour.json
Normal file
7
sample_data/result-sample_24hour.json
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
[
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 26.877, 282.751, 677.8530000000001, 1793.24, 3116.135, 4308.566, 5204.581, 5719.605, 5700.894, 5469.004, 4907.611, 3983.098, 2998.6240000000003, 1690.155, 701.6519999999999, 277.964, 31.974, 0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 38.739000000000004, 122.022, 320.14, 575.778, 829.742, 1055.714, 1230.401, 1350.3039999999999, 4218.804, 2571.766, 2437.692, 2836.6690000000003, 2504.74, 1645.876, 679.4889999999999, 183.74400000000003, 25.428, 0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 58.684, 477.876, 1870.129, 3450.4309999999996, 4026.9629999999997, 5087.083, 5438.415, 4964.932, 4084.5290000000005, 2302.481, 678.784, 118.728, 6.505, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0, 94.98899999999999, 582.484, 1826.751, 2995.259, 4350.571, 4532.098, 3855.2729999999997, 3992.255, 3785.785, 4398.564, 3474.6040000000003, 2636.6809999999996, 1293.424, 407.849, 46.41, 0.0, 0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0, 50.655, 264.48400000000004, 684.3539999999999, 2482.857, 3943.58, 4959.603, 5584.967, 4303.115, 3335.736, 2817.76, 1555.6370000000002, 580.747, 662.675, 252.368, 30.809, 0.0, 0.0, 0.0, 0.0]
|
||||||
|
]
|
Loading…
Reference in a new issue