"""Tools for building RESTful services that exposes machine learning models.
Building and running an app with the tools in this module is as simple as
1. Instantiating :class:`ModelApp`.
2. Instantiating a "service". E.g. instantiate :class:`PredictionService` for each
model you wish to add to the service.
3. Use the service(s) created in 2. to add models to the app with either
:meth:`ModelApp.add_service()` or :meth:`ModelApp.add_services()`.
For example,
>>> model_app = ModelApp()
>>> prediction_service1 = PredictionService(...)
>>> prediction_service2 = PredictionService(...)
>>> model_app.add_services(prediction_servie1, prediction_service2)
Now the model app can be run with ``model_app.run()`` for development, or as an
example of running the app in production ``$ gunicorn my_module:model_app``.
"""
import abc
import gzip
import json
import logging
import warnings
import flask
import pandas as pd
import werkzeug.exceptions as werkzeug_exc
from . import api
from . import config as cf
from . import constants as cn
from . import responses as porter_responses
from . import schemas
from .exceptions import PorterException
from . import __version__ as VERSION
# alias for convenience
_ID = cn.PREDICTION_PREDICTIONS_KEYS.ID
_logger = logging.getLogger(__name__)
[docs]
class StatefulRoute:
"""Helper class to ensure that classes we intend to route via their
__call__() method satisfy the flask interface.
"""
def __new__(cls, *args, **kwargs):
# flask looks for the __name__ attribute of the routed callable,
# and each name of a routed object must be unique.
# Therefore we define a unique name here to meet flask's expectations.
instance = super().__new__(cls)
if not hasattr(cls, '_instances'):
cls._instances = 0
cls._instances += 1
instance.__name__ = '%s_%s' % (cls.__name__.lower(), cls._instances)
return instance
[docs]
def serve_error_message(error):
response = porter_responses.make_error_response(error)
_logger.exception(response.data)
return response.jsonify()
[docs]
class ServeRoot(StatefulRoute):
_message = 'Send POST requests to /< model-name >/prediction/'
def __init__(self, app):
self.app = app
def __call__(self):
if self.app.expose_docs:
return flask.redirect(self.app.docs_url)
return self._message
[docs]
class ServeAlive(StatefulRoute):
"""Class for building stateful liveness routes.
Args:
app (object): A :class:`ModelApp` instance. Instances of this class inspect
``app`` when called to determine if the app is alive.
"""
logger = logging.getLogger(__name__)
def __init__(self, app):
self.app = app
def __call__(self):
"""Serve liveness response."""
response = porter_responses.make_alive_response(self.app)
self.logger.info(response.data)
return response.jsonify()
[docs]
class ServeReady(StatefulRoute):
"""Class for building stateful readiness routes.
Args:
app (object): A :class:`ModelApp` instance. Instances of this class inspect
``app`` when called to determine if the app is alive.
"""
logger = logging.getLogger(__name__)
def __init__(self, app):
self.app = app
def __call__(self):
"""Serve readiness response."""
response = porter_responses.make_ready_response(self.app)
self.logger.info(response.data)
return response.jsonify()
[docs]
class BaseService(abc.ABC, StatefulRoute):
"""
A service class contains all necessary state and functionality to route a
service and serve requests.
This class does nothing on its own and is meant to be extended.
Args:
name (str): The model name. The final routed endpoint is generally
derived from this parameter.
api_version (str): The service API version. The final routed endpoint is
generally derived from this parameter.
meta (dict): Additional meta data added to the response body. Optional.
log_api_calls (bool): Log request and response and response data.
Default is False.
namespace (str): String identifying a namespace that the service belongs
to. Used to route services by subclasses. Default is "".
validate_request_data (bool): Whether to validate the request data or
not. Applies to all HTTP methods and does nothing if
:meth:`add_request_schema()` is never called.
validate_response_data (bool): Whether to validate the response data
or not. Applies to all HTTP methods and does nothing if
:meth:`add_response_schema()` is never called.
Attributes:
id (str): A unique ID for the service.
name (str): The model name. The final routed endpoint is generally
derived from this attribute.
api_version (str): The service version.
meta (dict): Additional meta data added to the response body. Optional.
log_api_calls (bool): Log request and response and response data.
Default is False.
namespace (str): A namespace that the service belongs to.
validate_request_data (bool): Whether to validate the request data or
not. Applies to all HTTP methods and does nothing if
:meth:`add_request_schema()` is never called.
validate_response_data (bool): Whether to validate the response data
or not. Applies to all HTTP methods and does nothing if
:meth:`add_response_schema()` is never called.
action (str): ``str`` describing the action of the service, e.g.
"prediction". Used to determine the final routed endpoint.
endpoint (str): The endpoint where the service is exposed.
request_schemas (dict): Dictionary mapping HTTP methods to instances
of :class:`porter.schemas.RequestSchema`. Each ``RequestSchema``
object is added from calls to :meth:`add_request_schema` and
instantiated from the corresponding arguments.
response_schemas(dict): Dictionary mapping HTTP methods to a list of
:class:`porter.schemas.ResponseSchema`. Each ``ResponseSchema`` object
is added from calls to :meth:`add_request_schema` and
instantiated from the corresponding arguments.
"""
_ids = set()
_logger = logging.getLogger(__name__)
_default_response_schemas = [
# bad request: raised by flask if json can't be parsed
('POST', 400, schemas.model_context_error, None),
# Unsupported media type: content-encoding not supported
('POST', 415, schemas.model_context_error, None),
# Unprocessable entity: valid json with semantic errors raised by porter
('POST', 422, schemas.model_context_error, None),
# internal server error: any unhandled exception in .serve() will cause this
('POST', 500, schemas.model_context_error, None),
]
# subclasses can override this to add additional defaults
_service_default_schemas = []
def __init__(self, *, name, api_version, meta=None, log_api_calls=False,
namespace='', validate_request_data=False,
validate_response_data=False):
self.name = name
self.api_version = api_version
self.meta = {} if meta is None else meta
self.check_meta(self.meta)
self.namespace = namespace
self.validate_request_data = validate_request_data
self.validate_response_data = validate_response_data
if self.validate_response_data:
warnings.warn('Setting ``validate_response_data`` may significantly '
'impact the latency of responses and return confusing '
'error messages to the user. '
'Use only during development for testing and debugging. '
'This is an experimental feature and may be removed in '
'future releases.')
# Assign endpoint and ID last so they can be determined from other
# instance attributes. If the order of assignment changes here these
# methods may attempt to access attributes that have not been set yet
# and fail.
self.endpoint = self.define_endpoint()
self.id = self.define_id()
self.meta = self.update_meta(self.meta)
self.log_api_calls = log_api_calls
# these are a public interface exposing user registered schemas
self.request_schemas = {}
self.response_schemas = {}
# these are used internally for lookups at runtime
self._request_schemas = {}
self._response_schemas = {}
# add response schemas explicitly returned (i.e. raised) by porter.
for schema in self._default_response_schemas:
self.add_response_schema(*schema)
for schema in self._service_default_schemas:
self.add_response_schema(*schema)
def __call__(self):
"""Process HTTP requests and return a response.
This method is the main mechanism for serving model predictions.
Essentially this method wraps ``self.serve()`` (which must be implemented
by a subclass) with boiler plate such as
- Converting Python objects to raw HTTP responses.
- Logging API requests.
- Error handling.
- Validation of response schemas.
Returns:
A "Response" object or ``None``: The output of ``self.serve()`` converted
wrapped in a Response object with a status code to be served
to the client. Currently this is a `flask.Response` object in
particular.
Raises:
:class:`werkzeug.exceptions.HTTPException`
"""
# Default response is `null` in the event that an error occurs in
# self.serve()
response = None
# When we do `except Exception as error` the name `error` will not
# live outside of the except scope. We'll check if this name has been
# overwritten in the finally block to determine if we need to log an
# error.
caught_error = None
# Add the service to a context that is unique by http transaction.
# This allows us to determine how to approach error handling in
# resonses.py (or anywhere else for that matter).
api.set_model_context(self)
try:
response = self.serve()
# Allow users to return a JSON-like object instead of a `Response`.
# This is much more user friendly.
if not isinstance(response, porter_responses.Response):
response = porter_responses.Response(response)
# Note on the error handling here. The purpose of this block is to log
# any errors that were raised in self.serve(). If we happen to catch
# an unhandled exception we'll re-raise an internal server error.
#
# The only reason we distinguish between HTTPException and Exception
# here is to give unhandled exceptions a message more relevant to `porter`
# than the werkzeug default.
except (PorterException, werkzeug_exc.HTTPException) as error:
caught_error = error
# we need to make a porter response here so the status code exists
# below when we validate the response data.
response = porter_responses.make_error_response(caught_error)
raise error
except Exception as error:
caught_error = error
msg = 'Could not serve model results successfully.'
wrapped_error = werkzeug_exc.InternalServerError(msg)
# we need to make a porter response here so the status code exists
# below when we validate the response data.
response = porter_responses.make_error_response(wrapped_error)
raise wrapped_error from error
finally:
response = response.jsonify()
# log the original error, not necessarily the one we raised
# (i.e. InternalServerError)
if caught_error is not None:
self._log_error(caught_error)
if self.log_api_calls:
request_data = api.request_json()
response_data = getattr(response, 'raw_data', response)
self._log_api_call(request_data, response_data)
if self.validate_response_data:
schema = self._response_schemas.get((api.request_method(), response.status_code))
if schema is not None:
# eh, this is a bit of hack to work around
# fastjsonschema not understanding how to validate numpy
# types.
# There is probably a better way to do this, but this validates
# exactly what we want to and is a quick fix for a feature that
# is experimental anyway.
validation_data = response.data
if response.headers.get('Content-Encoding', None) == 'gzip':
validation_data = gzip.decompress(validation_data).decode('utf-8')
schema.validate(json.loads(validation_data))
return response
[docs]
def define_endpoint(self):
"""Return the service endpoint derived from instance attributes."""
endpoint = cn.ENDPOINT_TEMPLATE.format(
namespace=self.namespace, service_name=self.name,
api_version=self.api_version, action=self.action)
return endpoint
[docs]
@abc.abstractmethod
def serve(self):
"""Return a response to be served to the user.
Users extending this base class will want to return a native Python
object such as a ``str`` or ``dict``. In such cases the object must be
compatible with :obj:`porter.config.json_encoder`.
For subclasses defined internally, this should be the return value of
one of the functions in :mod:`porter.responses` or an instance of
:class:`porter.responses.Response`.
"""
@abc.abstractproperty
def status(self):
"""Return ``str`` representing the status of the service."""
@property
def route_kwargs(self):
"""Keyword arguments to use when routing ``self.serve()``."""
return {}
@property
@abc.abstractproperty
def action(self):
"""``str`` describing the action of the service, e.g. "prediction".
Used to determine the final routed endpoint.
"""
[docs]
def define_id(self):
"""Return a unique ID for the service. This is used to set the ``id``
attribute.
"""
return self.endpoint
@property
def id(self):
"""A unique ID for the instance."""
return self._id
@id.setter
def id(self, value):
"""
Raises:
:class:`ValueError`
"""
if value in self._ids:
raise ValueError(
f'The id={value} has already been used. '
'This likely means that you tried to instantiate a service '
'with parameters that were already used.')
self._ids.add(value)
self._id = value
@property
def namespace(self):
"""A namespace that the service belongs to."""
return self._namespace
@namespace.setter
def namespace(self, value):
if value and not value.startswith('/'):
value = '/' + value
if value and value.endswith('/'):
value = value[:-1]
self._namespace = value
@property
def name(self):
"""The model name. The final routed endpoint is generally derived from
this parameter.
"""
return self._name
@name.setter
def name(self, value):
self._name = value
@property
def api_version(self):
"""The model version. The final routed endpoint is generally derived from
this parameter.
"""
return self._api_version
@api_version.setter
def api_version(self, value):
self._api_version = value
[docs]
def get_post_data(self):
"""Return POST data.
Returns:
The result of ``porter.config.json_encoder``
Raises:
:class:`werkzeug.exceptions.UnprocessableEntity`
If ``self.validate_request_data is True`` and a request schema has
been defined the data will be validated against the schema.
"""
data = api.request_json()
if self.validate_request_data:
schema = self._request_schemas.get('POST')
if schema is not None:
try:
schema.validate(data)
except ValueError as err:
if err.args[0].startswith('Schema validation failed'):
raise werkzeug_exc.UnprocessableEntity(*err.args)
else:
raise err
return data
def _log_api_call(self, request_data, response_data):
self._logger.info('api logging',
extra={'request_id': api.request_id(),
'request_data': request_data,
'response_data': response_data,
'service_class': self.__class__.__name__,
'event': 'api_call'})
def _log_error(self, error):
self._logger.exception(error,
extra={'request_id': api.request_id(),
'service_class': self.__class__.__name__,
'event': 'exception'})
[docs]
def add_request_schema(self, method, api_obj, description=None):
"""Add a request schema.
Args:
method (str): The HTTP method, usually GET or POST.
api_obj (:class:`porter.schemas.ApiObject`): The request data schema.
description (str): Description of the schema. Optional.
"""
method = method.upper()
self.request_schemas[method] = schemas.RequestSchema(api_obj, description)
self._request_schemas[method] = api_obj
[docs]
def add_response_schema(self, method, status_code, api_obj, description=None):
"""Add a response schema.
Args:
method (str): The HTTP method, usually GET or POST.
status_code (int): The HTTP response status code.
api_obj (:class:`porter.schemas.ApiObject`): The request data schema.
description (str): Description of the schema. Optional.
"""
method = method.upper()
self._response_schemas[(method, status_code)] = api_obj
if not method in self.response_schemas:
self.response_schemas[method] = []
self.response_schemas[method].append(schemas.ResponseSchema(api_obj, status_code, description))
[docs]
class PredictionService(BaseService):
"""
A prediction service. Instances can be added to instances of :class:`ModelApp`
to serve predictions.
Args:
name (str): The model name. The final routed endpoint will become
"/<namespace>/<name>/<api version>/<action>/".
api_version (str): The model API version. The final routed endpoint
will become "/<namespace>/<name>/<api version>/<action>/".
meta (dict): Additional meta data added to the response body. Optional.
log_api_calls (bool): Log request and response and response data.
Default is False.
namespace (str): String identifying a namespace that the service belongs
to. The final routed endpoint will become
"/<namespace>/<name>/<api version>/<action>/". Default is "".
action (str): ``str`` describing the action of the service. Used to
determine the final routed endpoint. Defaults to "prediction". The
final routed endpoint will become
"/<namespace>/<name>/<api version>/<action>/".
model (object): An object implementing the interface defined by
:class:`porter.datascience.BaseModel`.
preprocessor (object or None): An object implementing the interface
defined by :class:`porter.datascience.BaseProcessor`. If not ``None``, the
`.process()` method of this object will be called on the POST
request data and its output will be passed to ``model.predict()``.
Optional.
postprocessor (object or None): An object implementing the interface
defined by :class:`porter.datascience.BaseProcessor`. If not ``None``, the
`.process()` method of this object will be called on the output of
``model.predict()`` and its return value will be used to populate
the predictions returned to the user. Optional.
batch_prediction (bool): Whether or not batch predictions are
supported or not. If ``True`` the API will accept an array of objects
to predict on. If ``False`` the API will only accept a single object
per request. Optional.
additional_checks (callable): If ``additional_checks`` raises a
``ValueError`` when called, a 422 UnprocessableEntity response
will be returned to the user. This method allows users to
implement additional data validations that cannot be expressed
with an OpenAPI schema. The signature should accept a single
positional argument for the validated POST input parsed to a
:obj:`pandas.DataFrame`.
feature_schema (:class:`porter.schemas.Object` or `None`): Description
of an a single feature set. Can be used to validate inputs if
``validate_request_data=True`` and document the API if added to an
instance of :class:`ModelApp` where ``expose_docs=True``.
prediction_schema (:class:`porter.schemas.Object` or `None`):
Description of a single model prediction. Can be used to validate
outputs if ``validate_request_data=True`` and document the API if
added to an instance of :class:`ModelApp` where
``expose_docs=True``.
**kwargs: Keyword arguments passed on to :class:`BaseService`.
Attributes:
id (str): A unique ID for the model. Composed of ``name`` and ``api_version``.
name (str): The model's name.
meta (dict): Additional meta data added to the response body. Optional.
log_api_calls (bool): Log request and response and response data.
Default is False.
namespace (str): String identifying a namespace that the service belongs
to. The final routed endpoint will become
"/<namespace>/<name>/<api version>/prediction/". Default is "".
api_version (str): The model API version.
endpoint (str): The endpoint where the model predictions are exposed.
This is computed as "/<name>/<api version>/prediction/".
model (object): An object implementing the interface defined by
:class:`porter.datascience.BaseModel`.
preprocessor (object or None): An object implementing the interface
defined by :class:`porter.datascience.BaseProcessor`. If not `None`, the
``.process()`` method of this object will be called on the POST
request data and its output will be passed to ``model.predict()``.
Optional.
postprocessor (object or None): An object implementing the interface
defined by :class:`porter.datascience.BaseProcessor`. If not `None`, the
`.process()` method of this object will be called on the output of
``model.predict()`` and its return value will be used to populate
the predictions returned to the user. Optional.
batch_prediction (bool): Whether or not the endpoint supports batch
predictions or not. If ``True`` the API will accept an array of
objects to predict on. If ``False`` the API will only accept a
single object per request. Optional.
additional_checks (callable): Raises ValueError or subclass thereof if
POST request is invalid.
feature_schema (:class:`porter.schemas.Object` or None): Description of an
individual instance to be predicted on. Can be used to validate
inputs if `validate_request_data=True` and document the API if
added to an instance of `ModelApp` where `expose_docs=True`.
prediction_schema (:class:`porter.schemas.Object` or None): Description of an
individual prediction returned to the user. Can be used to
validate outputs if `validate_request_data=True` and document the
API if added to an instance of `ModelApp` where
`expose_docs=True`.
request_schema (:class:`porter.schemas.Object` or None) Description of valid
request format, including instance IDs, and wrapped as Array if
``batch_prediction=True``. Can be used for validation outside of ``porter``.
response_schema (:class:`porter.schemas.Object` or None) Description of valid
POST 200 response format, including ``request_id``, ``model_context``, etc.
"""
route_kwargs = {'methods': ['GET', 'POST'], 'strict_slashes': False}
_service_default_schemas = [
('GET', 200, schemas.String(), None)
]
def __init__(self, *, model, preprocessor=None, postprocessor=None,
action='prediction', batch_prediction=True,
additional_checks=None, feature_schema=None,
prediction_schema=None, **kwargs):
self.model = model
self.preprocessor = preprocessor
self.postprocessor = postprocessor
self.batch_prediction = batch_prediction
if additional_checks is not None and not callable(additional_checks):
raise ValueError('`additional_checks` must be callable')
self._action = action
self.additional_checks = additional_checks
self._preprocess_model_input = self.preprocessor is not None
self._postprocess_model_output = self.postprocessor is not None
# need to do this before handling schemas
super().__init__(**kwargs)
self.feature_schema = feature_schema
self.prediction_schema = prediction_schema
self.request_schema = None
self.response_schema = None
if self.feature_schema is not None:
self._add_feature_schema(self.feature_schema)
self.feature_columns = list(self.feature_schema.properties.keys())
else:
self.feature_columns = None
# if None, we'll add the default schema anyway
self._add_prediction_schema(self.prediction_schema)
@property
def status(self):
"""Return 'READY'. Instances of this class are always ready."""
return cn.HEALTH_CHECK_VALUES.IS_READY
@property
def action(self):
"""``str`` describing the action of the service. Used to
determine the final routed endpoint. The final routed endpoint
will become "/<namespace>/<name>/<api version>/<action>/"."""
return self._action
[docs]
def serve(self):
"""Retrive POST request data from flask and return a response
containing the corresponding predictions.
Returns:
object: A "jsonified" object representing the response to return
to the user.
Raises:
:class:`werkzeug.exceptions.BadRequest`: Raised when request data cannot
be parsed (in super().get_post_data).
:class:`werkzeug.exceptions.UnprocessableEntity`: Raised when parsed
request data does not follow the specified schema (in
super().get_post_data).
:class:`werkzeug.exceptions.UnsupportedMediaType`: Raised when request data
is given in an unsupported Content-Encoding.
"""
if api.request_method() == 'GET':
return porter_responses.Response(
'This endpoint is live. Send POST requests for predictions')
return self._predict()
def _predict(self):
# retrieve the data and validate the inputs. If
# self.validate_request_data is True and a feature schema was
# provided, the schema is vetted in get_post_data()
X_input = self.get_post_data()
# Only perform user checks after the schema has been (optionally)
# validated. This way users don't need to do any error handling in
# additional_checks.
# If the user checks fail (ValueError raised) raise a 422.
if self.additional_checks is not None:
try:
self.additional_checks(X_input)
except ValueError as err:
raise werkzeug_exc.UnprocessableEntity(*err.args) from err
# Once the input data has been fully validated, extract the feature
# columns (all features provided in ``feature_schema``) if provided.
# This allows the user to fully anticipate what features are passed
# to the preprocessor.
if self.feature_columns:
X_preprocessed = X_input[self.feature_columns]
else:
X_preprocessed = X_input
# preprocess if user specified a preprocessor
if self._preprocess_model_input:
X_preprocessed = self.preprocessor.process(X_preprocessed)
# get the predictions
preds = self.model.predict(X_preprocessed)
# postprocess
if self._postprocess_model_output:
preds = self.postprocessor.process(X_input, X_preprocessed, preds)
# finally format the predictions and return
if self.batch_prediction:
response = porter_responses.make_batch_prediction_response(X_input[_ID], preds)
else:
response = porter_responses.make_prediction_response(X_input[_ID].iloc[0], preds[0])
return response
[docs]
def get_post_data(self):
"""Return data from the most recent POST request as a ``pandas.DataFrame``.
Returns:
``pandas.DataFrame``. Each ``row`` represents a single instance to
predict on. If ``self.batch_prediction`` is ``False`` the ``DataFrame``
will only contain one ``row``.
"""
data = super().get_post_data()
if not self.batch_prediction:
data = [data]
return pd.DataFrame(data)
def _add_feature_schema(self, user_schema):
assert isinstance(user_schema, schemas.Object), '``feature_schema`` must be an Object'
# add ID to schema
request_schema = schemas.Object(
properties={
'id': schemas.Integer('An ID uniquely identifying each instance in the POST body.'),
**user_schema.properties},
reference_name=user_schema.reference_name)
if self.batch_prediction:
request_schema = schemas.Array(item_type=request_schema)
# save this so the user can access it
self.request_schema = request_schema
# TODO: should a description be passed?
# https://github.com/CadentTech/porter/issues/32
self.add_request_schema('POST', request_schema)
def _add_prediction_schema(self, user_schema):
prediction_schema = schemas.Object(
'Model output',
properties={
'id': schemas.Integer('An ID uniquely identifying each instance in the POST body'),
'prediction': user_schema or schemas.Number('Model Prediction')
},
reference_name=getattr(user_schema, 'reference_name', None)
)
if self.batch_prediction:
prediction_schema = schemas.Array(item_type=prediction_schema)
response_schema = schemas.Object(
properties={
'request_id': schemas.request_id,
'model_context': schemas.model_context,
'predictions': prediction_schema
}
)
# save this so the user can access it
self.response_schema = response_schema
# TODO: should a description be passed?
# https://github.com/CadentTech/porter/issues/32
self.add_response_schema('POST', 200, response_schema)
[docs]
class ModelApp:
"""
Abstraction used to simplify building REST APIs that expose predictive
models.
Essentially this class is a wrapper around an instance of ``flask.Flask``.
Args:
name (str): Name for the application. This will appear in the documentation
if ``expose_docs=True``. Optional.
description (str): Description of the application. This will appear in
the documentation if ``expose_docs=True``. HTML allowed. Optional.
version (str): Version of the application. This will appear in the
documentation if ``expose_docs=True``. Optional.
meta (dict): Additional meta data added to the response body in health
checks. Optional.
expose_docs (bool): If ``True`` API documentation will be served at
``docs_url``. The documentation is built from the
``request_schemas`` and ``response_schemas`` attributes of
services added to the instance. Default is ``False``.
docs_url (str): Endpoint for the API documentation. Ignored if
``expose_docs=False``. Defaults to '/docs/'. Note this does
_not_ override `docs_prefix`.
docs_json_url (str): URL where documentation JSON is exposed. Ignored if
``expose_docs=False``. Defaults to '/_docs.json'. Note this does
_not_ override `docs_prefix`.
docs_prefix (str): Prefix to applied to all documentation endpoints.
Must begin with a `/` and end without one.
Attributes:
name (str): Name for the application.
description (str): Description of the application.
version (str): Version of the application.
meta (dict): Additional meta data added to the response body in health
checks.
expose_docs (bool): Whether the instance is configured to expose API
documentation.
docs_url (str): Endpoint the API documentation is exposed at.
docs_json_url (str): URL where documentation JSON is exposed.
docs_prefix (str): Prefix to applied to all documentation endpoints.
docs_json (dict or None): The OpenAPI spec used to serve the Swagger
documentation. `None` if `expose_docs` is `False`.
"""
# note: eventually we may want to save this state somewhere else.
# perhaps it might make sense serve health checks from
# BaseService subclasses since we're basically trying to
# replicate that here, but we might not actually want all
# of that overhead either.
# see _route_health_checks() below.
_health_check_response_schemas = {'GET': [schemas.ResponseSchema(schemas.health_check, 200)]}
def __init__(self, services, *, name=None, description=None, version=None, meta=None,
expose_docs=False, docs_url='/docs/', docs_json_url='/_docs.json', docs_prefix=''):
self.services = services
self.name = name
self.meta = {} if meta is None else meta
base_description = f'(porter v{VERSION})'
description = description if description is not None else ''
self.description = f'<div>{description}</div><div><p>{base_description}</p></div>'
self.version = version
self.check_meta(self.meta)
self.expose_docs = expose_docs
self.docs_url = docs_prefix + docs_url
self.docs_json_url = docs_prefix + docs_json_url
self.docs_prefix = docs_prefix
self.app = self._init_app()
self._request_schemas = {}
self._response_schemas = {}
self._additional_params = {}
# this is just a cache of service IDs we can use to verify that
# each service is given a unique ID
self._service_ids = set()
self.meta.update(self._init_meta())
# this gets updated in _build_app if expose_docs is True
self.docs_json = None
self._build_app()
def __call__(self, *args, **kwargs):
"""Return a WSGI interface to the model app."""
return self.app(*args, **kwargs)
[docs]
def run(self, *args, **kwargs):
"""
Run the app.
Args:
*args: Positional arguments passed on to the wrapped ``flask`` app.
**kwargs: Keyword arguments passed on to the wrapped ``flask`` app.
"""
self.app.run(*args, **kwargs)
def _init_meta(self):
meta = [
('name', self.name),
('description', self.description),
('version', self.version),
('expose_docs', self.expose_docs),
]
if self.expose_docs:
meta.extend([
('docs_url', self.docs_url),
('docs_json_url', self.docs_json_url),
('docs_prefix', self.docs_prefix),
])
return meta
def _init_app(self):
name = __name__ if self.name is None else self.name
return api.App(name, static_folder=cn.ASSETS_DIR)
def _build_app(self):
"""Build and return the app.
Any global properties of the app, such as error handling and response
formatting, are added here.
Returns:
An instance of :class:`porter.api.App`.
"""
# register error handler for all werkzeug default exceptions
for error in werkzeug_exc.default_exceptions:
self.app.register_error_handler(error, serve_error_message)
# route the health checks
self._route_health_checks()
# route root
serve_root = ServeRoot(self)
self.app.route('/', methods=['GET'])(serve_root)
for service in self.services:
self._add_service(service)
if self.expose_docs:
self.docs_json = schemas.make_openapi_spec(
self.name, self.description, self.version, self._request_schemas,
self._response_schemas, self._additional_params)
self._route_docs()
def _route_endpoint(self, endpoint, fn, route_kwargs, *, request_schemas=None,
response_schemas=None, additional_params=None):
"""Route an endpoint with a contract and do the corresponding "book keeping"."""
self.app.route(endpoint, **route_kwargs)(fn)
if request_schemas is not None:
self._request_schemas[endpoint] = request_schemas
if response_schemas is not None:
self._response_schemas[endpoint] = response_schemas
if additional_params is not None:
self._additional_params[endpoint] = additional_params
def _add_service(self, service):
"""Add a service to the app from ``service``.
Args:
service (object): Instance of :class:`porter.services.BaseService`.
Returns:
None
Raises:
ValueError: If ``service.id`` has already been registered on the
app. This prevents errors from trying to route multiple classes
on the same endpoint.
"""
# register the service with the add
if service.id in self._service_ids:
raise ValueError(
f'a service has already been added using id={service.id}')
self._service_ids.add(service.id)
# get an iterable of methods exposed by the service.
# `or` accounts for schema attributes can be None
methods = (
set((service.request_schemas or {}).keys())
| set((service.response_schemas or {}).keys()))
# create tags for the service for the API docs
additional_params = {method: {'tags': [service.name]} for method in methods}
self._route_endpoint(service.endpoint, service, service.route_kwargs,
request_schemas=service.request_schemas,
response_schemas=service.response_schemas,
additional_params=additional_params)
def _route_health_checks(self):
serve_alive = ServeAlive(self)
serve_ready = ServeReady(self)
# note: Eventually we may want to save this state somewhere else.
# perhaps it might make sense serve health checks from
# BaseService subclasses since we're basically trying to
# replicate that here, but we might not actually want all
# of that overhead either. Another option is to just store
# it on the class as we do with the health check responses.
route_kwargs = {'methods': ['GET']}
additional_params = {'GET': {'tags': ['Health Check']}}
response = self._health_check_response_schemas
self._route_endpoint(cn.LIVENESS_ENDPOINT, serve_alive, route_kwargs,
response_schemas=response, additional_params=additional_params)
self._route_endpoint(cn.READINESS_ENDPOINT, serve_ready, route_kwargs,
response_schemas=response, additional_params=additional_params)
# TODO: perhaps this should be moved into the schemas module at some point
# https://github.com/CadentTech/porter/issues/32
def _route_docs(self):
docs_assets_path = self.docs_prefix + '/assets/swagger-ui/<path:filename>'
@self.app.route(self.docs_url)
def docs():
html = schemas.make_docs_html(self.docs_prefix, self.docs_json_url)
return html
@self.app.route(docs_assets_path)
def swagger_ui(filename):
import flask
return flask.send_from_directory(cn.ASSETS_DIR, 'swagger-ui/' + filename)
@self.app.route(self.docs_json_url)
def docs_json():
return json.dumps(self.docs_json)