Skip to content

Commit

Permalink
Introduce Structural Equality Checks
Browse files Browse the repository at this point in the history
  • Loading branch information
potatomashed committed Nov 30, 2024
1 parent 574d5ae commit b8c78bc
Show file tree
Hide file tree
Showing 36 changed files with 1,907 additions and 752 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.15)

project(
mlc
VERSION 0.0.9
VERSION 0.0.10
DESCRIPTION "MLC-Python"
LANGUAGES C CXX
)
Expand Down
132 changes: 110 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,44 +4,132 @@
MLC-Python
</h1>

* [:key: Key features](#keykey-features)
* [:inbox_tray: Installation](#inbox_trayinstallation)
+ [:package: Install From PyPI](#packageinstall-from-pypi)
+ [:gear: Build from Source](#gearbuild-from-source)
+ [:ferris_wheel: Create MLC-Python Wheels](#ferris_wheel-create-mlc-python-wheels)
* [:inbox_tray: Installation](#inbox_tray-installation)
* [:key: Key Features](#key-key-features)
+ [:building_construction: MLC Dataclass](#building_construction-mlc-dataclass)
+ [:dart: Structure-Aware Tooling](#dart-structure-aware-tooling)
+ [:snake: Text Formats in Python AST](#snake-text-formats-in-python-ast)
+ [:zap: Zero-Copy Interoperability with C++ Plugins](#zap-zero-copy-interoperability-with-c-plugins)
* [:fuelpump: Development](#fuelpump-development)
+ [:gear: Editable Build](#gear-editable-build)
+ [:ferris_wheel: Create Wheels](#ferris_wheel-create-wheels)

MLC is a Python-first toolkit that makes it more ergonomic to build AI compilers, runtimes, and compound AI systems. It provides Pythonic dataclasses with rich tooling infra, which includes:

- Structure-aware equality and hashing methods;
- Serialization in JSON / pickle;
- Text format printing and parsing in Python syntax.
MLC is a Python-first toolkit that makes it more ergonomic to build AI compilers, runtimes, and compound AI systems with Pythonic dataclass, rich tooling infra and zero-copy interoperability with C++ plugins.

Additionally, MLC language bindings support:
## :inbox_tray: Installation

- Zero-copy bidirectional functioning calling for all MLC dataclasses.
```bash
pip install -U mlc-python
```

## :key: Key features
## :key: Key Features

TBD
### :building_construction: MLC Dataclass

## :inbox_tray: Installation
MLC dataclass is similar to Python’s native dataclass:

### :package: Install From PyPI
```python
import mlc.dataclasses as mlcd

```bash
pip install -U mlc-python
@mlcd.py_class("demo.MyClass")
class MyClass(mlcd.PyClass):
a: int
b: str
c: float | None

instance = MyClass(12, "test", c=None)
```

**Type safety**. MLC dataclass checks type strictly in Cython and C++.

```python
>>> instance.c = 10; print(instance)
demo.MyClass(a=12, b='test', c=10.0)

>>> instance.c = "wrong type"
TypeError: must be real number, not str

>>> instance.non_exist = 1
AttributeError: 'MyClass' object has no attribute 'non_exist' and no __dict__ for setting new attributes
```

**Serialization**. MLC dataclasses are picklable and JSON-serializable.

```python
>>> MyClass.from_json(instance.json())
demo.MyClass(a=12, b='test', c=None)

>>> import pickle; pickle.loads(pickle.dumps(instance))
demo.MyClass(a=12, b='test', c=None)
```

### :gear: Build from Source
### :dart: Structure-Aware Tooling

An extra `structure` field are used to specify a dataclass's structure, indicating def site and scoping in an IR.

```python
import mlc.dataclasses as mlcd

@mlcd.py_class
class Expr(mlcd.PyClass):
def __add__(self, other):
return Add(a=self, b=other)

@mlcd.py_class(structure="nobind")
class Add(Expr):
a: Expr
b: Expr

@mlcd.py_class(structure="var")
class Var(Expr):
name: str = mlcd.field(structure=None) # excludes `name` from defined structure

@mlcd.py_class(structure="bind")
class Let(Expr):
rhs: Expr
lhs: Var = mlcd.field(structure="bind") # `Let.lhs` is the def-site
body: Expr
```

**Structural equality**. Method eq_s is ready to use to compare the structural equality (alpha equivalence) of two IRs.

```python
"""
L1: let z = x + y; z
L2: let x = y + z; x
L3: let z = x + x; z
"""
>>> x, y, z = Var("x"), Var("y"), Var("z")
>>> L1 = Let(rhs=x + y, lhs=z, body=z)
>>> L2 = Let(rhs=y + z, lhs=x, body=x)
>>> L3 = Let(rhs=x + x, lhs=z, body=z)
>>> L1.eq_s(L2)
True
>>> L1.eq_s(L3, assert_mode=True)
ValueError: Structural equality check failed at {root}.rhs.b: Inconsistent binding. RHS has been bound to a different node while LHS is not bound
```

**Structural hashing**. TBD

### :snake: Text Formats in Python AST

TBD

### :zap: Zero-Copy Interoperability with C++ Plugins

TBD

## :fuelpump: Development

### :gear: Editable Build

```bash
python -m venv .venv
source .venv/bin/activate
python -m pip install --verbose --editable ".[dev]"
pip install --verbose --editable ".[dev]"
pre-commit install
```

### :ferris_wheel: Create MLC-Python Wheels
### :ferris_wheel: Create Wheels

This project uses `cibuildwheel` to build cross-platform wheels. See `.github/workflows/wheels.ym` for more details.

Expand Down
22 changes: 14 additions & 8 deletions cpp/c_api.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "./registry.h"
#include "mlc/core/str.h"
#include <mlc/all.h>

namespace mlc {
namespace registry {
Expand All @@ -20,22 +20,23 @@ using ::mlc::registry::TypeTable;
namespace {
thread_local Any last_error;
MLC_REGISTER_FUNC("mlc.ffi.LoadDSO").set_body([](std::string name) { TypeTable::Get(nullptr)->LoadDSO(name); });
MLC_REGISTER_FUNC("mlc.core.JSONParse").set_body([](AnyView json_str) {
MLC_REGISTER_FUNC("mlc.core.JSONLoads").set_body([](AnyView json_str) {
if (json_str.type_index == kMLCRawStr) {
return ::mlc::core::ParseJSON(json_str.operator const char *());
return ::mlc::core::JSONLoads(json_str.operator const char *());
} else {
::mlc::Str str = json_str;
return ::mlc::core::ParseJSON(str);
return ::mlc::core::JSONLoads(str);
}
});
MLC_REGISTER_FUNC("mlc.core.JSONSerialize").set_body(::mlc::core::Serialize);
MLC_REGISTER_FUNC("mlc.core.JSONSerialize").set_body(::mlc::core::Serialize); // TODO: `AnyView` as function argument
MLC_REGISTER_FUNC("mlc.core.JSONDeserialize").set_body([](AnyView json_str) {
if (json_str.type_index == kMLCRawStr) {
return ::mlc::core::Deserialize(json_str.operator const char *());
} else {
return ::mlc::core::Deserialize(json_str.operator ::mlc::Str());
}
});
MLC_REGISTER_FUNC("mlc.core.StructuralEqual").set_body(::mlc::core::StructuralEqual);
} // namespace

MLC_API MLCAny MLCGetLastError() {
Expand Down Expand Up @@ -64,9 +65,14 @@ MLC_API int32_t MLCTypeKey2Info(MLCTypeTableHandle _self, const char *type_key,
}

MLC_API int32_t MLCTypeDefReflection(MLCTypeTableHandle self, int32_t type_index, int64_t num_fields,
MLCTypeField *fields, int64_t num_methods, MLCTypeMethod *methods) {
MLCTypeField *fields, int64_t num_methods, MLCTypeMethod *methods,
int32_t structure_kind, int64_t num_sub_structures, int32_t *sub_structure_indices,
int32_t *sub_structure_kinds) {
MLC_SAFE_CALL_BEGIN();
TypeTable::Get(self)->TypeDefReflection(type_index, num_fields, fields, num_methods, methods);
auto *type_info = TypeTable::Get(self)->GetTypeInfoWrapper(type_index);
type_info->SetFields(num_fields, fields);
type_info->SetMethods(num_methods, methods);
type_info->SetStructure(structure_kind, num_sub_structures, sub_structure_indices, sub_structure_kinds);
MLC_SAFE_CALL_END(&last_error);
}

Expand Down Expand Up @@ -168,7 +174,7 @@ MLC_API int32_t MLCExtObjCreate(int32_t num_bytes, int32_t type_index, MLCAny *r

MLC_API int32_t _MLCExtObjDeleteImpl(void *objptr) {
MLC_SAFE_CALL_BEGIN();
::mlc::core::DeleteExternObject(static_cast<MLCAny *>(objptr));
::mlc::core::DeleteExternObject(static_cast<::mlc::Object *>(objptr));
MLC_SAFE_CALL_END(&last_error);
}

Expand Down
44 changes: 39 additions & 5 deletions cpp/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,11 @@ struct TypeInfoWrapper {
void Reset();
void ResetFields();
void ResetMethods();
void ResetStructure();
void SetFields(int64_t new_num_fields, MLCTypeField *fields);
void SetMethods(int64_t new_num_methods, MLCTypeMethod *methods);
void SetStructure(int32_t structure_kind, int64_t num_sub_structures, int32_t *sub_structure_indices,
int32_t *sub_structure_kinds);
~TypeInfoWrapper() { this->Reset(); }
};

Expand Down Expand Up @@ -210,6 +213,9 @@ struct TypeTable {
}
info->fields = nullptr;
info->methods = nullptr;
info->structure_kind = 0;
info->sub_structure_indices = nullptr;
info->sub_structure_kinds = nullptr;
wrapper->table = this;
return info;
}
Expand All @@ -231,9 +237,7 @@ struct TypeTable {
this->NewObjPtr(&it->second, func);
}

void TypeDefReflection(int32_t type_index, //
int64_t num_fields, MLCTypeField *fields, //
int64_t num_methods, MLCTypeMethod *methods) {
TypeInfoWrapper *GetTypeInfoWrapper(int32_t type_index) {
TypeInfoWrapper *wrapper = nullptr;
try {
wrapper = this->type_table.at(type_index).get();
Expand All @@ -242,8 +246,7 @@ struct TypeTable {
if (wrapper == nullptr || wrapper->table != this) {
MLC_THROW(KeyError) << "Type index `" << type_index << "` not registered";
}
wrapper->SetFields(num_fields, fields);
wrapper->SetMethods(num_methods, methods);
return wrapper;
}

void LoadDSO(std::string name) {
Expand Down Expand Up @@ -352,6 +355,15 @@ inline void TypeInfoWrapper::ResetMethods() {
}
}

inline void TypeInfoWrapper::ResetStructure() {
if (this->info.sub_structure_indices) {
this->table->DelArray(this->info.sub_structure_indices);
}
if (this->info.sub_structure_kinds) {
this->table->DelArray(this->info.sub_structure_kinds);
}
}

inline void TypeInfoWrapper::SetFields(int64_t new_num_fields, MLCTypeField *fields) {
this->ResetFields();
this->num_fields = new_num_fields;
Expand All @@ -360,6 +372,9 @@ inline void TypeInfoWrapper::SetFields(int64_t new_num_fields, MLCTypeField *fie
dst[i] = fields[i];
dst[i].name = this->table->NewArray(fields[i].name);
this->table->NewObjPtr(&dst[i].ty, dst[i].ty);
if (dst[i].index != i) {
MLC_THROW(ValueError) << "Field index mismatch: " << i << " vs " << dst[i].index;
}
}
dst[num_fields] = MLCTypeField{};
std::sort(dst, dst + num_fields, [](const MLCTypeField &a, const MLCTypeField &b) { return a.offset < b.offset; });
Expand All @@ -382,6 +397,25 @@ inline void TypeInfoWrapper::SetMethods(int64_t new_num_methods, MLCTypeMethod *
[](const MLCTypeMethod &a, const MLCTypeMethod &b) { return std::strcmp(a.name, b.name) < 0; });
}

inline void TypeInfoWrapper::SetStructure(int32_t structure_kind, int64_t num_sub_structures,
int32_t *sub_structure_indices, int32_t *sub_structure_kinds) {
this->ResetStructure();
this->info.structure_kind = structure_kind;
if (num_sub_structures > 0) {
this->info.sub_structure_indices = this->table->NewArray<int32_t>(num_sub_structures + 1);
this->info.sub_structure_kinds = this->table->NewArray<int32_t>(num_sub_structures + 1);
std::memcpy(this->info.sub_structure_indices, sub_structure_indices, num_sub_structures * sizeof(int32_t));
std::memcpy(this->info.sub_structure_kinds, sub_structure_kinds, num_sub_structures * sizeof(int32_t));
std::reverse(this->info.sub_structure_indices, this->info.sub_structure_indices + num_sub_structures);
std::reverse(this->info.sub_structure_kinds, this->info.sub_structure_kinds + num_sub_structures);
this->info.sub_structure_indices[num_sub_structures] = -1;
this->info.sub_structure_kinds[num_sub_structures] = -1;
} else {
this->info.sub_structure_indices = nullptr;
this->info.sub_structure_kinds = nullptr;
}
}

} // namespace registry
} // namespace mlc

Expand Down
4 changes: 4 additions & 0 deletions include/mlc/base/any.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ struct AnyView : public MLCAny {
}
/***** Misc *****/
bool defined() const { return this->type_index != static_cast<int32_t>(MLCTypeIndex::kMLCNone); }
const char *GetTypeKey() const { return ::mlc::base::TypeIndex2TypeInfo(this->type_index)->type_key; }
int32_t GetTypeIndex() const { return this->type_index; }
Str str() const;
friend std::ostream &operator<<(std::ostream &os, const AnyView &src);

Expand Down Expand Up @@ -76,6 +78,8 @@ struct Any : public MLCAny {
}
/***** Misc *****/
bool defined() const { return this->type_index != static_cast<int32_t>(MLCTypeIndex::kMLCNone); }
const char *GetTypeKey() const { return ::mlc::base::TypeIndex2TypeInfo(this->type_index)->type_key; }
int32_t GetTypeIndex() const { return this->type_index; }
Str str() const;
friend std::ostream &operator<<(std::ostream &os, const Any &src);

Expand Down
Loading

0 comments on commit b8c78bc

Please sign in to comment.