From Notebook To Production Part 1
In my previous project, I trained a text classification model using Fast.AI, which can classify a tweet as either Conservative or Liberal. I would like to share it in a way that allows users to make predictions using whatever text they want.
Flask is a web application framework for publishing Python-based web apps. Using Flask, I can make a very lightweight application that can serve my trained model as an API.
Web Application Requirements
I know the application will have two endpoints. The base URL /
and the model API URL /api
.
My app needs to be able to perform these tasks.
-
If a user sends a blank
GET
request to the base URL of the web app, the app responds withAPI UP!
-
The user sends a
POST
request to the/api
with a JSON containing the text to be classified by the model. Example Request:
a. If the request JSON is well-formed, the application takes the text from the request and sends it to the model. Then sends the Model response back to the user as a JSON
b. If the request JSON does not have atext
key, send an error message as a response.
c. If thetext
key value in the request JSON is not astring
send an error message as a response.
Example User Reqeust JSON:
{
'text': 'Text to be classified'
}
Example Successful Application Response:
{
'percent_in_class': 61.0,
'predicted_class': 'Liberal'
}
Example Application Error Message Response:
{
'error': 'Bad request'
}
Unit Testing
Now that I have some Requirements for my application, I can write some tests that make sure my app is doing what I want it to do.
You may have noticed that I have not written a single line of code, and I am talking about testing an application that doesn’t exist. Let’s talk about Test-Driven Development.
Developers usually write some code then run it to see if it is doing what it is supposed to do. This is called “manual testing.” It’s not a bad way to do things if your application is minimal. If your application does 100’s of different things, this could get tedious. Do you want to be the guy that has to manually do this every time a person makes a change? You could write a script, devs love scripts. Whenever you write new code, you add to the script. That is getting closer to automated testing.
With Test-Driven Development, the idea is to write your tests first and automate all testing. You write code. When your code passes the test without failing tests that already passed, your portion is complete and can be integrated into the codebase.
Writing tests before code necessitates prior planning about what the application should do. Prior Planning Prevents Poor Performance, testing does the same thing.
Python has many different suites that are built for running tests. For this project, I use PyTest.
Above I have four things that my app needs to do. These will be the basis of my tests. When all tests pass, my app meets my needs.
Test GET
request to the /
endpoint
If a user sends a blank GET
request to the /
endpoint, the app responds with: API UP!
and an HTTP status code of 200
# If a blank GET request is sent to the the base URL
# the response should be a string that says 'API UP!'
def test_hello():
response = app.test_client().get('/')
assert response.status_code == 200
assert response.data == b'API UP!'
Test POST
request to /api
endpoint
a. If the JSON sent to the /api
endpoint has the correct Key and the Value is a string, then send the model prediction back to the user and give an HTTP status code of 200
# if a user sends a good request then the response code
# the response code should be 200 and a json should be sent back
def test_api_good_data():
response = app.test_client().post(
'/api',
data=json.dumps({"text": "Great honor, I think? \
Mark Zuckerberg recently stated that “Donald J.\
Trump is Number 1 on Facebook. Number 2 is \
Prime Minister Modi of India.” Actually, I \
am going to India in two weeks.\
Looking forward to it! "
}),
content_type='application/json',
)
data = json.loads(response.get_data(as_text=True))
assert response.status_code == 200
assert data['predicted_class'] == "Conservative"
b. If the request JSON sent to the /api
endpoint does not have a text
key, send an error message and an HTTP status code of 400
as a response.
# if a user sends a JSON without a text key
# the response code should be 400 and an error message should be sent.
def test_api_bad_json_key_data():
response = app.test_client().post(
'/api',
data=json.dumps({ "not_text": "Great honor, I think? \
Mark Zuckerberg recently stated that “Donald J.\
Trump is Number 1 on Facebook. Number 2 is \
Prime Minister Modi of India.” Actually, I \
am going to India in two weeks.\
Looking forward to it! "
}),
content_type='application/json',
)
data = json.loads(response.get_data(as_text=True))
assert response.status_code == 400
assert data['error'] == "Bad request"
c. If the text
key value in the request JSON sent to the /api
endpoint is not a string
, send an error message and an HTTP status code of 400
as a response.
# if a user sends a JSON with non string data as the text key value
# the response code should be 400 and a error message should be sent
def test_api_not_string_data():
response = app.test_client().post(
'/api',
data=json.dumps({"text": 7
}),
content_type='application/json',
)
data = json.loads(response.get_data(as_text=True))
assert response.status_code == 400
assert data['error'] == "Bad request"
The Flask Application
Our tests are ready and we know what we need the application to do. In a separate script I wrote some functions that import my model and do the parsing of the incoming and outgoing data.
predict_functions.py_
from fastai.text import load_learner
# load the trained model from .pkl file
learn = load_learner("", file="export.pkl")
def extract_text_from_request(request_json):
'''
take text from request json save as a string
'''
request_text = request_json["text"]
return request_text
def predict_text_political_class(request_text):
'''
send the string to the model to be classified
return the response from the model
'''
pred = learn.predict(request_text)
return pred
def prediction_to_dict(prediction):
'''
return the prediction as a dictionary
extract correct portion of the prediction
based on the predicted class.
'''
# if response is liberal
if int(prediction[0]) == 1:
political_class = 'Liberal'
percent_in_class = round(float(prediction[2][1]), 2) * 100
response_dict = {'predicted_class': political_class,
'percent_in_class': percent_in_class}
return response_dict
# if response is conservative
else:
political_class = 'Conservative'
percent_in_class = round(float(prediction[2][0]), 2) * 100
response_dict = {'predicted_class': political_class,
'percent_in_class': percent_in_class}
return response_dict
Next I create the Flask Application. Flask works by setting routes to different functions. Each route is a different URL. In my application I have two routes the base URL \
and the URL for interacting with my model \api
. I also define a function for sending an error response.
You notice that I import from my predict_functions script
app.py_
from flask import Flask, request, jsonify, abort, make_response
from predict_functions import *
app = Flask(__name__)
# Taken from https://blog.miguelgrinberg.com
@app.errorhandler(400)
def bad_request(abort):
return make_response(jsonify({'error': 'Bad request'}), 400)
# home url, returns API UP
@app.route('/', methods=['GET'])
def confirm_server_function():
return "API UP!"
# api prediction url, returns a prediction from a POST request
@app.route('/api', methods=['POST'])
def answer_api_request():
# extract json from the request save as variable
request_json = request.json
# if the correct key is in the json move on to next step
# else return error json
if "text" in request_json:
# get text from the json
request_text: str = extract_text_from_request(request_json)
# if content of text key is a string advance to next step
if isinstance(request_text, str):
# take text and predict class
prediction =
predict_text_political_class(request_text=request_text)
# convert model prediction to a dictionary
response = prediction_to_dict(prediction=prediction)
# return the response as a JSON
return jsonify(response)
# return error
else:
abort(400)
# return error
else:
abort(400)
if __name__ == '__main__':
app.run(host='0.0.0.0')
Next I add the Gunicorn HTTP server because it is easy and the basic WSGI HTTP server that comes with Flask is not made for production. Here I add another python script.
wsgi.py_
from app import app
if __name__ == "__main__":
app.run()
My file structure looks like this:
FLASK_APP
|
+---app.py
|
+--predict_functions.py
|
+--wsgi.py
|
+--test_app.py
I run my tests. They pass.
PyTest says that my application does what I want it to do. Let’s start it up and see if it runs.
Test the API’s by sending some curl
requests to the application while it is running.
The first is an empty GET
request to the /
endpoint.
The second is a well formed JSON sent to the /api
endpoint that should return a prediction.
The third and fourth are malformed requests that should result in an error.
Each request did what PyTest says it should do.
We have a working Flask web application. It behaves exactly as we planned. We use PyTest to test to run tests. If I want to add a new feature I can write another test. When it passes that test withough breaking the previous 4, I will integrate the code.
In Part 2 of this blog series I will Containerize this web application with Docker.
View the full repo on GitLab