Skip to content

Commit

Permalink
feat/file viewer (#417)
Browse files Browse the repository at this point in the history
* feat: add file tree viewer

* feat: add copy from data catalog

* feat: add s3 viewer in data catalog

* feat: add show_files and show_s3 to config wizard

* fix: make boto3 optional

* fix: boto3 import and other refinements

* fix: use relpath in file tree test

* fix: update help screen; xfail windows test

* feat: cache s3 catalog contents

* fix: update help screen snapshot

* fix: enable all buckets mode with gcs
  • Loading branch information
tconbeer authored Jan 22, 2024
1 parent 0083dbf commit 5d66bd2
Show file tree
Hide file tree
Showing 27 changed files with 2,728 additions and 410 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ repos:
- tomlkit
- pandas-stubs
- importlib_metadata
- boto3-stubs
args:
- "--disallow-untyped-calls"
- "--disallow-untyped-defs"
Expand Down
14 changes: 14 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,20 @@ All notable changes to this project will be documented in this file.

## [Unreleased]

### Features

- Adds an option, `--show-files` (alias `-f`), which will display the passed directory in the Data Catalog, alongside the connected database schema, in a second tab. Like database catalog items, you can use <kbd>ctrl+enter</kbd>, <kbd>ctrl+j</kbd>, or double-click to insert the path into the query editor.
- Adds an option, `--show-s3` (alias `--s3`), which will display objects from the passed URI in the Data Catalog (in another tab). Uses the credentials from the AWS CLI's default profile. Use `--show-s3 all` to show all objects in all buckets for the currently-authenticated user, or pass buckets and key prefixes to restrict the catalog. For example, these all work:
```bash
harlequin --show-s3 my-bucket
harlequin --show-s3 my-bucket/my-nested/key-prefix
harlequin --show-s3 s3://my-bucket
harlequin --show-s3 https://my-storage.com/my-bucket/my-prefix
harlequin --show-s3 https://my-bucket.s3.amazonaws.com/my-prefix
harlequin --show-s3 https://my-bucket.storage.googleapis.com/my-prefix
```
- Items in the Data Catalog can now be copied to the clipboard with <kbd>ctrl+c</kbd>.

## [1.11.0] - 2024-01-12

### Features
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ lint:

.PHONY: serve
serve:
textual run --dev -c harlequin
textual run --dev -c harlequin -f .

.PHONY: sqlite
sqlite:
Expand All @@ -25,4 +25,4 @@ static/themes/%.svg: pyproject.toml src/scripts/export_screenshots.py
python src/scripts/export_screenshots.py

static/harlequin.gif: static/harlequin.mp4
ffmpeg -i static/harlequin.mp4 -vf "fps=24,scale=640:-1:flags=lanczos,split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse" -loop 0 static/harlequin.gif
ffmpeg -i static/harlequin.mp4 -vf "fps=24,scale=640:-1:flags=lanczos,split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse" -loop 0 static/harlequin.gif
685 changes: 602 additions & 83 deletions poetry.lock

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ tomli = { version = "^2.0.1", python = "<3.11.0" }
tomlkit = "^0.12.3"
questionary = "^2.0.1"

# optional deps
boto3 = { version = "^1.34.22", optional = true }

# database adapters (optional installs for extras)
harlequin-postgres = { version = "^0.2", optional = true }
harlequin-mysql = { version = "^0.1", optional = true }
Expand All @@ -59,6 +62,7 @@ ruff = ">=0.0.285"
mypy = "^1.2.0"
types-pygments = "^2.16.0.0"
pandas-stubs = "^2"
boto3-stubs = "^1.34.23"

[tool.poetry.group.test.dependencies]
pytest = "^7.3.1"
Expand All @@ -71,6 +75,7 @@ duckdb = "==0.8.1"
harlequin = "harlequin.cli:harlequin"

[tool.poetry.extras]
s3 = ["boto3"]
postgres = ["harlequin-postgres"]
mysql = ["harlequin-mysql"]
odbc = ["harlequin-odbc"]
Expand Down
95 changes: 75 additions & 20 deletions src/harlequin/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import time
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional, Type, Union

from textual import on, work
Expand All @@ -26,8 +27,12 @@
from harlequin import HarlequinConnection
from harlequin.adapter import HarlequinAdapter, HarlequinCursor
from harlequin.autocomplete import completer_factory
from harlequin.catalog import Catalog, CatalogItem, NewCatalog
from harlequin.catalog_cache import get_cached_catalog, update_cache_with_catalog
from harlequin.catalog import Catalog, NewCatalog
from harlequin.catalog_cache import (
CatalogCache,
get_catalog_cache,
update_catalog_cache,
)
from harlequin.colors import HarlequinColors
from harlequin.components import (
CodeEditor,
Expand All @@ -53,6 +58,12 @@
)


class CatalogCacheLoaded(Message):
def __init__(self, cache: CatalogCache) -> None:
super().__init__()
self.cache = cache


class DatabaseConnected(Message):
def __init__(self, connection: HarlequinConnection) -> None:
super().__init__()
Expand Down Expand Up @@ -121,6 +132,8 @@ def __init__(
*,
connection_hash: str | None = None,
theme: str = "harlequin",
show_files: Path | None = None,
show_s3: str | None = None,
max_results: int | str = 100_000,
driver_class: Union[Type[Driver], None] = None,
css_path: Union[CSSPathType, None] = None,
Expand All @@ -131,6 +144,8 @@ def __init__(
self.connection_hash = connection_hash
self.catalog: Catalog | None = None
self.theme = theme
self.show_files = show_files
self.show_s3 = show_s3 or None
try:
self.max_results = int(max_results)
except ValueError:
Expand All @@ -156,7 +171,11 @@ def __init__(
def compose(self) -> ComposeResult:
"""Create child widgets for the app."""
with Horizontal():
yield DataCatalog(type_color=self.app_colors.gray)
yield DataCatalog(
type_color=self.app_colors.gray,
show_files=self.show_files,
show_s3=self.show_s3,
)
with Vertical(id="main_panel"):
yield EditorCollection(language="sql", theme=self.theme)
yield RunQueryBar(
Expand Down Expand Up @@ -198,9 +217,9 @@ async def on_mount(self) -> None:

self.editor.focus()
self.run_query_bar.checkbox.value = False
self.data_catalog.loading = True

self._connect()
self._load_catalog_cache()

@on(Button.Pressed, "#run_query")
def submit_query_from_run_query_bar(self, message: Button.Pressed) -> None:
Expand All @@ -212,6 +231,15 @@ def submit_query_from_run_query_bar(self, message: Button.Pressed) -> None:
)
)

@on(CatalogCacheLoaded)
def build_trees(self, message: CatalogCacheLoaded) -> None:
if self.connection_hash is not None and (
cached_db := message.cache.get_db(self.connection_hash)
):
self.post_message(NewCatalog(catalog=cached_db))
if self.show_s3 is not None:
self.data_catalog.load_s3_tree_from_cache(message.cache)

@on(CodeEditor.Submitted)
def submit_query_from_editor(self, message: CodeEditor.Submitted) -> None:
message.stop()
Expand All @@ -233,13 +261,22 @@ def initialize_app(self, message: DatabaseConnected) -> None:
self.update_schema_data()

@on(DataCatalog.NodeSubmitted)
def insert_node_into_editor(
self, message: DataCatalog.NodeSubmitted[CatalogItem]
) -> None:
def insert_node_into_editor(self, message: DataCatalog.NodeSubmitted) -> None:
message.stop()
if message.node.data:
self.editor.insert_text_at_selection(text=message.node.data.query_name)
self.editor.focus()
self.editor.insert_text_at_selection(text=message.insert_name)
self.editor.focus()

@on(DataCatalog.NodeCopied)
def copy_node_name(self, message: DataCatalog.NodeCopied) -> None:
message.stop()
self.editor.text_input.clipboard = message.copy_name
if self.editor.use_system_clipboard:
try:
self.editor.text_input.system_copy(message.copy_name)
except Exception:
self.notify("Error copying data to system clipboard.", severity="error")
else:
self.notify("Selected label copied to clipboard.")

@on(EditorCollection.EditorSwitched)
def update_internal_editor_state(
Expand Down Expand Up @@ -315,7 +352,7 @@ async def _handle_worker_error(self, message: Worker.StateChanged) -> None:
header="Could not update data catalog",
error=message.worker.error,
)
self.data_catalog.loading = False
self.data_catalog.database_tree.loading = False
elif (
message.worker.name == "_execute_query" and message.worker.error is not None
):
Expand Down Expand Up @@ -344,11 +381,18 @@ async def _handle_worker_error(self, message: Worker.StateChanged) -> None:
)
self.exit(return_code=2, message=pretty_error_message(error))

@on(DataCatalog.CatalogError)
def handle_catalog_error(self, message: DataCatalog.CatalogError) -> None:
self._push_error_modal(
title=f"Catalog Error: {message.catalog_type}",
header=f"Could not populate the {message.catalog_type} data catalog",
error=message.error,
)

@on(NewCatalog)
def update_tree_and_completers(self, message: NewCatalog) -> None:
self.catalog = message.catalog
self.data_catalog.update_tree(message.catalog)
self.data_catalog.loading = False
self.data_catalog.update_database_tree(message.catalog)
self.update_completers(message.catalog)

@on(QueriesExecuted)
Expand Down Expand Up @@ -497,8 +541,9 @@ async def action_quit(self) -> None:
BufferState(editor.cursor, editor.selection_anchor, editor.text)
)
write_editor_cache(Cache(focus_index=focus_index, buffers=buffers))
if self.catalog is not None and self.connection_hash is not None:
update_cache_with_catalog(self.connection_hash, self.catalog)
update_catalog_cache(
self.connection_hash, self.catalog, self.data_catalog.s3_tree
)
await super().action_quit()

def action_show_help_screen(self) -> None:
Expand All @@ -520,8 +565,10 @@ def action_toggle_sidebar(self) -> None:
self.sidebar_hidden = not self.sidebar_hidden

def action_refresh_catalog(self) -> None:
self.data_catalog.loading = True
self.data_catalog.database_tree.loading = True
self.update_schema_data()
self.data_catalog.update_file_tree()
self.data_catalog.update_s3_tree()

@work(
thread=True,
Expand All @@ -532,12 +579,20 @@ def action_refresh_catalog(self) -> None:
)
def _connect(self) -> None:
connection = self.adapter.connect()
if self.connection_hash is not None and (
cached_catalog := get_cached_catalog(self.connection_hash)
):
self.post_message(NewCatalog(catalog=cached_catalog))
self.post_message(DatabaseConnected(connection=connection))

@work(
thread=True,
exclusive=True,
exit_on_error=False,
group="cache_loaders",
description="Loading cached catalog",
)
def _load_catalog_cache(self) -> None:
cache = get_catalog_cache()
if cache is not None:
self.post_message(CatalogCacheLoaded(cache=cache))

@work(
thread=True,
exclusive=True,
Expand Down
40 changes: 30 additions & 10 deletions src/harlequin/catalog_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,24 @@
import hashlib
import json
import pickle
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Sequence
from typing import TYPE_CHECKING, Any, Sequence

from platformdirs import user_cache_dir
from textual_textarea.key_handlers import Cursor as Cursor

from harlequin.catalog import Catalog

CACHE_VERSION = 0
if TYPE_CHECKING:
from harlequin.components.data_catalog import S3Tree

CACHE_VERSION = 1


def recursive_dict() -> defaultdict:
return defaultdict(recursive_dict)


class PermissiveEncoder(json.JSONEncoder):
Expand All @@ -29,6 +37,15 @@ def default(self, obj: Any) -> Any:
@dataclass
class CatalogCache:
databases: dict[str, Catalog]
s3: dict[tuple[str | None, str | None, str | None], dict]

def get_db(self, connection_hash: str) -> Catalog | None:
return self.databases.get(connection_hash, None)

def get_s3(
self, cache_key: tuple[str | None, str | None, str | None]
) -> dict | None:
return self.s3.get(cache_key, None)


def get_connection_hash(conn_str: Sequence[str], config: dict[str, Any]) -> str:
Expand All @@ -44,18 +61,20 @@ def get_connection_hash(conn_str: Sequence[str], config: dict[str, Any]) -> str:
)


def get_cached_catalog(connection_hash: str) -> Catalog | None:
cache = _load_cache()
if cache is None:
return None
return cache.databases.get(connection_hash, None)
def get_catalog_cache() -> CatalogCache | None:
return _load_cache()


def update_cache_with_catalog(connection_hash: str, catalog: Catalog) -> None:
def update_catalog_cache(
connection_hash: str | None, catalog: Catalog | None, s3_tree: S3Tree | None
) -> None:
cache = _load_cache()
if cache is None:
cache = CatalogCache(databases={})
cache.databases[connection_hash] = catalog
cache = CatalogCache(databases={}, s3={})
if catalog is not None and connection_hash is not None:
cache.databases[connection_hash] = catalog
if s3_tree is not None and s3_tree.catalog_data is not None:
cache.s3[s3_tree.cache_key] = s3_tree.catalog_data
_write_cache(cache)


Expand Down Expand Up @@ -83,6 +102,7 @@ def _load_cache() -> CatalogCache | None:
IndexError,
FileNotFoundError,
AssertionError,
EOFError,
):
return None
else:
Expand Down
Loading

0 comments on commit 5d66bd2

Please sign in to comment.