@@ -95,6 +95,16 @@ def _get_entities(include_entities):
9595 return entities
9696
9797
98+ def make_mock_client (response ):
99+ import mock
100+ from google .cloud .language .connection import Connection
101+ from google .cloud .language .client import Client
102+
103+ connection = mock .Mock (spec = Connection )
104+ connection .api_request .return_value = response
105+ return mock .Mock (_connection = connection , spec = Client )
106+
107+
98108class TestDocument (unittest .TestCase ):
99109
100110 @staticmethod
@@ -187,7 +197,36 @@ def _verify_entity(self, entity, name, entity_type, wiki_url, salience):
187197 self .assertEqual (entity .salience , salience )
188198 self .assertEqual (entity .mentions , [name ])
189199
200+ @staticmethod
201+ def _expected_data (content , encoding_type = None ,
202+ extract_sentiment = False ,
203+ extract_entities = False ,
204+ extract_syntax = False ):
205+ from google .cloud .language .document import DEFAULT_LANGUAGE
206+ from google .cloud .language .document import Document
207+
208+ expected = {
209+ 'document' : {
210+ 'language' : DEFAULT_LANGUAGE ,
211+ 'type' : Document .PLAIN_TEXT ,
212+ 'content' : content ,
213+ },
214+ }
215+ if encoding_type is not None :
216+ expected ['encodingType' ] = encoding_type
217+ if extract_sentiment :
218+ features = expected .setdefault ('features' , {})
219+ features ['extractDocumentSentiment' ] = True
220+ if extract_entities :
221+ features = expected .setdefault ('features' , {})
222+ features ['extractEntities' ] = True
223+ if extract_syntax :
224+ features = expected .setdefault ('features' , {})
225+ features ['extractSyntax' ] = True
226+ return expected
227+
190228 def test_analyze_entities (self ):
229+ from google .cloud .language .document import Encoding
191230 from google .cloud .language .entity import EntityType
192231
193232 name1 = 'R-O-C-K'
@@ -229,8 +268,7 @@ def test_analyze_entities(self):
229268 ],
230269 'language' : 'en-US' ,
231270 }
232- connection = _Connection (response )
233- client = _Client (connection = connection )
271+ client = make_mock_client (response )
234272 document = self ._make_one (client , content )
235273
236274 entities = document .analyze_entities ()
@@ -243,10 +281,10 @@ def test_analyze_entities(self):
243281 wiki2 , salience2 )
244282
245283 # Verify the request.
246- self .assertEqual ( len ( connection . _requested ), 1 )
247- req = connection . _requested [ 0 ]
248- self . assertEqual ( req [ 'path' ], 'analyzeEntities' )
249- self . assertEqual ( req [ 'method' ], 'POST' )
284+ expected = self ._expected_data (
285+ content , encoding_type = Encoding . UTF8 )
286+ client . _connection . api_request . assert_called_once_with (
287+ path = 'analyzeEntities' , method = 'POST' , data = expected )
250288
251289 def _verify_sentiment (self , sentiment , polarity , magnitude ):
252290 from google .cloud .language .sentiment import Sentiment
@@ -266,18 +304,16 @@ def test_analyze_sentiment(self):
266304 },
267305 'language' : 'en-US' ,
268306 }
269- connection = _Connection (response )
270- client = _Client (connection = connection )
307+ client = make_mock_client (response )
271308 document = self ._make_one (client , content )
272309
273310 sentiment = document .analyze_sentiment ()
274311 self ._verify_sentiment (sentiment , polarity , magnitude )
275312
276313 # Verify the request.
277- self .assertEqual (len (connection ._requested ), 1 )
278- req = connection ._requested [0 ]
279- self .assertEqual (req ['path' ], 'analyzeSentiment' )
280- self .assertEqual (req ['method' ], 'POST' )
314+ expected = self ._expected_data (content )
315+ client ._connection .api_request .assert_called_once_with (
316+ path = 'analyzeSentiment' , method = 'POST' , data = expected )
281317
282318 def _verify_sentences (self , include_syntax , annotations ):
283319 from google .cloud .language .syntax import Sentence
@@ -307,6 +343,7 @@ def _verify_tokens(self, annotations, token_info):
307343 def _annotate_text_helper (self , include_sentiment ,
308344 include_entities , include_syntax ):
309345 from google .cloud .language .document import Annotations
346+ from google .cloud .language .document import Encoding
310347 from google .cloud .language .entity import EntityType
311348
312349 token_info , sentences = _get_token_and_sentences (include_syntax )
@@ -324,8 +361,7 @@ def _annotate_text_helper(self, include_sentiment,
324361 'magnitude' : ANNOTATE_MAGNITUDE ,
325362 }
326363
327- connection = _Connection (response )
328- client = _Client (connection = connection )
364+ client = make_mock_client (response )
329365 document = self ._make_one (client , ANNOTATE_CONTENT )
330366
331367 annotations = document .annotate_text (
@@ -352,16 +388,13 @@ def _annotate_text_helper(self, include_sentiment,
352388 self .assertEqual (annotations .entities , [])
353389
354390 # Verify the request.
355- self .assertEqual (len (connection ._requested ), 1 )
356- req = connection ._requested [0 ]
357- self .assertEqual (req ['path' ], 'annotateText' )
358- self .assertEqual (req ['method' ], 'POST' )
359- features = req ['data' ]['features' ]
360- self .assertEqual (features .get ('extractDocumentSentiment' , False ),
361- include_sentiment )
362- self .assertEqual (features .get ('extractEntities' , False ),
363- include_entities )
364- self .assertEqual (features .get ('extractSyntax' , False ), include_syntax )
391+ expected = self ._expected_data (
392+ ANNOTATE_CONTENT , encoding_type = Encoding .UTF8 ,
393+ extract_sentiment = include_sentiment ,
394+ extract_entities = include_entities ,
395+ extract_syntax = include_syntax )
396+ client ._connection .api_request .assert_called_once_with (
397+ path = 'annotateText' , method = 'POST' , data = expected )
365398
366399 def test_annotate_text (self ):
367400 self ._annotate_text_helper (True , True , True )
@@ -374,20 +407,3 @@ def test_annotate_text_entities_only(self):
374407
375408 def test_annotate_text_syntax_only (self ):
376409 self ._annotate_text_helper (False , False , True )
377-
378-
379- class _Connection (object ):
380-
381- def __init__ (self , response ):
382- self ._response = response
383- self ._requested = []
384-
385- def api_request (self , ** kwargs ):
386- self ._requested .append (kwargs )
387- return self ._response
388-
389-
390- class _Client (object ):
391-
392- def __init__ (self , connection = None ):
393- self ._connection = connection
0 commit comments