From 8d4a6728912d49fca6a5f19ab8334d41cbb5894e Mon Sep 17 00:00:00 2001 From: Sean McGregor Date: Mon, 14 Mar 2016 22:47:20 -0700 Subject: [PATCH] added caching and updated server --- flask_server.py | 42 +++++++++++++++++++++++++++++------------- 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/flask_server.py b/flask_server.py index 5d3d834..0b6135a 100644 --- a/flask_server.py +++ b/flask_server.py @@ -8,8 +8,10 @@ :copyright: (C) 2015 by Sean McGregor. :license: MIT, see LICENSE for more details. """ -from flask import Flask, jsonify, request, redirect +from flask import Flask, jsonify, request, redirect, json import sys +import re +import os.path print """ Starting Flask Server... @@ -17,18 +19,11 @@ positional argument. """ -# Specify which domain should be selected from the domain bridge. -# This is an optional argument that is used for domain_bridges that sit -# atop a collection of MDP domains. You should leave this as the -# empty string unless you must select between domains. -domain = "" -if len(sys.argv) > 1: - domain = sys.argv[1] - # Add the parent folder path to the sys.path list so # we can include its bridge sys.path.insert(0, '..') -import domain_bridge +import hiv_bridge as domain_bridge +#import mountain_car_bridge as domain_bridge try: # The typical way to import flask-cors @@ -104,8 +99,27 @@ def cross_origin_rollouts(): requested parameters. ''' q = parse_query(request.args) - rollouts = domain_bridge.rollouts(q) - return jsonify({"rollouts": rollouts}) + + # todo: make better + # Quick hack for caching. + filename = 'cache/' + re.sub(r'\W+', '', str(q)) + if os.path.isfile(filename): + print """ + + Served from cache. Clear the cache folder if you want to regenerate. + + """ + f = open(filename, 'r') + json_obj = json.load(f) + resp = jsonify(json_obj) + else: + rollouts = domain_bridge.rollouts(q) + json_obj = {"rollouts": rollouts} + resp = jsonify(json_obj) + f = open(filename, 'w') + json.dump(json_obj, f) + f.close() + return resp @app.route("/optimize", methods=['GET']) @cross_origin(allow_headers=['Content-Type']) @@ -141,4 +155,6 @@ def parse_query(queryObject): # Binds the server to port 8938 and listens to all IP addresses. if __name__ == "__main__": - app.run(host='0.0.0.0', port=8938, debug=True) + print("Starting server...") + app.run(host='0.0.0.0', port=8938, debug=True, use_reloader=False, threaded=True) + print("...started")