Skip to content

Commit

Permalink
Major refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
zwimer committed Jan 13, 2024
1 parent 17a7d9b commit db009aa
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 112 deletions.
36 changes: 26 additions & 10 deletions rpipe/_shared.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,33 @@
class ReadCode:
from dataclasses import dataclass, asdict


WEB_VERSION = "0.0.0"
ENCRYPTED_HEADER = "encrypted"


class ErrorCode:
"""
HTTP error codes the rpipe client may be sent
"""

wrong_version: int = 412
illegal_version: int = 409
no_data: int = 410
ok: int = 200


class WriteCode:
missing_version: int = 412
illegal_version: int = 409
ok: int = 201
@dataclass(kw_only=True, frozen=True)
class RequestParams:
version: str
override: bool = False # Not passed for upload
encrypted: bool = False # Not passed for download

def to_dict(self) -> dict[str, str]:
return {i: str(k) for i, k in asdict(self).items()}

class Headers: # Underscores are not allowed
version_override: str = "Version-Override"
client_version: str = "Client-Version"
encrypted: str = "Encrypted-Data"
@classmethod
def from_dict(cls, d: dict[str, str]) -> "RequestParams":
return cls(
version=d.get("version", WEB_VERSION),
override=d.get("override", "") == "True",
encrypted=d.get("encrypted", "") == "True",
)
2 changes: 1 addition & 1 deletion rpipe/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__: str = "3.3.1" # Must be "<major>.<minor>.<patch>", all numbers
__version__: str = "4.0.0" # Must be "<major>.<minor>.<patch>", all numbers
57 changes: 26 additions & 31 deletions rpipe/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
import sys
import os

from requests import Request, Session
from Cryptodome.Random import get_random_bytes
from Cryptodome.Cipher import AES
from requests import Request, Session

from ._shared import WriteCode, ReadCode, Headers
from ._shared import ENCRYPTED_HEADER, RequestParams, ErrorCode
from ._version import __version__

if TYPE_CHECKING:
Expand Down Expand Up @@ -114,8 +114,6 @@ def _crypt(encrypt: bool, data: bytes, password: str | None) -> bytes:
def _request(*args, **kwargs) -> "Response":
r = Request(*args, **kwargs).prepare()
logging.debug("Preparing request:\n %s %s", r.method, r.url)
for i, k in r.headers.items():
logging.debug(" %s: %s", i, k)
if r.body:
logging.debug(" len(request.body) = %d", len(r.body))
logging.debug(" timeout=%d", _TIMEOUT)
Expand All @@ -133,24 +131,23 @@ def _recv(config: ValidConfig, peek: bool, force: bool) -> None:
Receive data from the remote pipe
"""
logging.debug("Reading from channel %s with peek=%s and force=%s", config.channel, peek, force)
r = _request(
"GET",
f"{config.url}/{'peek' if peek else 'read'}/{quote(config.channel)}",
headers={Headers.client_version: __version__, Headers.version_override: str(force)},
)
url = f"{config.url}/{'peek' if peek else 'read'}/{quote(config.channel)}"
r = _request("GET", url, params=RequestParams(version=__version__, override=force).to_dict())
encrypted: bool = r.headers.get(ENCRYPTED_HEADER, "") == "True"
if r.ok:
sys.stdout.buffer.write(_crypt(False, r.content, config.password if encrypted else None))
sys.stdout.flush()
return
match r.status_code:
case ReadCode.ok:
encrypted = r.headers.get(Headers.encrypted, "False") == "True"
sys.stdout.buffer.write(_crypt(False, r.content, config.password if encrypted else None))
sys.stdout.flush()
case ReadCode.wrong_version:
case ErrorCode.wrong_version:
raise VersionError(f"Version mismatch; uploader version = {r.text}; force a read with --force")
case ReadCode.illegal_version:
case ErrorCode.illegal_version:
raise VersionError(f"Server requires version >= {r.text}")
case ReadCode.no_data:
case ErrorCode.no_data:
raise NoData(f"The channel {config.channel} is empty.")
case _:
raise RuntimeError(f"Unknown status code: {r.status_code}\nBody: {r.text}")
logging.debug("Body: %s", r.text)
raise RuntimeError(f"Unknown status code: {r.status_code}")


def _send(config: ValidConfig) -> None:
Expand All @@ -162,19 +159,14 @@ def _send(config: ValidConfig) -> None:
r = _request(
"POST",
f"{config.url}/write/{quote(config.channel)}",
headers={
Headers.client_version: __version__,
Headers.encrypted: str(isinstance(config.password, str)),
},
params=RequestParams(version=__version__, encrypted=config.password is not None).to_dict(),
data=data,
)
if r.ok:
return
match r.status_code:
case WriteCode.ok:
pass
case WriteCode.illegal_version:
case ErrorCode.illegal_version:
raise VersionError(f"Server requires version >= {r.text}")
case WriteCode.missing_version:
raise VersionError("Client failed to set headers correctly; please report this")
case _:
raise RuntimeError(f"Unexpected status code: {r.status_code}\nBody: {r.text}")

Expand All @@ -184,7 +176,7 @@ def _clear(config: ValidConfig) -> None:
Clear the remote pipe
"""
logging.debug("Clearing channel %s", config.channel)
r = _request("GET", f"{config.url}/clear/{quote(config.channel)}")
r = _request("DELETE", f"{config.url}/clear/{quote(config.channel)}")
if not r.ok:
raise RuntimeError(f"Unexpected status code: {r.status_code}\nBody: {r.text}")

Expand Down Expand Up @@ -250,22 +242,23 @@ def _verify_config(conf: Config, encrypt: bool) -> None:
raise UsageError("Missing: --encrypt requires a password")


def rpipe(conf: Config, mode: Mode) -> None:
def rpipe(conf: Config, mode: Mode) -> bool:
"""
rpipe
:returns: True on success
"""
logging.debug("Config file: %s", CONFIG_FILE)
_mode_check(mode)
if mode.print_config:
_print_config()
return
return True
# Load pipe config and save is requested
conf = _load_config(conf, mode.plaintext)
msg = "Loaded config with:\n url = %s\n channel = %s\n has password: %s"
logging.debug(msg, conf.url, conf.channel, conf.password is not None)
if mode.save_config:
_save_config(conf, mode.encrypt)
return
return True
if not (mode.encrypt or mode.plaintext or mode.read or mode.clear):
logging.info("Write mode: No password found, falling back to --plaintext")
_verify_config(conf, mode.encrypt)
Expand All @@ -277,6 +270,7 @@ def rpipe(conf: Config, mode: Mode) -> None:
_recv(valid_conf, mode.peek, mode.force)
else:
_send(valid_conf)
return True


def main(prog: str, *args: str) -> None:
Expand Down Expand Up @@ -322,7 +316,8 @@ def main(prog: str, *args: str) -> None:
mode_d = {i: k for i, k in ns.items() if i in keys(Mode)}
assert set(ns) == set(conf_d) | set(mode_d)
conf_d["password"] = None
return rpipe(Config(**conf_d), Mode(read=sys.stdin.isatty(), **mode_d))
if not rpipe(Config(**conf_d), Mode(read=sys.stdin.isatty(), **mode_d)):
sys.exit(1)


def cli() -> None:
Expand Down
105 changes: 35 additions & 70 deletions rpipe/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@
import time
import sys

from flask import Flask, Response, request, redirect
from flask import Flask, Response, request
import waitress

from ._shared import WriteCode, ReadCode, Headers
from ._shared import ENCRYPTED_HEADER, WEB_VERSION, RequestParams, ErrorCode
from ._version import __version__

if TYPE_CHECKING:
from werkzeug.datastructures import Headers as HeadersType
from werkzeug.wrappers import Response as BaseResponse


Expand All @@ -28,8 +27,7 @@

app = Flask(__name__)

MIN_CLIENT_VERSION = (3, 0, 0)
WEB_VERSION = "0.0.0"
MIN_VERSION = (4, 0, 0)
PRUNE_DELAY: int = 5


Expand All @@ -54,71 +52,56 @@ class Data(NamedTuple):
#


def _version_to_tuple(version: str) -> tuple[int, int, int]:
def _version_tup(version: str) -> tuple[int, int, int]:
ret = tuple(int(i) for i in version.split("."))
if len(ret) != 3:
raise ValueError("Bad version")
return ret


def _version_from_tuple(version: tuple[int, int, int]) -> str:
def _version_str(version: tuple[int, int, int]) -> str:
return ".".join(str(i) for i in version)


def _check_required_version(client_version: str) -> Response | None:
def _check_version(version: str) -> Response | None:
"""
:param client_version: The version to check
:return: A flask Response if the version is not acceptable
"""
if not client_version:
return Response(
"Try updating your client or visit /help if using a browser", status=WriteCode.missing_version
)
try:
if _version_to_tuple(client_version) < MIN_CLIENT_VERSION:
if _version_tup(version) < MIN_VERSION:
raise ValueError()
except (AttributeError, ValueError):
return Response(_version_from_tuple(MIN_CLIENT_VERSION), status=WriteCode.illegal_version)
msg = f"Bad version: {_version_str(MIN_VERSION)}"
return Response(msg, status=ErrorCode.illegal_version)
return None


def _get(channel: str, path: str, headers: HeadersType, delete: bool) -> Response:
def _get(channel: str, args: RequestParams, delete: bool) -> Response:
"""
Get the data from channel, delete it afterwards if required
If web version: Fail if not encrypted, bypass version checks
If non web-version, but should be: redirect to web version
Otherwise: Version check
"""
# Redirect non-client requests to the web version and mark the version as the web version
if Headers.client_version not in request.headers:
if not path.startswith("/web"):
return redirect(f"/web{path}", code=308) # type: ignore
client_version = WEB_VERSION
# For client requests, verify the version is new enough to accept
else:
client_version = headers.get(Headers.client_version, "")
if (ret := _check_required_version(client_version)) is not None:
return ret
# Read data from the channel
if args.version != WEB_VERSION and (ret := _check_version(args.version)) is not None:
return ret
with lock:
got: Data | None = data.get(channel, None)
# No data found?
if got is None:
return Response(f"No data on channel {channel}", status=ReadCode.no_data)
return Response(f"No data on channel {channel}", status=ErrorCode.no_data)
# Web version cannot handle encryption
if client_version == WEB_VERSION and got.encrypted:
return Response(
"Web version cannot read encrypted data. Use the CLI: pip install rpipe", status=422
)
if args.version == WEB_VERSION and got.encrypted:
msg = "Web version cannot read encrypted data. Use the CLI: pip install rpipe"
return Response(msg, status=422)
# Version comparison; bypass if web version or override requested
got_ver = _version_from_tuple(got.client_version)
if client_version not in (WEB_VERSION, got_ver):
if headers.get(Headers.version_override, "") != "True":
return Response(got_ver, status=ReadCode.wrong_version)
got_ver = _version_str(got.client_version)
if args.version not in (WEB_VERSION, got_ver) and not args.override:
return Response(got_ver, status=ErrorCode.wrong_version)
# Delete data from channel if needed
if got is not None and delete:
del data[channel]
return Response(got.data, headers={Headers.encrypted: str(got.encrypted)}, status=ReadCode.ok)
return Response(got.data, headers={ENCRYPTED_HEADER: str(got.encrypted)})


def _periodic_prune() -> None:
Expand All @@ -129,8 +112,9 @@ def _periodic_prune() -> None:
while True:
old: datetime = datetime.now() - prune_age
with lock:
for i in [i for i, k in data.items() if k.when < old]:
del data[i]
for i, k in data.items():
if k.when < old:
del data[i]
time.sleep(60)


Expand All @@ -140,14 +124,12 @@ def _periodic_prune() -> None:


@app.route("/")
@app.route("/web")
@app.route("/help")
@app.route("/web/help")
def _help() -> str:
return (
"Write to /web/write, read from /web/read or /web/peek, clear with "
"/web/clear; add a trailing /<channel> to specify the channel. "
"Note: Using the /web/ API bypasses version consistenct checks "
"Write to /write, read from /read or /peek, clear with "
"/clear; add a trailing /<channel> to specify the channel. "
"Note: Using the web version bypasses version consistenct checks "
"and may result in safe but unexpected behavior (such as failing "
"an uploaded message; if possible use the rpipe CLI instead. "
"Install the CLI via: pip install rpipe"
Expand All @@ -159,49 +141,32 @@ def _show_version() -> str:
return __version__


@app.route("/clear/<channel>")
@app.route("/web/clear/<channel>")
@app.route("/clear/<channel>", methods=["DELETE"])
def _clear(channel: str) -> BaseResponse:
if Headers.client_version not in request.headers:
if not request.path.startswith("/web"):
return redirect(f"/web{request.path}", code=308)
with lock:
if channel in data:
del data[channel]
return Response("Cleared", status=202)


@app.route("/peek/<channel>")
@app.route("/web/peek/<channel>")
def _peek(channel: str) -> BaseResponse:
return _get(channel, request.path, request.headers, False)
return _get(channel, RequestParams.from_dict(request.args.to_dict()), False)


@app.route("/read/<channel>")
@app.route("/web/read/<channel>")
def _read(channel: str) -> BaseResponse:
return _get(channel, request.path, request.headers, True)
return _get(channel, RequestParams.from_dict(request.args.to_dict()), True)


@app.route("/write/<channel>", methods=["POST"])
@app.route("/web/write/<channel>", methods=["POST"])
def _write(channel: str) -> BaseResponse:
# Redirect non-client requests to web API, variables for web usage
if Headers.client_version not in request.headers:
if not request.path.startswith("/web"):
return redirect(f"/web{request.path}", code=308)
client_version = WEB_VERSION
encrypted = False
# Version check client, determine if message is encrypted
else:
client_version = request.headers.get(Headers.client_version, "")
if (ret := _check_required_version(client_version)) is not None:
return ret
encrypted = request.headers.get(Headers.encrypted, "False") == "True"
# Store the uploaded data
args = RequestParams.from_dict(request.args.to_dict())
if args.version != WEB_VERSION and (ret := _check_version(args.version)) is not None:
return ret
with lock:
data[channel] = Data(request.get_data(), datetime.now(), encrypted, _version_to_tuple(client_version))
return Response(status=WriteCode.ok)
data[channel] = Data(request.get_data(), datetime.now(), args.encrypted, _version_tup(args.version))
return Response(status=201)


#
Expand All @@ -226,7 +191,7 @@ def main(prog, *args) -> None:
parser.add_argument(
"--min-client-version",
action="version",
version=f"rpipe>={_version_from_tuple(MIN_CLIENT_VERSION)}",
version=f"rpipe>={_version_str(MIN_VERSION)}",
help="Print the minimum supported client version then exit",
)
parser.add_argument("port", type=int, help="The port waitress will listen on")
Expand Down

0 comments on commit db009aa

Please sign in to comment.