diff --git a/pyformlang/objects/base_epsilon.py b/pyformlang/objects/base_epsilon.py index 926c859..17c06b2 100644 --- a/pyformlang/objects/base_epsilon.py +++ b/pyformlang/objects/base_epsilon.py @@ -30,3 +30,6 @@ def __hash__(self) -> int: def __repr__(self) -> str: return "epsilon" + + def _is_equal_to(self, other: FormalObject) -> bool: + return isinstance(other, BaseEpsilon) diff --git a/pyformlang/objects/base_terminal.py b/pyformlang/objects/base_terminal.py index 3f5d48b..1b37802 100644 --- a/pyformlang/objects/base_terminal.py +++ b/pyformlang/objects/base_terminal.py @@ -1,6 +1,5 @@ """ General terminal representation """ -from typing import Any from abc import abstractmethod from .formal_object import FormalObject @@ -9,16 +8,9 @@ class BaseTerminal(FormalObject): """ General terminal representation """ - def __eq__(self, other: Any) -> bool: - if isinstance(other, BaseTerminal): - return self.value == other.value - if isinstance(other, FormalObject): - return False - return self.value == other - - def __hash__(self) -> int: - return super().__hash__() - @abstractmethod def __repr__(self): raise NotImplementedError + + def _is_equal_to(self, other: FormalObject) -> bool: + return isinstance(other, BaseTerminal) and self.value == other.value diff --git a/pyformlang/objects/cfg_objects/variable.py b/pyformlang/objects/cfg_objects/variable.py index 46a031a..85c78f5 100644 --- a/pyformlang/objects/cfg_objects/variable.py +++ b/pyformlang/objects/cfg_objects/variable.py @@ -1,6 +1,5 @@ """ A variable in a CFG """ -from typing import Any from string import ascii_uppercase from .cfg_object import CFGObject @@ -16,16 +15,6 @@ class Variable(CFGObject): The value of the variable """ - def __eq__(self, other: Any) -> bool: - if isinstance(other, Variable): - return self.value == other.value - if isinstance(other, FormalObject): - return False - return self.value == other - - def __hash__(self) -> int: - return super().__hash__() - def __repr__(self) -> str: return f"Variable({self})" @@ -34,3 +23,6 @@ def to_text(self) -> str: if text and text[0] not in ascii_uppercase: return '"VAR:' + text + '"' return text + + def _is_equal_to(self, other: FormalObject) -> bool: + return isinstance(other, Variable) and self.value == other.value diff --git a/pyformlang/objects/finite_automaton_objects/state.py b/pyformlang/objects/finite_automaton_objects/state.py index d8fb45c..41244c3 100644 --- a/pyformlang/objects/finite_automaton_objects/state.py +++ b/pyformlang/objects/finite_automaton_objects/state.py @@ -2,8 +2,6 @@ Representation of a state in a finite state automaton """ -from typing import Any - from .finite_automaton_object import FiniteAutomatonObject from ..formal_object import FormalObject @@ -24,15 +22,8 @@ class State(FiniteAutomatonObject): """ - def __eq__(self, other: Any) -> bool: - if isinstance(other, State): - return self.value == other.value - if isinstance(other, FormalObject): - return False - return self.value == other - - def __hash__(self) -> int: - return super().__hash__() - def __repr__(self) -> str: return f"State({self})" + + def _is_equal_to(self, other: FormalObject) -> bool: + return isinstance(other, State) and self.value == other.value diff --git a/pyformlang/objects/formal_object.py b/pyformlang/objects/formal_object.py index 4af8560..31d88f5 100644 --- a/pyformlang/objects/formal_object.py +++ b/pyformlang/objects/formal_object.py @@ -23,9 +23,10 @@ def value(self) -> Hashable: """ return self._value - @abstractmethod def __eq__(self, other: Any) -> bool: - raise NotImplementedError + if not isinstance(other, FormalObject): + return self.value == other + return self._is_equal_to(other) and other._is_equal_to(self) def __hash__(self) -> int: if self._hash is None: @@ -38,3 +39,7 @@ def __str__(self) -> str: @abstractmethod def __repr__(self) -> str: raise NotImplementedError + + @abstractmethod + def _is_equal_to(self, other: "FormalObject") -> bool: + raise NotImplementedError diff --git a/pyformlang/objects/pda_objects/stack_symbol.py b/pyformlang/objects/pda_objects/stack_symbol.py index 0fb4c91..c22f29e 100644 --- a/pyformlang/objects/pda_objects/stack_symbol.py +++ b/pyformlang/objects/pda_objects/stack_symbol.py @@ -1,7 +1,5 @@ """ A StackSymbol in a pushdown automaton """ -from typing import Any - from .symbol import Symbol from ..formal_object import FormalObject @@ -16,15 +14,8 @@ class StackSymbol(Symbol): """ - def __eq__(self, other: Any) -> bool: - if isinstance(other, StackSymbol): - return self.value == other.value - if isinstance(other, FormalObject): - return False - return self.value == other - - def __hash__(self) -> int: - return super().__hash__() - def __repr__(self) -> str: return f"StackSymbol({self})" + + def _is_equal_to(self, other: FormalObject) -> bool: + return isinstance(other, StackSymbol) and self.value == other.value diff --git a/pyformlang/objects/pda_objects/state.py b/pyformlang/objects/pda_objects/state.py index b174354..8b0a385 100644 --- a/pyformlang/objects/pda_objects/state.py +++ b/pyformlang/objects/pda_objects/state.py @@ -1,7 +1,5 @@ """ A State in a pushdown automaton """ -from typing import Any - from .pda_object import PDAObject from ..formal_object import FormalObject @@ -16,15 +14,8 @@ class State(PDAObject): """ - def __eq__(self, other: Any) -> bool: - if isinstance(other, State): - return self.value == other.value - if isinstance(other, FormalObject): - return False - return self.value == other - - def __hash__(self) -> int: - return super().__hash__() - def __repr__(self) -> str: return f"State({self})" + + def _is_equal_to(self, other: FormalObject) -> bool: + return isinstance(other, State) and self.value == other.value diff --git a/pyformlang/pda/tests/test_pda.py b/pyformlang/pda/tests/test_pda.py index e0f7f17..03173cd 100644 --- a/pyformlang/pda/tests/test_pda.py +++ b/pyformlang/pda/tests/test_pda.py @@ -389,6 +389,8 @@ def test_object_eq(self): assert StackSymbol("ABC") != Symbol("ABC") assert State("ABC") != FAState("ABC") assert Symbol("s") == Terminal("s") + assert Terminal(1) != StackSymbol(1) + assert StackSymbol(42) != FAState(42) def test_contains(self, pda_example: PDA): """ Tests the transition containment checks """