diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..918f8dc --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/.mypy_cache +/__pycache__ diff --git a/README.md b/README.md index d6c56d8..819a221 100644 --- a/README.md +++ b/README.md @@ -1,23 +1,18 @@ -A switch "statement" for Python. Please don't actually use this. +A switch "statement" for Python. x = 5 - switch(x) - - if case(4): - assert False - - if case(5): - assert True - fallthrough() - - if case(lambda x: x > 3): - assert True - - if case(6): - assert False - - if default(): - assert False + with switch(x) as case: + if case(4): + assert False + if case(5): + assert True + case.fallthrough() + if case(lambda x: x > 3): + assert True + if case(6): + assert False + if case.default(): + assert False Tests and usage in test.py. diff --git a/switch.py b/switch.py index f8e61c6..fe76428 100644 --- a/switch.py +++ b/switch.py @@ -1,44 +1,43 @@ -__all__ = ['switch', 'case', 'default', 'fallthrough'] +__all__ = ['switch'] -class Switch: - def __init__(self, value): - self.value = value - self.finished = False - - def case(self, cond): - if self.finished: - return False - match = False +from contextlib import contextmanager +from typing import ( + cast, + Callable, + Generator, + Generic, + TypeVar, + Union, +) - if hasattr(cond, '__call__'): - if cond(self.value): - match = True - elif cond == self.value: - match = True - if match: - self.finished = True +T = TypeVar('T') +GuardT = Callable[[T], bool] - return match - def fallthrough(self): +class Case(Generic[T]): + def __init__(self, value: T): + self.value = value self.finished = False - def default(self): - return not self.finished + def __call__(self, cond: Union[T, GuardT]) -> bool: + if (self.finished + or (hasattr(cond, '__call__') + and not cast(GuardT, cond)(self.value)) + or cond != self.value): + return False -the_switch = None + self.finished = True + return True -def switch(value): - global the_switch - the_switch = Switch(value) + def fallthrough(self) -> None: + self.finished = False -def case(cond): - return the_switch.case(cond) + def default(self) -> bool: + return not self.finished -def fallthrough(): - the_switch.fallthrough() -def default(): - the_switch.default() +@contextmanager +def switch(value: T) -> Generator[Case[T], None, None]: + yield Case(value) diff --git a/test.py b/test.py index 506a33b..7749b34 100644 --- a/test.py +++ b/test.py @@ -1,30 +1,23 @@ -from switch import * +from switch import switch x = 5 -switch(x) - -if case(4): - assert False - -if case(5): - assert True - fallthrough() - -if case(5): - assert True - -if case(5): - assert False - -if case(6): - assert False - - -switch(3) - -if case(lambda x: x > 5): - assert False - -if default(): - assert True +with switch(x) as case: + if case(4): + assert False + if case(5): + assert True + case.fallthrough() + if case(5): + assert True + if case(5): + assert False + if case(6): + assert False + + +with switch(3) as case: + if case(lambda x: x > 5): + assert False + if case.default(): + assert True