@@ -492,11 +492,13 @@ def test_collection_add(client, cleanup):
492492 assert set (collection3 .list_documents ()) == {document_ref5 }
493493
494494
495- def test_query_stream (client , cleanup ):
495+ @pytest .fixture
496+ def query_docs (client ):
496497 collection_id = "qs" + UNIQUE_RESOURCE_ID
497498 sub_collection = "child" + UNIQUE_RESOURCE_ID
498499 collection = client .collection (collection_id , "doc" , sub_collection )
499500
501+ cleanup = []
500502 stored = {}
501503 num_vals = 5
502504 allowed_vals = six .moves .xrange (num_vals )
@@ -505,38 +507,82 @@ def test_query_stream(client, cleanup):
505507 document_data = {
506508 "a" : a_val ,
507509 "b" : b_val ,
510+ "c" : [a_val , num_vals * 100 ],
508511 "stats" : {"sum" : a_val + b_val , "product" : a_val * b_val },
509512 }
510513 _ , doc_ref = collection .add (document_data )
511514 # Add to clean-up.
512- cleanup (doc_ref .delete )
515+ cleanup . append (doc_ref .delete )
513516 stored [doc_ref .id ] = document_data
514517
515- # 0. Limit to snapshots where ``a==1``.
516- query0 = collection .where ("a" , "==" , 1 )
517- values0 = {snapshot .id : snapshot .to_dict () for snapshot in query0 .stream ()}
518- assert len (values0 ) == num_vals
519- for key , value in six .iteritems (values0 ):
518+ yield collection , stored , allowed_vals
519+
520+ for operation in cleanup :
521+ operation ()
522+
523+
524+ def test_query_stream_w_simple_field_eq_op (query_docs ):
525+ collection , stored , allowed_vals = query_docs
526+ query = collection .where ("a" , "==" , 1 )
527+ values = {snapshot .id : snapshot .to_dict () for snapshot in query .stream ()}
528+ assert len (values ) == len (allowed_vals )
529+ for key , value in six .iteritems (values ):
530+ assert stored [key ] == value
531+ assert value ["a" ] == 1
532+
533+
534+ def test_query_stream_w_simple_field_array_contains_op (query_docs ):
535+ collection , stored , allowed_vals = query_docs
536+ query = collection .where ("c" , "array_contains" , 1 )
537+ values = {snapshot .id : snapshot .to_dict () for snapshot in query .stream ()}
538+ assert len (values ) == len (allowed_vals )
539+ for key , value in six .iteritems (values ):
540+ assert stored [key ] == value
541+ assert value ["a" ] == 1
542+
543+
544+ def test_query_stream_w_simple_field_in_op (query_docs ):
545+ collection , stored , allowed_vals = query_docs
546+ num_vals = len (allowed_vals )
547+ query = collection .where ("a" , "in" , [1 , num_vals + 100 ])
548+ values = {snapshot .id : snapshot .to_dict () for snapshot in query .stream ()}
549+ assert len (values ) == len (allowed_vals )
550+ for key , value in six .iteritems (values ):
520551 assert stored [key ] == value
521552 assert value ["a" ] == 1
522553
523- # 1. Order by ``b``.
524- query1 = collection .order_by ("b" , direction = query0 .DESCENDING )
525- values1 = [(snapshot .id , snapshot .to_dict ()) for snapshot in query1 .stream ()]
526- assert len (values1 ) == len (stored )
527- b_vals1 = []
528- for key , value in values1 :
554+
555+ def test_query_stream_w_simple_field_array_contains_any_op (query_docs ):
556+ collection , stored , allowed_vals = query_docs
557+ num_vals = len (allowed_vals )
558+ query = collection .where ("c" , "array_contains_any" , [1 , num_vals * 200 ])
559+ values = {snapshot .id : snapshot .to_dict () for snapshot in query .stream ()}
560+ assert len (values ) == len (allowed_vals )
561+ for key , value in six .iteritems (values ):
529562 assert stored [key ] == value
530- b_vals1 .append (value ["b" ])
563+ assert value ["a" ] == 1
564+
565+
566+ def test_query_stream_w_order_by (query_docs ):
567+ collection , stored , allowed_vals = query_docs
568+ query = collection .order_by ("b" , direction = firestore .Query .DESCENDING )
569+ values = [(snapshot .id , snapshot .to_dict ()) for snapshot in query .stream ()]
570+ assert len (values ) == len (stored )
571+ b_vals = []
572+ for key , value in values :
573+ assert stored [key ] == value
574+ b_vals .append (value ["b" ])
531575 # Make sure the ``b``-values are in DESCENDING order.
532- assert sorted (b_vals1 , reverse = True ) == b_vals1
576+ assert sorted (b_vals , reverse = True ) == b_vals
577+
533578
534- # 2. Limit to snapshots where ``stats.sum > 1`` (a field path).
535- query2 = collection .where ("stats.sum" , ">" , 4 )
536- values2 = {snapshot .id : snapshot .to_dict () for snapshot in query2 .stream ()}
537- assert len (values2 ) == 10
579+ def test_query_stream_w_field_path (query_docs ):
580+ collection , stored , allowed_vals = query_docs
581+ query = collection .where ("stats.sum" , ">" , 4 )
582+ values = {snapshot .id : snapshot .to_dict () for snapshot in query .stream ()}
583+ assert len (values ) == 10
538584 ab_pairs2 = set ()
539- for key , value in six .iteritems (values2 ):
585+ for key , value in six .iteritems (values ):
540586 assert stored [key ] == value
541587 ab_pairs2 .add ((value ["a" ], value ["b" ]))
542588
@@ -550,63 +596,72 @@ def test_query_stream(client, cleanup):
550596 )
551597 assert expected_ab_pairs == ab_pairs2
552598
553- # 3. Use a start and end cursor.
554- query3 = (
599+
600+ def test_query_stream_w_start_end_cursor (query_docs ):
601+ collection , stored , allowed_vals = query_docs
602+ num_vals = len (allowed_vals )
603+ query = (
555604 collection .order_by ("a" )
556605 .start_at ({"a" : num_vals - 2 })
557606 .end_before ({"a" : num_vals - 1 })
558607 )
559- values3 = [(snapshot .id , snapshot .to_dict ()) for snapshot in query3 .stream ()]
560- assert len (values3 ) == num_vals
561- for key , value in values3 :
608+ values = [(snapshot .id , snapshot .to_dict ()) for snapshot in query .stream ()]
609+ assert len (values ) == num_vals
610+ for key , value in values :
562611 assert stored [key ] == value
563612 assert value ["a" ] == num_vals - 2
564- b_vals1 .append (value ["b" ])
565-
566- # 4. Send a query with no results.
567- query4 = collection .where ("b" , "==" , num_vals + 100 )
568- values4 = list (query4 .stream ())
569- assert len (values4 ) == 0
570-
571- # 5. Select a subset of fields.
572- query5 = collection .where ("b" , "<=" , 1 )
573- query5 = query5 .select (["a" , "stats.product" ])
574- values5 = {snapshot .id : snapshot .to_dict () for snapshot in query5 .stream ()}
575- assert len (values5 ) == num_vals * 2 # a ANY, b in (0, 1)
576- for key , value in six .iteritems (values5 ):
613+
614+
615+ def test_query_stream_wo_results (query_docs ):
616+ collection , stored , allowed_vals = query_docs
617+ num_vals = len (allowed_vals )
618+ query = collection .where ("b" , "==" , num_vals + 100 )
619+ values = list (query .stream ())
620+ assert len (values ) == 0
621+
622+
623+ def test_query_stream_w_projection (query_docs ):
624+ collection , stored , allowed_vals = query_docs
625+ num_vals = len (allowed_vals )
626+ query = collection .where ("b" , "<=" , 1 ).select (["a" , "stats.product" ])
627+ values = {snapshot .id : snapshot .to_dict () for snapshot in query .stream ()}
628+ assert len (values ) == num_vals * 2 # a ANY, b in (0, 1)
629+ for key , value in six .iteritems (values ):
577630 expected = {
578631 "a" : stored [key ]["a" ],
579632 "stats" : {"product" : stored [key ]["stats" ]["product" ]},
580633 }
581634 assert expected == value
582635
583- # 6. Add multiple filters via ``where()``.
584- query6 = collection .where ("stats.product" , ">" , 5 )
585- query6 = query6 .where ("stats.product" , "<" , 10 )
586- values6 = {snapshot .id : snapshot .to_dict () for snapshot in query6 .stream ()}
587636
637+ def test_query_stream_w_multiple_filters (query_docs ):
638+ collection , stored , allowed_vals = query_docs
639+ query = collection .where ("stats.product" , ">" , 5 ).where ("stats.product" , "<" , 10 )
640+ values = {snapshot .id : snapshot .to_dict () for snapshot in query .stream ()}
588641 matching_pairs = [
589642 (a_val , b_val )
590643 for a_val in allowed_vals
591644 for b_val in allowed_vals
592645 if 5 < a_val * b_val < 10
593646 ]
594- assert len (values6 ) == len (matching_pairs )
595- for key , value in six .iteritems (values6 ):
647+ assert len (values ) == len (matching_pairs )
648+ for key , value in six .iteritems (values ):
596649 assert stored [key ] == value
597650 pair = (value ["a" ], value ["b" ])
598651 assert pair in matching_pairs
599652
600- # 7. Skip the first three results, when ``b==2``
601- query7 = collection .where ("b" , "==" , 2 )
653+
654+ def test_query_stream_w_offset (query_docs ):
655+ collection , stored , allowed_vals = query_docs
656+ num_vals = len (allowed_vals )
602657 offset = 3
603- query7 = query7 .offset (offset )
604- values7 = {snapshot .id : snapshot .to_dict () for snapshot in query7 .stream ()}
658+ query = collection . where ( "b" , "==" , 2 ) .offset (offset )
659+ values = {snapshot .id : snapshot .to_dict () for snapshot in query .stream ()}
605660 # NOTE: We don't check the ``a``-values, since that would require
606661 # an ``order_by('a')``, which combined with the ``b == 2``
607662 # filter would necessitate an index.
608- assert len (values7 ) == num_vals - offset
609- for key , value in six .iteritems (values7 ):
663+ assert len (values ) == num_vals - offset
664+ for key , value in six .iteritems (values ):
610665 assert stored [key ] == value
611666 assert value ["b" ] == 2
612667
0 commit comments