diff --git a/tests/test_all.py b/tests/test_all.py index 5e70b10..8670253 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -13,6 +13,7 @@ import pandas as pd import numpy as np import pytest +import joblib import psweep as ps @@ -885,3 +886,20 @@ def test_run_local_deprecated(): params = ps.plist("a", [1, 2, 3]) with pytest.deprecated_call(): ps.run_local(func_a, params, save=False) + + +def test_pickle_io(): + obj = dict(a=1, b=_Foo(), c=np.sin) + hsh = lambda obj: joblib.hash(obj, hash_name="sha1") + with tempfile.TemporaryDirectory() as tmpdir: + fn = f"{tmpdir}/path/that/has/to/be/created/file.pk" + ps.pickle_write(fn, obj) + assert hsh(obj) == hsh(ps.pickle_read(fn)) + + +def test_file_io(): + txt = "some random text" + with tempfile.TemporaryDirectory() as tmpdir: + fn = f"{tmpdir}/path/that/has/to/be/created/file.txt" + ps.file_write(fn, txt) + assert txt == ps.file_read(fn)