From Notebook To Production Part 1

7 minute read

In my previous project, I trained a text classification model using Fast.AI, which can classify a tweet as either Conservative or Liberal. I would like to share it in a way that allows users to make predictions using whatever text they want.

Flask is a web application framework for publishing Python-based web apps. Using Flask, I can make a very lightweight application that can serve my trained model as an API.

Web Application Requirements

I know the application will have two endpoints. The base URL / and the model API URL /api.

My app needs to be able to perform these tasks.

  1. If a user sends a blank GET request to the base URL of the web app, the app responds with API UP!

  2. The user sends a POST request to the /api with a JSON containing the text to be classified by the model. Example Request:
    a. If the request JSON is well-formed, the application takes the text from the request and sends it to the model. Then sends the Model response back to the user as a JSON
    b. If the request JSON does not have a text key, send an error message as a response.

    c. If the text key value in the request JSON is not a string send an error message as a response.


Example User Reqeust JSON:

{
'text': 'Text to be classified'
}  

Example Successful Application Response:

{
'percent_in_class': 61.0,
'predicted_class': 'Liberal'
}  

Example Application Error Message Response:

{
'error': 'Bad request'
}

Unit Testing

Now that I have some Requirements for my application, I can write some tests that make sure my app is doing what I want it to do.

You may have noticed that I have not written a single line of code, and I am talking about testing an application that doesn’t exist. Let’s talk about Test-Driven Development.

Developers usually write some code then run it to see if it is doing what it is supposed to do. This is called “manual testing.” It’s not a bad way to do things if your application is minimal. If your application does 100’s of different things, this could get tedious. Do you want to be the guy that has to manually do this every time a person makes a change? You could write a script, devs love scripts. Whenever you write new code, you add to the script. That is getting closer to automated testing.

With Test-Driven Development, the idea is to write your tests first and automate all testing. You write code. When your code passes the test without failing tests that already passed, your portion is complete and can be integrated into the codebase.

Writing tests before code necessitates prior planning about what the application should do. Prior Planning Prevents Poor Performance, testing does the same thing.

Python has many different suites that are built for running tests. For this project, I use PyTest.

Above I have four things that my app needs to do. These will be the basis of my tests. When all tests pass, my app meets my needs.

Test GET request to the / endpoint

If a user sends a blank GET request to the / endpoint, the app responds with: API UP! and an HTTP status code of 200

# If a blank GET request is sent to the the base URL
# the response should be a string that says 'API UP!'
def test_hello():
    response = app.test_client().get('/')

    assert response.status_code == 200
    assert response.data == b'API UP!'

Test POST request to /api endpoint

a. If the JSON sent to the /api endpoint has the correct Key and the Value is a string, then send the model prediction back to the user and give an HTTP status code of 200

# if a user sends a good request then the response code
# the response code should be 200 and a json should be sent back
def test_api_good_data():
    response = app.test_client().post(
        '/api',
        data=json.dumps({"text": "Great honor, I think? \
         Mark Zuckerberg recently stated that “Donald J.\
          Trump is Number 1 on Facebook. Number 2 is \
          Prime Minister Modi of India.” Actually, I \
          am going to India in two weeks.\
           Looking forward to it! "
                         }),

        content_type='application/json',
    )
    data = json.loads(response.get_data(as_text=True))
    assert response.status_code == 200
    assert data['predicted_class'] == "Conservative"

b. If the request JSON sent to the /api endpoint does not have a text key, send an error message and an HTTP status code of 400 as a response.

# if a user sends a JSON without a text key
# the response code should be 400 and an error message should be sent.
def test_api_bad_json_key_data():
    response = app.test_client().post(
        '/api',
        data=json.dumps({ "not_text":  "Great honor, I think? \
         Mark Zuckerberg recently stated that “Donald J.\
          Trump is Number 1 on Facebook. Number 2 is \
          Prime Minister Modi of India.” Actually, I \
          am going to India in two weeks.\
           Looking forward to it! "
                         }),
        content_type='application/json',
    )
    data = json.loads(response.get_data(as_text=True))
    assert response.status_code == 400
    assert data['error'] == "Bad request"

c. If the text key value in the request JSON sent to the /api endpoint is not a string, send an error message and an HTTP status code of 400 as a response.

# if a user sends a JSON with non string data as the text key value
# the response code should be 400 and a error message should be sent
def test_api_not_string_data():
    response = app.test_client().post(
        '/api',
        data=json.dumps({"text": 7
                         }),
        content_type='application/json',
    )
    data = json.loads(response.get_data(as_text=True))
    assert response.status_code == 400
    assert data['error'] == "Bad request"

The Flask Application

Our tests are ready and we know what we need the application to do. In a separate script I wrote some functions that import my model and do the parsing of the incoming and outgoing data.

predict_functions.py_

from fastai.text import load_learner

# load the trained model from .pkl file
learn = load_learner("", file="export.pkl")  


def extract_text_from_request(request_json):
  '''
  take text from request json save as a string
  '''
    request_text = request_json["text"]
    return request_text


def predict_text_political_class(request_text):
  '''
  send the string to the model to be classified
  return the response from the model
  '''
    pred = learn.predict(request_text)
    return pred


def prediction_to_dict(prediction):
  '''
  return the prediction as a dictionary
  extract correct portion of the prediction
  based on the predicted class.
  '''
    # if response is liberal
    if int(prediction[0]) == 1:
        political_class = 'Liberal'
        percent_in_class = round(float(prediction[2][1]), 2) * 100

        response_dict = {'predicted_class': political_class,
                         'percent_in_class': percent_in_class}
        return response_dict

    # if response is conservative
    else:
        political_class = 'Conservative'
        percent_in_class = round(float(prediction[2][0]), 2) * 100
        response_dict = {'predicted_class': political_class,
                         'percent_in_class': percent_in_class}
        return response_dict

Next I create the Flask Application. Flask works by setting routes to different functions. Each route is a different URL. In my application I have two routes the base URL \ and the URL for interacting with my model \api. I also define a function for sending an error response.

You notice that I import from my predict_functions script

app.py_


from flask import Flask, request, jsonify, abort, make_response

from predict_functions import *

app = Flask(__name__)


# Taken from https://blog.miguelgrinberg.com
@app.errorhandler(400)
def bad_request(abort):
    return make_response(jsonify({'error': 'Bad request'}), 400)


# home url, returns API UP
@app.route('/', methods=['GET'])
def confirm_server_function():
    return "API UP!"


# api prediction url, returns a prediction from a POST request
@app.route('/api', methods=['POST'])
def answer_api_request():
    # extract json from the request save as variable
    request_json = request.json

    # if the correct key is in the json move on to next step
    # else return error json
    if "text" in request_json:

        # get text from the json
        request_text: str = extract_text_from_request(request_json)

        # if content of text key is a string advance to next step
        if isinstance(request_text, str):

            # take text and predict class
            prediction =
            predict_text_political_class(request_text=request_text)

            # convert model prediction to a dictionary
            response = prediction_to_dict(prediction=prediction)

            # return the response as a JSON
            return jsonify(response)

        # return error
        else:
           abort(400)

    # return error
    else:
        abort(400)


if __name__ == '__main__':
    app.run(host='0.0.0.0')

Next I add the Gunicorn HTTP server because it is easy and the basic WSGI HTTP server that comes with Flask is not made for production. Here I add another python script.

wsgi.py_

from app import app

if __name__ == "__main__":
   app.run()

My file structure looks like this:

FLASK_APP
   |
   +---app.py
   |
   +--predict_functions.py
   |
   +--wsgi.py  
   |
   +--test_app.py

I run my tests. They pass.

png

PyTest says that my application does what I want it to do. Let’s start it up and see if it runs.

png

Test the API’s by sending some curl requests to the application while it is running.

The first is an empty GET request to the / endpoint.

The second is a well formed JSON sent to the /api endpoint that should return a prediction.

The third and fourth are malformed requests that should result in an error.

Each request did what PyTest says it should do.

png

We have a working Flask web application. It behaves exactly as we planned. We use PyTest to test to run tests. If I want to add a new feature I can write another test. When it passes that test withough breaking the previous 4, I will integrate the code.

In Part 2 of this blog series I will Containerize this web application with Docker.

View the full repo on GitLab

Updated: