diff --git a/pywatts/predict_for_json.py b/pywatts/photovoltaic_gruppe1.py similarity index 69% rename from pywatts/predict_for_json.py rename to pywatts/photovoltaic_gruppe1.py index 98f666f..3889f6c 100644 --- a/pywatts/predict_for_json.py +++ b/pywatts/photovoltaic_gruppe1.py @@ -9,21 +9,20 @@ from pywatts.routines import * # get rid of TF debug message os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' -if len(sys.argv) != 3: - print("Usage: python predict_for_json.py 24h|1h ") +if len(sys.argv) != 2: + print("Usage: python photovoltaic_gruppe1.py ") exit(1) -type = sys.argv[1] # '1h' or '24h' -json_file = sys.argv[2] # json file +json_file = sys.argv[1] # json file -queries = input_queries(json_file) +oneH, queries = input_queries(json_file) feature_col = [tf.feature_column.numeric_column(str(idx)) for idx in range(336)] n = pywatts.neural.Net(feature_cols=feature_col) predictions = [] for query in queries: - if type == '1h': + if oneH: predictions.extend(predict(n, query).astype('Float64').tolist()) else: predictions.append(predict24h(n, query)) diff --git a/pywatts/routines.py b/pywatts/routines.py index bf3919f..54e3beb 100644 --- a/pywatts/routines.py +++ b/pywatts/routines.py @@ -39,6 +39,12 @@ def input_query(json_str, idx=0): def input_queries(json_str): tmp_df = pandas.read_json(json_str) + oneH = False + try: + s = tmp_df['max_temp'][0] + except KeyError: + oneH = True + queries = [] for i in range(len(tmp_df)): queries.append(pandas.DataFrame.from_dict( @@ -46,7 +52,7 @@ def input_queries(json_str): 'temp': tmp_df['temp'][i], 'wind': tmp_df['wind'][i]} )) - return queries + return oneH, queries def input_result(json_str, idx=0):