diff options
Diffstat (limited to 'libs/flask_restx/reqparse.py')
-rw-r--r-- | libs/flask_restx/reqparse.py | 455 |
1 files changed, 455 insertions, 0 deletions
diff --git a/libs/flask_restx/reqparse.py b/libs/flask_restx/reqparse.py new file mode 100644 index 000000000..632606603 --- /dev/null +++ b/libs/flask_restx/reqparse.py @@ -0,0 +1,455 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + +import decimal +import six + +try: + from collections.abc import Hashable +except ImportError: + from collections import Hashable +from copy import deepcopy +from flask import current_app, request + +from werkzeug.datastructures import MultiDict, FileStorage +from werkzeug import exceptions + +from .errors import abort, SpecsError +from .marshalling import marshal +from .model import Model +from ._http import HTTPStatus + + +class ParseResult(dict): + """ + The default result container as an Object dict. + """ + + def __getattr__(self, name): + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name, value): + self[name] = value + + +_friendly_location = { + "json": "the JSON body", + "form": "the post body", + "args": "the query string", + "values": "the post body or the query string", + "headers": "the HTTP headers", + "cookies": "the request's cookies", + "files": "an uploaded file", +} + +#: Maps Flask-RESTX RequestParser locations to Swagger ones +LOCATIONS = { + "args": "query", + "form": "formData", + "headers": "header", + "json": "body", + "values": "query", + "files": "formData", +} + +#: Maps Python primitives types to Swagger ones +PY_TYPES = { + int: "integer", + str: "string", + bool: "boolean", + float: "number", + None: "void", +} + +SPLIT_CHAR = "," + +text_type = lambda x: six.text_type(x) # noqa + + +class Argument(object): + """ + :param name: Either a name or a list of option strings, e.g. foo or -f, --foo. + :param default: The value produced if the argument is absent from the request. + :param dest: The name of the attribute to be added to the object + returned by :meth:`~reqparse.RequestParser.parse_args()`. + :param bool required: Whether or not the argument may be omitted (optionals only). + :param string action: The basic type of action to be taken when this argument + is encountered in the request. Valid options are "store" and "append". + :param bool ignore: Whether to ignore cases where the argument fails type conversion + :param type: The type to which the request argument should be converted. + If a type raises an exception, the message in the error will be returned in the response. + Defaults to :class:`unicode` in python2 and :class:`str` in python3. + :param location: The attributes of the :class:`flask.Request` object + to source the arguments from (ex: headers, args, etc.), can be an + iterator. The last item listed takes precedence in the result set. + :param choices: A container of the allowable values for the argument. + :param help: A brief description of the argument, returned in the + response when the argument is invalid. May optionally contain + an "{error_msg}" interpolation token, which will be replaced with + the text of the error raised by the type converter. + :param bool case_sensitive: Whether argument values in the request are + case sensitive or not (this will convert all values to lowercase) + :param bool store_missing: Whether the arguments default value should + be stored if the argument is missing from the request. + :param bool trim: If enabled, trims whitespace around the argument. + :param bool nullable: If enabled, allows null value in argument. + """ + + def __init__( + self, + name, + default=None, + dest=None, + required=False, + ignore=False, + type=text_type, + location=("json", "values",), + choices=(), + action="store", + help=None, + operators=("=",), + case_sensitive=True, + store_missing=True, + trim=False, + nullable=True, + ): + self.name = name + self.default = default + self.dest = dest + self.required = required + self.ignore = ignore + self.location = location + self.type = type + self.choices = choices + self.action = action + self.help = help + self.case_sensitive = case_sensitive + self.operators = operators + self.store_missing = store_missing + self.trim = trim + self.nullable = nullable + + def source(self, request): + """ + Pulls values off the request in the provided location + :param request: The flask request object to parse arguments from + """ + if isinstance(self.location, six.string_types): + value = getattr(request, self.location, MultiDict()) + if callable(value): + value = value() + if value is not None: + return value + else: + values = MultiDict() + for l in self.location: + value = getattr(request, l, None) + if callable(value): + value = value() + if value is not None: + values.update(value) + return values + + return MultiDict() + + def convert(self, value, op): + # Don't cast None + if value is None: + if not self.nullable: + raise ValueError("Must not be null!") + return None + + elif isinstance(self.type, Model) and isinstance(value, dict): + return marshal(value, self.type) + + # and check if we're expecting a filestorage and haven't overridden `type` + # (required because the below instantiation isn't valid for FileStorage) + elif isinstance(value, FileStorage) and self.type == FileStorage: + return value + + try: + return self.type(value, self.name, op) + except TypeError: + try: + if self.type is decimal.Decimal: + return self.type(str(value), self.name) + else: + return self.type(value, self.name) + except TypeError: + return self.type(value) + + def handle_validation_error(self, error, bundle_errors): + """ + Called when an error is raised while parsing. Aborts the request + with a 400 status and an error message + + :param error: the error that was raised + :param bool bundle_errors: do not abort when first error occurs, return a + dict with the name of the argument and the error message to be + bundled + """ + error_str = six.text_type(error) + error_msg = ( + " ".join([six.text_type(self.help), error_str]) if self.help else error_str + ) + errors = {self.name: error_msg} + + if bundle_errors: + return ValueError(error), errors + abort(HTTPStatus.BAD_REQUEST, "Input payload validation failed", errors=errors) + + def parse(self, request, bundle_errors=False): + """ + Parses argument value(s) from the request, converting according to + the argument's type. + + :param request: The flask request object to parse arguments from + :param bool bundle_errors: do not abort when first error occurs, return a + dict with the name of the argument and the error message to be + bundled + """ + bundle_errors = current_app.config.get("BUNDLE_ERRORS", False) or bundle_errors + source = self.source(request) + + results = [] + + # Sentinels + _not_found = False + _found = True + + for operator in self.operators: + name = self.name + operator.replace("=", "", 1) + if name in source: + # Account for MultiDict and regular dict + if hasattr(source, "getlist"): + values = source.getlist(name) + else: + values = [source.get(name)] + + for value in values: + if hasattr(value, "strip") and self.trim: + value = value.strip() + if hasattr(value, "lower") and not self.case_sensitive: + value = value.lower() + + if hasattr(self.choices, "__iter__"): + self.choices = [choice.lower() for choice in self.choices] + + try: + if self.action == "split": + value = [ + self.convert(v, operator) + for v in value.split(SPLIT_CHAR) + ] + else: + value = self.convert(value, operator) + except Exception as error: + if self.ignore: + continue + return self.handle_validation_error(error, bundle_errors) + + if self.choices and value not in self.choices: + msg = "The value '{0}' is not a valid choice for '{1}'.".format( + value, name + ) + return self.handle_validation_error(msg, bundle_errors) + + if name in request.unparsed_arguments: + request.unparsed_arguments.pop(name) + results.append(value) + + if not results and self.required: + if isinstance(self.location, six.string_types): + location = _friendly_location.get(self.location, self.location) + else: + locations = [_friendly_location.get(loc, loc) for loc in self.location] + location = " or ".join(locations) + error_msg = "Missing required parameter in {0}".format(location) + return self.handle_validation_error(error_msg, bundle_errors) + + if not results: + if callable(self.default): + return self.default(), _not_found + else: + return self.default, _not_found + + if self.action == "append": + return results, _found + + if self.action == "store" or len(results) == 1: + return results[0], _found + return results, _found + + @property + def __schema__(self): + if self.location == "cookie": + return + param = {"name": self.name, "in": LOCATIONS.get(self.location, "query")} + _handle_arg_type(self, param) + if self.required: + param["required"] = True + if self.help: + param["description"] = self.help + if self.default is not None: + param["default"] = ( + self.default() if callable(self.default) else self.default + ) + if self.action == "append": + param["items"] = {"type": param["type"]} + param["type"] = "array" + param["collectionFormat"] = "multi" + if self.action == "split": + param["items"] = {"type": param["type"]} + param["type"] = "array" + param["collectionFormat"] = "csv" + if self.choices: + param["enum"] = self.choices + return param + + +class RequestParser(object): + """ + Enables adding and parsing of multiple arguments in the context of a single request. + Ex:: + + from flask_restx import RequestParser + + parser = RequestParser() + parser.add_argument('foo') + parser.add_argument('int_bar', type=int) + args = parser.parse_args() + + :param bool trim: If enabled, trims whitespace on all arguments in this parser + :param bool bundle_errors: If enabled, do not abort when first error occurs, + return a dict with the name of the argument and the error message to be + bundled and return all validation errors + """ + + def __init__( + self, + argument_class=Argument, + result_class=ParseResult, + trim=False, + bundle_errors=False, + ): + self.args = [] + self.argument_class = argument_class + self.result_class = result_class + self.trim = trim + self.bundle_errors = bundle_errors + + def add_argument(self, *args, **kwargs): + """ + Adds an argument to be parsed. + + Accepts either a single instance of Argument or arguments to be passed + into :class:`Argument`'s constructor. + + See :class:`Argument`'s constructor for documentation on the available options. + """ + + if len(args) == 1 and isinstance(args[0], self.argument_class): + self.args.append(args[0]) + else: + self.args.append(self.argument_class(*args, **kwargs)) + + # Do not know what other argument classes are out there + if self.trim and self.argument_class is Argument: + # enable trim for appended element + self.args[-1].trim = kwargs.get("trim", self.trim) + + return self + + def parse_args(self, req=None, strict=False): + """ + Parse all arguments from the provided request and return the results as a ParseResult + + :param bool strict: if req includes args not in parser, throw 400 BadRequest exception + :return: the parsed results as :class:`ParseResult` (or any class defined as :attr:`result_class`) + :rtype: ParseResult + """ + if req is None: + req = request + + result = self.result_class() + + # A record of arguments not yet parsed; as each is found + # among self.args, it will be popped out + req.unparsed_arguments = ( + dict(self.argument_class("").source(req)) if strict else {} + ) + errors = {} + for arg in self.args: + value, found = arg.parse(req, self.bundle_errors) + if isinstance(value, ValueError): + errors.update(found) + found = None + if found or arg.store_missing: + result[arg.dest or arg.name] = value + if errors: + abort( + HTTPStatus.BAD_REQUEST, "Input payload validation failed", errors=errors + ) + + if strict and req.unparsed_arguments: + arguments = ", ".join(req.unparsed_arguments.keys()) + msg = "Unknown arguments: {0}".format(arguments) + raise exceptions.BadRequest(msg) + + return result + + def copy(self): + """Creates a copy of this RequestParser with the same set of arguments""" + parser_copy = self.__class__(self.argument_class, self.result_class) + parser_copy.args = deepcopy(self.args) + parser_copy.trim = self.trim + parser_copy.bundle_errors = self.bundle_errors + return parser_copy + + def replace_argument(self, name, *args, **kwargs): + """Replace the argument matching the given name with a new version.""" + new_arg = self.argument_class(name, *args, **kwargs) + for index, arg in enumerate(self.args[:]): + if new_arg.name == arg.name: + del self.args[index] + self.args.append(new_arg) + break + return self + + def remove_argument(self, name): + """Remove the argument matching the given name.""" + for index, arg in enumerate(self.args[:]): + if name == arg.name: + del self.args[index] + break + return self + + @property + def __schema__(self): + params = [] + locations = set() + for arg in self.args: + param = arg.__schema__ + if param: + params.append(param) + locations.add(param["in"]) + if "body" in locations and "formData" in locations: + raise SpecsError("Can't use formData and body at the same time") + return params + + +def _handle_arg_type(arg, param): + if isinstance(arg.type, Hashable) and arg.type in PY_TYPES: + param["type"] = PY_TYPES[arg.type] + elif hasattr(arg.type, "__apidoc__"): + param["type"] = arg.type.__apidoc__["name"] + param["in"] = "body" + elif hasattr(arg.type, "__schema__"): + param.update(arg.type.__schema__) + elif arg.location == "files": + param["type"] = "file" + else: + param["type"] = "string" |