Skip to content

Commit

Permalink
chore: fmt code
Browse files Browse the repository at this point in the history
  • Loading branch information
Ovler-Young committed Nov 30, 2024
1 parent 965af17 commit 67b8fd2
Showing 1 changed file with 32 additions and 21 deletions.
53 changes: 32 additions & 21 deletions src/ia_collection_analyzer/streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@
if "original_values" not in st.session_state:
st.session_state.original_values = {}


@st.fragment
def collection_input():
"""Fragment for collection ID input and metadata fetching"""
"""Fragment for collection ID input and metadata fetching"""
# input the collection name
col1, col2 = st.columns([6, 1], vertical_alignment="bottom")
with col1:
Expand All @@ -59,7 +60,10 @@ def collection_input():
if not conform_button and not st.session_state.got_metadata or collection_id == "":
st.stop()

if st.session_state.got_metadata and collection_id == st.session_state.collection_id:
if (
st.session_state.got_metadata
and collection_id == st.session_state.collection_id
):
items_pd = st.session_state.items_pd
# progress_message
progress_message = st.session_state.progress_message
Expand All @@ -71,14 +75,15 @@ def collection_input():
st.markdown("Failed to display top 10 lines. Only first will be shown.")
st.write(items_pd.head(1))
st.write(e)

return

# Check if we need to fetch new data
if not st.session_state.got_metadata or collection_id != st.session_state.collection_id:
st.markdown(
f"Getting fresh metadata for collection: **{collection_id}**"
)
if (
not st.session_state.got_metadata
or collection_id != st.session_state.collection_id
):
st.markdown(f"Getting fresh metadata for collection: **{collection_id}**")
items, progress_message = fetch_metadata(collection_id)
data_transform_text = st.text("Transforming data...")
items_pd = pd.DataFrame(items)
Expand Down Expand Up @@ -117,29 +122,32 @@ def collection_input():
# Update cache
st.session_state.items_pd = items_pd
else:
st.markdown(
f"Using cached metadata for collection: **{collection_id}**"
)
st.markdown(f"Using cached metadata for collection: **{collection_id}**")
items_pd = st.session_state.items_pd

st.session_state.got_metadata = True
st.session_state.collection_id = collection_id
st.session_state.progress_message = progress_message
st.session_state.selected_columns = []

st.rerun()


@st.fragment
def column_selector():
"""Fragment for selecting columns to analyze"""
items_pd = st.session_state.items_pd

st.header("Selecting columns to analyze")
st.write("Select additional columns you want to analyze:")
seleactable_columns = [col for col in items_pd.columns if col not in REQUIRED_METADATA]
seleactable_columns = [
col for col in items_pd.columns if col not in REQUIRED_METADATA
]

col1, col2 = st.columns([6, 1], vertical_alignment="bottom")
selected_columns = st.multiselect("Select columns:", seleactable_columns, default=[])
selected_columns = st.multiselect(
"Select columns:", seleactable_columns, default=[]
)

# Update the filtering code to use cache
if (
Expand All @@ -159,6 +167,7 @@ def column_selector():
st.write("Preview of the selected columns:")
st.write(filtered_pd.head(30))


@st.fragment
def transform_data():
"""Fragment for transforming data"""
Expand All @@ -168,13 +177,12 @@ def transform_data():
index=0,
placeholder="No",
)

if transform_needed == "No":
return

filtered_pd = st.session_state.filtered_pd



st.header("Transform Column")
st.write("Transform an existing column with data transformations")

Expand Down Expand Up @@ -360,18 +368,19 @@ def transform_data():
{"source_col": source_col, "transform_type": transform_type}
)
st.session_state.original_values[source_col] = preview_df["Original"]

st.rerun()


@st.fragment
def plot_data():
"""Fragment for data visualization"""
if not st.session_state.filtered_pd is not None:
return

filtered_pd = st.session_state.filtered_pd
plotable_columns = st.session_state.selected_columns + REQUIRED_METADATA

col1, col2, col3 = st.columns([3, 3, 1], vertical_alignment="bottom")
with col1:
x_axis = st.selectbox("Select the x-axis:", plotable_columns, index=0)
Expand Down Expand Up @@ -442,6 +451,7 @@ def plot_data():
counts_df = pd.crosstab(expanded_df[x_axis], expanded_df[y_axis])
st.write(counts_df)


def main():
collection_input()
if st.session_state.got_metadata:
Expand All @@ -450,5 +460,6 @@ def main():
transform_data()
plot_data()


if __name__ == "__main__":
main()
main()

0 comments on commit 67b8fd2

Please sign in to comment.