aboutsummaryrefslogtreecommitdiffhomepage
path: root/libs/flask_restx/reqparse.py
diff options
context:
space:
mode:
Diffstat (limited to 'libs/flask_restx/reqparse.py')
-rw-r--r--libs/flask_restx/reqparse.py455
1 files changed, 455 insertions, 0 deletions
diff --git a/libs/flask_restx/reqparse.py b/libs/flask_restx/reqparse.py
new file mode 100644
index 000000000..632606603
--- /dev/null
+++ b/libs/flask_restx/reqparse.py
@@ -0,0 +1,455 @@
+# -*- coding: utf-8 -*-
+from __future__ import unicode_literals
+
+import decimal
+import six
+
+try:
+ from collections.abc import Hashable
+except ImportError:
+ from collections import Hashable
+from copy import deepcopy
+from flask import current_app, request
+
+from werkzeug.datastructures import MultiDict, FileStorage
+from werkzeug import exceptions
+
+from .errors import abort, SpecsError
+from .marshalling import marshal
+from .model import Model
+from ._http import HTTPStatus
+
+
+class ParseResult(dict):
+ """
+ The default result container as an Object dict.
+ """
+
+ def __getattr__(self, name):
+ try:
+ return self[name]
+ except KeyError:
+ raise AttributeError(name)
+
+ def __setattr__(self, name, value):
+ self[name] = value
+
+
+_friendly_location = {
+ "json": "the JSON body",
+ "form": "the post body",
+ "args": "the query string",
+ "values": "the post body or the query string",
+ "headers": "the HTTP headers",
+ "cookies": "the request's cookies",
+ "files": "an uploaded file",
+}
+
+#: Maps Flask-RESTX RequestParser locations to Swagger ones
+LOCATIONS = {
+ "args": "query",
+ "form": "formData",
+ "headers": "header",
+ "json": "body",
+ "values": "query",
+ "files": "formData",
+}
+
+#: Maps Python primitives types to Swagger ones
+PY_TYPES = {
+ int: "integer",
+ str: "string",
+ bool: "boolean",
+ float: "number",
+ None: "void",
+}
+
+SPLIT_CHAR = ","
+
+text_type = lambda x: six.text_type(x) # noqa
+
+
+class Argument(object):
+ """
+ :param name: Either a name or a list of option strings, e.g. foo or -f, --foo.
+ :param default: The value produced if the argument is absent from the request.
+ :param dest: The name of the attribute to be added to the object
+ returned by :meth:`~reqparse.RequestParser.parse_args()`.
+ :param bool required: Whether or not the argument may be omitted (optionals only).
+ :param string action: The basic type of action to be taken when this argument
+ is encountered in the request. Valid options are "store" and "append".
+ :param bool ignore: Whether to ignore cases where the argument fails type conversion
+ :param type: The type to which the request argument should be converted.
+ If a type raises an exception, the message in the error will be returned in the response.
+ Defaults to :class:`unicode` in python2 and :class:`str` in python3.
+ :param location: The attributes of the :class:`flask.Request` object
+ to source the arguments from (ex: headers, args, etc.), can be an
+ iterator. The last item listed takes precedence in the result set.
+ :param choices: A container of the allowable values for the argument.
+ :param help: A brief description of the argument, returned in the
+ response when the argument is invalid. May optionally contain
+ an "{error_msg}" interpolation token, which will be replaced with
+ the text of the error raised by the type converter.
+ :param bool case_sensitive: Whether argument values in the request are
+ case sensitive or not (this will convert all values to lowercase)
+ :param bool store_missing: Whether the arguments default value should
+ be stored if the argument is missing from the request.
+ :param bool trim: If enabled, trims whitespace around the argument.
+ :param bool nullable: If enabled, allows null value in argument.
+ """
+
+ def __init__(
+ self,
+ name,
+ default=None,
+ dest=None,
+ required=False,
+ ignore=False,
+ type=text_type,
+ location=("json", "values",),
+ choices=(),
+ action="store",
+ help=None,
+ operators=("=",),
+ case_sensitive=True,
+ store_missing=True,
+ trim=False,
+ nullable=True,
+ ):
+ self.name = name
+ self.default = default
+ self.dest = dest
+ self.required = required
+ self.ignore = ignore
+ self.location = location
+ self.type = type
+ self.choices = choices
+ self.action = action
+ self.help = help
+ self.case_sensitive = case_sensitive
+ self.operators = operators
+ self.store_missing = store_missing
+ self.trim = trim
+ self.nullable = nullable
+
+ def source(self, request):
+ """
+ Pulls values off the request in the provided location
+ :param request: The flask request object to parse arguments from
+ """
+ if isinstance(self.location, six.string_types):
+ value = getattr(request, self.location, MultiDict())
+ if callable(value):
+ value = value()
+ if value is not None:
+ return value
+ else:
+ values = MultiDict()
+ for l in self.location:
+ value = getattr(request, l, None)
+ if callable(value):
+ value = value()
+ if value is not None:
+ values.update(value)
+ return values
+
+ return MultiDict()
+
+ def convert(self, value, op):
+ # Don't cast None
+ if value is None:
+ if not self.nullable:
+ raise ValueError("Must not be null!")
+ return None
+
+ elif isinstance(self.type, Model) and isinstance(value, dict):
+ return marshal(value, self.type)
+
+ # and check if we're expecting a filestorage and haven't overridden `type`
+ # (required because the below instantiation isn't valid for FileStorage)
+ elif isinstance(value, FileStorage) and self.type == FileStorage:
+ return value
+
+ try:
+ return self.type(value, self.name, op)
+ except TypeError:
+ try:
+ if self.type is decimal.Decimal:
+ return self.type(str(value), self.name)
+ else:
+ return self.type(value, self.name)
+ except TypeError:
+ return self.type(value)
+
+ def handle_validation_error(self, error, bundle_errors):
+ """
+ Called when an error is raised while parsing. Aborts the request
+ with a 400 status and an error message
+
+ :param error: the error that was raised
+ :param bool bundle_errors: do not abort when first error occurs, return a
+ dict with the name of the argument and the error message to be
+ bundled
+ """
+ error_str = six.text_type(error)
+ error_msg = (
+ " ".join([six.text_type(self.help), error_str]) if self.help else error_str
+ )
+ errors = {self.name: error_msg}
+
+ if bundle_errors:
+ return ValueError(error), errors
+ abort(HTTPStatus.BAD_REQUEST, "Input payload validation failed", errors=errors)
+
+ def parse(self, request, bundle_errors=False):
+ """
+ Parses argument value(s) from the request, converting according to
+ the argument's type.
+
+ :param request: The flask request object to parse arguments from
+ :param bool bundle_errors: do not abort when first error occurs, return a
+ dict with the name of the argument and the error message to be
+ bundled
+ """
+ bundle_errors = current_app.config.get("BUNDLE_ERRORS", False) or bundle_errors
+ source = self.source(request)
+
+ results = []
+
+ # Sentinels
+ _not_found = False
+ _found = True
+
+ for operator in self.operators:
+ name = self.name + operator.replace("=", "", 1)
+ if name in source:
+ # Account for MultiDict and regular dict
+ if hasattr(source, "getlist"):
+ values = source.getlist(name)
+ else:
+ values = [source.get(name)]
+
+ for value in values:
+ if hasattr(value, "strip") and self.trim:
+ value = value.strip()
+ if hasattr(value, "lower") and not self.case_sensitive:
+ value = value.lower()
+
+ if hasattr(self.choices, "__iter__"):
+ self.choices = [choice.lower() for choice in self.choices]
+
+ try:
+ if self.action == "split":
+ value = [
+ self.convert(v, operator)
+ for v in value.split(SPLIT_CHAR)
+ ]
+ else:
+ value = self.convert(value, operator)
+ except Exception as error:
+ if self.ignore:
+ continue
+ return self.handle_validation_error(error, bundle_errors)
+
+ if self.choices and value not in self.choices:
+ msg = "The value '{0}' is not a valid choice for '{1}'.".format(
+ value, name
+ )
+ return self.handle_validation_error(msg, bundle_errors)
+
+ if name in request.unparsed_arguments:
+ request.unparsed_arguments.pop(name)
+ results.append(value)
+
+ if not results and self.required:
+ if isinstance(self.location, six.string_types):
+ location = _friendly_location.get(self.location, self.location)
+ else:
+ locations = [_friendly_location.get(loc, loc) for loc in self.location]
+ location = " or ".join(locations)
+ error_msg = "Missing required parameter in {0}".format(location)
+ return self.handle_validation_error(error_msg, bundle_errors)
+
+ if not results:
+ if callable(self.default):
+ return self.default(), _not_found
+ else:
+ return self.default, _not_found
+
+ if self.action == "append":
+ return results, _found
+
+ if self.action == "store" or len(results) == 1:
+ return results[0], _found
+ return results, _found
+
+ @property
+ def __schema__(self):
+ if self.location == "cookie":
+ return
+ param = {"name": self.name, "in": LOCATIONS.get(self.location, "query")}
+ _handle_arg_type(self, param)
+ if self.required:
+ param["required"] = True
+ if self.help:
+ param["description"] = self.help
+ if self.default is not None:
+ param["default"] = (
+ self.default() if callable(self.default) else self.default
+ )
+ if self.action == "append":
+ param["items"] = {"type": param["type"]}
+ param["type"] = "array"
+ param["collectionFormat"] = "multi"
+ if self.action == "split":
+ param["items"] = {"type": param["type"]}
+ param["type"] = "array"
+ param["collectionFormat"] = "csv"
+ if self.choices:
+ param["enum"] = self.choices
+ return param
+
+
+class RequestParser(object):
+ """
+ Enables adding and parsing of multiple arguments in the context of a single request.
+ Ex::
+
+ from flask_restx import RequestParser
+
+ parser = RequestParser()
+ parser.add_argument('foo')
+ parser.add_argument('int_bar', type=int)
+ args = parser.parse_args()
+
+ :param bool trim: If enabled, trims whitespace on all arguments in this parser
+ :param bool bundle_errors: If enabled, do not abort when first error occurs,
+ return a dict with the name of the argument and the error message to be
+ bundled and return all validation errors
+ """
+
+ def __init__(
+ self,
+ argument_class=Argument,
+ result_class=ParseResult,
+ trim=False,
+ bundle_errors=False,
+ ):
+ self.args = []
+ self.argument_class = argument_class
+ self.result_class = result_class
+ self.trim = trim
+ self.bundle_errors = bundle_errors
+
+ def add_argument(self, *args, **kwargs):
+ """
+ Adds an argument to be parsed.
+
+ Accepts either a single instance of Argument or arguments to be passed
+ into :class:`Argument`'s constructor.
+
+ See :class:`Argument`'s constructor for documentation on the available options.
+ """
+
+ if len(args) == 1 and isinstance(args[0], self.argument_class):
+ self.args.append(args[0])
+ else:
+ self.args.append(self.argument_class(*args, **kwargs))
+
+ # Do not know what other argument classes are out there
+ if self.trim and self.argument_class is Argument:
+ # enable trim for appended element
+ self.args[-1].trim = kwargs.get("trim", self.trim)
+
+ return self
+
+ def parse_args(self, req=None, strict=False):
+ """
+ Parse all arguments from the provided request and return the results as a ParseResult
+
+ :param bool strict: if req includes args not in parser, throw 400 BadRequest exception
+ :return: the parsed results as :class:`ParseResult` (or any class defined as :attr:`result_class`)
+ :rtype: ParseResult
+ """
+ if req is None:
+ req = request
+
+ result = self.result_class()
+
+ # A record of arguments not yet parsed; as each is found
+ # among self.args, it will be popped out
+ req.unparsed_arguments = (
+ dict(self.argument_class("").source(req)) if strict else {}
+ )
+ errors = {}
+ for arg in self.args:
+ value, found = arg.parse(req, self.bundle_errors)
+ if isinstance(value, ValueError):
+ errors.update(found)
+ found = None
+ if found or arg.store_missing:
+ result[arg.dest or arg.name] = value
+ if errors:
+ abort(
+ HTTPStatus.BAD_REQUEST, "Input payload validation failed", errors=errors
+ )
+
+ if strict and req.unparsed_arguments:
+ arguments = ", ".join(req.unparsed_arguments.keys())
+ msg = "Unknown arguments: {0}".format(arguments)
+ raise exceptions.BadRequest(msg)
+
+ return result
+
+ def copy(self):
+ """Creates a copy of this RequestParser with the same set of arguments"""
+ parser_copy = self.__class__(self.argument_class, self.result_class)
+ parser_copy.args = deepcopy(self.args)
+ parser_copy.trim = self.trim
+ parser_copy.bundle_errors = self.bundle_errors
+ return parser_copy
+
+ def replace_argument(self, name, *args, **kwargs):
+ """Replace the argument matching the given name with a new version."""
+ new_arg = self.argument_class(name, *args, **kwargs)
+ for index, arg in enumerate(self.args[:]):
+ if new_arg.name == arg.name:
+ del self.args[index]
+ self.args.append(new_arg)
+ break
+ return self
+
+ def remove_argument(self, name):
+ """Remove the argument matching the given name."""
+ for index, arg in enumerate(self.args[:]):
+ if name == arg.name:
+ del self.args[index]
+ break
+ return self
+
+ @property
+ def __schema__(self):
+ params = []
+ locations = set()
+ for arg in self.args:
+ param = arg.__schema__
+ if param:
+ params.append(param)
+ locations.add(param["in"])
+ if "body" in locations and "formData" in locations:
+ raise SpecsError("Can't use formData and body at the same time")
+ return params
+
+
+def _handle_arg_type(arg, param):
+ if isinstance(arg.type, Hashable) and arg.type in PY_TYPES:
+ param["type"] = PY_TYPES[arg.type]
+ elif hasattr(arg.type, "__apidoc__"):
+ param["type"] = arg.type.__apidoc__["name"]
+ param["in"] = "body"
+ elif hasattr(arg.type, "__schema__"):
+ param.update(arg.type.__schema__)
+ elif arg.location == "files":
+ param["type"] = "file"
+ else:
+ param["type"] = "string"