Skip to content

Commit 28920b5

Browse files
philkrahonzakral
authored andcommitted
Ensure a custom User-Agent header is not overwritten (elastic#992)
1 parent 100a2e6 commit 28920b5

File tree

6 files changed

+52
-11
lines changed

6 files changed

+52
-11
lines changed

elasticsearch/connection/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import logging
22

3+
from platform import python_version
4+
35
try:
46
import simplejson as json
57
except ImportError:
68
import json
79

810
from ..exceptions import TransportError, HTTP_EXCEPTIONS
11+
from .. import __versionstr__
912

1013
logger = logging.getLogger("elasticsearch")
1114

@@ -177,3 +180,6 @@ def _raise_error(self, status_code, raw_data):
177180
raise HTTP_EXCEPTIONS.get(status_code, TransportError)(
178181
status_code, error_message, additional_info
179182
)
183+
184+
def _get_default_user_agent(self):
185+
return "elasticsearch-py/%s (Python %s)" % (__versionstr__, python_version())

elasticsearch/connection/http_requests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def __init__(
7373
self.session = requests.Session()
7474
self.session.headers = headers or {}
7575
self.session.headers.setdefault("content-type", "application/json")
76+
self.session.headers.setdefault("user-agent", self._get_default_user_agent())
7677
if http_auth is not None:
7778
if isinstance(http_auth, (tuple, list)):
7879
http_auth = tuple(http_auth)

elasticsearch/connection/http_urllib3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def __init__(
127127
self.headers.update({"content-encoding": "gzip"})
128128

129129
self.headers.setdefault("content-type", "application/json")
130+
self.headers.setdefault("user-agent", self._get_default_user_agent())
130131
pool_class = urllib3.HTTPConnectionPool
131132
kw = {}
132133

elasticsearch/transport.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -336,11 +336,6 @@ def perform_request(self, method, url, headers=None, params=None, body=None):
336336
ignore = params.pop("ignore", ())
337337
if isinstance(ignore, int):
338338
ignore = (ignore,)
339-
340-
if headers is None:
341-
headers = {}
342-
headers["user-agent"] = "elasticsearch-py/%s (Python %s)" % (__versionstr__, python_version())
343-
344339
for attempt in range(self.max_retries + 1):
345340
connection = self.get_connection()
346341

test_elasticsearch/test_connection.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import urllib3
66
import warnings
77
from requests.auth import AuthBase
8+
from platform import python_version
89

910
from elasticsearch.exceptions import (
1011
TransportError,
@@ -14,9 +15,9 @@
1415
)
1516
from elasticsearch.connection import RequestsHttpConnection, Urllib3HttpConnection
1617
from elasticsearch.exceptions import ImproperlyConfigured
18+
from elasticsearch import __versionstr__
1719
from .test_cases import TestCase, SkipTest
1820

19-
2021
class TestUrllib3Connection(TestCase):
2122
def test_ssl_context(self):
2223
try:
@@ -48,14 +49,22 @@ def test_http_compression(self):
4849
self.assertTrue(con.http_compress)
4950
self.assertEquals(con.headers["content-encoding"], "gzip")
5051

52+
def test_default_user_agent(self):
53+
con = Urllib3HttpConnection()
54+
self.assertEquals(con._get_default_user_agent(), "elasticsearch-py/%s (Python %s)" % (__versionstr__, python_version()))
55+
5156
def test_timeout_set(self):
5257
con = Urllib3HttpConnection(timeout=42)
5358
self.assertEquals(42, con.timeout)
5459

5560
def test_keep_alive_is_on_by_default(self):
5661
con = Urllib3HttpConnection()
5762
self.assertEquals(
58-
{"connection": "keep-alive", "content-type": "application/json"},
63+
{
64+
"connection": "keep-alive",
65+
"content-type": "application/json",
66+
"user-agent": con._get_default_user_agent(),
67+
},
5968
con.headers,
6069
)
6170

@@ -66,6 +75,7 @@ def test_http_auth(self):
6675
"authorization": "Basic dXNlcm5hbWU6c2VjcmV0",
6776
"connection": "keep-alive",
6877
"content-type": "application/json",
78+
"user-agent": con._get_default_user_agent(),
6979
},
7080
con.headers,
7181
)
@@ -77,6 +87,7 @@ def test_http_auth_tuple(self):
7787
"authorization": "Basic dXNlcm5hbWU6c2VjcmV0",
7888
"content-type": "application/json",
7989
"connection": "keep-alive",
90+
"user-agent": con._get_default_user_agent(),
8091
},
8192
con.headers,
8293
)
@@ -88,6 +99,7 @@ def test_http_auth_list(self):
8899
"authorization": "Basic dXNlcm5hbWU6c2VjcmV0",
89100
"content-type": "application/json",
90101
"connection": "keep-alive",
102+
"user-agent": con._get_default_user_agent(),
91103
},
92104
con.headers,
93105
)
@@ -213,6 +225,21 @@ def test_merge_headers(self):
213225
self.assertEquals(req.headers["h2"], "v2p")
214226
self.assertEquals(req.headers["h3"], "v3")
215227

228+
def test_default_headers(self):
229+
con = self._get_mock_connection()
230+
req = self._get_request(con, "GET", "/")
231+
self.assertEquals(req.headers["content-type"], "application/json")
232+
self.assertEquals(req.headers["user-agent"], con._get_default_user_agent())
233+
234+
def test_custom_headers(self):
235+
con = self._get_mock_connection()
236+
req = self._get_request(con, "GET", "/", headers={
237+
"content-type": "application/x-ndjson",
238+
"user-agent": "custom-agent/1.2.3",
239+
})
240+
self.assertEquals(req.headers["content-type"], "application/x-ndjson")
241+
self.assertEquals(req.headers["user-agent"], "custom-agent/1.2.3")
242+
216243
def test_http_auth(self):
217244
con = RequestsHttpConnection(http_auth="username:secret")
218245
self.assertEquals(("username", "secret"), con.session.auth)

test_elasticsearch/test_transport.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
# -*- coding: utf-8 -*-
22
from __future__ import unicode_literals
33
import time
4-
from platform import python_version
54

65
from elasticsearch.transport import Transport, get_host_info
76
from elasticsearch.connection import Connection
87
from elasticsearch.connection_pool import DummyConnectionPool
98
from elasticsearch.exceptions import ConnectionError, ImproperlyConfigured
10-
from elasticsearch import __versionstr__
119

1210
from .test_cases import TestCase
1311

@@ -84,8 +82,21 @@ def test_request_timeout_extracted_from_params_and_passed(self):
8482
self.assertEquals(1, len(t.get_connection().calls))
8583
self.assertEquals(("GET", "/", {}, None), t.get_connection().calls[0][0])
8684
self.assertEquals(
87-
{"timeout": 42, "ignore": (), "headers": {
88-
'user-agent':"elasticsearch-py/%s (Python %s)" % (__versionstr__, python_version())}
85+
{
86+
"timeout": 42,
87+
"ignore": (),
88+
"headers": None,
89+
},
90+
t.get_connection().calls[0][1],
91+
)
92+
93+
def test_request_with_custom_user_agent_header(self):
94+
t = Transport([{}], connection_class=DummyConnection)
95+
96+
t.perform_request("GET", "/", headers={"user-agent": "my-custom-value/1.2.3"})
97+
self.assertEquals(1, len(t.get_connection().calls))
98+
self.assertEquals(
99+
{"timeout": None, "ignore": (), "headers": {"user-agent": "my-custom-value/1.2.3"}
89100
},
90101
t.get_connection().calls[0][1],
91102
)

0 commit comments

Comments
 (0)