From 72fa855344e2fc74194518f53896e4a2a10bd65d Mon Sep 17 00:00:00 2001 From: Mashed Potato <38517644+potatomashed@users.noreply.github.com> Date: Tue, 7 Jan 2025 23:16:34 -0800 Subject: [PATCH] feat(dataclass): Support `copy` and `deepcopy` (#4) This PR introduces support for Python's native `__copy__` and `__deepcopy__` method for all MLC dataclasses. This is done by field visitor and topo visitor. --- .github/workflows/ci.yml | 4 + .pre-commit-config.yaml | 3 + cpp/json.cc | 71 +++---- cpp/structure.cc | 131 +++++++++++++ include/mlc/base/all.h | 2 +- include/mlc/base/utils.h | 7 +- include/mlc/core/dict.h | 5 +- include/mlc/core/field_visitor.h | 17 +- include/mlc/core/list.h | 5 +- python/mlc/_cython/core.pyx | 37 +++- python/mlc/core/device.py | 3 + python/mlc/core/dtype.py | 3 + python/mlc/core/object.py | 15 ++ tests/python/test_dataclasses_copy.py | 264 ++++++++++++++++++++++++++ 14 files changed, 513 insertions(+), 54 deletions(-) create mode 100644 tests/python/test_dataclasses_copy.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d0d4ae00..a689738f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,6 +20,10 @@ jobs: with: python-version: ${{ env.MLC_PYTHON_VERSION }} - uses: pre-commit/action@v3.0.1 + - uses: ytanikin/pr-conventional-commits@1.4.0 + with: + task_types: '["feat", "fix", "ci", "chore", "test"]' + add_label: 'false' windows: name: Windows runs-on: windows-latest diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 678b856d..1f6fb984 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,8 @@ # See https://pre-commit.com for more information # See https://pre-commit.com/hooks.html for more hooks +default_install_hook_types: + - pre-commit + - commit-msg repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v5.0.0 diff --git a/cpp/json.cc b/cpp/json.cc index 2af5b474..18fedd5a 100644 --- a/cpp/json.cc +++ b/cpp/json.cc @@ -31,26 +31,30 @@ inline mlc::Str Serialize(Any any) { using TObj2Idx = std::unordered_map; using TJsonTypeIndex = decltype(get_json_type_index); struct Emitter { + MLC_INLINE void operator()(MLCTypeField *, const Any *any) { EmitAny(any); } // clang-format off - MLC_INLINE void operator()(MLCTypeField *, const Any *any) { EmitAny(any); } - MLC_INLINE void operator()(MLCTypeField *, ObjectRef *obj) { if (Object *v = obj->get()) EmitObject(v); else EmitNil(); } - MLC_INLINE void operator()(MLCTypeField *, Optional *opt) { if (Object *v = opt->get()) EmitObject(v); else EmitNil(); } - MLC_INLINE void operator()(MLCTypeField *, Optional *opt) { if (const int64_t *v = opt->get()) EmitInt(*v); else EmitNil(); } - MLC_INLINE void operator()(MLCTypeField *, Optional *opt) { if (const double *v = opt->get()) EmitFloat(*v); else EmitNil(); } - MLC_INLINE void operator()(MLCTypeField *, Optional *opt) { if (const DLDevice *v = opt->get()) EmitDevice(*v); else EmitNil(); } - MLC_INLINE void operator()(MLCTypeField *, Optional *opt) { if (const DLDataType *v = opt->get()) EmitDType(*v); else EmitNil(); } - MLC_INLINE void operator()(MLCTypeField *, int8_t *v) { EmitInt(static_cast(*v)); } - MLC_INLINE void operator()(MLCTypeField *, int16_t *v) { EmitInt(static_cast(*v)); } - MLC_INLINE void operator()(MLCTypeField *, int32_t *v) { EmitInt(static_cast(*v)); } - MLC_INLINE void operator()(MLCTypeField *, int64_t *v) { EmitInt(static_cast(*v)); } - MLC_INLINE void operator()(MLCTypeField *, float *v) { EmitFloat(static_cast(*v)); } - MLC_INLINE void operator()(MLCTypeField *, double *v) { EmitFloat(static_cast(*v)); } - MLC_INLINE void operator()(MLCTypeField *, DLDataType *v) { EmitDType(*v); } - MLC_INLINE void operator()(MLCTypeField *, DLDevice *v) { EmitDevice(*v); } - MLC_INLINE void operator()(MLCTypeField *, Optional *) { MLC_THROW(TypeError) << "Unserializable type: void *"; } - MLC_INLINE void operator()(MLCTypeField *, void **) { MLC_THROW(TypeError) << "Unserializable type: void *"; } - MLC_INLINE void operator()(MLCTypeField *, const char **) { MLC_THROW(TypeError) << "Unserializable type: const char *"; } + MLC_INLINE void operator()(MLCTypeField *, ObjectRef *obj) { if (Object *v = obj->get()) EmitObject(v); else EmitNil(); } + MLC_INLINE void operator()(MLCTypeField *, Optional *opt) { if (Object *v = opt->get()) EmitObject(v); else EmitNil(); } + MLC_INLINE void operator()(MLCTypeField *, Optional *opt) { if (const int64_t *v = opt->get()) EmitInt(*v); else EmitNil(); } + MLC_INLINE void operator()(MLCTypeField *, Optional *opt) { if (const double *v = opt->get()) EmitFloat(*v); else EmitNil(); } + MLC_INLINE void operator()(MLCTypeField *, Optional *opt) { if (const DLDevice *v = opt->get()) EmitDevice(*v); else EmitNil(); } + MLC_INLINE void operator()(MLCTypeField *, Optional *opt) { if (const DLDataType *v = opt->get()) EmitDType(*v); else EmitNil(); } // clang-format on + MLC_INLINE void operator()(MLCTypeField *, int8_t *v) { EmitInt(static_cast(*v)); } + MLC_INLINE void operator()(MLCTypeField *, int16_t *v) { EmitInt(static_cast(*v)); } + MLC_INLINE void operator()(MLCTypeField *, int32_t *v) { EmitInt(static_cast(*v)); } + MLC_INLINE void operator()(MLCTypeField *, int64_t *v) { EmitInt(static_cast(*v)); } + MLC_INLINE void operator()(MLCTypeField *, float *v) { EmitFloat(static_cast(*v)); } + MLC_INLINE void operator()(MLCTypeField *, double *v) { EmitFloat(static_cast(*v)); } + MLC_INLINE void operator()(MLCTypeField *, DLDataType *v) { EmitDType(*v); } + MLC_INLINE void operator()(MLCTypeField *, DLDevice *v) { EmitDevice(*v); } + MLC_INLINE void operator()(MLCTypeField *, Optional *) { + MLC_THROW(TypeError) << "Unserializable type: void *"; + } + MLC_INLINE void operator()(MLCTypeField *, void **) { MLC_THROW(TypeError) << "Unserializable type: void *"; } + MLC_INLINE void operator()(MLCTypeField *, const char **) { + MLC_THROW(TypeError) << "Unserializable type: const char *"; + } inline void EmitNil() { (*os) << ", null"; } inline void EmitFloat(double v) { (*os) << ", " << std::fixed << std::setprecision(19) << v; } inline void EmitInt(int64_t v) { @@ -98,10 +102,17 @@ inline mlc::Str Serialize(Any any) { const TObj2Idx *obj2index; }; + std::unordered_map topo_indices; std::ostringstream os; - auto on_visit = [get_json_type_index = &get_json_type_index, os = &os, is_first_object = true]( - Object *object, MLCTypeInfo *type_info, const TObj2Idx &obj2index) mutable -> void { - Emitter emitter{os, get_json_type_index, &obj2index}; + auto on_visit = [&topo_indices, get_json_type_index = &get_json_type_index, os = &os, + is_first_object = true](Object *object, MLCTypeInfo *type_info) mutable -> void { + int32_t &topo_index = topo_indices[object]; + if (topo_index == 0) { + topo_index = static_cast(topo_indices.size()) - 1; + } else { + MLC_THROW(InternalError) << "This should never happen: object already visited"; + } + Emitter emitter{os, get_json_type_index, &topo_indices}; if (is_first_object) { is_first_object = false; } else { @@ -163,29 +174,23 @@ inline mlc::Str Serialize(Any any) { } inline Any Deserialize(const char *json_str, int64_t json_str_len) { - MLCVTableHandle init_vtable; - MLCVTableGetGlobal(nullptr, "__init__", &init_vtable); + MLCVTableHandle init_table = ::mlc::base::LibState::init; // Step 0. Parse JSON string UDict json_obj = JSONLoads(json_str, json_str_len); // Step 1. type_key => constructors UList type_keys = json_obj->at("type_keys"); - std::vector constructors; + std::vector constructors; constructors.reserve(type_keys.size()); for (Str type_key : type_keys) { - Any init_func; int32_t type_index = ::mlc::base::TypeKey2TypeIndex(type_key->data()); - MLCVTableGetFunc(init_vtable, type_index, false, &init_func); - if (!::mlc::base::IsTypeIndexNone(init_func.type_index)) { - constructors.push_back(init_func.operator Func()); - } else { - MLC_THROW(InternalError) << "Method `__init__` is not defined for type " << type_key; - } + FuncObj *func = ::mlc::base::LibState::VTableGetFunc(init_table, type_index, "__init__"); + constructors.push_back(func); } auto invoke_init = [&constructors](UList args) { int32_t json_type_index = args[0]; Any ret; - ::mlc::base::FuncCall(constructors.at(json_type_index).get(), static_cast(args.size()) - 1, - args->data() + 1, &ret); + ::mlc::base::FuncCall(constructors.at(json_type_index), static_cast(args.size()) - 1, args->data() + 1, + &ret); return ret; }; // Step 2. Translate JSON object to objects diff --git a/cpp/structure.cc b/cpp/structure.cc index 0401ece9..b793a6db 100644 --- a/cpp/structure.cc +++ b/cpp/structure.cc @@ -1,3 +1,4 @@ +#include "mlc/core/error.h" #include #include #include @@ -532,11 +533,141 @@ inline uint64_t StructuralHash(Object *obj) { #undef MLC_CORE_HASH_S_POD #undef MLC_CORE_HASH_S_ANY +inline Any CopyShallow(AnyView source) { + int32_t type_index = source.type_index; + if (::mlc::base::IsTypeIndexPOD(type_index)) { + return source; + } else if (UListObj *list = source.TryCast()) { + return UList(list->begin(), list->end()); + } else if (UDictObj *dict = source.TryCast()) { + return UDict(dict->begin(), dict->end()); + } else if (source.IsInstance() || source.IsInstance() || source.IsInstance()) { + return source; + } + struct Copier { + MLC_INLINE void operator()(MLCTypeField *, const Any *any) { fields->push_back(AnyView(*any)); } + MLC_INLINE void operator()(MLCTypeField *, ObjectRef *obj) { fields->push_back(AnyView(*obj)); } + MLC_INLINE void operator()(MLCTypeField *, Optional *opt) { fields->push_back(AnyView(*opt)); } + MLC_INLINE void operator()(MLCTypeField *, Optional *opt) { fields->push_back(AnyView(*opt)); } + MLC_INLINE void operator()(MLCTypeField *, Optional *opt) { fields->push_back(AnyView(*opt)); } + MLC_INLINE void operator()(MLCTypeField *, Optional *opt) { fields->push_back(AnyView(*opt)); } + MLC_INLINE void operator()(MLCTypeField *, Optional *opt) { fields->push_back(AnyView(*opt)); } + MLC_INLINE void operator()(MLCTypeField *, int8_t *v) { fields->push_back(AnyView(*v)); } + MLC_INLINE void operator()(MLCTypeField *, int16_t *v) { fields->push_back(AnyView(*v)); } + MLC_INLINE void operator()(MLCTypeField *, int32_t *v) { fields->push_back(AnyView(*v)); } + MLC_INLINE void operator()(MLCTypeField *, int64_t *v) { fields->push_back(AnyView(*v)); } + MLC_INLINE void operator()(MLCTypeField *, float *v) { fields->push_back(AnyView(*v)); } + MLC_INLINE void operator()(MLCTypeField *, double *v) { fields->push_back(AnyView(*v)); } + MLC_INLINE void operator()(MLCTypeField *, DLDataType *v) { fields->push_back(AnyView(*v)); } + MLC_INLINE void operator()(MLCTypeField *, DLDevice *v) { fields->push_back(AnyView(*v)); } + MLC_INLINE void operator()(MLCTypeField *, Optional *v) { fields->push_back(AnyView(*v)); } + MLC_INLINE void operator()(MLCTypeField *, void **v) { fields->push_back(AnyView(*v)); } + MLC_INLINE void operator()(MLCTypeField *, const char **v) { fields->push_back(AnyView(*v)); } + std::vector *fields; + }; + FuncObj *init_func = ::mlc::base::LibState::VTableGetFunc(::mlc::base::LibState::init, type_index, "__init__"); + MLCTypeInfo *type_info = ::mlc::base::TypeIndex2TypeInfo(type_index); + std::vector fields; + VisitFields(source.operator Object *(), type_info, Copier{&fields}); + Any ret; + ::mlc::base::FuncCall(init_func, static_cast(fields.size()), fields.data(), &ret); + return ret; +} + +inline Any CopyDeep(AnyView source) { + if (::mlc::base::IsTypeIndexPOD(source.type_index)) { + return source; + } + struct Copier { + MLC_INLINE void operator()(MLCTypeField *, const Any *any) { HandleAny(any); } + MLC_INLINE void operator()(MLCTypeField *, ObjectRef *ref) { + if (const Object *obj = ref->get()) { + HandleObject(obj); + } else { + fields->push_back(AnyView()); + } + } + MLC_INLINE void operator()(MLCTypeField *, Optional *opt) { + if (const Object *obj = opt->get()) { + HandleObject(obj); + } else { + fields->push_back(AnyView()); + } + } + MLC_INLINE void operator()(MLCTypeField *, Optional *opt) { fields->push_back(AnyView(*opt)); } + MLC_INLINE void operator()(MLCTypeField *, Optional *opt) { fields->push_back(AnyView(*opt)); } + MLC_INLINE void operator()(MLCTypeField *, Optional *opt) { fields->push_back(AnyView(*opt)); } + MLC_INLINE void operator()(MLCTypeField *, Optional *opt) { fields->push_back(AnyView(*opt)); } + MLC_INLINE void operator()(MLCTypeField *, int8_t *v) { fields->push_back(AnyView(*v)); } + MLC_INLINE void operator()(MLCTypeField *, int16_t *v) { fields->push_back(AnyView(*v)); } + MLC_INLINE void operator()(MLCTypeField *, int32_t *v) { fields->push_back(AnyView(*v)); } + MLC_INLINE void operator()(MLCTypeField *, int64_t *v) { fields->push_back(AnyView(*v)); } + MLC_INLINE void operator()(MLCTypeField *, float *v) { fields->push_back(AnyView(*v)); } + MLC_INLINE void operator()(MLCTypeField *, double *v) { fields->push_back(AnyView(*v)); } + MLC_INLINE void operator()(MLCTypeField *, DLDataType *v) { fields->push_back(AnyView(*v)); } + MLC_INLINE void operator()(MLCTypeField *, DLDevice *v) { fields->push_back(AnyView(*v)); } + MLC_INLINE void operator()(MLCTypeField *, Optional *v) { fields->push_back(AnyView(*v)); } + MLC_INLINE void operator()(MLCTypeField *, void **v) { fields->push_back(AnyView(*v)); } + MLC_INLINE void operator()(MLCTypeField *, const char **v) { fields->push_back(AnyView(*v)); } + + void HandleObject(const Object *obj) { + if (auto it = orig2copy->find(obj); it != orig2copy->end()) { + fields->push_back(AnyView(it->second)); + } else { + MLC_THROW(InternalError) << "InternalError: object doesn't exist in the memo: " << AnyView(obj); + } + } + + void HandleAny(const Any *any) { + if (const Object *obj = any->TryCast()) { + HandleObject(obj); + } else { + fields->push_back(AnyView(*any)); + } + } + + std::unordered_map *orig2copy; + std::vector *fields; + }; + std::unordered_map orig2copy; + std::vector fields; + TopoVisit(source.operator Object *(), nullptr, [&](Object *object, MLCTypeInfo *type_info) mutable -> void { + Any ret; + if (UListObj *list = object->TryCast()) { + fields.clear(); + fields.reserve(list->size()); + for (Any &e : *list) { + Copier{&orig2copy, &fields}.HandleAny(&e); + } + UList::FromAnyTuple(static_cast(fields.size()), fields.data(), &ret); + } else if (UDictObj *dict = object->TryCast()) { + fields.clear(); + for (auto [key, value] : *dict) { + Copier{&orig2copy, &fields}.HandleAny(&key); + Copier{&orig2copy, &fields}.HandleAny(&value); + } + UDict::FromAnyTuple(static_cast(fields.size()), fields.data(), &ret); + } else if (object->IsInstance() || object->IsInstance() || object->IsInstance()) { + ret = object; + } else { + fields.clear(); + VisitFields(object, type_info, Copier{&orig2copy, &fields}); + FuncObj *func = + ::mlc::base::LibState::VTableGetFunc(::mlc::base::LibState::init, type_info->type_index, "__init__"); + ::mlc::base::FuncCall(func, static_cast(fields.size()), fields.data(), &ret); + } + orig2copy[object] = ret.operator ObjectRef(); + }); + return orig2copy.at(source.operator Object *()); +} + MLC_REGISTER_FUNC("mlc.core.StructuralEqual").set_body(::mlc::core::StructuralEqual); MLC_REGISTER_FUNC("mlc.core.StructuralHash").set_body([](::mlc::Object *obj) -> int64_t { uint64_t ret = ::mlc::core::StructuralHash(obj); return static_cast(ret); }); +MLC_REGISTER_FUNC("mlc.core.CopyShallow").set_body(::mlc::core::CopyShallow); +MLC_REGISTER_FUNC("mlc.core.CopyDeep").set_body(::mlc::core::CopyDeep); } // namespace } // namespace core } // namespace mlc diff --git a/include/mlc/base/all.h b/include/mlc/base/all.h index 9543a129..764bac9a 100644 --- a/include/mlc/base/all.h +++ b/include/mlc/base/all.h @@ -58,7 +58,7 @@ template MLC_INLINE AnyView::AnyView(Ref &&src) : AnyView(static // `src` is not reset here because `AnyView` does not take ownership of the object } -template MLC_INLINE AnyView::AnyView(const Optional &src) { +template MLC_INLINE AnyView::AnyView(const Optional &src) : MLCAny() { if (const auto *value = src.get()) { if constexpr (::mlc::base::IsPOD) { using TPOD = T; diff --git a/include/mlc/base/utils.h b/include/mlc/base/utils.h index 745429f8..162f02d3 100644 --- a/include/mlc/base/utils.h +++ b/include/mlc/base/utils.h @@ -387,7 +387,11 @@ struct LibState { DecRef(func.v.v_obj); } FuncObj *ret = reinterpret_cast(func.v.v_obj); - if (func.type_index != kMLCFunc) { + if (func.type_index == kMLCNone) { + MLC_THROW(TypeError) << "Function `" << vtable_name + << "` for type: " << ::mlc::base::TypeIndex2TypeKey(type_index) + << " is not defined in the vtable"; + } else if (func.type_index != kMLCFunc) { MLC_THROW(TypeError) << "Function `" << vtable_name << "` for type: " << ::mlc::base::TypeIndex2TypeKey(type_index) << " is not callable. Its type is " << ::mlc::base::TypeIndex2TypeKey(func.type_index); @@ -401,6 +405,7 @@ struct LibState { static MLC_SYMBOL_HIDE inline MLCVTableHandle cxx_str = VTableGetGlobal("__cxx_str__"); static MLC_SYMBOL_HIDE inline MLCVTableHandle str = VTableGetGlobal("__str__"); static MLC_SYMBOL_HIDE inline MLCVTableHandle ir_print = VTableGetGlobal("__ir_print__"); + static MLC_SYMBOL_HIDE inline MLCVTableHandle init = VTableGetGlobal("__init__"); }; } // namespace base diff --git a/include/mlc/core/dict.h b/include/mlc/core/dict.h index b86e38c7..d57580ed 100644 --- a/include/mlc/core/dict.h +++ b/include/mlc/core/dict.h @@ -138,11 +138,14 @@ struct UDict : public ObjectRef { MLC_INLINE const_iterator end() const { return get()->end(); } MLC_INLINE const_reverse_iterator rbegin() const { return get()->rbegin(); } MLC_INLINE const_reverse_iterator rend() const { return get()->rend(); } + MLC_INLINE static void FromAnyTuple(int32_t num_args, const AnyView *args, Any *ret) { + ::mlc::core::DictBase::Accessor::New(num_args, args, ret); + } MLC_DEF_OBJ_REF(UDict, UDictObj, ObjectRef) .FieldReadOnly("capacity", &MLCDict::capacity) .FieldReadOnly("size", &MLCDict::size) .FieldReadOnly("data", &MLCDict::data) - .StaticFn("__init__", ::mlc::core::DictBase::Accessor::New) + .StaticFn("__init__", FromAnyTuple) .MemFn("__str__", &UDictObj::__str__) .MemFn("__getitem__", ::mlc::core::DictBase::Accessor::GetItem) .MemFn("__iter_get_key__", ::mlc::core::DictBase::Accessor::GetKey) diff --git a/include/mlc/core/field_visitor.h b/include/mlc/core/field_visitor.h index 78fddac3..d8aed12c 100644 --- a/include/mlc/core/field_visitor.h +++ b/include/mlc/core/field_visitor.h @@ -164,9 +164,7 @@ template inline void VisitStructure(Object *root, MLCTypeInfo } inline void TopoVisit(Object *root, std::function pre_visit, - std::function &topo_indices)> - on_visit) { + std::function on_visit) { struct TopoInfo { Object *obj; MLCTypeInfo *type_info; @@ -271,20 +269,13 @@ inline void TopoVisit(Object *root, std::function topo_indices; size_t num_objects = 0; for (; !stack.empty(); ++num_objects) { TopoInfo *current = stack.back(); stack.pop_back(); - // Step 3.1. Lable object index - int32_t &topo_index = topo_indices[current->obj]; - if (topo_index != 0) { - MLC_THROW(InternalError) << "This should never happen: object already visited"; - } - topo_index = static_cast(num_objects); - // Step 3.2. Visit object - on_visit(current->obj, current->type_info, topo_indices); - // Step 3.3. Decrease the dependency count of topo_parents + // Step 3.1. Visit object + on_visit(current->obj, current->type_info); + // Step 3.2. Decrease the dependency count of topo_parents for (TopoInfo *parent : current->topo_parents) { if (--parent->topo_deps == 0) { stack.push_back(parent); diff --git a/include/mlc/core/list.h b/include/mlc/core/list.h index c9b80868..7860f2c4 100644 --- a/include/mlc/core/list.h +++ b/include/mlc/core/list.h @@ -147,11 +147,14 @@ struct UList : public ObjectRef { MLC_INLINE const_iterator end() const { return get()->end(); } MLC_INLINE const_reverse_iterator rbegin() const { return get()->rbegin(); } MLC_INLINE const_reverse_iterator rend() const { return get()->rend(); } + MLC_INLINE static void FromAnyTuple(int32_t num_args, const AnyView *args, Any *ret) { + ::mlc::core::ListBase::Accessor::New(num_args, args, ret); + } MLC_DEF_OBJ_REF(UList, UListObj, ObjectRef) .FieldReadOnly("size", &MLCList::size) .FieldReadOnly("capacity", &MLCList::capacity) .FieldReadOnly("data", &MLCList::data) - .StaticFn("__init__", &::mlc::core::ListBase::Accessor::New) + .StaticFn("__init__", FromAnyTuple) .MemFn("__str__", &UListObj::__str__) .MemFn("__iter_at__", &::mlc::core::ListBase::Accessor::At); }; diff --git a/python/mlc/_cython/core.pyx b/python/mlc/_cython/core.pyx index 93dd60fc..978954ac 100644 --- a/python/mlc/_cython/core.pyx +++ b/python/mlc/_cython/core.pyx @@ -300,6 +300,11 @@ cdef class PyAny: def __init__(self): pass + @property + def _mlc_address(self): + cdef uint64_t ret = ((self._mlc_any.v.v_obj)) if self._mlc_any.type_index >= kMLCStaticObjectBegin else 0 # no-cython-lint + return ret + def _mlc_init(self, *init_args) -> None: cdef int32_t type_index = type(self)._mlc_type_info.type_index cdef MLCFunc* func = _vtable_get_func_ptr(_VTABLE_INIT, type_index, False) @@ -345,6 +350,14 @@ cdef class PyAny: ret += 2 ** 63 return ret + @staticmethod + def _mlc_copy_shallow(PyAny x) -> PyAny: + return func_call(_COPY_SHALLOW, (x,)) + + @staticmethod + def _mlc_copy_deep(PyAny x) -> PyAny: + return func_call(_COPY_DEEP, (x,)) + @classmethod def _C(cls, bytes name, *args): cdef int32_t type_index = cls._mlc_type_info.type_index @@ -902,6 +915,11 @@ cdef class TypeCheckerList: cdef int32_t num_args = 0 cdef MLCAny* c_args = NULL cdef MLCAny ret = _MLCAnyNone() + if isinstance(_value, mlc_list): + for v in _value: + _type_checker_call(self.sub, v, temporary_storage) + temporary_storage.append(_value) + return (_value)._mlc_any if not isinstance(_value, (list, tuple, mlc_list)): raise TypeError(f"Expected `list` or `tuple`, but got: {type(_value)}") value = tuple(_value) @@ -932,6 +950,8 @@ cdef class TypeCheckerDict: @staticmethod cdef MLCAny convert(object _self, object _value, list temporary_storage): + from mlc.core.dict import Dict as mlc_dict + cdef TypeCheckerDict self = _self cdef tuple value cdef int32_t num_args = 0 @@ -939,9 +959,16 @@ cdef class TypeCheckerDict: cdef MLCAny ret = _MLCAnyNone() cdef TypeChecker sub_k = self.sub_k cdef TypeChecker sub_v = self.sub_v - if not isinstance(_value, dict): + if isinstance(_value, mlc_dict): + for k, v in _value.items(): + _type_checker_call(sub_k, k, temporary_storage) + _type_checker_call(sub_v, v, temporary_storage) + temporary_storage.append(_value) + return (_value)._mlc_any + elif isinstance(_value, dict): + value = _flatten_dict_to_tuple(_value) + else: raise TypeError(f"Expected `dict`, but got: {type(_value)}") - value = _flatten_dict_to_tuple(_value) num_args = len(value) c_args = malloc(num_args * sizeof(MLCAny)) try: @@ -1403,18 +1430,20 @@ cdef PyAny _SERIALIZE = func_get_untyped("mlc.core.JSONSerialize") # Any -> str cdef PyAny _DESERIALIZE = func_get_untyped("mlc.core.JSONDeserialize") # str -> Any cdef PyAny _STRUCUTRAL_EQUAL = func_get_untyped("mlc.core.StructuralEqual") cdef PyAny _STRUCUTRAL_HASH = func_get_untyped("mlc.core.StructuralHash") +cdef PyAny _COPY_SHALLOW = func_get_untyped("mlc.core.CopyShallow") +cdef PyAny _COPY_DEEP = func_get_untyped("mlc.core.CopyDeep") -cdef MLCVTableHandle _VTABLE_INIT = _vtable_get_global(b"__init__") cdef MLCVTableHandle _VTABLE_STR = _vtable_get_global(b"__str__") -cdef MLCVTableHandle _VTABLE_NEW_REF = _vtable_get_global(b"__new_ref__") cdef MLCVTableHandle _VTABLE_ANY_TO_REF = _vtable_get_global(b"__any_to_ref__") +cdef MLCVTableHandle _VTABLE_NEW_REF = _vtable_get_global(b"__new_ref__") cdef MLCFunc* _INT_NEW = _vtable_get_func_ptr(_VTABLE_NEW_REF, kMLCInt, False) cdef MLCFunc* _FLOAT_NEW = _vtable_get_func_ptr(_VTABLE_NEW_REF, kMLCFloat, False) cdef MLCFunc* _PTR_NEW = _vtable_get_func_ptr(_VTABLE_NEW_REF, kMLCPtr, False) cdef MLCFunc* _DTYPE_NEW = _vtable_get_func_ptr(_VTABLE_NEW_REF, kMLCDataType, False) cdef MLCFunc* _DEVICE_NEW = _vtable_get_func_ptr(_VTABLE_NEW_REF, kMLCDevice, False) +cdef MLCVTableHandle _VTABLE_INIT = _vtable_get_global(b"__init__") cdef MLCFunc* _DTYPE_INIT = _vtable_get_func_ptr(_VTABLE_INIT, kMLCDataType, False) cdef MLCFunc* _DEVICE_INIT = _vtable_get_func_ptr(_VTABLE_INIT, kMLCDevice, False) cdef MLCFunc* _LIST_INIT = _vtable_get_func_ptr(_VTABLE_INIT, kMLCList, False) diff --git a/python/mlc/core/device.py b/python/mlc/core/device.py index cfef134c..a2358dc8 100644 --- a/python/mlc/core/device.py +++ b/python/mlc/core/device.py @@ -25,3 +25,6 @@ def __eq__(self, other: object) -> bool: def __ne__(self, other: object) -> bool: return isinstance(other, Device) and self._dtype_triple != other._dtype_triple + + def __hash__(self) -> int: + return hash((Device, *self._device_pair)) diff --git a/python/mlc/core/dtype.py b/python/mlc/core/dtype.py index f8c0c51b..b5694725 100644 --- a/python/mlc/core/dtype.py +++ b/python/mlc/core/dtype.py @@ -31,3 +31,6 @@ def __eq__(self, other: object) -> bool: def __ne__(self, other: object) -> bool: return isinstance(other, DataType) and self._dtype_triple != other._dtype_triple + + def __hash__(self) -> int: + return hash((DataType, *self._dtype_triple)) diff --git a/python/mlc/core/object.py b/python/mlc/core/object.py index b6302d10..d68bed15 100644 --- a/python/mlc/core/object.py +++ b/python/mlc/core/object.py @@ -31,3 +31,18 @@ def eq_s( def hash_s(self) -> int: return PyAny._mlc_hash_s(self) # type: ignore[attr-defined] + + def __copy__(self: Object) -> Object: + return PyAny._mlc_copy_shallow(self) # type: ignore[attr-defined] + + def __deepcopy__(self: Object, memo: dict[int, Object] | None) -> Object: + return PyAny._mlc_copy_deep(self) + + def __hash__(self) -> int: + return hash((type(self), self._mlc_address)) + + def __eq__(self, other: object) -> bool: + return isinstance(other, Object) and self._mlc_address == other._mlc_address + + def __ne__(self, other: object) -> bool: + return not self == other diff --git a/tests/python/test_dataclasses_copy.py b/tests/python/test_dataclasses_copy.py new file mode 100644 index 00000000..ff36eea9 --- /dev/null +++ b/tests/python/test_dataclasses_copy.py @@ -0,0 +1,264 @@ +import copy +from typing import Any, Optional + +import mlc +import pytest + + +@mlc.py_class +class PyClassForTest(mlc.PyClass): + i64: int + f64: float + raw_ptr: mlc.Ptr + dtype: mlc.DataType + device: mlc.Device + any: Any + func: mlc.Func + ulist: list[Any] + udict: dict + str_: str + ### + list_any: list[Any] + list_list_int: list[list[int]] + dict_any_any: dict[Any, Any] + dict_str_any: dict[str, Any] + dict_any_str: dict[Any, str] + dict_str_list_int: dict[str, list[int]] + ### + opt_i64: Optional[int] + opt_f64: Optional[float] + opt_raw_ptr: Optional[mlc.Ptr] + opt_dtype: Optional[mlc.DataType] + opt_device: Optional[mlc.Device] + opt_func: Optional[mlc.Func] + opt_ulist: Optional[list] + opt_udict: Optional[dict[Any, Any]] + opt_str: Optional[str] + ### + opt_list_any: Optional[list[Any]] + opt_list_list_int: Optional[list[list[int]]] + opt_dict_any_any: Optional[dict] + opt_dict_str_any: Optional[dict[str, Any]] + opt_dict_any_str: Optional[dict[Any, str]] + opt_dict_str_list_int: Optional[dict[str, list[int]]] + + def i64_plus_one(self) -> int: + return self.i64 + 1 + + +@pytest.fixture +def mlc_class_for_test() -> PyClassForTest: + return PyClassForTest( + i64=64, + f64=2.5, + raw_ptr=mlc.Ptr(0xDEADBEEF), + dtype="float8", + device="cuda:0", + any="hello", + func=lambda x: x + 1, + ulist=[1, 2.0, "three", lambda: 4], + udict={"1": 1, "2": 2.0, "3": "three", "4": lambda: 4}, + str_="world", + ### + list_any=[1, 2.0, "three", lambda: 4], + list_list_int=[[1, 2, 3], [4, 5, 6]], + dict_any_any={1: 1.0, 2.0: 2, "three": "four", 4: lambda: 5}, + dict_str_any={"1": 1.0, "2.0": 2, "three": "four", "4": lambda: 5}, + dict_any_str={1: "1.0", 2.0: "2", "three": "four", 4: "5"}, + dict_str_list_int={"1": [1, 2, 3], "2": [4, 5, 6]}, + ### + opt_i64=-64, + opt_f64=-2.5, + opt_raw_ptr=mlc.Ptr(0xBEEFDEAD), + opt_dtype="float16", + opt_device="cuda:0", + opt_func=lambda x: x - 1, + opt_ulist=[1, 2.0, "three", lambda: 4], + opt_udict={"1": 1, "2": 2.0, "3": "three", "4": lambda: 4}, + opt_str="world", + ### + opt_list_any=[1, 2.0, "three", lambda: 4], + opt_list_list_int=[[1, 2, 3], [4, 5, 6]], + opt_dict_any_any={1: 1.0, 2.0: 2, "three": "four", 4: lambda: 5}, + opt_dict_str_any={"1": 1.0, "2.0": 2, "three": "four", "4": lambda: 5}, + opt_dict_any_str={1: "1.0", 2.0: "2", "three": "four", 4: "5"}, + opt_dict_str_list_int={"1": [1, 2, 3], "2": [4, 5, 6]}, + ) + + +def test_copy_shallow(mlc_class_for_test: PyClassForTest) -> None: + src = mlc_class_for_test + dst = copy.copy(src) + assert src != dst + assert src.i64 == dst.i64 + assert src.f64 == dst.f64 + assert src.raw_ptr.value == dst.raw_ptr.value + assert src.dtype == dst.dtype + assert src.device == dst.device + assert src.any == dst.any + assert src.func(1) == dst.func(1) + assert src.ulist == dst.ulist + assert src.udict == dst.udict + assert src.str_ == dst.str_ + assert src.list_any == dst.list_any + assert src.list_list_int == dst.list_list_int + assert src.dict_any_any == dst.dict_any_any + assert src.dict_str_any == dst.dict_str_any + assert src.dict_any_str == dst.dict_any_str + assert src.dict_str_list_int == dst.dict_str_list_int + assert src.opt_i64 == dst.opt_i64 + assert src.opt_f64 == dst.opt_f64 + assert src.opt_raw_ptr.value == dst.opt_raw_ptr.value # type: ignore[union-attr] + assert src.opt_dtype == dst.opt_dtype + assert src.opt_device == dst.opt_device + assert src.opt_func(2) == dst.opt_func(2) # type: ignore[misc] + assert src.opt_ulist == dst.opt_ulist + assert src.opt_udict == dst.opt_udict + assert src.opt_str == dst.opt_str + assert src.opt_list_any == dst.opt_list_any + assert src.opt_list_list_int == dst.opt_list_list_int + assert src.opt_dict_any_any == dst.opt_dict_any_any + assert src.opt_dict_str_any == dst.opt_dict_str_any + assert src.opt_dict_any_str == dst.opt_dict_any_str + assert src.opt_dict_str_list_int == dst.opt_dict_str_list_int + + +def test_copy_deep(mlc_class_for_test: PyClassForTest) -> None: + src = mlc_class_for_test + dst = copy.deepcopy(src) + assert src != dst + assert src.i64 == dst.i64 + assert src.f64 == dst.f64 + assert src.raw_ptr.value == dst.raw_ptr.value + assert src.dtype == dst.dtype + assert src.device == dst.device + assert src.any == dst.any + assert src.func(1) == dst.func(1) + assert ( + src.ulist != dst.ulist + and len(src.ulist) == len(dst.ulist) + and src.ulist[0] == dst.ulist[0] + and src.ulist[1] == dst.ulist[1] + and src.ulist[2] == dst.ulist[2] + and src.ulist[3]() == dst.ulist[3]() + ) + assert ( + src.udict != dst.udict + and len(src.udict) == len(dst.udict) + and src.udict["1"] == dst.udict["1"] + and src.udict["2"] == dst.udict["2"] + and src.udict["3"] == dst.udict["3"] + and src.udict["4"]() == dst.udict["4"]() + ) + assert src.str_ == dst.str_ + assert ( + src.list_any != dst.list_any + and len(src.list_any) == len(dst.list_any) + and src.list_any[0] == dst.list_any[0] + and src.list_any[1] == dst.list_any[1] + and src.list_any[2] == dst.list_any[2] + and src.list_any[3]() == dst.list_any[3]() + ) + assert ( + src.list_list_int != dst.list_list_int + and len(src.list_list_int) == len(dst.list_list_int) + and tuple(src.list_list_int[0]) == tuple(dst.list_list_int[0]) + and tuple(src.list_list_int[1]) == tuple(dst.list_list_int[1]) + ) + assert ( + src.dict_any_any != dst.dict_any_any + and len(src.dict_any_any) == len(dst.dict_any_any) + and src.dict_any_any[1] == dst.dict_any_any[1] + and src.dict_any_any[2.0] == dst.dict_any_any[2.0] + and src.dict_any_any["three"] == dst.dict_any_any["three"] + and src.dict_any_any[4]() == dst.dict_any_any[4]() + ) + assert ( + src.dict_str_any != dst.dict_str_any + and len(src.dict_str_any) == len(dst.dict_str_any) + and src.dict_str_any["1"] == dst.dict_str_any["1"] + and src.dict_str_any["2.0"] == dst.dict_str_any["2.0"] + and src.dict_str_any["three"] == dst.dict_str_any["three"] + and src.dict_str_any["4"]() == dst.dict_str_any["4"]() + ) + assert ( + src.dict_any_str != dst.dict_any_str + and len(src.dict_any_str) == len(dst.dict_any_str) + and src.dict_any_str[1] == dst.dict_any_str[1] + and src.dict_any_str[2.0] == dst.dict_any_str[2.0] + and src.dict_any_str["three"] == dst.dict_any_str["three"] + and src.dict_any_str[4] == dst.dict_any_str[4] + ) + assert ( + src.dict_str_list_int != dst.dict_str_list_int + and len(src.dict_str_list_int) == len(dst.dict_str_list_int) + and tuple(src.dict_str_list_int["1"]) == tuple(dst.dict_str_list_int["1"]) + and tuple(src.dict_str_list_int["2"]) == tuple(dst.dict_str_list_int["2"]) + ) + assert src.opt_i64 == dst.opt_i64 + assert src.opt_f64 == dst.opt_f64 + assert src.opt_raw_ptr.value == dst.opt_raw_ptr.value # type: ignore[union-attr] + assert src.opt_dtype == dst.opt_dtype + assert src.opt_device == dst.opt_device + assert src.opt_func(2) == dst.opt_func(2) # type: ignore[misc] + assert ( + src.opt_ulist != dst.opt_ulist + and len(src.opt_ulist) == len(dst.opt_ulist) # type: ignore[arg-type] + and src.opt_ulist[0] == dst.opt_ulist[0] # type: ignore[index] + and src.opt_ulist[1] == dst.opt_ulist[1] # type: ignore[index] + and src.opt_ulist[2] == dst.opt_ulist[2] # type: ignore[index] + and src.opt_ulist[3]() == dst.opt_ulist[3]() # type: ignore[index] + ) + assert ( + src.opt_udict != dst.opt_udict + and len(src.opt_udict) == len(dst.opt_udict) # type: ignore[arg-type] + and src.opt_udict["1"] == dst.opt_udict["1"] # type: ignore[index] + and src.opt_udict["2"] == dst.opt_udict["2"] # type: ignore[index] + and src.opt_udict["3"] == dst.opt_udict["3"] # type: ignore[index] + and src.opt_udict["4"]() == dst.opt_udict["4"]() # type: ignore[index] + ) + assert src.opt_str == dst.opt_str + assert ( + src.opt_list_any != dst.opt_list_any + and len(src.opt_list_any) == len(dst.opt_list_any) # type: ignore[arg-type] + and src.opt_list_any[0] == dst.opt_list_any[0] # type: ignore[index] + and src.opt_list_any[1] == dst.opt_list_any[1] # type: ignore[index] + and src.opt_list_any[2] == dst.opt_list_any[2] # type: ignore[index] + and src.opt_list_any[3]() == dst.opt_list_any[3]() # type: ignore[index] + ) + assert ( + src.opt_list_list_int != dst.opt_list_list_int + and len(src.opt_list_list_int) == len(dst.opt_list_list_int) # type: ignore[arg-type] + and tuple(src.opt_list_list_int[0]) == tuple(dst.opt_list_list_int[0]) # type: ignore[index] + and tuple(src.opt_list_list_int[1]) == tuple(dst.opt_list_list_int[1]) # type: ignore[index] + ) + assert ( + src.opt_dict_any_any != dst.opt_dict_any_any + and len(src.opt_dict_any_any) == len(dst.opt_dict_any_any) # type: ignore[arg-type] + and src.opt_dict_any_any[1] == dst.opt_dict_any_any[1] # type: ignore[index] + and src.opt_dict_any_any[2.0] == dst.opt_dict_any_any[2.0] # type: ignore[index] + and src.opt_dict_any_any["three"] == dst.opt_dict_any_any["three"] # type: ignore[index] + and src.opt_dict_any_any[4]() == dst.opt_dict_any_any[4]() # type: ignore[index] + ) + assert ( + src.opt_dict_str_any != dst.opt_dict_str_any + and len(src.opt_dict_str_any) == len(dst.opt_dict_str_any) # type: ignore[arg-type] + and src.opt_dict_str_any["1"] == dst.opt_dict_str_any["1"] # type: ignore[index] + and src.opt_dict_str_any["2.0"] == dst.opt_dict_str_any["2.0"] # type: ignore[index] + and src.opt_dict_str_any["three"] == dst.opt_dict_str_any["three"] # type: ignore[index] + and src.opt_dict_str_any["4"]() == dst.opt_dict_str_any["4"]() # type: ignore[index] + ) + assert ( + src.opt_dict_any_str != dst.opt_dict_any_str + and len(src.opt_dict_any_str) == len(dst.opt_dict_any_str) # type: ignore[arg-type] + and src.opt_dict_any_str[1] == dst.opt_dict_any_str[1] # type: ignore[index] + and src.opt_dict_any_str[2.0] == dst.opt_dict_any_str[2.0] # type: ignore[index] + and src.opt_dict_any_str["three"] == dst.opt_dict_any_str["three"] # type: ignore[index] + and src.opt_dict_any_str[4] == dst.opt_dict_any_str[4] # type: ignore[index] + ) + assert ( + src.opt_dict_str_list_int != dst.opt_dict_str_list_int + and len(src.opt_dict_str_list_int) == len(dst.opt_dict_str_list_int) # type: ignore[arg-type] + and tuple(src.opt_dict_str_list_int["1"]) == tuple(dst.opt_dict_str_list_int["1"]) # type: ignore[index] + and tuple(src.opt_dict_str_list_int["2"]) == tuple(dst.opt_dict_str_list_int["2"]) # type: ignore[index] + )