1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
|
import json
import logging
import threading
import warnings
from typing import Any, Iterator, List, Optional, Type, Union
logger = logging.getLogger("pycountry.db")
class Data:
def __init__(self, **fields: str):
self._fields = fields
def __getattr__(self, key):
if key in self._fields:
return self._fields[key]
raise AttributeError()
def __setattr__(self, key: str, value: str) -> None:
if key != "_fields":
self._fields[key] = value
super().__setattr__(key, value)
def __repr__(self) -> str:
cls_name = self.__class__.__name__
fields = ", ".join("%s=%r" % i for i in sorted(self._fields.items()))
return f"{cls_name}({fields})"
def __dir__(self) -> List[str]:
return dir(self.__class__) + list(self._fields)
def __iter__(self):
# allow casting into a dict
for field in self._fields:
yield field, getattr(self, field)
class Country(Data):
def __getattr__(self, key):
if key in ("common_name", "official_name"):
# First try to get the common_name or official_name
value = self._fields.get(key)
if value is not None:
return value
# Fall back to name if common_name or official_name is not found
name = self._fields.get("name")
if name is not None:
warning_message = (
f"Country's {key} not found. Country name provided instead."
)
warnings.warn(warning_message, UserWarning)
return name
raise AttributeError()
else:
# For other keys, simply return the value or raise an error
if key in self._fields:
return self._fields[key]
raise AttributeError()
class Subdivision(Data):
pass
def lazy_load(f):
def load_if_needed(self, *args, **kw):
if not self._is_loaded:
with self._load_lock:
self._load()
return f(self, *args, **kw)
return load_if_needed
class Database:
data_class: Union[Type, str]
root_key: Optional[str] = None
no_index: List[str] = []
def __init__(self, filename: str) -> None:
self.filename = filename
self._is_loaded = False
self._load_lock = threading.Lock()
if isinstance(self.data_class, str):
self.factory = type(self.data_class, (Data,), {})
else:
self.factory = self.data_class
def _clear(self):
self._is_loaded = False
self.objects = []
self.index_names = set()
self.indices = {}
def _load(self) -> None:
if self._is_loaded:
# Help keeping the _load_if_needed code easier
# to read.
return
self._clear()
with open(self.filename, encoding="utf-8") as f:
tree = json.load(f)
for entry in tree[self.root_key]:
obj = self.factory(**entry)
self.objects.append(obj)
# Inject into index.
for key, value in entry.items():
if key in self.no_index:
continue
# Lookups and searches are case insensitive. Normalize
# here.
index = self.indices.setdefault(key, {})
value = value.lower()
if value in index:
logger.debug(
"%s %r already taken in index %r and will be "
"ignored. This is an error in the databases."
% (self.factory.__name__, value, key)
)
index[value] = obj
self._is_loaded = True
# Public API
@lazy_load
def add_entry(self, **kw):
# create the object with the correct dynamic type
obj = self.factory(**kw)
# append object
self.objects.append(obj)
# update indices
for key, value in kw.items():
if key in self.no_index:
continue
value = value.lower()
index = self.indices.setdefault(key, {})
index[value] = obj
@lazy_load
def remove_entry(self, **kw):
# make sure that we receive None if no entry found
if "default" in kw:
del kw["default"]
obj = self.get(**kw)
if not obj:
raise KeyError(
f"{self.factory.__name__} not found and cannot be removed: {kw}"
)
# remove object
self.objects.remove(obj)
# update indices
for key, value in obj:
if key in self.no_index:
continue
value = value.lower()
index = self.indices.setdefault(key, {})
if value in index:
del index[value]
@lazy_load
def __iter__(self) -> Iterator["Database"]:
return iter(self.objects)
@lazy_load
def __len__(self) -> int:
return len(self.objects)
@lazy_load
def get(self, **kw: Optional[str]) -> Optional[Any]:
kw.setdefault("default", None)
default = kw.pop("default")
if len(kw) != 1:
raise TypeError("Only one criteria may be given")
field, value = kw.popitem()
if not isinstance(value, str):
raise LookupError()
# Normalize for case-insensitivity
value = value.lower()
index = self.indices[field]
try:
return index[value]
except KeyError:
# Pythonic APIs implementing get() shouldn't raise KeyErrors.
# Those are a bit unexpected and they should rather support
# returning `None` by default and allow customization.
return default
@lazy_load
def lookup(self, value: str) -> Type:
if not isinstance(value, str):
raise LookupError()
# Normalize for case-insensitivity
value = value.lower()
# Use indexes first
for key in self.indices:
try:
return self.indices[key][value]
except LookupError:
pass
# Use non-indexed values now. Avoid going through indexed values.
for candidate in self:
for k in self.no_index:
v = candidate._fields.get(k)
if v is None:
continue
if v.lower() == value:
return candidate
raise LookupError("Could not find a record for %r" % value)
|