Skip to content

Commit

Permalink
Merge pull request #298 from NikolayBaranovv/add_async_generator
Browse files Browse the repository at this point in the history
feat: async generator
  • Loading branch information
psi29a authored Oct 25, 2024
2 parents 54d1a78 + dc32c7a commit 58f13c0
Show file tree
Hide file tree
Showing 3 changed files with 441 additions and 190 deletions.
100 changes: 84 additions & 16 deletions tests/basic/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from unittest.mock import patch

import bson
Expand Down Expand Up @@ -40,6 +41,7 @@
only_for_mongodb_starting_from,
)
from tests.utils import SingleCollectionTest
from txmongo.collection import Cursor
from txmongo.errors import TimeExceeded
from txmongo.protocol import MongoProtocol

Expand All @@ -58,6 +60,17 @@ class TestMongoQueries(SingleCollectionTest):

timeout = 15

@defer.inlineCallbacks
def test_find_return_type(self):
dfr = self.coll.find()
dfr_one = self.coll.find_one()
try:
self.assertIsInstance(dfr, defer.Deferred)
self.assertIsInstance(dfr_one, defer.Deferred)
finally:
yield dfr
yield dfr_one

@defer.inlineCallbacks
def test_SingleCursorIteration(self):
yield self.coll.insert_many([{"v": i} for i in range(10)])
Expand Down Expand Up @@ -190,6 +203,34 @@ def test_CursorClosingWithCursor(self):

yield self.__check_no_open_cursors()

@defer.inlineCallbacks
def test_TimeoutAndDeadline(self):
yield self.coll.insert_many([{"a": i} for i in range(10)])

# Success cases
result = yield self.coll.find()
self.assertEqual(len(result), 10)
result = yield self.coll.find({"$where": "sleep(40); true"}, timeout=0.5)
self.assertEqual(len(result), 10)
result = yield self.coll.find(
{"$where": "sleep(40); true"}, timeout=0.5, batch_size=2
)
self.assertEqual(len(result), 10)

# Timeout cases
dfr = self.coll.find({"$where": "sleep(55); true"}, timeout=0.5)
yield self.assertFailure(dfr, TimeExceeded)
dfr = self.coll.find({"$where": "sleep(55); true"}, timeout=0.5, batch_size=2)
yield self.assertFailure(dfr, TimeExceeded)

# Deadline cases
dfr = self.coll.find({"$where": "sleep(55); true"}, deadline=time.time() + 0.5)
yield self.assertFailure(dfr, TimeExceeded)
dfr = self.coll.find(
{"$where": "sleep(55); true"}, deadline=time.time() + 0.5, batch_size=2
)
yield self.assertFailure(dfr, TimeExceeded)

@defer.inlineCallbacks
def test_CursorClosingWithTimeout(self):
yield self.coll.insert_many({"x": x} for x in range(10))
Expand Down Expand Up @@ -243,6 +284,7 @@ def test_FindOneNone(self):

@defer.inlineCallbacks
def test_AllowPartialResults(self):

with patch.object(
MongoProtocol, "send_msg", side_effect=MongoProtocol.send_msg, autospec=True
) as mock:
Expand All @@ -253,28 +295,54 @@ def test_AllowPartialResults(self):
cmd = bson.decode(msg.body)
self.assertEqual(cmd["allowPartialResults"], True)

async def test_FindIterate(self):
await self.coll.insert_many([{"b": i} for i in range(50)])

class TestMongoQueriesEdgeCases(SingleCollectionTest):
sum_of_doc, doc_count = 0, 0
async for doc in self.coll.find(batch_size=10):
sum_of_doc += doc["b"]
doc_count += 1

timeout = 15
self.assertEqual(sum_of_doc, 1225)
self.assertEqual(doc_count, 50)

@defer.inlineCallbacks
def test_BelowBatchThreshold(self):
yield self.coll.insert_many([{"v": i} for i in range(100)])
res = yield self.coll.find()
self.assertEqual(len(res), 100)
async def test_FindIterateBatches(self):
await self.coll.insert_many([{"a": i} for i in range(100)])

@defer.inlineCallbacks
def test_EqualToBatchThreshold(self):
yield self.coll.insert_many([{"v": i} for i in range(101)])
res = yield self.coll.find()
self.assertEqual(len(res), 101)
all_batches_len = 0
async for batch in self.coll.find(batch_size=10).batches():
batch_len = len(batch)
self.assertEqual(batch_len, 10)
all_batches_len += batch_len

self.assertEqual(all_batches_len, 100)

async def test_FindIterateCloseCursor(self):
await self.coll.insert_many([{"c": i} for i in range(50)])

doc_count = 0
async for _ in self.coll.find(batch_size=10):
doc_count += 1
if doc_count == 25:
break

self.assertEqual(doc_count, 25)

await self.__check_no_open_cursors()

@defer.inlineCallbacks
def test_AboveBatchThreshold(self):
yield self.coll.insert_many([{"v": i} for i in range(102)])
res = yield self.coll.find()
self.assertEqual(len(res), 102)
def test_IterateNextBatch(self):
yield self.coll.insert_many([{"c": i} for i in range(50)])

all_docs = []
cursor = self.coll.find(batch_size=10)
while not cursor.exhausted:
batch = yield cursor.next_batch()
all_docs.extend(batch)

self.assertEqual(len(all_docs), 50)

yield self.__check_no_open_cursors()


class TestLimit(SingleCollectionTest):
Expand Down
Loading

0 comments on commit 58f13c0

Please sign in to comment.