import glob
from pathlib import Path
from typing import Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image
from .font_property import _get_font_properties
def _clean_pattern(pattern, image_extensions):
# find the image file extensions used in the pattern
pattern_extensions = [ext for ext in image_extensions if ext in pattern][0]
pattern = pattern.replace(pattern_extensions, "")
# Remove common regex characters (you might need to extend this list)
regex_chars = ["*", ".", "?", "+", "^", "$", "(", ")", "[", "]", "{", "}", "|"]
for char in regex_chars:
pattern = pattern.replace(char, "")
return pattern
[docs]
def plot_image(
dir_image_input: Union[str, Path],
n_row: int,
n_col: int,
subplot_width: int = 3,
subplot_height: int = 3,
dir_csv_input: Union[str, Path] = None,
csv_file_pattern: str = "*.csv",
image_file_pattern: str = None,
sort_by: str = "random",
ascending: bool = True,
use_all: bool = False,
title: str = None,
path_output: Union[str, Path] = None,
random_seed: int = 42,
font_size: int = 30,
dark_mode: bool = False,
dpi: int = 300,
) -> Tuple[plt.Figure, plt.Axes]:
"""Generates a grid of images based on specified parameters and optionally annotates
them using data from a CSV file. Images can be displayed in a random or sorted order
according to metadata provided in a CSV file.
Args:
dir_image_input (Union[str, Path]): Directory path containing image files.
n_row (int): Number of rows in the image grid.
n_col (int): Number of columns in the image grid.
subplot_width (int, optional): Width of each subplot. Defaults to 3.
subplot_height (int, optional): Height of each subplot. Defaults to 3.
dir_csv_input (Union[str, Path], optional): Directory path containing CSV files with metadata. Defaults to None.
csv_file_pattern (str, optional): Pattern to match CSV files in the directory. Defaults to None.
image_file_pattern (str, optional): Pattern to match image files in the directory. Defaults to None.
sort_by (str, optional): Column name to sort the images by; set to "random" for random order. Defaults to "random".
ascending (bool, optional): Sort order. True for ascending, False for descending. Defaults to True.
use_all (bool, optional): If True, use all available images, otherwise use only a subset to fit the grid. Defaults to False.
title (str, optional): Title of the plot. Defaults to None.
path_output (Union[str, Path], optional): Path to save the output plot. Defaults to None.
random_seed (int, optional): Seed for random operations to ensure reproducibility. Defaults to 42.
font_size (int, optional): Font size for the plot title. Defaults to 30.
dark_mode (bool, optional): Set to True to use a dark theme for the plot. Defaults to False.
dpi (int, optional): Resolution in dots per inch for saving the image. Defaults to 300.
Returns:
Tuple[plt.Figure, plt.Axes]: A tuple containing the matplotlib figure and axes objects.
Raises:
ValueError: If the specified number of rows and columns does not match the available number of images.
KeyError: If the 'sort_by' column is not found in the provided CSV files.
"""
# Function implementation remains the same.
image_extensions = [
".jpg",
".jpeg",
".png",
".tif",
".tiff",
".bmp",
".dib",
".pbm",
".pgm",
".ppm",
".sr",
".ras",
".exr",
".jp2",
]
if dark_mode:
plt.style.use("dark_background")
# get a list of image files recursively
# List of possible image file extensions
if image_file_pattern is not None:
image_files = list(Path(dir_image_input).rglob(image_file_pattern))
else:
dir_image_input = Path(dir_image_input)
# Collect all image files matching the extensions
image_files = []
for ext in image_extensions:
for file in dir_image_input.rglob(f"*{ext}"):
image_files.append(file)
# Find CSV files matching the pattern
if dir_csv_input is not None and csv_file_pattern is not None:
csv_files = glob.glob(str(Path(dir_csv_input) / "**" / csv_file_pattern), recursive=True)
# combine all CSV files into a single DataFrame
df = pd.concat([pd.read_csv(file) for file in csv_files], ignore_index=True)
# Pre-filter image file names without extensions
if image_file_pattern is not None:
image_file_pattern_no_regex = _clean_pattern(image_file_pattern, image_extensions)
image_file_names = {str(file.stem).replace(image_file_pattern_no_regex, ""): file for file in image_files}
else:
image_file_names = {file.stem: file for file in image_files}
# map image file names to the DataFrame by using "filename_key" column and image_file_names keys
df["filename_key"] = df["filename_key"].astype(str)
df["image_full_path"] = df["filename_key"].map(image_file_names)
# Remove rows with missing image file paths
df_filtered = df.dropna(subset=["image_full_path"])
# Randomly shuffle the DataFrame
if not use_all:
# only get the random n_row * n_col rows
if n_row * n_col > len(df_filtered):
raise ValueError(
f"n_row * n_col ({n_row * n_col}) is greater than the number of images ({len(df_filtered)})"
)
rows = np.random.choice(df_filtered.index, n_row * n_col, replace=False)
df_filtered = df_filtered.loc[rows]
# Filter and sort the DataFrame
if sort_by.lower() != "random":
try:
df_filtered = df_filtered.sort_values(by=sort_by, ascending=ascending).reset_index(drop=True)
except KeyError:
raise KeyError(f"Column '{sort_by}' not found in the CSV file")
else:
df_filtered = df_filtered.sample(frac=1, random_state=random_seed).reset_index(
drop=True
) # Randomly shuffle the DataFrame
else:
# Randomly shuffle the image files
np.random.seed(random_seed)
np.random.shuffle(image_files)
# Create a DataFrame with the image files
df_filtered = pd.DataFrame({"image_full_path": image_files})
# Prepare the subplot
fig, axes = plt.subplots(n_row, n_col, figsize=(n_col * subplot_width, n_row * subplot_height))
fig.suptitle(title)
# Flatten the axes array for easy indexing
axes = axes.flatten()
for idx, ax in enumerate(axes):
if idx < len(df_filtered):
image_path = df_filtered.iloc[idx]["image_full_path"]
img = Image.open(image_path)
ax.imshow(img)
ax.axis("off")
else:
ax.axis("off")
# set title
prop_title, _, _ = _get_font_properties(font_size)
fig.suptitle(title, fontproperties=prop_title, color="#2b2b2b" if not dark_mode else "white")
if path_output:
plt.savefig(path_output, bbox_inches="tight", dpi=dpi)
return fig, ax