diff --git a/pysolr.py b/pysolr.py index 3bd50c7c..a949a8b3 100644 --- a/pysolr.py +++ b/pysolr.py @@ -64,7 +64,6 @@ # Ugh. long = int # NOQA: A001 - __author__ = "Daniel Lindsley, Joseph Kocherhans, Jacob Kaplan-Moss, Thomas Rieder" __all__ = ["Solr"] @@ -218,12 +217,8 @@ def is_valid_xml_char_ordinal(i): Char ::= #x9 | #xA | #xD | [#x20-#xD7FF] | [#xE000-#xFFFD] | [#x10000-#x10FFFF] """ # conditions ordered by presumed frequency - return ( - 0x20 <= i <= 0xD7FF - or i in (0x9, 0xA, 0xD) - or 0xE000 <= i <= 0xFFFD - or 0x10000 <= i <= 0x10FFFF - ) + return (0x20 <= i <= 0xD7FF or i in (0x9, 0xA, 0xD) + or 0xE000 <= i <= 0xFFFD or 0x10000 <= i <= 0x10FFFF) def clean_xml_string(s): @@ -299,9 +294,8 @@ def __init__(self, decoded, next_page_query=None): self.qtime = decoded.get("responseHeader", {}).get("QTime", None) self.grouped = decoded.get("grouped", {}) self.nextCursorMark = decoded.get("nextCursorMark", None) - self._next_page_query = ( - self.nextCursorMark is not None and next_page_query or None - ) + self._next_page_query = (self.nextCursorMark is not None + and next_page_query or None) def __len__(self): if self._next_page_query: @@ -346,18 +340,18 @@ class Solr(object): """ def __init__( - self, - url, - decoder=None, - encoder=None, - timeout=60, - results_cls=Results, - search_handler="select", - use_qt_param=False, - always_commit=False, - auth=None, - verify=True, - session=None, + self, + url, + decoder=None, + encoder=None, + timeout=60, + results_cls=Results, + search_handler="select", + use_qt_param=False, + always_commit=False, + auth=None, + verify=True, + session=None, ): self.decoder = decoder or json.JSONDecoder() self.encoder = encoder or json.JSONEncoder() @@ -389,7 +383,12 @@ def _create_full_url(self, path=""): # No path? No problem. return self.url - def _send_request(self, method, path="", body=None, headers=None, files=None): + def _send_request(self, + method, + path="", + body=None, + headers=None, + files=None): url = self._create_full_url(path) method = method.lower() log_body = body @@ -415,7 +414,8 @@ def _send_request(self, method, path="", body=None, headers=None, files=None): try: requests_method = getattr(session, method) except AttributeError: - raise SolrError("Unable to use unknown HTTP method '{0}.".format(method)) + raise SolrError( + "Unable to use unknown HTTP method '{0}.".format(method)) # Everything except the body can be Unicode. The body must be # encoded to bytes to work properly on Py3. @@ -501,11 +501,13 @@ def _select(self, params, handler=None): # Handles very long queries by submitting as a POST. path = "%s/" % handler headers = { - "Content-type": "application/x-www-form-urlencoded; charset=utf-8" + "Content-type": + "application/x-www-form-urlencoded; charset=utf-8" } - return self._send_request( - "post", path, body=params_encoded, headers=headers - ) + return self._send_request("post", + path, + body=params_encoded, + headers=headers) def _mlt(self, params, handler="mlt"): return self._select(params, handler) @@ -514,17 +516,17 @@ def _suggest_terms(self, params, handler="terms"): return self._select(params, handler) def _update( - self, - message, - clean_ctrl_chars=True, - commit=None, - softCommit=False, - waitFlush=None, - waitSearcher=None, - overwrite=None, - handler="update", - solrapi="XML", - min_rf=None, + self, + message, + clean_ctrl_chars=True, + commit=None, + softCommit=False, + waitFlush=None, + waitSearcher=None, + overwrite=None, + handler="update", + solrapi="XML", + min_rf=None, ): """ Posts the given xml or json message to http:///update and @@ -565,7 +567,8 @@ def _update( query_vars.append("overwrite=%s" % str(bool(overwrite)).lower()) if waitSearcher is not None: - query_vars.append("waitSearcher=%s" % str(bool(waitSearcher)).lower()) + query_vars.append("waitSearcher=%s" % + str(bool(waitSearcher)).lower()) if query_vars: path = "%s?%s" % (path, "&".join(query_vars)) @@ -576,8 +579,8 @@ def _update( if solrapi == "XML": return self._send_request( - "post", path, message, {"Content-type": "text/xml; charset=utf-8"} - ) + "post", path, message, + {"Content-type": "text/xml; charset=utf-8"}) elif solrapi == "JSON": return self._send_request( "post", @@ -604,7 +607,8 @@ def _extract_error(self, resp): full_response = resp.content except ValueError: # otherwise we assume it's html - reason, full_html = self._scrape_response(resp.headers, resp.content) + reason, full_html = self._scrape_response( + resp.headers, resp.content) full_response = unescape_html(full_html) msg = "[Reason: %s]" % reason @@ -658,7 +662,8 @@ def _scrape_response(self, headers, response): if server_type == "tomcat": # Tomcat doesn't produce a valid XML response or consistent HTML: - m = re.search(r"<(h1)[^>]*>\s*(.+?)\s*", response, re.IGNORECASE) + m = re.search(r"<(h1)[^>]*>\s*(.+?)\s*", response, + re.IGNORECASE) if m: reason = m.group(2) else: @@ -684,7 +689,9 @@ def _scrape_response(self, headers, response): LOG.warning( # NOQA: G200 "Unable to extract error message from invalid XML: %s", err, - extra={"data": {"response": response}}, + extra={"data": { + "response": response + }}, ) full_html = "%s" % response @@ -918,13 +925,15 @@ def suggest_terms(self, fields, prefix, handler="terms", **kwargs): res[field] = tmp - self.log.debug( - "Found '%d' Term suggestions results.", sum(len(j) for i, j in res.items()) - ) + self.log.debug("Found '%d' Term suggestions results.", + sum(len(j) for i, j in res.items())) return res def _build_json_doc(self, doc): - cleaned_doc = {k: v for k, v in doc.items() if not self._is_null_value(v)} + cleaned_doc = { + k: v + for k, v in doc.items() if not self._is_null_value(v) + } return cleaned_doc def _build_xml_doc(self, doc, boost=None, fieldUpdates=None): @@ -933,7 +942,8 @@ def _build_xml_doc(self, doc, boost=None, fieldUpdates=None): for key, value in doc.items(): if key == NESTED_DOC_KEY: for child in value: - doc_elem.append(self._build_xml_doc(child, boost, fieldUpdates)) + doc_elem.append( + self._build_xml_doc(child, boost, fieldUpdates)) continue if key == "boost": @@ -945,11 +955,11 @@ def _build_xml_doc(self, doc, boost=None, fieldUpdates=None): if isinstance(value, (list, tuple, set)): values = value else: - values = (value,) + values = (value, ) use_field_updates = fieldUpdates and key in fieldUpdates if use_field_updates and not values: - values = ("",) + values = ("", ) for bit in values: attrs = {"name": key} @@ -980,18 +990,18 @@ def _build_xml_doc(self, doc, boost=None, fieldUpdates=None): return doc_elem def add( - self, - docs, - boost=None, - fieldUpdates=None, - commit=None, - softCommit=False, - commitWithin=None, - waitFlush=None, - waitSearcher=None, - overwrite=None, - handler="update", - min_rf=None, + self, + docs, + boost=None, + fieldUpdates=None, + commit=None, + softCommit=False, + commitWithin=None, + waitFlush=None, + waitSearcher=None, + overwrite=None, + handler="update", + min_rf=None, ): """ Adds or updates documents. @@ -1046,7 +1056,9 @@ def add( # json array of docs if isinstance(message, list): # convert to string - cleaned_message = [self._build_json_doc(doc) for doc in message] + cleaned_message = [ + self._build_json_doc(doc) for doc in message + ] m = self.encoder.encode(cleaned_message).encode("utf-8") else: raise ValueError("wrong message type") @@ -1057,7 +1069,9 @@ def add( message.set("commitWithin", commitWithin) for doc in docs: - el = self._build_xml_doc(doc, boost=boost, fieldUpdates=fieldUpdates) + el = self._build_xml_doc(doc, + boost=boost, + fieldUpdates=fieldUpdates) message.append(el) # This returns a bytestring. Ugh. @@ -1084,14 +1098,14 @@ def add( ) def delete( - self, - id=None, # NOQA: A002 - q=None, - commit=None, - softCommit=False, - waitFlush=None, - waitSearcher=None, - handler="update", + self, + id=None, # NOQA: A002 + q=None, + commit=None, + softCommit=False, + waitFlush=None, + waitSearcher=None, + handler="update", ): # NOQA: A002 """ Deletes documents. @@ -1126,7 +1140,8 @@ def delete( else: doc_id = list(filter(None, id)) if doc_id: - m = "%s" % "".join("%s" % i for i in doc_id) + m = "%s" % "".join("%s" % i + for i in doc_id) else: raise ValueError("The list of documents to delete was empty.") elif q is not None: @@ -1142,12 +1157,12 @@ def delete( ) def commit( - self, - softCommit=False, - waitFlush=None, - waitSearcher=None, - expungeDeletes=None, - handler="update", + self, + softCommit=False, + waitFlush=None, + waitSearcher=None, + expungeDeletes=None, + handler="update", ): """ Forces Solr to write the index data to disk. @@ -1166,7 +1181,8 @@ def commit( """ if expungeDeletes is not None: - msg = '' % str(bool(expungeDeletes)).lower() + msg = '' % str( + bool(expungeDeletes)).lower() else: msg = "" @@ -1180,12 +1196,12 @@ def commit( ) def optimize( - self, - commit=True, - waitFlush=None, - waitSearcher=None, - maxSegments=None, - handler="update", + self, + commit=True, + waitFlush=None, + waitSearcher=None, + maxSegments=None, + handler="update", ): """ Tells Solr to streamline the number of segments used, essentially a @@ -1215,7 +1231,11 @@ def optimize( handler=handler, ) - def extract(self, file_obj, extractOnly=True, handler="update/extract", **kwargs): + def extract(self, + file_obj, + extractOnly=True, + handler="update/extract", + **kwargs): """ POSTs a file to the Solr ExtractingRequestHandler so rich content can be processed using Apache Tika. See the Solr wiki for details: @@ -1254,9 +1274,10 @@ def extract(self, file_obj, extractOnly=True, handler="update/extract", **kwargs try: # We'll provide the file using its true name as Tika may use that # as a file type hint: - resp = self._send_request( - "post", handler, body=params, files={"file": (filename, file_obj)} - ) + resp = self._send_request("post", + handler, + body=params, + files={"file": (filename, file_obj)}) except (IOError, SolrError): self.log.exception("Failed to extract document metadata") raise @@ -1300,11 +1321,13 @@ def ping(self, handler="admin/ping", **kwargs): # Handles very long queries by submitting as a POST. path = "%s/" % handler headers = { - "Content-type": "application/x-www-form-urlencoded; charset=utf-8" + "Content-type": + "application/x-www-form-urlencoded; charset=utf-8" } - return self._send_request( - "post", path, body=params_encoded, headers=headers - ) + return self._send_request("post", + path, + body=params_encoded, + headers=headers) class SolrCoreAdmin(object): @@ -1353,15 +1376,22 @@ def status(self, core=None): return self._get_url(self.url, params=params) - def create( - self, name, instance_dir=None, config="solrconfig.xml", schema="schema.xml" - ): + def create(self, + name, + instance_dir=None, + config="solrconfig.xml", + schema="schema.xml"): """ Create a new core See https://wiki.apache.org/solr/CoreAdmin#CREATE """ - params = {"action": "CREATE", "name": name, "config": config, "schema": schema} + params = { + "action": "CREATE", + "name": name, + "config": config, + "schema": schema + } if instance_dir is None: params.update(instanceDir=name) @@ -1407,7 +1437,8 @@ def unload(self, core): return self._get_url(self.url, params=params) def load(self, core): - raise NotImplementedError("Solr 1.4 and below do not support this operation.") + raise NotImplementedError( + "Solr 1.4 and below do not support this operation.") # Using two-tuples to preserve order. @@ -1455,20 +1486,18 @@ def sanitize(data): class SolrCloud(Solr): - def __init__( - self, - zookeeper, - collection, - decoder=None, - encoder=None, - timeout=60, - retry_count=5, - retry_timeout=0.2, - auth=None, - verify=True, - *args, - **kwargs - ): + def __init__(self, + zookeeper, + collection, + decoder=None, + encoder=None, + timeout=60, + retry_count=5, + retry_timeout=0.2, + auth=None, + verify=True, + *args, + **kwargs): url = zookeeper.getRandomURL(collection) self.auth = auth self.collection = collection @@ -1477,22 +1506,26 @@ def __init__( self.verify = verify self.zookeeper = zookeeper - super(SolrCloud, self).__init__( - url, - decoder=decoder, - encoder=encoder, - timeout=timeout, - auth=self.auth, - verify=self.verify, - *args, - **kwargs - ) - - def _send_request(self, method, path="", body=None, headers=None, files=None): + super(SolrCloud, self).__init__(url, + decoder=decoder, + encoder=encoder, + timeout=timeout, + auth=self.auth, + verify=self.verify, + *args, + **kwargs) + + def _send_request(self, + method, + path="", + body=None, + headers=None, + files=None): for retry_number in range(0, self.retry_count): try: self.url = self.zookeeper.getRandomURL(self.collection) - return Solr._send_request(self, method, path, body, headers, files) + return Solr._send_request(self, method, path, body, headers, + files) except (SolrError, requests.exceptions.RequestException): LOG.exception( "%s %s failed on retry %s, will retry after %0.1fs", @@ -1503,9 +1536,8 @@ def _send_request(self, method, path="", body=None, headers=None, files=None): ) time.sleep(self.retry_timeout) - raise SolrError( - "Request %s %s failed after %d attempts" % (method, path, self.retry_count) - ) + raise SolrError("Request %s %s failed after %d attempts" % + (method, path, self.retry_count)) def _update(self, *args, **kwargs): self.url = self.zookeeper.getLeaderURL(self.collection) @@ -1530,9 +1562,14 @@ class ZooKeeper(object): FALSE = "false" COLLECTION = "collection" - def __init__(self, zkServerAddress, timeout=15, max_retries=-1, kazoo_client=None): + def __init__(self, + zkServerAddress, + timeout=15, + max_retries=-1, + kazoo_client=None): if KazooClient is None: - logging.error("ZooKeeper requires the `kazoo` library to be installed") + logging.error( + "ZooKeeper requires the `kazoo` library to be installed") raise RuntimeError self.collections = {} @@ -1564,7 +1601,8 @@ def connectionListener(state): @self.zk.DataWatch(ZooKeeper.CLUSTER_STATE) def watchClusterState(data, *args, **kwargs): if not data: - LOG.warning("No cluster state available: no collections defined?") + LOG.warning( + "No cluster state available: no collections defined?") else: self.collections = json.loads(data.decode("utf-8")) LOG.info("Updated collections: %s", self.collections) @@ -1592,7 +1630,8 @@ def watchAliases(data, stat): def watchCollectionState(data, *args, **kwargs): if not data: - LOG.warning("No cluster state available: no collections defined?") + LOG.warning( + "No cluster state available: no collections defined?") else: self.collections.update(json.loads(data.decode("utf-8"))) LOG.info("Updated collections: %s", self.collections) @@ -1601,7 +1640,8 @@ def watchCollectionState(data, *args, **kwargs): def watchCollectionStatus(children): LOG.info("Updated collection: %s", children) for c in children: - self.zk.DataWatch(self.COLLECTION_STATE % c, watchCollectionState) + self.zk.DataWatch(self.COLLECTION_STATE % c, + watchCollectionState) def getHosts(self, collname, only_leader=False, seen_aliases=None): if self.aliases and collname in self.aliases: @@ -1620,9 +1660,8 @@ def getHosts(self, collname, only_leader=False, seen_aliases=None): replica = replicas[replicaname] if replica[ZooKeeper.STATE] == ZooKeeper.ACTIVE: - if not only_leader or ( - replica.get(ZooKeeper.LEADER, None) == ZooKeeper.TRUE - ): + if not only_leader or (replica.get( + ZooKeeper.LEADER, None) == ZooKeeper.TRUE): base_url = replica[ZooKeeper.BASE_URL] if base_url not in hosts: hosts.append(base_url) @@ -1631,7 +1670,8 @@ def getHosts(self, collname, only_leader=False, seen_aliases=None): def getAliasHosts(self, collname, only_leader, seen_aliases): if seen_aliases: if collname in seen_aliases: - LOG.warning("%s in circular alias definition - ignored", collname) + LOG.warning("%s in circular alias definition - ignored", + collname) return [] else: seen_aliases = [] diff --git a/tests/test_client.py b/tests/test_client.py index 37349609..2881f73c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -38,34 +38,32 @@ class UtilsTestCase(unittest.TestCase): def test_unescape_html(self): - self.assertEqual(unescape_html("Hello • world"), "Hello \x95 world") + self.assertEqual(unescape_html("Hello • world"), + "Hello \x95 world") self.assertEqual(unescape_html("Hello d world"), "Hello d world") self.assertEqual(unescape_html("Hello & ☃"), "Hello & ☃") - self.assertEqual( - unescape_html("Hello &doesnotexist; world"), "Hello &doesnotexist; world" - ) + self.assertEqual(unescape_html("Hello &doesnotexist; world"), + "Hello &doesnotexist; world") def test_safe_urlencode(self): self.assertEqual( force_unicode( - unquote_plus(safe_urlencode({"test": "Hello ☃! Helllo world!"})) - ), + unquote_plus(safe_urlencode({"test": + "Hello ☃! Helllo world!"}))), "test=Hello ☃! Helllo world!", ) self.assertEqual( force_unicode( unquote_plus( - safe_urlencode({"test": ["Hello ☃!", "Helllo world!"]}, True) - ) - ), + safe_urlencode({"test": ["Hello ☃!", "Helllo world!"]}, + True))), "test=Hello \u2603!&test=Helllo world!", ) self.assertEqual( force_unicode( unquote_plus( - safe_urlencode({"test": ("Hello ☃!", "Helllo world!")}, True) - ) - ), + safe_urlencode({"test": ("Hello ☃!", "Helllo world!")}, + True))), "test=Hello \u2603!&test=Helllo world!", ) @@ -82,18 +80,18 @@ def test_force_unicode(self): # Don't mangle, it's already Unicode. self.assertEqual(force_unicode("Hello ☃"), "Hello ☃") - self.assertEqual(force_unicode(1), "1", "force_unicode() should convert ints") - self.assertEqual( - force_unicode(1.0), "1.0", "force_unicode() should convert floats" - ) - self.assertEqual( - force_unicode(None), "None", "force_unicode() should convert None" - ) + self.assertEqual(force_unicode(1), "1", + "force_unicode() should convert ints") + self.assertEqual(force_unicode(1.0), "1.0", + "force_unicode() should convert floats") + self.assertEqual(force_unicode(None), "None", + "force_unicode() should convert None") def test_force_bytes(self): self.assertEqual(force_bytes("Hello ☃"), b"Hello \xe2\x98\x83") # Don't mangle, it's already a bytestring. - self.assertEqual(force_bytes(b"Hello \xe2\x98\x83"), b"Hello \xe2\x98\x83") + self.assertEqual(force_bytes(b"Hello \xe2\x98\x83"), + b"Hello \xe2\x98\x83") def test_clean_xml_string(self): self.assertEqual(clean_xml_string("\x00\x0b\x0d\uffff"), "\x0d") @@ -102,8 +100,14 @@ def test_clean_xml_string(self): class ResultsTestCase(unittest.TestCase): def test_init(self): default_results = Results( - {"response": {"docs": [{"id": 1}, {"id": 2}], "numFound": 2}} - ) + {"response": { + "docs": [{ + "id": 1 + }, { + "id": 2 + }], + "numFound": 2 + }}) self.assertEqual(default_results.docs, [{"id": 1}, {"id": 2}]) self.assertEqual(default_results.hits, 2) @@ -115,19 +119,28 @@ def test_init(self): self.assertEqual(default_results.debug, {}) self.assertEqual(default_results.grouped, {}) - full_results = Results( - { - "response": {"docs": [{"id": 1}, {"id": 2}, {"id": 3}], "numFound": 3}, - # Fake data just to check assignments. - "highlighting": "hi", - "facet_counts": "fa", - "spellcheck": "sp", - "stats": "st", - "responseHeader": {"QTime": "0.001"}, - "debug": True, - "grouped": ["a"], - } - ) + full_results = Results({ + "response": { + "docs": [{ + "id": 1 + }, { + "id": 2 + }, { + "id": 3 + }], + "numFound": 3 + }, + # Fake data just to check assignments. + "highlighting": "hi", + "facet_counts": "fa", + "spellcheck": "sp", + "stats": "st", + "responseHeader": { + "QTime": "0.001" + }, + "debug": True, + "grouped": ["a"], + }) self.assertEqual(full_results.docs, [{"id": 1}, {"id": 2}, {"id": 3}]) self.assertEqual(full_results.hits, 3) @@ -141,19 +154,43 @@ def test_init(self): def test_len(self): small_results = Results( - {"response": {"docs": [{"id": 1}, {"id": 2}], "numFound": 2}} - ) + {"response": { + "docs": [{ + "id": 1 + }, { + "id": 2 + }], + "numFound": 2 + }}) self.assertEqual(len(small_results), 2) - wrong_hits_results = Results( - {"response": {"docs": [{"id": 1}, {"id": 2}, {"id": 3}], "numFound": 7}} - ) + wrong_hits_results = Results({ + "response": { + "docs": [{ + "id": 1 + }, { + "id": 2 + }, { + "id": 3 + }], + "numFound": 7 + } + }) self.assertEqual(len(wrong_hits_results), 3) def test_iter(self): - long_results = Results( - {"response": {"docs": [{"id": 1}, {"id": 2}, {"id": 3}], "numFound": 7}} - ) + long_results = Results({ + "response": { + "docs": [{ + "id": 1 + }, { + "id": 2 + }, { + "id": 3 + }], + "numFound": 7 + } + }) to_iter = list(long_results) self.assertEqual(to_iter[0], {"id": 1}) @@ -175,23 +212,47 @@ def setUp(self): super(SolrTestCase, self).setUp() self.solr = self.get_solr("core0") self.docs = [ - {"id": "doc_1", "title": "Example doc 1", "price": 12.59, "popularity": 10}, + { + "id": "doc_1", + "title": "Example doc 1", + "price": 12.59, + "popularity": 10 + }, { "id": "doc_2", "title": "Another example ☃ doc 2", "price": 13.69, "popularity": 7, }, - {"id": "doc_3", "title": "Another thing", "price": 2.35, "popularity": 8}, - {"id": "doc_4", "title": "doc rock", "price": 99.99, "popularity": 10}, - {"id": "doc_5", "title": "Boring", "price": 1.12, "popularity": 2}, + { + "id": "doc_3", + "title": "Another thing", + "price": 2.35, + "popularity": 8 + }, + { + "id": "doc_4", + "title": "doc rock", + "price": 99.99, + "popularity": 10 + }, + { + "id": "doc_5", + "title": "Boring", + "price": 1.12, + "popularity": 2 + }, # several with nested docs (not using fields that are used in # normal docs so that they don't interfere with their tests) { - "id": "parentdoc_1", - "type_s": "parent", - "name_t": "Parent no. 1", - "pages_i": 5, + "id": + "parentdoc_1", + "type_s": + "parent", + "name_t": + "Parent no. 1", + "pages_i": + 5, NESTED_DOC_KEY: [ { "id": "childdoc_1", @@ -208,26 +269,30 @@ def setUp(self): ], }, { - "id": "parentdoc_2", - "type_s": "parent", - "name_t": "Parent no. 2", - "pages_i": 500, - NESTED_DOC_KEY: [ - { - "id": "childdoc_3", - "type_s": "child", - "name_t": "Child of another parent", - "comment_t": "Yello", - NESTED_DOC_KEY: [ - { - "id": "grandchilddoc_1", - "type_s": "grandchild", - "name_t": "Grand child of parent", - "comment_t": "Blah", - } - ], - } - ], + "id": + "parentdoc_2", + "type_s": + "parent", + "name_t": + "Parent no. 2", + "pages_i": + 500, + NESTED_DOC_KEY: [{ + "id": + "childdoc_3", + "type_s": + "child", + "name_t": + "Child of another parent", + "comment_t": + "Yello", + NESTED_DOC_KEY: [{ + "id": "grandchilddoc_1", + "type_s": "grandchild", + "name_t": "Grand child of parent", + "comment_t": "Blah", + }], + }], }, ] @@ -250,8 +315,7 @@ def assertURLStartsWith(self, URL, path): # Note that we do not use urljoin to ensure that any changes in trailing # slash handling are caught quickly: return self.assertEqual( - URL, "%s/%s" % (self.solr.url.replace("/core0", ""), path) - ) + URL, "%s/%s" % (self.solr.url.replace("/core0", ""), path)) def get_solr(self, collection, timeout=60, always_commit=False): return Solr( @@ -281,7 +345,8 @@ def test_custom_results_class(self): def test_cursor_traversal(self): solr = Solr("http://localhost:8983/solr/core0") - expected = solr.search(q="*:*", rows=len(self.docs) * 3, sort="id asc").docs + expected = solr.search(q="*:*", rows=len(self.docs) * 3, + sort="id asc").docs results = solr.search(q="*:*", cursorMark="*", rows=2, sort="id asc") all_docs = list(results) self.assertEqual(len(expected), len(all_docs)) @@ -293,14 +358,15 @@ def test__create_full_url_base(self): def test__create_full_url_with_path(self): self.assertURLStartsWith( - self.solr._create_full_url(path="pysolr_tests"), "core0/pysolr_tests" - ) + self.solr._create_full_url(path="pysolr_tests"), + "core0/pysolr_tests") def test__create_full_url_with_path_and_querystring(self): # Note the use of a querystring parameter including a trailing slash to # catch sloppy trimming: self.assertURLStartsWith( - self.solr._create_full_url(path="/pysolr_tests/select/?whatever=/"), + self.solr._create_full_url( + path="/pysolr_tests/select/?whatever=/"), "core0/pysolr_tests/select/?whatever=/", ) @@ -332,16 +398,14 @@ def test__send_request(self): def test__send_request_to_bad_path(self): # Test a non-existent URL: self.solr.url = "http://127.0.0.1:56789/wahtever" - self.assertRaises( - SolrError, self.solr._send_request, "get", "select/?q=doc&wt=json" - ) + self.assertRaises(SolrError, self.solr._send_request, "get", + "select/?q=doc&wt=json") def test_send_request_to_bad_core(self): # Test a bad core on a valid URL: self.solr.url = "http://localhost:8983/solr/bad_core" - self.assertRaises( - SolrError, self.solr._send_request, "get", "select/?q=doc&wt=json" - ) + self.assertRaises(SolrError, self.solr._send_request, "get", + "select/?q=doc&wt=json") def test__select(self): # Short params. @@ -353,12 +417,17 @@ def test__select(self): resp_body = self.solr._select({"q": "doc" * 1024}) resp_data = json.loads(resp_body) self.assertEqual(resp_data["response"]["numFound"], 0) - self.assertEqual(len(resp_data["responseHeader"]["params"]["q"]), 3 * 1024) + self.assertEqual(len(resp_data["responseHeader"]["params"]["q"]), + 3 * 1024) # Test Deep Pagination CursorMark - resp_body = self.solr._select( - {"q": "*", "cursorMark": "*", "sort": "id desc", "start": 0, "rows": 2} - ) + resp_body = self.solr._select({ + "q": "*", + "cursorMark": "*", + "sort": "id desc", + "start": 0, + "rows": 2 + }) resp_data = json.loads(resp_body) self.assertEqual(len(resp_data["response"]["docs"]), 2) self.assertIn("nextCursorMark", resp_data) @@ -398,37 +467,34 @@ def json(self): return json.loads(self.content) # Just the reason. - resp_1 = RubbishResponse("We don't care.", {"reason": "Something went wrong."}) - self.assertEqual( - self.solr._extract_error(resp_1), "[Reason: Something went wrong.]" - ) + resp_1 = RubbishResponse("We don't care.", + {"reason": "Something went wrong."}) + self.assertEqual(self.solr._extract_error(resp_1), + "[Reason: Something went wrong.]") # Empty reason. resp_2 = RubbishResponse("We don't care.", {"reason": None}) - self.assertEqual( - self.solr._extract_error(resp_2), "[Reason: None]\nWe don't care." - ) + self.assertEqual(self.solr._extract_error(resp_2), + "[Reason: None]\nWe don't care.") # No reason. Time to scrape. resp_3 = RubbishResponse( "
Something is broke.
", {"server": "jetty"}, ) - self.assertEqual( - self.solr._extract_error(resp_3), "[Reason: Something is broke.]" - ) + self.assertEqual(self.solr._extract_error(resp_3), + "[Reason: Something is broke.]") # No reason. JSON response. - resp_4 = RubbishResponse( - b'\n {"error": {"msg": "It happens"}}', {"server": "tomcat"} - ) - self.assertEqual(self.solr._extract_error(resp_4), "[Reason: It happens]") + resp_4 = RubbishResponse(b'\n {"error": {"msg": "It happens"}}', + {"server": "tomcat"}) + self.assertEqual(self.solr._extract_error(resp_4), + "[Reason: It happens]") # No reason. Weird JSON response. resp_5 = RubbishResponse(b'{"kinda": "weird"}', {"server": "jetty"}) - self.assertEqual( - self.solr._extract_error(resp_5), '[Reason: None]\n{"kinda": "weird"}' - ) + self.assertEqual(self.solr._extract_error(resp_5), + '[Reason: None]\n{"kinda": "weird"}') def test__scrape_response(self): # Jetty. @@ -485,7 +551,8 @@ def test__scrape_response_tomcat(self): # Invalid XML bogus_xml = '\n\n4000Invalid Date String:\'2015-03-23 10:43:33\'400' # NOQA: E501 - reason, full_html = self.solr._scrape_response({"server": "coyote"}, bogus_xml) + reason, full_html = self.solr._scrape_response({"server": "coyote"}, + bogus_xml) self.assertIsNone(reason, None) self.assertEqual(full_html, bogus_xml.replace("\n", "")) @@ -499,9 +566,8 @@ def test__from_python(self): self.assertEqual(self.solr._from_python("\x01test\x02"), "test") def test__from_python_dates(self): - self.assertEqual( - self.solr._from_python(datetime.date(2013, 1, 18)), "2013-01-18T00:00:00Z" - ) + self.assertEqual(self.solr._from_python(datetime.date(2013, 1, 18)), + "2013-01-18T00:00:00Z") self.assertEqual( self.solr._from_python(datetime.datetime(2013, 1, 18, 0, 30, 28)), "2013-01-18T00:30:28Z", @@ -519,8 +585,13 @@ def dst(self): # Check a UTC timestamp self.assertEqual( self.solr._from_python( - datetime.datetime(2013, 1, 18, 0, 30, 28, tzinfo=FakeTimeZone()) - ), + datetime.datetime(2013, + 1, + 18, + 0, + 30, + 28, + tzinfo=FakeTimeZone())), "2013-01-18T00:30:28Z", ) @@ -528,15 +599,19 @@ def dst(self): FakeTimeZone.offset = -(5 * 60) self.assertEqual( self.solr._from_python( - datetime.datetime(2013, 1, 18, 0, 30, 28, tzinfo=FakeTimeZone()) - ), + datetime.datetime(2013, + 1, + 18, + 0, + 30, + 28, + tzinfo=FakeTimeZone())), "2013-01-18T05:30:28Z", ) def test__to_python(self): - self.assertEqual( - self.solr._to_python("2013-01-18T00:00:00Z"), datetime.datetime(2013, 1, 18) - ) + self.assertEqual(self.solr._to_python("2013-01-18T00:00:00Z"), + datetime.datetime(2013, 1, 18)) self.assertEqual( self.solr._to_python("2013-01-18T00:30:28Z"), datetime.datetime(2013, 1, 18, 0, 30, 28), @@ -549,9 +624,8 @@ def test__to_python(self): self.assertEqual(self.solr._to_python("hello ☃"), "hello ☃") self.assertEqual(self.solr._to_python(["foo", "bar"]), "foo") self.assertEqual(self.solr._to_python(("foo", "bar")), "foo") - self.assertEqual( - self.solr._to_python('tuple("foo", "bar")'), 'tuple("foo", "bar")' - ) + self.assertEqual(self.solr._to_python('tuple("foo", "bar")'), + 'tuple("foo", "bar")') def test__is_null_value(self): self.assertTrue(self.solr._is_null_value(None)) @@ -575,8 +649,7 @@ def test_search(self): # Advanced options. results = self.solr.search( - "doc", - **{ + "doc", **{ "debug": "true", "hl": "true", "hl.fragsize": 8, @@ -585,11 +658,14 @@ def test_search(self): "spellcheck": "true", "spellcheck.collate": "true", "spellcheck.count": 1, - } - ) + }) self.assertEqual(len(results), 3) self.assertIn("explain", results.debug) - self.assertEqual(results.highlighting, {"doc_4": {}, "doc_2": {}, "doc_1": {}}) + self.assertEqual(results.highlighting, { + "doc_4": {}, + "doc_2": {}, + "doc_1": {} + }) self.assertEqual(results.spellcheck, {}) self.assertEqual( results.facets["facet_fields"]["popularity"], @@ -598,10 +674,12 @@ def test_search(self): self.assertIsNotNone(results.qtime) # Nested search #1: find parent where child's comment has 'hello' - results = self.solr.search("{!parent which=type_s:parent}comment_t:hello") + results = self.solr.search( + "{!parent which=type_s:parent}comment_t:hello") self.assertEqual(len(results), 1) # Nested search #2: find child with a child - results = self.solr.search("{!parent which=type_s:child}comment_t:blah") + results = self.solr.search( + "{!parent which=type_s:child}comment_t:blah") self.assertEqual(len(results), 1) def test_multiple_search_handlers(self): @@ -670,17 +748,21 @@ def test__build_xml_doc(self): "popularity": 10, } doc_xml = force_unicode( - ElementTree.tostring(self.solr._build_xml_doc(doc), encoding="utf-8") - ) + ElementTree.tostring(self.solr._build_xml_doc(doc), + encoding="utf-8")) self.assertIn('Example doc ☃ 1', doc_xml) self.assertIn('doc_1', doc_xml) self.assertEqual(len(doc_xml), 152) def test__build_xml_doc_with_sets(self): - doc = {"id": "doc_1", "title": "Set test doc", "tags": {"alpha", "beta"}} + doc = { + "id": "doc_1", + "title": "Set test doc", + "tags": {"alpha", "beta"} + } doc_xml = force_unicode( - ElementTree.tostring(self.solr._build_xml_doc(doc), encoding="utf-8") - ) + ElementTree.tostring(self.solr._build_xml_doc(doc), + encoding="utf-8")) self.assertIn('doc_1', doc_xml) self.assertIn('Set test doc', doc_xml) self.assertIn('alpha', doc_xml) @@ -715,8 +797,10 @@ def test__build_xml_doc_with_sub_docs(self): children_docs = doc_xml.findall("doc") self.assertEqual(len(children_docs), len(sub_docs)) - self.assertEqual(children_docs[0].find("*[@name='id']").text, sub_docs[0]["id"]) - self.assertEqual(children_docs[1].find("*[@name='id']").text, sub_docs[1]["id"]) + self.assertEqual(children_docs[0].find("*[@name='id']").text, + sub_docs[0]["id"]) + self.assertEqual(children_docs[1].find("*[@name='id']").text, + sub_docs[1]["id"]) def test__build_xml_doc_with_empty_values(self): doc = { @@ -726,8 +810,8 @@ def test__build_xml_doc_with_empty_values(self): "tags": [], } doc_xml = force_unicode( - ElementTree.tostring(self.solr._build_xml_doc(doc), encoding="utf-8") - ) + ElementTree.tostring(self.solr._build_xml_doc(doc), + encoding="utf-8")) self.assertNotIn('', doc_xml) self.assertNotIn('', doc_xml) self.assertNotIn('', doc_xml) @@ -749,11 +833,12 @@ def test__build_xml_doc_with_empty_values_and_field_updates(self): ElementTree.tostring( self.solr._build_xml_doc(doc, fieldUpdates=fieldUpdates), encoding="utf-8", - ) - ) - self.assertIn('', doc_xml) + )) + self.assertIn('', + doc_xml) self.assertNotIn('', doc_xml) - self.assertIn('', doc_xml) + self.assertIn('', + doc_xml) self.assertIn('doc_1', doc_xml) self.assertEqual(len(doc_xml), 134) @@ -771,8 +856,14 @@ def test_add(self): self.solr.add( [ - {"id": "doc_6", "title": "Newly added doc"}, - {"id": "doc_7", "title": "Another example doc"}, + { + "id": "doc_6", + "title": "Newly added doc" + }, + { + "id": "doc_7", + "title": "Another example doc" + }, ], commit=True, ) @@ -792,13 +883,18 @@ def test_add(self): def test_add_with_boost(self): self.assertEqual(len(self.solr.search("doc")), 3) - self.solr.add( - [{"id": "doc_6", "title": "Important doc"}], boost={"title": 10.0} - ) + self.solr.add([{ + "id": "doc_6", + "title": "Important doc" + }], + boost={"title": 10.0}) - self.solr.add( - [{"id": "doc_7", "title": "Spam doc doc"}], boost={"title": 0}, commit=True - ) + self.solr.add([{ + "id": "doc_7", + "title": "Spam doc doc" + }], + boost={"title": 0}, + commit=True) res = self.solr.search("doc") self.assertEqual(len(res), 5) @@ -810,21 +906,20 @@ def test_field_update_inc(self): updateList = [] for doc in originalDocs: updateList.append({"id": doc["id"], "popularity": 5}) - self.solr.add(updateList, fieldUpdates={"popularity": "inc"}, commit=True) + self.solr.add(updateList, + fieldUpdates={"popularity": "inc"}, + commit=True) updatedDocs = self.solr.search("doc") self.assertEqual(len(updatedDocs), 3) for (originalDoc, updatedDoc) in zip(originalDocs, updatedDocs): self.assertEqual(len(updatedDoc.keys()), len(originalDoc.keys())) - self.assertEqual(updatedDoc["popularity"], originalDoc["popularity"] + 5) + self.assertEqual(updatedDoc["popularity"], + originalDoc["popularity"] + 5) # TODO: change this to use assertSetEqual: self.assertTrue( - all( - updatedDoc[k] == originalDoc[k] - for k in updatedDoc.keys() - if k not in ["_version_", "popularity"] - ) - ) + all(updatedDoc[k] == originalDoc[k] for k in updatedDoc.keys() + if k not in ["_version_", "popularity"])) def test_field_update_set(self): originalDocs = self.solr.search("doc") @@ -832,8 +927,13 @@ def test_field_update_set(self): self.assertEqual(len(originalDocs), 3) updateList = [] for doc in originalDocs: - updateList.append({"id": doc["id"], "popularity": updated_popularity}) - self.solr.add(updateList, fieldUpdates={"popularity": "set"}, commit=True) + updateList.append({ + "id": doc["id"], + "popularity": updated_popularity + }) + self.solr.add(updateList, + fieldUpdates={"popularity": "set"}, + commit=True) updatedDocs = self.solr.search("doc") self.assertEqual(len(updatedDocs), 3) @@ -842,12 +942,8 @@ def test_field_update_set(self): self.assertEqual(updatedDoc["popularity"], updated_popularity) # TODO: change this to use assertSetEqual: self.assertTrue( - all( - updatedDoc[k] == originalDoc[k] - for k in updatedDoc.keys() - if k not in ["_version_", "popularity"] - ) - ) + all(updatedDoc[k] == originalDoc[k] for k in updatedDoc.keys() + if k not in ["_version_", "popularity"])) def test_field_update_add(self): self.solr.add( @@ -870,24 +966,22 @@ def test_field_update_add(self): self.assertEqual(len(originalDocs), 2) updateList = [] for doc in originalDocs: - updateList.append({"id": doc["id"], "word_ss": ["epsilon", "gamma"]}) + updateList.append({ + "id": doc["id"], + "word_ss": ["epsilon", "gamma"] + }) self.solr.add(updateList, fieldUpdates={"word_ss": "add"}, commit=True) updatedDocs = self.solr.search("multivalued") self.assertEqual(len(updatedDocs), 2) for (originalDoc, updatedDoc) in zip(originalDocs, updatedDocs): self.assertEqual(len(updatedDoc.keys()), len(originalDoc.keys())) - self.assertEqual( - updatedDoc["word_ss"], originalDoc["word_ss"] + ["epsilon", "gamma"] - ) + self.assertEqual(updatedDoc["word_ss"], + originalDoc["word_ss"] + ["epsilon", "gamma"]) # TODO: change this to use assertSetEqual: self.assertTrue( - all( - updatedDoc[k] == originalDoc[k] - for k in updatedDoc.keys() - if k not in ["_version_", "word_ss"] - ) - ) + all(updatedDoc[k] == originalDoc[k] for k in updatedDoc.keys() + if k not in ["_version_", "word_ss"])) def test_delete(self): self.assertEqual(len(self.solr.search("doc")), 3) @@ -929,7 +1023,8 @@ def leaf_doc(doc): self.assertEqual(len(self.solr.search(leaf_q)), len(to_delete_docs)) # Extract a random doc from the list, to later check it wasn't deleted. graced_doc_id = to_delete_ids.pop( - random.randint(0, len(to_delete_ids) - 1) # NOQA: B311 + random.randint(0, + len(to_delete_ids) - 1) # NOQA: B311 ) self.solr.delete(id=to_delete_ids, commit=True) # There should be only one left, our graced id @@ -969,7 +1064,11 @@ def test_can_handles_default_commit_policy(self): commit_arg = [False, True, None] for expected_commit, arg in zip(expected_commits, commit_arg): - self.solr.add([{"id": "doc_6", "title": "Newly added doc"}], commit=arg) + self.solr.add([{ + "id": "doc_6", + "title": "Newly added doc" + }], + commit=arg) args, _ = self.solr._send_request.call_args committing_in_url = "commit" in args[1] self.assertEqual(expected_commit, committing_in_url) @@ -978,8 +1077,14 @@ def test_overwrite(self): self.assertEqual(len(self.solr.search("id:doc_overwrite_1")), 0) self.solr.add( [ - {"id": "doc_overwrite_1", "title": "Kim is awesome."}, - {"id": "doc_overwrite_1", "title": "Kim is more awesome."}, + { + "id": "doc_overwrite_1", + "title": "Kim is awesome." + }, + { + "id": "doc_overwrite_1", + "title": "Kim is more awesome." + }, ], overwrite=False, commit=True, @@ -995,7 +1100,11 @@ def test_overwrite(self): def test_optimize(self): # Make sure it doesn't blow up. Side effects are hard to measure. :/ self.assertEqual(len(self.solr.search("doc")), 3) - self.solr.add([{"id": "doc_6", "title": "Newly added doc"}], commit=False) + self.solr.add([{ + "id": "doc_6", + "title": "Newly added doc" + }], + commit=False) self.assertEqual(len(self.solr.search("doc")), 3) self.solr.optimize() # optimize should default to 'update' handler @@ -1010,8 +1119,7 @@ def test_optimize(self): self.assertTrue(args[1].startswith("fakehandler")) def test_extract(self): - fake_f = StringIO( - """ + fake_f = StringIO(""" @@ -1020,8 +1128,7 @@ def test_extract(self): foobar - """ - ) + """) fake_f.name = "test.html" extracted = self.solr.extract(fake_f) # extract should default to 'update/extract' handler @@ -1044,7 +1151,8 @@ def test_extract(self): self.assertEqual([fake_f.name], m["stream_name"]) - self.assertIn("haystack-test", m, "HTML metadata should have been extracted!") + self.assertIn("haystack-test", m, + "HTML metadata should have been extracted!") self.assertEqual(["test 1234"], m["haystack-test"]) # Note the underhanded use of a double snowman to verify both that Tika @@ -1053,8 +1161,7 @@ def test_extract(self): self.assertEqual(["Test Title ☃☃"], m["title"]) def test_extract_special_char_in_filename(self): - fake_f = StringIO( - """ + fake_f = StringIO(""" @@ -1063,8 +1170,7 @@ def test_extract_special_char_in_filename(self): foobar - """ - ) + """) fake_f.name = "test☃.html" extracted = self.solr.extract(fake_f) # extract should default to 'update/extract' handler @@ -1085,9 +1191,11 @@ def test_extract_special_char_in_filename(self): m = extracted["metadata"] - self.assertEqual([quote(fake_f.name.encode("utf-8"))], m["stream_name"]) + self.assertEqual([quote(fake_f.name.encode("utf-8"))], + m["stream_name"]) - self.assertIn("haystack-test", m, "HTML metadata should have been extracted!") + self.assertIn("haystack-test", m, + "HTML metadata should have been extracted!") self.assertEqual(["test 1234"], m["haystack-test"]) # Note the underhanded use of a double snowman to verify both that Tika @@ -1147,8 +1255,14 @@ def setUp(self): super(SolrCommitByDefaultTestCase, self).setUp() self.solr = self.get_solr("core0", always_commit=True) self.docs = [ - {"id": "doc_1", "title": "Newly added doc"}, - {"id": "doc_2", "title": "Another example doc"}, + { + "id": "doc_1", + "title": "Newly added doc" + }, + { + "id": "doc_2", + "title": "Another example doc" + }, ] def test_does_not_require_commit(self):