summaryrefslogtreecommitdiffhomepage
path: root/libs/playhouse/flask_utils.py
blob: 76a2a62f426e9baf475d67a21ebc05e1b0ee7931 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import math
import sys

from flask import abort
from flask import render_template
from flask import request
from peewee import Database
from peewee import DoesNotExist
from peewee import Model
from peewee import Proxy
from peewee import SelectQuery
from playhouse.db_url import connect as db_url_connect


class PaginatedQuery(object):
    def __init__(self, query_or_model, paginate_by, page_var='page', page=None,
                 check_bounds=False):
        self.paginate_by = paginate_by
        self.page_var = page_var
        self.page = page or None
        self.check_bounds = check_bounds

        if isinstance(query_or_model, SelectQuery):
            self.query = query_or_model
            self.model = self.query.model
        else:
            self.model = query_or_model
            self.query = self.model.select()

    def get_page(self):
        if self.page is not None:
            return self.page

        curr_page = request.args.get(self.page_var)
        if curr_page and curr_page.isdigit():
            return max(1, int(curr_page))
        return 1

    def get_page_count(self):
        if not hasattr(self, '_page_count'):
            self._page_count = int(math.ceil(
                float(self.query.count()) / self.paginate_by))
        return self._page_count

    def get_object_list(self):
        if self.check_bounds and self.get_page() > self.get_page_count():
            abort(404)
        return self.query.paginate(self.get_page(), self.paginate_by)


def get_object_or_404(query_or_model, *query):
    if not isinstance(query_or_model, SelectQuery):
        query_or_model = query_or_model.select()
    try:
        return query_or_model.where(*query).get()
    except DoesNotExist:
        abort(404)

def object_list(template_name, query, context_variable='object_list',
                paginate_by=20, page_var='page', page=None, check_bounds=True,
                **kwargs):
    paginated_query = PaginatedQuery(
        query,
        paginate_by=paginate_by,
        page_var=page_var,
        page=page,
        check_bounds=check_bounds)
    kwargs[context_variable] = paginated_query.get_object_list()
    return render_template(
        template_name,
        pagination=paginated_query,
        page=paginated_query.get_page(),
        **kwargs)

def get_current_url():
    if not request.query_string:
        return request.path
    return '%s?%s' % (request.path, request.query_string)

def get_next_url(default='/'):
    if request.args.get('next'):
        return request.args['next']
    elif request.form.get('next'):
        return request.form['next']
    return default

class FlaskDB(object):
    def __init__(self, app=None, database=None, model_class=Model):
        self.database = None  # Reference to actual Peewee database instance.
        self.base_model_class = model_class
        self._app = app
        self._db = database  # dict, url, Database, or None (default).
        if app is not None:
            self.init_app(app)

    def init_app(self, app):
        self._app = app

        if self._db is None:
            if 'DATABASE' in app.config:
                initial_db = app.config['DATABASE']
            elif 'DATABASE_URL' in app.config:
                initial_db = app.config['DATABASE_URL']
            else:
                raise ValueError('Missing required configuration data for '
                                 'database: DATABASE or DATABASE_URL.')
        else:
            initial_db = self._db

        self._load_database(app, initial_db)
        self._register_handlers(app)

    def _load_database(self, app, config_value):
        if isinstance(config_value, Database):
            database = config_value
        elif isinstance(config_value, dict):
            database = self._load_from_config_dict(dict(config_value))
        else:
            # Assume a database connection URL.
            database = db_url_connect(config_value)

        if isinstance(self.database, Proxy):
            self.database.initialize(database)
        else:
            self.database = database

    def _load_from_config_dict(self, config_dict):
        try:
            name = config_dict.pop('name')
            engine = config_dict.pop('engine')
        except KeyError:
            raise RuntimeError('DATABASE configuration must specify a '
                               '`name` and `engine`.')

        if '.' in engine:
            path, class_name = engine.rsplit('.', 1)
        else:
            path, class_name = 'peewee', engine

        try:
            __import__(path)
            module = sys.modules[path]
            database_class = getattr(module, class_name)
            assert issubclass(database_class, Database)
        except ImportError:
            raise RuntimeError('Unable to import %s' % engine)
        except AttributeError:
            raise RuntimeError('Database engine not found %s' % engine)
        except AssertionError:
            raise RuntimeError('Database engine not a subclass of '
                               'peewee.Database: %s' % engine)

        return database_class(name, **config_dict)

    def _register_handlers(self, app):
        app.before_request(self.connect_db)
        app.teardown_request(self.close_db)

    def get_model_class(self):
        if self.database is None:
            raise RuntimeError('Database must be initialized.')

        class BaseModel(self.base_model_class):
            class Meta:
                database = self.database

        return BaseModel

    @property
    def Model(self):
        if self._app is None:
            database = getattr(self, 'database', None)
            if database is None:
                self.database = Proxy()

        if not hasattr(self, '_model_class'):
            self._model_class = self.get_model_class()
        return self._model_class

    def connect_db(self):
        self.database.connect()

    def close_db(self, exc):
        if not self.database.is_closed():
            self.database.close()