Skip to content

Commit

Permalink
feat(dataclass): Support copy and deepcopy (#4)
Browse files Browse the repository at this point in the history
This PR introduces support for Python's native `__copy__` and `__deepcopy__` method for all MLC dataclasses. This is done by field visitor and topo visitor.
  • Loading branch information
potatomashed authored Jan 8, 2025
1 parent 73bbe0f commit 72fa855
Show file tree
Hide file tree
Showing 14 changed files with 513 additions and 54 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ jobs:
with:
python-version: ${{ env.MLC_PYTHON_VERSION }}
- uses: pre-commit/[email protected]
- uses: ytanikin/[email protected]
with:
task_types: '["feat", "fix", "ci", "chore", "test"]'
add_label: 'false'
windows:
name: Windows
runs-on: windows-latest
Expand Down
3 changes: 3 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
default_install_hook_types:
- pre-commit
- commit-msg
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
Expand Down
71 changes: 38 additions & 33 deletions cpp/json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,26 +31,30 @@ inline mlc::Str Serialize(Any any) {
using TObj2Idx = std::unordered_map<Object *, int32_t>;
using TJsonTypeIndex = decltype(get_json_type_index);
struct Emitter {
MLC_INLINE void operator()(MLCTypeField *, const Any *any) { EmitAny(any); }
// clang-format off
MLC_INLINE void operator()(MLCTypeField *, const Any *any) { EmitAny(any); }
MLC_INLINE void operator()(MLCTypeField *, ObjectRef *obj) { if (Object *v = obj->get()) EmitObject(v); else EmitNil(); }
MLC_INLINE void operator()(MLCTypeField *, Optional<ObjectRef> *opt) { if (Object *v = opt->get()) EmitObject(v); else EmitNil(); }
MLC_INLINE void operator()(MLCTypeField *, Optional<int64_t> *opt) { if (const int64_t *v = opt->get()) EmitInt(*v); else EmitNil(); }
MLC_INLINE void operator()(MLCTypeField *, Optional<double> *opt) { if (const double *v = opt->get()) EmitFloat(*v); else EmitNil(); }
MLC_INLINE void operator()(MLCTypeField *, Optional<DLDevice> *opt) { if (const DLDevice *v = opt->get()) EmitDevice(*v); else EmitNil(); }
MLC_INLINE void operator()(MLCTypeField *, Optional<DLDataType> *opt) { if (const DLDataType *v = opt->get()) EmitDType(*v); else EmitNil(); }
MLC_INLINE void operator()(MLCTypeField *, int8_t *v) { EmitInt(static_cast<int64_t>(*v)); }
MLC_INLINE void operator()(MLCTypeField *, int16_t *v) { EmitInt(static_cast<int64_t>(*v)); }
MLC_INLINE void operator()(MLCTypeField *, int32_t *v) { EmitInt(static_cast<int64_t>(*v)); }
MLC_INLINE void operator()(MLCTypeField *, int64_t *v) { EmitInt(static_cast<int64_t>(*v)); }
MLC_INLINE void operator()(MLCTypeField *, float *v) { EmitFloat(static_cast<double>(*v)); }
MLC_INLINE void operator()(MLCTypeField *, double *v) { EmitFloat(static_cast<double>(*v)); }
MLC_INLINE void operator()(MLCTypeField *, DLDataType *v) { EmitDType(*v); }
MLC_INLINE void operator()(MLCTypeField *, DLDevice *v) { EmitDevice(*v); }
MLC_INLINE void operator()(MLCTypeField *, Optional<void *> *) { MLC_THROW(TypeError) << "Unserializable type: void *"; }
MLC_INLINE void operator()(MLCTypeField *, void **) { MLC_THROW(TypeError) << "Unserializable type: void *"; }
MLC_INLINE void operator()(MLCTypeField *, const char **) { MLC_THROW(TypeError) << "Unserializable type: const char *"; }
MLC_INLINE void operator()(MLCTypeField *, ObjectRef *obj) { if (Object *v = obj->get()) EmitObject(v); else EmitNil(); }
MLC_INLINE void operator()(MLCTypeField *, Optional<ObjectRef> *opt) { if (Object *v = opt->get()) EmitObject(v); else EmitNil(); }
MLC_INLINE void operator()(MLCTypeField *, Optional<int64_t> *opt) { if (const int64_t *v = opt->get()) EmitInt(*v); else EmitNil(); }
MLC_INLINE void operator()(MLCTypeField *, Optional<double> *opt) { if (const double *v = opt->get()) EmitFloat(*v); else EmitNil(); }
MLC_INLINE void operator()(MLCTypeField *, Optional<DLDevice> *opt) { if (const DLDevice *v = opt->get()) EmitDevice(*v); else EmitNil(); }
MLC_INLINE void operator()(MLCTypeField *, Optional<DLDataType> *opt) { if (const DLDataType *v = opt->get()) EmitDType(*v); else EmitNil(); }
// clang-format on
MLC_INLINE void operator()(MLCTypeField *, int8_t *v) { EmitInt(static_cast<int64_t>(*v)); }
MLC_INLINE void operator()(MLCTypeField *, int16_t *v) { EmitInt(static_cast<int64_t>(*v)); }
MLC_INLINE void operator()(MLCTypeField *, int32_t *v) { EmitInt(static_cast<int64_t>(*v)); }
MLC_INLINE void operator()(MLCTypeField *, int64_t *v) { EmitInt(static_cast<int64_t>(*v)); }
MLC_INLINE void operator()(MLCTypeField *, float *v) { EmitFloat(static_cast<double>(*v)); }
MLC_INLINE void operator()(MLCTypeField *, double *v) { EmitFloat(static_cast<double>(*v)); }
MLC_INLINE void operator()(MLCTypeField *, DLDataType *v) { EmitDType(*v); }
MLC_INLINE void operator()(MLCTypeField *, DLDevice *v) { EmitDevice(*v); }
MLC_INLINE void operator()(MLCTypeField *, Optional<void *> *) {
MLC_THROW(TypeError) << "Unserializable type: void *";
}
MLC_INLINE void operator()(MLCTypeField *, void **) { MLC_THROW(TypeError) << "Unserializable type: void *"; }
MLC_INLINE void operator()(MLCTypeField *, const char **) {
MLC_THROW(TypeError) << "Unserializable type: const char *";
}
inline void EmitNil() { (*os) << ", null"; }
inline void EmitFloat(double v) { (*os) << ", " << std::fixed << std::setprecision(19) << v; }
inline void EmitInt(int64_t v) {
Expand Down Expand Up @@ -98,10 +102,17 @@ inline mlc::Str Serialize(Any any) {
const TObj2Idx *obj2index;
};

std::unordered_map<Object *, int32_t> topo_indices;
std::ostringstream os;
auto on_visit = [get_json_type_index = &get_json_type_index, os = &os, is_first_object = true](
Object *object, MLCTypeInfo *type_info, const TObj2Idx &obj2index) mutable -> void {
Emitter emitter{os, get_json_type_index, &obj2index};
auto on_visit = [&topo_indices, get_json_type_index = &get_json_type_index, os = &os,
is_first_object = true](Object *object, MLCTypeInfo *type_info) mutable -> void {
int32_t &topo_index = topo_indices[object];
if (topo_index == 0) {
topo_index = static_cast<int32_t>(topo_indices.size()) - 1;
} else {
MLC_THROW(InternalError) << "This should never happen: object already visited";
}
Emitter emitter{os, get_json_type_index, &topo_indices};
if (is_first_object) {
is_first_object = false;
} else {
Expand Down Expand Up @@ -163,29 +174,23 @@ inline mlc::Str Serialize(Any any) {
}

inline Any Deserialize(const char *json_str, int64_t json_str_len) {
MLCVTableHandle init_vtable;
MLCVTableGetGlobal(nullptr, "__init__", &init_vtable);
MLCVTableHandle init_table = ::mlc::base::LibState::init;
// Step 0. Parse JSON string
UDict json_obj = JSONLoads(json_str, json_str_len);
// Step 1. type_key => constructors
UList type_keys = json_obj->at("type_keys");
std::vector<Func> constructors;
std::vector<FuncObj *> constructors;
constructors.reserve(type_keys.size());
for (Str type_key : type_keys) {
Any init_func;
int32_t type_index = ::mlc::base::TypeKey2TypeIndex(type_key->data());
MLCVTableGetFunc(init_vtable, type_index, false, &init_func);
if (!::mlc::base::IsTypeIndexNone(init_func.type_index)) {
constructors.push_back(init_func.operator Func());
} else {
MLC_THROW(InternalError) << "Method `__init__` is not defined for type " << type_key;
}
FuncObj *func = ::mlc::base::LibState::VTableGetFunc(init_table, type_index, "__init__");
constructors.push_back(func);
}
auto invoke_init = [&constructors](UList args) {
int32_t json_type_index = args[0];
Any ret;
::mlc::base::FuncCall(constructors.at(json_type_index).get(), static_cast<int32_t>(args.size()) - 1,
args->data() + 1, &ret);
::mlc::base::FuncCall(constructors.at(json_type_index), static_cast<int32_t>(args.size()) - 1, args->data() + 1,
&ret);
return ret;
};
// Step 2. Translate JSON object to objects
Expand Down
131 changes: 131 additions & 0 deletions cpp/structure.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "mlc/core/error.h"
#include <algorithm>
#include <cmath>
#include <cstdint>
Expand Down Expand Up @@ -532,11 +533,141 @@ inline uint64_t StructuralHash(Object *obj) {
#undef MLC_CORE_HASH_S_POD
#undef MLC_CORE_HASH_S_ANY

inline Any CopyShallow(AnyView source) {
int32_t type_index = source.type_index;
if (::mlc::base::IsTypeIndexPOD(type_index)) {
return source;
} else if (UListObj *list = source.TryCast<UListObj>()) {
return UList(list->begin(), list->end());
} else if (UDictObj *dict = source.TryCast<UDictObj>()) {
return UDict(dict->begin(), dict->end());
} else if (source.IsInstance<StrObj>() || source.IsInstance<ErrorObj>() || source.IsInstance<FuncObj>()) {
return source;
}
struct Copier {
MLC_INLINE void operator()(MLCTypeField *, const Any *any) { fields->push_back(AnyView(*any)); }
MLC_INLINE void operator()(MLCTypeField *, ObjectRef *obj) { fields->push_back(AnyView(*obj)); }
MLC_INLINE void operator()(MLCTypeField *, Optional<ObjectRef> *opt) { fields->push_back(AnyView(*opt)); }
MLC_INLINE void operator()(MLCTypeField *, Optional<int64_t> *opt) { fields->push_back(AnyView(*opt)); }
MLC_INLINE void operator()(MLCTypeField *, Optional<double> *opt) { fields->push_back(AnyView(*opt)); }
MLC_INLINE void operator()(MLCTypeField *, Optional<DLDevice> *opt) { fields->push_back(AnyView(*opt)); }
MLC_INLINE void operator()(MLCTypeField *, Optional<DLDataType> *opt) { fields->push_back(AnyView(*opt)); }
MLC_INLINE void operator()(MLCTypeField *, int8_t *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, int16_t *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, int32_t *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, int64_t *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, float *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, double *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, DLDataType *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, DLDevice *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, Optional<void *> *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, void **v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, const char **v) { fields->push_back(AnyView(*v)); }
std::vector<AnyView> *fields;
};
FuncObj *init_func = ::mlc::base::LibState::VTableGetFunc(::mlc::base::LibState::init, type_index, "__init__");
MLCTypeInfo *type_info = ::mlc::base::TypeIndex2TypeInfo(type_index);
std::vector<AnyView> fields;
VisitFields(source.operator Object *(), type_info, Copier{&fields});
Any ret;
::mlc::base::FuncCall(init_func, static_cast<int32_t>(fields.size()), fields.data(), &ret);
return ret;
}

inline Any CopyDeep(AnyView source) {
if (::mlc::base::IsTypeIndexPOD(source.type_index)) {
return source;
}
struct Copier {
MLC_INLINE void operator()(MLCTypeField *, const Any *any) { HandleAny(any); }
MLC_INLINE void operator()(MLCTypeField *, ObjectRef *ref) {
if (const Object *obj = ref->get()) {
HandleObject(obj);
} else {
fields->push_back(AnyView());
}
}
MLC_INLINE void operator()(MLCTypeField *, Optional<ObjectRef> *opt) {
if (const Object *obj = opt->get()) {
HandleObject(obj);
} else {
fields->push_back(AnyView());
}
}
MLC_INLINE void operator()(MLCTypeField *, Optional<int64_t> *opt) { fields->push_back(AnyView(*opt)); }
MLC_INLINE void operator()(MLCTypeField *, Optional<double> *opt) { fields->push_back(AnyView(*opt)); }
MLC_INLINE void operator()(MLCTypeField *, Optional<DLDevice> *opt) { fields->push_back(AnyView(*opt)); }
MLC_INLINE void operator()(MLCTypeField *, Optional<DLDataType> *opt) { fields->push_back(AnyView(*opt)); }
MLC_INLINE void operator()(MLCTypeField *, int8_t *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, int16_t *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, int32_t *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, int64_t *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, float *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, double *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, DLDataType *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, DLDevice *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, Optional<void *> *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, void **v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, const char **v) { fields->push_back(AnyView(*v)); }

void HandleObject(const Object *obj) {
if (auto it = orig2copy->find(obj); it != orig2copy->end()) {
fields->push_back(AnyView(it->second));
} else {
MLC_THROW(InternalError) << "InternalError: object doesn't exist in the memo: " << AnyView(obj);
}
}

void HandleAny(const Any *any) {
if (const Object *obj = any->TryCast<Object>()) {
HandleObject(obj);
} else {
fields->push_back(AnyView(*any));
}
}

std::unordered_map<const Object *, ObjectRef> *orig2copy;
std::vector<AnyView> *fields;
};
std::unordered_map<const Object *, ObjectRef> orig2copy;
std::vector<AnyView> fields;
TopoVisit(source.operator Object *(), nullptr, [&](Object *object, MLCTypeInfo *type_info) mutable -> void {
Any ret;
if (UListObj *list = object->TryCast<UListObj>()) {
fields.clear();
fields.reserve(list->size());
for (Any &e : *list) {
Copier{&orig2copy, &fields}.HandleAny(&e);
}
UList::FromAnyTuple(static_cast<int32_t>(fields.size()), fields.data(), &ret);
} else if (UDictObj *dict = object->TryCast<UDictObj>()) {
fields.clear();
for (auto [key, value] : *dict) {
Copier{&orig2copy, &fields}.HandleAny(&key);
Copier{&orig2copy, &fields}.HandleAny(&value);
}
UDict::FromAnyTuple(static_cast<int32_t>(fields.size()), fields.data(), &ret);
} else if (object->IsInstance<StrObj>() || object->IsInstance<ErrorObj>() || object->IsInstance<FuncObj>()) {
ret = object;
} else {
fields.clear();
VisitFields(object, type_info, Copier{&orig2copy, &fields});
FuncObj *func =
::mlc::base::LibState::VTableGetFunc(::mlc::base::LibState::init, type_info->type_index, "__init__");
::mlc::base::FuncCall(func, static_cast<int32_t>(fields.size()), fields.data(), &ret);
}
orig2copy[object] = ret.operator ObjectRef();
});
return orig2copy.at(source.operator Object *());
}

MLC_REGISTER_FUNC("mlc.core.StructuralEqual").set_body(::mlc::core::StructuralEqual);
MLC_REGISTER_FUNC("mlc.core.StructuralHash").set_body([](::mlc::Object *obj) -> int64_t {
uint64_t ret = ::mlc::core::StructuralHash(obj);
return static_cast<int64_t>(ret);
});
MLC_REGISTER_FUNC("mlc.core.CopyShallow").set_body(::mlc::core::CopyShallow);
MLC_REGISTER_FUNC("mlc.core.CopyDeep").set_body(::mlc::core::CopyDeep);
} // namespace
} // namespace core
} // namespace mlc
2 changes: 1 addition & 1 deletion include/mlc/base/all.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ template <typename T> MLC_INLINE AnyView::AnyView(Ref<T> &&src) : AnyView(static
// `src` is not reset here because `AnyView` does not take ownership of the object
}

template <typename T> MLC_INLINE AnyView::AnyView(const Optional<T> &src) {
template <typename T> MLC_INLINE AnyView::AnyView(const Optional<T> &src) : MLCAny() {
if (const auto *value = src.get()) {
if constexpr (::mlc::base::IsPOD<T>) {
using TPOD = T;
Expand Down
7 changes: 6 additions & 1 deletion include/mlc/base/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,11 @@ struct LibState {
DecRef(func.v.v_obj);
}
FuncObj *ret = reinterpret_cast<FuncObj *>(func.v.v_obj);
if (func.type_index != kMLCFunc) {
if (func.type_index == kMLCNone) {
MLC_THROW(TypeError) << "Function `" << vtable_name
<< "` for type: " << ::mlc::base::TypeIndex2TypeKey(type_index)
<< " is not defined in the vtable";
} else if (func.type_index != kMLCFunc) {
MLC_THROW(TypeError) << "Function `" << vtable_name
<< "` for type: " << ::mlc::base::TypeIndex2TypeKey(type_index)
<< " is not callable. Its type is " << ::mlc::base::TypeIndex2TypeKey(func.type_index);
Expand All @@ -401,6 +405,7 @@ struct LibState {
static MLC_SYMBOL_HIDE inline MLCVTableHandle cxx_str = VTableGetGlobal("__cxx_str__");
static MLC_SYMBOL_HIDE inline MLCVTableHandle str = VTableGetGlobal("__str__");
static MLC_SYMBOL_HIDE inline MLCVTableHandle ir_print = VTableGetGlobal("__ir_print__");
static MLC_SYMBOL_HIDE inline MLCVTableHandle init = VTableGetGlobal("__init__");
};

} // namespace base
Expand Down
5 changes: 4 additions & 1 deletion include/mlc/core/dict.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,14 @@ struct UDict : public ObjectRef {
MLC_INLINE const_iterator end() const { return get()->end(); }
MLC_INLINE const_reverse_iterator rbegin() const { return get()->rbegin(); }
MLC_INLINE const_reverse_iterator rend() const { return get()->rend(); }
MLC_INLINE static void FromAnyTuple(int32_t num_args, const AnyView *args, Any *ret) {
::mlc::core::DictBase::Accessor<UDictObj>::New(num_args, args, ret);
}
MLC_DEF_OBJ_REF(UDict, UDictObj, ObjectRef)
.FieldReadOnly("capacity", &MLCDict::capacity)
.FieldReadOnly("size", &MLCDict::size)
.FieldReadOnly("data", &MLCDict::data)
.StaticFn("__init__", ::mlc::core::DictBase::Accessor<UDictObj>::New)
.StaticFn("__init__", FromAnyTuple)
.MemFn("__str__", &UDictObj::__str__)
.MemFn("__getitem__", ::mlc::core::DictBase::Accessor<UDictObj>::GetItem)
.MemFn("__iter_get_key__", ::mlc::core::DictBase::Accessor<UDictObj>::GetKey)
Expand Down
17 changes: 4 additions & 13 deletions include/mlc/core/field_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,7 @@ template <typename Visitor> inline void VisitStructure(Object *root, MLCTypeInfo
}

inline void TopoVisit(Object *root, std::function<void(Object *object, MLCTypeInfo *type_info)> pre_visit,
std::function<void(Object *object, MLCTypeInfo *type_info,
const std::unordered_map<Object *, int32_t> &topo_indices)>
on_visit) {
std::function<void(Object *object, MLCTypeInfo *type_info)> on_visit) {
struct TopoInfo {
Object *obj;
MLCTypeInfo *type_info;
Expand Down Expand Up @@ -271,20 +269,13 @@ inline void TopoVisit(Object *root, std::function<void(Object *object, MLCTypeIn
}
}
// Step 3. Traverse the graph by topological order
std::unordered_map<Object *, int32_t> topo_indices;
size_t num_objects = 0;
for (; !stack.empty(); ++num_objects) {
TopoInfo *current = stack.back();
stack.pop_back();
// Step 3.1. Lable object index
int32_t &topo_index = topo_indices[current->obj];
if (topo_index != 0) {
MLC_THROW(InternalError) << "This should never happen: object already visited";
}
topo_index = static_cast<int32_t>(num_objects);
// Step 3.2. Visit object
on_visit(current->obj, current->type_info, topo_indices);
// Step 3.3. Decrease the dependency count of topo_parents
// Step 3.1. Visit object
on_visit(current->obj, current->type_info);
// Step 3.2. Decrease the dependency count of topo_parents
for (TopoInfo *parent : current->topo_parents) {
if (--parent->topo_deps == 0) {
stack.push_back(parent);
Expand Down
Loading

0 comments on commit 72fa855

Please sign in to comment.