From e65cd56f9a8421a23d56ee73664cd454a8e86519 Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Mon, 12 Jan 2026 06:08:27 -0800 Subject: [PATCH 01/27] Add core, components, webgpu, fft - javascript --- widget/js/CONFIG.ts | 189 ++++ widget/js/components.tsx | 220 +++++ widget/js/core/canvas-utils.ts | 268 ++++++ widget/js/core/canvas.ts | 276 ++++++ widget/js/core/colormaps.ts | 100 +++ widget/js/core/colors.ts | 71 ++ widget/js/core/export.ts | 135 +++ widget/js/core/fft-utils.ts | 161 ++++ widget/js/core/format.ts | 58 ++ widget/js/core/hooks.ts | 211 +++++ widget/js/core/index.ts | 86 ++ widget/js/core/styles.ts | 295 ++++++ widget/js/core/webgpu-hook.ts | 37 + widget/js/index.jsx | 33 - widget/js/shared.ts | 221 +++++ widget/js/show4dstem.css | 19 + widget/js/show4dstem.tsx | 1532 ++++++++++++++++++++++++++++++++ widget/js/webgpu-fft.ts | 558 ++++++++++++ 18 files changed, 4437 insertions(+), 33 deletions(-) create mode 100644 widget/js/CONFIG.ts create mode 100644 widget/js/components.tsx create mode 100644 widget/js/core/canvas-utils.ts create mode 100644 widget/js/core/canvas.ts create mode 100644 widget/js/core/colormaps.ts create mode 100644 widget/js/core/colors.ts create mode 100644 widget/js/core/export.ts create mode 100644 widget/js/core/fft-utils.ts create mode 100644 widget/js/core/format.ts create mode 100644 widget/js/core/hooks.ts create mode 100644 widget/js/core/index.ts create mode 100644 widget/js/core/styles.ts create mode 100644 widget/js/core/webgpu-hook.ts delete mode 100644 widget/js/index.jsx create mode 100644 widget/js/shared.ts create mode 100644 widget/js/show4dstem.css create mode 100644 widget/js/show4dstem.tsx create mode 100644 widget/js/webgpu-fft.ts diff --git a/widget/js/CONFIG.ts b/widget/js/CONFIG.ts new file mode 100644 index 00000000..9443471a --- /dev/null +++ b/widget/js/CONFIG.ts @@ -0,0 +1,189 @@ +/** + * Global configuration for bobleesj.widget. + * Layout constants and styling presets for all widgets. + */ + +// Import colors from single source of truth +import { COLORS, colors } from "./core/colors"; +export { COLORS, colors }; + +// ============================================================================ +// TYPOGRAPHY +// ============================================================================ +export const TYPOGRAPHY = { + LABEL: { + color: COLORS.TEXT_SECONDARY, + fontSize: 11, + }, + LABEL_SMALL: { + color: COLORS.TEXT_MUTED, + fontSize: 10, + }, + VALUE: { + color: COLORS.TEXT_MUTED, + fontSize: 10, + fontFamily: "monospace", + }, + TITLE: { + color: COLORS.ACCENT, + fontWeight: "bold" as const, + }, +}; + +// ============================================================================ +// CONTROL PANEL STYLES +// ============================================================================ +export const CONTROL_PANEL = { + // Standard control group (height: 32px) + GROUP: { + bgcolor: COLORS.BG_PANEL, + px: 1.5, + py: 0.5, + borderRadius: 1, + border: `1px solid ${COLORS.BORDER}`, + height: 32, + }, + // Compact button + BUTTON: { + color: COLORS.TEXT_MUTED, + fontSize: 10, + cursor: "pointer", + "&:hover": { color: COLORS.TEXT_PRIMARY }, + bgcolor: COLORS.BG_PANEL, + px: 1, + py: 0.25, + borderRadius: 0.5, + border: `1px solid ${COLORS.BORDER}`, + }, + // Select dropdown + SELECT: { + minWidth: 90, + bgcolor: COLORS.BG_INPUT, + color: COLORS.TEXT_PRIMARY, + fontSize: 11, + "& .MuiSelect-select": { + py: 0.5, + }, + }, +}; + +// ============================================================================ +// CONTAINER STYLES +// ============================================================================ +export const CONTAINER = { + ROOT: { + p: 2, + bgcolor: COLORS.BG, + color: COLORS.TEXT_PRIMARY, + fontFamily: "monospace", + borderRadius: 1, + // CRITICAL: Allow dropdowns to overflow + overflow: "visible", + }, + IMAGE_BOX: { + bgcolor: "#000", + border: `1px solid ${COLORS.BORDER}`, + overflow: "hidden", + position: "relative" as const, + }, +}; + +// ============================================================================ +// SLIDER SIZES +// ============================================================================ +export const SLIDER = { + // Width presets + WIDTH: { + TINY: 60, // Very compact (e.g., ms/frame slider) + SMALL: 80, // Standard small slider + MEDIUM: 100, // Medium slider + LARGE: 120, // Larger slider + }, + // Container min-widths (for label + slider + value combos) + CONTAINER: { + COMPACT: 120, // Minimal container + STANDARD: 150, // Standard container (e.g., delay slider) + WIDE: 180, // Wider container + }, +}; + +// ============================================================================ +// PANEL SIZES (for canvases and image boxes) +// ============================================================================ +export const PANEL = { + // Main image canvas sizes + MAIN: { + DEFAULT: 300, // Default main canvas size + MIN: 150, // Minimum resizable size + MAX: 600, // Maximum resizable size + }, + // Side panels (FFT, histogram, etc.) + SIDE: { + DEFAULT: 150, // Default side panel size + MIN: 80, // Minimum resizable size + MAX: 250, // Maximum resizable size + }, + // Show4DSTEM specific + DP: { + DEFAULT: 400, // Diffraction pattern panel + }, + VIRTUAL: { + DEFAULT: 300, // Virtual image panel + }, + FFT: { + DEFAULT: 300, // FFT panel + }, + // Gallery mode + GALLERY: { + IMAGE_SIZE: 200, // Target size for gallery images + MIN_COLS: 2, // Minimum columns + MAX_COLS: 4, // Maximum columns + }, +}; + +// ============================================================================ +// ZOOM/PAN LIMITS +// ============================================================================ +export const ZOOM = { + MIN: 0.5, + MAX: 10, + WHEEL_FACTOR: { + IN: 1.1, + OUT: 0.9, + }, +}; + +// ============================================================================ +// ANIMATION/PLAYBACK +// ============================================================================ +export const PLAYBACK = { + MS_PER_FRAME: { + DEFAULT: 1000, // Default: 1 fps + MIN: 200, // Fastest: 5 fps + MAX: 3000, // Slowest: ~0.33 fps + STEP: 100, // Step size for slider + }, +}; + +// ============================================================================ +// LEGACY ALIASES (for backward compatibility during migration) +// These use camelCase keys to match existing widget code +// Note: `colors` is imported from core/colors.ts and re-exported at the top +// ============================================================================ +export const typography = { + label: TYPOGRAPHY.LABEL, + labelSmall: TYPOGRAPHY.LABEL_SMALL, + value: TYPOGRAPHY.VALUE, + title: TYPOGRAPHY.TITLE, +}; + +export const controlPanel = { + group: CONTROL_PANEL.GROUP, + button: CONTROL_PANEL.BUTTON, + select: CONTROL_PANEL.SELECT, +}; + +export const container = { + root: CONTAINER.ROOT, + imageBox: CONTAINER.IMAGE_BOX, +}; diff --git a/widget/js/components.tsx b/widget/js/components.tsx new file mode 100644 index 00000000..c275c8cb --- /dev/null +++ b/widget/js/components.tsx @@ -0,0 +1,220 @@ +/** + * Shared styling constants and simple UI components for bobleesj.widget. + * + * ARCHITECTURE NOTE: Only styling should be shared here. + * Widget-specific logic (resize handlers, zoom handlers) should be inlined per-widget. + */ + +import * as React from "react"; +import Switch from "@mui/material/Switch"; +import Select from "@mui/material/Select"; +import MenuItem from "@mui/material/MenuItem"; +import Stack from "@mui/material/Stack"; +import Typography from "@mui/material/Typography"; +import { colors, controlPanel, typography } from "./CONFIG"; + +// ============================================================================ +// Switch Style Constants +// ============================================================================ +export const switchStyles = { + small: { + '& .MuiSwitch-thumb': { width: 12, height: 12 }, + '& .MuiSwitch-switchBase': { padding: '4px' }, + }, + medium: { + '& .MuiSwitch-thumb': { width: 14, height: 14 }, + '& .MuiSwitch-switchBase': { padding: '4px' }, + }, +}; + +// ============================================================================ +// Select MenuProps for upward dropdown (all widgets use this) +// ============================================================================ +export const upwardMenuProps = { + anchorOrigin: { vertical: "top" as const, horizontal: "left" as const }, + transformOrigin: { vertical: "bottom" as const, horizontal: "left" as const }, + sx: { zIndex: 9999 }, +}; + +// ============================================================================ +// LabeledSwitch - Label + Switch combo (optional, use if needed) +// ============================================================================ +interface LabeledSwitchProps { + label: string; + checked: boolean; + onChange: (checked: boolean) => void; + size?: "small" | "medium"; +} + +export function LabeledSwitch({ label, checked, onChange, size = "small" }: LabeledSwitchProps) { + return ( + + {label}: + onChange(e.target.checked)} + size="small" + sx={switchStyles[size]} + /> + + ); +} + +// ============================================================================ +// LabeledSelect - Label + Select dropdown combo (optional, use if needed) +// ============================================================================ +interface LabeledSelectProps { + label: string; + value: T; + options: readonly T[] | T[]; + onChange: (value: T) => void; + formatLabel?: (value: T) => string; +} + +export function LabeledSelect({ + label, + value, + options, + onChange, + formatLabel, +}: LabeledSelectProps) { + return ( + + {label}: + + + ); +} + +// ============================================================================ +// ScaleBar - Overlay component for canvas scale bars +// ============================================================================ +interface ScaleBarProps { + zoom: number; + size: number; + label?: string; +} + +export function ScaleBar({ zoom, size, label = "px" }: ScaleBarProps) { + const scaleBarPx = 50; + const realPixels = Math.round(scaleBarPx / zoom); + + return ( +
+ + {realPixels} {label} + +
+
+ ); +} + +// ============================================================================ +// ZoomIndicator - Overlay component for zoom level display +// ============================================================================ +interface ZoomIndicatorProps { + zoom: number; +} + +export function ZoomIndicator({ zoom }: ZoomIndicatorProps) { + return ( + + {zoom.toFixed(1)}× + + ); +} + +// ============================================================================ +// ResetButton - Compact reset button +// ============================================================================ +interface ResetButtonProps { + onClick: () => void; + label?: string; +} + +export function ResetButton({ onClick, label = "Reset" }: ResetButtonProps) { + return ( + + {label} + + ); +} + +// ============================================================================ +// ControlGroup - Wrapper for control panel groups +// ============================================================================ +interface ControlGroupProps { + children: React.ReactNode; +} + +export function ControlGroup({ children }: ControlGroupProps) { + return ( + + {children} + + ); +} + +// ============================================================================ +// ColormapSelect - Colormap dropdown with standard options +// ============================================================================ +const COLORMAP_OPTIONS = ["inferno", "viridis", "plasma", "magma", "hot", "gray"] as const; + +interface ColormapSelectProps { + value: string; + onChange: (value: string) => void; +} + +export function ColormapSelect({ value, onChange }: ColormapSelectProps) { + return ( + v.charAt(0).toUpperCase() + v.slice(1)} + /> + ); +} diff --git a/widget/js/core/canvas-utils.ts b/widget/js/core/canvas-utils.ts new file mode 100644 index 00000000..04bf0101 --- /dev/null +++ b/widget/js/core/canvas-utils.ts @@ -0,0 +1,268 @@ +/** + * Shared canvas rendering utilities. + * Used by Show2D, Show3D, Show4DSTEM, and Reconstruct. + */ + +import { COLORMAPS } from "./colormaps"; +import { colors } from "./colors"; + +// ============================================================================ +// Colormap LUT Application +// ============================================================================ + +/** + * Render uint8 data to canvas with colormap LUT. + */ +export function renderWithColormap( + ctx: CanvasRenderingContext2D, + data: Uint8Array, + width: number, + height: number, + cmapName: string = "inferno" +): void { + const lut = COLORMAPS[cmapName] || COLORMAPS.inferno; + const imgData = ctx.createImageData(width, height); + const rgba = imgData.data; + + for (let i = 0; i < data.length; i++) { + const v = data[i]; + const j = i * 4; + const lutIdx = v * 3; + rgba[j] = lut[lutIdx]; + rgba[j + 1] = lut[lutIdx + 1]; + rgba[j + 2] = lut[lutIdx + 2]; + rgba[j + 3] = 255; + } + ctx.putImageData(imgData, 0, 0); +} + +/** + * Render float32 data to canvas with colormap and optional percentile contrast. + */ +export function renderFloat32WithColormap( + ctx: CanvasRenderingContext2D, + data: Float32Array, + width: number, + height: number, + cmapName: string = "inferno", + percentileLow: number = 0, + percentileHigh: number = 100 +): void { + const lut = COLORMAPS[cmapName] || COLORMAPS.inferno; + + // Calculate min/max using percentiles + const sorted = Float32Array.from(data).sort((a, b) => a - b); + const len = sorted.length; + const loIdx = Math.floor((percentileLow / 100) * (len - 1)); + const hiIdx = Math.floor((percentileHigh / 100) * (len - 1)); + const min = sorted[loIdx]; + const max = sorted[hiIdx]; + const range = max - min || 1; + const scale = 255 / range; + + const imgData = ctx.createImageData(width, height); + const rgba = imgData.data; + + for (let i = 0; i < data.length; i++) { + const v = Math.round((data[i] - min) * scale); + const lutIdx = Math.max(0, Math.min(255, v)) * 3; + const j = i * 4; + rgba[j] = lut[lutIdx]; + rgba[j + 1] = lut[lutIdx + 1]; + rgba[j + 2] = lut[lutIdx + 2]; + rgba[j + 3] = 255; + } + ctx.putImageData(imgData, 0, 0); +} + +/** + * Draw image data to canvas with zoom and pan. + */ +export function drawWithZoomPan( + ctx: CanvasRenderingContext2D, + source: HTMLCanvasElement | ImageData, + canvasWidth: number, + canvasHeight: number, + zoom: number, + panX: number, + panY: number +): void { + ctx.imageSmoothingEnabled = false; + ctx.clearRect(0, 0, canvasWidth, canvasHeight); + ctx.save(); + ctx.translate(panX, panY); + ctx.scale(zoom, zoom); + if (source instanceof ImageData) { + ctx.putImageData(source, 0, 0); + } else { + ctx.drawImage(source, 0, 0); + } + ctx.restore(); +} + +// ============================================================================ +// Scale Bar Rendering +// ============================================================================ + +/** Round to a nice value (1, 2, 5, 10, 20, 50, etc.) */ +export function roundToNiceValue(value: number): number { + if (value <= 0) return 1; + const magnitude = Math.pow(10, Math.floor(Math.log10(value))); + const normalized = value / magnitude; + if (normalized < 1.5) return magnitude; + if (normalized < 3.5) return 2 * magnitude; + if (normalized < 7.5) return 5 * magnitude; + return 10 * magnitude; +} + +/** Format scale bar label with appropriate unit */ +export function formatScaleLabel(value: number, unit: string): string { + const nice = roundToNiceValue(value); + + if (unit === "nm") { + if (nice >= 1000) return `${Math.round(nice / 1000)} µm`; + if (nice >= 1) return `${Math.round(nice)} nm`; + return `${nice.toFixed(2)} nm`; + } else if (unit === "mrad") { + if (nice >= 1000) return `${Math.round(nice / 1000)} rad`; + if (nice >= 1) return `${Math.round(nice)} mrad`; + return `${nice.toFixed(2)} mrad`; + } else if (unit === "1/µm") { + if (nice >= 1000) return `${Math.round(nice / 1000)} 1/nm`; + if (nice >= 1) return `${Math.round(nice)} 1/µm`; + return `${nice.toFixed(2)} 1/µm`; + } else if (unit === "px") { + return `${Math.round(nice)} px`; + } + return `${Math.round(nice)} ${unit}`; +} + +/** + * Draw scale bar on high-DPI canvas. + */ +export function drawScaleBarHiDPI( + canvas: HTMLCanvasElement, + dpr: number, + zoom: number, + pixelSize: number, + unit: string = "nm", + imageWidth: number, + imageHeight: number +): void { + const ctx = canvas.getContext("2d"); + if (!ctx) return; + + ctx.clearRect(0, 0, canvas.width, canvas.height); + ctx.save(); + ctx.scale(dpr, dpr); + + const cssWidth = canvas.width / dpr; + const cssHeight = canvas.height / dpr; + const displayScale = Math.min(cssWidth / imageWidth, cssHeight / imageHeight); + const effectiveZoom = zoom * displayScale; + + // Fixed UI sizes in CSS pixels + const targetBarPx = 60; + const barThickness = 5; + const fontSize = 16; + const margin = 12; + + const targetPhysical = (targetBarPx / effectiveZoom) * pixelSize; + const nicePhysical = roundToNiceValue(targetPhysical); + const barPx = (nicePhysical / pixelSize) * effectiveZoom; + + const barY = cssHeight - margin; + const barX = cssWidth - barPx - margin; + + // Draw with shadow for visibility + ctx.shadowColor = "rgba(0, 0, 0, 0.5)"; + ctx.shadowBlur = 2; + ctx.shadowOffsetX = 1; + ctx.shadowOffsetY = 1; + + ctx.fillStyle = "white"; + ctx.fillRect(barX, barY, barPx, barThickness); + + // Draw label + const label = formatScaleLabel(nicePhysical, unit); + ctx.font = `${fontSize}px -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif`; + ctx.textAlign = "center"; + ctx.textBaseline = "bottom"; + ctx.fillText(label, barX + barPx / 2, barY - 4); + + // Draw zoom indicator (bottom left) + ctx.textAlign = "left"; + ctx.textBaseline = "bottom"; + ctx.fillText(`${zoom.toFixed(1)}×`, margin, cssHeight - margin + barThickness); + + ctx.restore(); +} + +/** + * Draw crosshair on high-DPI canvas. + */ +export function drawCrosshairHiDPI( + canvas: HTMLCanvasElement, + dpr: number, + posX: number, + posY: number, + zoom: number, + panX: number, + panY: number, + imageWidth: number, + imageHeight: number, + isDragging: boolean, + color: string = "rgba(0, 255, 0, 0.9)", + dragColor: string = "rgba(255, 255, 0, 0.9)" +): void { + const ctx = canvas.getContext("2d"); + if (!ctx) return; + + ctx.save(); + ctx.scale(dpr, dpr); + + const cssWidth = canvas.width / dpr; + const cssHeight = canvas.height / dpr; + const displayScale = Math.min(cssWidth / imageWidth, cssHeight / imageHeight); + + const screenX = posX * zoom * displayScale + panX * displayScale; + const screenY = posY * zoom * displayScale + panY * displayScale; + + const crosshairSize = 18; + const lineWidth = 3; + const dotRadius = 6; + + ctx.shadowColor = "rgba(0, 0, 0, 0.5)"; + ctx.shadowBlur = 2; + ctx.shadowOffsetX = 1; + ctx.shadowOffsetY = 1; + + ctx.strokeStyle = isDragging ? dragColor : color; + ctx.lineWidth = lineWidth; + + ctx.beginPath(); + ctx.moveTo(screenX - crosshairSize, screenY); + ctx.lineTo(screenX + crosshairSize, screenY); + ctx.moveTo(screenX, screenY - crosshairSize); + ctx.lineTo(screenX, screenY + crosshairSize); + ctx.stroke(); + + ctx.beginPath(); + ctx.arc(screenX, screenY, dotRadius, 0, 2 * Math.PI); + ctx.stroke(); + + ctx.restore(); +} + +// ============================================================================ +// Export to Blob/ZIP Helpers +// ============================================================================ + +/** + * Convert canvas to PNG blob. + */ +export function canvasToBlob(canvas: HTMLCanvasElement): Promise { + return new Promise((resolve) => { + canvas.toBlob((blob) => resolve(blob!), "image/png"); + }); +} diff --git a/widget/js/core/canvas.ts b/widget/js/core/canvas.ts new file mode 100644 index 00000000..336cee64 --- /dev/null +++ b/widget/js/core/canvas.ts @@ -0,0 +1,276 @@ +/** + * Canvas rendering utilities for image widgets. + * Scale bar, overlays, ROI drawing, etc. + */ + +import { colors } from "./colors"; + +/** Nice values for scale bar lengths */ +const NICE_VALUES = [1, 2, 5, 10, 20, 50, 100, 200, 500, 1000]; + +/** + * Calculate a "nice" scale bar length. + * @param imageWidthNm - Total image width in nm + * @param targetFraction - Target fraction of image width (default 0.2) + * @returns Scale bar length in nm + */ +export function calculateNiceScaleBar( + imageWidthNm: number, + targetFraction: number = 0.2 +): number { + const targetNm = imageWidthNm * targetFraction; + const magnitude = Math.pow(10, Math.floor(Math.log10(targetNm))); + + let barNm = magnitude; + for (const v of NICE_VALUES) { + if (v * magnitude <= targetNm * 1.5) { + barNm = v * magnitude; + } + } + return barNm; +} + +/** + * Round to a "nice" scale bar value (1, 2, 5, 10, 20, 50, 100, etc.) + * This ensures scale bars always show clean integer values. + * @param value - Raw value to round + * @returns Nice rounded value + */ +function roundToNiceValue(value: number): number { + if (value <= 0) return 1; + const magnitude = Math.pow(10, Math.floor(Math.log10(value))); + const normalized = value / magnitude; + // Round to 1, 2, 5, or 10 + if (normalized < 1.5) return magnitude; + if (normalized < 3.5) return 2 * magnitude; + if (normalized < 7.5) return 5 * magnitude; + return 10 * magnitude; +} + +/** + * Format scale bar label with appropriate unit. + * Always displays integer values (no decimals). + * @param angstroms - Length in Angstroms + * @returns Formatted string (e.g., "5 Å", "20 nm", "1 µm") + */ +export function formatScaleBarLabel(angstroms: number): string { + // Round to nice value first + const nice = roundToNiceValue(angstroms); + + if (nice >= 10000) { // >= 1 µm + return `${Math.round(nice / 10000)} µm`; + } + if (nice >= 100) { // >= 10 nm, show in nm + return `${Math.round(nice / 10)} nm`; + } + return `${Math.round(nice)} Å`; +} + +/** + * Draw scale bar on canvas overlay with nice integer labels. + * The bar length is dynamically calculated to accurately represent the physical distance. + * @param ctx - Canvas 2D context + * @param canvasWidth - Canvas width in pixels + * @param canvasHeight - Canvas height in pixels + * @param imageWidth - Image width in data pixels + * @param pixelSizeAngstrom - Pixel size in Angstroms + * @param displayScale - Canvas scale factor (includes zoom) + * @param targetBarLength - Target length of the scale bar in pixels (default 50) + * @param barThickness - Thickness of the scale bar (default 4) + * @param fontSize - Font size for the label (default 16) + */ +export function drawScaleBar( + ctx: CanvasRenderingContext2D, + canvasWidth: number, + canvasHeight: number, + _imageWidth: number, + pixelSizeAngstrom: number, + displayScale: number = 1, + targetBarLength: number = 50, + barThickness: number = 4, + fontSize: number = 16 +): void { + // Fallback if pixelSize is missing or invalid: show bar in pixels + if (pixelSizeAngstrom <= 0) { + const x = canvasWidth - targetBarLength - 10; + const y = canvasHeight - 20; + + ctx.fillStyle = colors.textPrimary; + ctx.fillRect(x, y, targetBarLength, barThickness); + + ctx.font = "11px sans-serif"; + ctx.fillStyle = colors.textPrimary; + ctx.textAlign = "right"; + ctx.fillText(`${targetBarLength} px`, x + targetBarLength, y - 5); + return; + } + + // Calculate what the target bar length represents in Angstroms at current zoom + const targetAngstroms = targetBarLength * pixelSizeAngstrom / displayScale; + + // Round to a nice value + const niceAngstroms = roundToNiceValue(targetAngstroms); + + // Calculate the actual bar length for the nice value + const barLength = (niceAngstroms / pixelSizeAngstrom) * displayScale; + + const x = canvasWidth - barLength - 10; + const y = canvasHeight - 20; + + // Draw bar (length matches the nice value) + ctx.fillStyle = colors.textPrimary; + ctx.fillRect(x, y, barLength, barThickness); + + // Draw label + const label = formatScaleBarLabel(niceAngstroms); + ctx.font = `${fontSize}px sans-serif`; + ctx.fillStyle = colors.textPrimary; + ctx.textAlign = "right"; + ctx.fillText(label, x + barLength, y - 5); +} + +/** + * Draw ROI on canvas overlay with different shapes. + * @param ctx - Canvas 2D context + * @param x - Center X in canvas pixels + * @param y - Center Y in canvas pixels + * @param shape - ROI shape: "circle", "square", or "rectangle" + * @param radius - Radius for circle, or half-size for square + * @param width - Width for rectangle + * @param height - Height for rectangle + * @param active - Whether ROI is being dragged + */ +export function drawROI( + ctx: CanvasRenderingContext2D, + x: number, + y: number, + shape: "circle" | "square" | "rectangle", + radius: number, + width: number, + height: number, + active: boolean = false +): void { + const strokeColor = active ? colors.accentYellow : colors.accentGreen; + ctx.strokeStyle = strokeColor; + ctx.lineWidth = 2; + + if (shape === "circle") { + ctx.beginPath(); + ctx.arc(x, y, radius, 0, Math.PI * 2); + ctx.stroke(); + } else if (shape === "square") { + const size = radius * 2; + ctx.strokeRect(x - radius, y - radius, size, size); + } else if (shape === "rectangle") { + const halfW = width / 2; + const halfH = height / 2; + ctx.strokeRect(x - halfW, y - halfH, width, height); + } + + // Center crosshair - only show while dragging + if (active) { + ctx.beginPath(); + ctx.moveTo(x - 5, y); + ctx.lineTo(x + 5, y); + ctx.moveTo(x, y - 5); + ctx.lineTo(x, y + 5); + ctx.stroke(); + } +} + +/** + * Draw ROI circle on canvas overlay. + * @param ctx - Canvas 2D context + * @param x - Center X in canvas pixels + * @param y - Center Y in canvas pixels + * @param radius - Radius in canvas pixels + * @param active - Whether ROI is being dragged + */ +export function drawROICircle( + ctx: CanvasRenderingContext2D, + x: number, + y: number, + radius: number, + active: boolean = false +): void { + const strokeColor = active ? colors.accentYellow : colors.accentGreen; + + // Circle + ctx.strokeStyle = strokeColor; + ctx.lineWidth = 2; + ctx.beginPath(); + ctx.arc(x, y, radius, 0, Math.PI * 2); + ctx.stroke(); + + // Center crosshair + ctx.beginPath(); + ctx.moveTo(x - 5, y); + ctx.lineTo(x + 5, y); + ctx.moveTo(x, y - 5); + ctx.lineTo(x, y + 5); + ctx.stroke(); +} + +/** + * Draw crosshair on canvas. + * @param ctx - Canvas 2D context + * @param x - Center X + * @param y - Center Y + * @param size - Half-length of crosshair arms + * @param color - Stroke color + */ +export function drawCrosshair( + ctx: CanvasRenderingContext2D, + x: number, + y: number, + size: number = 10, + color: string = colors.accentGreen +): void { + ctx.strokeStyle = color; + ctx.lineWidth = 2; + ctx.beginPath(); + ctx.moveTo(x - size, y); + ctx.lineTo(x + size, y); + ctx.moveTo(x, y - size); + ctx.lineTo(x, y + size); + ctx.stroke(); +} + +/** + * Calculate canvas scale factor for display. + * Aims for approximately targetSize pixels on screen. + * @param width - Image width + * @param height - Image height + * @param targetSize - Target display size in pixels (default 400) + * @returns Integer scale factor >= 1 + */ +export function calculateDisplayScale( + width: number, + height: number, + targetSize: number = 400 +): number { + return Math.max(1, Math.floor(targetSize / Math.max(width, height))); +} + +/** + * Extract bytes from DataView (handles anywidget's byte transfer). + * @param dataView - DataView from anywidget + * @returns Uint8Array of bytes + */ +export function extractBytes(dataView: DataView | ArrayBuffer | Uint8Array): Uint8Array { + if (dataView instanceof Uint8Array) { + return dataView; + } + if (dataView instanceof ArrayBuffer) { + return new Uint8Array(dataView); + } + // DataView from anywidget + if (dataView && "buffer" in dataView) { + return new Uint8Array( + dataView.buffer, + dataView.byteOffset, + dataView.byteLength + ); + } + return new Uint8Array(0); +} diff --git a/widget/js/core/colormaps.ts b/widget/js/core/colormaps.ts new file mode 100644 index 00000000..7047d1b4 --- /dev/null +++ b/widget/js/core/colormaps.ts @@ -0,0 +1,100 @@ +/** + * Colormap definitions and utilities for image display. + * Shared across Show2D, Show3D, Show4D widgets. + */ + +// Control points for interpolation +export const COLORMAP_POINTS: Record = { + inferno: [ + [0, 0, 4], [40, 11, 84], [101, 21, 110], [159, 42, 99], + [212, 72, 66], [245, 125, 21], [252, 193, 57], [252, 255, 164], + ], + viridis: [ + [68, 1, 84], [72, 36, 117], [65, 68, 135], [53, 95, 141], + [42, 120, 142], [33, 145, 140], [34, 168, 132], [68, 191, 112], + [122, 209, 81], [189, 223, 38], [253, 231, 37], + ], + plasma: [ + [13, 8, 135], [75, 3, 161], [126, 3, 168], [168, 34, 150], + [203, 70, 121], [229, 107, 93], [248, 148, 65], [253, 195, 40], [240, 249, 33], + ], + magma: [ + [0, 0, 4], [28, 16, 68], [79, 18, 123], [129, 37, 129], + [181, 54, 122], [229, 80, 100], [251, 135, 97], [254, 194, 135], [252, 253, 191], + ], + hot: [ + [0, 0, 0], [87, 0, 0], [173, 0, 0], [255, 0, 0], + [255, 87, 0], [255, 173, 0], [255, 255, 0], [255, 255, 128], [255, 255, 255], + ], + gray: [[0, 0, 0], [255, 255, 255]], +}; + +/** Available colormap names */ +export const COLORMAP_NAMES = Object.keys(COLORMAP_POINTS); + +/** Create 256-entry LUT from control points */ +export function createColormapLUT(points: number[][]): Uint8Array { + const lut = new Uint8Array(256 * 3); + for (let i = 0; i < 256; i++) { + const t = (i / 255) * (points.length - 1); + const idx = Math.floor(t); + const frac = t - idx; + const p0 = points[Math.min(idx, points.length - 1)]; + const p1 = points[Math.min(idx + 1, points.length - 1)]; + lut[i * 3] = Math.round(p0[0] + frac * (p1[0] - p0[0])); + lut[i * 3 + 1] = Math.round(p0[1] + frac * (p1[1] - p0[1])); + lut[i * 3 + 2] = Math.round(p0[2] + frac * (p1[2] - p0[2])); + } + return lut; +} + +/** Pre-computed LUTs for all colormaps (flat Uint8Array, 256*3 bytes each) */ +export const COLORMAPS: Record = Object.fromEntries( + Object.entries(COLORMAP_POINTS).map(([name, points]) => [name, createColormapLUT(points)]) +); + +/** Apply colormap to a single normalized value [0,1] */ +export function applyColormapValue( + value: number, + cmap: number[][] +): [number, number, number] { + const n = cmap.length - 1; + const t = Math.max(0, Math.min(1, value)) * n; + const i = Math.min(Math.floor(t), n - 1); + const f = t - i; + return [ + Math.round(cmap[i][0] * (1 - f) + cmap[i + 1][0] * f), + Math.round(cmap[i][1] * (1 - f) + cmap[i + 1][1] * f), + Math.round(cmap[i][2] * (1 - f) + cmap[i + 1][2] * f), + ]; +} + +/** + * Apply colormap to uint8 grayscale data, returning RGBA ImageData. + * @param data - Uint8Array of grayscale values (0-255) + * @param width - Image width + * @param height - Image height + * @param cmapName - Name of colormap to use + * @returns Uint8ClampedArray of RGBA values + */ +export function applyColormapToImage( + data: Uint8Array, + width: number, + height: number, + cmapName: string +): Uint8ClampedArray { + const lut = COLORMAPS[cmapName] || COLORMAPS.inferno; + const rgba = new Uint8ClampedArray(width * height * 4); + + for (let i = 0; i < data.length; i++) { + const v = Math.max(0, Math.min(255, data[i])); + const j = i * 4; + const lutIdx = v * 3; + rgba[j] = lut[lutIdx]; + rgba[j + 1] = lut[lutIdx + 1]; + rgba[j + 2] = lut[lutIdx + 2]; + rgba[j + 3] = 255; + } + + return rgba; +} diff --git a/widget/js/core/colors.ts b/widget/js/core/colors.ts new file mode 100644 index 00000000..996ee68d --- /dev/null +++ b/widget/js/core/colors.ts @@ -0,0 +1,71 @@ +/** + * Shared color palette for all bobleesj.widget components. + * Single source of truth for theming across Show2D, Show3D, Show4D, and Reconstruct. + */ + +// Primary color definitions (SCREAMING_SNAKE_CASE for constants) +export const COLORS = { + // Backgrounds + BG: "#1a1a1a", + BG_PANEL: "#222", + BG_INPUT: "#333", + BG_CANVAS: "#000", + + // Borders + BORDER: "#444", + BORDER_LIGHT: "#555", + + // Text + TEXT_PRIMARY: "#fff", + TEXT_SECONDARY: "#aaa", + TEXT_MUTED: "#888", + TEXT_DIM: "#666", + + // Accent colors + ACCENT: "#0af", + ACCENT_GREEN: "#0f0", + ACCENT_RED: "#f00", + ACCENT_ORANGE: "#fa0", + ACCENT_CYAN: "#0cf", + ACCENT_YELLOW: "#ff0", +} as const; + +// Convenience alias with camelCase keys (for existing widget code) +export const colors = { + bg: COLORS.BG, + bgPanel: COLORS.BG_PANEL, + bgInput: COLORS.BG_INPUT, + bgCanvas: COLORS.BG_CANVAS, + border: COLORS.BORDER, + borderLight: COLORS.BORDER_LIGHT, + textPrimary: COLORS.TEXT_PRIMARY, + textSecondary: COLORS.TEXT_SECONDARY, + textMuted: COLORS.TEXT_MUTED, + textDim: COLORS.TEXT_DIM, + accent: COLORS.ACCENT, + accentGreen: COLORS.ACCENT_GREEN, + accentRed: COLORS.ACCENT_RED, + accentOrange: COLORS.ACCENT_ORANGE, + accentCyan: COLORS.ACCENT_CYAN, + accentYellow: COLORS.ACCENT_YELLOW, +} as const; + +// CSS variable export for vanilla JS widgets +export const cssVars = ` + --bg: ${COLORS.BG}; + --bg-panel: ${COLORS.BG_PANEL}; + --bg-input: ${COLORS.BG_INPUT}; + --bg-canvas: ${COLORS.BG_CANVAS}; + --border: ${COLORS.BORDER}; + --border-light: ${COLORS.BORDER_LIGHT}; + --text-primary: ${COLORS.TEXT_PRIMARY}; + --text-secondary: ${COLORS.TEXT_SECONDARY}; + --text-muted: ${COLORS.TEXT_MUTED}; + --text-dim: ${COLORS.TEXT_DIM}; + --accent: ${COLORS.ACCENT}; + --accent-green: ${COLORS.ACCENT_GREEN}; + --accent-red: ${COLORS.ACCENT_RED}; + --accent-orange: ${COLORS.ACCENT_ORANGE}; + --accent-cyan: ${COLORS.ACCENT_CYAN}; + --accent-yellow: ${COLORS.ACCENT_YELLOW}; +`; diff --git a/widget/js/core/export.ts b/widget/js/core/export.ts new file mode 100644 index 00000000..6a0324e6 --- /dev/null +++ b/widget/js/core/export.ts @@ -0,0 +1,135 @@ +/** + * Export utilities for downloading widget canvases as images. + * Composites multiple canvas layers and burns in overlays (scale bars, etc). + */ + +/** + * Generate a timestamped filename. + * @param prefix - Filename prefix (e.g., "show2d", "show3d") + * @param extension - File extension (default: "png") + */ +export function generateFilename(prefix: string, extension: string = "png"): string { + const now = new Date(); + const timestamp = now.toISOString() + .replace(/[:.]/g, "-") + .slice(0, 19); + return `${prefix}_${timestamp}.${extension}`; +} + +/** + * Composite multiple canvases into a single canvas. + * Layers are drawn in order (first = bottom, last = top). + * @param layers - Array of canvases to composite + * @param width - Output width + * @param height - Output height + */ +export function compositeCanvases( + layers: (HTMLCanvasElement | null)[], + width: number, + height: number +): HTMLCanvasElement { + const output = document.createElement("canvas"); + output.width = width; + output.height = height; + const ctx = output.getContext("2d"); + + if (ctx) { + // Fill with black background + ctx.fillStyle = "#000"; + ctx.fillRect(0, 0, width, height); + + // Draw each layer + for (const layer of layers) { + if (layer) { + ctx.drawImage(layer, 0, 0, width, height); + } + } + } + + return output; +} + +/** + * Download a canvas as a PNG file. + * @param canvas - The canvas to download + * @param filename - Output filename + */ +export function downloadCanvas(canvas: HTMLCanvasElement, filename: string): void { + canvas.toBlob((blob) => { + if (!blob) return; + + const url = URL.createObjectURL(blob); + const link = document.createElement("a"); + link.href = url; + link.download = filename; + link.click(); + + // Cleanup + URL.revokeObjectURL(url); + }, "image/png"); +} + +/** + * Export a widget's canvas with overlays burned in. + * @param imageCanvas - Main image canvas + * @param overlayCanvas - Overlay canvas (scale bar, etc) + * @param prefix - Filename prefix + * @param label - Optional label to append to filename + */ +export function exportWithOverlay( + imageCanvas: HTMLCanvasElement | null, + overlayCanvas: HTMLCanvasElement | null, + prefix: string, + label?: string +): void { + if (!imageCanvas) return; + + const width = imageCanvas.width; + const height = imageCanvas.height; + + const output = compositeCanvases([imageCanvas, overlayCanvas], width, height); + + // Generate filename with optional label + const cleanLabel = label ? `_${label.replace(/[^a-zA-Z0-9]/g, "_")}` : ""; + const filename = generateFilename(`${prefix}${cleanLabel}`); + + downloadCanvas(output, filename); +} + +/** + * Export multiple canvases as a ZIP file (for galleries). + * Requires JSZip to be available. + */ +export async function exportGalleryAsZip( + canvases: { image: HTMLCanvasElement | null; overlay: HTMLCanvasElement | null; label: string }[], + prefix: string +): Promise { + // Dynamic import to avoid bundling JSZip if not needed + const JSZip = (await import("jszip")).default; + const zip = new JSZip(); + + const timestamp = new Date().toISOString().replace(/[:.]/g, "-").slice(0, 19); + + for (let i = 0; i < canvases.length; i++) { + const { image, overlay, label } = canvases[i]; + if (!image) continue; + + const output = compositeCanvases([image, overlay], image.width, image.height); + const cleanLabel = label.replace(/[^a-zA-Z0-9]/g, "_"); + const filename = `${String(i + 1).padStart(3, "0")}_${cleanLabel}.png`; + + const blob = await new Promise((resolve) => { + output.toBlob((b) => resolve(b!), "image/png"); + }); + + zip.file(filename, blob); + } + + const zipBlob = await zip.generateAsync({ type: "blob" }); + const url = URL.createObjectURL(zipBlob); + const link = document.createElement("a"); + link.href = url; + link.download = `${prefix}_gallery_${timestamp}.zip`; + link.click(); + URL.revokeObjectURL(url); +} diff --git a/widget/js/core/fft-utils.ts b/widget/js/core/fft-utils.ts new file mode 100644 index 00000000..8d6a5626 --- /dev/null +++ b/widget/js/core/fft-utils.ts @@ -0,0 +1,161 @@ +/** + * FFT and histogram rendering utilities. + * Shared across Show2D and Show3D widgets. + */ + +import { COLORMAPS } from "./colormaps"; +import { colors } from "./colors"; + +// ============================================================================ +// FFT Rendering +// ============================================================================ + +/** + * Render FFT magnitude to canvas with log scale and colormap. + * @param ctx - Canvas 2D context + * @param fftMag - FFT magnitude data (Float32Array) + * @param width - Image width + * @param height - Image height + * @param panelSize - Canvas panel size + * @param zoom - Zoom level (default 3 for center detail) + * @param panX - Pan X offset + * @param panY - Pan Y offset + * @param cmapName - Colormap name (default "inferno") + */ +export function renderFFT( + ctx: CanvasRenderingContext2D, + fftMag: Float32Array, + width: number, + height: number, + panelSize: number, + zoom: number = 3, + panX: number = 0, + panY: number = 0, + cmapName: string = "inferno" +): void { + // Log scale and normalize + let min = Infinity; + let max = -Infinity; + const logData = new Float32Array(fftMag.length); + + for (let i = 0; i < fftMag.length; i++) { + logData[i] = Math.log(1 + fftMag[i]); + if (logData[i] < min) min = logData[i]; + if (logData[i] > max) max = logData[i]; + } + + const lut = COLORMAPS[cmapName] || COLORMAPS.inferno; + + // Create offscreen canvas at native resolution + const offscreen = document.createElement("canvas"); + offscreen.width = width; + offscreen.height = height; + const offCtx = offscreen.getContext("2d"); + if (!offCtx) return; + + const imgData = offCtx.createImageData(width, height); + const range = max - min || 1; + + for (let i = 0; i < logData.length; i++) { + const v = Math.floor(((logData[i] - min) / range) * 255); + const j = i * 4; + imgData.data[j] = lut[v * 3]; + imgData.data[j + 1] = lut[v * 3 + 1]; + imgData.data[j + 2] = lut[v * 3 + 2]; + imgData.data[j + 3] = 255; + } + offCtx.putImageData(imgData, 0, 0); + + // Draw with zoom/pan - center the zoomed view + const scale = panelSize / Math.max(width, height); + ctx.imageSmoothingEnabled = false; + ctx.clearRect(0, 0, panelSize, panelSize); + ctx.save(); + + const centerOffsetX = (panelSize - width * scale * zoom) / 2 + panX; + const centerOffsetY = (panelSize - height * scale * zoom) / 2 + panY; + + ctx.translate(centerOffsetX, centerOffsetY); + ctx.scale(zoom, zoom); + ctx.drawImage(offscreen, 0, 0, width * scale, height * scale); + ctx.restore(); +} + +// ============================================================================ +// Histogram Rendering +// ============================================================================ + +/** + * Render histogram to canvas. + * @param ctx - Canvas 2D context + * @param counts - Histogram bin counts + * @param panelSize - Canvas panel size + * @param accentColor - Bar color (default: accent blue) + * @param bgColor - Background color (default: panel background) + */ +export function renderHistogram( + ctx: CanvasRenderingContext2D, + counts: number[], + panelSize: number, + accentColor: string = colors.accent, + bgColor: string = colors.bgPanel +): void { + const w = panelSize; + const h = panelSize; + + // Clear and fill background + ctx.fillStyle = bgColor; + ctx.fillRect(0, 0, w, h); + + // Only draw bars if we have data + if (!counts || counts.length === 0) return; + + const maxCount = Math.max(...counts); + if (maxCount === 0) return; + + // Add padding for centering + const padding = 8; + const drawWidth = w - 2 * padding; + const drawHeight = h - padding - 5; // 5px bottom margin + const barWidth = drawWidth / counts.length; + + ctx.fillStyle = accentColor; + for (let i = 0; i < counts.length; i++) { + const barHeight = (counts[i] / maxCount) * drawHeight; + ctx.fillRect(padding + i * barWidth, h - padding - barHeight, barWidth - 1, barHeight); + } +} + +// ============================================================================ +// FFT Shift (move DC component to center) +// ============================================================================ + +/** + * Shift FFT data to center the DC component. + * Modifies data in place. + */ +export function fftshift(data: Float32Array, width: number, height: number): void { + const halfW = width >> 1; + const halfH = height >> 1; + const temp = new Float32Array(width * height); + + for (let y = 0; y < height; y++) { + for (let x = 0; x < width; x++) { + const newY = (y + halfH) % height; + const newX = (x + halfW) % width; + temp[newY * width + newX] = data[y * width + x]; + } + } + data.set(temp); +} + +/** + * Compute FFT magnitude from real and imaginary parts. + */ +export function computeMagnitude(real: Float32Array, imag: Float32Array): Float32Array { + const mag = new Float32Array(real.length); + for (let i = 0; i < real.length; i++) { + mag[i] = Math.sqrt(real[i] ** 2 + imag[i] ** 2); + } + return mag; +} diff --git a/widget/js/core/format.ts b/widget/js/core/format.ts new file mode 100644 index 00000000..a168c68c --- /dev/null +++ b/widget/js/core/format.ts @@ -0,0 +1,58 @@ +/** + * Number and text formatting utilities. + */ + +/** + * Format a number for display with appropriate precision. + * Uses exponential notation for very large or small values. + * @param val - Value to format + * @param decimals - Number of decimal places (default 2) + * @returns Formatted string + */ +export function formatNumber(val: number, decimals: number = 2): string { + if (val === 0) return "0"; + if (Math.abs(val) >= 1000 || Math.abs(val) < 0.01) { + return val.toExponential(decimals); + } + return val.toFixed(decimals); +} + +/** + * Format bytes as human-readable size. + * @param bytes - Number of bytes + * @returns Formatted string (e.g., "1.5 MB") + */ +export function formatBytes(bytes: number): string { + if (bytes === 0) return "0 B"; + const k = 1024; + const sizes = ["B", "KB", "MB", "GB"]; + const i = Math.floor(Math.log(bytes) / Math.log(k)); + return `${(bytes / Math.pow(k, i)).toFixed(1)} ${sizes[i]}`; +} + +/** + * Format time duration. + * @param seconds - Duration in seconds + * @returns Formatted string (e.g., "1.5 s" or "150 ms") + */ +export function formatDuration(seconds: number): string { + if (seconds < 0.001) { + return `${(seconds * 1e6).toFixed(0)} µs`; + } + if (seconds < 1) { + return `${(seconds * 1000).toFixed(1)} ms`; + } + if (seconds < 60) { + return `${seconds.toFixed(2)} s`; + } + const mins = Math.floor(seconds / 60); + const secs = seconds % 60; + return `${mins}m ${secs.toFixed(0)}s`; +} + +/** + * Clamp a value between min and max. + */ +export function clamp(val: number, min: number, max: number): number { + return Math.max(min, Math.min(max, val)); +} diff --git a/widget/js/core/hooks.ts b/widget/js/core/hooks.ts new file mode 100644 index 00000000..c5d59a39 --- /dev/null +++ b/widget/js/core/hooks.ts @@ -0,0 +1,211 @@ +/** + * Shared React hooks for widget functionality. + * Provides reusable zoom/pan and resize logic. + */ + +import * as React from "react"; + +// ============================================================================ +// Constants +// ============================================================================ +export const ZOOM_LIMITS = { + MIN: 0.5, + MAX: 10, + WHEEL_IN: 1.1, + WHEEL_OUT: 0.9, +} as const; + +// ============================================================================ +// Types +// ============================================================================ +export interface ZoomPanState { + zoom: number; + panX: number; + panY: number; +} + +export const DEFAULT_ZOOM_PAN: ZoomPanState = { + zoom: 1, + panX: 0, + panY: 0, +}; + +// ============================================================================ +// useZoomPan Hook +// ============================================================================ +export interface UseZoomPanOptions { + canvasRef: React.RefObject; + canvasWidth: number; + canvasHeight: number; + initialState?: ZoomPanState; +} + +export interface UseZoomPanResult { + state: ZoomPanState; + setState: React.Dispatch>; + reset: () => void; + handleWheel: (e: React.WheelEvent) => void; + handleMouseDown: (e: React.MouseEvent) => void; + handleMouseMove: (e: React.MouseEvent) => void; + handleMouseUp: () => void; + handleDoubleClick: () => void; + isDragging: boolean; +} + +export function useZoomPan(options: UseZoomPanOptions): UseZoomPanResult { + const { canvasRef, canvasWidth, canvasHeight, initialState = DEFAULT_ZOOM_PAN } = options; + + const [state, setState] = React.useState(initialState); + const [isDragging, setIsDragging] = React.useState(false); + const [dragStart, setDragStart] = React.useState<{ x: number; y: number; panX: number; panY: number } | null>(null); + + const reset = React.useCallback(() => { + setState(DEFAULT_ZOOM_PAN); + }, []); + + const handleWheel = React.useCallback((e: React.WheelEvent) => { + const canvas = canvasRef.current; + if (!canvas) return; + + const rect = canvas.getBoundingClientRect(); + const scaleX = canvas.width / rect.width; + const scaleY = canvas.height / rect.height; + + // Mouse position in canvas coordinates + const mouseX = (e.clientX - rect.left) * scaleX; + const mouseY = (e.clientY - rect.top) * scaleY; + + // Canvas center + const cx = canvasWidth / 2; + const cy = canvasHeight / 2; + + setState(prev => { + // Calculate position in image space + const imageX = (mouseX - cx - prev.panX) / prev.zoom + cx; + const imageY = (mouseY - cy - prev.panY) / prev.zoom + cy; + + // Apply zoom factor + const zoomFactor = e.deltaY > 0 ? ZOOM_LIMITS.WHEEL_OUT : ZOOM_LIMITS.WHEEL_IN; + const newZoom = Math.max(ZOOM_LIMITS.MIN, Math.min(ZOOM_LIMITS.MAX, prev.zoom * zoomFactor)); + + // Calculate new pan to keep mouse position fixed + const newPanX = mouseX - (imageX - cx) * newZoom - cx; + const newPanY = mouseY - (imageY - cy) * newZoom - cy; + + return { zoom: newZoom, panX: newPanX, panY: newPanY }; + }); + }, [canvasRef, canvasWidth, canvasHeight]); + + const handleMouseDown = React.useCallback((e: React.MouseEvent) => { + setIsDragging(true); + setDragStart({ x: e.clientX, y: e.clientY, panX: state.panX, panY: state.panY }); + }, [state.panX, state.panY]); + + const handleMouseMove = React.useCallback((e: React.MouseEvent) => { + if (!isDragging || !dragStart) return; + + const canvas = canvasRef.current; + if (!canvas) return; + + const rect = canvas.getBoundingClientRect(); + const scaleX = canvas.width / rect.width; + const scaleY = canvas.height / rect.height; + + const dx = (e.clientX - dragStart.x) * scaleX; + const dy = (e.clientY - dragStart.y) * scaleY; + + setState(prev => ({ ...prev, panX: dragStart.panX + dx, panY: dragStart.panY + dy })); + }, [isDragging, dragStart, canvasRef]); + + const handleMouseUp = React.useCallback(() => { + setIsDragging(false); + setDragStart(null); + }, []); + + const handleDoubleClick = React.useCallback(() => { + reset(); + }, [reset]); + + return { + state, + setState, + reset, + handleWheel, + handleMouseDown, + handleMouseMove, + handleMouseUp, + handleDoubleClick, + isDragging, + }; +} + +// ============================================================================ +// useResize Hook +// ============================================================================ +export interface UseResizeOptions { + initialSize: number; + minSize?: number; + maxSize?: number; +} + +export interface UseResizeResult { + size: number; + setSize: React.Dispatch>; + isResizing: boolean; + handleResizeStart: (e: React.MouseEvent) => void; +} + +export function useResize(options: UseResizeOptions): UseResizeResult { + const { initialSize, minSize = 80, maxSize = 600 } = options; + + const [size, setSize] = React.useState(initialSize); + const [isResizing, setIsResizing] = React.useState(false); + const [resizeStart, setResizeStart] = React.useState<{ x: number; y: number; size: number } | null>(null); + + const handleResizeStart = React.useCallback((e: React.MouseEvent) => { + e.stopPropagation(); + e.preventDefault(); + setIsResizing(true); + setResizeStart({ x: e.clientX, y: e.clientY, size }); + }, [size]); + + React.useEffect(() => { + if (!isResizing || !resizeStart) return; + + const handleMouseMove = (e: MouseEvent) => { + const delta = Math.max(e.clientX - resizeStart.x, e.clientY - resizeStart.y); + const newSize = Math.max(minSize, Math.min(maxSize, resizeStart.size + delta)); + setSize(newSize); + }; + + const handleMouseUp = () => { + setIsResizing(false); + setResizeStart(null); + }; + + document.addEventListener("mousemove", handleMouseMove); + document.addEventListener("mouseup", handleMouseUp); + return () => { + document.removeEventListener("mousemove", handleMouseMove); + document.removeEventListener("mouseup", handleMouseUp); + }; + }, [isResizing, resizeStart, minSize, maxSize]); + + return { size, setSize, isResizing, handleResizeStart }; +} + +// ============================================================================ +// usePreventScroll Hook +// ============================================================================ +export function usePreventScroll(refs: React.RefObject[]): void { + React.useEffect(() => { + const preventDefault = (e: WheelEvent) => e.preventDefault(); + const elements = refs.map(ref => ref.current).filter(Boolean) as HTMLElement[]; + + elements.forEach(el => el.addEventListener("wheel", preventDefault, { passive: false })); + + return () => { + elements.forEach(el => el.removeEventListener("wheel", preventDefault)); + }; + }, [refs]); +} diff --git a/widget/js/core/index.ts b/widget/js/core/index.ts new file mode 100644 index 00000000..45769f87 --- /dev/null +++ b/widget/js/core/index.ts @@ -0,0 +1,86 @@ +/** + * Core utilities for bobleesj.widget components. + * Re-exports all shared modules. + */ + +// Colors and theming +export { COLORS, colors, cssVars } from "./colors"; + +// Colormaps +export { + COLORMAP_NAMES, + COLORMAP_POINTS, + COLORMAPS, + applyColormapToImage, + applyColormapValue, + createColormapLUT, +} from "./colormaps"; + +// Canvas utilities +export { + calculateDisplayScale, + calculateNiceScaleBar, + drawCrosshair, + drawROI, + drawROICircle, + drawScaleBar, + extractBytes, + formatScaleBarLabel, +} from "./canvas"; + +// Formatting +export { + clamp, + formatBytes, + formatDuration, + formatNumber, +} from "./format"; + +// Base CSS +export { baseCSS } from "./styles"; + +// FFT and histogram utilities +export { + computeMagnitude, + fftshift, + renderFFT, + renderHistogram, +} from "./fft-utils"; + +// React hooks +export { + DEFAULT_ZOOM_PAN, + ZOOM_LIMITS, + usePreventScroll, + useResize, + useZoomPan, + type UseResizeOptions, + type UseResizeResult, + type UseZoomPanOptions, + type UseZoomPanResult, + type ZoomPanState, +} from "./hooks"; + +// WebGPU hook +export { useWebGPU, type UseWebGPUResult } from "./webgpu-hook"; + +// Advanced canvas utilities (high-DPI, colormap rendering) +export { + canvasToBlob, + drawCrosshairHiDPI, + drawScaleBarHiDPI, + drawWithZoomPan, + formatScaleLabel, + renderFloat32WithColormap, + renderWithColormap, + roundToNiceValue, +} from "./canvas-utils"; + +// Export utilities +export { + compositeCanvases, + downloadCanvas, + exportGalleryAsZip, + exportWithOverlay, + generateFilename, +} from "./export"; diff --git a/widget/js/core/styles.ts b/widget/js/core/styles.ts new file mode 100644 index 00000000..fbc3c8cf --- /dev/null +++ b/widget/js/core/styles.ts @@ -0,0 +1,295 @@ +/** + * Shared CSS for widget components. + * Base styles used by Show2D, Show3D, and other vanilla JS widgets. + */ + +export const baseCSS = ` +/* ============================================================================ + Base Styles - Shared across Show2D, Show3D + ============================================================================ */ + +/* Root container */ +.widget-root { + font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif; + background-color: var(--bg, #1a1a1a); + color: var(--text-primary, #fff); + padding: 12px; + border-radius: 6px; + display: inline-block; + min-width: 320px; + + --bg: #1a1a1a; + --bg-panel: #222; + --bg-input: #333; + --bg-canvas: #000; + --border: #444; + --border-light: #555; + --text-primary: #fff; + --text-secondary: #aaa; + --text-muted: #888; + --text-dim: #666; + --accent: #0af; + --accent-green: #0f0; + --accent-red: #f00; +} + +.widget-root:focus { + outline: 2px solid var(--accent); + outline-offset: 2px; +} + +/* Title bar */ +.widget-title-bar { + margin-bottom: 8px; +} + +.widget-title { + color: var(--accent); + font-weight: bold; + font-size: 13px; +} + +/* Canvas container */ +.widget-canvas-container { + position: relative; + background-color: var(--bg-canvas); + border: 1px solid var(--border); + border-radius: 4px; + overflow: hidden; +} + +.widget-canvas { + display: block; + image-rendering: pixelated; + image-rendering: crisp-edges; +} + +.widget-overlay { + position: absolute; + top: 0; + left: 0; + pointer-events: none; +} + +/* Panels */ +.widget-panel { + background-color: var(--bg-panel); + border: 1px solid var(--border); + border-radius: 4px; + padding: 6px; +} + +.widget-panel-title { + font-size: 10px; + color: var(--text-muted); + text-transform: uppercase; + margin-bottom: 4px; +} + +/* Control group */ +.widget-control-group { + display: flex; + align-items: center; + gap: 6px; + background-color: var(--bg-panel); + padding: 4px 8px; + border-radius: 4px; + border: 1px solid var(--border); +} + +/* Buttons */ +.widget-btn { + background-color: var(--bg-input); + border: 1px solid var(--border-light); + color: var(--text-secondary); + min-width: 32px; + height: 28px; + border-radius: 4px; + cursor: pointer; + font-size: 12px; + display: flex; + align-items: center; + justify-content: center; + transition: all 0.15s; + padding: 0 8px; +} + +.widget-btn:hover { + background-color: var(--border); + color: var(--text-primary); +} + +.widget-btn:active, +.widget-btn-active { + background-color: var(--accent); + color: #000; + border-color: var(--accent); +} + +.widget-btn-primary { + background-color: var(--accent); + color: #000; + border-color: var(--accent); +} + +.widget-btn-primary:hover { + background-color: #0cf; +} + +/* Slider */ +.widget-slider { + flex: 1; + height: 6px; + -webkit-appearance: none; + appearance: none; + background: var(--border); + border-radius: 3px; + cursor: pointer; +} + +.widget-slider::-webkit-slider-thumb { + -webkit-appearance: none; + width: 14px; + height: 14px; + background: var(--accent); + border-radius: 50%; + cursor: pointer; + border: 2px solid var(--bg); + box-shadow: 0 1px 3px rgba(0, 0, 0, 0.4); +} + +.widget-slider::-moz-range-thumb { + width: 14px; + height: 14px; + background: var(--accent); + border-radius: 50%; + cursor: pointer; + border: 2px solid var(--bg); +} + +.widget-slider:focus { + outline: none; +} + +/* Inputs */ +.widget-input { + background-color: var(--bg-input); + border: 1px solid var(--border-light); + color: var(--text-secondary); + border-radius: 3px; + padding: 4px 6px; + font-size: 11px; + font-family: monospace; +} + +.widget-input:focus { + outline: none; + border-color: var(--accent); + color: var(--text-primary); +} + +.widget-input-small { + width: 45px; + height: 24px; + text-align: center; +} + +/* Toggles / Checkboxes */ +.widget-toggle { + display: flex; + align-items: center; + gap: 4px; + font-size: 11px; + color: var(--text-muted); + cursor: pointer; + user-select: none; +} + +.widget-toggle:hover { + color: var(--text-primary); +} + +.widget-toggle input[type="checkbox"] { + width: 14px; + height: 14px; + accent-color: var(--accent); + cursor: pointer; +} + +/* Select */ +.widget-select { + background-color: var(--bg-input); + border: 1px solid var(--border-light); + color: var(--text-secondary); + border-radius: 3px; + padding: 4px 8px; + font-size: 11px; + cursor: pointer; +} + +.widget-select:focus { + outline: none; + border-color: var(--accent); +} + +/* Stats bar */ +.widget-stats-bar { + display: flex; + flex-wrap: wrap; + gap: 16px; + background-color: var(--bg-panel); + padding: 6px 12px; + border-radius: 4px; + border: 1px solid var(--border); +} + +.widget-stat-item { + display: flex; + gap: 6px; + align-items: baseline; +} + +.widget-stat-label { + font-size: 10px; + color: var(--text-dim); +} + +.widget-stat-value { + font-size: 11px; + font-family: monospace; + color: var(--accent); +} + +/* Labels */ +.widget-label { + color: var(--text-secondary); + font-size: 11px; +} + +.widget-label-small { + color: var(--text-dim); + font-size: 10px; +} + +/* Layout helpers */ +.widget-row { + display: flex; + align-items: center; + gap: 8px; +} + +.widget-col { + display: flex; + flex-direction: column; + gap: 8px; +} + +.widget-flex { + flex: 1; +} + +/* Monospace text */ +.widget-mono { + font-family: monospace; +} +`; diff --git a/widget/js/core/webgpu-hook.ts b/widget/js/core/webgpu-hook.ts new file mode 100644 index 00000000..3c72f846 --- /dev/null +++ b/widget/js/core/webgpu-hook.ts @@ -0,0 +1,37 @@ +/** + * Shared WebGPU FFT hook for all widgets. + * Provides consistent GPU acceleration across Show4DSTEM and Reconstruct. + */ + +import * as React from "react"; +import { getWebGPUFFT, WebGPUFFT } from "../webgpu-fft"; + +export interface UseWebGPUResult { + gpuFFT: WebGPUFFT | null; + gpuReady: boolean; +} + +/** + * Hook to initialize WebGPU FFT on mount. + * Returns null if WebGPU is not available (falls back to CPU). + */ +export function useWebGPU(): UseWebGPUResult { + const gpuFFTRef = React.useRef(null); + const [gpuReady, setGpuReady] = React.useState(false); + + React.useEffect(() => { + let cancelled = false; + + getWebGPUFFT().then(fft => { + if (cancelled) return; + if (fft) { + gpuFFTRef.current = fft; + setGpuReady(true); + } + }); + + return () => { cancelled = true; }; + }, []); + + return { gpuFFT: gpuFFTRef.current, gpuReady }; +} diff --git a/widget/js/index.jsx b/widget/js/index.jsx deleted file mode 100644 index a3341f63..00000000 --- a/widget/js/index.jsx +++ /dev/null @@ -1,33 +0,0 @@ -import * as React from "react"; -import * as ReactDOM from "react-dom/client"; - -function Widget({ model }) { - const [count, setCount] = React.useState(model.get("count")); - - React.useEffect(() => { - const onChange = () => setCount(model.get("count")); - model.on("change:count", onChange); - return () => model.off("change:count", onChange); - }, [model]); - - const handleClick = () => { - model.set("count", count + 1); - model.save_changes(); - }; - - return ( -
-

quantem.widget

-

Count: {count}

- -
- ); -} - -function render({ model, el }) { - const root = ReactDOM.createRoot(el); - root.render(); - return () => root.unmount(); -} - -export default { render }; diff --git a/widget/js/shared.ts b/widget/js/shared.ts new file mode 100644 index 00000000..cb14fb33 --- /dev/null +++ b/widget/js/shared.ts @@ -0,0 +1,221 @@ +/** + * Shared utilities for widget components. + * Contains CPU FFT fallback and band-pass filtering. + * Re-exports commonly used utilities from core. + */ + +// Re-export colormaps from core +export { COLORMAP_NAMES, COLORMAP_POINTS, COLORMAPS, createColormapLUT, applyColormapValue, applyColormapToImage } from "./core/colormaps"; + +// Re-export fftshift from core (also available here for backward compatibility) +export { fftshift, computeMagnitude, renderFFT, renderHistogram } from "./core/fft-utils"; + +// Re-export zoom constants from core hooks +export { ZOOM_LIMITS } from "./core/hooks"; +export const MIN_ZOOM = 0.5; // Legacy alias +export const MAX_ZOOM = 10; // Legacy alias + +// ============================================================================ +// CPU FFT Implementation (Cooley-Tukey radix-2) - Fallback when WebGPU unavailable +// Supports ANY size via automatic zero-padding to next power of 2 +// ============================================================================ + +/** Get next power of 2 >= n */ +function nextPow2(n: number): number { + return Math.pow(2, Math.ceil(Math.log2(n))); +} + +/** Check if n is a power of 2 */ +function isPow2(n: number): boolean { + return n > 0 && (n & (n - 1)) === 0; +} + +/** Internal 1D FFT - requires power-of-2 size */ +function fft1dPow2(real: Float32Array, imag: Float32Array, inverse: boolean = false) { + const n = real.length; + if (n <= 1) return; + + // Bit-reversal permutation + let j = 0; + for (let i = 0; i < n - 1; i++) { + if (i < j) { + [real[i], real[j]] = [real[j], real[i]]; + [imag[i], imag[j]] = [imag[j], imag[i]]; + } + let k = n >> 1; + while (k <= j) { j -= k; k >>= 1; } + j += k; + } + + // Cooley-Tukey FFT + const sign = inverse ? 1 : -1; + for (let len = 2; len <= n; len <<= 1) { + const halfLen = len >> 1; + const angle = (sign * 2 * Math.PI) / len; + const wReal = Math.cos(angle); + const wImag = Math.sin(angle); + + for (let i = 0; i < n; i += len) { + let curReal = 1, curImag = 0; + for (let k = 0; k < halfLen; k++) { + const evenIdx = i + k; + const oddIdx = i + k + halfLen; + + const tReal = curReal * real[oddIdx] - curImag * imag[oddIdx]; + const tImag = curReal * imag[oddIdx] + curImag * real[oddIdx]; + + real[oddIdx] = real[evenIdx] - tReal; + imag[oddIdx] = imag[evenIdx] - tImag; + real[evenIdx] += tReal; + imag[evenIdx] += tImag; + + const newReal = curReal * wReal - curImag * wImag; + curImag = curReal * wImag + curImag * wReal; + curReal = newReal; + } + } + } + + if (inverse) { + for (let i = 0; i < n; i++) { + real[i] /= n; + imag[i] /= n; + } + } +} + +/** + * 1D FFT - supports ANY size via zero-padding + * Modifies arrays in-place + */ +export function fft1d(real: Float32Array, imag: Float32Array, inverse: boolean = false) { + const n = real.length; + if (isPow2(n)) { + fft1dPow2(real, imag, inverse); + return; + } + + // Pad to next power of 2 + const paddedN = nextPow2(n); + const paddedReal = new Float32Array(paddedN); + const paddedImag = new Float32Array(paddedN); + paddedReal.set(real); + paddedImag.set(imag); + + fft1dPow2(paddedReal, paddedImag, inverse); + + // Copy back (truncate to original size) + for (let i = 0; i < n; i++) { + real[i] = paddedReal[i]; + imag[i] = paddedImag[i]; + } +} + +/** + * 2D FFT - supports ANY size via zero-padding + * Modifies arrays in-place + */ +export function fft2d(real: Float32Array, imag: Float32Array, width: number, height: number, inverse: boolean = false) { + const paddedW = nextPow2(width); + const paddedH = nextPow2(height); + const needsPadding = paddedW !== width || paddedH !== height; + + // Work arrays (padded if needed) + let workReal: Float32Array; + let workImag: Float32Array; + + if (needsPadding) { + workReal = new Float32Array(paddedW * paddedH); + workImag = new Float32Array(paddedW * paddedH); + // Copy original data into top-left corner + for (let y = 0; y < height; y++) { + for (let x = 0; x < width; x++) { + const srcIdx = y * width + x; + const dstIdx = y * paddedW + x; + workReal[dstIdx] = real[srcIdx]; + workImag[dstIdx] = imag[srcIdx]; + } + } + } else { + workReal = real; + workImag = imag; + } + + // FFT on rows (padded width) + const rowReal = new Float32Array(paddedW); + const rowImag = new Float32Array(paddedW); + for (let y = 0; y < paddedH; y++) { + const offset = y * paddedW; + for (let x = 0; x < paddedW; x++) { + rowReal[x] = workReal[offset + x]; + rowImag[x] = workImag[offset + x]; + } + fft1dPow2(rowReal, rowImag, inverse); + for (let x = 0; x < paddedW; x++) { + workReal[offset + x] = rowReal[x]; + workImag[offset + x] = rowImag[x]; + } + } + + // FFT on columns (padded height) + const colReal = new Float32Array(paddedH); + const colImag = new Float32Array(paddedH); + for (let x = 0; x < paddedW; x++) { + for (let y = 0; y < paddedH; y++) { + colReal[y] = workReal[y * paddedW + x]; + colImag[y] = workImag[y * paddedW + x]; + } + fft1dPow2(colReal, colImag, inverse); + for (let y = 0; y < paddedH; y++) { + workReal[y * paddedW + x] = colReal[y]; + workImag[y * paddedW + x] = colImag[y]; + } + } + + // Copy back to original arrays if padded + if (needsPadding) { + for (let y = 0; y < height; y++) { + for (let x = 0; x < width; x++) { + const srcIdx = y * paddedW + x; + const dstIdx = y * width + x; + real[dstIdx] = workReal[srcIdx]; + imag[dstIdx] = workImag[srcIdx]; + } + } + } +} + +// ============================================================================ +// Band-pass Filter +// ============================================================================ + +/** Apply band-pass filter in frequency domain (keeps frequencies between inner and outer radius) */ +export function applyBandPassFilter( + real: Float32Array, + imag: Float32Array, + width: number, + height: number, + innerRadius: number, // High-pass: remove frequencies below this + outerRadius: number // Low-pass: remove frequencies above this +) { + const centerX = width >> 1; + const centerY = height >> 1; + const innerSq = innerRadius * innerRadius; + const outerSq = outerRadius * outerRadius; + + for (let y = 0; y < height; y++) { + for (let x = 0; x < width; x++) { + const dx = x - centerX; + const dy = y - centerY; + const distSq = dx * dx + dy * dy; + const idx = y * width + x; + + // Zero out frequencies outside the band + if (distSq < innerSq || (outerRadius > 0 && distSq > outerSq)) { + real[idx] = 0; + imag[idx] = 0; + } + } + } +} + diff --git a/widget/js/show4dstem.css b/widget/js/show4dstem.css new file mode 100644 index 00000000..f754251e --- /dev/null +++ b/widget/js/show4dstem.css @@ -0,0 +1,19 @@ +/* show4dstem.css - Minimal CSS for Show4DSTEM */ +/* Most styling handled by MUI, this is for canvas-specific styles */ + +.show4dstem-root { + font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif; + background-color: #1a1a1a; +} + +/* Target Jupyter/VS Code output areas */ +.widget-output, +.jp-OutputArea-output, +.jp-RenderedHTMLCommon, +.cell-output-ipywidget-background { + background-color: #1a1a1a !important; +} + +.show4dstem-root canvas { + display: block; +} \ No newline at end of file diff --git a/widget/js/show4dstem.tsx b/widget/js/show4dstem.tsx new file mode 100644 index 00000000..a26f1047 --- /dev/null +++ b/widget/js/show4dstem.tsx @@ -0,0 +1,1532 @@ +import * as React from "react"; +import { createRender, useModelState } from "@anywidget/react"; +import Box from "@mui/material/Box"; +import Typography from "@mui/material/Typography"; +import Stack from "@mui/material/Stack"; +import Select from "@mui/material/Select"; +import MenuItem from "@mui/material/MenuItem"; +import Slider from "@mui/material/Slider"; +import Button from "@mui/material/Button"; +import Switch from "@mui/material/Switch"; +import JSZip from "jszip"; +import { getWebGPUFFT, WebGPUFFT, getGPUInfo } from "./webgpu-fft"; +import { COLORMAPS, fft1d, fft2d, fftshift, applyBandPassFilter, MIN_ZOOM, MAX_ZOOM } from "./shared"; +import { colors, typography, controlPanel, container } from "./CONFIG"; +import { upwardMenuProps, switchStyles } from "./components"; +import "./show4dstem.css"; + +// ============================================================================ +// Constants - Relative sizing for various detector sizes (64x64 to 256x256+) +// ============================================================================ +const RESIZE_HANDLE_FRACTION = 0.05; // Resize handle as fraction of detector size +const RESIZE_HANDLE_MIN_PX = 5; // Minimum resize handle radius +const RESIZE_HANDLE_MAX_PX = 8; // Maximum resize handle radius +const RESIZE_HIT_AREA_FRACTION = 0.06; // Click tolerance as fraction of detector +const RESIZE_HIT_AREA_MIN_PX = 6; // Minimum click tolerance +// Crosshair sizes: fixed pixel sizes for consistent appearance +const CROSSHAIR_SIZE_PX = 18; // Fixed crosshair size for point mode (CSS pixels on 400px canvas) +const CROSSHAIR_SIZE_SMALL_PX = 10; // Fixed small crosshair size for ROI center +const CENTER_DOT_RADIUS_PX = 6; // Center dot radius +const CIRCLE_HANDLE_ANGLE = 0.707; // cos(45°) for circle handle position +// Line widths as fraction of size +const LINE_WIDTH_FRACTION = 0.015; // Line width as fraction of size +const LINE_WIDTH_MIN_PX = 1.5; // Minimum line width +const LINE_WIDTH_MAX_PX = 3; // Maximum line width + +// ============================================================================ +// Scale Bar (dynamic adjustment to nice values) +// ============================================================================ + +/** Round to a nice value (1, 2, 5, 10, 20, 50, etc.) */ +function roundToNiceValue(value: number): number { + if (value <= 0) return 1; + const magnitude = Math.pow(10, Math.floor(Math.log10(value))); + const normalized = value / magnitude; + if (normalized < 1.5) return magnitude; + if (normalized < 3.5) return 2 * magnitude; + if (normalized < 7.5) return 5 * magnitude; + return 10 * magnitude; +} + +/** Format scale bar label with appropriate unit */ +function formatScaleLabel(value: number, unit: string): string { + const nice = roundToNiceValue(value); + + if (unit === "nm") { + if (nice >= 1000) return `${Math.round(nice / 1000)} µm`; + if (nice >= 1) return `${Math.round(nice)} nm`; + return `${nice.toFixed(2)} nm`; + } else if (unit === "mrad") { + if (nice >= 1000) return `${Math.round(nice / 1000)} rad`; + if (nice >= 1) return `${Math.round(nice)} mrad`; + return `${nice.toFixed(2)} mrad`; + } else if (unit === "1/µm") { + if (nice >= 1000) return `${Math.round(nice / 1000)} 1/nm`; + if (nice >= 1) return `${Math.round(nice)} 1/µm`; + return `${nice.toFixed(2)} 1/µm`; + } + return `${Math.round(nice)} ${unit}`; +} + +/** + * Draw scale bar and zoom indicator on a high-DPI UI canvas. + * This renders crisp text/lines independent of the image resolution. + */ +function drawScaleBarHiDPI( + canvas: HTMLCanvasElement, + dpr: number, + zoom: number, + pixelSize: number, + unit: string = "nm", + imageWidth: number, // Original image width in pixels + imageHeight: number // Original image height in pixels +) { + const ctx = canvas.getContext("2d"); + if (!ctx) return; + + // Clear canvas + ctx.clearRect(0, 0, canvas.width, canvas.height); + + // Scale context for device pixel ratio + ctx.save(); + ctx.scale(dpr, dpr); + + // CSS pixel dimensions + const cssWidth = canvas.width / dpr; + const cssHeight = canvas.height / dpr; + + // Calculate the display scale factor (how much the image is scaled to fit the canvas) + const displayScale = Math.min(cssWidth / imageWidth, cssHeight / imageHeight); + const effectiveZoom = zoom * displayScale; + + // Fixed UI sizes in CSS pixels (always crisp) + const targetBarPx = 60; // Target bar length in CSS pixels + const barThickness = 5; + const fontSize = 16; + const margin = 12; + + // Calculate what physical size the target bar represents + const targetPhysical = (targetBarPx / effectiveZoom) * pixelSize; + + // Round to a nice value + const nicePhysical = roundToNiceValue(targetPhysical); + + // Calculate actual bar length for the nice value (in CSS pixels) + const barPx = (nicePhysical / pixelSize) * effectiveZoom; + + const barY = cssHeight - margin; + const barX = cssWidth - barPx - margin; + + // Draw bar with shadow for visibility + ctx.shadowColor = "rgba(0, 0, 0, 0.5)"; + ctx.shadowBlur = 2; + ctx.shadowOffsetX = 1; + ctx.shadowOffsetY = 1; + + ctx.fillStyle = "white"; + ctx.fillRect(barX, barY, barPx, barThickness); + + // Draw label (centered above bar) + const label = formatScaleLabel(nicePhysical, unit); + ctx.font = `${fontSize}px -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif`; + ctx.fillStyle = "white"; + ctx.textAlign = "center"; + ctx.textBaseline = "bottom"; + ctx.fillText(label, barX + barPx / 2, barY - 4); + + // Draw zoom indicator (bottom left) + ctx.textAlign = "left"; + ctx.textBaseline = "bottom"; + ctx.fillText(`${zoom.toFixed(1)}×`, margin, cssHeight - margin + barThickness); + + ctx.restore(); +} + +/** + * Draw VI crosshair on high-DPI canvas (crisp regardless of image resolution) + * Note: Does NOT clear canvas - should be called after drawScaleBarHiDPI + */ +function drawViCrosshairHiDPI( + canvas: HTMLCanvasElement, + dpr: number, + posX: number, // Position in image coordinates + posY: number, + zoom: number, + panX: number, + panY: number, + imageWidth: number, + imageHeight: number, + isDragging: boolean +) { + const ctx = canvas.getContext("2d"); + if (!ctx) return; + + ctx.save(); + ctx.scale(dpr, dpr); + + const cssWidth = canvas.width / dpr; + const cssHeight = canvas.height / dpr; + const displayScale = Math.min(cssWidth / imageWidth, cssHeight / imageHeight); + + // Convert image coordinates to CSS pixel coordinates + const screenX = posY * zoom * displayScale + panX * displayScale; + const screenY = posX * zoom * displayScale + panY * displayScale; + + // Fixed UI sizes in CSS pixels (consistent with DP crosshair) + const crosshairSize = 18; + const lineWidth = 3; + const dotRadius = 6; + + ctx.shadowColor = "rgba(0, 0, 0, 0.5)"; + ctx.shadowBlur = 2; + ctx.shadowOffsetX = 1; + ctx.shadowOffsetY = 1; + + ctx.strokeStyle = isDragging ? "rgba(255, 255, 0, 0.9)" : "rgba(255, 100, 100, 0.9)"; + ctx.lineWidth = lineWidth; + + // Draw crosshair + ctx.beginPath(); + ctx.moveTo(screenX - crosshairSize, screenY); + ctx.lineTo(screenX + crosshairSize, screenY); + ctx.moveTo(screenX, screenY - crosshairSize); + ctx.lineTo(screenX, screenY + crosshairSize); + ctx.stroke(); + + // Draw center dot + ctx.beginPath(); + ctx.arc(screenX, screenY, dotRadius, 0, 2 * Math.PI); + ctx.stroke(); + + ctx.restore(); +} + +/** + * Draw DP crosshair on high-DPI canvas (crisp regardless of detector resolution) + * Note: Does NOT clear canvas - should be called after drawScaleBarHiDPI + */ +function drawDpCrosshairHiDPI( + canvas: HTMLCanvasElement, + dpr: number, + kx: number, // Position in detector coordinates + ky: number, + zoom: number, + panX: number, + panY: number, + detWidth: number, + detHeight: number, + isDragging: boolean +) { + const ctx = canvas.getContext("2d"); + if (!ctx) return; + + ctx.save(); + ctx.scale(dpr, dpr); + + const cssWidth = canvas.width / dpr; + const cssHeight = canvas.height / dpr; + const displayScale = Math.min(cssWidth / detWidth, cssHeight / detHeight); + + // Convert detector coordinates to CSS pixel coordinates (no swap - kx is X, ky is Y) + const screenX = kx * zoom * displayScale + panX * displayScale; + const screenY = ky * zoom * displayScale + panY * displayScale; + + // Fixed UI sizes in CSS pixels (consistent with VI crosshair) + const crosshairSize = 18; + const lineWidth = 3; + const dotRadius = 6; + + ctx.shadowColor = "rgba(0, 0, 0, 0.5)"; + ctx.shadowBlur = 2; + ctx.shadowOffsetX = 1; + ctx.shadowOffsetY = 1; + + ctx.strokeStyle = isDragging ? "rgba(255, 255, 0, 0.9)" : "rgba(0, 255, 0, 0.9)"; + ctx.lineWidth = lineWidth; + + // Draw crosshair + ctx.beginPath(); + ctx.moveTo(screenX - crosshairSize, screenY); + ctx.lineTo(screenX + crosshairSize, screenY); + ctx.moveTo(screenX, screenY - crosshairSize); + ctx.lineTo(screenX, screenY + crosshairSize); + ctx.stroke(); + + // Draw center dot + ctx.beginPath(); + ctx.arc(screenX, screenY, dotRadius, 0, 2 * Math.PI); + ctx.stroke(); + + ctx.restore(); +} + +// Legacy stub - scale bars now drawn on high-DPI UI canvases +function drawScaleBar(..._args: unknown[]) { /* no-op, replaced by drawScaleBarHiDPI */ } + +// ============================================================================ +// Main Component +// ============================================================================ +function Show4DSTEM() { + // ───────────────────────────────────────────────────────────────────────── + // Model State (synced with Python) + // ───────────────────────────────────────────────────────────────────────── + const [shapeX] = useModelState("shape_x"); + const [shapeY] = useModelState("shape_y"); + const [detX] = useModelState("det_x"); + const [detY] = useModelState("det_y"); + + const [posX, setPosX] = useModelState("pos_x"); + const [posY, setPosY] = useModelState("pos_y"); + const [roiCenterX, setRoiCenterX] = useModelState("roi_center_x"); + const [roiCenterY, setRoiCenterY] = useModelState("roi_center_y"); + const [, setRoiActive] = useModelState("roi_active"); + + const [pixelSize] = useModelState("pixel_size"); + const [detPixelSize] = useModelState("det_pixel_size"); + + const [frameBytes] = useModelState("frame_bytes"); + const [virtualImageBytes] = useModelState("virtual_image_bytes"); + + // ROI state + const [roiRadius, setRoiRadius] = useModelState("roi_radius"); + const [roiRadiusInner, setRoiRadiusInner] = useModelState("roi_radius_inner"); + const [roiMode, setRoiMode] = useModelState("roi_mode"); + const [roiWidth, setRoiWidth] = useModelState("roi_width"); + const [roiHeight, setRoiHeight] = useModelState("roi_height"); + + // Display options + const [logScale, setLogScale] = useModelState("log_scale"); + const [autoRange, setAutoRange] = useModelState("auto_range"); + const [percentileLow, setPercentileLow] = useModelState("percentile_low"); + const [percentileHigh, setPercentileHigh] = useModelState("percentile_high"); + + // Detector calibration (for presets) + const [bfRadius] = useModelState("bf_radius"); + const [centerX] = useModelState("center_x"); + const [centerY] = useModelState("center_y"); + + // Path animation state + const [pathPlaying, setPathPlaying] = useModelState("path_playing"); + const [pathIndex, setPathIndex] = useModelState("path_index"); + const [pathLength] = useModelState("path_length"); + const [pathIntervalMs] = useModelState("path_interval_ms"); + const [pathLoop] = useModelState("path_loop"); + + // ───────────────────────────────────────────────────────────────────────── + // Local State (UI-only, not synced to Python) + // ───────────────────────────────────────────────────────────────────────── + const [localKx, setLocalKx] = React.useState(roiCenterX); + const [localKy, setLocalKy] = React.useState(roiCenterY); + const [localPosX, setLocalPosX] = React.useState(posX); + const [localPosY, setLocalPosY] = React.useState(posY); + const [isDraggingDP, setIsDraggingDP] = React.useState(false); + const [isDraggingVI, setIsDraggingVI] = React.useState(false); + const [isDraggingFFT, setIsDraggingFFT] = React.useState(false); + const [fftDragStart, setFftDragStart] = React.useState<{ x: number, y: number, panX: number, panY: number } | null>(null); + const [isDraggingResize, setIsDraggingResize] = React.useState(false); + const [isDraggingResizeInner, setIsDraggingResizeInner] = React.useState(false); // For annular inner handle + const [isHoveringResize, setIsHoveringResize] = React.useState(false); + const [isHoveringResizeInner, setIsHoveringResizeInner] = React.useState(false); + const [colormap, setColormap] = React.useState("inferno"); + + // Band-pass filter range [innerCutoff, outerCutoff] in pixels - [0, 0] means disabled + const [bandpass, setBandpass] = React.useState([0, 0]); + const bpInner = bandpass[0]; + const bpOuter = bandpass[1]; + + // GPU FFT state + const gpuFFTRef = React.useRef(null); + const [gpuReady, setGpuReady] = React.useState(false); + + // Path animation timer + React.useEffect(() => { + if (!pathPlaying || pathLength === 0) return; + + const timer = setInterval(() => { + setPathIndex((prev: number) => { + const next = prev + 1; + if (next >= pathLength) { + if (pathLoop) { + return 0; // Loop back to start + } else { + setPathPlaying(false); // Stop at end + return prev; + } + } + return next; + }); + }, pathIntervalMs); + + return () => clearInterval(timer); + }, [pathPlaying, pathLength, pathIntervalMs, pathLoop, setPathIndex, setPathPlaying]); + + // Keyboard shortcuts + React.useEffect(() => { + const handleKeyDown = (e: KeyboardEvent) => { + // Ignore if typing in an input + if (e.target instanceof HTMLInputElement || e.target instanceof HTMLTextAreaElement) return; + + const step = e.shiftKey ? 10 : 1; + + switch (e.key) { + case 'ArrowUp': + e.preventDefault(); + setPosX(Math.max(0, posX - step)); + break; + case 'ArrowDown': + e.preventDefault(); + setPosX(Math.min(shapeX - 1, posX + step)); + break; + case 'ArrowLeft': + e.preventDefault(); + setPosY(Math.max(0, posY - step)); + break; + case 'ArrowRight': + e.preventDefault(); + setPosY(Math.min(shapeY - 1, posY + step)); + break; + case ' ': // Space bar + e.preventDefault(); + if (pathLength > 0) { + setPathPlaying(!pathPlaying); + } + break; + case 'r': // Reset view + case 'R': + setDpZoom(1); setDpPanX(0); setDpPanY(0); + setViZoom(1); setViPanX(0); setViPanY(0); + setFftZoom(1); setFftPanX(0); setFftPanY(0); + break; + } + }; + + window.addEventListener('keydown', handleKeyDown); + return () => window.removeEventListener('keydown', handleKeyDown); + }, [posX, posY, shapeX, shapeY, pathPlaying, pathLength, setPosX, setPosY, setPathPlaying]); + + // Initialize WebGPU FFT on mount + React.useEffect(() => { + getWebGPUFFT().then(fft => { + if (fft) { + gpuFFTRef.current = fft; + setGpuReady(true); + console.log("WebGPU FFT ready - using GPU acceleration!"); + } else { + console.log("⚠️ WebGPU not available - using CPU FFT"); + } + }); + }, []); + + // Zoom state + const [dpZoom, setDpZoom] = React.useState(1); + const [dpPanX, setDpPanX] = React.useState(0); + const [dpPanY, setDpPanY] = React.useState(0); + const [viZoom, setViZoom] = React.useState(1); + const [viPanX, setViPanX] = React.useState(0); + const [viPanY, setViPanY] = React.useState(0); + const [fftZoom, setFftZoom] = React.useState(1); + const [fftPanX, setFftPanX] = React.useState(0); + const [fftPanY, setFftPanY] = React.useState(0); + + // Sync local state + React.useEffect(() => { + if (!isDraggingDP && !isDraggingResize) { setLocalKx(roiCenterX); setLocalKy(roiCenterY); } + }, [roiCenterX, roiCenterY, isDraggingDP, isDraggingResize]); + + React.useEffect(() => { + if (!isDraggingVI) { setLocalPosX(posX); setLocalPosY(posY); } + }, [posX, posY, isDraggingVI]); + + // Canvas refs + const dpCanvasRef = React.useRef(null); + const dpOverlayRef = React.useRef(null); + const dpUiRef = React.useRef(null); // High-DPI UI overlay for scale bar + const virtualCanvasRef = React.useRef(null); + const virtualOverlayRef = React.useRef(null); + const viUiRef = React.useRef(null); // High-DPI UI overlay for scale bar + const fftCanvasRef = React.useRef(null); + const fftOverlayRef = React.useRef(null); + const fftUiRef = React.useRef(null); // High-DPI UI overlay for scale bar + + // Display size for high-DPI UI overlays + const UI_SIZE = 400; + const DPR = typeof window !== 'undefined' ? window.devicePixelRatio || 1 : 1; + + // ───────────────────────────────────────────────────────────────────────── + // Effects: Canvas Rendering & Animation + // ───────────────────────────────────────────────────────────────────────── + + // Prevent page scroll when scrolling on canvases + React.useEffect(() => { + const preventDefault = (e: WheelEvent) => e.preventDefault(); + const overlays = [dpOverlayRef.current, virtualOverlayRef.current, fftOverlayRef.current]; + overlays.forEach(el => el?.addEventListener("wheel", preventDefault, { passive: false })); + return () => overlays.forEach(el => el?.removeEventListener("wheel", preventDefault)); + }, []); + + // Store raw virtual image data for filtering + const rawVirtualImageRef = React.useRef(null); + + // Parse virtual image bytes into Float32Array + React.useEffect(() => { + if (!virtualImageBytes) return; + const bytes = new Uint8Array(virtualImageBytes.buffer, virtualImageBytes.byteOffset, virtualImageBytes.byteLength); + const floatData = new Float32Array(bytes.length); + for (let i = 0; i < bytes.length; i++) { + floatData[i] = bytes[i]; + } + rawVirtualImageRef.current = floatData; + }, [virtualImageBytes]); + + // Render DP with zoom + React.useEffect(() => { + if (!frameBytes || !dpCanvasRef.current) return; + const canvas = dpCanvasRef.current; + const ctx = canvas.getContext("2d"); + if (!ctx) return; + + const bytes = new Uint8Array(frameBytes.buffer, frameBytes.byteOffset, frameBytes.byteLength); + const lut = COLORMAPS[colormap] || COLORMAPS.inferno; + + const offscreen = document.createElement("canvas"); + offscreen.width = detY; + offscreen.height = detX; + const offCtx = offscreen.getContext("2d"); + if (!offCtx) return; + + const imgData = offCtx.createImageData(detY, detX); + const rgba = imgData.data; + + for (let i = 0; i < bytes.length; i++) { + const v = bytes[i]; + const j = i * 4; + const lutIdx = v * 3; + rgba[j] = lut[lutIdx]; + rgba[j + 1] = lut[lutIdx + 1]; + rgba[j + 2] = lut[lutIdx + 2]; + rgba[j + 3] = 255; + } + offCtx.putImageData(imgData, 0, 0); + + ctx.imageSmoothingEnabled = false; + ctx.clearRect(0, 0, canvas.width, canvas.height); + ctx.save(); + ctx.translate(dpPanX, dpPanY); + ctx.scale(dpZoom, dpZoom); + ctx.drawImage(offscreen, 0, 0); + ctx.restore(); + }, [frameBytes, detX, detY, colormap, dpZoom, dpPanX, dpPanY]); + + // Render DP overlay + React.useEffect(() => { + if (!dpOverlayRef.current) return; + const canvas = dpOverlayRef.current; + const ctx = canvas.getContext("2d"); + if (!ctx) return; + ctx.clearRect(0, 0, canvas.width, canvas.height); + + const screenKx = localKx * dpZoom + dpPanX; + const screenKy = localKy * dpZoom + dpPanY; + + // Convert fixed CSS pixel sizes to canvas pixels (canvas is detX x detY, displayed at 400x400) + const minDetSize = Math.min(detX, detY); + const canvasScale = minDetSize / 400; // How many canvas pixels per CSS pixel + const crosshairSize = CROSSHAIR_SIZE_PX * canvasScale * dpZoom; + const crosshairSizeSmall = CROSSHAIR_SIZE_SMALL_PX * canvasScale * dpZoom; + const centerDotRadius = CENTER_DOT_RADIUS_PX * canvasScale * dpZoom; + const lineWidth = Math.max(LINE_WIDTH_MIN_PX, Math.min(LINE_WIDTH_MAX_PX, minDetSize * LINE_WIDTH_FRACTION)); + + ctx.strokeStyle = isDraggingDP ? "rgba(255, 255, 0, 0.9)" : "rgba(0, 255, 0, 0.9)"; + ctx.lineWidth = lineWidth; + + // Helper to draw resize handle + const drawResizeHandle = (handleX: number, handleY: number) => { + let handleFill: string; + let handleStroke: string; + if (isDraggingResize) { + handleFill = "rgba(0, 200, 255, 1)"; // Cyan when dragging + handleStroke = "rgba(255, 255, 255, 1)"; + } else if (isHoveringResize) { + handleFill = "rgba(255, 100, 100, 1)"; // Red when hovering + handleStroke = "rgba(255, 255, 255, 1)"; + } else { + handleFill = "rgba(0, 255, 0, 0.8)"; // Green default + handleStroke = "rgba(255, 255, 255, 0.8)"; + } + ctx.beginPath(); + ctx.arc(handleX, handleY, RESIZE_HANDLE_RADIUS, 0, 2 * Math.PI); + ctx.fillStyle = handleFill; + ctx.fill(); + ctx.strokeStyle = handleStroke; + ctx.lineWidth = 1.5; + ctx.stroke(); + }; + + if (roiMode === "circle" && roiRadius > 0) { + // Circle mode: draw a filled circular ROI + const screenRadius = roiRadius * dpZoom; + ctx.beginPath(); + ctx.arc(screenKx, screenKy, screenRadius, 0, 2 * Math.PI); + ctx.stroke(); + // Semi-transparent fill + ctx.fillStyle = isDraggingDP ? "rgba(255, 255, 0, 0.15)" : "rgba(0, 255, 0, 0.15)"; + ctx.fill(); + // Draw center crosshair (smaller) + ctx.strokeStyle = isDraggingDP ? "rgba(255, 255, 0, 0.9)" : "rgba(0, 255, 0, 0.9)"; + ctx.lineWidth = lineWidth; + ctx.beginPath(); + ctx.moveTo(screenKx - crosshairSizeSmall, screenKy); ctx.lineTo(screenKx + crosshairSizeSmall, screenKy); + ctx.moveTo(screenKx, screenKy - crosshairSizeSmall); ctx.lineTo(screenKx, screenKy + crosshairSizeSmall); + ctx.stroke(); + + // Draw resize handle at bottom-right of circle (45° position) + const handleOffset = screenRadius * CIRCLE_HANDLE_ANGLE; + drawResizeHandle(screenKx + handleOffset, screenKy + handleOffset); + + } else if (roiMode === "square" && roiRadius > 0) { + // Square mode: draw a filled square ROI + const screenHalfSize = roiRadius * dpZoom; + const left = screenKx - screenHalfSize; + const top = screenKy - screenHalfSize; + const size = screenHalfSize * 2; + + ctx.beginPath(); + ctx.rect(left, top, size, size); + ctx.stroke(); + // Semi-transparent fill + ctx.fillStyle = isDraggingDP ? "rgba(255, 255, 0, 0.15)" : "rgba(0, 255, 0, 0.15)"; + ctx.fill(); + // Draw center crosshair (smaller) + ctx.strokeStyle = isDraggingDP ? "rgba(255, 255, 0, 0.9)" : "rgba(0, 255, 0, 0.9)"; + ctx.lineWidth = lineWidth; + ctx.beginPath(); + ctx.moveTo(screenKx - crosshairSizeSmall, screenKy); ctx.lineTo(screenKx + crosshairSizeSmall, screenKy); + ctx.moveTo(screenKx, screenKy - crosshairSizeSmall); ctx.lineTo(screenKx, screenKy + crosshairSizeSmall); + ctx.stroke(); + + // Draw resize handle at bottom-right corner of square + drawResizeHandle(screenKx + screenHalfSize, screenKy + screenHalfSize); + + } else if (roiMode === "rect" && roiWidth > 0 && roiHeight > 0) { + // Rectangular mode: draw a filled rectangular ROI with independent width/height + const screenHalfW = (roiWidth / 2) * dpZoom; + const screenHalfH = (roiHeight / 2) * dpZoom; + const left = screenKx - screenHalfW; + const top = screenKy - screenHalfH; + + ctx.beginPath(); + ctx.rect(left, top, screenHalfW * 2, screenHalfH * 2); + ctx.stroke(); + // Semi-transparent fill + ctx.fillStyle = isDraggingDP ? "rgba(255, 255, 0, 0.15)" : "rgba(0, 255, 0, 0.15)"; + ctx.fill(); + // Draw center crosshair (smaller) + ctx.strokeStyle = isDraggingDP ? "rgba(255, 255, 0, 0.9)" : "rgba(0, 255, 0, 0.9)"; + ctx.lineWidth = lineWidth; + ctx.beginPath(); + ctx.moveTo(screenKx - crosshairSizeSmall, screenKy); ctx.lineTo(screenKx + crosshairSizeSmall, screenKy); + ctx.moveTo(screenKx, screenKy - crosshairSizeSmall); ctx.lineTo(screenKx, screenKy + crosshairSizeSmall); + ctx.stroke(); + + // Draw resize handle at bottom-right corner of rectangle + drawResizeHandle(screenKx + screenHalfW, screenKy + screenHalfH); + + } else if (roiMode === "annular" && roiRadius > 0) { + // Annular mode: draw donut-shaped ROI (ADF/HAADF) + const screenRadiusOuter = roiRadius * dpZoom; + const screenRadiusInner = (roiRadiusInner || 0) * dpZoom; + + // Draw outer circle + ctx.beginPath(); + ctx.arc(screenKx, screenKy, screenRadiusOuter, 0, 2 * Math.PI); + ctx.stroke(); + + // Draw inner circle (different color for distinction) + ctx.save(); + ctx.strokeStyle = isDraggingDP ? "rgba(255, 200, 0, 0.9)" : "rgba(0, 220, 255, 0.9)"; // Cyan/orange for inner + ctx.beginPath(); + ctx.arc(screenKx, screenKy, screenRadiusInner, 0, 2 * Math.PI); + ctx.stroke(); + ctx.restore(); + + // Fill the annular region (donut) using composite operation + ctx.save(); + ctx.fillStyle = isDraggingDP ? "rgba(255, 255, 0, 0.15)" : "rgba(0, 255, 0, 0.15)"; + ctx.beginPath(); + ctx.arc(screenKx, screenKy, screenRadiusOuter, 0, 2 * Math.PI); + ctx.arc(screenKx, screenKy, screenRadiusInner, 0, 2 * Math.PI, true); // counter-clockwise to cut hole + ctx.fill(); + ctx.restore(); + + // Draw center crosshair (smaller) + ctx.strokeStyle = isDraggingDP ? "rgba(255, 255, 0, 0.9)" : "rgba(0, 255, 0, 0.9)"; + ctx.lineWidth = lineWidth; + ctx.beginPath(); + ctx.moveTo(screenKx - crosshairSizeSmall, screenKy); ctx.lineTo(screenKx + crosshairSizeSmall, screenKy); + ctx.moveTo(screenKx, screenKy - crosshairSizeSmall); ctx.lineTo(screenKx, screenKy + crosshairSizeSmall); + ctx.stroke(); + + // Draw resize handle at outer circle's 45° position (green) + const handleOffsetOuter = screenRadiusOuter * CIRCLE_HANDLE_ANGLE; + drawResizeHandle(screenKx + handleOffsetOuter, screenKy + handleOffsetOuter); + + // Draw resize handle at inner circle's 45° position (cyan) + ctx.save(); + ctx.strokeStyle = isHoveringResizeInner ? "rgba(0, 220, 255, 1)" : "rgba(0, 220, 255, 0.8)"; + ctx.fillStyle = "rgba(0, 40, 50, 0.8)"; + const handleOffsetInner = screenRadiusInner * CIRCLE_HANDLE_ANGLE; + ctx.beginPath(); + ctx.arc(screenKx + handleOffsetInner, screenKy + handleOffsetInner, RESIZE_HANDLE_RADIUS, 0, 2 * Math.PI); + ctx.fill(); + ctx.stroke(); + ctx.restore(); + + } + // Point mode crosshair is drawn on dpUiRef for crisp rendering + + drawScaleBar(ctx, canvas.width, canvas.height, dpZoom, detPixelSize || 1, "mrad", Math.max(detX, detY)); + }, [localKx, localKy, isDraggingDP, isDraggingResize, isDraggingResizeInner, isHoveringResize, isHoveringResizeInner, dpZoom, dpPanX, dpPanY, detPixelSize, roiMode, roiRadius, roiRadiusInner, roiWidth, roiHeight, detX, detY]); + + // Render filtered virtual image + React.useEffect(() => { + if (!rawVirtualImageRef.current || !virtualCanvasRef.current) return; + const canvas = virtualCanvasRef.current; + const ctx = canvas.getContext("2d"); + if (!ctx) return; + + const width = shapeY; + const height = shapeX; + + const renderData = (filtered: Float32Array) => { + // Normalize and render (with optional percentile contrast) + let min = Infinity, max = -Infinity; + + if (autoRange) { + // Percentile-based contrast: sort a sample and pick percentile values + const sorted = Float32Array.from(filtered).sort((a, b) => a - b); + const lowIdx = Math.floor((percentileLow / 100) * sorted.length); + const highIdx = Math.floor((percentileHigh / 100) * sorted.length) - 1; + min = sorted[Math.max(0, lowIdx)]; + max = sorted[Math.min(sorted.length - 1, highIdx)]; + } else { + // Full range + for (let i = 0; i < filtered.length; i++) { + if (filtered[i] < min) min = filtered[i]; + if (filtered[i] > max) max = filtered[i]; + } + } + + const lut = COLORMAPS[colormap] || COLORMAPS.inferno; + const offscreen = document.createElement("canvas"); + offscreen.width = width; + offscreen.height = height; + const offCtx = offscreen.getContext("2d"); + if (!offCtx) return; + + const imageData = offCtx.createImageData(width, height); + for (let i = 0; i < filtered.length; i++) { + const val = Math.floor(((filtered[i] - min) / (max - min || 1)) * 255); + imageData.data[i * 4] = lut[val * 3]; + imageData.data[i * 4 + 1] = lut[val * 3 + 1]; + imageData.data[i * 4 + 2] = lut[val * 3 + 2]; + imageData.data[i * 4 + 3] = 255; + } + offCtx.putImageData(imageData, 0, 0); + + ctx.imageSmoothingEnabled = false; + ctx.clearRect(0, 0, canvas.width, canvas.height); + ctx.save(); + ctx.translate(viPanX, viPanY); + ctx.scale(viZoom, viZoom); + ctx.drawImage(offscreen, 0, 0); + ctx.restore(); + }; + + if (bpInner > 0 || bpOuter > 0) { + if (gpuFFTRef.current && gpuReady) { + // GPU filtering (Async) + const real = rawVirtualImageRef.current.slice(); + const imag = new Float32Array(real.length); + + // We use a local flag to prevent state updates if the effect has already re-run + let isCancelled = false; + + const runGpuFilter = async () => { + // WebGPU version of: Forward -> Filter -> Inverse + // Note: The provided WebGPUFFT doesn't have shift/unshift built-in yet, + // but we can apply the filter in shifted coordinates or modify it. + // For now, let's keep it simple: Forward -> Filter -> Inverse. + const { real: fReal, imag: fImag } = await gpuFFTRef.current!.fft2D(real, imag, width, height, false); + + if (isCancelled) return; + + // Shift in CPU for now (future: do this in WGSL) + fftshift(fReal, width, height); + fftshift(fImag, width, height); + applyBandPassFilter(fReal, fImag, width, height, bpInner, bpOuter); + fftshift(fReal, width, height); + fftshift(fImag, width, height); + + const { real: invReal } = await gpuFFTRef.current!.fft2D(fReal, fImag, width, height, true); + + if (!isCancelled) renderData(invReal); + }; + + runGpuFilter(); + return () => { isCancelled = true; }; + } else { + // CPU Fallback (Sync) + const real = rawVirtualImageRef.current.slice(); + const imag = new Float32Array(real.length); + fft2d(real, imag, width, height, false); + fftshift(real, width, height); + fftshift(imag, width, height); + applyBandPassFilter(real, imag, width, height, bpInner, bpOuter); + fftshift(real, width, height); + fftshift(imag, width, height); + fft2d(real, imag, width, height, true); + renderData(real); + } + } else { + renderData(rawVirtualImageRef.current); + } + }, [virtualImageBytes, shapeX, shapeY, colormap, viZoom, viPanX, viPanY, bpInner, bpOuter, gpuReady, autoRange, percentileLow, percentileHigh]); + + // Render virtual image overlay (just clear - crosshair drawn on high-DPI UI canvas) + React.useEffect(() => { + if (!virtualOverlayRef.current) return; + const canvas = virtualOverlayRef.current; + const ctx = canvas.getContext("2d"); + if (!ctx) return; + ctx.clearRect(0, 0, canvas.width, canvas.height); + // Crosshair and scale bar now drawn on high-DPI UI canvas (viUiRef) + }, [localPosX, localPosY, isDraggingVI, viZoom, viPanX, viPanY, pixelSize, shapeX, shapeY]); + + // Render FFT (computed in JS from filtered image) + React.useEffect(() => { + if (!rawVirtualImageRef.current || !fftCanvasRef.current) return; + const canvas = fftCanvasRef.current; + const ctx = canvas.getContext("2d"); + if (!ctx) return; + + const width = shapeY; + const height = shapeX; + + // Use raw or filtered data + let sourceData = rawVirtualImageRef.current; + if (bpInner > 0 || bpOuter > 0) { + sourceData = rawVirtualImageRef.current.slice(); + } + + const real = sourceData.slice(); + const imag = new Float32Array(real.length); + + // Forward FFT + fft2d(real, imag, width, height, false); + fftshift(real, width, height); + fftshift(imag, width, height); + + // Compute log magnitude + const magnitude = new Float32Array(real.length); + for (let i = 0; i < real.length; i++) { + magnitude[i] = Math.log1p(Math.sqrt(real[i] * real[i] + imag[i] * imag[i])); + } + + // Normalize + let min = Infinity, max = -Infinity; + for (let i = 0; i < magnitude.length; i++) { + if (magnitude[i] < min) min = magnitude[i]; + if (magnitude[i] > max) max = magnitude[i]; + } + + const lut = COLORMAPS[colormap] || COLORMAPS.inferno; + const offscreen = document.createElement("canvas"); + offscreen.width = width; + offscreen.height = height; + const offCtx = offscreen.getContext("2d"); + if (!offCtx) return; + + const imgData = offCtx.createImageData(width, height); + const rgba = imgData.data; + const range = max > min ? max - min : 1; + + for (let i = 0; i < magnitude.length; i++) { + const v = Math.round(((magnitude[i] - min) / range) * 255); + const j = i * 4; + const lutIdx = Math.max(0, Math.min(255, v)) * 3; + rgba[j] = lut[lutIdx]; + rgba[j + 1] = lut[lutIdx + 1]; + rgba[j + 2] = lut[lutIdx + 2]; + rgba[j + 3] = 255; + } + offCtx.putImageData(imgData, 0, 0); + + ctx.imageSmoothingEnabled = false; + ctx.clearRect(0, 0, canvas.width, canvas.height); + ctx.save(); + ctx.translate(fftPanX, fftPanY); + ctx.scale(fftZoom, fftZoom); + ctx.drawImage(offscreen, 0, 0); + ctx.restore(); + }, [virtualImageBytes, shapeX, shapeY, colormap, fftZoom, fftPanX, fftPanY, bpInner, bpOuter]); + + // Render FFT overlay with high-pass filter circle + React.useEffect(() => { + if (!fftOverlayRef.current) return; + const canvas = fftOverlayRef.current; + const ctx = canvas.getContext("2d"); + if (!ctx) return; + ctx.clearRect(0, 0, canvas.width, canvas.height); + + // Draw band-pass filter circles (inner = HP, outer = LP) + const centerX = (shapeY / 2) * fftZoom + fftPanX; + const centerY = (shapeX / 2) * fftZoom + fftPanY; + const minScanSize = Math.min(shapeX, shapeY); + const fftLineWidth = Math.max(LINE_WIDTH_MIN_PX, Math.min(LINE_WIDTH_MAX_PX, minScanSize * LINE_WIDTH_FRACTION)); + + if (bpInner > 0) { + ctx.strokeStyle = "rgba(255, 0, 0, 0.8)"; + ctx.lineWidth = fftLineWidth; + ctx.setLineDash([5, 5]); + ctx.beginPath(); + ctx.arc(centerX, centerY, bpInner * fftZoom, 0, 2 * Math.PI); + ctx.stroke(); + ctx.setLineDash([]); + } + if (bpOuter > 0) { + ctx.strokeStyle = "rgba(0, 150, 255, 0.8)"; + ctx.lineWidth = fftLineWidth; + ctx.setLineDash([5, 5]); + ctx.beginPath(); + ctx.arc(centerX, centerY, bpOuter * fftZoom, 0, 2 * Math.PI); + ctx.stroke(); + ctx.setLineDash([]); + } + + const fftPixelSize = pixelSize ? 1 / (shapeX * pixelSize) : 1; + drawScaleBar(ctx, canvas.width, canvas.height, fftZoom, fftPixelSize * 1000, "1/µm", Math.max(shapeX, shapeY)); + }, [fftZoom, fftPanX, fftPanY, pixelSize, shapeX, shapeY, bpInner, bpOuter]); + + // ───────────────────────────────────────────────────────────────────────── + // High-DPI Scale Bar UI Overlays + // ───────────────────────────────────────────────────────────────────────── + + // DP scale bar + crosshair (high-DPI) + React.useEffect(() => { + if (!dpUiRef.current) return; + // Draw scale bar first (clears canvas) + drawScaleBarHiDPI(dpUiRef.current, DPR, dpZoom, detPixelSize || 1, "mrad", detY, detX); + // Draw crosshair (only for point mode - Python uses "point", not "crosshair") + if (roiMode === "point") { + drawDpCrosshairHiDPI(dpUiRef.current, DPR, localKx, localKy, dpZoom, dpPanX, dpPanY, detY, detX, isDraggingDP); + } + }, [dpZoom, dpPanX, dpPanY, detPixelSize, detX, detY, roiMode, localKx, localKy, isDraggingDP]); + + // VI scale bar + crosshair (high-DPI) + React.useEffect(() => { + if (!viUiRef.current) return; + // Draw scale bar first (clears canvas) + drawScaleBarHiDPI(viUiRef.current, DPR, viZoom, pixelSize || 1, "nm", shapeY, shapeX); + // Then draw crosshair on top + drawViCrosshairHiDPI(viUiRef.current, DPR, localPosX, localPosY, viZoom, viPanX, viPanY, shapeY, shapeX, isDraggingVI); + }, [viZoom, viPanX, viPanY, pixelSize, shapeX, shapeY, localPosX, localPosY, isDraggingVI]); + + // FFT - no scale bar (just clear canvas) + React.useEffect(() => { + if (!fftUiRef.current) return; + const ctx = fftUiRef.current.getContext("2d"); + if (ctx) ctx.clearRect(0, 0, fftUiRef.current.width, fftUiRef.current.height); + }, [fftZoom, shapeX, shapeY]); + + // Generic zoom handler + const createZoomHandler = ( + setZoom: React.Dispatch>, + setPanX: React.Dispatch>, + setPanY: React.Dispatch>, + zoom: number, panX: number, panY: number, + canvasRef: React.RefObject + ) => (e: React.WheelEvent) => { + e.preventDefault(); + const canvas = canvasRef.current; + if (!canvas) return; + const rect = canvas.getBoundingClientRect(); + const mouseX = (e.clientX - rect.left) * (canvas.width / rect.width); + const mouseY = (e.clientY - rect.top) * (canvas.height / rect.height); + const zoomFactor = e.deltaY > 0 ? 0.9 : 1.1; + const newZoom = Math.max(MIN_ZOOM, Math.min(MAX_ZOOM, zoom * zoomFactor)); + const zoomRatio = newZoom / zoom; + setZoom(newZoom); + setPanX(mouseX - (mouseX - panX) * zoomRatio); + setPanY(mouseY - (mouseY - panY) * zoomRatio); + }; + + // ───────────────────────────────────────────────────────────────────────── + // Mouse Handlers + // ───────────────────────────────────────────────────────────────────────── + + // Helper: check if point is near the outer resize handle + const isNearResizeHandle = (imgX: number, imgY: number): boolean => { + if (roiMode === "rect") { + // For rectangle, check near bottom-right corner + const handleX = roiCenterX + roiWidth / 2; + const handleY = roiCenterY + roiHeight / 2; + const dist = Math.sqrt((imgX - handleX) ** 2 + (imgY - handleY) ** 2); + return dist < RESIZE_HIT_AREA_PX / dpZoom; + } + if ((roiMode !== "circle" && roiMode !== "square" && roiMode !== "annular") || !roiRadius) return false; + const offset = roiMode === "square" ? roiRadius : roiRadius * CIRCLE_HANDLE_ANGLE; + const handleX = roiCenterX + offset; + const handleY = roiCenterY + offset; + const dist = Math.sqrt((imgX - handleX) ** 2 + (imgY - handleY) ** 2); + return dist < RESIZE_HIT_AREA_PX / dpZoom; + }; + + // Helper: check if point is near the inner resize handle (annular mode only) + const isNearResizeHandleInner = (imgX: number, imgY: number): boolean => { + if (roiMode !== "annular" || !roiRadiusInner) return false; + const offset = roiRadiusInner * CIRCLE_HANDLE_ANGLE; + const handleX = roiCenterX + offset; + const handleY = roiCenterY + offset; + const dist = Math.sqrt((imgX - handleX) ** 2 + (imgY - handleY) ** 2); + return dist < RESIZE_HIT_AREA_PX / dpZoom; + }; + + // Mouse handlers + const handleDpMouseDown = (e: React.MouseEvent) => { + const canvas = dpOverlayRef.current; + if (!canvas) return; + const rect = canvas.getBoundingClientRect(); + const screenX = (e.clientX - rect.left) * (canvas.width / rect.width); + const screenY = (e.clientY - rect.top) * (canvas.height / rect.height); + const imgX = (screenX - dpPanX) / dpZoom; + const imgY = (screenY - dpPanY) / dpZoom; + + // Check if clicking on resize handle (inner first, then outer) + if (isNearResizeHandleInner(imgX, imgY)) { + setIsDraggingResizeInner(true); + return; + } + if (isNearResizeHandle(imgX, imgY)) { + setIsDraggingResize(true); + return; + } + + setIsDraggingDP(true); + setLocalKx(imgX); setLocalKy(imgY); + setRoiActive(true); + setRoiCenterX(Math.round(Math.max(0, Math.min(detY - 1, imgX)))); + setRoiCenterY(Math.round(Math.max(0, Math.min(detX - 1, imgY)))); + }; + + const handleDpMouseMove = (e: React.MouseEvent) => { + const canvas = dpOverlayRef.current; + if (!canvas) return; + const rect = canvas.getBoundingClientRect(); + const screenX = (e.clientX - rect.left) * (canvas.width / rect.width); + const screenY = (e.clientY - rect.top) * (canvas.height / rect.height); + const imgX = (screenX - dpPanX) / dpZoom; + const imgY = (screenY - dpPanY) / dpZoom; + + // Handle inner resize dragging (annular mode) + if (isDraggingResizeInner) { + const dx = Math.abs(imgX - roiCenterX); + const dy = Math.abs(imgY - roiCenterY); + const newRadius = Math.sqrt(dx ** 2 + dy ** 2); + // Inner radius must be less than outer radius + setRoiRadiusInner(Math.max(1, Math.min(roiRadius - 1, Math.round(newRadius)))); + return; + } + + // Handle outer resize dragging - use model state center, not local values + if (isDraggingResize) { + const dx = Math.abs(imgX - roiCenterX); + const dy = Math.abs(imgY - roiCenterY); + if (roiMode === "rect") { + // For rectangle, update width and height independently + setRoiWidth(Math.max(2, Math.round(dx * 2))); + setRoiHeight(Math.max(2, Math.round(dy * 2))); + } else { + const newRadius = roiMode === "square" ? Math.max(dx, dy) : Math.sqrt(dx ** 2 + dy ** 2); + // For annular mode, outer radius must be greater than inner radius + const minRadius = roiMode === "annular" ? (roiRadiusInner || 0) + 1 : 1; + setRoiRadius(Math.max(minRadius, Math.round(newRadius))); + } + return; + } + + // Check hover state for resize handles + if (!isDraggingDP) { + setIsHoveringResizeInner(isNearResizeHandleInner(imgX, imgY)); + setIsHoveringResize(isNearResizeHandle(imgX, imgY)); + return; + } + + setLocalKx(imgX); setLocalKy(imgY); + setRoiCenterX(Math.round(Math.max(0, Math.min(detY - 1, imgX)))); + setRoiCenterY(Math.round(Math.max(0, Math.min(detX - 1, imgY)))); + }; + + const handleDpMouseUp = () => { setIsDraggingDP(false); setIsDraggingResize(false); setIsDraggingResizeInner(false); }; + const handleDpMouseLeave = () => { setIsDraggingDP(false); setIsDraggingResize(false); setIsDraggingResizeInner(false); setIsHoveringResize(false); setIsHoveringResizeInner(false); }; + const handleDpDoubleClick = () => { setDpZoom(1); setDpPanX(0); setDpPanY(0); }; + + const handleViMouseDown = (e: React.MouseEvent) => { + const canvas = virtualOverlayRef.current; + if (!canvas) return; + const rect = canvas.getBoundingClientRect(); + const screenX = (e.clientX - rect.left) * (canvas.width / rect.width); + const screenY = (e.clientY - rect.top) * (canvas.height / rect.height); + const imgX = (screenY - viPanY) / viZoom; + const imgY = (screenX - viPanX) / viZoom; + setIsDraggingVI(true); + setLocalPosX(imgX); setLocalPosY(imgY); + setPosX(Math.round(Math.max(0, Math.min(shapeX - 1, imgX)))); + setPosY(Math.round(Math.max(0, Math.min(shapeY - 1, imgY)))); + }; + + const handleViMouseMove = (e: React.MouseEvent) => { + if (!isDraggingVI) return; + const canvas = virtualOverlayRef.current; + if (!canvas) return; + const rect = canvas.getBoundingClientRect(); + const screenX = (e.clientX - rect.left) * (canvas.width / rect.width); + const screenY = (e.clientY - rect.top) * (canvas.height / rect.height); + const imgX = (screenY - viPanY) / viZoom; + const imgY = (screenX - viPanX) / viZoom; + setLocalPosX(imgX); setLocalPosY(imgY); + setPosX(Math.round(Math.max(0, Math.min(shapeX - 1, imgX)))); + setPosY(Math.round(Math.max(0, Math.min(shapeY - 1, imgY)))); + }; + + const handleViMouseUp = () => setIsDraggingVI(false); + const handleViDoubleClick = () => { setViZoom(1); setViPanX(0); setViPanY(0); }; + const handleFftDoubleClick = () => { setFftZoom(1); setFftPanX(0); setFftPanY(0); }; + + // FFT drag-to-pan handlers + const handleFftMouseDown = (e: React.MouseEvent) => { + setIsDraggingFFT(true); + setFftDragStart({ x: e.clientX, y: e.clientY, panX: fftPanX, panY: fftPanY }); + }; + + const handleFftMouseMove = (e: React.MouseEvent) => { + if (!isDraggingFFT || !fftDragStart) return; + const canvas = fftOverlayRef.current; + if (!canvas) return; + const rect = canvas.getBoundingClientRect(); + const scaleX = canvas.width / rect.width; + const scaleY = canvas.height / rect.height; + const dx = (e.clientX - fftDragStart.x) * scaleX; + const dy = (e.clientY - fftDragStart.y) * scaleY; + setFftPanX(fftDragStart.panX + dx); + setFftPanY(fftDragStart.panY + dy); + }; + + const handleFftMouseUp = () => { setIsDraggingFFT(false); setFftDragStart(null); }; + const handleFftMouseLeave = () => { setIsDraggingFFT(false); setFftDragStart(null); }; + + // ───────────────────────────────────────────────────────────────────────── + // Render + // ───────────────────────────────────────────────────────────────────────── + return ( + + {/* Wrapper to ensure header and content have same width */} + + {/* Header */} + + + 4D-STEM Explorer + + + + {shapeX}×{shapeY} scan | {detX}×{detY} det + + { + setBandpass([0, 0]); + setDpZoom(1); setDpPanX(0); setDpPanY(0); + setViZoom(1); setViPanX(0); setViPanY(0); + setFftZoom(1); setFftPanX(0); setFftPanY(0); + setRoiMode("point"); + }} + sx={{ ...controlPanel.button }} + > + Reset + + + + + + {/* LEFT: DP */} + + + DP at ({Math.round(localPosX)}, {Math.round(localPosY)}) + k: ({Math.round(localKx)}, {Math.round(localKy)}) + + + + + {/* High-DPI UI overlay for crisp scale bar */} + + + + {/* RIGHT: Virtual Image + FFT */} + + + + + Virtual Image + + { + const timestamp = new Date().toISOString().replace(/[:.]/g, '-').slice(0, 19); + const zip = new JSZip(); + + // Add metadata JSON + const metadata = { + exported_at: new Date().toISOString(), + scan_position: { x: posX, y: posY }, + scan_shape: { x: shapeX, y: shapeY }, + detector_shape: { x: detX, y: detY }, + roi: { + mode: roiMode, + center_x: roiCenterX, + center_y: roiCenterY, + radius_outer: roiRadius, + radius_inner: roiRadiusInner, + }, + display: { + colormap: colormap, + log_scale: logScale, + auto_range: autoRange, + percentile_low: percentileLow, + percentile_high: percentileHigh, + }, + calibration: { + bf_radius: bfRadius, + center_x: centerX, + center_y: centerY, + pixel_size: pixelSize, + det_pixel_size: detPixelSize, + }, + }; + zip.file("metadata.json", JSON.stringify(metadata, null, 2)); + + // Helper to convert canvas to blob + const canvasToBlob = (canvas: HTMLCanvasElement): Promise => { + return new Promise((resolve) => { + canvas.toBlob((blob) => resolve(blob!), 'image/png'); + }); + }; + + // Add images + const viCanvas = virtualCanvasRef.current; + if (viCanvas) { + const blob = await canvasToBlob(viCanvas); + zip.file("virtual_image.png", blob); + } + const dpCanvas = dpCanvasRef.current; + if (dpCanvas) { + const blob = await canvasToBlob(dpCanvas); + zip.file("diffraction_pattern.png", blob); + } + const fftCanvas = fftCanvasRef.current; + if (fftCanvas) { + const blob = await canvasToBlob(fftCanvas); + zip.file("fft.png", blob); + } + + // Generate and download ZIP + const zipBlob = await zip.generateAsync({ type: "blob" }); + const link = document.createElement('a'); + link.download = `4dstem_export_${timestamp}.zip`; + link.href = URL.createObjectURL(zipBlob); + link.click(); + URL.revokeObjectURL(link.href); + }} + sx={{ ...controlPanel.button }} + > + Export + + + + + + {/* High-DPI UI overlay for crisp scale bar */} + + + + + + + FFT + + + + + {/* High-DPI UI overlay for crisp scale bar */} + + + + + + + + {/* Controls - Organized in 3 rows */} + + + {/* Row 1: Presets + Detector */} + + {/* Detector Presets - only show if bf_radius is calibrated */} + {bfRadius > 0 && ( + + Presets: + + + + + + )} + + {/* Virtual Detector Mode */} + + Detector: + + {(roiMode === "circle" || roiMode === "square") && ( + <> + {roiMode === "circle" ? "r:" : "½:"} + setRoiRadius(v as number)} + min={1} + max={Math.min(detX, detY) / 2} + size="small" + sx={{ width: 80 }} + /> + + {Math.round(roiRadius || 10)}px + + + )} + {roiMode === "rect" && ( + <> + W: + setRoiWidth(v as number)} + min={2} + max={detY} + size="small" + sx={{ width: 60 }} + /> + H: + setRoiHeight(v as number)} + min={2} + max={detX} + size="small" + sx={{ width: 60 }} + /> + + {Math.round(roiWidth || 20)}×{Math.round(roiHeight || 10)} + + + )} + {roiMode === "annular" && ( + <> + in: + setRoiRadiusInner(v as number)} + min={0} + max={Math.min(detX, detY) / 2} + size="small" + sx={{ width: 60 }} + /> + out: + setRoiRadius(v as number)} + min={1} + max={Math.min(detX, detY) / 2} + size="small" + sx={{ width: 60 }} + /> + + {Math.round(roiRadiusInner || 5)}-{Math.round(roiRadius || 10)}px + + + )} + + + {/* Path Animation Controls - only show if path is defined */} + {pathLength > 0 && ( + + Path: + { setPathPlaying(false); setPathIndex(0); }} + sx={{ color: "#888", fontSize: 14, cursor: "pointer", "&:hover": { color: "#fff" }, px: 0.5 }} + title="Stop" + >⏹ + setPathPlaying(!pathPlaying)} + sx={{ color: pathPlaying ? "#0f0" : "#888", fontSize: 14, cursor: "pointer", "&:hover": { color: "#fff" }, px: 0.5 }} + title={pathPlaying ? "Pause" : "Play"} + >{pathPlaying ? "⏸" : "▶"} + + {pathIndex + 1}/{pathLength} + + { setPathPlaying(false); setPathIndex(v as number); }} + min={0} + max={Math.max(0, pathLength - 1)} + size="small" + sx={{ width: 100 }} + /> + + )} + + + {/* Row 2: Colormap + Log + Contrast */} + + + Colormap: + + + + + Log: + setLogScale(e.target.checked)} + size="small" + sx={switchStyles.medium} + /> + + + + Contrast: + setAutoRange(e.target.checked)} + size="small" + sx={switchStyles.small} + /> + {autoRange && ( + <> + { + const [low, high] = v as number[]; + setPercentileLow(low); + setPercentileHigh(high); + }} + min={0} + max={100} + size="small" + sx={{ width: 80 }} + valueLabelDisplay="auto" + valueLabelFormat={(v) => `${v}%`} + /> + + {Math.round(percentileLow || 1)}-{Math.round(percentileHigh || 99)}% + + + )} + + + + + ); +} + +export const render = createRender(Show4DSTEM); diff --git a/widget/js/webgpu-fft.ts b/widget/js/webgpu-fft.ts new file mode 100644 index 00000000..4cbdbe7e --- /dev/null +++ b/widget/js/webgpu-fft.ts @@ -0,0 +1,558 @@ +/// + +/** + * WebGPU FFT Implementation + * + * Implements Cooley-Tukey radix-2 FFT using WebGPU compute shaders. + * Supports 1D and 2D FFT with forward and inverse transforms. + */ + +// WGSL Shader for FFT butterfly operations +const FFT_SHADER = /* wgsl */` +// Complex number operations +fn cmul(a: vec2, b: vec2) -> vec2 { + return vec2( + a.x * b.x - a.y * b.y, + a.x * b.y + a.y * b.x + ); +} + +// Twiddle factor: e^(-2πi * k / N) for forward, e^(2πi * k / N) for inverse +fn twiddle(k: u32, N: u32, inverse: f32) -> vec2 { + let angle = inverse * 2.0 * 3.14159265359 * f32(k) / f32(N); + return vec2(cos(angle), sin(angle)); +} + +// Bit reversal for index +fn bitReverse(x: u32, log2N: u32) -> u32 { + var result: u32 = 0u; + var val = x; + for (var i: u32 = 0u; i < log2N; i = i + 1u) { + result = (result << 1u) | (val & 1u); + val = val >> 1u; + } + return result; +} + +struct FFTParams { + N: u32, // FFT size + log2N: u32, // log2(N) + stage: u32, // Current butterfly stage + inverse: f32, // -1.0 for forward, 1.0 for inverse +} + +@group(0) @binding(0) var params: FFTParams; +@group(0) @binding(1) var data: array>; + +// Bit-reversal permutation kernel +@compute @workgroup_size(256) +fn bitReversePermute(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= params.N) { return; } + + let rev = bitReverse(idx, params.log2N); + if (idx < rev) { + let temp = data[idx]; + data[idx] = data[rev]; + data[rev] = temp; + } +} + +// Butterfly operation kernel for one stage +@compute @workgroup_size(256) +fn butterflyStage(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= params.N / 2u) { return; } + + let stage = params.stage; + let halfSize = 1u << stage; // 2^stage + let fullSize = halfSize << 1u; // 2^(stage+1) + + let group = idx / halfSize; + let pos = idx % halfSize; + + let i = group * fullSize + pos; + let j = i + halfSize; + + let w = twiddle(pos, fullSize, params.inverse); + + let u = data[i]; + let t = cmul(w, data[j]); + + data[i] = u + t; + data[j] = u - t; +} + +// Normalization for inverse FFT +@compute @workgroup_size(256) +fn normalize(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= params.N) { return; } + + let scale = 1.0 / f32(params.N); + data[idx] = data[idx] * scale; +} +`; + +// 2D FFT Shader (row-wise and column-wise transforms) +const FFT_2D_SHADER = /* wgsl */` +fn cmul(a: vec2, b: vec2) -> vec2 { + return vec2( + a.x * b.x - a.y * b.y, + a.x * b.y + a.y * b.x + ); +} + +fn twiddle(k: u32, N: u32, inverse: f32) -> vec2 { + let angle = inverse * 2.0 * 3.14159265359 * f32(k) / f32(N); + return vec2(cos(angle), sin(angle)); +} + +fn bitReverse(x: u32, log2N: u32) -> u32 { + var result: u32 = 0u; + var val = x; + for (var i: u32 = 0u; i < log2N; i = i + 1u) { + result = (result << 1u) | (val & 1u); + val = val >> 1u; + } + return result; +} + +struct FFT2DParams { + width: u32, + height: u32, + log2Size: u32, + stage: u32, + inverse: f32, + isRowWise: u32, // 1 for row-wise, 0 for column-wise +} + +@group(0) @binding(0) var params: FFT2DParams; +@group(0) @binding(1) var data: array>; + +// Get linear index for 2D data +fn getIndex(row: u32, col: u32) -> u32 { + return row * params.width + col; +} + +// Bit-reversal for rows +@compute @workgroup_size(16, 16) +fn bitReverseRows(@builtin(global_invocation_id) gid: vec3) { + let row = gid.y; + let col = gid.x; + if (row >= params.height || col >= params.width) { return; } + + let rev = bitReverse(col, params.log2Size); + if (col < rev) { + let idx1 = getIndex(row, col); + let idx2 = getIndex(row, rev); + let temp = data[idx1]; + data[idx1] = data[idx2]; + data[idx2] = temp; + } +} + +// Bit-reversal for columns +@compute @workgroup_size(16, 16) +fn bitReverseCols(@builtin(global_invocation_id) gid: vec3) { + let row = gid.y; + let col = gid.x; + if (row >= params.height || col >= params.width) { return; } + + let rev = bitReverse(row, params.log2Size); + if (row < rev) { + let idx1 = getIndex(row, col); + let idx2 = getIndex(rev, col); + let temp = data[idx1]; + data[idx1] = data[idx2]; + data[idx2] = temp; + } +} + +// Butterfly for rows +@compute @workgroup_size(16, 16) +fn butterflyRows(@builtin(global_invocation_id) gid: vec3) { + let row = gid.y; + let idx = gid.x; + if (row >= params.height || idx >= params.width / 2u) { return; } + + let stage = params.stage; + let halfSize = 1u << stage; + let fullSize = halfSize << 1u; + + let group = idx / halfSize; + let pos = idx % halfSize; + + let col_i = group * fullSize + pos; + let col_j = col_i + halfSize; + + if (col_j >= params.width) { return; } + + let w = twiddle(pos, fullSize, params.inverse); + + let i = getIndex(row, col_i); + let j = getIndex(row, col_j); + + let u = data[i]; + let t = cmul(w, data[j]); + + data[i] = u + t; + data[j] = u - t; +} + +// Butterfly for columns +@compute @workgroup_size(16, 16) +fn butterflyCols(@builtin(global_invocation_id) gid: vec3) { + let col = gid.x; + let idx = gid.y; + if (col >= params.width || idx >= params.height / 2u) { return; } + + let stage = params.stage; + let halfSize = 1u << stage; + let fullSize = halfSize << 1u; + + let group = idx / halfSize; + let pos = idx % halfSize; + + let row_i = group * fullSize + pos; + let row_j = row_i + halfSize; + + if (row_j >= params.height) { return; } + + let w = twiddle(pos, fullSize, params.inverse); + + let i = getIndex(row_i, col); + let j = getIndex(row_j, col); + + let u = data[i]; + let t = cmul(w, data[j]); + + data[i] = u + t; + data[j] = u - t; +} + +// Normalization for inverse 2D FFT +@compute @workgroup_size(16, 16) +fn normalize2D(@builtin(global_invocation_id) gid: vec3) { + let row = gid.y; + let col = gid.x; + if (row >= params.height || col >= params.width) { return; } + + let idx = getIndex(row, col); + let scale = 1.0 / f32(params.width * params.height); + data[idx] = data[idx] * scale; +} +`; + +/** + * Get next power of 2 >= n + */ +function nextPow2(n: number): number { + return Math.pow(2, Math.ceil(Math.log2(n))); +} + +/** + * WebGPU FFT class for 1D and 2D transforms + */ +export class WebGPUFFT { + private device: GPUDevice; + private pipelines1D: { + bitReverse: GPUComputePipeline; + butterfly: GPUComputePipeline; + normalize: GPUComputePipeline; + } | null = null; + private pipelines2D: { + bitReverseRows: GPUComputePipeline; + bitReverseCols: GPUComputePipeline; + butterflyRows: GPUComputePipeline; + butterflyCols: GPUComputePipeline; + normalize: GPUComputePipeline; + } | null = null; + private initialized = false; + + constructor(device: GPUDevice) { + this.device = device; + } + + async init(): Promise { + if (this.initialized) return; + + // Create 1D FFT pipelines + const module1D = this.device.createShaderModule({ code: FFT_SHADER }); + + this.pipelines1D = { + bitReverse: this.device.createComputePipeline({ + layout: 'auto', + compute: { module: module1D, entryPoint: 'bitReversePermute' } + }), + butterfly: this.device.createComputePipeline({ + layout: 'auto', + compute: { module: module1D, entryPoint: 'butterflyStage' } + }), + normalize: this.device.createComputePipeline({ + layout: 'auto', + compute: { module: module1D, entryPoint: 'normalize' } + }) + }; + + // Create 2D FFT pipelines + const module2D = this.device.createShaderModule({ code: FFT_2D_SHADER }); + + this.pipelines2D = { + bitReverseRows: this.device.createComputePipeline({ + layout: 'auto', + compute: { module: module2D, entryPoint: 'bitReverseRows' } + }), + bitReverseCols: this.device.createComputePipeline({ + layout: 'auto', + compute: { module: module2D, entryPoint: 'bitReverseCols' } + }), + butterflyRows: this.device.createComputePipeline({ + layout: 'auto', + compute: { module: module2D, entryPoint: 'butterflyRows' } + }), + butterflyCols: this.device.createComputePipeline({ + layout: 'auto', + compute: { module: module2D, entryPoint: 'butterflyCols' } + }), + normalize: this.device.createComputePipeline({ + layout: 'auto', + compute: { module: module2D, entryPoint: 'normalize2D' } + }) + }; + + this.initialized = true; + console.log('WebGPU FFT initialized'); + } + + /** + * Perform 2D FFT - supports ANY size via automatic zero-padding + */ + async fft2D( + realData: Float32Array, + imagData: Float32Array, + width: number, + height: number, + inverse: boolean = false + ): Promise<{ real: Float32Array, imag: Float32Array }> { + await this.init(); + + // Compute padded power-of-2 dimensions + const paddedWidth = nextPow2(width); + const paddedHeight = nextPow2(height); + const needsPadding = paddedWidth !== width || paddedHeight !== height; + + const log2Width = Math.log2(paddedWidth); + const log2Height = Math.log2(paddedHeight); + + const paddedSize = paddedWidth * paddedHeight; + const originalSize = width * height; + + // Zero-pad input if needed + let workReal: Float32Array; + let workImag: Float32Array; + + if (needsPadding) { + workReal = new Float32Array(paddedSize); + workImag = new Float32Array(paddedSize); + // Copy original data into top-left corner + for (let y = 0; y < height; y++) { + for (let x = 0; x < width; x++) { + const srcIdx = y * width + x; + const dstIdx = y * paddedWidth + x; + workReal[dstIdx] = realData[srcIdx]; + workImag[dstIdx] = imagData[srcIdx]; + } + } + } else { + workReal = realData; + workImag = imagData; + } + + const size = paddedSize; + + // Interleave real and imaginary (use padded work arrays) + const complexData = new Float32Array(size * 2); + for (let i = 0; i < size; i++) { + complexData[i * 2] = workReal[i]; + complexData[i * 2 + 1] = workImag[i]; + } + + // Create buffers + const dataBuffer = this.device.createBuffer({ + size: complexData.byteLength, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, + }); + this.device.queue.writeBuffer(dataBuffer, 0, complexData); + + const paramsBuffer = this.device.createBuffer({ + size: 24, // 6 x u32/f32 + usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, + }); + + const readBuffer = this.device.createBuffer({ + size: complexData.byteLength, + usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, + }); + + const inverseVal = inverse ? 1.0 : -1.0; + const workgroupsX = Math.ceil(paddedWidth / 16); + const workgroupsY = Math.ceil(paddedHeight / 16); + + // Helper to run a pass + const runPass = (pipeline: GPUComputePipeline) => { + const bindGroup = this.device.createBindGroup({ + layout: pipeline.getBindGroupLayout(0), + entries: [ + { binding: 0, resource: { buffer: paramsBuffer } }, + { binding: 1, resource: { buffer: dataBuffer } }, + ] + }); + + const encoder = this.device.createCommandEncoder(); + const pass = encoder.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bindGroup); + pass.dispatchWorkgroups(workgroupsX, workgroupsY); + pass.end(); + this.device.queue.submit([encoder.finish()]); + }; + + // Row-wise FFT (use padded dimensions) + const params = new ArrayBuffer(24); + const paramsU32 = new Uint32Array(params); + const paramsF32 = new Float32Array(params); + paramsU32[0] = paddedWidth; + paramsU32[1] = paddedHeight; + paramsU32[2] = log2Width; + paramsU32[3] = 0; + paramsF32[4] = inverseVal; + paramsU32[5] = 1; + this.device.queue.writeBuffer(paramsBuffer, 0, params); + runPass(this.pipelines2D!.bitReverseRows); + + for (let stage = 0; stage < log2Width; stage++) { + paramsU32[3] = stage; + this.device.queue.writeBuffer(paramsBuffer, 0, params); + runPass(this.pipelines2D!.butterflyRows); + } + + // Column-wise FFT + paramsU32[2] = log2Height; + paramsU32[3] = 0; + paramsU32[5] = 0; + this.device.queue.writeBuffer(paramsBuffer, 0, params); + runPass(this.pipelines2D!.bitReverseCols); + + for (let stage = 0; stage < log2Height; stage++) { + paramsU32[3] = stage; + this.device.queue.writeBuffer(paramsBuffer, 0, params); + runPass(this.pipelines2D!.butterflyCols); + } + + if (inverse) { + runPass(this.pipelines2D!.normalize); + } + + // Read back results + const encoder = this.device.createCommandEncoder(); + encoder.copyBufferToBuffer(dataBuffer, 0, readBuffer, 0, complexData.byteLength); + this.device.queue.submit([encoder.finish()]); + + await readBuffer.mapAsync(GPUMapMode.READ); + const result = new Float32Array(readBuffer.getMappedRange().slice(0)); + readBuffer.unmap(); + + // Cleanup GPU buffers + dataBuffer.destroy(); + paramsBuffer.destroy(); + readBuffer.destroy(); + + // Deinterleave and crop back to original size if needed + if (needsPadding) { + const realResult = new Float32Array(originalSize); + const imagResult = new Float32Array(originalSize); + for (let y = 0; y < height; y++) { + for (let x = 0; x < width; x++) { + const srcIdx = y * paddedWidth + x; + const dstIdx = y * width + x; + realResult[dstIdx] = result[srcIdx * 2]; + imagResult[dstIdx] = result[srcIdx * 2 + 1]; + } + } + return { real: realResult, imag: imagResult }; + } else { + const realResult = new Float32Array(size); + const imagResult = new Float32Array(size); + for (let i = 0; i < size; i++) { + realResult[i] = result[i * 2]; + imagResult[i] = result[i * 2 + 1]; + } + return { real: realResult, imag: imagResult }; + } + } + + destroy(): void { + this.initialized = false; + } +} + +// Singleton instance +let gpuFFT: WebGPUFFT | null = null; +let gpuDevice: GPUDevice | null = null; +let gpuInfo = "GPU"; + +/** + * Initialize WebGPU and get FFT instance + */ +export async function getWebGPUFFT(): Promise { + if (gpuFFT) return gpuFFT; + + if (!navigator.gpu) { + console.warn('WebGPU not supported, falling back to CPU FFT'); + return null; + } + + try { + const adapter = await navigator.gpu.requestAdapter(); + if (!adapter) { + console.warn('No GPU adapter found'); + return null; + } + + // Attempt to get GPU info + try { + // In modern browsers, we can request adapter info + // @ts-ignore - requestAdapterInfo is not yet in all type definitions + const info = await adapter.requestAdapterInfo?.(); + if (info) { + // Prioritize 'description' which usually has the full name (e.g. "NVIDIA GeForce RTX 4090") + // Fallback to vendor/device if description is missing + gpuInfo = info.description || + `${info.vendor} ${info.architecture || ""} ${info.device || ""}`.trim() || + "Generic WebGPU Adapter"; + } + } catch (e) { + console.log("Could not get detailed adapter info", e); + } + + gpuDevice = await adapter.requestDevice(); + gpuFFT = new WebGPUFFT(gpuDevice); + await gpuFFT.init(); + + console.log(`🚀 WebGPU FFT ready on ${gpuInfo}!`); + return gpuFFT; + } catch (e) { + console.warn('WebGPU init failed:', e); + return null; + } +} + +/** + * Get current GPU info string + */ +export function getGPUInfo(): string { + return gpuInfo; +} + +export default WebGPUFFT; From ec631956787d87879fc007686a9e1803310f8734 Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Mon, 12 Jan 2026 06:52:36 -0800 Subject: [PATCH 02/27] add python source code for show4dstem --- src/quantem/core/io/file_readers.py | 118 +++- src/quantem/core/utils/validators.py | 6 +- uv.lock | 8 +- widget/js/show4dstem.tsx | 515 ++++++++------- widget/package-lock.json | 896 +++++++++++++++++++++++++- widget/package.json | 26 +- widget/pyproject.toml | 2 + widget/src/quantem/widget/__init__.py | 31 +- widget/tests/test_widget.py | 10 + widget/vite.config.js | 10 +- 10 files changed, 1294 insertions(+), 328 deletions(-) diff --git a/src/quantem/core/io/file_readers.py b/src/quantem/core/io/file_readers.py index cb36f1de..91d93eeb 100644 --- a/src/quantem/core/io/file_readers.py +++ b/src/quantem/core/io/file_readers.py @@ -138,6 +138,41 @@ def read_2d( return dataset +def _find_4d_dataset(group: h5py.Group, path: list[str] | None = None) -> tuple[list[str], h5py.Dataset] | None: + """Recursively search for a 4D dataset in an HDF5 group.""" + if path is None: + path = [] + for key in group.keys(): + item = group[key] + current_path = path + [key] + if isinstance(item, h5py.Dataset): + if item.ndim == 4: + return current_path, item + elif isinstance(item, h5py.Group): + result = _find_4d_dataset(item, current_path) + if result is not None: + return result + return None + + +def _find_calibration(group: h5py.Group, path: list[str] | None = None) -> tuple[list[str], h5py.Group] | None: + """Recursively search for a calibration group containing R_pixel_size and Q_pixel_size.""" + if path is None: + path = [] + for key in group.keys(): + item = group[key] + current_path = path + [key] + if isinstance(item, h5py.Group): + # Check if this group has calibration keys + if "R_pixel_size" in item and "Q_pixel_size" in item: + return current_path, item + # Recurse into subgroups + result = _find_calibration(item, current_path) + if result is not None: + return result + return None + + def read_emdfile_to_4dstem( file_path: str | PathLike, data_keys: list[str] | None = None, @@ -146,42 +181,73 @@ def read_emdfile_to_4dstem( """ File reader for legacy `emdFile` / `py4DSTEM` files. + If data_keys and calibration_keys are not provided, the function will + automatically search for a 4D dataset and calibration metadata. + Parameters ---------- file_path: str | PathLike Path to data + data_keys: list[str], optional + List of keys to navigate to the data. If None, auto-detects. + calibration_keys: list[str], optional + List of keys to navigate to calibration. If None, auto-detects. Returns -------- Dataset4dstem """ with h5py.File(file_path, "r") as file: - # Access the data directly - data_keys = ["datacube_root", "datacube", "data"] if data_keys is None else data_keys - print("keys: ", data_keys) - try: - data = file - for key in data_keys: - data = data[key] # type: ignore - except KeyError: - raise KeyError(f"Could not find key {data_keys} in {file_path}") - - # Access calibration values directly - calibration_keys = ( - ["datacube_root", "metadatabundle", "calibration"] - if calibration_keys is None - else calibration_keys - ) - try: - calibration = file - for key in calibration_keys: - calibration = calibration[key] # type: ignore - except KeyError: - raise KeyError(f"Could not find calibration key {calibration_keys} in {file_path}") - r_pixel_size = calibration["R_pixel_size"][()] # type: ignore - q_pixel_size = calibration["Q_pixel_size"][()] # type: ignore - r_pixel_units = calibration["R_pixel_units"][()] # type: ignore - q_pixel_units = calibration["Q_pixel_units"][()] # type: ignore + # Auto-detect or use provided data keys + if data_keys is None: + result = _find_4d_dataset(file) + if result is None: + raise KeyError(f"Could not find any 4D dataset in {file_path}") + data_keys, data = result + else: + try: + data = file + for key in data_keys: + data = data[key] # type: ignore + except KeyError: + raise KeyError(f"Could not find key {data_keys} in {file_path}") + + # Auto-detect or use provided calibration keys + if calibration_keys is None: + result = _find_calibration(file) + if result is None: + # No calibration found, use defaults + r_pixel_size = 1.0 + q_pixel_size = 1.0 + r_pixel_units = "pixels" + q_pixel_units = "pixels" + else: + calibration_keys, calibration = result + r_pixel_size = calibration["R_pixel_size"][()] # type: ignore + q_pixel_size = calibration["Q_pixel_size"][()] # type: ignore + r_pixel_units = calibration.get("R_pixel_units", [()]) + if hasattr(r_pixel_units, "__getitem__"): + r_pixel_units = r_pixel_units[()] + q_pixel_units = calibration.get("Q_pixel_units", [()]) + if hasattr(q_pixel_units, "__getitem__"): + q_pixel_units = q_pixel_units[()] + else: + try: + calibration = file + for key in calibration_keys: + calibration = calibration[key] # type: ignore + except KeyError: + raise KeyError(f"Could not find calibration key {calibration_keys} in {file_path}") + r_pixel_size = calibration["R_pixel_size"][()] # type: ignore + q_pixel_size = calibration["Q_pixel_size"][()] # type: ignore + r_pixel_units = calibration["R_pixel_units"][()] # type: ignore + q_pixel_units = calibration["Q_pixel_units"][()] # type: ignore + + # Decode bytes to string if needed + if isinstance(r_pixel_units, bytes): + r_pixel_units = r_pixel_units.decode("utf-8") + if isinstance(q_pixel_units, bytes): + q_pixel_units = q_pixel_units.decode("utf-8") dataset = Dataset4dstem.from_array( array=data, diff --git a/src/quantem/core/utils/validators.py b/src/quantem/core/utils/validators.py index 78ddfef5..0bc11670 100644 --- a/src/quantem/core/utils/validators.py +++ b/src/quantem/core/utils/validators.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os from pathlib import Path from typing import TYPE_CHECKING, Any, Optional, TypeAlias, Union, overload @@ -12,13 +14,15 @@ if TYPE_CHECKING: import cupy as cp import torch + + TensorLike: TypeAlias = ArrayLike | torch.Tensor else: if config.get("has_torch"): import torch if config.get("has_cupy"): import cupy as cp -TensorLike: TypeAlias = ArrayLike | "torch.Tensor" + TensorLike: TypeAlias = Union[ArrayLike, "torch.Tensor"] # --- Dataset Validation Functions --- diff --git a/uv.lock b/uv.lock index a489a851..80ed1ec0 100644 --- a/uv.lock +++ b/uv.lock @@ -2658,10 +2658,16 @@ version = "0.0.1" source = { editable = "widget" } dependencies = [ { name = "anywidget" }, + { name = "numpy" }, + { name = "traitlets" }, ] [package.metadata] -requires-dist = [{ name = "anywidget", specifier = ">=0.9.0" }] +requires-dist = [ + { name = "anywidget", specifier = ">=0.9.0" }, + { name = "numpy", specifier = ">=2.0.0" }, + { name = "traitlets", specifier = ">=5.0.0" }, +] [[package]] name = "referencing" diff --git a/widget/js/show4dstem.tsx b/widget/js/show4dstem.tsx index a26f1047..f32e7adb 100644 --- a/widget/js/show4dstem.tsx +++ b/widget/js/show4dstem.tsx @@ -21,8 +21,10 @@ import "./show4dstem.css"; const RESIZE_HANDLE_FRACTION = 0.05; // Resize handle as fraction of detector size const RESIZE_HANDLE_MIN_PX = 5; // Minimum resize handle radius const RESIZE_HANDLE_MAX_PX = 8; // Maximum resize handle radius +const RESIZE_HANDLE_RADIUS = 6; // Fixed handle radius for drawing const RESIZE_HIT_AREA_FRACTION = 0.06; // Click tolerance as fraction of detector const RESIZE_HIT_AREA_MIN_PX = 6; // Minimum click tolerance +const RESIZE_HIT_AREA_PX = 10; // Fixed hit area for click detection // Crosshair sizes: fixed pixel sizes for consistent appearance const CROSSHAIR_SIZE_PX = 18; // Fixed crosshair size for point mode (CSS pixels on 400px canvas) const CROSSHAIR_SIZE_SMALL_PX = 10; // Fixed small crosshair size for ROI center @@ -260,8 +262,188 @@ function drawDpCrosshairHiDPI( ctx.restore(); } -// Legacy stub - scale bars now drawn on high-DPI UI canvases -function drawScaleBar(..._args: unknown[]) { /* no-op, replaced by drawScaleBarHiDPI */ } +/** + * Draw ROI overlay (circle, square, rect, annular) on high-DPI canvas + * Note: Does NOT clear canvas - should be called after drawScaleBarHiDPI + */ +function drawRoiOverlayHiDPI( + canvas: HTMLCanvasElement, + dpr: number, + roiMode: string, + centerX: number, + centerY: number, + radius: number, + radiusInner: number, + roiWidth: number, + roiHeight: number, + zoom: number, + panX: number, + panY: number, + detWidth: number, + detHeight: number, + isDragging: boolean, + isDraggingResize: boolean, + isDraggingResizeInner: boolean, + isHoveringResize: boolean, + isHoveringResizeInner: boolean +) { + const ctx = canvas.getContext("2d"); + if (!ctx) return; + + ctx.save(); + ctx.scale(dpr, dpr); + + const cssWidth = canvas.width / dpr; + const cssHeight = canvas.height / dpr; + const displayScale = Math.min(cssWidth / detWidth, cssHeight / detHeight); + + // Convert detector coordinates to CSS pixel coordinates + const screenX = centerX * zoom * displayScale + panX * displayScale; + const screenY = centerY * zoom * displayScale + panY * displayScale; + + // Fixed UI sizes in CSS pixels + const lineWidth = 2.5; + const crosshairSizeSmall = 10; + const handleRadius = 6; + + ctx.shadowColor = "rgba(0, 0, 0, 0.4)"; + ctx.shadowBlur = 2; + ctx.shadowOffsetX = 1; + ctx.shadowOffsetY = 1; + + // Helper to draw resize handle + const drawResizeHandle = (handleX: number, handleY: number, isInner: boolean = false) => { + let handleFill: string; + let handleStroke: string; + const dragging = isInner ? isDraggingResizeInner : isDraggingResize; + const hovering = isInner ? isHoveringResizeInner : isHoveringResize; + + if (dragging) { + handleFill = "rgba(0, 200, 255, 1)"; + handleStroke = "rgba(255, 255, 255, 1)"; + } else if (hovering) { + handleFill = "rgba(255, 100, 100, 1)"; + handleStroke = "rgba(255, 255, 255, 1)"; + } else { + handleFill = isInner ? "rgba(0, 220, 255, 0.8)" : "rgba(0, 255, 0, 0.8)"; + handleStroke = "rgba(255, 255, 255, 0.8)"; + } + ctx.beginPath(); + ctx.arc(handleX, handleY, handleRadius, 0, 2 * Math.PI); + ctx.fillStyle = handleFill; + ctx.fill(); + ctx.strokeStyle = handleStroke; + ctx.lineWidth = 1.5; + ctx.stroke(); + }; + + // Helper to draw center crosshair + const drawCenterCrosshair = () => { + ctx.strokeStyle = isDragging ? "rgba(255, 255, 0, 0.9)" : "rgba(0, 255, 0, 0.9)"; + ctx.lineWidth = lineWidth; + ctx.beginPath(); + ctx.moveTo(screenX - crosshairSizeSmall, screenY); + ctx.lineTo(screenX + crosshairSizeSmall, screenY); + ctx.moveTo(screenX, screenY - crosshairSizeSmall); + ctx.lineTo(screenX, screenY + crosshairSizeSmall); + ctx.stroke(); + }; + + const HANDLE_ANGLE = 0.707; // cos(45°) + + if (roiMode === "circle" && radius > 0) { + const screenRadius = radius * zoom * displayScale; + + // Draw circle + ctx.strokeStyle = isDragging ? "rgba(255, 255, 0, 0.9)" : "rgba(0, 255, 0, 0.9)"; + ctx.lineWidth = lineWidth; + ctx.beginPath(); + ctx.arc(screenX, screenY, screenRadius, 0, 2 * Math.PI); + ctx.stroke(); + + // Semi-transparent fill + ctx.fillStyle = isDragging ? "rgba(255, 255, 0, 0.12)" : "rgba(0, 255, 0, 0.12)"; + ctx.fill(); + + drawCenterCrosshair(); + + // Resize handle at 45° + const handleOffset = screenRadius * HANDLE_ANGLE; + drawResizeHandle(screenX + handleOffset, screenY + handleOffset); + + } else if (roiMode === "square" && radius > 0) { + const screenHalfSize = radius * zoom * displayScale; + const left = screenX - screenHalfSize; + const top = screenY - screenHalfSize; + const size = screenHalfSize * 2; + + ctx.strokeStyle = isDragging ? "rgba(255, 255, 0, 0.9)" : "rgba(0, 255, 0, 0.9)"; + ctx.lineWidth = lineWidth; + ctx.beginPath(); + ctx.rect(left, top, size, size); + ctx.stroke(); + + ctx.fillStyle = isDragging ? "rgba(255, 255, 0, 0.12)" : "rgba(0, 255, 0, 0.12)"; + ctx.fill(); + + drawCenterCrosshair(); + drawResizeHandle(screenX + screenHalfSize, screenY + screenHalfSize); + + } else if (roiMode === "rect" && roiWidth > 0 && roiHeight > 0) { + const screenHalfW = (roiWidth / 2) * zoom * displayScale; + const screenHalfH = (roiHeight / 2) * zoom * displayScale; + const left = screenX - screenHalfW; + const top = screenY - screenHalfH; + + ctx.strokeStyle = isDragging ? "rgba(255, 255, 0, 0.9)" : "rgba(0, 255, 0, 0.9)"; + ctx.lineWidth = lineWidth; + ctx.beginPath(); + ctx.rect(left, top, screenHalfW * 2, screenHalfH * 2); + ctx.stroke(); + + ctx.fillStyle = isDragging ? "rgba(255, 255, 0, 0.12)" : "rgba(0, 255, 0, 0.12)"; + ctx.fill(); + + drawCenterCrosshair(); + drawResizeHandle(screenX + screenHalfW, screenY + screenHalfH); + + } else if (roiMode === "annular" && radius > 0) { + const screenRadiusOuter = radius * zoom * displayScale; + const screenRadiusInner = (radiusInner || 0) * zoom * displayScale; + + // Outer circle (green) + ctx.strokeStyle = isDragging ? "rgba(255, 255, 0, 0.9)" : "rgba(0, 255, 0, 0.9)"; + ctx.lineWidth = lineWidth; + ctx.beginPath(); + ctx.arc(screenX, screenY, screenRadiusOuter, 0, 2 * Math.PI); + ctx.stroke(); + + // Inner circle (cyan) + ctx.strokeStyle = isDragging ? "rgba(255, 200, 0, 0.9)" : "rgba(0, 220, 255, 0.9)"; + ctx.beginPath(); + ctx.arc(screenX, screenY, screenRadiusInner, 0, 2 * Math.PI); + ctx.stroke(); + + // Fill annular region + ctx.fillStyle = isDragging ? "rgba(255, 255, 0, 0.12)" : "rgba(0, 255, 0, 0.12)"; + ctx.beginPath(); + ctx.arc(screenX, screenY, screenRadiusOuter, 0, 2 * Math.PI); + ctx.arc(screenX, screenY, screenRadiusInner, 0, 2 * Math.PI, true); + ctx.fill(); + + drawCenterCrosshair(); + + // Outer handle + const handleOffsetOuter = screenRadiusOuter * HANDLE_ANGLE; + drawResizeHandle(screenX + handleOffsetOuter, screenY + handleOffsetOuter); + + // Inner handle + const handleOffsetInner = screenRadiusInner * HANDLE_ANGLE; + drawResizeHandle(screenX + handleOffsetInner, screenY + handleOffsetInner, true); + } + + ctx.restore(); +} // ============================================================================ // Main Component @@ -410,9 +592,6 @@ function Show4DSTEM() { if (fft) { gpuFFTRef.current = fft; setGpuReady(true); - console.log("WebGPU FFT ready - using GPU acceleration!"); - } else { - console.log("⚠️ WebGPU not available - using CPU FFT"); } }); }, []); @@ -517,175 +696,15 @@ function Show4DSTEM() { ctx.restore(); }, [frameBytes, detX, detY, colormap, dpZoom, dpPanX, dpPanY]); - // Render DP overlay + // Render DP overlay - just clear (ROI shapes now drawn on high-DPI UI canvas) React.useEffect(() => { if (!dpOverlayRef.current) return; const canvas = dpOverlayRef.current; const ctx = canvas.getContext("2d"); if (!ctx) return; ctx.clearRect(0, 0, canvas.width, canvas.height); - - const screenKx = localKx * dpZoom + dpPanX; - const screenKy = localKy * dpZoom + dpPanY; - - // Convert fixed CSS pixel sizes to canvas pixels (canvas is detX x detY, displayed at 400x400) - const minDetSize = Math.min(detX, detY); - const canvasScale = minDetSize / 400; // How many canvas pixels per CSS pixel - const crosshairSize = CROSSHAIR_SIZE_PX * canvasScale * dpZoom; - const crosshairSizeSmall = CROSSHAIR_SIZE_SMALL_PX * canvasScale * dpZoom; - const centerDotRadius = CENTER_DOT_RADIUS_PX * canvasScale * dpZoom; - const lineWidth = Math.max(LINE_WIDTH_MIN_PX, Math.min(LINE_WIDTH_MAX_PX, minDetSize * LINE_WIDTH_FRACTION)); - - ctx.strokeStyle = isDraggingDP ? "rgba(255, 255, 0, 0.9)" : "rgba(0, 255, 0, 0.9)"; - ctx.lineWidth = lineWidth; - - // Helper to draw resize handle - const drawResizeHandle = (handleX: number, handleY: number) => { - let handleFill: string; - let handleStroke: string; - if (isDraggingResize) { - handleFill = "rgba(0, 200, 255, 1)"; // Cyan when dragging - handleStroke = "rgba(255, 255, 255, 1)"; - } else if (isHoveringResize) { - handleFill = "rgba(255, 100, 100, 1)"; // Red when hovering - handleStroke = "rgba(255, 255, 255, 1)"; - } else { - handleFill = "rgba(0, 255, 0, 0.8)"; // Green default - handleStroke = "rgba(255, 255, 255, 0.8)"; - } - ctx.beginPath(); - ctx.arc(handleX, handleY, RESIZE_HANDLE_RADIUS, 0, 2 * Math.PI); - ctx.fillStyle = handleFill; - ctx.fill(); - ctx.strokeStyle = handleStroke; - ctx.lineWidth = 1.5; - ctx.stroke(); - }; - - if (roiMode === "circle" && roiRadius > 0) { - // Circle mode: draw a filled circular ROI - const screenRadius = roiRadius * dpZoom; - ctx.beginPath(); - ctx.arc(screenKx, screenKy, screenRadius, 0, 2 * Math.PI); - ctx.stroke(); - // Semi-transparent fill - ctx.fillStyle = isDraggingDP ? "rgba(255, 255, 0, 0.15)" : "rgba(0, 255, 0, 0.15)"; - ctx.fill(); - // Draw center crosshair (smaller) - ctx.strokeStyle = isDraggingDP ? "rgba(255, 255, 0, 0.9)" : "rgba(0, 255, 0, 0.9)"; - ctx.lineWidth = lineWidth; - ctx.beginPath(); - ctx.moveTo(screenKx - crosshairSizeSmall, screenKy); ctx.lineTo(screenKx + crosshairSizeSmall, screenKy); - ctx.moveTo(screenKx, screenKy - crosshairSizeSmall); ctx.lineTo(screenKx, screenKy + crosshairSizeSmall); - ctx.stroke(); - - // Draw resize handle at bottom-right of circle (45° position) - const handleOffset = screenRadius * CIRCLE_HANDLE_ANGLE; - drawResizeHandle(screenKx + handleOffset, screenKy + handleOffset); - - } else if (roiMode === "square" && roiRadius > 0) { - // Square mode: draw a filled square ROI - const screenHalfSize = roiRadius * dpZoom; - const left = screenKx - screenHalfSize; - const top = screenKy - screenHalfSize; - const size = screenHalfSize * 2; - - ctx.beginPath(); - ctx.rect(left, top, size, size); - ctx.stroke(); - // Semi-transparent fill - ctx.fillStyle = isDraggingDP ? "rgba(255, 255, 0, 0.15)" : "rgba(0, 255, 0, 0.15)"; - ctx.fill(); - // Draw center crosshair (smaller) - ctx.strokeStyle = isDraggingDP ? "rgba(255, 255, 0, 0.9)" : "rgba(0, 255, 0, 0.9)"; - ctx.lineWidth = lineWidth; - ctx.beginPath(); - ctx.moveTo(screenKx - crosshairSizeSmall, screenKy); ctx.lineTo(screenKx + crosshairSizeSmall, screenKy); - ctx.moveTo(screenKx, screenKy - crosshairSizeSmall); ctx.lineTo(screenKx, screenKy + crosshairSizeSmall); - ctx.stroke(); - - // Draw resize handle at bottom-right corner of square - drawResizeHandle(screenKx + screenHalfSize, screenKy + screenHalfSize); - - } else if (roiMode === "rect" && roiWidth > 0 && roiHeight > 0) { - // Rectangular mode: draw a filled rectangular ROI with independent width/height - const screenHalfW = (roiWidth / 2) * dpZoom; - const screenHalfH = (roiHeight / 2) * dpZoom; - const left = screenKx - screenHalfW; - const top = screenKy - screenHalfH; - - ctx.beginPath(); - ctx.rect(left, top, screenHalfW * 2, screenHalfH * 2); - ctx.stroke(); - // Semi-transparent fill - ctx.fillStyle = isDraggingDP ? "rgba(255, 255, 0, 0.15)" : "rgba(0, 255, 0, 0.15)"; - ctx.fill(); - // Draw center crosshair (smaller) - ctx.strokeStyle = isDraggingDP ? "rgba(255, 255, 0, 0.9)" : "rgba(0, 255, 0, 0.9)"; - ctx.lineWidth = lineWidth; - ctx.beginPath(); - ctx.moveTo(screenKx - crosshairSizeSmall, screenKy); ctx.lineTo(screenKx + crosshairSizeSmall, screenKy); - ctx.moveTo(screenKx, screenKy - crosshairSizeSmall); ctx.lineTo(screenKx, screenKy + crosshairSizeSmall); - ctx.stroke(); - - // Draw resize handle at bottom-right corner of rectangle - drawResizeHandle(screenKx + screenHalfW, screenKy + screenHalfH); - - } else if (roiMode === "annular" && roiRadius > 0) { - // Annular mode: draw donut-shaped ROI (ADF/HAADF) - const screenRadiusOuter = roiRadius * dpZoom; - const screenRadiusInner = (roiRadiusInner || 0) * dpZoom; - - // Draw outer circle - ctx.beginPath(); - ctx.arc(screenKx, screenKy, screenRadiusOuter, 0, 2 * Math.PI); - ctx.stroke(); - - // Draw inner circle (different color for distinction) - ctx.save(); - ctx.strokeStyle = isDraggingDP ? "rgba(255, 200, 0, 0.9)" : "rgba(0, 220, 255, 0.9)"; // Cyan/orange for inner - ctx.beginPath(); - ctx.arc(screenKx, screenKy, screenRadiusInner, 0, 2 * Math.PI); - ctx.stroke(); - ctx.restore(); - - // Fill the annular region (donut) using composite operation - ctx.save(); - ctx.fillStyle = isDraggingDP ? "rgba(255, 255, 0, 0.15)" : "rgba(0, 255, 0, 0.15)"; - ctx.beginPath(); - ctx.arc(screenKx, screenKy, screenRadiusOuter, 0, 2 * Math.PI); - ctx.arc(screenKx, screenKy, screenRadiusInner, 0, 2 * Math.PI, true); // counter-clockwise to cut hole - ctx.fill(); - ctx.restore(); - - // Draw center crosshair (smaller) - ctx.strokeStyle = isDraggingDP ? "rgba(255, 255, 0, 0.9)" : "rgba(0, 255, 0, 0.9)"; - ctx.lineWidth = lineWidth; - ctx.beginPath(); - ctx.moveTo(screenKx - crosshairSizeSmall, screenKy); ctx.lineTo(screenKx + crosshairSizeSmall, screenKy); - ctx.moveTo(screenKx, screenKy - crosshairSizeSmall); ctx.lineTo(screenKx, screenKy + crosshairSizeSmall); - ctx.stroke(); - - // Draw resize handle at outer circle's 45° position (green) - const handleOffsetOuter = screenRadiusOuter * CIRCLE_HANDLE_ANGLE; - drawResizeHandle(screenKx + handleOffsetOuter, screenKy + handleOffsetOuter); - - // Draw resize handle at inner circle's 45° position (cyan) - ctx.save(); - ctx.strokeStyle = isHoveringResizeInner ? "rgba(0, 220, 255, 1)" : "rgba(0, 220, 255, 0.8)"; - ctx.fillStyle = "rgba(0, 40, 50, 0.8)"; - const handleOffsetInner = screenRadiusInner * CIRCLE_HANDLE_ANGLE; - ctx.beginPath(); - ctx.arc(screenKx + handleOffsetInner, screenKy + handleOffsetInner, RESIZE_HANDLE_RADIUS, 0, 2 * Math.PI); - ctx.fill(); - ctx.stroke(); - ctx.restore(); - - } - // Point mode crosshair is drawn on dpUiRef for crisp rendering - - drawScaleBar(ctx, canvas.width, canvas.height, dpZoom, detPixelSize || 1, "mrad", Math.max(detX, detY)); - }, [localKx, localKy, isDraggingDP, isDraggingResize, isDraggingResizeInner, isHoveringResize, isHoveringResizeInner, dpZoom, dpPanX, dpPanY, detPixelSize, roiMode, roiRadius, roiRadiusInner, roiWidth, roiHeight, detX, detY]); + // All visual overlays (crosshair, ROI shapes, scale bar) are now on dpUiRef for crisp rendering + }, [localKx, localKy, isDraggingDP, isDraggingResize, isDraggingResizeInner, isHoveringResize, isHoveringResizeInner, dpZoom, dpPanX, dpPanY, roiMode, roiRadius, roiRadiusInner, roiWidth, roiHeight, detX, detY]); // Render filtered virtual image React.useEffect(() => { @@ -802,7 +821,7 @@ function Show4DSTEM() { // Crosshair and scale bar now drawn on high-DPI UI canvas (viUiRef) }, [localPosX, localPosY, isDraggingVI, viZoom, viPanX, viPanY, pixelSize, shapeX, shapeY]); - // Render FFT (computed in JS from filtered image) + // Render FFT (WebGPU when available, CPU fallback) React.useEffect(() => { if (!rawVirtualImageRef.current || !fftCanvasRef.current) return; const canvas = fftCanvasRef.current; @@ -811,64 +830,83 @@ function Show4DSTEM() { const width = shapeY; const height = shapeX; + const sourceData = rawVirtualImageRef.current; + const lut = COLORMAPS[colormap] || COLORMAPS.inferno; - // Use raw or filtered data - let sourceData = rawVirtualImageRef.current; - if (bpInner > 0 || bpOuter > 0) { - sourceData = rawVirtualImageRef.current.slice(); - } - - const real = sourceData.slice(); - const imag = new Float32Array(real.length); - - // Forward FFT - fft2d(real, imag, width, height, false); - fftshift(real, width, height); - fftshift(imag, width, height); + // Helper to render magnitude to canvas + const renderMagnitude = (real: Float32Array, imag: Float32Array) => { + // Compute log magnitude + const magnitude = new Float32Array(real.length); + for (let i = 0; i < real.length; i++) { + magnitude[i] = Math.log1p(Math.sqrt(real[i] * real[i] + imag[i] * imag[i])); + } - // Compute log magnitude - const magnitude = new Float32Array(real.length); - for (let i = 0; i < real.length; i++) { - magnitude[i] = Math.log1p(Math.sqrt(real[i] * real[i] + imag[i] * imag[i])); - } + // Normalize + let min = Infinity, max = -Infinity; + for (let i = 0; i < magnitude.length; i++) { + if (magnitude[i] < min) min = magnitude[i]; + if (magnitude[i] > max) max = magnitude[i]; + } - // Normalize - let min = Infinity, max = -Infinity; - for (let i = 0; i < magnitude.length; i++) { - if (magnitude[i] < min) min = magnitude[i]; - if (magnitude[i] > max) max = magnitude[i]; - } + const offscreen = document.createElement("canvas"); + offscreen.width = width; + offscreen.height = height; + const offCtx = offscreen.getContext("2d"); + if (!offCtx) return; - const lut = COLORMAPS[colormap] || COLORMAPS.inferno; - const offscreen = document.createElement("canvas"); - offscreen.width = width; - offscreen.height = height; - const offCtx = offscreen.getContext("2d"); - if (!offCtx) return; + const imgData = offCtx.createImageData(width, height); + const rgba = imgData.data; + const range = max > min ? max - min : 1; + + for (let i = 0; i < magnitude.length; i++) { + const v = Math.round(((magnitude[i] - min) / range) * 255); + const j = i * 4; + const lutIdx = Math.max(0, Math.min(255, v)) * 3; + rgba[j] = lut[lutIdx]; + rgba[j + 1] = lut[lutIdx + 1]; + rgba[j + 2] = lut[lutIdx + 2]; + rgba[j + 3] = 255; + } + offCtx.putImageData(imgData, 0, 0); - const imgData = offCtx.createImageData(width, height); - const rgba = imgData.data; - const range = max > min ? max - min : 1; + ctx.imageSmoothingEnabled = false; + ctx.clearRect(0, 0, canvas.width, canvas.height); + ctx.save(); + ctx.translate(fftPanX, fftPanY); + ctx.scale(fftZoom, fftZoom); + ctx.drawImage(offscreen, 0, 0); + ctx.restore(); + }; - for (let i = 0; i < magnitude.length; i++) { - const v = Math.round(((magnitude[i] - min) / range) * 255); - const j = i * 4; - const lutIdx = Math.max(0, Math.min(255, v)) * 3; - rgba[j] = lut[lutIdx]; - rgba[j + 1] = lut[lutIdx + 1]; - rgba[j + 2] = lut[lutIdx + 2]; - rgba[j + 3] = 255; + // Try WebGPU first, fall back to CPU + if (gpuFFTRef.current && gpuReady) { + // WebGPU path (async) + let isCancelled = false; + const runGpuFFT = async () => { + const real = sourceData.slice(); + const imag = new Float32Array(real.length); + + const { real: fReal, imag: fImag } = await gpuFFTRef.current!.fft2D(real, imag, width, height, false); + if (isCancelled) return; + + // Shift in CPU (TODO: move to GPU shader) + fftshift(fReal, width, height); + fftshift(fImag, width, height); + + renderMagnitude(fReal, fImag); + }; + runGpuFFT(); + return () => { isCancelled = true; }; + } else { + // CPU fallback (sync) + const real = sourceData.slice(); + const imag = new Float32Array(real.length); + fft2d(real, imag, width, height, false); + fftshift(real, width, height); + fftshift(imag, width, height); + renderMagnitude(real, imag); } - offCtx.putImageData(imgData, 0, 0); - - ctx.imageSmoothingEnabled = false; - ctx.clearRect(0, 0, canvas.width, canvas.height); - ctx.save(); - ctx.translate(fftPanX, fftPanY); - ctx.scale(fftZoom, fftZoom); - ctx.drawImage(offscreen, 0, 0); - ctx.restore(); - }, [virtualImageBytes, shapeX, shapeY, colormap, fftZoom, fftPanX, fftPanY, bpInner, bpOuter]); + }, [virtualImageBytes, shapeX, shapeY, colormap, fftZoom, fftPanX, fftPanY, gpuReady]); // Render FFT overlay with high-pass filter circle React.useEffect(() => { @@ -902,25 +940,29 @@ function Show4DSTEM() { ctx.stroke(); ctx.setLineDash([]); } - - const fftPixelSize = pixelSize ? 1 / (shapeX * pixelSize) : 1; - drawScaleBar(ctx, canvas.width, canvas.height, fftZoom, fftPixelSize * 1000, "1/µm", Math.max(shapeX, shapeY)); }, [fftZoom, fftPanX, fftPanY, pixelSize, shapeX, shapeY, bpInner, bpOuter]); // ───────────────────────────────────────────────────────────────────────── // High-DPI Scale Bar UI Overlays // ───────────────────────────────────────────────────────────────────────── - // DP scale bar + crosshair (high-DPI) + // DP scale bar + crosshair + ROI overlay (high-DPI) React.useEffect(() => { if (!dpUiRef.current) return; // Draw scale bar first (clears canvas) drawScaleBarHiDPI(dpUiRef.current, DPR, dpZoom, detPixelSize || 1, "mrad", detY, detX); - // Draw crosshair (only for point mode - Python uses "point", not "crosshair") + // Draw ROI overlay (circle, square, rect, annular) or point crosshair if (roiMode === "point") { drawDpCrosshairHiDPI(dpUiRef.current, DPR, localKx, localKy, dpZoom, dpPanX, dpPanY, detY, detX, isDraggingDP); + } else { + drawRoiOverlayHiDPI( + dpUiRef.current, DPR, roiMode, + localKx, localKy, roiRadius, roiRadiusInner, roiWidth, roiHeight, + dpZoom, dpPanX, dpPanY, detY, detX, + isDraggingDP, isDraggingResize, isDraggingResizeInner, isHoveringResize, isHoveringResizeInner + ); } - }, [dpZoom, dpPanX, dpPanY, detPixelSize, detX, detY, roiMode, localKx, localKy, isDraggingDP]); + }, [dpZoom, dpPanX, dpPanY, detPixelSize, detX, detY, roiMode, roiRadius, roiRadiusInner, roiWidth, roiHeight, localKx, localKy, isDraggingDP, isDraggingResize, isDraggingResizeInner, isHoveringResize, isHoveringResizeInner]); // VI scale bar + crosshair (high-DPI) React.useEffect(() => { @@ -1407,23 +1449,18 @@ function Show4DSTEM() { )} {roiMode === "annular" && ( <> - in: setRoiRadiusInner(v as number)} + value={[roiRadiusInner || 5, roiRadius || 10]} + onChange={(_, v) => { + const [inner, outer] = v as number[]; + setRoiRadiusInner(inner); + setRoiRadius(outer); + }} min={0} max={Math.min(detX, detY) / 2} size="small" - sx={{ width: 60 }} - /> - out: - setRoiRadius(v as number)} - min={1} - max={Math.min(detX, detY) / 2} - size="small" - sx={{ width: 60 }} + sx={{ width: 120 }} + valueLabelDisplay="auto" /> {Math.round(roiRadiusInner || 5)}-{Math.round(roiRadius || 10)}px diff --git a/widget/package-lock.json b/widget/package-lock.json index 4e039394..32128b87 100644 --- a/widget/package-lock.json +++ b/widget/package-lock.json @@ -6,20 +6,28 @@ "": { "name": "quantem-widget-frontend", "dependencies": { - "@anywidget/react": "^0.1.0", + "@anywidget/react": "^0.2.0", + "@emotion/react": "^11.14.0", + "@emotion/styled": "^11.14.1", + "@mui/material": "^6.4.0", + "jszip": "^3.10.1", "react": "^18.2.0", "react-dom": "^18.2.0" }, "devDependencies": { "@anywidget/vite": "^0.2.0", + "@types/react": "^18.2.0", + "@types/react-dom": "^18.2.0", "@vitejs/plugin-react": "^4.3.0", + "@webgpu/types": "^0.1.68", + "typescript": "^5.8.3", "vite": "^5.2.0" } }, "node_modules/@anywidget/react": { - "version": "0.1.0", - "resolved": "https://registry.npmjs.org/@anywidget/react/-/react-0.1.0.tgz", - "integrity": "sha512-Hh6wbMGXsgTz2xz1I9/h40M3b6uDWLjmh3sJ4tNjfppDl5y9Sw1jyuQGhNzwViOtszmyjYEJZ4kWHdPFUXUanw==", + "version": "0.2.0", + "resolved": "https://registry.npmjs.org/@anywidget/react/-/react-0.2.0.tgz", + "integrity": "sha512-7jmyfEeKDzMAOmdvzQ/KNtct72lqR1j6KYtb3RtrnHoDMMO5aymqbXib5bvDdnZQOiqlTU+B+cKxtBpYc0W4hg==", "dependencies": { "@anywidget/types": "^0.2.0" }, @@ -49,7 +57,6 @@ "version": "7.27.1", "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.27.1.tgz", "integrity": "sha512-cjQ7ZlQ0Mv3b47hABuTevyTuYN4i+loJKGeV9flcCgIK37cCXRh+L1bd3iBHlynerhQ7BhCkn2BPbQUL+rGqFg==", - "dev": true, "license": "MIT", "dependencies": { "@babel/helper-validator-identifier": "^7.27.1", @@ -106,7 +113,6 @@ "version": "7.28.5", "resolved": "https://registry.npmjs.org/@babel/generator/-/generator-7.28.5.tgz", "integrity": "sha512-3EwLFhZ38J4VyIP6WNtt2kUdW9dokXA9Cr4IVIFHuCpZ3H8/YFOl5JjZHisrn1fATPBmKKqXzDFvh9fUwHz6CQ==", - "dev": true, "license": "MIT", "dependencies": { "@babel/parser": "^7.28.5", @@ -140,7 +146,6 @@ "version": "7.28.0", "resolved": "https://registry.npmjs.org/@babel/helper-globals/-/helper-globals-7.28.0.tgz", "integrity": "sha512-+W6cISkXFa1jXsDEdYA8HeevQT/FULhxzR99pxphltZcVaugps53THCeiWA8SguxxpSp3gKPiuYfSWopkLQ4hw==", - "dev": true, "license": "MIT", "engines": { "node": ">=6.9.0" @@ -150,7 +155,6 @@ "version": "7.27.1", "resolved": "https://registry.npmjs.org/@babel/helper-module-imports/-/helper-module-imports-7.27.1.tgz", "integrity": "sha512-0gSFWUPNXNopqtIPQvlD5WgXYI5GY2kP2cCvoT8kczjbfcfuIljTbcWrulD1CIPIX2gt1wghbDy08yE1p+/r3w==", - "dev": true, "license": "MIT", "dependencies": { "@babel/traverse": "^7.27.1", @@ -192,7 +196,6 @@ "version": "7.27.1", "resolved": "https://registry.npmjs.org/@babel/helper-string-parser/-/helper-string-parser-7.27.1.tgz", "integrity": "sha512-qMlSxKbpRlAridDExk92nSobyDdpPijUq2DW6oDnUqd0iOGxmQjyqhMIihI9+zv4LPyZdRje2cavWPbCbWm3eA==", - "dev": true, "license": "MIT", "engines": { "node": ">=6.9.0" @@ -202,7 +205,6 @@ "version": "7.28.5", "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.28.5.tgz", "integrity": "sha512-qSs4ifwzKJSV39ucNjsvc6WVHs6b7S03sOh2OcHF9UHfVPqWWALUsNUVzhSBiItjRZoLHx7nIarVjqKVusUZ1Q==", - "dev": true, "license": "MIT", "engines": { "node": ">=6.9.0" @@ -236,7 +238,6 @@ "version": "7.28.5", "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.28.5.tgz", "integrity": "sha512-KKBU1VGYR7ORr3At5HAtUQ+TV3SzRCXmA/8OdDZiLDBIZxVyzXuztPjfLd3BV1PRAQGCMWWSHYhL0F8d5uHBDQ==", - "dev": true, "license": "MIT", "dependencies": { "@babel/types": "^7.28.5" @@ -280,11 +281,19 @@ "@babel/core": "^7.0.0-0" } }, + "node_modules/@babel/runtime": { + "version": "7.28.4", + "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.28.4.tgz", + "integrity": "sha512-Q/N6JNWvIvPnLDvjlE1OUBLPQHH6l3CltCEsHIujp45zQUSSh8K+gHnaEX45yAT1nyngnINhvWtzN+Nb9D8RAQ==", + "license": "MIT", + "engines": { + "node": ">=6.9.0" + } + }, "node_modules/@babel/template": { "version": "7.27.2", "resolved": "https://registry.npmjs.org/@babel/template/-/template-7.27.2.tgz", "integrity": "sha512-LPDZ85aEJyYSd18/DkjNh4/y1ntkE5KwUHWTiqgRxruuZL2F1yuHligVHLvcHY2vMHXttKFpJn6LwfI7cw7ODw==", - "dev": true, "license": "MIT", "dependencies": { "@babel/code-frame": "^7.27.1", @@ -299,7 +308,6 @@ "version": "7.28.5", "resolved": "https://registry.npmjs.org/@babel/traverse/-/traverse-7.28.5.tgz", "integrity": "sha512-TCCj4t55U90khlYkVV/0TfkJkAkUg3jZFA3Neb7unZT8CPok7iiRfaX0F+WnqWqt7OxhOn0uBKXCw4lbL8W0aQ==", - "dev": true, "license": "MIT", "dependencies": { "@babel/code-frame": "^7.27.1", @@ -318,7 +326,6 @@ "version": "7.28.5", "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.28.5.tgz", "integrity": "sha512-qQ5m48eI/MFLQ5PxQj4PFaprjyCTLI37ElWMmNs0K8Lk3dVeOdNpB3ks8jc7yM5CDmVC73eMVk/trk3fgmrUpA==", - "dev": true, "license": "MIT", "dependencies": { "@babel/helper-string-parser": "^7.27.1", @@ -328,6 +335,160 @@ "node": ">=6.9.0" } }, + "node_modules/@emotion/babel-plugin": { + "version": "11.13.5", + "resolved": "https://registry.npmjs.org/@emotion/babel-plugin/-/babel-plugin-11.13.5.tgz", + "integrity": "sha512-pxHCpT2ex+0q+HH91/zsdHkw/lXd468DIN2zvfvLtPKLLMo6gQj7oLObq8PhkrxOZb/gGCq03S3Z7PDhS8pduQ==", + "license": "MIT", + "dependencies": { + "@babel/helper-module-imports": "^7.16.7", + "@babel/runtime": "^7.18.3", + "@emotion/hash": "^0.9.2", + "@emotion/memoize": "^0.9.0", + "@emotion/serialize": "^1.3.3", + "babel-plugin-macros": "^3.1.0", + "convert-source-map": "^1.5.0", + "escape-string-regexp": "^4.0.0", + "find-root": "^1.1.0", + "source-map": "^0.5.7", + "stylis": "4.2.0" + } + }, + "node_modules/@emotion/babel-plugin/node_modules/convert-source-map": { + "version": "1.9.0", + "resolved": "https://registry.npmjs.org/convert-source-map/-/convert-source-map-1.9.0.tgz", + "integrity": "sha512-ASFBup0Mz1uyiIjANan1jzLQami9z1PoYSZCiiYW2FczPbenXc45FZdBZLzOT+r6+iciuEModtmCti+hjaAk0A==", + "license": "MIT" + }, + "node_modules/@emotion/cache": { + "version": "11.14.0", + "resolved": "https://registry.npmjs.org/@emotion/cache/-/cache-11.14.0.tgz", + "integrity": "sha512-L/B1lc/TViYk4DcpGxtAVbx0ZyiKM5ktoIyafGkH6zg/tj+mA+NE//aPYKG0k8kCHSHVJrpLpcAlOBEXQ3SavA==", + "license": "MIT", + "dependencies": { + "@emotion/memoize": "^0.9.0", + "@emotion/sheet": "^1.4.0", + "@emotion/utils": "^1.4.2", + "@emotion/weak-memoize": "^0.4.0", + "stylis": "4.2.0" + } + }, + "node_modules/@emotion/hash": { + "version": "0.9.2", + "resolved": "https://registry.npmjs.org/@emotion/hash/-/hash-0.9.2.tgz", + "integrity": "sha512-MyqliTZGuOm3+5ZRSaaBGP3USLw6+EGykkwZns2EPC5g8jJ4z9OrdZY9apkl3+UP9+sdz76YYkwCKP5gh8iY3g==", + "license": "MIT" + }, + "node_modules/@emotion/is-prop-valid": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/@emotion/is-prop-valid/-/is-prop-valid-1.4.0.tgz", + "integrity": "sha512-QgD4fyscGcbbKwJmqNvUMSE02OsHUa+lAWKdEUIJKgqe5IwRSKd7+KhibEWdaKwgjLj0DRSHA9biAIqGBk05lw==", + "license": "MIT", + "dependencies": { + "@emotion/memoize": "^0.9.0" + } + }, + "node_modules/@emotion/memoize": { + "version": "0.9.0", + "resolved": "https://registry.npmjs.org/@emotion/memoize/-/memoize-0.9.0.tgz", + "integrity": "sha512-30FAj7/EoJ5mwVPOWhAyCX+FPfMDrVecJAM+Iw9NRoSl4BBAQeqj4cApHHUXOVvIPgLVDsCFoz/hGD+5QQD1GQ==", + "license": "MIT" + }, + "node_modules/@emotion/react": { + "version": "11.14.0", + "resolved": "https://registry.npmjs.org/@emotion/react/-/react-11.14.0.tgz", + "integrity": "sha512-O000MLDBDdk/EohJPFUqvnp4qnHeYkVP5B0xEG0D/L7cOKP9kefu2DXn8dj74cQfsEzUqh+sr1RzFqiL1o+PpA==", + "license": "MIT", + "peer": true, + "dependencies": { + "@babel/runtime": "^7.18.3", + "@emotion/babel-plugin": "^11.13.5", + "@emotion/cache": "^11.14.0", + "@emotion/serialize": "^1.3.3", + "@emotion/use-insertion-effect-with-fallbacks": "^1.2.0", + "@emotion/utils": "^1.4.2", + "@emotion/weak-memoize": "^0.4.0", + "hoist-non-react-statics": "^3.3.1" + }, + "peerDependencies": { + "react": ">=16.8.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@emotion/serialize": { + "version": "1.3.3", + "resolved": "https://registry.npmjs.org/@emotion/serialize/-/serialize-1.3.3.tgz", + "integrity": "sha512-EISGqt7sSNWHGI76hC7x1CksiXPahbxEOrC5RjmFRJTqLyEK9/9hZvBbiYn70dw4wuwMKiEMCUlR6ZXTSWQqxA==", + "license": "MIT", + "dependencies": { + "@emotion/hash": "^0.9.2", + "@emotion/memoize": "^0.9.0", + "@emotion/unitless": "^0.10.0", + "@emotion/utils": "^1.4.2", + "csstype": "^3.0.2" + } + }, + "node_modules/@emotion/sheet": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/@emotion/sheet/-/sheet-1.4.0.tgz", + "integrity": "sha512-fTBW9/8r2w3dXWYM4HCB1Rdp8NLibOw2+XELH5m5+AkWiL/KqYX6dc0kKYlaYyKjrQ6ds33MCdMPEwgs2z1rqg==", + "license": "MIT" + }, + "node_modules/@emotion/styled": { + "version": "11.14.1", + "resolved": "https://registry.npmjs.org/@emotion/styled/-/styled-11.14.1.tgz", + "integrity": "sha512-qEEJt42DuToa3gurlH4Qqc1kVpNq8wO8cJtDzU46TjlzWjDlsVyevtYCRijVq3SrHsROS+gVQ8Fnea108GnKzw==", + "license": "MIT", + "peer": true, + "dependencies": { + "@babel/runtime": "^7.18.3", + "@emotion/babel-plugin": "^11.13.5", + "@emotion/is-prop-valid": "^1.3.0", + "@emotion/serialize": "^1.3.3", + "@emotion/use-insertion-effect-with-fallbacks": "^1.2.0", + "@emotion/utils": "^1.4.2" + }, + "peerDependencies": { + "@emotion/react": "^11.0.0-rc.0", + "react": ">=16.8.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@emotion/unitless": { + "version": "0.10.0", + "resolved": "https://registry.npmjs.org/@emotion/unitless/-/unitless-0.10.0.tgz", + "integrity": "sha512-dFoMUuQA20zvtVTuxZww6OHoJYgrzfKM1t52mVySDJnMSEa08ruEvdYQbhvyu6soU+NeLVd3yKfTfT0NeV6qGg==", + "license": "MIT" + }, + "node_modules/@emotion/use-insertion-effect-with-fallbacks": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/@emotion/use-insertion-effect-with-fallbacks/-/use-insertion-effect-with-fallbacks-1.2.0.tgz", + "integrity": "sha512-yJMtVdH59sxi/aVJBpk9FQq+OR8ll5GT8oWd57UpeaKEVGab41JWaCFA7FRLoMLloOZF/c/wsPoe+bfGmRKgDg==", + "license": "MIT", + "peerDependencies": { + "react": ">=16.8.0" + } + }, + "node_modules/@emotion/utils": { + "version": "1.4.2", + "resolved": "https://registry.npmjs.org/@emotion/utils/-/utils-1.4.2.tgz", + "integrity": "sha512-3vLclRofFziIa3J2wDh9jjbkUz9qk5Vi3IZ/FSTKViB0k+ef0fPV7dYrUIugbgupYDx7v9ud/SjrtEP8Y4xLoA==", + "license": "MIT" + }, + "node_modules/@emotion/weak-memoize": { + "version": "0.4.0", + "resolved": "https://registry.npmjs.org/@emotion/weak-memoize/-/weak-memoize-0.4.0.tgz", + "integrity": "sha512-snKqtPW01tN0ui7yu9rGv69aJXr/a/Ywvl11sUjNtEcRc+ng/mQriFL0wLXMef74iHa/EkftbDzU9F8iFbH+zg==", + "license": "MIT" + }, "node_modules/@esbuild/aix-ppc64": { "version": "0.21.5", "resolved": "https://registry.npmjs.org/@esbuild/aix-ppc64/-/aix-ppc64-0.21.5.tgz", @@ -723,7 +884,6 @@ "version": "0.3.13", "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.13.tgz", "integrity": "sha512-2kkt/7niJ6MgEPxF0bYdQ6etZaA+fQvDcLKckhy1yIQOzaoKjBBjSj63/aLVjYE3qhRt5dvM+uUyfCg6UKCBbA==", - "dev": true, "license": "MIT", "dependencies": { "@jridgewell/sourcemap-codec": "^1.5.0", @@ -745,7 +905,6 @@ "version": "3.1.2", "resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.2.tgz", "integrity": "sha512-bRISgCIjP20/tbWSPWMEi54QVPRZExkuD9lJL+UIxUKtwVJA8wW1Trb1jMs1RFXo1CBTNZ/5hpC9QvmKWdopKw==", - "dev": true, "license": "MIT", "engines": { "node": ">=6.0.0" @@ -755,20 +914,232 @@ "version": "1.5.5", "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.5.tgz", "integrity": "sha512-cYQ9310grqxueWbl+WuIUIaiUaDcj7WOq5fVhEljNVgRfOUhY9fy2zTvfoqWsnebh8Sl70VScFbICvJnLKB0Og==", - "dev": true, "license": "MIT" }, "node_modules/@jridgewell/trace-mapping": { "version": "0.3.31", "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.31.tgz", "integrity": "sha512-zzNR+SdQSDJzc8joaeP8QQoCQr8NuYx2dIIytl1QeBEZHJ9uW6hebsrYgbz8hJwUQao3TWCMtmfV8Nu1twOLAw==", - "dev": true, "license": "MIT", "dependencies": { "@jridgewell/resolve-uri": "^3.1.0", "@jridgewell/sourcemap-codec": "^1.4.14" } }, + "node_modules/@mui/core-downloads-tracker": { + "version": "6.5.0", + "resolved": "https://registry.npmjs.org/@mui/core-downloads-tracker/-/core-downloads-tracker-6.5.0.tgz", + "integrity": "sha512-LGb8t8i6M2ZtS3Drn3GbTI1DVhDY6FJ9crEey2lZ0aN2EMZo8IZBZj9wRf4vqbZHaWjsYgtbOnJw5V8UWbmK2Q==", + "license": "MIT", + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/mui-org" + } + }, + "node_modules/@mui/material": { + "version": "6.5.0", + "resolved": "https://registry.npmjs.org/@mui/material/-/material-6.5.0.tgz", + "integrity": "sha512-yjvtXoFcrPLGtgKRxFaH6OQPtcLPhkloC0BML6rBG5UeldR0nPULR/2E2BfXdo5JNV7j7lOzrrLX2Qf/iSidow==", + "license": "MIT", + "dependencies": { + "@babel/runtime": "^7.26.0", + "@mui/core-downloads-tracker": "^6.5.0", + "@mui/system": "^6.5.0", + "@mui/types": "~7.2.24", + "@mui/utils": "^6.4.9", + "@popperjs/core": "^2.11.8", + "@types/react-transition-group": "^4.4.12", + "clsx": "^2.1.1", + "csstype": "^3.1.3", + "prop-types": "^15.8.1", + "react-is": "^19.0.0", + "react-transition-group": "^4.4.5" + }, + "engines": { + "node": ">=14.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/mui-org" + }, + "peerDependencies": { + "@emotion/react": "^11.5.0", + "@emotion/styled": "^11.3.0", + "@mui/material-pigment-css": "^6.5.0", + "@types/react": "^17.0.0 || ^18.0.0 || ^19.0.0", + "react": "^17.0.0 || ^18.0.0 || ^19.0.0", + "react-dom": "^17.0.0 || ^18.0.0 || ^19.0.0" + }, + "peerDependenciesMeta": { + "@emotion/react": { + "optional": true + }, + "@emotion/styled": { + "optional": true + }, + "@mui/material-pigment-css": { + "optional": true + }, + "@types/react": { + "optional": true + } + } + }, + "node_modules/@mui/private-theming": { + "version": "6.4.9", + "resolved": "https://registry.npmjs.org/@mui/private-theming/-/private-theming-6.4.9.tgz", + "integrity": "sha512-LktcVmI5X17/Q5SkwjCcdOLBzt1hXuc14jYa7NPShog0GBDCDvKtcnP0V7a2s6EiVRlv7BzbWEJzH6+l/zaCxw==", + "license": "MIT", + "dependencies": { + "@babel/runtime": "^7.26.0", + "@mui/utils": "^6.4.9", + "prop-types": "^15.8.1" + }, + "engines": { + "node": ">=14.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/mui-org" + }, + "peerDependencies": { + "@types/react": "^17.0.0 || ^18.0.0 || ^19.0.0", + "react": "^17.0.0 || ^18.0.0 || ^19.0.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@mui/styled-engine": { + "version": "6.5.0", + "resolved": "https://registry.npmjs.org/@mui/styled-engine/-/styled-engine-6.5.0.tgz", + "integrity": "sha512-8woC2zAqF4qUDSPIBZ8v3sakj+WgweolpyM/FXf8jAx6FMls+IE4Y8VDZc+zS805J7PRz31vz73n2SovKGaYgw==", + "license": "MIT", + "dependencies": { + "@babel/runtime": "^7.26.0", + "@emotion/cache": "^11.13.5", + "@emotion/serialize": "^1.3.3", + "@emotion/sheet": "^1.4.0", + "csstype": "^3.1.3", + "prop-types": "^15.8.1" + }, + "engines": { + "node": ">=14.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/mui-org" + }, + "peerDependencies": { + "@emotion/react": "^11.4.1", + "@emotion/styled": "^11.3.0", + "react": "^17.0.0 || ^18.0.0 || ^19.0.0" + }, + "peerDependenciesMeta": { + "@emotion/react": { + "optional": true + }, + "@emotion/styled": { + "optional": true + } + } + }, + "node_modules/@mui/system": { + "version": "6.5.0", + "resolved": "https://registry.npmjs.org/@mui/system/-/system-6.5.0.tgz", + "integrity": "sha512-XcbBYxDS+h/lgsoGe78ExXFZXtuIlSBpn/KsZq8PtZcIkUNJInkuDqcLd2rVBQrDC1u+rvVovdaWPf2FHKJf3w==", + "license": "MIT", + "dependencies": { + "@babel/runtime": "^7.26.0", + "@mui/private-theming": "^6.4.9", + "@mui/styled-engine": "^6.5.0", + "@mui/types": "~7.2.24", + "@mui/utils": "^6.4.9", + "clsx": "^2.1.1", + "csstype": "^3.1.3", + "prop-types": "^15.8.1" + }, + "engines": { + "node": ">=14.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/mui-org" + }, + "peerDependencies": { + "@emotion/react": "^11.5.0", + "@emotion/styled": "^11.3.0", + "@types/react": "^17.0.0 || ^18.0.0 || ^19.0.0", + "react": "^17.0.0 || ^18.0.0 || ^19.0.0" + }, + "peerDependenciesMeta": { + "@emotion/react": { + "optional": true + }, + "@emotion/styled": { + "optional": true + }, + "@types/react": { + "optional": true + } + } + }, + "node_modules/@mui/types": { + "version": "7.2.24", + "resolved": "https://registry.npmjs.org/@mui/types/-/types-7.2.24.tgz", + "integrity": "sha512-3c8tRt/CbWZ+pEg7QpSwbdxOk36EfmhbKf6AGZsD1EcLDLTSZoxxJ86FVtcjxvjuhdyBiWKSTGZFaXCnidO2kw==", + "license": "MIT", + "peerDependencies": { + "@types/react": "^17.0.0 || ^18.0.0 || ^19.0.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@mui/utils": { + "version": "6.4.9", + "resolved": "https://registry.npmjs.org/@mui/utils/-/utils-6.4.9.tgz", + "integrity": "sha512-Y12Q9hbK9g+ZY0T3Rxrx9m2m10gaphDuUMgWxyV5kNJevVxXYCLclYUCC9vXaIk1/NdNDTcW2Yfr2OGvNFNmHg==", + "license": "MIT", + "dependencies": { + "@babel/runtime": "^7.26.0", + "@mui/types": "~7.2.24", + "@types/prop-types": "^15.7.14", + "clsx": "^2.1.1", + "prop-types": "^15.8.1", + "react-is": "^19.0.0" + }, + "engines": { + "node": ">=14.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/mui-org" + }, + "peerDependencies": { + "@types/react": "^17.0.0 || ^18.0.0 || ^19.0.0", + "react": "^17.0.0 || ^18.0.0 || ^19.0.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@popperjs/core": { + "version": "2.11.8", + "resolved": "https://registry.npmjs.org/@popperjs/core/-/core-2.11.8.tgz", + "integrity": "sha512-P1st0aksCrn9sGZhp8GMYwBnQsbvAWsZAX44oXNNvLHGqAOcoVxmjZiohstwQ7SqKnbR47akdNi+uleWD8+g6A==", + "license": "MIT", + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/popperjs" + } + }, "node_modules/@rolldown/pluginutils": { "version": "1.0.0-beta.27", "resolved": "https://registry.npmjs.org/@rolldown/pluginutils/-/pluginutils-1.0.0-beta.27.tgz", @@ -1178,24 +1549,46 @@ "dev": true, "license": "MIT" }, + "node_modules/@types/parse-json": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/@types/parse-json/-/parse-json-4.0.2.tgz", + "integrity": "sha512-dISoDXWWQwUquiKsyZ4Ng+HX2KsPL7LyHKHQwgGFEA3IaKac4Obd+h2a/a6waisAoepJlBcx9paWqjA8/HVjCw==", + "license": "MIT" + }, + "node_modules/@types/prop-types": { + "version": "15.7.15", + "resolved": "https://registry.npmjs.org/@types/prop-types/-/prop-types-15.7.15.tgz", + "integrity": "sha512-F6bEyamV9jKGAFBEmlQnesRPGOQqS2+Uwi0Em15xenOxHaf2hv6L8YCVn3rPdPJOiJfPiCnLIRyvwVaqMY3MIw==", + "license": "MIT" + }, "node_modules/@types/react": { - "version": "19.2.8", - "resolved": "https://registry.npmjs.org/@types/react/-/react-19.2.8.tgz", - "integrity": "sha512-3MbSL37jEchWZz2p2mjntRZtPt837ij10ApxKfgmXCTuHWagYg7iA5bqPw6C8BMPfwidlvfPI/fxOc42HLhcyg==", + "version": "18.3.27", + "resolved": "https://registry.npmjs.org/@types/react/-/react-18.3.27.tgz", + "integrity": "sha512-cisd7gxkzjBKU2GgdYrTdtQx1SORymWyaAFhaxQPK9bYO9ot3Y5OikQRvY0VYQtvwjeQnizCINJAenh/V7MK2w==", "license": "MIT", "peer": true, "dependencies": { + "@types/prop-types": "*", "csstype": "^3.2.2" } }, "node_modules/@types/react-dom": { - "version": "19.2.3", - "resolved": "https://registry.npmjs.org/@types/react-dom/-/react-dom-19.2.3.tgz", - "integrity": "sha512-jp2L/eY6fn+KgVVQAOqYItbF0VY/YApe5Mz2F0aykSO8gx31bYCZyvSeYxCHKvzHG5eZjc+zyaS5BrBWya2+kQ==", + "version": "18.3.7", + "resolved": "https://registry.npmjs.org/@types/react-dom/-/react-dom-18.3.7.tgz", + "integrity": "sha512-MEe3UeoENYVFXzoXEWsvcpg6ZvlrFNlOQ7EOsvhI3CfAXwzPfO8Qwuxd40nepsYKqyyVQnTdEfv68q91yLcKrQ==", "license": "MIT", "peer": true, "peerDependencies": { - "@types/react": "^19.2.0" + "@types/react": "^18.0.0" + } + }, + "node_modules/@types/react-transition-group": { + "version": "4.4.12", + "resolved": "https://registry.npmjs.org/@types/react-transition-group/-/react-transition-group-4.4.12.tgz", + "integrity": "sha512-8TV6R3h2j7a91c+1DXdJi3Syo69zzIZbz7Lg5tORM5LEJG7X/E6a1V3drRyBRZq7/utz7A+c4OgYLiLcYGHG6w==", + "license": "MIT", + "peerDependencies": { + "@types/react": "*" } }, "node_modules/@vitejs/plugin-react": { @@ -1219,6 +1612,28 @@ "vite": "^4.2.0 || ^5.0.0 || ^6.0.0 || ^7.0.0" } }, + "node_modules/@webgpu/types": { + "version": "0.1.68", + "resolved": "https://registry.npmjs.org/@webgpu/types/-/types-0.1.68.tgz", + "integrity": "sha512-3ab1B59Ojb6RwjOspYLsTpCzbNB3ZaamIAxBMmvnNkiDoLTZUOBXZ9p5nAYVEkQlDdf6qAZWi1pqj9+ypiqznA==", + "dev": true, + "license": "BSD-3-Clause" + }, + "node_modules/babel-plugin-macros": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/babel-plugin-macros/-/babel-plugin-macros-3.1.0.tgz", + "integrity": "sha512-Cg7TFGpIr01vOQNODXOOaGz2NpCU5gl8x1qJFbb6hbZxR7XrcE2vtbAsTAbJ7/xwJtUuJEw8K8Zr/AE0LHlesg==", + "license": "MIT", + "dependencies": { + "@babel/runtime": "^7.12.5", + "cosmiconfig": "^7.0.0", + "resolve": "^1.19.0" + }, + "engines": { + "node": ">=10", + "npm": ">=6" + } + }, "node_modules/baseline-browser-mapping": { "version": "2.9.14", "resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.9.14.tgz", @@ -1264,6 +1679,15 @@ "node": "^6 || ^7 || ^8 || ^9 || ^10 || ^11 || ^12 || >=13.7" } }, + "node_modules/callsites": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/callsites/-/callsites-3.1.0.tgz", + "integrity": "sha512-P8BjAsXvZS+VIDUI11hHCQEv74YT67YUi5JJFNWIqL235sBmjX4+qx9Muvls5ivyNENctx46xQLQ3aTuE7ssaQ==", + "license": "MIT", + "engines": { + "node": ">=6" + } + }, "node_modules/caniuse-lite": { "version": "1.0.30001764", "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001764.tgz", @@ -1285,6 +1709,15 @@ ], "license": "CC-BY-4.0" }, + "node_modules/clsx": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/clsx/-/clsx-2.1.1.tgz", + "integrity": "sha512-eYm0QWBtUrBWZWG0d386OGAw16Z995PiOVo2B7bjWSbHedGl5e0ZWaq65kOGgUSNesEIDkB9ISbTg/JK9dhCZA==", + "license": "MIT", + "engines": { + "node": ">=6" + } + }, "node_modules/convert-source-map": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/convert-source-map/-/convert-source-map-2.0.0.tgz", @@ -1292,6 +1725,28 @@ "dev": true, "license": "MIT" }, + "node_modules/core-util-is": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/core-util-is/-/core-util-is-1.0.3.tgz", + "integrity": "sha512-ZQBvi1DcpJ4GDqanjucZ2Hj3wEO5pZDS89BWbkcrvdxksJorwUDDZamX9ldFkp9aw2lmBDLgkObEA4DWNJ9FYQ==", + "license": "MIT" + }, + "node_modules/cosmiconfig": { + "version": "7.1.0", + "resolved": "https://registry.npmjs.org/cosmiconfig/-/cosmiconfig-7.1.0.tgz", + "integrity": "sha512-AdmX6xUzdNASswsFtmwSt7Vj8po9IuqXm0UXz7QKPuEUmPB4XyjGfaAr2PSuELMwkRMVH1EpIkX5bTZGRB3eCA==", + "license": "MIT", + "dependencies": { + "@types/parse-json": "^4.0.0", + "import-fresh": "^3.2.1", + "parse-json": "^5.0.0", + "path-type": "^4.0.0", + "yaml": "^1.10.0" + }, + "engines": { + "node": ">=10" + } + }, "node_modules/csstype": { "version": "3.2.3", "resolved": "https://registry.npmjs.org/csstype/-/csstype-3.2.3.tgz", @@ -1302,7 +1757,6 @@ "version": "4.4.3", "resolved": "https://registry.npmjs.org/debug/-/debug-4.4.3.tgz", "integrity": "sha512-RGwwWnwQvkVfavKVt22FGLw+xYSdzARwm0ru6DhTVA3umU5hZc28V3kO4stgYryrTlLpuvgI9GiijltAjNbcqA==", - "dev": true, "license": "MIT", "dependencies": { "ms": "^2.1.3" @@ -1316,6 +1770,16 @@ } } }, + "node_modules/dom-helpers": { + "version": "5.2.1", + "resolved": "https://registry.npmjs.org/dom-helpers/-/dom-helpers-5.2.1.tgz", + "integrity": "sha512-nRCa7CK3VTrM2NmGkIy4cbK7IZlgBE/PYMn55rrXefr5xXDP0LdtfPnblFDoVdcAfslJ7or6iqAUnx0CCGIWQA==", + "license": "MIT", + "dependencies": { + "@babel/runtime": "^7.8.7", + "csstype": "^3.0.2" + } + }, "node_modules/electron-to-chromium": { "version": "1.5.267", "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.5.267.tgz", @@ -1323,6 +1787,15 @@ "dev": true, "license": "ISC" }, + "node_modules/error-ex": { + "version": "1.3.4", + "resolved": "https://registry.npmjs.org/error-ex/-/error-ex-1.3.4.tgz", + "integrity": "sha512-sqQamAnR14VgCr1A618A3sGrygcpK+HEbenA/HiEAkkUwcZIIB/tgWqHFxWgOyDh4nB4JCRimh79dR5Ywc9MDQ==", + "license": "MIT", + "dependencies": { + "is-arrayish": "^0.2.1" + } + }, "node_modules/esbuild": { "version": "0.21.5", "resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.21.5.tgz", @@ -1372,6 +1845,24 @@ "node": ">=6" } }, + "node_modules/escape-string-regexp": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-4.0.0.tgz", + "integrity": "sha512-TtpcNJ3XAzx3Gq8sWRzJaVajRs0uVxA2YAkdb1jm2YkPz4G6egUFAyA3n5vtEIZefPk5Wa4UXbKuS5fKkJWdgA==", + "license": "MIT", + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/find-root": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/find-root/-/find-root-1.1.0.tgz", + "integrity": "sha512-NKfW6bec6GfKc0SGx1e07QZY9PE99u0Bft/0rzSD5k3sO/vwkVUpDUKVm5Gpp5Ue3YfShPFTX2070tDs5kB9Ng==", + "license": "MIT" + }, "node_modules/fsevents": { "version": "2.3.3", "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz", @@ -1387,6 +1878,15 @@ "node": "^8.16.0 || ^10.6.0 || >=11.0.0" } }, + "node_modules/function-bind": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.2.tgz", + "integrity": "sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==", + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, "node_modules/gensync": { "version": "1.0.0-beta.2", "resolved": "https://registry.npmjs.org/gensync/-/gensync-1.0.0-beta.2.tgz", @@ -1397,6 +1897,88 @@ "node": ">=6.9.0" } }, + "node_modules/hasown": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz", + "integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==", + "license": "MIT", + "dependencies": { + "function-bind": "^1.1.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/hoist-non-react-statics": { + "version": "3.3.2", + "resolved": "https://registry.npmjs.org/hoist-non-react-statics/-/hoist-non-react-statics-3.3.2.tgz", + "integrity": "sha512-/gGivxi8JPKWNm/W0jSmzcMPpfpPLc3dY/6GxhX2hQ9iGj3aDfklV4ET7NjKpSinLpJ5vafa9iiGIEZg10SfBw==", + "license": "BSD-3-Clause", + "dependencies": { + "react-is": "^16.7.0" + } + }, + "node_modules/hoist-non-react-statics/node_modules/react-is": { + "version": "16.13.1", + "resolved": "https://registry.npmjs.org/react-is/-/react-is-16.13.1.tgz", + "integrity": "sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ==", + "license": "MIT" + }, + "node_modules/immediate": { + "version": "3.0.6", + "resolved": "https://registry.npmjs.org/immediate/-/immediate-3.0.6.tgz", + "integrity": "sha512-XXOFtyqDjNDAQxVfYxuF7g9Il/IbWmmlQg2MYKOH8ExIT1qg6xc4zyS3HaEEATgs1btfzxq15ciUiY7gjSXRGQ==", + "license": "MIT" + }, + "node_modules/import-fresh": { + "version": "3.3.1", + "resolved": "https://registry.npmjs.org/import-fresh/-/import-fresh-3.3.1.tgz", + "integrity": "sha512-TR3KfrTZTYLPB6jUjfx6MF9WcWrHL9su5TObK4ZkYgBdWKPOFoSoQIdEuTuR82pmtxH2spWG9h6etwfr1pLBqQ==", + "license": "MIT", + "dependencies": { + "parent-module": "^1.0.0", + "resolve-from": "^4.0.0" + }, + "engines": { + "node": ">=6" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/inherits": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz", + "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==", + "license": "ISC" + }, + "node_modules/is-arrayish": { + "version": "0.2.1", + "resolved": "https://registry.npmjs.org/is-arrayish/-/is-arrayish-0.2.1.tgz", + "integrity": "sha512-zz06S8t0ozoDXMG+ube26zeCTNXcKIPJZJi8hBrF4idCLms4CG9QtK7qBl1boi5ODzFpjswb5JPmHCbMpjaYzg==", + "license": "MIT" + }, + "node_modules/is-core-module": { + "version": "2.16.1", + "resolved": "https://registry.npmjs.org/is-core-module/-/is-core-module-2.16.1.tgz", + "integrity": "sha512-UfoeMA6fIJ8wTYFEUjelnaGI67v6+N7qXJEvQuIGa99l4xsCruSYOVSQ0uPANn4dAzm8lkYPaKLrrijLq7x23w==", + "license": "MIT", + "dependencies": { + "hasown": "^2.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/isarray": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/isarray/-/isarray-1.0.0.tgz", + "integrity": "sha512-VLghIWNM6ELQzo7zwmcg0NmTVyWKYjvIeM83yjp0wRDTmUnrM678fQbcKBo6n2CJEF0szoG//ytg+TKla89ALQ==", + "license": "MIT" + }, "node_modules/js-tokens": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz", @@ -1407,7 +1989,6 @@ "version": "3.1.0", "resolved": "https://registry.npmjs.org/jsesc/-/jsesc-3.1.0.tgz", "integrity": "sha512-/sM3dO2FOzXjKQhJuo0Q173wf2KOo8t4I8vHy6lF9poUp7bKT0/NHE8fPX23PwfhnykfqnC2xRxOnVw5XuGIaA==", - "dev": true, "license": "MIT", "bin": { "jsesc": "bin/jsesc" @@ -1416,6 +1997,12 @@ "node": ">=6" } }, + "node_modules/json-parse-even-better-errors": { + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/json-parse-even-better-errors/-/json-parse-even-better-errors-2.3.1.tgz", + "integrity": "sha512-xyFwyhro/JEof6Ghe2iz2NcXoj2sloNsWr/XsERDK/oiPCfaNhl5ONfp+jQdAZRQQ0IJWNzH9zIZF7li91kh2w==", + "license": "MIT" + }, "node_modules/json5": { "version": "2.2.3", "resolved": "https://registry.npmjs.org/json5/-/json5-2.2.3.tgz", @@ -1429,6 +2016,33 @@ "node": ">=6" } }, + "node_modules/jszip": { + "version": "3.10.1", + "resolved": "https://registry.npmjs.org/jszip/-/jszip-3.10.1.tgz", + "integrity": "sha512-xXDvecyTpGLrqFrvkrUSoxxfJI5AH7U8zxxtVclpsUtMCq4JQ290LY8AW5c7Ggnr/Y/oK+bQMbqK2qmtk3pN4g==", + "license": "(MIT OR GPL-3.0-or-later)", + "dependencies": { + "lie": "~3.3.0", + "pako": "~1.0.2", + "readable-stream": "~2.3.6", + "setimmediate": "^1.0.5" + } + }, + "node_modules/lie": { + "version": "3.3.0", + "resolved": "https://registry.npmjs.org/lie/-/lie-3.3.0.tgz", + "integrity": "sha512-UaiMJzeWRlEujzAuw5LokY1L5ecNQYZKfmyZ9L7wDHb/p5etKaxXhohBcrw0EYby+G/NA52vRSN4N39dxHAIwQ==", + "license": "MIT", + "dependencies": { + "immediate": "~3.0.5" + } + }, + "node_modules/lines-and-columns": { + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/lines-and-columns/-/lines-and-columns-1.2.4.tgz", + "integrity": "sha512-7ylylesZQ/PV29jhEDl3Ufjo6ZX7gCqJr5F7PKrqc93v7fzSymt1BpwEU8nAUXs8qzzvqhbjhK5QZg6Mt/HkBg==", + "license": "MIT" + }, "node_modules/loose-envify": { "version": "1.4.0", "resolved": "https://registry.npmjs.org/loose-envify/-/loose-envify-1.4.0.tgz", @@ -1455,7 +2069,6 @@ "version": "2.1.3", "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", - "dev": true, "license": "MIT" }, "node_modules/nanoid": { @@ -1484,11 +2097,70 @@ "dev": true, "license": "MIT" }, + "node_modules/object-assign": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/object-assign/-/object-assign-4.1.1.tgz", + "integrity": "sha512-rJgTQnkUnH1sFw8yT6VSU3zD3sWmu6sZhIseY8VX+GRu3P6F7Fu+JNDoXfklElbLJSnc3FUQHVe4cU5hj+BcUg==", + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/pako": { + "version": "1.0.11", + "resolved": "https://registry.npmjs.org/pako/-/pako-1.0.11.tgz", + "integrity": "sha512-4hLB8Py4zZce5s4yd9XzopqwVv/yGNhV1Bl8NTmCq1763HeK2+EwVTv+leGeL13Dnh2wfbqowVPXCIO0z4taYw==", + "license": "(MIT AND Zlib)" + }, + "node_modules/parent-module": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/parent-module/-/parent-module-1.0.1.tgz", + "integrity": "sha512-GQ2EWRpQV8/o+Aw8YqtfZZPfNRWZYkbidE9k5rpl/hC3vtHHBfGm2Ifi6qWV+coDGkrUKZAxE3Lot5kcsRlh+g==", + "license": "MIT", + "dependencies": { + "callsites": "^3.0.0" + }, + "engines": { + "node": ">=6" + } + }, + "node_modules/parse-json": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/parse-json/-/parse-json-5.2.0.tgz", + "integrity": "sha512-ayCKvm/phCGxOkYRSCM82iDwct8/EonSEgCSxWxD7ve6jHggsFl4fZVQBPRNgQoKiuV/odhFrGzQXZwbifC8Rg==", + "license": "MIT", + "dependencies": { + "@babel/code-frame": "^7.0.0", + "error-ex": "^1.3.1", + "json-parse-even-better-errors": "^2.3.0", + "lines-and-columns": "^1.1.6" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/path-parse": { + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/path-parse/-/path-parse-1.0.7.tgz", + "integrity": "sha512-LDJzPVEEEPR+y48z93A0Ed0yXb8pAByGWo/k5YYdYgpY2/2EsOsksJrq7lOHxryrVOn1ejG6oAp8ahvOIQD8sw==", + "license": "MIT" + }, + "node_modules/path-type": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/path-type/-/path-type-4.0.0.tgz", + "integrity": "sha512-gDKb8aZMDeD/tZWs9P6+q0J9Mwkdl6xMV8TjnGP3qJVJ06bdMgkbBlLU8IdfOsIsFz2BW1rNVT3XuNEl8zPAvw==", + "license": "MIT", + "engines": { + "node": ">=8" + } + }, "node_modules/picocolors": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz", "integrity": "sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==", - "dev": true, "license": "ISC" }, "node_modules/postcss": { @@ -1520,6 +2192,29 @@ "node": "^10 || ^12 || >=14" } }, + "node_modules/process-nextick-args": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/process-nextick-args/-/process-nextick-args-2.0.1.tgz", + "integrity": "sha512-3ouUOpQhtgrbOa17J7+uxOTpITYWaGP7/AhoR3+A+/1e9skrzelGi/dXzEYyvbxubEF6Wn2ypscTKiKJFFn1ag==", + "license": "MIT" + }, + "node_modules/prop-types": { + "version": "15.8.1", + "resolved": "https://registry.npmjs.org/prop-types/-/prop-types-15.8.1.tgz", + "integrity": "sha512-oj87CgZICdulUohogVAR7AjlC0327U4el4L6eAvOqCeudMDVU0NThNaV+b9Df4dXgSP1gXMTnPdhfe/2qDH5cg==", + "license": "MIT", + "dependencies": { + "loose-envify": "^1.4.0", + "object-assign": "^4.1.1", + "react-is": "^16.13.1" + } + }, + "node_modules/prop-types/node_modules/react-is": { + "version": "16.13.1", + "resolved": "https://registry.npmjs.org/react-is/-/react-is-16.13.1.tgz", + "integrity": "sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ==", + "license": "MIT" + }, "node_modules/react": { "version": "18.3.1", "resolved": "https://registry.npmjs.org/react/-/react-18.3.1.tgz", @@ -1547,6 +2242,12 @@ "react": "^18.3.1" } }, + "node_modules/react-is": { + "version": "19.2.3", + "resolved": "https://registry.npmjs.org/react-is/-/react-is-19.2.3.tgz", + "integrity": "sha512-qJNJfu81ByyabuG7hPFEbXqNcWSU3+eVus+KJs+0ncpGfMyYdvSmxiJxbWR65lYi1I+/0HBcliO029gc4F+PnA==", + "license": "MIT" + }, "node_modules/react-refresh": { "version": "0.17.0", "resolved": "https://registry.npmjs.org/react-refresh/-/react-refresh-0.17.0.tgz", @@ -1557,6 +2258,66 @@ "node": ">=0.10.0" } }, + "node_modules/react-transition-group": { + "version": "4.4.5", + "resolved": "https://registry.npmjs.org/react-transition-group/-/react-transition-group-4.4.5.tgz", + "integrity": "sha512-pZcd1MCJoiKiBR2NRxeCRg13uCXbydPnmB4EOeRrY7480qNWO8IIgQG6zlDkm6uRMsURXPuKq0GWtiM59a5Q6g==", + "license": "BSD-3-Clause", + "dependencies": { + "@babel/runtime": "^7.5.5", + "dom-helpers": "^5.0.1", + "loose-envify": "^1.4.0", + "prop-types": "^15.6.2" + }, + "peerDependencies": { + "react": ">=16.6.0", + "react-dom": ">=16.6.0" + } + }, + "node_modules/readable-stream": { + "version": "2.3.8", + "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-2.3.8.tgz", + "integrity": "sha512-8p0AUk4XODgIewSi0l8Epjs+EVnWiK7NoDIEGU0HhE7+ZyY8D1IMY7odu5lRrFXGg71L15KG8QrPmum45RTtdA==", + "license": "MIT", + "dependencies": { + "core-util-is": "~1.0.0", + "inherits": "~2.0.3", + "isarray": "~1.0.0", + "process-nextick-args": "~2.0.0", + "safe-buffer": "~5.1.1", + "string_decoder": "~1.1.1", + "util-deprecate": "~1.0.1" + } + }, + "node_modules/resolve": { + "version": "1.22.11", + "resolved": "https://registry.npmjs.org/resolve/-/resolve-1.22.11.tgz", + "integrity": "sha512-RfqAvLnMl313r7c9oclB1HhUEAezcpLjz95wFH4LVuhk9JF/r22qmVP9AMmOU4vMX7Q8pN8jwNg/CSpdFnMjTQ==", + "license": "MIT", + "dependencies": { + "is-core-module": "^2.16.1", + "path-parse": "^1.0.7", + "supports-preserve-symlinks-flag": "^1.0.0" + }, + "bin": { + "resolve": "bin/resolve" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/resolve-from": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/resolve-from/-/resolve-from-4.0.0.tgz", + "integrity": "sha512-pb/MYmXstAkysRFx8piNI1tGFNQIFA3vkE3Gq4EuA1dF6gHp/+vgZqsCGJapvy8N3Q+4o7FwvquPJcnZ7RYy4g==", + "license": "MIT", + "engines": { + "node": ">=4" + } + }, "node_modules/rollup": { "version": "4.55.1", "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.55.1.tgz", @@ -1602,6 +2363,12 @@ "fsevents": "~2.3.2" } }, + "node_modules/safe-buffer": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.1.2.tgz", + "integrity": "sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g==", + "license": "MIT" + }, "node_modules/scheduler": { "version": "0.23.2", "resolved": "https://registry.npmjs.org/scheduler/-/scheduler-0.23.2.tgz", @@ -1621,6 +2388,21 @@ "semver": "bin/semver.js" } }, + "node_modules/setimmediate": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/setimmediate/-/setimmediate-1.0.5.tgz", + "integrity": "sha512-MATJdZp8sLqDl/68LfQmbP8zKPLQNV6BIZoIgrscFDQ+RsvK/BxeDQOgyxKKoh0y/8h3BqVFnCqQ/gd+reiIXA==", + "license": "MIT" + }, + "node_modules/source-map": { + "version": "0.5.7", + "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.5.7.tgz", + "integrity": "sha512-LbrmJOMUSdEVxIKvdcJzQC+nQhe8FUZQTXQy6+I75skNgn3OoQ0DZA8YnFa7gp8tqtL3KPf1kmo0R5DoApeSGQ==", + "license": "BSD-3-Clause", + "engines": { + "node": ">=0.10.0" + } + }, "node_modules/source-map-js": { "version": "1.2.1", "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.1.tgz", @@ -1631,6 +2413,47 @@ "node": ">=0.10.0" } }, + "node_modules/string_decoder": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.1.1.tgz", + "integrity": "sha512-n/ShnvDi6FHbbVfviro+WojiFzv+s8MPMHBczVePfUpDJLwoLT0ht1l4YwBCbi8pJAveEEdnkHyPyTP/mzRfwg==", + "license": "MIT", + "dependencies": { + "safe-buffer": "~5.1.0" + } + }, + "node_modules/stylis": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/stylis/-/stylis-4.2.0.tgz", + "integrity": "sha512-Orov6g6BB1sDfYgzWfTHDOxamtX1bE/zo104Dh9e6fqJ3PooipYyfJ0pUmrZO2wAvO8YbEyeFrkV91XTsGMSrw==", + "license": "MIT" + }, + "node_modules/supports-preserve-symlinks-flag": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/supports-preserve-symlinks-flag/-/supports-preserve-symlinks-flag-1.0.0.tgz", + "integrity": "sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/typescript": { + "version": "5.9.3", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.9.3.tgz", + "integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==", + "dev": true, + "license": "Apache-2.0", + "bin": { + "tsc": "bin/tsc", + "tsserver": "bin/tsserver" + }, + "engines": { + "node": ">=14.17" + } + }, "node_modules/update-browserslist-db": { "version": "1.2.3", "resolved": "https://registry.npmjs.org/update-browserslist-db/-/update-browserslist-db-1.2.3.tgz", @@ -1662,6 +2485,12 @@ "browserslist": ">= 4.21.0" } }, + "node_modules/util-deprecate": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz", + "integrity": "sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==", + "license": "MIT" + }, "node_modules/vite": { "version": "5.4.21", "resolved": "https://registry.npmjs.org/vite/-/vite-5.4.21.tgz", @@ -1729,6 +2558,15 @@ "integrity": "sha512-a4UGQaWPH59mOXUYnAG2ewncQS4i4F43Tv3JoAM+s2VDAmS9NsK8GpDMLrCHPksFT7h3K6TOoUNn2pb7RoXx4g==", "dev": true, "license": "ISC" + }, + "node_modules/yaml": { + "version": "1.10.2", + "resolved": "https://registry.npmjs.org/yaml/-/yaml-1.10.2.tgz", + "integrity": "sha512-r3vXyErRCYJ7wg28yvBY5VSoAF8ZvlcW9/BwUzEtUsjvX/DKs24dIkuwjtuprwJJHsbyUbLApepYTR1BN4uHrg==", + "license": "ISC", + "engines": { + "node": ">= 6" + } } } } diff --git a/widget/package.json b/widget/package.json index 4abc343d..d78da578 100644 --- a/widget/package.json +++ b/widget/package.json @@ -1,18 +1,24 @@ { - "name": "quantem-widget-frontend", - "type": "module", "scripts": { - "dev": "vite build --watch", - "build": "vite build" + "dev": "npm run build -- --sourcemap=inline --watch", + "build": "esbuild js/*.tsx --minify --format=esm --bundle --outdir=src/quantem/widget/static", + "typecheck": "tsc --noEmit" }, "dependencies": { - "react": "^18.2.0", - "react-dom": "^18.2.0", - "@anywidget/react": "^0.1.0" + "@anywidget/react": "^0.2.0", + "@emotion/react": "^11.14.0", + "@emotion/styled": "^11.14.1", + "@mui/material": "^7.3.6", + "jszip": "^3.10.1", + "react": "^19.1.0", + "react-dom": "^19.1.0", + "webfft": "^1.0.3" }, "devDependencies": { - "vite": "^5.2.0", - "@anywidget/vite": "^0.2.0", - "@vitejs/plugin-react": "^4.3.0" + "@types/react": "^19.1.3", + "@types/react-dom": "^19.1.4", + "@webgpu/types": "^0.1.68", + "esbuild": "^0.25.4", + "typescript": "^5.8.3" } } diff --git a/widget/pyproject.toml b/widget/pyproject.toml index e00b047e..84078690 100644 --- a/widget/pyproject.toml +++ b/widget/pyproject.toml @@ -10,6 +10,8 @@ license = { file = "../LICENSE" } requires-python = ">=3.11" dependencies = [ "anywidget>=0.9.0", + "numpy>=2.0.0", + "traitlets>=5.0.0", ] [tool.hatch.build.targets.wheel] diff --git a/widget/src/quantem/widget/__init__.py b/widget/src/quantem/widget/__init__.py index d4ca85a7..bc47cc42 100644 --- a/widget/src/quantem/widget/__init__.py +++ b/widget/src/quantem/widget/__init__.py @@ -1,24 +1,17 @@ -from importlib.metadata import version -import pathlib -import anywidget -import traitlets +""" +quantem.widget: Interactive Jupyter widgets using anywidget + React. +""" -__version__ = version("quantem.widget") +import importlib.metadata -_static = pathlib.Path(__file__).parent / "static" +try: + __version__ = importlib.metadata.version("quantem-widget") +except importlib.metadata.PackageNotFoundError: + __version__ = "unknown" +from quantem.widget.show4dstem import Show4DSTEM -class CounterWidget(anywidget.AnyWidget): - _esm = _static / "index.js" +# Alias for convenience +Show4D = Show4DSTEM - count = traitlets.Int(0).tag(sync=True) - - -def show4dstem(): - # TODO: Implement 4D-STEM visualization widget - print("show4dstem: not yet implemented") - - -def counter(): - """Create a minimal counter widget for testing.""" - return CounterWidget() +__all__ = ["Show4DSTEM", "Show4D"] diff --git a/widget/tests/test_widget.py b/widget/tests/test_widget.py index bd1ba517..2c6b6b81 100644 --- a/widget/tests/test_widget.py +++ b/widget/tests/test_widget.py @@ -1,4 +1,7 @@ +import numpy as np + import quantem.widget +from quantem.widget import Show4DSTEM def test_version_exists(): @@ -7,3 +10,10 @@ def test_version_exists(): def test_version_is_string(): assert isinstance(quantem.widget.__version__, str) + + +def test_show4dstem_loads(): + """Widget can be created from mock 4D data.""" + data = np.random.rand(8, 8, 16, 16).astype(np.float32) + widget = Show4DSTEM(data) + assert widget is not None diff --git a/widget/vite.config.js b/widget/vite.config.js index 8f303083..948e74d3 100644 --- a/widget/vite.config.js +++ b/widget/vite.config.js @@ -10,13 +10,17 @@ export default defineConfig({ build: { outDir: "src/quantem/widget/static", lib: { - entry: "js/index.jsx", + entry: { + show4dstem: "js/show4dstem.tsx", + }, formats: ["es"], - fileName: "index", }, rollupOptions: { output: { - inlineDynamicImports: true, + // Each entry gets its own file + entryFileNames: "[name].js", + // CSS is handled separately by anywidget + assetFileNames: "[name][extname]", }, }, }, From d2de496e4cc9fd684f775ff4aa45bd2899a937b0 Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Mon, 12 Jan 2026 06:53:30 -0800 Subject: [PATCH 03/27] add init, array, show4dstem files --- widget/src/quantem/__init__.py | 0 widget/src/quantem/widget/array_utils.py | 150 ++++ widget/src/quantem/widget/show4dstem.py | 1025 ++++++++++++++++++++++ 3 files changed, 1175 insertions(+) create mode 100644 widget/src/quantem/__init__.py create mode 100644 widget/src/quantem/widget/array_utils.py create mode 100644 widget/src/quantem/widget/show4dstem.py diff --git a/widget/src/quantem/__init__.py b/widget/src/quantem/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/widget/src/quantem/widget/array_utils.py b/widget/src/quantem/widget/array_utils.py new file mode 100644 index 00000000..717c5bcc --- /dev/null +++ b/widget/src/quantem/widget/array_utils.py @@ -0,0 +1,150 @@ +""" +Array utilities for handling NumPy, CuPy, and PyTorch arrays uniformly. + +This module provides utilities to convert arrays from different backends +into NumPy arrays for widget processing. +""" + +from typing import Any, Literal +import numpy as np + + +ArrayBackend = Literal["numpy", "cupy", "torch", "unknown"] + + +def get_array_backend(data: Any) -> ArrayBackend: + """ + Detect the array backend of the input data. + + Parameters + ---------- + data : array-like + Input array (NumPy, CuPy, PyTorch, or other). + + Returns + ------- + str + One of: "numpy", "cupy", "torch", "unknown" + """ + # Check PyTorch first (has both .numpy and .detach methods) + if hasattr(data, "detach") and hasattr(data, "numpy"): + return "torch" + # Check CuPy (has .get() or __cuda_array_interface__) + if hasattr(data, "__cuda_array_interface__"): + return "cupy" + if hasattr(data, "get") and hasattr(data, "__array__"): + # CuPy arrays have .get() to transfer to CPU + type_name = type(data).__module__ + if "cupy" in type_name: + return "cupy" + # Check NumPy + if isinstance(data, np.ndarray): + return "numpy" + return "unknown" + + +def to_numpy(data: Any, dtype: np.dtype | None = None) -> np.ndarray: + """ + Convert any array-like (NumPy, CuPy, PyTorch) to a NumPy array. + + Parameters + ---------- + data : array-like + Input array from any supported backend. + dtype : np.dtype, optional + Target dtype for the output array. If None, preserves original dtype. + + Returns + ------- + np.ndarray + NumPy array with the same data. + + Examples + -------- + >>> import numpy as np + >>> from quantem.widget.array_utils import to_numpy + >>> + >>> # NumPy passthrough + >>> arr = np.random.rand(10, 10) + >>> result = to_numpy(arr) + >>> + >>> # CuPy conversion (if available) + >>> import cupy as cp + >>> gpu_arr = cp.random.rand(10, 10) + >>> cpu_arr = to_numpy(gpu_arr) + >>> + >>> # PyTorch conversion (if available) + >>> import torch + >>> tensor = torch.rand(10, 10) + >>> arr = to_numpy(tensor) + """ + backend = get_array_backend(data) + + if backend == "torch": + # PyTorch tensor: detach from graph, move to CPU, convert to numpy + result = data.detach().cpu().numpy() + + elif backend == "cupy": + # CuPy array: use .get() to transfer to CPU + if hasattr(data, "get"): + result = data.get() + else: + # Fallback for __cuda_array_interface__ + import cupy as cp + + result = cp.asnumpy(data) + + elif backend == "numpy": + # NumPy array: passthrough (may copy if dtype changes) + result = data + + else: + # Unknown backend: try np.asarray as fallback + result = np.asarray(data) + + # Apply dtype conversion if specified + if dtype is not None: + result = np.asarray(result, dtype=dtype) + + return result + + +def to_numpy_list(data_list: list[Any], dtype: np.dtype | None = None) -> list[np.ndarray]: + """ + Convert a list of arrays to NumPy arrays. + + Parameters + ---------- + data_list : list of array-like + List of arrays from any supported backend. + dtype : np.dtype, optional + Target dtype for all output arrays. + + Returns + ------- + list of np.ndarray + List of NumPy arrays. + """ + return [to_numpy(arr, dtype=dtype) for arr in data_list] + + +def get_gpu_module(data: Any): + """ + Get the GPU module (cupy) if the data is on GPU, else return numpy. + + Parameters + ---------- + data : array-like + Input array. + + Returns + ------- + module + Either cupy or numpy module. + """ + backend = get_array_backend(data) + if backend == "cupy": + import cupy as cp + + return cp + return np diff --git a/widget/src/quantem/widget/show4dstem.py b/widget/src/quantem/widget/show4dstem.py new file mode 100644 index 00000000..a7321445 --- /dev/null +++ b/widget/src/quantem/widget/show4dstem.py @@ -0,0 +1,1025 @@ +""" +show4d: Fast interactive 4D-STEM viewer widget with advanced features. + +Features: +- Binary transfer (no base64 overhead) +- Live statistics panel (mean/max/min) +- Virtual detector overlays (BF/ADF circles) +- Linked scan view (side-by-side) +- Auto-range with percentile scaling +- ROI drawing tools +""" + +import pathlib +from collections.abc import Callable +from typing import TYPE_CHECKING + +import anywidget +import numpy as np +import traitlets + +from quantem.widget.array_utils import to_numpy + +if TYPE_CHECKING: + from quantem.detector import Detector + + +# Detector geometry constants (ratios of detector size) +DEFAULT_BF_RATIO = 0.125 # 1/8 of detector size +DEFAULT_ADF_INNER_RATIO = 0.1875 # 1.5 * BF = 3/16 of detector +DEFAULT_ADF_OUTER_RATIO = 0.375 # 3 * BF = 3/8 of detector + + +class Show4DSTEM(anywidget.AnyWidget): + """ + Fast interactive 4D-STEM viewer with advanced features. + + Optimized for speed with binary transfer and pre-normalization. + Works with NumPy and PyTorch arrays. + + Parameters + ---------- + data : array_like + 4D array of shape (scan_x, scan_y, det_x, det_y). + Supports NumPy and PyTorch arrays. + scan_shape : tuple, optional + If data is flattened (N, det_x, det_y), provide scan dimensions. + detector : Detector, optional + Detector object from quantem.detector for automatic calibration. + If provided, center and bf_radius are extracted from the detector. + pixel_size : float, default 1.0 + Pixel size in nm (real-space). Used for scale bar. + det_pixel_size : float, default 1.0 + Detector pixel size in mrad (k-space). Used for scale bar. + center : tuple[float, float], optional + (center_x, center_y) of the diffraction pattern in pixels. + If not provided, defaults to detector center. + bf_radius : float, optional + Bright field disk radius in pixels. If not provided, estimated as 1/8 of detector size. + log_scale : bool, default True + Use log scale for better dynamic range visualization. + auto_range : bool, default False + Use percentile-based scaling instead of global min/max. + percentile_low : float, default 1.0 + Lower percentile for auto-range (0-100). + percentile_high : float, default 99.0 + Upper percentile for auto-range (0-100). + path_points : list[tuple[int, int]], optional + List of (x, y) scan positions for programmatic animation. + Use with play(), pause(), stop() methods. + path_interval_ms : int, default 100 + Time between frames in path animation (milliseconds). + path_loop : bool, default True + Whether to loop when path animation reaches the end. + + Examples + -------- + >>> from quantem.widget import Show4DSTEM + >>> import numpy as np + >>> data = np.random.rand(64, 64, 128, 128) + >>> Show4DSTEM(data) + + >>> # With manual calibration + >>> Show4DSTEM(data, pixel_size=0.5, det_pixel_size=0.1, bf_radius=20) + + >>> # With Detector object (optional dependency) + >>> from quantem.detector import Detector + >>> det = Detector("data.h5") + >>> Show4DSTEM(det.data, detector=det) # Uses det.center, det.bf_radius + + >>> # With path animation + >>> path = [(i, i) for i in range(64)] # Diagonal path + >>> widget = Show4DSTEM(data, path_points=path, path_interval_ms=50) + >>> widget.play() # Start animation + """ + + _esm = pathlib.Path(__file__).parent / "static" / "show4dstem.js" + _css = pathlib.Path(__file__).parent / "static" / "show4dstem.css" + + # Position in scan space + pos_x = traitlets.Int(0).tag(sync=True) + pos_y = traitlets.Int(0).tag(sync=True) + + # Shape of scan space (for slider bounds) + shape_x = traitlets.Int(1).tag(sync=True) + shape_y = traitlets.Int(1).tag(sync=True) + + # Detector shape for frontend + det_x = traitlets.Int(1).tag(sync=True) + det_y = traitlets.Int(1).tag(sync=True) + + # Pre-normalized uint8 frame as bytes (no base64!) + frame_bytes = traitlets.Bytes(b"").tag(sync=True) + + # Log scale toggle + log_scale = traitlets.Bool(True).tag(sync=True) + + # ========================================================================= + # Stats Panel + # ========================================================================= + stats_mean = traitlets.Float(0.0).tag(sync=True) + stats_max = traitlets.Float(0.0).tag(sync=True) + stats_min = traitlets.Float(0.0).tag(sync=True) + show_stats = traitlets.Bool(True).tag(sync=True) + + # ========================================================================= + # Detector Integration (BF/ADF overlays) + # ========================================================================= + has_detector = traitlets.Bool(False).tag(sync=True) + center_x = traitlets.Float(0.0).tag(sync=True) + center_y = traitlets.Float(0.0).tag(sync=True) + bf_radius = traitlets.Float(0.0).tag(sync=True) + show_bf_overlay = traitlets.Bool(True).tag(sync=True) + show_adf_overlay = traitlets.Bool(False).tag(sync=True) + adf_inner_radius = traitlets.Float(0.0).tag(sync=True) + adf_outer_radius = traitlets.Float(0.0).tag(sync=True) + + # ========================================================================= + # Linked Scan View + # ========================================================================= + show_scan_view = traitlets.Bool(False).tag(sync=True) + scan_mode = traitlets.Unicode("bf").tag(sync=True) # 'bf', 'adf', 'custom' + scan_image_bytes = traitlets.Bytes(b"").tag(sync=True) + + # ========================================================================= + # Auto-Range (percentile scaling) + # ========================================================================= + auto_range = traitlets.Bool(False).tag(sync=True) + percentile_low = traitlets.Float(1.0).tag(sync=True) + percentile_high = traitlets.Float(99.0).tag(sync=True) + + # ========================================================================= + # ROI Drawing (for virtual imaging) + # ========================================================================= + roi_active = traitlets.Bool(False).tag(sync=True) + roi_mode = traitlets.Unicode("point").tag(sync=True) # 'point', 'circle', 'square', 'rect', or 'annular' + roi_center_x = traitlets.Float(0.0).tag(sync=True) + roi_center_y = traitlets.Float(0.0).tag(sync=True) + roi_radius = traitlets.Float(10.0).tag(sync=True) # Outer radius for circle/annular, half-width for square + roi_radius_inner = traitlets.Float(5.0).tag(sync=True) # Inner radius for annular mode + roi_width = traitlets.Float(20.0).tag(sync=True) # Width for rectangular mode + roi_height = traitlets.Float(10.0).tag(sync=True) # Height for rectangular mode + roi_integrated_value = traitlets.Float(0.0).tag(sync=True) + + # ========================================================================= + # Mean Diffraction Pattern + # ========================================================================= + mean_dp_bytes = traitlets.Bytes(b"").tag(sync=True) + show_mean_dp = traitlets.Bool(True).tag(sync=True) + + # ========================================================================= + # BF Image (Bright Field integrated image) + # ========================================================================= + bf_image_bytes = traitlets.Bytes(b"").tag(sync=True) + show_bf_image = traitlets.Bool(True).tag(sync=True) + + # ========================================================================= + # Virtual Image (ROI-based, updates as you drag ROI on DP) + # ========================================================================= + virtual_image_bytes = traitlets.Bytes(b"").tag(sync=True) + + # ========================================================================= + # Scale Bar + # ========================================================================= + pixel_size = traitlets.Float(1.0).tag(sync=True) # nm per pixel (real-space) + det_pixel_size = traitlets.Float(1.0).tag(sync=True) # mrad per pixel (k-space) + + # ========================================================================= + # Path Animation (programmatic crosshair control) + # ========================================================================= + path_playing = traitlets.Bool(False).tag(sync=True) + path_index = traitlets.Int(0).tag(sync=True) + path_length = traitlets.Int(0).tag(sync=True) + path_interval_ms = traitlets.Int(100).tag(sync=True) # ms between frames + path_loop = traitlets.Bool(True).tag(sync=True) # loop when reaching end + + def __init__( + self, + data, + scan_shape: tuple[int, int] | None = None, + detector: "Detector | None" = None, + pixel_size: float = 1.0, + det_pixel_size: float = 1.0, + center: tuple[float, float] | None = None, + bf_radius: float | None = None, + log_scale: bool = True, + auto_range: bool = False, + percentile_low: float = 1.0, + percentile_high: float = 99.0, + path_points: list[tuple[int, int]] | None = None, + path_interval_ms: int = 100, + path_loop: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + + self._log_scale = log_scale + self.log_scale = log_scale + self.auto_range = auto_range + self.percentile_low = percentile_low + self.percentile_high = percentile_high + + # Store calibration values + self.pixel_size = pixel_size + self.det_pixel_size = det_pixel_size + + # Path animation settings + self._path_points: list[tuple[int, int]] = path_points or [] + self.path_length = len(self._path_points) + self.path_interval_ms = path_interval_ms + self.path_loop = path_loop + + # Convert to NumPy + self._data = to_numpy(data) + + # Handle flattened data + if data.ndim == 3: + if scan_shape is not None: + self._scan_shape = scan_shape + else: + # Infer square scan shape from N + n = data.shape[0] + side = int(n ** 0.5) + if side * side != n: + raise ValueError( + f"Cannot infer square scan_shape from N={n}. " + f"Provide scan_shape explicitly." + ) + self._scan_shape = (side, side) + self._det_shape = (data.shape[1], data.shape[2]) + elif data.ndim == 4: + self._scan_shape = (data.shape[0], data.shape[1]) + self._det_shape = (data.shape[2], data.shape[3]) + else: + raise ValueError(f"Expected 3D or 4D array, got {data.ndim}D") + + self.shape_x = self._scan_shape[0] + self.shape_y = self._scan_shape[1] + self.det_x = self._det_shape[0] + self.det_y = self._det_shape[1] + + # Initial position at center + self.pos_x = self.shape_x // 2 + self.pos_y = self.shape_y // 2 + + # Precompute global range for consistent scaling + self._compute_global_range() + + # Setup center and BF radius: detector > user params > defaults + if detector is not None: + # Use Detector object for calibration + det_center = detector.center + self.center_x = float(det_center[0]) + self.center_y = float(det_center[1]) + self.bf_radius = float(detector.bf_radius) + elif center is not None: + # Use user-provided center + self.center_x = float(center[0]) + self.center_y = float(center[1]) + det_size = min(self.det_x, self.det_y) + self.bf_radius = float(bf_radius) if bf_radius is not None else det_size * DEFAULT_BF_RATIO + else: + # Default to detector center + self.center_x = float(self.det_y / 2) + self.center_y = float(self.det_x / 2) + det_size = min(self.det_x, self.det_y) + self.bf_radius = float(bf_radius) if bf_radius is not None else det_size * DEFAULT_BF_RATIO + + # Set ADF defaults based on detector size + det_size = min(self.det_x, self.det_y) + self.adf_inner_radius = det_size * DEFAULT_ADF_INNER_RATIO + self.adf_outer_radius = det_size * DEFAULT_ADF_OUTER_RATIO + self.has_detector = True + + # Compute mean DP and BF image (sent once on init) + self._compute_mean_dp() + self._compute_bf_image() + + # Pre-compute and cache common virtual images (BF, ABF, LAADF, HAADF) + self._cached_bf_virtual = None + self._cached_abf_virtual = None + self._cached_laadf_virtual = None + self._cached_haadf_virtual = None + self._precompute_common_virtual_images() + + # Update frame when position or settings change + self.observe(self._update_frame, names=[ + "pos_x", "pos_y", "log_scale", "auto_range", + "percentile_low", "percentile_high" + ]) + self.observe(self._on_roi_change, names=[ + "roi_center_x", "roi_center_y", "roi_radius", "roi_radius_inner", + "roi_active", "roi_mode", "roi_width", "roi_height" + ]) + self.observe(self._on_scan_mode_change, names=["scan_mode", "show_scan_view"]) + + # Initialize default ROI at BF center + self.roi_center_x = self.center_x + self.roi_center_y = self.center_y + self.roi_radius = self.bf_radius * 0.5 # Start with half BF radius + self.roi_active = True + + # Compute initial virtual image + try: + self._compute_virtual_image_from_roi() + except Exception: + pass + + self._update_frame() + + # Path animation: observe index changes from frontend + self.observe(self._on_path_index_change, names=["path_index"]) + + # ========================================================================= + # Array Utilities + # ========================================================================= + + @staticmethod + def _to_cpu(arr): + """Convert array to CPU (NumPy).""" + return np.asarray(arr) + + @staticmethod + def _to_scalar(val): + """Convert scalar value to Python float.""" + return float(val) + + # ========================================================================= + # Path Animation Methods + # ========================================================================= + + def set_path( + self, + points: list[tuple[int, int]] | None = None, + generator: "Callable[[int, int, int], tuple[int, int]] | None" = None, + n_frames: int | None = None, + interval_ms: int | None = None, + loop: bool | None = None, + autoplay: bool = True, + ) -> "Show4DSTEM": + """ + Set a path of scan positions to animate through. + + You can provide either a list of points OR a generator function. + + Parameters + ---------- + points : list[tuple[int, int]], optional + List of (x, y) scan positions to visit. + generator : callable, optional + Custom function with signature `f(index, shape_x, shape_y) -> (x, y)`. + Called for each frame to get the next position. + n_frames : int, optional + Number of frames when using generator. Required if using generator. + interval_ms : int, optional + Time between frames in milliseconds. Default 100ms. + loop : bool, optional + Whether to loop when reaching end. Default True. + autoplay : bool, default True + Start playing immediately. + + Returns + ------- + Show4DSTEM + Self for method chaining. + + Examples + -------- + >>> # Option 1: List of points + >>> path = [(0, 0), (10, 10), (20, 20), (30, 30)] + >>> widget.set_path(points=path) + + >>> # Option 2: Custom generator function + >>> def my_path(i, sx, sy): + ... # Random walk + ... import random + ... return (random.randint(0, sx-1), random.randint(0, sy-1)) + >>> widget.set_path(generator=my_path, n_frames=100) + + >>> # Option 3: Lambda for quick patterns + >>> widget.set_path( + ... generator=lambda i, sx, sy: (i % sx, (i * 3) % sy), + ... n_frames=200 + ... ) + """ + if generator is not None: + # Use generator function to create points + if n_frames is None: + n_frames = 100 # Default + self._path_points = [ + generator(i, self.shape_x, self.shape_y) + for i in range(n_frames) + ] + elif points is not None: + self._path_points = list(points) + else: + raise ValueError("Must provide either 'points' or 'generator'") + + self.path_length = len(self._path_points) + self.path_index = 0 + + if interval_ms is not None: + self.path_interval_ms = interval_ms + if loop is not None: + self.path_loop = loop + if autoplay and self.path_length > 0: + self.path_playing = True + + return self + + def play(self) -> "Show4DSTEM": + """Start playing the path animation.""" + if self.path_length > 0: + self.path_playing = True + return self + + def pause(self) -> "Show4DSTEM": + """Pause the path animation.""" + self.path_playing = False + return self + + def stop(self) -> "Show4DSTEM": + """Stop and reset path animation to beginning.""" + self.path_playing = False + self.path_index = 0 + return self + + def goto(self, index: int) -> "Show4DSTEM": + """Jump to a specific index in the path.""" + if 0 <= index < self.path_length: + self.path_index = index + return self + + def _on_path_index_change(self, change): + """Called when path_index changes (from frontend timer).""" + idx = change["new"] + if 0 <= idx < len(self._path_points): + x, y = self._path_points[idx] + # Clamp to valid range + self.pos_x = max(0, min(self.shape_x - 1, x)) + self.pos_y = max(0, min(self.shape_y - 1, y)) + + # ========================================================================= + # Path Animation Patterns + # ========================================================================= + + def play_raster(self, step: int = 1, bidirectional: bool = False) -> "Show4DSTEM": + """ + Play a raster scan path (row by row, left to right). + + This mimics real STEM scanning: left→right, step down, left→right, etc. + + Parameters + ---------- + step : int, default 1 + Step size between positions. + bidirectional : bool, default False + If True, use snake/boustrophedon pattern (alternating direction). + If False (default), always scan left→right like real STEM. + + Returns + ------- + Show4DSTEM + Self for method chaining. + """ + points = [] + for x in range(0, self.shape_x, step): + row = list(range(0, self.shape_y, step)) + if bidirectional and (x // step % 2 == 1): + row = row[::-1] # Alternate direction for snake pattern + for y in row: + points.append((x, y)) + return self.set_path(points=points) + + # ========================================================================= + # ROI Mode Methods + # ========================================================================= + + def set_roi_circle(self, radius: float | None = None) -> "Show4DSTEM": + """ + Switch to circle ROI mode for virtual imaging. + + In circle mode, the virtual image integrates over a circular region + centered at the current ROI position (like a virtual bright field detector). + + Parameters + ---------- + radius : float, optional + Radius of the circle in pixels. If not provided, uses current value + or defaults to half the BF radius. + + Returns + ------- + Show4DSTEM + Self for method chaining. + + Examples + -------- + >>> widget.set_roi_circle(20) # 20px radius circle + >>> widget.set_roi_circle() # Use default radius + """ + self.roi_mode = "circle" + if radius is not None: + self.roi_radius = float(radius) + return self + + def set_roi_point(self) -> "Show4DSTEM": + """ + Switch to point ROI mode (single-pixel indexing). + + In point mode, the virtual image shows intensity at the exact ROI position. + This is the default mode. + + Returns + ------- + Show4DSTEM + Self for method chaining. + """ + self.roi_mode = "point" + return self + + def set_roi_square(self, size: float | None = None) -> "Show4DSTEM": + """ + Switch to square ROI mode for virtual imaging. + + In square mode, the virtual image integrates over a square region + centered at the current ROI position. + + Parameters + ---------- + size : float, optional + Half-size of the square in pixels (distance from center to edge). + If not provided, uses current roi_radius value. + + Returns + ------- + Show4DSTEM + Self for method chaining. + + Examples + -------- + >>> widget.set_roi_square(15) # 30x30 pixel square + >>> widget.set_roi_square() # Use default size + """ + self.roi_mode = "square" + if size is not None: + self.roi_radius = float(size) + return self + + def set_roi_annular( + self, inner_radius: float | None = None, outer_radius: float | None = None + ) -> "Show4DSTEM": + """ + Set ROI mode to annular (donut-shaped) for ADF/HAADF imaging. + + Parameters + ---------- + inner_radius : float, optional + Inner radius in pixels. If not provided, uses current roi_radius_inner. + outer_radius : float, optional + Outer radius in pixels. If not provided, uses current roi_radius. + + Returns + ------- + Show4DSTEM + Self for method chaining. + + Examples + -------- + >>> widget.set_roi_annular(20, 50) # ADF: inner=20px, outer=50px + >>> widget.set_roi_annular(30, 80) # HAADF: larger angles + """ + self.roi_mode = "annular" + if inner_radius is not None: + self.roi_radius_inner = float(inner_radius) + if outer_radius is not None: + self.roi_radius = float(outer_radius) + return self + + def set_roi_rect( + self, width: float | None = None, height: float | None = None + ) -> "Show4DSTEM": + """ + Set ROI mode to rectangular. + + Parameters + ---------- + width : float, optional + Width in pixels. If not provided, uses current roi_width. + height : float, optional + Height in pixels. If not provided, uses current roi_height. + + Returns + ------- + Show4DSTEM + Self for method chaining. + + Examples + -------- + >>> widget.set_roi_rect(30, 20) # 30px wide, 20px tall + >>> widget.set_roi_rect(40, 40) # 40x40 rectangle + """ + self.roi_mode = "rect" + if width is not None: + self.roi_width = float(width) + if height is not None: + self.roi_height = float(height) + return self + + def _compute_global_range(self): + """Compute global min/max from sampled frames for consistent scaling.""" + + # Sample corners and center + samples = [ + (0, 0), + (0, self.shape_y - 1), + (self.shape_x - 1, 0), + (self.shape_x - 1, self.shape_y - 1), + (self.shape_x // 2, self.shape_y // 2), + ] + + all_min, all_max = float("inf"), float("-inf") + all_values = [] + for x, y in samples: + frame = self._get_frame(x, y) + fmin = float(frame.min()) + fmax = float(frame.max()) + all_min = min(all_min, fmin) + all_max = max(all_max, fmax) + + # Sample values for percentile estimation + all_values.append(self._to_cpu(frame).flatten()[::100]) + + self._global_min = max(all_min, 1e-10) + self._global_max = all_max + + # Precompute log range + self._log_min = np.log1p(self._global_min) + self._log_max = np.log1p(self._global_max) + + # Store sampled values for percentile computation + self._sampled_values = np.concatenate(all_values) + + def _get_frame(self, x: int, y: int): + """Get single diffraction frame at position (x, y).""" + if self._data.ndim == 3: + idx = x * self.shape_y + y + return self._data[idx] + else: + return self._data[x, y] + + def _compute_percentile_range(self, frame): + """Compute percentile-based range for a frame.""" + + # Use NumPy for percentile (faster for small arrays) + frame_np = self._to_cpu(frame).flatten() + + vmin = float(np.percentile(frame_np, self.percentile_low)) + vmax = float(np.percentile(frame_np, self.percentile_high)) + return max(vmin, 1e-10), vmax + + def _update_frame(self, change=None): + """Send pre-normalized uint8 frame to frontend with stats.""" + frame = self._get_frame(self.pos_x, self.pos_y) + + # Compute stats + self.stats_mean = self._to_scalar(frame.mean()) + self.stats_max = self._to_scalar(frame.max()) + self.stats_min = self._to_scalar(frame.min()) + + # Determine value range + if self.auto_range: + vmin, vmax = self._compute_percentile_range(frame) + if self.log_scale: + vmin = np.log1p(vmin) + vmax = np.log1p(vmax) + else: + if self.log_scale: + vmin, vmax = self._log_min, self._log_max + else: + vmin, vmax = self._global_min, self._global_max + + # Apply log scale if enabled + if self.log_scale: + frame = np.log1p(frame.astype(np.float32)) + else: + frame = frame.astype(np.float32) + + # Normalize to 0-255 + if vmax > vmin: + normalized = np.clip((frame - vmin) / (vmax - vmin) * 255, 0, 255) + normalized = normalized.astype(np.uint8) + else: + normalized = np.zeros(frame.shape, dtype=np.uint8) + + # Send as raw bytes (no base64 encoding!) + self.frame_bytes = normalized.tobytes() + + def _on_roi_change(self, change=None): + """Compute integrated value when ROI changes.""" + # Skip if ROI is not active or has no valid size + if not self.roi_active: + self.roi_integrated_value = 0.0 + return + + # For circle/square/annular modes, need positive radius + if self.roi_mode in ("circle", "square", "annular") and self.roi_radius <= 0: + self.roi_integrated_value = 0.0 + return + + frame = self._get_frame(self.pos_x, self.pos_y) + + # Create mask based on ROI mode + if self.roi_mode == "circle" and self.roi_radius > 0: + mask = self._create_circular_mask( + self.roi_center_x, self.roi_center_y, self.roi_radius + ) + elif self.roi_mode == "square" and self.roi_radius > 0: + mask = self._create_square_mask( + self.roi_center_x, self.roi_center_y, self.roi_radius + ) + elif self.roi_mode == "annular" and self.roi_radius > 0: + mask = self._create_annular_mask( + self.roi_center_x, self.roi_center_y, + self.roi_radius_inner, self.roi_radius + ) + elif self.roi_mode == "rect" and self.roi_width > 0 and self.roi_height > 0: + mask = self._create_rect_mask( + self.roi_center_x, self.roi_center_y, + self.roi_width / 2, self.roi_height / 2 + ) + else: + # Point mode: no mask, just single pixel + self.roi_integrated_value = 0.0 + self._compute_virtual_image_from_roi() + return + + # Compute integrated value (use multiplication to avoid indexing issues) + integrated = self._to_scalar((frame * mask).sum()) + self.roi_integrated_value = integrated + + # Fast path: check if we can use cached preset + cached = self._get_cached_preset() + if cached is not None: + self.virtual_image_bytes = cached + return + + # Real-time update using fast masked sum + self._compute_virtual_image_from_roi() + + def _on_scan_mode_change(self, change=None): + """Recompute scan image when mode changes.""" + if self.show_scan_view and self.has_detector: + self._compute_scan_image() + + def _compute_scan_image(self): + """Compute virtual detector image (BF or ADF).""" + + # Get appropriate mask + if self.scan_mode == "bf": + mask = self._create_circular_mask( + self.center_x, self.center_y, self.bf_radius + ) + elif self.scan_mode == "adf": + mask = self._create_annular_mask( + self.center_x, self.center_y, + self.adf_inner_radius, self.adf_outer_radius + ) + else: + # Custom ROI mask + if self.roi_active and self.roi_radius > 0: + mask = self._create_circular_mask( + self.roi_center_x, self.roi_center_y, self.roi_radius + ) + else: + return + + # Compute integrated image + if self._data.ndim == 4: + # (Rx, Ry, Qx, Qy) -> Apply mask and sum + scan_image = (self._data * mask).sum(axis=(-2, -1)) + else: + # (N, Qx, Qy) -> reshape and sum + scan_image = (self._data * mask).sum(axis=(-2, -1)) + scan_image = scan_image.reshape(self._scan_shape) + + # Normalize to uint8 + smin = self._to_scalar(scan_image.min()) + smax = self._to_scalar(scan_image.max()) + + if smax > smin: + normalized = np.clip((scan_image - smin) / (smax - smin) * 255, 0, 255) + normalized = normalized.astype(np.uint8) + else: + normalized = np.zeros(scan_image.shape, dtype=np.uint8) + + normalized = self._to_cpu(normalized) + + self.scan_image_bytes = normalized.tobytes() + + def _create_circular_mask(self, cx: float, cy: float, radius: float): + """Create circular mask (boolean).""" + y, x = np.ogrid[:self.det_x, :self.det_y] + mask = (x - cx) ** 2 + (y - cy) ** 2 <= radius ** 2 + return mask + + def _create_square_mask(self, cx: float, cy: float, half_size: float): + """Create square mask (boolean).""" + y, x = np.ogrid[:self.det_x, :self.det_y] + mask = (np.abs(x - cx) <= half_size) & (np.abs(y - cy) <= half_size) + return mask + + def _create_annular_mask( + self, cx: float, cy: float, inner: float, outer: float + ): + """Create annular (donut) mask (boolean).""" + y, x = np.ogrid[:self.det_x, :self.det_y] + dist_sq = (x - cx) ** 2 + (y - cy) ** 2 + mask = (dist_sq >= inner ** 2) & (dist_sq <= outer ** 2) + return mask + + def _create_rect_mask(self, cx: float, cy: float, half_width: float, half_height: float): + """Create rectangular mask (boolean).""" + y, x = np.ogrid[:self.det_x, :self.det_y] + mask = (np.abs(x - cx) <= half_width) & (np.abs(y - cy) <= half_height) + return mask + + def _compute_mean_dp(self): + """Compute and send mean diffraction pattern.""" + if self._data.ndim == 4: + mean_dp = self._data.mean(axis=(0, 1)) + else: + mean_dp = self._data.mean(axis=0) + # Log scale + mean_dp = np.log1p(mean_dp) + # Normalize to uint8 + mean_dp_cpu = self._to_cpu(mean_dp) + vmin, vmax = float(mean_dp_cpu.min()), float(mean_dp_cpu.max()) + if vmax > vmin: + normalized = np.clip((mean_dp_cpu - vmin) / (vmax - vmin) * 255, 0, 255).astype(np.uint8) + else: + normalized = np.zeros(mean_dp_cpu.shape, dtype=np.uint8) + self.mean_dp_bytes = normalized.tobytes() + + def _compute_bf_image(self): + """Compute BF integrated image using detected probe.""" + + # Create BF mask + mask = self._create_circular_mask(self.center_x, self.center_y, self.bf_radius) + + # Compute integrated BF image + if self._data.ndim == 4: + bf_image = (self._data * mask).sum(axis=(-2, -1)) + else: + bf_image = (self._data * mask).sum(axis=(-2, -1)) + bf_image = bf_image.reshape(self._scan_shape) + + # Normalize to uint8 + vmin = self._to_scalar(bf_image.min()) + vmax = self._to_scalar(bf_image.max()) + + if vmax > vmin: + normalized = np.clip((bf_image - vmin) / (vmax - vmin) * 255, 0, 255) + normalized = normalized.astype(np.uint8) + else: + normalized = np.zeros(bf_image.shape, dtype=np.uint8) + + normalized = self._to_cpu(normalized) + + self.bf_image_bytes = normalized.tobytes() + + def _precompute_common_virtual_images(self): + """Pre-compute BF/ABF/LAADF/HAADF virtual images for instant mode switching.""" + + def _compute_and_normalize(mask): + if self._data.ndim == 4: + img = (self._data * mask).sum(axis=(-2, -1)) + else: + img = (self._data * mask).sum(axis=(-2, -1)).reshape(self._scan_shape) + vmin, vmax = self._to_scalar(img.min()), self._to_scalar(img.max()) + if vmax > vmin: + norm = np.clip((img - vmin) / (vmax - vmin) * 255, 0, 255).astype(np.uint8) + else: + norm = np.zeros(img.shape, dtype=np.uint8) + return self._to_cpu(norm).tobytes() + + # BF: circle at bf_radius + bf_mask = self._create_circular_mask(self.center_x, self.center_y, self.bf_radius) + self._cached_bf_virtual = _compute_and_normalize(bf_mask) + + # ABF: annular at 0.5*bf to bf (matches JS button) + abf_mask = self._create_annular_mask( + self.center_x, self.center_y, self.bf_radius * 0.5, self.bf_radius + ) + self._cached_abf_virtual = _compute_and_normalize(abf_mask) + + # LAADF: annular at bf to 2*bf (matches JS button) + laadf_mask = self._create_annular_mask( + self.center_x, self.center_y, self.bf_radius, self.bf_radius * 2.0 + ) + self._cached_laadf_virtual = _compute_and_normalize(laadf_mask) + + # HAADF: annular at 2*bf to 4*bf (matches JS button) + haadf_mask = self._create_annular_mask( + self.center_x, self.center_y, self.bf_radius * 2.0, self.bf_radius * 4.0 + ) + self._cached_haadf_virtual = _compute_and_normalize(haadf_mask) + + def _get_cached_preset(self) -> bytes | None: + """Check if current ROI matches a cached preset and return it.""" + # Must be centered on detector center + if abs(self.roi_center_x - self.center_x) >= 1 or abs(self.roi_center_y - self.center_y) >= 1: + return None + + bf = self.bf_radius + + # BF: circle at bf_radius + if (self.roi_mode == "circle" and abs(self.roi_radius - bf) < 1): + return self._cached_bf_virtual + + # ABF: annular at 0.5*bf to bf + if (self.roi_mode == "annular" and + abs(self.roi_radius_inner - bf * 0.5) < 1 and + abs(self.roi_radius - bf) < 1): + return self._cached_abf_virtual + + # LAADF: annular at bf to 2*bf + if (self.roi_mode == "annular" and + abs(self.roi_radius_inner - bf) < 1 and + abs(self.roi_radius - bf * 2.0) < 1): + return self._cached_laadf_virtual + + # HAADF: annular at 2*bf to 4*bf + if (self.roi_mode == "annular" and + abs(self.roi_radius_inner - bf * 2.0) < 1 and + abs(self.roi_radius - bf * 4.0) < 1): + return self._cached_haadf_virtual + + return None + + def _fast_masked_sum(self, mask) -> 'np.ndarray': + """Fast masked sum using element-wise multiply (memory efficient).""" + # Handle both 3D and 4D data + if self._data.ndim == 4: + # (scan_x, scan_y, det_x, det_y) -> sum over detector dims + virtual_image = (self._data.astype(np.float32) * mask).sum(axis=(2, 3)) + else: + # (N, det_x, det_y) -> sum over detector dims then reshape + virtual_image = (self._data.astype(np.float32) * mask).sum(axis=(1, 2)) + virtual_image = virtual_image.reshape(self._scan_shape) + return virtual_image + + def _compute_virtual_image_from_roi(self): + """Compute virtual image based on ROI mode (point, circle, square, or annular).""" + + # Fast path: use cached images for presets (BF/ABF/LAADF/HAADF) + cached = self._get_cached_preset() + if cached is not None: + self.virtual_image_bytes = cached + return + + if self.roi_mode == "circle" and self.roi_radius > 0: + mask = self._create_circular_mask( + self.roi_center_x, self.roi_center_y, self.roi_radius + ) + virtual_image = self._fast_masked_sum(mask) + elif self.roi_mode == "square" and self.roi_radius > 0: + mask = self._create_square_mask( + self.roi_center_x, self.roi_center_y, self.roi_radius + ) + virtual_image = self._fast_masked_sum(mask) + elif self.roi_mode == "annular" and self.roi_radius > 0 and self.roi_radius_inner >= 0: + mask = self._create_annular_mask( + self.roi_center_x, self.roi_center_y, + self.roi_radius_inner, self.roi_radius + ) + virtual_image = self._fast_masked_sum(mask) + elif self.roi_mode == "rect" and self.roi_width > 0 and self.roi_height > 0: + mask = self._create_rect_mask( + self.roi_center_x, self.roi_center_y, + self.roi_width / 2, self.roi_height / 2 + ) + virtual_image = self._fast_masked_sum(mask) + else: + # Point mode: single-pixel indexing - O(1) on GPU! + kx = int(max(0, min(self._det_shape[0] - 1, round(self.roi_center_y)))) + ky = int(max(0, min(self._det_shape[1] - 1, round(self.roi_center_x)))) + + if self._data.ndim == 4: + virtual_image = self._data[:, :, kx, ky] + else: + virtual_image = self._data[:, kx, ky].reshape(self._scan_shape) + + # Normalize to uint8 + vmin = self._to_scalar(virtual_image.min()) + vmax = self._to_scalar(virtual_image.max()) + + if vmax > vmin: + normalized = np.clip((virtual_image - vmin) / (vmax - vmin) * 255, 0, 255) + normalized = normalized.astype(np.uint8) + else: + normalized = np.zeros(virtual_image.shape, dtype=np.uint8) + + normalized_cpu = self._to_cpu(normalized) + + self.virtual_image_bytes = normalized_cpu.tobytes() From b19f4f0db359c1878ad8e96cea86788aa8b4de5a Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Mon, 12 Jan 2026 09:49:53 -0800 Subject: [PATCH 04/27] clean-up dead trailets, and support A/mmrad now --- widget/js/core/canvas-utils.ts | 25 +- widget/js/show4dstem.tsx | 121 ++--- widget/src/quantem/widget/__init__.py | 2 +- widget/src/quantem/widget/show4dstem.py | 608 +++++++----------------- widget/tests/test_widget.py | 68 +++ 5 files changed, 276 insertions(+), 548 deletions(-) diff --git a/widget/js/core/canvas-utils.ts b/widget/js/core/canvas-utils.ts index 04bf0101..1679fa2f 100644 --- a/widget/js/core/canvas-utils.ts +++ b/widget/js/core/canvas-utils.ts @@ -37,26 +37,23 @@ export function renderWithColormap( } /** - * Render float32 data to canvas with colormap and optional percentile contrast. + * Render float32 data to canvas with colormap. */ export function renderFloat32WithColormap( ctx: CanvasRenderingContext2D, data: Float32Array, width: number, height: number, - cmapName: string = "inferno", - percentileLow: number = 0, - percentileHigh: number = 100 + cmapName: string = "inferno" ): void { const lut = COLORMAPS[cmapName] || COLORMAPS.inferno; - // Calculate min/max using percentiles - const sorted = Float32Array.from(data).sort((a, b) => a - b); - const len = sorted.length; - const loIdx = Math.floor((percentileLow / 100) * (len - 1)); - const hiIdx = Math.floor((percentileHigh / 100) * (len - 1)); - const min = sorted[loIdx]; - const max = sorted[hiIdx]; + // Calculate min/max + let min = Infinity, max = -Infinity; + for (let i = 0; i < data.length; i++) { + if (data[i] < min) min = data[i]; + if (data[i] > max) max = data[i]; + } const range = max - min || 1; const scale = 255 / range; @@ -119,7 +116,11 @@ export function roundToNiceValue(value: number): number { export function formatScaleLabel(value: number, unit: string): string { const nice = roundToNiceValue(value); - if (unit === "nm") { + if (unit === "Å") { + if (nice >= 10) return `${Math.round(nice / 10)} nm`; + if (nice >= 1) return `${Math.round(nice)} Å`; + return `${nice.toFixed(2)} Å`; + } else if (unit === "nm") { if (nice >= 1000) return `${Math.round(nice / 1000)} µm`; if (nice >= 1) return `${Math.round(nice)} nm`; return `${nice.toFixed(2)} nm`; diff --git a/widget/js/show4dstem.tsx b/widget/js/show4dstem.tsx index f32e7adb..46b799b9 100644 --- a/widget/js/show4dstem.tsx +++ b/widget/js/show4dstem.tsx @@ -9,35 +9,18 @@ import Slider from "@mui/material/Slider"; import Button from "@mui/material/Button"; import Switch from "@mui/material/Switch"; import JSZip from "jszip"; -import { getWebGPUFFT, WebGPUFFT, getGPUInfo } from "./webgpu-fft"; -import { COLORMAPS, fft1d, fft2d, fftshift, applyBandPassFilter, MIN_ZOOM, MAX_ZOOM } from "./shared"; -import { colors, typography, controlPanel, container } from "./CONFIG"; +import { getWebGPUFFT, WebGPUFFT } from "./webgpu-fft"; +import { COLORMAPS, fft2d, fftshift, applyBandPassFilter, MIN_ZOOM, MAX_ZOOM } from "./shared"; +import { typography, controlPanel, container } from "./CONFIG"; import { upwardMenuProps, switchStyles } from "./components"; import "./show4dstem.css"; -// ============================================================================ -// Constants - Relative sizing for various detector sizes (64x64 to 256x256+) -// ============================================================================ -const RESIZE_HANDLE_FRACTION = 0.05; // Resize handle as fraction of detector size -const RESIZE_HANDLE_MIN_PX = 5; // Minimum resize handle radius -const RESIZE_HANDLE_MAX_PX = 8; // Maximum resize handle radius -const RESIZE_HANDLE_RADIUS = 6; // Fixed handle radius for drawing -const RESIZE_HIT_AREA_FRACTION = 0.06; // Click tolerance as fraction of detector -const RESIZE_HIT_AREA_MIN_PX = 6; // Minimum click tolerance -const RESIZE_HIT_AREA_PX = 10; // Fixed hit area for click detection -// Crosshair sizes: fixed pixel sizes for consistent appearance -const CROSSHAIR_SIZE_PX = 18; // Fixed crosshair size for point mode (CSS pixels on 400px canvas) -const CROSSHAIR_SIZE_SMALL_PX = 10; // Fixed small crosshair size for ROI center -const CENTER_DOT_RADIUS_PX = 6; // Center dot radius -const CIRCLE_HANDLE_ANGLE = 0.707; // cos(45°) for circle handle position -// Line widths as fraction of size -const LINE_WIDTH_FRACTION = 0.015; // Line width as fraction of size -const LINE_WIDTH_MIN_PX = 1.5; // Minimum line width -const LINE_WIDTH_MAX_PX = 3; // Maximum line width - -// ============================================================================ -// Scale Bar (dynamic adjustment to nice values) -// ============================================================================ +// Constants +const RESIZE_HIT_AREA_PX = 10; +const CIRCLE_HANDLE_ANGLE = 0.707; // cos(45°) +const LINE_WIDTH_FRACTION = 0.015; +const LINE_WIDTH_MIN_PX = 1.5; +const LINE_WIDTH_MAX_PX = 3; /** Round to a nice value (1, 2, 5, 10, 20, 50, etc.) */ function roundToNiceValue(value: number): number { @@ -53,8 +36,12 @@ function roundToNiceValue(value: number): number { /** Format scale bar label with appropriate unit */ function formatScaleLabel(value: number, unit: string): string { const nice = roundToNiceValue(value); - - if (unit === "nm") { + + if (unit === "Å") { + if (nice >= 10) return `${Math.round(nice / 10)} nm`; + if (nice >= 1) return `${Math.round(nice)} Å`; + return `${nice.toFixed(2)} Å`; + } else if (unit === "nm") { if (nice >= 1000) return `${Math.round(nice / 1000)} µm`; if (nice >= 1) return `${Math.round(nice)} nm`; return `${nice.toFixed(2)} nm`; @@ -79,7 +66,7 @@ function drawScaleBarHiDPI( dpr: number, zoom: number, pixelSize: number, - unit: string = "nm", + unit: string = "Å", imageWidth: number, // Original image width in pixels imageHeight: number // Original image height in pixels ) { @@ -349,8 +336,6 @@ function drawRoiOverlayHiDPI( ctx.stroke(); }; - const HANDLE_ANGLE = 0.707; // cos(45°) - if (roiMode === "circle" && radius > 0) { const screenRadius = radius * zoom * displayScale; @@ -368,7 +353,7 @@ function drawRoiOverlayHiDPI( drawCenterCrosshair(); // Resize handle at 45° - const handleOffset = screenRadius * HANDLE_ANGLE; + const handleOffset = screenRadius * CIRCLE_HANDLE_ANGLE; drawResizeHandle(screenX + handleOffset, screenY + handleOffset); } else if (roiMode === "square" && radius > 0) { @@ -434,11 +419,11 @@ function drawRoiOverlayHiDPI( drawCenterCrosshair(); // Outer handle - const handleOffsetOuter = screenRadiusOuter * HANDLE_ANGLE; + const handleOffsetOuter = screenRadiusOuter * CIRCLE_HANDLE_ANGLE; drawResizeHandle(screenX + handleOffsetOuter, screenY + handleOffsetOuter); // Inner handle - const handleOffsetInner = screenRadiusInner * HANDLE_ANGLE; + const handleOffsetInner = screenRadiusInner * CIRCLE_HANDLE_ANGLE; drawResizeHandle(screenX + handleOffsetInner, screenY + handleOffsetInner, true); } @@ -464,7 +449,7 @@ function Show4DSTEM() { const [, setRoiActive] = useModelState("roi_active"); const [pixelSize] = useModelState("pixel_size"); - const [detPixelSize] = useModelState("det_pixel_size"); + const [kPixelSize] = useModelState("k_pixel_size"); const [frameBytes] = useModelState("frame_bytes"); const [virtualImageBytes] = useModelState("virtual_image_bytes"); @@ -478,9 +463,6 @@ function Show4DSTEM() { // Display options const [logScale, setLogScale] = useModelState("log_scale"); - const [autoRange, setAutoRange] = useModelState("auto_range"); - const [percentileLow, setPercentileLow] = useModelState("percentile_low"); - const [percentileHigh, setPercentileHigh] = useModelState("percentile_high"); // Detector calibration (for presets) const [bfRadius] = useModelState("bf_radius"); @@ -717,22 +699,11 @@ function Show4DSTEM() { const height = shapeX; const renderData = (filtered: Float32Array) => { - // Normalize and render (with optional percentile contrast) + // Normalize and render let min = Infinity, max = -Infinity; - - if (autoRange) { - // Percentile-based contrast: sort a sample and pick percentile values - const sorted = Float32Array.from(filtered).sort((a, b) => a - b); - const lowIdx = Math.floor((percentileLow / 100) * sorted.length); - const highIdx = Math.floor((percentileHigh / 100) * sorted.length) - 1; - min = sorted[Math.max(0, lowIdx)]; - max = sorted[Math.min(sorted.length - 1, highIdx)]; - } else { - // Full range - for (let i = 0; i < filtered.length; i++) { - if (filtered[i] < min) min = filtered[i]; - if (filtered[i] > max) max = filtered[i]; - } + for (let i = 0; i < filtered.length; i++) { + if (filtered[i] < min) min = filtered[i]; + if (filtered[i] > max) max = filtered[i]; } const lut = COLORMAPS[colormap] || COLORMAPS.inferno; @@ -809,7 +780,7 @@ function Show4DSTEM() { } else { renderData(rawVirtualImageRef.current); } - }, [virtualImageBytes, shapeX, shapeY, colormap, viZoom, viPanX, viPanY, bpInner, bpOuter, gpuReady, autoRange, percentileLow, percentileHigh]); + }, [virtualImageBytes, shapeX, shapeY, colormap, viZoom, viPanX, viPanY, bpInner, bpOuter, gpuReady]); // Render virtual image overlay (just clear - crosshair drawn on high-DPI UI canvas) React.useEffect(() => { @@ -950,7 +921,7 @@ function Show4DSTEM() { React.useEffect(() => { if (!dpUiRef.current) return; // Draw scale bar first (clears canvas) - drawScaleBarHiDPI(dpUiRef.current, DPR, dpZoom, detPixelSize || 1, "mrad", detY, detX); + drawScaleBarHiDPI(dpUiRef.current, DPR, dpZoom, kPixelSize || 1, "mrad", detY, detX); // Draw ROI overlay (circle, square, rect, annular) or point crosshair if (roiMode === "point") { drawDpCrosshairHiDPI(dpUiRef.current, DPR, localKx, localKy, dpZoom, dpPanX, dpPanY, detY, detX, isDraggingDP); @@ -962,13 +933,13 @@ function Show4DSTEM() { isDraggingDP, isDraggingResize, isDraggingResizeInner, isHoveringResize, isHoveringResizeInner ); } - }, [dpZoom, dpPanX, dpPanY, detPixelSize, detX, detY, roiMode, roiRadius, roiRadiusInner, roiWidth, roiHeight, localKx, localKy, isDraggingDP, isDraggingResize, isDraggingResizeInner, isHoveringResize, isHoveringResizeInner]); + }, [dpZoom, dpPanX, dpPanY, kPixelSize, detX, detY, roiMode, roiRadius, roiRadiusInner, roiWidth, roiHeight, localKx, localKy, isDraggingDP, isDraggingResize, isDraggingResizeInner, isHoveringResize, isHoveringResizeInner]); // VI scale bar + crosshair (high-DPI) React.useEffect(() => { if (!viUiRef.current) return; // Draw scale bar first (clears canvas) - drawScaleBarHiDPI(viUiRef.current, DPR, viZoom, pixelSize || 1, "nm", shapeY, shapeX); + drawScaleBarHiDPI(viUiRef.current, DPR, viZoom, pixelSize || 1, "Å", shapeY, shapeX); // Then draw crosshair on top drawViCrosshairHiDPI(viUiRef.current, DPR, localPosX, localPosY, viZoom, viPanX, viPanY, shapeY, shapeX, isDraggingVI); }, [viZoom, viPanX, viPanY, pixelSize, shapeX, shapeY, localPosX, localPosY, isDraggingVI]); @@ -1258,16 +1229,13 @@ function Show4DSTEM() { display: { colormap: colormap, log_scale: logScale, - auto_range: autoRange, - percentile_low: percentileLow, - percentile_high: percentileHigh, }, calibration: { bf_radius: bfRadius, center_x: centerX, center_y: centerY, pixel_size: pixelSize, - det_pixel_size: detPixelSize, + k_pixel_size: kPixelSize, }, }; zip.file("metadata.json", JSON.stringify(metadata, null, 2)); @@ -1529,37 +1497,6 @@ function Show4DSTEM() { sx={switchStyles.medium} /> - - - Contrast: - setAutoRange(e.target.checked)} - size="small" - sx={switchStyles.small} - /> - {autoRange && ( - <> - { - const [low, high] = v as number[]; - setPercentileLow(low); - setPercentileHigh(high); - }} - min={0} - max={100} - size="small" - sx={{ width: 80 }} - valueLabelDisplay="auto" - valueLabelFormat={(v) => `${v}%`} - /> - - {Math.round(percentileLow || 1)}-{Math.round(percentileHigh || 99)}% - - - )} - diff --git a/widget/src/quantem/widget/__init__.py b/widget/src/quantem/widget/__init__.py index bc47cc42..d4c6c040 100644 --- a/widget/src/quantem/widget/__init__.py +++ b/widget/src/quantem/widget/__init__.py @@ -14,4 +14,4 @@ # Alias for convenience Show4D = Show4DSTEM -__all__ = ["Show4DSTEM", "Show4D"] +__all__ = ["Show4DSTEM"] diff --git a/widget/src/quantem/widget/show4dstem.py b/widget/src/quantem/widget/show4dstem.py index a7321445..66aad1a9 100644 --- a/widget/src/quantem/widget/show4dstem.py +++ b/widget/src/quantem/widget/show4dstem.py @@ -6,28 +6,22 @@ - Live statistics panel (mean/max/min) - Virtual detector overlays (BF/ADF circles) - Linked scan view (side-by-side) -- Auto-range with percentile scaling - ROI drawing tools +- Path animation (raster scan, custom paths) """ import pathlib -from collections.abc import Callable -from typing import TYPE_CHECKING import anywidget import numpy as np import traitlets from quantem.widget.array_utils import to_numpy +from quantem.core.datastructures import Dataset4dstem -if TYPE_CHECKING: - from quantem.detector import Detector - -# Detector geometry constants (ratios of detector size) +# Detector geometry constant DEFAULT_BF_RATIO = 0.125 # 1/8 of detector size -DEFAULT_ADF_INNER_RATIO = 0.1875 # 1.5 * BF = 3/16 of detector -DEFAULT_ADF_OUTER_RATIO = 0.375 # 3 * BF = 3/8 of detector class Show4DSTEM(anywidget.AnyWidget): @@ -39,58 +33,40 @@ class Show4DSTEM(anywidget.AnyWidget): Parameters ---------- - data : array_like - 4D array of shape (scan_x, scan_y, det_x, det_y). - Supports NumPy and PyTorch arrays. + data : Dataset4dstem or array_like + Dataset4dstem object (calibration auto-extracted) or 4D array + of shape (scan_x, scan_y, det_x, det_y). scan_shape : tuple, optional If data is flattened (N, det_x, det_y), provide scan dimensions. - detector : Detector, optional - Detector object from quantem.detector for automatic calibration. - If provided, center and bf_radius are extracted from the detector. - pixel_size : float, default 1.0 - Pixel size in nm (real-space). Used for scale bar. - det_pixel_size : float, default 1.0 + pixel_size : float, optional + Pixel size in Å (real-space). Used for scale bar. + Auto-extracted from Dataset4dstem if not provided. + k_pixel_size : float, optional Detector pixel size in mrad (k-space). Used for scale bar. + Auto-extracted from Dataset4dstem if not provided. center : tuple[float, float], optional (center_x, center_y) of the diffraction pattern in pixels. If not provided, defaults to detector center. bf_radius : float, optional Bright field disk radius in pixels. If not provided, estimated as 1/8 of detector size. - log_scale : bool, default True + log_scale : bool, default False Use log scale for better dynamic range visualization. - auto_range : bool, default False - Use percentile-based scaling instead of global min/max. - percentile_low : float, default 1.0 - Lower percentile for auto-range (0-100). - percentile_high : float, default 99.0 - Upper percentile for auto-range (0-100). - path_points : list[tuple[int, int]], optional - List of (x, y) scan positions for programmatic animation. - Use with play(), pause(), stop() methods. - path_interval_ms : int, default 100 - Time between frames in path animation (milliseconds). - path_loop : bool, default True - Whether to loop when path animation reaches the end. Examples -------- - >>> from quantem.widget import Show4DSTEM + >>> # From Dataset4dstem (calibration auto-extracted) + >>> from quantem.core.io.file_readers import read_emdfile_to_4dstem + >>> dataset = read_emdfile_to_4dstem("data.h5") + >>> Show4DSTEM(dataset) + + >>> # From raw array with manual calibration >>> import numpy as np >>> data = np.random.rand(64, 64, 128, 128) - >>> Show4DSTEM(data) - - >>> # With manual calibration - >>> Show4DSTEM(data, pixel_size=0.5, det_pixel_size=0.1, bf_radius=20) + >>> Show4DSTEM(data, pixel_size=2.39, k_pixel_size=0.46) - >>> # With Detector object (optional dependency) - >>> from quantem.detector import Detector - >>> det = Detector("data.h5") - >>> Show4DSTEM(det.data, detector=det) # Uses det.center, det.bf_radius - - >>> # With path animation - >>> path = [(i, i) for i in range(64)] # Diagonal path - >>> widget = Show4DSTEM(data, path_points=path, path_interval_ms=50) - >>> widget.play() # Start animation + >>> # With raster animation + >>> widget = Show4DSTEM(dataset) + >>> widget.raster(step=2, interval_ms=50) """ _esm = pathlib.Path(__file__).parent / "static" / "show4dstem.js" @@ -112,66 +88,31 @@ class Show4DSTEM(anywidget.AnyWidget): frame_bytes = traitlets.Bytes(b"").tag(sync=True) # Log scale toggle - log_scale = traitlets.Bool(True).tag(sync=True) + log_scale = traitlets.Bool(False).tag(sync=True) # ========================================================================= - # Stats Panel + # Detector Calibration (for presets and scale bar) # ========================================================================= - stats_mean = traitlets.Float(0.0).tag(sync=True) - stats_max = traitlets.Float(0.0).tag(sync=True) - stats_min = traitlets.Float(0.0).tag(sync=True) - show_stats = traitlets.Bool(True).tag(sync=True) - - # ========================================================================= - # Detector Integration (BF/ADF overlays) - # ========================================================================= - has_detector = traitlets.Bool(False).tag(sync=True) - center_x = traitlets.Float(0.0).tag(sync=True) - center_y = traitlets.Float(0.0).tag(sync=True) - bf_radius = traitlets.Float(0.0).tag(sync=True) - show_bf_overlay = traitlets.Bool(True).tag(sync=True) - show_adf_overlay = traitlets.Bool(False).tag(sync=True) - adf_inner_radius = traitlets.Float(0.0).tag(sync=True) - adf_outer_radius = traitlets.Float(0.0).tag(sync=True) - - # ========================================================================= - # Linked Scan View - # ========================================================================= - show_scan_view = traitlets.Bool(False).tag(sync=True) - scan_mode = traitlets.Unicode("bf").tag(sync=True) # 'bf', 'adf', 'custom' - scan_image_bytes = traitlets.Bytes(b"").tag(sync=True) - - # ========================================================================= - # Auto-Range (percentile scaling) - # ========================================================================= - auto_range = traitlets.Bool(False).tag(sync=True) - percentile_low = traitlets.Float(1.0).tag(sync=True) - percentile_high = traitlets.Float(99.0).tag(sync=True) + center_x = traitlets.Float(0.0).tag(sync=True) # Detector center X + center_y = traitlets.Float(0.0).tag(sync=True) # Detector center Y + bf_radius = traitlets.Float(0.0).tag(sync=True) # BF disk radius (pixels) # ========================================================================= # ROI Drawing (for virtual imaging) + # roi_radius is multi-purpose by mode: + # - circle: radius of circle + # - square: half-size (distance from center to edge) + # - annular: outer radius (roi_radius_inner = inner radius) + # - rect: uses roi_width/roi_height instead # ========================================================================= roi_active = traitlets.Bool(False).tag(sync=True) - roi_mode = traitlets.Unicode("point").tag(sync=True) # 'point', 'circle', 'square', 'rect', or 'annular' + roi_mode = traitlets.Unicode("point").tag(sync=True) roi_center_x = traitlets.Float(0.0).tag(sync=True) roi_center_y = traitlets.Float(0.0).tag(sync=True) - roi_radius = traitlets.Float(10.0).tag(sync=True) # Outer radius for circle/annular, half-width for square - roi_radius_inner = traitlets.Float(5.0).tag(sync=True) # Inner radius for annular mode - roi_width = traitlets.Float(20.0).tag(sync=True) # Width for rectangular mode - roi_height = traitlets.Float(10.0).tag(sync=True) # Height for rectangular mode - roi_integrated_value = traitlets.Float(0.0).tag(sync=True) - - # ========================================================================= - # Mean Diffraction Pattern - # ========================================================================= - mean_dp_bytes = traitlets.Bytes(b"").tag(sync=True) - show_mean_dp = traitlets.Bool(True).tag(sync=True) - - # ========================================================================= - # BF Image (Bright Field integrated image) - # ========================================================================= - bf_image_bytes = traitlets.Bytes(b"").tag(sync=True) - show_bf_image = traitlets.Bool(True).tag(sync=True) + roi_radius = traitlets.Float(10.0).tag(sync=True) + roi_radius_inner = traitlets.Float(5.0).tag(sync=True) + roi_width = traitlets.Float(20.0).tag(sync=True) + roi_height = traitlets.Float(10.0).tag(sync=True) # ========================================================================= # Virtual Image (ROI-based, updates as you drag ROI on DP) @@ -181,8 +122,8 @@ class Show4DSTEM(anywidget.AnyWidget): # ========================================================================= # Scale Bar # ========================================================================= - pixel_size = traitlets.Float(1.0).tag(sync=True) # nm per pixel (real-space) - det_pixel_size = traitlets.Float(1.0).tag(sync=True) # mrad per pixel (k-space) + pixel_size = traitlets.Float(1.0).tag(sync=True) # Å per pixel (real-space) + k_pixel_size = traitlets.Float(1.0).tag(sync=True) # mrad per pixel (k-space) # ========================================================================= # Path Animation (programmatic crosshair control) @@ -195,43 +136,36 @@ class Show4DSTEM(anywidget.AnyWidget): def __init__( self, - data, + data: "Dataset4dstem | np.ndarray", scan_shape: tuple[int, int] | None = None, - detector: "Detector | None" = None, - pixel_size: float = 1.0, - det_pixel_size: float = 1.0, + pixel_size: float | None = None, + k_pixel_size: float | None = None, center: tuple[float, float] | None = None, bf_radius: float | None = None, - log_scale: bool = True, - auto_range: bool = False, - percentile_low: float = 1.0, - percentile_high: float = 99.0, - path_points: list[tuple[int, int]] | None = None, - path_interval_ms: int = 100, - path_loop: bool = True, + log_scale: bool = False, **kwargs, ): super().__init__(**kwargs) - self._log_scale = log_scale self.log_scale = log_scale - self.auto_range = auto_range - self.percentile_low = percentile_low - self.percentile_high = percentile_high - - # Store calibration values - self.pixel_size = pixel_size - self.det_pixel_size = det_pixel_size - - # Path animation settings - self._path_points: list[tuple[int, int]] = path_points or [] - self.path_length = len(self._path_points) - self.path_interval_ms = path_interval_ms - self.path_loop = path_loop + # Extract calibration from Dataset4dstem if provided + if hasattr(data, "sampling") and hasattr(data, "array"): + # Dataset4dstem: extract calibration and array + # sampling = [scan_x, scan_y, det_x, det_y] + if pixel_size is None: + pixel_size = float(data.sampling[0]) + if k_pixel_size is None: + k_pixel_size = float(data.sampling[2]) + data = data.array + + # Store calibration values (default to 1.0 if not provided) + self.pixel_size = pixel_size if pixel_size is not None else 1.0 + self.k_pixel_size = k_pixel_size if k_pixel_size is not None else 1.0 + # Path animation (configured via set_path() or raster()) + self._path_points: list[tuple[int, int]] = [] # Convert to NumPy self._data = to_numpy(data) - # Handle flattened data if data.ndim == 3: if scan_shape is not None: @@ -257,61 +191,34 @@ def __init__( self.shape_y = self._scan_shape[1] self.det_x = self._det_shape[0] self.det_y = self._det_shape[1] - # Initial position at center self.pos_x = self.shape_x // 2 self.pos_y = self.shape_y // 2 - # Precompute global range for consistent scaling self._compute_global_range() - - # Setup center and BF radius: detector > user params > defaults - if detector is not None: - # Use Detector object for calibration - det_center = detector.center - self.center_x = float(det_center[0]) - self.center_y = float(det_center[1]) - self.bf_radius = float(detector.bf_radius) - elif center is not None: - # Use user-provided center + # Setup center and BF/ADF radii based on detector size + det_size = min(self.det_x, self.det_y) + if center is not None: self.center_x = float(center[0]) self.center_y = float(center[1]) - det_size = min(self.det_x, self.det_y) - self.bf_radius = float(bf_radius) if bf_radius is not None else det_size * DEFAULT_BF_RATIO else: - # Default to detector center self.center_x = float(self.det_y / 2) self.center_y = float(self.det_x / 2) - det_size = min(self.det_x, self.det_y) - self.bf_radius = float(bf_radius) if bf_radius is not None else det_size * DEFAULT_BF_RATIO - - # Set ADF defaults based on detector size - det_size = min(self.det_x, self.det_y) - self.adf_inner_radius = det_size * DEFAULT_ADF_INNER_RATIO - self.adf_outer_radius = det_size * DEFAULT_ADF_OUTER_RATIO - self.has_detector = True + self.bf_radius = float(bf_radius) if bf_radius is not None else det_size * DEFAULT_BF_RATIO - # Compute mean DP and BF image (sent once on init) - self._compute_mean_dp() - self._compute_bf_image() - # Pre-compute and cache common virtual images (BF, ABF, LAADF, HAADF) self._cached_bf_virtual = None self._cached_abf_virtual = None self._cached_laadf_virtual = None self._cached_haadf_virtual = None self._precompute_common_virtual_images() - + # Update frame when position or settings change - self.observe(self._update_frame, names=[ - "pos_x", "pos_y", "log_scale", "auto_range", - "percentile_low", "percentile_high" - ]) + self.observe(self._update_frame, names=["pos_x", "pos_y", "log_scale"]) self.observe(self._on_roi_change, names=[ - "roi_center_x", "roi_center_y", "roi_radius", "roi_radius_inner", + "roi_center_x", "roi_center_y", "roi_radius", "roi_radius_inner", "roi_active", "roi_mode", "roi_width", "roi_height" ]) - self.observe(self._on_scan_mode_change, names=["scan_mode", "show_scan_view"]) # Initialize default ROI at BF center self.roi_center_x = self.center_x @@ -330,19 +237,36 @@ def __init__( # Path animation: observe index changes from frontend self.observe(self._on_path_index_change, names=["path_index"]) + def __repr__(self) -> str: + return ( + f"Show4DSTEM(shape=({self.shape_x}, {self.shape_y}, {self.det_x}, {self.det_y}), " + f"sampling=({self.pixel_size} Å, {self.k_pixel_size} mrad), " + f"pos=({self.pos_x}, {self.pos_y}))" + ) + # ========================================================================= - # Array Utilities + # Convenience Properties # ========================================================================= - - @staticmethod - def _to_cpu(arr): - """Convert array to CPU (NumPy).""" - return np.asarray(arr) - - @staticmethod - def _to_scalar(val): - """Convert scalar value to Python float.""" - return float(val) + + @property + def position(self) -> tuple[int, int]: + """Current scan position as (x, y) tuple.""" + return (self.pos_x, self.pos_y) + + @position.setter + def position(self, value: tuple[int, int]) -> None: + """Set scan position from (x, y) tuple.""" + self.pos_x, self.pos_y = value + + @property + def scan_shape(self) -> tuple[int, int]: + """Scan dimensions as (shape_x, shape_y) tuple.""" + return (self.shape_x, self.shape_y) + + @property + def detector_shape(self) -> tuple[int, int]: + """Detector dimensions as (det_x, det_y) tuple.""" + return (self.det_x, self.det_y) # ========================================================================= # Path Animation Methods @@ -350,81 +274,42 @@ def _to_scalar(val): def set_path( self, - points: list[tuple[int, int]] | None = None, - generator: "Callable[[int, int, int], tuple[int, int]] | None" = None, - n_frames: int | None = None, - interval_ms: int | None = None, - loop: bool | None = None, + points: list[tuple[int, int]], + interval_ms: int = 100, + loop: bool = True, autoplay: bool = True, ) -> "Show4DSTEM": """ - Set a path of scan positions to animate through. - - You can provide either a list of points OR a generator function. - + Set a custom path of scan positions to animate through. + Parameters ---------- - points : list[tuple[int, int]], optional + points : list[tuple[int, int]] List of (x, y) scan positions to visit. - generator : callable, optional - Custom function with signature `f(index, shape_x, shape_y) -> (x, y)`. - Called for each frame to get the next position. - n_frames : int, optional - Number of frames when using generator. Required if using generator. - interval_ms : int, optional - Time between frames in milliseconds. Default 100ms. - loop : bool, optional - Whether to loop when reaching end. Default True. + interval_ms : int, default 100 + Time between frames in milliseconds. + loop : bool, default True + Whether to loop when reaching end. autoplay : bool, default True Start playing immediately. - + Returns ------- Show4DSTEM Self for method chaining. - + Examples -------- - >>> # Option 1: List of points - >>> path = [(0, 0), (10, 10), (20, 20), (30, 30)] - >>> widget.set_path(points=path) - - >>> # Option 2: Custom generator function - >>> def my_path(i, sx, sy): - ... # Random walk - ... import random - ... return (random.randint(0, sx-1), random.randint(0, sy-1)) - >>> widget.set_path(generator=my_path, n_frames=100) - - >>> # Option 3: Lambda for quick patterns - >>> widget.set_path( - ... generator=lambda i, sx, sy: (i % sx, (i * 3) % sy), - ... n_frames=200 - ... ) + >>> widget.set_path([(0, 0), (10, 10), (20, 20), (30, 30)]) + >>> widget.set_path([(i, i) for i in range(48)], interval_ms=50) """ - if generator is not None: - # Use generator function to create points - if n_frames is None: - n_frames = 100 # Default - self._path_points = [ - generator(i, self.shape_x, self.shape_y) - for i in range(n_frames) - ] - elif points is not None: - self._path_points = list(points) - else: - raise ValueError("Must provide either 'points' or 'generator'") - + self._path_points = list(points) self.path_length = len(self._path_points) self.path_index = 0 - - if interval_ms is not None: - self.path_interval_ms = interval_ms - if loop is not None: - self.path_loop = loop + self.path_interval_ms = interval_ms + self.path_loop = loop if autoplay and self.path_length > 0: self.path_playing = True - return self def play(self) -> "Show4DSTEM": @@ -462,13 +347,19 @@ def _on_path_index_change(self, change): # ========================================================================= # Path Animation Patterns # ========================================================================= - - def play_raster(self, step: int = 1, bidirectional: bool = False) -> "Show4DSTEM": + + def raster( + self, + step: int = 1, + bidirectional: bool = False, + interval_ms: int = 100, + loop: bool = True, + ) -> "Show4DSTEM": """ Play a raster scan path (row by row, left to right). - + This mimics real STEM scanning: left→right, step down, left→right, etc. - + Parameters ---------- step : int, default 1 @@ -476,7 +367,11 @@ def play_raster(self, step: int = 1, bidirectional: bool = False) -> "Show4DSTEM bidirectional : bool, default False If True, use snake/boustrophedon pattern (alternating direction). If False (default), always scan left→right like real STEM. - + interval_ms : int, default 100 + Time between frames in milliseconds. + loop : bool, default True + Whether to loop when reaching the end. + Returns ------- Show4DSTEM @@ -489,13 +384,13 @@ def play_raster(self, step: int = 1, bidirectional: bool = False) -> "Show4DSTEM row = row[::-1] # Alternate direction for snake pattern for y in row: points.append((x, y)) - return self.set_path(points=points) + return self.set_path(points=points, interval_ms=interval_ms, loop=loop) # ========================================================================= # ROI Mode Methods # ========================================================================= - def set_roi_circle(self, radius: float | None = None) -> "Show4DSTEM": + def roi_circle(self, radius: float | None = None) -> "Show4DSTEM": """ Switch to circle ROI mode for virtual imaging. @@ -515,15 +410,15 @@ def set_roi_circle(self, radius: float | None = None) -> "Show4DSTEM": Examples -------- - >>> widget.set_roi_circle(20) # 20px radius circle - >>> widget.set_roi_circle() # Use default radius + >>> widget.roi_circle(20) # 20px radius circle + >>> widget.roi_circle() # Use default radius """ self.roi_mode = "circle" if radius is not None: self.roi_radius = float(radius) return self - def set_roi_point(self) -> "Show4DSTEM": + def roi_point(self) -> "Show4DSTEM": """ Switch to point ROI mode (single-pixel indexing). @@ -538,35 +433,36 @@ def set_roi_point(self) -> "Show4DSTEM": self.roi_mode = "point" return self - def set_roi_square(self, size: float | None = None) -> "Show4DSTEM": + def roi_square(self, half_size: float | None = None) -> "Show4DSTEM": """ Switch to square ROI mode for virtual imaging. - + In square mode, the virtual image integrates over a square region centered at the current ROI position. - + Parameters ---------- - size : float, optional + half_size : float, optional Half-size of the square in pixels (distance from center to edge). + A half_size of 15 creates a 30x30 pixel square. If not provided, uses current roi_radius value. - + Returns ------- Show4DSTEM Self for method chaining. - + Examples -------- - >>> widget.set_roi_square(15) # 30x30 pixel square - >>> widget.set_roi_square() # Use default size + >>> widget.roi_square(15) # 30x30 pixel square (half_size=15) + >>> widget.roi_square() # Use default size """ self.roi_mode = "square" - if size is not None: - self.roi_radius = float(size) + if half_size is not None: + self.roi_radius = float(half_size) return self - def set_roi_annular( + def roi_annular( self, inner_radius: float | None = None, outer_radius: float | None = None ) -> "Show4DSTEM": """ @@ -586,8 +482,8 @@ def set_roi_annular( Examples -------- - >>> widget.set_roi_annular(20, 50) # ADF: inner=20px, outer=50px - >>> widget.set_roi_annular(30, 80) # HAADF: larger angles + >>> widget.roi_annular(20, 50) # ADF: inner=20px, outer=50px + >>> widget.roi_annular(30, 80) # HAADF: larger angles """ self.roi_mode = "annular" if inner_radius is not None: @@ -596,7 +492,7 @@ def set_roi_annular( self.roi_radius = float(outer_radius) return self - def set_roi_rect( + def roi_rect( self, width: float | None = None, height: float | None = None ) -> "Show4DSTEM": """ @@ -616,8 +512,8 @@ def set_roi_rect( Examples -------- - >>> widget.set_roi_rect(30, 20) # 30px wide, 20px tall - >>> widget.set_roi_rect(40, 40) # 40x40 rectangle + >>> widget.roi_rect(30, 20) # 30px wide, 20px tall + >>> widget.roi_rect(40, 40) # 40x40 rectangle """ self.roi_mode = "rect" if width is not None: @@ -639,26 +535,19 @@ def _compute_global_range(self): ] all_min, all_max = float("inf"), float("-inf") - all_values = [] for x, y in samples: frame = self._get_frame(x, y) fmin = float(frame.min()) fmax = float(frame.max()) all_min = min(all_min, fmin) all_max = max(all_max, fmax) - - # Sample values for percentile estimation - all_values.append(self._to_cpu(frame).flatten()[::100]) - + self._global_min = max(all_min, 1e-10) self._global_max = all_max - + # Precompute log range self._log_min = np.log1p(self._global_min) self._log_max = np.log1p(self._global_max) - - # Store sampled values for percentile computation - self._sampled_values = np.concatenate(all_values) def _get_frame(self, x: int, y: int): """Get single diffraction frame at position (x, y).""" @@ -668,36 +557,15 @@ def _get_frame(self, x: int, y: int): else: return self._data[x, y] - def _compute_percentile_range(self, frame): - """Compute percentile-based range for a frame.""" - - # Use NumPy for percentile (faster for small arrays) - frame_np = self._to_cpu(frame).flatten() - - vmin = float(np.percentile(frame_np, self.percentile_low)) - vmax = float(np.percentile(frame_np, self.percentile_high)) - return max(vmin, 1e-10), vmax - def _update_frame(self, change=None): - """Send pre-normalized uint8 frame to frontend with stats.""" + """Send pre-normalized uint8 frame to frontend.""" frame = self._get_frame(self.pos_x, self.pos_y) - # Compute stats - self.stats_mean = self._to_scalar(frame.mean()) - self.stats_max = self._to_scalar(frame.max()) - self.stats_min = self._to_scalar(frame.min()) - # Determine value range - if self.auto_range: - vmin, vmax = self._compute_percentile_range(frame) - if self.log_scale: - vmin = np.log1p(vmin) - vmax = np.log1p(vmax) + if self.log_scale: + vmin, vmax = self._log_min, self._log_max else: - if self.log_scale: - vmin, vmax = self._log_min, self._log_max - else: - vmin, vmax = self._global_min, self._global_max + vmin, vmax = self._global_min, self._global_max # Apply log scale if enabled if self.log_scale: @@ -716,107 +584,11 @@ def _update_frame(self, change=None): self.frame_bytes = normalized.tobytes() def _on_roi_change(self, change=None): - """Compute integrated value when ROI changes.""" - # Skip if ROI is not active or has no valid size + """Recompute virtual image when ROI changes.""" if not self.roi_active: - self.roi_integrated_value = 0.0 - return - - # For circle/square/annular modes, need positive radius - if self.roi_mode in ("circle", "square", "annular") and self.roi_radius <= 0: - self.roi_integrated_value = 0.0 - return - - frame = self._get_frame(self.pos_x, self.pos_y) - - # Create mask based on ROI mode - if self.roi_mode == "circle" and self.roi_radius > 0: - mask = self._create_circular_mask( - self.roi_center_x, self.roi_center_y, self.roi_radius - ) - elif self.roi_mode == "square" and self.roi_radius > 0: - mask = self._create_square_mask( - self.roi_center_x, self.roi_center_y, self.roi_radius - ) - elif self.roi_mode == "annular" and self.roi_radius > 0: - mask = self._create_annular_mask( - self.roi_center_x, self.roi_center_y, - self.roi_radius_inner, self.roi_radius - ) - elif self.roi_mode == "rect" and self.roi_width > 0 and self.roi_height > 0: - mask = self._create_rect_mask( - self.roi_center_x, self.roi_center_y, - self.roi_width / 2, self.roi_height / 2 - ) - else: - # Point mode: no mask, just single pixel - self.roi_integrated_value = 0.0 - self._compute_virtual_image_from_roi() - return - - # Compute integrated value (use multiplication to avoid indexing issues) - integrated = self._to_scalar((frame * mask).sum()) - self.roi_integrated_value = integrated - - # Fast path: check if we can use cached preset - cached = self._get_cached_preset() - if cached is not None: - self.virtual_image_bytes = cached return - - # Real-time update using fast masked sum self._compute_virtual_image_from_roi() - def _on_scan_mode_change(self, change=None): - """Recompute scan image when mode changes.""" - if self.show_scan_view and self.has_detector: - self._compute_scan_image() - - def _compute_scan_image(self): - """Compute virtual detector image (BF or ADF).""" - - # Get appropriate mask - if self.scan_mode == "bf": - mask = self._create_circular_mask( - self.center_x, self.center_y, self.bf_radius - ) - elif self.scan_mode == "adf": - mask = self._create_annular_mask( - self.center_x, self.center_y, - self.adf_inner_radius, self.adf_outer_radius - ) - else: - # Custom ROI mask - if self.roi_active and self.roi_radius > 0: - mask = self._create_circular_mask( - self.roi_center_x, self.roi_center_y, self.roi_radius - ) - else: - return - - # Compute integrated image - if self._data.ndim == 4: - # (Rx, Ry, Qx, Qy) -> Apply mask and sum - scan_image = (self._data * mask).sum(axis=(-2, -1)) - else: - # (N, Qx, Qy) -> reshape and sum - scan_image = (self._data * mask).sum(axis=(-2, -1)) - scan_image = scan_image.reshape(self._scan_shape) - - # Normalize to uint8 - smin = self._to_scalar(scan_image.min()) - smax = self._to_scalar(scan_image.max()) - - if smax > smin: - normalized = np.clip((scan_image - smin) / (smax - smin) * 255, 0, 255) - normalized = normalized.astype(np.uint8) - else: - normalized = np.zeros(scan_image.shape, dtype=np.uint8) - - normalized = self._to_cpu(normalized) - - self.scan_image_bytes = normalized.tobytes() - def _create_circular_mask(self, cx: float, cy: float, radius: float): """Create circular mask (boolean).""" y, x = np.ogrid[:self.det_x, :self.det_y] @@ -844,64 +616,19 @@ def _create_rect_mask(self, cx: float, cy: float, half_width: float, half_height mask = (np.abs(x - cx) <= half_width) & (np.abs(y - cy) <= half_height) return mask - def _compute_mean_dp(self): - """Compute and send mean diffraction pattern.""" - if self._data.ndim == 4: - mean_dp = self._data.mean(axis=(0, 1)) - else: - mean_dp = self._data.mean(axis=0) - # Log scale - mean_dp = np.log1p(mean_dp) - # Normalize to uint8 - mean_dp_cpu = self._to_cpu(mean_dp) - vmin, vmax = float(mean_dp_cpu.min()), float(mean_dp_cpu.max()) - if vmax > vmin: - normalized = np.clip((mean_dp_cpu - vmin) / (vmax - vmin) * 255, 0, 255).astype(np.uint8) - else: - normalized = np.zeros(mean_dp_cpu.shape, dtype=np.uint8) - self.mean_dp_bytes = normalized.tobytes() - - def _compute_bf_image(self): - """Compute BF integrated image using detected probe.""" - - # Create BF mask - mask = self._create_circular_mask(self.center_x, self.center_y, self.bf_radius) - - # Compute integrated BF image - if self._data.ndim == 4: - bf_image = (self._data * mask).sum(axis=(-2, -1)) - else: - bf_image = (self._data * mask).sum(axis=(-2, -1)) - bf_image = bf_image.reshape(self._scan_shape) - - # Normalize to uint8 - vmin = self._to_scalar(bf_image.min()) - vmax = self._to_scalar(bf_image.max()) - - if vmax > vmin: - normalized = np.clip((bf_image - vmin) / (vmax - vmin) * 255, 0, 255) - normalized = normalized.astype(np.uint8) - else: - normalized = np.zeros(bf_image.shape, dtype=np.uint8) - - normalized = self._to_cpu(normalized) - - self.bf_image_bytes = normalized.tobytes() - def _precompute_common_virtual_images(self): """Pre-compute BF/ABF/LAADF/HAADF virtual images for instant mode switching.""" - def _compute_and_normalize(mask): if self._data.ndim == 4: img = (self._data * mask).sum(axis=(-2, -1)) else: img = (self._data * mask).sum(axis=(-2, -1)).reshape(self._scan_shape) - vmin, vmax = self._to_scalar(img.min()), self._to_scalar(img.max()) + vmin, vmax = float(img.min()), float(img.max()) if vmax > vmin: norm = np.clip((img - vmin) / (vmax - vmin) * 255, 0, 255).astype(np.uint8) else: norm = np.zeros(img.shape, dtype=np.uint8) - return self._to_cpu(norm).tobytes() + return norm.tobytes() # BF: circle at bf_radius bf_mask = self._create_circular_mask(self.center_x, self.center_y, self.bf_radius) @@ -1001,25 +728,20 @@ def _compute_virtual_image_from_roi(self): ) virtual_image = self._fast_masked_sum(mask) else: - # Point mode: single-pixel indexing - O(1) on GPU! - kx = int(max(0, min(self._det_shape[0] - 1, round(self.roi_center_y)))) - ky = int(max(0, min(self._det_shape[1] - 1, round(self.roi_center_x)))) - + # Point mode: single-pixel indexing + # Array indexing: [row, col] where row=det_y, col=det_x in image coords + row = int(max(0, min(self._det_shape[0] - 1, round(self.roi_center_y)))) + col = int(max(0, min(self._det_shape[1] - 1, round(self.roi_center_x)))) if self._data.ndim == 4: - virtual_image = self._data[:, :, kx, ky] + virtual_image = self._data[:, :, row, col] else: - virtual_image = self._data[:, kx, ky].reshape(self._scan_shape) - + virtual_image = self._data[:, row, col].reshape(self._scan_shape) + # Normalize to uint8 - vmin = self._to_scalar(virtual_image.min()) - vmax = self._to_scalar(virtual_image.max()) - + vmin, vmax = float(virtual_image.min()), float(virtual_image.max()) if vmax > vmin: normalized = np.clip((virtual_image - vmin) / (vmax - vmin) * 255, 0, 255) normalized = normalized.astype(np.uint8) else: normalized = np.zeros(virtual_image.shape, dtype=np.uint8) - - normalized_cpu = self._to_cpu(normalized) - - self.virtual_image_bytes = normalized_cpu.tobytes() + self.virtual_image_bytes = normalized.tobytes() diff --git a/widget/tests/test_widget.py b/widget/tests/test_widget.py index 2c6b6b81..9a880b76 100644 --- a/widget/tests/test_widget.py +++ b/widget/tests/test_widget.py @@ -17,3 +17,71 @@ def test_show4dstem_loads(): data = np.random.rand(8, 8, 16, 16).astype(np.float32) widget = Show4DSTEM(data) assert widget is not None + + +def test_show4dstem_flattened_scan_shape_mapping(): + data = np.zeros((6, 2, 2), dtype=np.float32) + for idx in range(data.shape[0]): + data[idx] = idx + + widget = Show4DSTEM(data, scan_shape=(2, 3)) + assert (widget.shape_x, widget.shape_y) == (2, 3) + assert (widget.det_x, widget.det_y) == (2, 2) + frame = widget._get_frame(1, 2) + assert np.array_equal(frame, np.full((2, 2), 5, dtype=np.float32)) + + +def test_roi_circle_integrated_value(): + data = np.zeros((1, 1, 5, 5), dtype=np.float32) + rows = np.arange(5, dtype=np.float32)[:, None] + cols = np.arange(5, dtype=np.float32)[None, :] + data[0, 0] = rows * 10 + cols + widget = Show4DSTEM(data, center=(2, 2), bf_radius=1, log_scale=False) + widget.roi_mode = "circle" + widget.roi_center_x = 2 + widget.roi_center_y = 2 + widget.roi_radius = 1 + widget.roi_active = True + widget._on_roi_change() + assert np.isclose(widget.roi_integrated_value, 110.0) + + +def test_scan_image_bf_mode(): + base = np.array( + [ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9], + ], + dtype=np.float32, + ) + data = np.zeros((2, 2, 3, 3), dtype=np.float32) + for x in range(2): + for y in range(2): + data[x, y] = base + x * 100 + y * 10 + + widget = Show4DSTEM(data, center=(1, 1), bf_radius=1, log_scale=False) + widget.scan_mode = "bf" + widget.show_scan_view = True + widget._compute_scan_image() + + actual = np.frombuffer(widget.scan_image_bytes, dtype=np.uint8).reshape(2, 2) + scan_image = np.array([[25, 75], [525, 575]], dtype=np.float32) + expected = np.clip( + (scan_image - scan_image.min()) / (scan_image.max() - scan_image.min()) * 255, + 0, + 255, + ).astype(np.uint8) + assert np.array_equal(actual, expected) + + +def test_log_scale_changes_frame_bytes(): + data = np.array([[[[0, 1], [3, 7]]]], dtype=np.float32) + widget = Show4DSTEM(data, log_scale=True) + log_bytes = bytes(widget.frame_bytes) + + widget.log_scale = False + widget._update_frame() + linear_bytes = bytes(widget.frame_bytes) + + assert log_bytes != linear_bytes From 5874d8944d417b60bea54ec27d28980dac3cf873 Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Mon, 12 Jan 2026 10:30:30 -0800 Subject: [PATCH 05/27] remove extra fft overlay --- widget/js/show4dstem.tsx | 28 +++---------------------- widget/src/quantem/widget/show4dstem.py | 8 +++---- 2 files changed, 7 insertions(+), 29 deletions(-) diff --git a/widget/js/show4dstem.tsx b/widget/js/show4dstem.tsx index 46b799b9..0f9f7bcf 100644 --- a/widget/js/show4dstem.tsx +++ b/widget/js/show4dstem.tsx @@ -34,25 +34,18 @@ function roundToNiceValue(value: number): number { } /** Format scale bar label with appropriate unit */ -function formatScaleLabel(value: number, unit: string): string { +function formatScaleLabel(value: number, unit: "Å" | "mrad"): string { const nice = roundToNiceValue(value); if (unit === "Å") { if (nice >= 10) return `${Math.round(nice / 10)} nm`; if (nice >= 1) return `${Math.round(nice)} Å`; return `${nice.toFixed(2)} Å`; - } else if (unit === "nm") { - if (nice >= 1000) return `${Math.round(nice / 1000)} µm`; - if (nice >= 1) return `${Math.round(nice)} nm`; - return `${nice.toFixed(2)} nm`; - } else if (unit === "mrad") { + } + if (unit === "mrad") { if (nice >= 1000) return `${Math.round(nice / 1000)} rad`; if (nice >= 1) return `${Math.round(nice)} mrad`; return `${nice.toFixed(2)} mrad`; - } else if (unit === "1/µm") { - if (nice >= 1000) return `${Math.round(nice / 1000)} 1/nm`; - if (nice >= 1) return `${Math.round(nice)} 1/µm`; - return `${nice.toFixed(2)} 1/µm`; } return `${Math.round(nice)} ${unit}`; } @@ -607,7 +600,6 @@ function Show4DSTEM() { const viUiRef = React.useRef(null); // High-DPI UI overlay for scale bar const fftCanvasRef = React.useRef(null); const fftOverlayRef = React.useRef(null); - const fftUiRef = React.useRef(null); // High-DPI UI overlay for scale bar // Display size for high-DPI UI overlays const UI_SIZE = 400; @@ -944,13 +936,6 @@ function Show4DSTEM() { drawViCrosshairHiDPI(viUiRef.current, DPR, localPosX, localPosY, viZoom, viPanX, viPanY, shapeY, shapeX, isDraggingVI); }, [viZoom, viPanX, viPanY, pixelSize, shapeX, shapeY, localPosX, localPosY, isDraggingVI]); - // FFT - no scale bar (just clear canvas) - React.useEffect(() => { - if (!fftUiRef.current) return; - const ctx = fftUiRef.current.getContext("2d"); - if (ctx) ctx.clearRect(0, 0, fftUiRef.current.width, fftUiRef.current.height); - }, [fftZoom, shapeX, shapeY]); - // Generic zoom handler const createZoomHandler = ( setZoom: React.Dispatch>, @@ -1313,13 +1298,6 @@ function Show4DSTEM() { onDoubleClick={handleFftDoubleClick} style={{ position: "absolute", width: "100%", height: "100%", cursor: isDraggingFFT ? "grabbing" : "grab" }} /> - {/* High-DPI UI overlay for crisp scale bar */} - diff --git a/widget/src/quantem/widget/show4dstem.py b/widget/src/quantem/widget/show4dstem.py index 66aad1a9..c471be04 100644 --- a/widget/src/quantem/widget/show4dstem.py +++ b/widget/src/quantem/widget/show4dstem.py @@ -3,21 +3,21 @@ Features: - Binary transfer (no base64 overhead) -- Live statistics panel (mean/max/min) -- Virtual detector overlays (BF/ADF circles) -- Linked scan view (side-by-side) - ROI drawing tools - Path animation (raster scan, custom paths) """ import pathlib +from typing import TYPE_CHECKING import anywidget import numpy as np import traitlets from quantem.widget.array_utils import to_numpy -from quantem.core.datastructures import Dataset4dstem + +if TYPE_CHECKING: + from quantem.core.datastructures import Dataset4dstem # Detector geometry constant From 3b88c639d446596a9a2f2c8e426fb2499c79daa6 Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Mon, 12 Jan 2026 15:58:27 -0800 Subject: [PATCH 06/27] pre-commute BF/ABF, etc. for instant preset switching --- widget/js/show4dstem.tsx | 161 ++++++++++++++++++------ widget/src/quantem/widget/show4dstem.py | 129 +++++++------------ 2 files changed, 169 insertions(+), 121 deletions(-) diff --git a/widget/js/show4dstem.tsx b/widget/js/show4dstem.tsx index 0f9f7bcf..71032df0 100644 --- a/widget/js/show4dstem.tsx +++ b/widget/js/show4dstem.tsx @@ -36,18 +36,12 @@ function roundToNiceValue(value: number): number { /** Format scale bar label with appropriate unit */ function formatScaleLabel(value: number, unit: "Å" | "mrad"): string { const nice = roundToNiceValue(value); - if (unit === "Å") { if (nice >= 10) return `${Math.round(nice / 10)} nm`; - if (nice >= 1) return `${Math.round(nice)} Å`; - return `${nice.toFixed(2)} Å`; - } - if (unit === "mrad") { - if (nice >= 1000) return `${Math.round(nice / 1000)} rad`; - if (nice >= 1) return `${Math.round(nice)} mrad`; - return `${nice.toFixed(2)} mrad`; + return nice >= 1 ? `${Math.round(nice)} Å` : `${nice.toFixed(2)} Å`; } - return `${Math.round(nice)} ${unit}`; + if (nice >= 1000) return `${Math.round(nice / 1000)} rad`; + return nice >= 1 ? `${Math.round(nice)} mrad` : `${nice.toFixed(2)} mrad`; } /** @@ -59,9 +53,9 @@ function drawScaleBarHiDPI( dpr: number, zoom: number, pixelSize: number, - unit: string = "Å", - imageWidth: number, // Original image width in pixels - imageHeight: number // Original image height in pixels + unit: "Å" | "mrad", + imageWidth: number, + imageHeight: number ) { const ctx = canvas.getContext("2d"); if (!ctx) return; @@ -485,6 +479,7 @@ function Show4DSTEM() { const [isHoveringResize, setIsHoveringResize] = React.useState(false); const [isHoveringResizeInner, setIsHoveringResizeInner] = React.useState(false); const [colormap, setColormap] = React.useState("inferno"); + const [showFft, setShowFft] = React.useState(true); // Band-pass filter range [innerCutoff, outerCutoff] in pixels - [0, 0] means disabled const [bandpass, setBandpass] = React.useState([0, 0]); @@ -595,11 +590,17 @@ function Show4DSTEM() { const dpCanvasRef = React.useRef(null); const dpOverlayRef = React.useRef(null); const dpUiRef = React.useRef(null); // High-DPI UI overlay for scale bar + const dpOffscreenRef = React.useRef(null); + const dpImageDataRef = React.useRef(null); const virtualCanvasRef = React.useRef(null); const virtualOverlayRef = React.useRef(null); const viUiRef = React.useRef(null); // High-DPI UI overlay for scale bar + const viOffscreenRef = React.useRef(null); + const viImageDataRef = React.useRef(null); const fftCanvasRef = React.useRef(null); const fftOverlayRef = React.useRef(null); + const fftOffscreenRef = React.useRef(null); + const fftImageDataRef = React.useRef(null); // Display size for high-DPI UI overlays const UI_SIZE = 400; @@ -619,16 +620,24 @@ function Show4DSTEM() { // Store raw virtual image data for filtering const rawVirtualImageRef = React.useRef(null); + const viWorkRealRef = React.useRef(null); + const viWorkImagRef = React.useRef(null); + const fftWorkRealRef = React.useRef(null); + const fftWorkImagRef = React.useRef(null); + const fftMagnitudeRef = React.useRef(null); // Parse virtual image bytes into Float32Array React.useEffect(() => { if (!virtualImageBytes) return; const bytes = new Uint8Array(virtualImageBytes.buffer, virtualImageBytes.byteOffset, virtualImageBytes.byteLength); - const floatData = new Float32Array(bytes.length); + let floatData = rawVirtualImageRef.current; + if (!floatData || floatData.length !== bytes.length) { + floatData = new Float32Array(bytes.length); + rawVirtualImageRef.current = floatData; + } for (let i = 0; i < bytes.length; i++) { floatData[i] = bytes[i]; } - rawVirtualImageRef.current = floatData; }, [virtualImageBytes]); // Render DP with zoom @@ -641,13 +650,25 @@ function Show4DSTEM() { const bytes = new Uint8Array(frameBytes.buffer, frameBytes.byteOffset, frameBytes.byteLength); const lut = COLORMAPS[colormap] || COLORMAPS.inferno; - const offscreen = document.createElement("canvas"); - offscreen.width = detY; - offscreen.height = detX; + let offscreen = dpOffscreenRef.current; + if (!offscreen) { + offscreen = document.createElement("canvas"); + dpOffscreenRef.current = offscreen; + } + const sizeChanged = offscreen.width !== detY || offscreen.height !== detX; + if (sizeChanged) { + offscreen.width = detY; + offscreen.height = detX; + dpImageDataRef.current = null; + } const offCtx = offscreen.getContext("2d"); if (!offCtx) return; - const imgData = offCtx.createImageData(detY, detX); + let imgData = dpImageDataRef.current; + if (!imgData) { + imgData = offCtx.createImageData(detY, detX); + dpImageDataRef.current = imgData; + } const rgba = imgData.data; for (let i = 0; i < bytes.length; i++) { @@ -699,13 +720,25 @@ function Show4DSTEM() { } const lut = COLORMAPS[colormap] || COLORMAPS.inferno; - const offscreen = document.createElement("canvas"); - offscreen.width = width; - offscreen.height = height; + let offscreen = viOffscreenRef.current; + if (!offscreen) { + offscreen = document.createElement("canvas"); + viOffscreenRef.current = offscreen; + } + const sizeChanged = offscreen.width !== width || offscreen.height !== height; + if (sizeChanged) { + offscreen.width = width; + offscreen.height = height; + viImageDataRef.current = null; + } const offCtx = offscreen.getContext("2d"); if (!offCtx) return; - const imageData = offCtx.createImageData(width, height); + let imageData = viImageDataRef.current; + if (!imageData) { + imageData = offCtx.createImageData(width, height); + viImageDataRef.current = imageData; + } for (let i = 0; i < filtered.length; i++) { const val = Math.floor(((filtered[i] - min) / (max - min || 1)) * 255); imageData.data[i * 4] = lut[val * 3]; @@ -758,8 +791,22 @@ function Show4DSTEM() { return () => { isCancelled = true; }; } else { // CPU Fallback (Sync) - const real = rawVirtualImageRef.current.slice(); - const imag = new Float32Array(real.length); + const source = rawVirtualImageRef.current; + if (!source) return; + const len = source.length; + let real = viWorkRealRef.current; + if (!real || real.length !== len) { + real = new Float32Array(len); + viWorkRealRef.current = real; + } + real.set(source); + let imag = viWorkImagRef.current; + if (!imag || imag.length !== len) { + imag = new Float32Array(len); + viWorkImagRef.current = imag; + } else { + imag.fill(0); + } fft2d(real, imag, width, height, false); fftshift(real, width, height); fftshift(imag, width, height); @@ -770,6 +817,7 @@ function Show4DSTEM() { renderData(real); } } else { + if (!rawVirtualImageRef.current) return; renderData(rawVirtualImageRef.current); } }, [virtualImageBytes, shapeX, shapeY, colormap, viZoom, viPanX, viPanY, bpInner, bpOuter, gpuReady]); @@ -790,6 +838,10 @@ function Show4DSTEM() { const canvas = fftCanvasRef.current; const ctx = canvas.getContext("2d"); if (!ctx) return; + if (!showFft) { + ctx.clearRect(0, 0, canvas.width, canvas.height); + return; + } const width = shapeY; const height = shapeX; @@ -799,7 +851,11 @@ function Show4DSTEM() { // Helper to render magnitude to canvas const renderMagnitude = (real: Float32Array, imag: Float32Array) => { // Compute log magnitude - const magnitude = new Float32Array(real.length); + let magnitude = fftMagnitudeRef.current; + if (!magnitude || magnitude.length !== real.length) { + magnitude = new Float32Array(real.length); + fftMagnitudeRef.current = magnitude; + } for (let i = 0; i < real.length; i++) { magnitude[i] = Math.log1p(Math.sqrt(real[i] * real[i] + imag[i] * imag[i])); } @@ -811,13 +867,25 @@ function Show4DSTEM() { if (magnitude[i] > max) max = magnitude[i]; } - const offscreen = document.createElement("canvas"); - offscreen.width = width; - offscreen.height = height; + let offscreen = fftOffscreenRef.current; + if (!offscreen) { + offscreen = document.createElement("canvas"); + fftOffscreenRef.current = offscreen; + } + const sizeChanged = offscreen.width !== width || offscreen.height !== height; + if (sizeChanged) { + offscreen.width = width; + offscreen.height = height; + fftImageDataRef.current = null; + } const offCtx = offscreen.getContext("2d"); if (!offCtx) return; - const imgData = offCtx.createImageData(width, height); + let imgData = fftImageDataRef.current; + if (!imgData) { + imgData = offCtx.createImageData(width, height); + fftImageDataRef.current = imgData; + } const rgba = imgData.data; const range = max > min ? max - min : 1; @@ -862,14 +930,26 @@ function Show4DSTEM() { return () => { isCancelled = true; }; } else { // CPU fallback (sync) - const real = sourceData.slice(); - const imag = new Float32Array(real.length); + const len = sourceData.length; + let real = fftWorkRealRef.current; + if (!real || real.length !== len) { + real = new Float32Array(len); + fftWorkRealRef.current = real; + } + real.set(sourceData); + let imag = fftWorkImagRef.current; + if (!imag || imag.length !== len) { + imag = new Float32Array(len); + fftWorkImagRef.current = imag; + } else { + imag.fill(0); + } fft2d(real, imag, width, height, false); fftshift(real, width, height); fftshift(imag, width, height); renderMagnitude(real, imag); } - }, [virtualImageBytes, shapeX, shapeY, colormap, fftZoom, fftPanX, fftPanY, gpuReady]); + }, [virtualImageBytes, shapeX, shapeY, colormap, fftZoom, fftPanX, fftPanY, gpuReady, showFft]); // Render FFT overlay with high-pass filter circle React.useEffect(() => { @@ -878,6 +958,7 @@ function Show4DSTEM() { const ctx = canvas.getContext("2d"); if (!ctx) return; ctx.clearRect(0, 0, canvas.width, canvas.height); + if (!showFft) return; // Draw band-pass filter circles (inner = HP, outer = LP) const centerX = (shapeY / 2) * fftZoom + fftPanX; @@ -903,7 +984,7 @@ function Show4DSTEM() { ctx.stroke(); ctx.setLineDash([]); } - }, [fftZoom, fftPanX, fftPanY, pixelSize, shapeX, shapeY, bpInner, bpOuter]); + }, [fftZoom, fftPanX, fftPanY, pixelSize, shapeX, shapeY, bpInner, bpOuter, showFft]); // ───────────────────────────────────────────────────────────────────────── // High-DPI Scale Bar UI Overlays @@ -1283,9 +1364,17 @@ function Show4DSTEM() { - - FFT - + + + FFT + + setShowFft(e.target.checked)} + size="small" + sx={switchStyles.medium} + /> + vmin: - norm = np.clip((img - vmin) / (vmax - vmin) * 255, 0, 255).astype(np.uint8) - else: - norm = np.zeros(img.shape, dtype=np.uint8) - return norm.tobytes() - - # BF: circle at bf_radius - bf_mask = self._create_circular_mask(self.center_x, self.center_y, self.bf_radius) - self._cached_bf_virtual = _compute_and_normalize(bf_mask) - - # ABF: annular at 0.5*bf to bf (matches JS button) - abf_mask = self._create_annular_mask( - self.center_x, self.center_y, self.bf_radius * 0.5, self.bf_radius - ) - self._cached_abf_virtual = _compute_and_normalize(abf_mask) - - # LAADF: annular at bf to 2*bf (matches JS button) - laadf_mask = self._create_annular_mask( - self.center_x, self.center_y, self.bf_radius, self.bf_radius * 2.0 - ) - self._cached_laadf_virtual = _compute_and_normalize(laadf_mask) - - # HAADF: annular at 2*bf to 4*bf (matches JS button) - haadf_mask = self._create_annular_mask( - self.center_x, self.center_y, self.bf_radius * 2.0, self.bf_radius * 4.0 - ) - self._cached_haadf_virtual = _compute_and_normalize(haadf_mask) + """Pre-compute BF/ABF/LAADF/HAADF virtual images for instant preset switching.""" + cx, cy, bf = self.center_x, self.center_y, self.bf_radius + self._cached_bf_virtual = self._normalize_to_bytes( + self._fast_masked_sum(self._create_circular_mask(cx, cy, bf))) + self._cached_abf_virtual = self._normalize_to_bytes( + self._fast_masked_sum(self._create_annular_mask(cx, cy, bf * 0.5, bf))) + self._cached_laadf_virtual = self._normalize_to_bytes( + self._fast_masked_sum(self._create_annular_mask(cx, cy, bf, bf * 2.0))) + self._cached_haadf_virtual = self._normalize_to_bytes( + self._fast_masked_sum(self._create_annular_mask(cx, cy, bf * 2.0, bf * 4.0))) def _get_cached_preset(self) -> bytes | None: """Check if current ROI matches a cached preset and return it.""" @@ -684,64 +660,47 @@ def _get_cached_preset(self) -> bytes | None: return None - def _fast_masked_sum(self, mask) -> 'np.ndarray': - """Fast masked sum using element-wise multiply (memory efficient).""" - # Handle both 3D and 4D data + def _fast_masked_sum(self, mask) -> "np.ndarray": + """Masked sum over detector dimensions.""" if self._data.ndim == 4: - # (scan_x, scan_y, det_x, det_y) -> sum over detector dims - virtual_image = (self._data.astype(np.float32) * mask).sum(axis=(2, 3)) + return (self._data.astype(np.float32) * mask).sum(axis=(2, 3)) + return (self._data.astype(np.float32) * mask).sum(axis=(1, 2)).reshape(self._scan_shape) + + def _normalize_to_bytes(self, arr: "np.ndarray") -> bytes: + """Normalize array to uint8 bytes.""" + vmin, vmax = float(arr.min()), float(arr.max()) + if vmax > vmin: + norm = np.clip((arr - vmin) / (vmax - vmin) * 255, 0, 255).astype(np.uint8) else: - # (N, det_x, det_y) -> sum over detector dims then reshape - virtual_image = (self._data.astype(np.float32) * mask).sum(axis=(1, 2)) - virtual_image = virtual_image.reshape(self._scan_shape) - return virtual_image + norm = np.zeros(arr.shape, dtype=np.uint8) + return norm.tobytes() def _compute_virtual_image_from_roi(self): - """Compute virtual image based on ROI mode (point, circle, square, or annular).""" - - # Fast path: use cached images for presets (BF/ABF/LAADF/HAADF) + """Compute virtual image based on ROI mode.""" cached = self._get_cached_preset() if cached is not None: self.virtual_image_bytes = cached return - + + cx, cy = self.roi_center_x, self.roi_center_y + if self.roi_mode == "circle" and self.roi_radius > 0: - mask = self._create_circular_mask( - self.roi_center_x, self.roi_center_y, self.roi_radius - ) - virtual_image = self._fast_masked_sum(mask) + mask = self._create_circular_mask(cx, cy, self.roi_radius) elif self.roi_mode == "square" and self.roi_radius > 0: - mask = self._create_square_mask( - self.roi_center_x, self.roi_center_y, self.roi_radius - ) - virtual_image = self._fast_masked_sum(mask) - elif self.roi_mode == "annular" and self.roi_radius > 0 and self.roi_radius_inner >= 0: - mask = self._create_annular_mask( - self.roi_center_x, self.roi_center_y, - self.roi_radius_inner, self.roi_radius - ) - virtual_image = self._fast_masked_sum(mask) + mask = self._create_square_mask(cx, cy, self.roi_radius) + elif self.roi_mode == "annular" and self.roi_radius > 0: + mask = self._create_annular_mask(cx, cy, self.roi_radius_inner, self.roi_radius) elif self.roi_mode == "rect" and self.roi_width > 0 and self.roi_height > 0: - mask = self._create_rect_mask( - self.roi_center_x, self.roi_center_y, - self.roi_width / 2, self.roi_height / 2 - ) - virtual_image = self._fast_masked_sum(mask) + mask = self._create_rect_mask(cx, cy, self.roi_width / 2, self.roi_height / 2) else: # Point mode: single-pixel indexing - # Array indexing: [row, col] where row=det_y, col=det_x in image coords - row = int(max(0, min(self._det_shape[0] - 1, round(self.roi_center_y)))) - col = int(max(0, min(self._det_shape[1] - 1, round(self.roi_center_x)))) + row = int(np.clip(round(cy), 0, self._det_shape[0] - 1)) + col = int(np.clip(round(cx), 0, self._det_shape[1] - 1)) if self._data.ndim == 4: - virtual_image = self._data[:, :, row, col] + vi = self._data[:, :, row, col] else: - virtual_image = self._data[:, row, col].reshape(self._scan_shape) + vi = self._data[:, row, col].reshape(self._scan_shape) + self.virtual_image_bytes = self._normalize_to_bytes(vi) + return - # Normalize to uint8 - vmin, vmax = float(virtual_image.min()), float(virtual_image.max()) - if vmax > vmin: - normalized = np.clip((virtual_image - vmin) / (vmax - vmin) * 255, 0, 255) - normalized = normalized.astype(np.uint8) - else: - normalized = np.zeros(virtual_image.shape, dtype=np.uint8) - self.virtual_image_bytes = normalized.tobytes() + self.virtual_image_bytes = self._normalize_to_bytes(self._fast_masked_sum(mask)) From f2a17c4a13204c5e399fc085199e3f1786b8bd12 Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Mon, 12 Jan 2026 16:44:28 -0800 Subject: [PATCH 07/27] reset all when user clicks --- widget/js/show4dstem.css | 20 ++++-------- widget/js/show4dstem.tsx | 42 ++++++++++++++++++++----- widget/src/quantem/widget/show4dstem.py | 30 ++++++++++-------- 3 files changed, 58 insertions(+), 34 deletions(-) diff --git a/widget/js/show4dstem.css b/widget/js/show4dstem.css index f754251e..bd4f9579 100644 --- a/widget/js/show4dstem.css +++ b/widget/js/show4dstem.css @@ -1,19 +1,11 @@ -/* show4dstem.css - Minimal CSS for Show4DSTEM */ -/* Most styling handled by MUI, this is for canvas-specific styles */ - -.show4dstem-root { - font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif; - background-color: #1a1a1a; -} - -/* Target Jupyter/VS Code output areas */ +/* show4dstem.css - Force dark background in Jupyter/VS Code output areas */ .widget-output, .jp-OutputArea-output, .jp-RenderedHTMLCommon, -.cell-output-ipywidget-background { +.jp-OutputArea-child, +.jp-OutputArea, +.cell-output-ipywidget-background, +.cell-output, +.output_subarea { background-color: #1a1a1a !important; -} - -.show4dstem-root canvas { - display: block; } \ No newline at end of file diff --git a/widget/js/show4dstem.tsx b/widget/js/show4dstem.tsx index 71032df0..6e41ab7c 100644 --- a/widget/js/show4dstem.tsx +++ b/widget/js/show4dstem.tsx @@ -34,12 +34,15 @@ function roundToNiceValue(value: number): number { } /** Format scale bar label with appropriate unit */ -function formatScaleLabel(value: number, unit: "Å" | "mrad"): string { +function formatScaleLabel(value: number, unit: "Å" | "mrad" | "px"): string { const nice = roundToNiceValue(value); if (unit === "Å") { if (nice >= 10) return `${Math.round(nice / 10)} nm`; return nice >= 1 ? `${Math.round(nice)} Å` : `${nice.toFixed(2)} Å`; } + if (unit === "px") { + return nice >= 1 ? `${Math.round(nice)} px` : `${nice.toFixed(1)} px`; + } if (nice >= 1000) return `${Math.round(nice / 1000)} rad`; return nice >= 1 ? `${Math.round(nice)} mrad` : `${nice.toFixed(2)} mrad`; } @@ -53,7 +56,7 @@ function drawScaleBarHiDPI( dpr: number, zoom: number, pixelSize: number, - unit: "Å" | "mrad", + unit: "Å" | "mrad" | "px", imageWidth: number, imageHeight: number ) { @@ -437,6 +440,7 @@ function Show4DSTEM() { const [pixelSize] = useModelState("pixel_size"); const [kPixelSize] = useModelState("k_pixel_size"); + const [kCalibrated] = useModelState("k_calibrated"); const [frameBytes] = useModelState("frame_bytes"); const [virtualImageBytes] = useModelState("virtual_image_bytes"); @@ -566,6 +570,19 @@ function Show4DSTEM() { }); }, []); + // Fix VS Code Jupyter white background (traverse up and fix parent) + const rootRef = React.useRef(null); + React.useEffect(() => { + if (!rootRef.current) return; + let el: HTMLElement | null = rootRef.current; + while (el) { + if (el.classList.contains("cell-output-ipywidget-background")) { + el.style.setProperty("background-color", "#1a1a1a", "important"); + } + el = el.parentElement; + } + }, []); + // Zoom state const [dpZoom, setDpZoom] = React.useState(1); const [dpPanX, setDpPanX] = React.useState(0); @@ -994,7 +1011,8 @@ function Show4DSTEM() { React.useEffect(() => { if (!dpUiRef.current) return; // Draw scale bar first (clears canvas) - drawScaleBarHiDPI(dpUiRef.current, DPR, dpZoom, kPixelSize || 1, "mrad", detY, detX); + const kUnit = kCalibrated ? "mrad" : "px"; + drawScaleBarHiDPI(dpUiRef.current, DPR, dpZoom, kPixelSize || 1, kUnit, detY, detX); // Draw ROI overlay (circle, square, rect, annular) or point crosshair if (roiMode === "point") { drawDpCrosshairHiDPI(dpUiRef.current, DPR, localKx, localKy, dpZoom, dpPanX, dpPanY, detY, detX, isDraggingDP); @@ -1006,7 +1024,7 @@ function Show4DSTEM() { isDraggingDP, isDraggingResize, isDraggingResizeInner, isHoveringResize, isHoveringResizeInner ); } - }, [dpZoom, dpPanX, dpPanY, kPixelSize, detX, detY, roiMode, roiRadius, roiRadiusInner, roiWidth, roiHeight, localKx, localKy, isDraggingDP, isDraggingResize, isDraggingResizeInner, isHoveringResize, isHoveringResizeInner]); + }, [dpZoom, dpPanX, dpPanY, kPixelSize, kCalibrated, detX, detY, roiMode, roiRadius, roiRadiusInner, roiWidth, roiHeight, localKx, localKy, isDraggingDP, isDraggingResize, isDraggingResizeInner, isHoveringResize, isHoveringResizeInner]); // VI scale bar + crosshair (high-DPI) React.useEffect(() => { @@ -1207,7 +1225,7 @@ function Show4DSTEM() { // Render // ───────────────────────────────────────────────────────────────────────── return ( - + {/* Wrapper to ensure header and content have same width */} {/* Header */} @@ -1222,11 +1240,21 @@ function Show4DSTEM() { { - setBandpass([0, 0]); + // Reset position to center + setPosX(Math.floor(shapeX / 2)); + setPosY(Math.floor(shapeY / 2)); + // Reset ROI to detector center, point mode + setRoiCenterX(centerX); + setRoiCenterY(centerY); + setRoiRadius(bfRadius * 0.5); + setRoiMode("point"); + // Reset zoom/pan setDpZoom(1); setDpPanX(0); setDpPanY(0); setViZoom(1); setViPanX(0); setViPanY(0); setFftZoom(1); setFftPanX(0); setFftPanY(0); - setRoiMode("point"); + // Reset colormap and bandpass + setColormap("inferno"); + setBandpass([0, 0]); }} sx={{ ...controlPanel.button }} > diff --git a/widget/src/quantem/widget/show4dstem.py b/widget/src/quantem/widget/show4dstem.py index 0ae70ace..8c1acb2b 100644 --- a/widget/src/quantem/widget/show4dstem.py +++ b/widget/src/quantem/widget/show4dstem.py @@ -25,7 +25,7 @@ class Show4DSTEM(anywidget.AnyWidget): Fast interactive 4D-STEM viewer with advanced features. Optimized for speed with binary transfer and pre-normalization. - Works with NumPy and PyTorch arrays.`````` + Works with NumPy and PyTorch arrays. Parameters ---------- @@ -122,6 +122,7 @@ class Show4DSTEM(anywidget.AnyWidget): # ========================================================================= pixel_size = traitlets.Float(1.0).tag(sync=True) # Å per pixel (real-space) k_pixel_size = traitlets.Float(1.0).tag(sync=True) # mrad per pixel (k-space) + k_calibrated = traitlets.Bool(False).tag(sync=True) # True if k-space has mrad calibration # ========================================================================= # Path Animation (programmatic crosshair control) @@ -149,18 +150,24 @@ def __init__( self.log_scale = log_scale # Extract calibration from Dataset4dstem if provided + k_calibrated = False if hasattr(data, "sampling") and hasattr(data, "array"): # Dataset4dstem: extract calibration and array # sampling = [scan_x, scan_y, det_x, det_y] - if pixel_size is None: + units = getattr(data, "units", ["pixels"] * 4) + if pixel_size is None and units[0] in ("Å", "angstrom", "A", "nm"): pixel_size = float(data.sampling[0]) - if k_pixel_size is None: + if units[0] == "nm": + pixel_size *= 10 # Convert nm to Å + if k_pixel_size is None and units[2] in ("mrad", "1/Å", "1/A"): k_pixel_size = float(data.sampling[2]) + k_calibrated = True data = data.array # Store calibration values (default to 1.0 if not provided) self.pixel_size = pixel_size if pixel_size is not None else 1.0 self.k_pixel_size = k_pixel_size if k_pixel_size is not None else 1.0 + self.k_calibrated = k_calibrated or (k_pixel_size is not None) # Path animation (configured via set_path() or raster()) self._path_points: list[tuple[int, int]] = [] # Convert to NumPy @@ -226,21 +233,18 @@ def __init__( self.roi_radius = self.bf_radius * 0.5 # Start with half BF radius self.roi_active = True - # Compute initial virtual image - try: - self._compute_virtual_image_from_roi() - except Exception: - pass - + # Compute initial virtual image and frame + self._compute_virtual_image_from_roi() self._update_frame() # Path animation: observe index changes from frontend self.observe(self._on_path_index_change, names=["path_index"]) def __repr__(self) -> str: + k_unit = "mrad" if self.k_calibrated else "px" return ( f"Show4DSTEM(shape=({self.shape_x}, {self.shape_y}, {self.det_x}, {self.det_y}), " - f"sampling=({self.pixel_size} Å, {self.k_pixel_size} mrad), " + f"sampling=({self.pixel_size} Å, {self.k_pixel_size} {k_unit}), " f"pos=({self.pos_x}, {self.pos_y}))" ) @@ -697,10 +701,10 @@ def _compute_virtual_image_from_roi(self): row = int(np.clip(round(cy), 0, self._det_shape[0] - 1)) col = int(np.clip(round(cx), 0, self._det_shape[1] - 1)) if self._data.ndim == 4: - vi = self._data[:, :, row, col] + virtual_image = self._data[:, :, row, col] else: - vi = self._data[:, row, col].reshape(self._scan_shape) - self.virtual_image_bytes = self._normalize_to_bytes(vi) + virtual_image = self._data[:, row, col].reshape(self._scan_shape) + self.virtual_image_bytes = self._normalize_to_bytes(virtual_image) return self.virtual_image_bytes = self._normalize_to_bytes(self._fast_masked_sum(mask)) From f4f085bbc6bdf1a92f4b71fb3f1bce25292f623e Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Mon, 19 Jan 2026 20:33:22 -0800 Subject: [PATCH 08/27] add VI ROI, stats display, scale modes, and UI improvements - Add VI ROI for real-space region selection (circle/square/rect) - Add summed DP computation from VI ROI positions - Add dp_stats/vi_stats for mean/min/max/std display on both panels - Add scale modes (linear/log/power) for DP panel - Auto-detect center and BF radius on initialization - Move FFT toggle to VI header row - Hide crosshair when VI ROI is active - Fix annular slider overflow with smaller thumbs - Set reasonable initial sizes for VI ROI (~15% of scan) - Mask DC component for DP stats by default --- widget/js/show4dstem.css | 75 +- widget/js/show4dstem.tsx | 1420 +++++++++++++++-------- widget/src/quantem/widget/show4dstem.py | 349 +++++- 3 files changed, 1311 insertions(+), 533 deletions(-) diff --git a/widget/js/show4dstem.css b/widget/js/show4dstem.css index bd4f9579..4030fbf1 100644 --- a/widget/js/show4dstem.css +++ b/widget/js/show4dstem.css @@ -1,11 +1,64 @@ -/* show4dstem.css - Force dark background in Jupyter/VS Code output areas */ -.widget-output, -.jp-OutputArea-output, -.jp-RenderedHTMLCommon, -.jp-OutputArea-child, -.jp-OutputArea, -.cell-output-ipywidget-background, -.cell-output, -.output_subarea { - background-color: #1a1a1a !important; -} \ No newline at end of file +/* show4dstem.css - Dark theme styling for 4D-STEM widget */ + +/* Widget root - always dark background for consistent appearance */ +.show4dstem-root { + background: #1e1e1e; + color: #e0e0e0; + border-radius: 2px; + padding: 16px; +} + +/* Image containers */ +.show4dstem-root .MuiBox-root[style*="position: relative"] { + background: #0a0a0a; + border-radius: 2px; +} + +/* Control groups - bordered boxes */ +.show4dstem-control-group { + background: #2a2a2a; + border: 1px solid #3a3a3a; + border-radius: 2px; + padding: 8px 12px; +} + +/* Buttons - clean bordered style with blue accent */ +.show4dstem-root .MuiButton-root { + background: #2a2a2a; + color: #ccc; + border: 1px solid #555; + border-radius: 2px; + text-transform: none; + font-size: 12px; + padding: 4px 16px; + min-width: 60px; +} + +.show4dstem-root .MuiButton-root:hover { + background: #3a3a3a; + border-color: #5af; +} + +/* Select dropdowns */ +.show4dstem-root .MuiSelect-select { + background: #3a3a3a; + border-radius: 2px !important; +} + +.show4dstem-root .MuiOutlinedInput-root { + border-radius: 2px; +} + +/* Slider styling - blue accent */ +.show4dstem-root .MuiSlider-root { + color: #5af; +} + +/* Switch styling - blue accent */ +.show4dstem-root .MuiSwitch-switchBase.Mui-checked { + color: #5af; +} + +.show4dstem-root .MuiSwitch-switchBase.Mui-checked + .MuiSwitch-track { + background-color: #5af; +} diff --git a/widget/js/show4dstem.tsx b/widget/js/show4dstem.tsx index 6e41ab7c..edee1836 100644 --- a/widget/js/show4dstem.tsx +++ b/widget/js/show4dstem.tsx @@ -15,7 +15,19 @@ import { typography, controlPanel, container } from "./CONFIG"; import { upwardMenuProps, switchStyles } from "./components"; import "./show4dstem.css"; -// Constants +// ============================================================================ +// Layout Constants - consistent spacing throughout +// ============================================================================ +const SPACING = { + XS: 4, // Extra small gap + SM: 8, // Small gap (default between elements) + MD: 12, // Medium gap (between control groups) + LG: 16, // Large gap (between major sections) +}; + +const CANVAS_SIZE = 450; // Both DP and VI canvases + +// Interaction constants const RESIZE_HIT_AREA_PX = 10; const CIRCLE_HANDLE_ANGLE = 0.707; // cos(45°) const LINE_WIDTH_FRACTION = 0.015; @@ -47,6 +59,18 @@ function formatScaleLabel(value: number, unit: "Å" | "mrad" | "px"): string { return nice >= 1 ? `${Math.round(nice)} mrad` : `${nice.toFixed(2)} mrad`; } +/** Format stat value for display (compact scientific notation for small values) */ +function formatStat(value: number): string { + if (value === 0) return "0"; + const abs = Math.abs(value); + if (abs < 0.001 || abs >= 10000) { + return value.toExponential(2); + } + if (abs < 0.01) return value.toFixed(4); + if (abs < 1) return value.toFixed(3); + return value.toFixed(2); +} + /** * Draw scale bar and zoom indicator on a high-DPI UI canvas. * This renders crisp text/lines independent of the image resolution. @@ -62,21 +86,22 @@ function drawScaleBarHiDPI( ) { const ctx = canvas.getContext("2d"); if (!ctx) return; - + // Clear canvas ctx.clearRect(0, 0, canvas.width, canvas.height); - + // Scale context for device pixel ratio ctx.save(); ctx.scale(dpr, dpr); - + // CSS pixel dimensions const cssWidth = canvas.width / dpr; const cssHeight = canvas.height / dpr; - - // Calculate the display scale factor (how much the image is scaled to fit the canvas) - const displayScale = Math.min(cssWidth / imageWidth, cssHeight / imageHeight); - const effectiveZoom = zoom * displayScale; + + // Calculate separate X/Y scale factors (canvas stretches to fill, not aspect-preserving) + const scaleX = cssWidth / imageWidth; + // Use X scale for horizontal measurements (scale bar is horizontal) + const effectiveZoom = zoom * scaleX; // Fixed UI sizes in CSS pixels (always crisp) const targetBarPx = 60; // Target bar length in CSS pixels @@ -125,7 +150,7 @@ function drawScaleBarHiDPI( * Draw VI crosshair on high-DPI canvas (crisp regardless of image resolution) * Note: Does NOT clear canvas - should be called after drawScaleBarHiDPI */ -function drawViCrosshairHiDPI( +function drawViPositionMarker( canvas: HTMLCanvasElement, dpr: number, posX: number, // Position in image coordinates @@ -139,44 +164,188 @@ function drawViCrosshairHiDPI( ) { const ctx = canvas.getContext("2d"); if (!ctx) return; - + ctx.save(); ctx.scale(dpr, dpr); - + const cssWidth = canvas.width / dpr; const cssHeight = canvas.height / dpr; - const displayScale = Math.min(cssWidth / imageWidth, cssHeight / imageHeight); - + const scaleX = cssWidth / imageWidth; + const scaleY = cssHeight / imageHeight; + // Convert image coordinates to CSS pixel coordinates - const screenX = posY * zoom * displayScale + panX * displayScale; - const screenY = posX * zoom * displayScale + panY * displayScale; - - // Fixed UI sizes in CSS pixels (consistent with DP crosshair) - const crosshairSize = 18; - const lineWidth = 3; - const dotRadius = 6; - + const screenX = posY * zoom * scaleX + panX * scaleX; + const screenY = posX * zoom * scaleY + panY * scaleY; + + // Simple crosshair (no circle) + const crosshairSize = 12; + const lineWidth = 1.5; + ctx.shadowColor = "rgba(0, 0, 0, 0.5)"; ctx.shadowBlur = 2; ctx.shadowOffsetX = 1; ctx.shadowOffsetY = 1; - + ctx.strokeStyle = isDragging ? "rgba(255, 255, 0, 0.9)" : "rgba(255, 100, 100, 0.9)"; ctx.lineWidth = lineWidth; - - // Draw crosshair + + // Draw crosshair lines only ctx.beginPath(); ctx.moveTo(screenX - crosshairSize, screenY); ctx.lineTo(screenX + crosshairSize, screenY); ctx.moveTo(screenX, screenY - crosshairSize); ctx.lineTo(screenX, screenY + crosshairSize); ctx.stroke(); - - // Draw center dot - ctx.beginPath(); - ctx.arc(screenX, screenY, dotRadius, 0, 2 * Math.PI); - ctx.stroke(); - + + ctx.restore(); +} + +/** + * Draw VI ROI overlay on high-DPI canvas for real-space region selection + * Note: Does NOT clear canvas - should be called after drawViPositionMarker + */ +function drawViRoiOverlayHiDPI( + canvas: HTMLCanvasElement, + dpr: number, + roiMode: string, + centerX: number, + centerY: number, + radius: number, + roiWidth: number, + roiHeight: number, + zoom: number, + panX: number, + panY: number, + imageWidth: number, + imageHeight: number, + isDragging: boolean, + isDraggingResize: boolean, + isHoveringResize: boolean +) { + if (roiMode === "off") return; + + const ctx = canvas.getContext("2d"); + if (!ctx) return; + + ctx.save(); + ctx.scale(dpr, dpr); + + const cssWidth = canvas.width / dpr; + const cssHeight = canvas.height / dpr; + const scaleX = cssWidth / imageWidth; + const scaleY = cssHeight / imageHeight; + + // Convert image coordinates to screen coordinates (note: Y is row, X is col in image) + const screenX = centerY * zoom * scaleX + panX * scaleX; + const screenY = centerX * zoom * scaleY + panY * scaleY; + + const lineWidth = 2.5; + const crosshairSize = 10; + const handleRadius = 6; + + ctx.shadowColor = "rgba(0, 0, 0, 0.4)"; + ctx.shadowBlur = 2; + ctx.shadowOffsetX = 1; + ctx.shadowOffsetY = 1; + + // Helper to draw resize handle (purple color for VI ROI to differentiate from DP) + const drawResizeHandle = (handleX: number, handleY: number) => { + let handleFill: string; + let handleStroke: string; + + if (isDraggingResize) { + handleFill = "rgba(180, 100, 255, 1)"; + handleStroke = "rgba(255, 255, 255, 1)"; + } else if (isHoveringResize) { + handleFill = "rgba(220, 150, 255, 1)"; + handleStroke = "rgba(255, 255, 255, 1)"; + } else { + handleFill = "rgba(160, 80, 255, 0.8)"; + handleStroke = "rgba(255, 255, 255, 0.8)"; + } + ctx.beginPath(); + ctx.arc(handleX, handleY, handleRadius, 0, 2 * Math.PI); + ctx.fillStyle = handleFill; + ctx.fill(); + ctx.strokeStyle = handleStroke; + ctx.lineWidth = 1.5; + ctx.stroke(); + }; + + // Helper to draw center crosshair (purple/magenta for VI ROI) + const drawCenterCrosshair = () => { + ctx.strokeStyle = isDragging ? "rgba(255, 200, 0, 0.9)" : "rgba(180, 80, 255, 0.9)"; + ctx.lineWidth = lineWidth; + ctx.beginPath(); + ctx.moveTo(screenX - crosshairSize, screenY); + ctx.lineTo(screenX + crosshairSize, screenY); + ctx.moveTo(screenX, screenY - crosshairSize); + ctx.lineTo(screenX, screenY + crosshairSize); + ctx.stroke(); + }; + + // Purple/magenta color for VI ROI to differentiate from green DP detector + const strokeColor = isDragging ? "rgba(255, 200, 0, 0.9)" : "rgba(180, 80, 255, 0.9)"; + const fillColor = isDragging ? "rgba(255, 200, 0, 0.15)" : "rgba(180, 80, 255, 0.15)"; + + if (roiMode === "circle" && radius > 0) { + const screenRadiusX = radius * zoom * scaleX; + const screenRadiusY = radius * zoom * scaleY; + + ctx.strokeStyle = strokeColor; + ctx.lineWidth = lineWidth; + ctx.beginPath(); + ctx.ellipse(screenX, screenY, screenRadiusX, screenRadiusY, 0, 0, 2 * Math.PI); + ctx.stroke(); + + ctx.fillStyle = fillColor; + ctx.fill(); + + drawCenterCrosshair(); + + // Resize handle at 45° diagonal + const handleOffsetX = screenRadiusX * CIRCLE_HANDLE_ANGLE; + const handleOffsetY = screenRadiusY * CIRCLE_HANDLE_ANGLE; + drawResizeHandle(screenX + handleOffsetX, screenY + handleOffsetY); + + } else if (roiMode === "square" && radius > 0) { + // Square uses radius as half-size + const screenHalfW = radius * zoom * scaleX; + const screenHalfH = radius * zoom * scaleY; + const left = screenX - screenHalfW; + const top = screenY - screenHalfH; + + ctx.strokeStyle = strokeColor; + ctx.lineWidth = lineWidth; + ctx.beginPath(); + ctx.rect(left, top, screenHalfW * 2, screenHalfH * 2); + ctx.stroke(); + + ctx.fillStyle = fillColor; + ctx.fill(); + + drawCenterCrosshair(); + drawResizeHandle(screenX + screenHalfW, screenY + screenHalfH); + + } else if (roiMode === "rect" && roiWidth > 0 && roiHeight > 0) { + const screenHalfW = (roiWidth / 2) * zoom * scaleX; + const screenHalfH = (roiHeight / 2) * zoom * scaleY; + const left = screenX - screenHalfW; + const top = screenY - screenHalfH; + + ctx.strokeStyle = strokeColor; + ctx.lineWidth = lineWidth; + ctx.beginPath(); + ctx.rect(left, top, screenHalfW * 2, screenHalfH * 2); + ctx.stroke(); + + ctx.fillStyle = fillColor; + ctx.fill(); + + drawCenterCrosshair(); + drawResizeHandle(screenX + screenHalfW, screenY + screenHalfH); + } + ctx.restore(); } @@ -198,17 +367,19 @@ function drawDpCrosshairHiDPI( ) { const ctx = canvas.getContext("2d"); if (!ctx) return; - + ctx.save(); ctx.scale(dpr, dpr); - + const cssWidth = canvas.width / dpr; const cssHeight = canvas.height / dpr; - const displayScale = Math.min(cssWidth / detWidth, cssHeight / detHeight); - + // Use separate X/Y scale factors (canvas stretches to fill container) + const scaleX = cssWidth / detWidth; + const scaleY = cssHeight / detHeight; + // Convert detector coordinates to CSS pixel coordinates (no swap - kx is X, ky is Y) - const screenX = kx * zoom * displayScale + panX * displayScale; - const screenY = ky * zoom * displayScale + panY * displayScale; + const screenX = kx * zoom * scaleX + panX * scaleX; + const screenY = ky * zoom * scaleY + panY * scaleY; // Fixed UI sizes in CSS pixels (consistent with VI crosshair) const crosshairSize = 18; @@ -266,17 +437,19 @@ function drawRoiOverlayHiDPI( ) { const ctx = canvas.getContext("2d"); if (!ctx) return; - + ctx.save(); ctx.scale(dpr, dpr); - + const cssWidth = canvas.width / dpr; const cssHeight = canvas.height / dpr; - const displayScale = Math.min(cssWidth / detWidth, cssHeight / detHeight); - + // Use separate X/Y scale factors (canvas stretches to fill container) + const scaleX = cssWidth / detWidth; + const scaleY = cssHeight / detHeight; + // Convert detector coordinates to CSS pixel coordinates - const screenX = centerX * zoom * displayScale + panX * displayScale; - const screenY = centerY * zoom * displayScale + panY * displayScale; + const screenX = centerX * zoom * scaleX + panX * scaleX; + const screenY = centerY * zoom * scaleY + panY * scaleY; // Fixed UI sizes in CSS pixels const lineWidth = 2.5; @@ -327,99 +500,266 @@ function drawRoiOverlayHiDPI( }; if (roiMode === "circle" && radius > 0) { - const screenRadius = radius * zoom * displayScale; - - // Draw circle + // Use separate X/Y radii for ellipse (handles non-square detectors) + const screenRadiusX = radius * zoom * scaleX; + const screenRadiusY = radius * zoom * scaleY; + + // Draw ellipse (becomes circle if scaleX === scaleY) ctx.strokeStyle = isDragging ? "rgba(255, 255, 0, 0.9)" : "rgba(0, 255, 0, 0.9)"; ctx.lineWidth = lineWidth; ctx.beginPath(); - ctx.arc(screenX, screenY, screenRadius, 0, 2 * Math.PI); + ctx.ellipse(screenX, screenY, screenRadiusX, screenRadiusY, 0, 0, 2 * Math.PI); ctx.stroke(); - + // Semi-transparent fill ctx.fillStyle = isDragging ? "rgba(255, 255, 0, 0.12)" : "rgba(0, 255, 0, 0.12)"; ctx.fill(); - + drawCenterCrosshair(); - - // Resize handle at 45° - const handleOffset = screenRadius * CIRCLE_HANDLE_ANGLE; - drawResizeHandle(screenX + handleOffset, screenY + handleOffset); - + + // Resize handle at 45° diagonal + const handleOffsetX = screenRadiusX * CIRCLE_HANDLE_ANGLE; + const handleOffsetY = screenRadiusY * CIRCLE_HANDLE_ANGLE; + drawResizeHandle(screenX + handleOffsetX, screenY + handleOffsetY); + } else if (roiMode === "square" && radius > 0) { - const screenHalfSize = radius * zoom * displayScale; - const left = screenX - screenHalfSize; - const top = screenY - screenHalfSize; - const size = screenHalfSize * 2; - + // Square in detector space uses same half-size in both dimensions + const screenHalfW = radius * zoom * scaleX; + const screenHalfH = radius * zoom * scaleY; + const left = screenX - screenHalfW; + const top = screenY - screenHalfH; + ctx.strokeStyle = isDragging ? "rgba(255, 255, 0, 0.9)" : "rgba(0, 255, 0, 0.9)"; ctx.lineWidth = lineWidth; ctx.beginPath(); - ctx.rect(left, top, size, size); + ctx.rect(left, top, screenHalfW * 2, screenHalfH * 2); ctx.stroke(); - + ctx.fillStyle = isDragging ? "rgba(255, 255, 0, 0.12)" : "rgba(0, 255, 0, 0.12)"; ctx.fill(); - + drawCenterCrosshair(); - drawResizeHandle(screenX + screenHalfSize, screenY + screenHalfSize); - + drawResizeHandle(screenX + screenHalfW, screenY + screenHalfH); + } else if (roiMode === "rect" && roiWidth > 0 && roiHeight > 0) { - const screenHalfW = (roiWidth / 2) * zoom * displayScale; - const screenHalfH = (roiHeight / 2) * zoom * displayScale; + const screenHalfW = (roiWidth / 2) * zoom * scaleX; + const screenHalfH = (roiHeight / 2) * zoom * scaleY; const left = screenX - screenHalfW; const top = screenY - screenHalfH; - + ctx.strokeStyle = isDragging ? "rgba(255, 255, 0, 0.9)" : "rgba(0, 255, 0, 0.9)"; ctx.lineWidth = lineWidth; ctx.beginPath(); ctx.rect(left, top, screenHalfW * 2, screenHalfH * 2); ctx.stroke(); - + ctx.fillStyle = isDragging ? "rgba(255, 255, 0, 0.12)" : "rgba(0, 255, 0, 0.12)"; ctx.fill(); - + drawCenterCrosshair(); drawResizeHandle(screenX + screenHalfW, screenY + screenHalfH); - + } else if (roiMode === "annular" && radius > 0) { - const screenRadiusOuter = radius * zoom * displayScale; - const screenRadiusInner = (radiusInner || 0) * zoom * displayScale; - - // Outer circle (green) + // Use separate X/Y radii for ellipses + const screenRadiusOuterX = radius * zoom * scaleX; + const screenRadiusOuterY = radius * zoom * scaleY; + const screenRadiusInnerX = (radiusInner || 0) * zoom * scaleX; + const screenRadiusInnerY = (radiusInner || 0) * zoom * scaleY; + + // Outer ellipse (green) ctx.strokeStyle = isDragging ? "rgba(255, 255, 0, 0.9)" : "rgba(0, 255, 0, 0.9)"; ctx.lineWidth = lineWidth; ctx.beginPath(); - ctx.arc(screenX, screenY, screenRadiusOuter, 0, 2 * Math.PI); + ctx.ellipse(screenX, screenY, screenRadiusOuterX, screenRadiusOuterY, 0, 0, 2 * Math.PI); ctx.stroke(); - - // Inner circle (cyan) + + // Inner ellipse (cyan) ctx.strokeStyle = isDragging ? "rgba(255, 200, 0, 0.9)" : "rgba(0, 220, 255, 0.9)"; ctx.beginPath(); - ctx.arc(screenX, screenY, screenRadiusInner, 0, 2 * Math.PI); + ctx.ellipse(screenX, screenY, screenRadiusInnerX, screenRadiusInnerY, 0, 0, 2 * Math.PI); ctx.stroke(); - + // Fill annular region ctx.fillStyle = isDragging ? "rgba(255, 255, 0, 0.12)" : "rgba(0, 255, 0, 0.12)"; ctx.beginPath(); - ctx.arc(screenX, screenY, screenRadiusOuter, 0, 2 * Math.PI); - ctx.arc(screenX, screenY, screenRadiusInner, 0, 2 * Math.PI, true); + ctx.ellipse(screenX, screenY, screenRadiusOuterX, screenRadiusOuterY, 0, 0, 2 * Math.PI); + ctx.ellipse(screenX, screenY, screenRadiusInnerX, screenRadiusInnerY, 0, 0, 2 * Math.PI, true); ctx.fill(); - + drawCenterCrosshair(); - - // Outer handle - const handleOffsetOuter = screenRadiusOuter * CIRCLE_HANDLE_ANGLE; - drawResizeHandle(screenX + handleOffsetOuter, screenY + handleOffsetOuter); - - // Inner handle - const handleOffsetInner = screenRadiusInner * CIRCLE_HANDLE_ANGLE; - drawResizeHandle(screenX + handleOffsetInner, screenY + handleOffsetInner, true); + + // Outer handle at 45° diagonal + const handleOffsetOuterX = screenRadiusOuterX * CIRCLE_HANDLE_ANGLE; + const handleOffsetOuterY = screenRadiusOuterY * CIRCLE_HANDLE_ANGLE; + drawResizeHandle(screenX + handleOffsetOuterX, screenY + handleOffsetOuterY); + + // Inner handle at 45° diagonal + const handleOffsetInnerX = screenRadiusInnerX * CIRCLE_HANDLE_ANGLE; + const handleOffsetInnerY = screenRadiusInnerY * CIRCLE_HANDLE_ANGLE; + drawResizeHandle(screenX + handleOffsetInnerX, screenY + handleOffsetInnerY, true); } ctx.restore(); } +// ============================================================================ +// Histogram Component +// ============================================================================ + +/** + * Compute histogram from byte data (0-255). + * Returns 256 bins normalized to 0-1 range. + */ +function computeHistogramFromBytes(data: Uint8Array | Float32Array | null, numBins = 256): number[] { + if (!data || data.length === 0) { + return new Array(numBins).fill(0); + } + + const bins = new Array(numBins).fill(0); + + // For Float32Array, find min/max and bin accordingly + if (data instanceof Float32Array) { + let min = Infinity, max = -Infinity; + for (let i = 0; i < data.length; i++) { + const v = data[i]; + if (isFinite(v)) { + if (v < min) min = v; + if (v > max) max = v; + } + } + if (!isFinite(min) || !isFinite(max) || min === max) { + return bins; + } + const range = max - min; + for (let i = 0; i < data.length; i++) { + const v = data[i]; + if (isFinite(v)) { + const binIdx = Math.min(numBins - 1, Math.floor(((v - min) / range) * numBins)); + bins[binIdx]++; + } + } + } else { + // Uint8Array - values are already 0-255 + for (let i = 0; i < data.length; i++) { + const binIdx = Math.min(numBins - 1, data[i]); + bins[binIdx]++; + } + } + + // Normalize bins to 0-1 + const maxCount = Math.max(...bins); + if (maxCount > 0) { + for (let i = 0; i < numBins; i++) { + bins[i] /= maxCount; + } + } + + return bins; +} + +interface HistogramProps { + data: Uint8Array | Float32Array | null; + colormap: string; + vminPct: number; + vmaxPct: number; + onRangeChange: (min: number, max: number) => void; + width?: number; + height?: number; +} + +/** + * Histogram component with integrated vmin/vmax slider and statistics. + * Shows data distribution with colormap gradient and adjustable clipping. + */ +function Histogram({ + data, + colormap, + vminPct, + vmaxPct, + onRangeChange, + width = 120, + height = 40 +}: HistogramProps) { + const canvasRef = React.useRef(null); + const bins = React.useMemo(() => computeHistogramFromBytes(data), [data]); + + // Draw histogram (vertical gray bars) + React.useEffect(() => { + const canvas = canvasRef.current; + if (!canvas) return; + + const ctx = canvas.getContext("2d"); + if (!ctx) return; + + const dpr = window.devicePixelRatio || 1; + canvas.width = width * dpr; + canvas.height = height * dpr; + ctx.scale(dpr, dpr); + + // Clear with dark background + ctx.fillStyle = "#1a1a1a"; + ctx.fillRect(0, 0, width, height); + + // Reduce to fewer bins for cleaner display + const displayBins = 64; + const binRatio = Math.floor(bins.length / displayBins); + const reducedBins: number[] = []; + for (let i = 0; i < displayBins; i++) { + let sum = 0; + for (let j = 0; j < binRatio; j++) { + sum += bins[i * binRatio + j] || 0; + } + reducedBins.push(sum / binRatio); + } + + // Normalize + const maxVal = Math.max(...reducedBins, 0.001); + const barWidth = width / displayBins; + + // Calculate which bins are in the clipped range + const vminBin = Math.floor((vminPct / 100) * displayBins); + const vmaxBin = Math.floor((vmaxPct / 100) * displayBins); + + // Draw histogram bars (gray) + for (let i = 0; i < displayBins; i++) { + const barHeight = (reducedBins[i] / maxVal) * (height - 2); + const x = i * barWidth; + + // Bars inside range are lighter gray, outside are darker + const inRange = i >= vminBin && i <= vmaxBin; + ctx.fillStyle = inRange ? "#888" : "#444"; + ctx.fillRect(x + 0.5, height - barHeight, Math.max(1, barWidth - 1), barHeight); + } + + }, [bins, colormap, vminPct, vmaxPct, width, height]); + + return ( + + + { + const [newMin, newMax] = v as number[]; + onRangeChange(Math.min(newMin, newMax - 1), Math.max(newMax, newMin + 1)); + }} + min={0} + max={100} + size="small" + sx={{ + width, + py: 0, + "& .MuiSlider-thumb": { width: 8, height: 8 }, + "& .MuiSlider-rail": { height: 2 }, + "& .MuiSlider-track": { height: 2 }, + }} + /> + + ); +} + // ============================================================================ // Main Component // ============================================================================ @@ -453,7 +793,8 @@ function Show4DSTEM() { const [roiHeight, setRoiHeight] = useModelState("roi_height"); // Display options - const [logScale, setLogScale] = useModelState("log_scale"); + const [dpScaleMode, setDpScaleMode] = useModelState("dp_scale_mode"); + const [dpPowerExp, setDpPowerExp] = useModelState("dp_power_exp"); // Detector calibration (for presets) const [bfRadius] = useModelState("bf_radius"); @@ -467,6 +808,9 @@ function Show4DSTEM() { const [pathIntervalMs] = useModelState("path_interval_ms"); const [pathLoop] = useModelState("path_loop"); + // Auto-detection trigger + const [, setAutoDetectTrigger] = useModelState("auto_detect_trigger"); + // ───────────────────────────────────────────────────────────────────────── // Local State (UI-only, not synced to Python) // ───────────────────────────────────────────────────────────────────────── @@ -482,8 +826,51 @@ function Show4DSTEM() { const [isDraggingResizeInner, setIsDraggingResizeInner] = React.useState(false); // For annular inner handle const [isHoveringResize, setIsHoveringResize] = React.useState(false); const [isHoveringResizeInner, setIsHoveringResizeInner] = React.useState(false); - const [colormap, setColormap] = React.useState("inferno"); - const [showFft, setShowFft] = React.useState(true); + // VI ROI drag/resize states (same pattern as DP) + const [isDraggingViRoi, setIsDraggingViRoi] = React.useState(false); + const [isDraggingViRoiResize, setIsDraggingViRoiResize] = React.useState(false); + const [isHoveringViRoiResize, setIsHoveringViRoiResize] = React.useState(false); + // Independent colormaps for DP and VI panels + const [dpColormap, setDpColormap] = React.useState("inferno"); + const [viColormap, setViColormap] = React.useState("inferno"); + // vmin/vmax percentile clipping (0-100) + const [dpVminPct, setDpVminPct] = React.useState(0); + const [dpVmaxPct, setDpVmaxPct] = React.useState(100); + const [viVminPct, setViVminPct] = React.useState(0); + const [viVmaxPct, setViVmaxPct] = React.useState(100); + // Scale mode: "linear" | "log" | "power" + const [viScaleMode, setViScaleMode] = React.useState<"linear" | "log" | "power">("linear"); + const [viPowerExp, setViPowerExp] = React.useState(0.5); + + // VI ROI state (real-space region selection for summed DP) - synced with Python + const [viRoiMode, setViRoiMode] = useModelState("vi_roi_mode"); + const [viRoiCenterX, setViRoiCenterX] = useModelState("vi_roi_center_x"); + const [viRoiCenterY, setViRoiCenterY] = useModelState("vi_roi_center_y"); + const [viRoiRadius, setViRoiRadius] = useModelState("vi_roi_radius"); + const [viRoiWidth, setViRoiWidth] = useModelState("vi_roi_width"); + const [viRoiHeight, setViRoiHeight] = useModelState("vi_roi_height"); + // Local VI ROI center for smooth dragging + const [localViRoiCenterX, setLocalViRoiCenterX] = React.useState(viRoiCenterX || 0); + const [localViRoiCenterY, setLocalViRoiCenterY] = React.useState(viRoiCenterY || 0); + const [summedDpBytes] = useModelState("summed_dp_bytes"); + const [summedDpCount] = useModelState("summed_dp_count"); + const [dpStats] = useModelState("dp_stats"); // [mean, min, max, std] + const [viStats] = useModelState("vi_stats"); // [mean, min, max, std] + const [showFft, setShowFft] = React.useState(false); // Hidden by default per feedback + + // Histogram data - use state to ensure re-renders + const [dpHistogramData, setDpHistogramData] = React.useState(null); + const [viHistogramData, setViHistogramData] = React.useState(null); + + // Parse DP frame bytes for histogram + React.useEffect(() => { + if (!frameBytes) return; + const bytes = new Uint8Array(frameBytes.buffer, frameBytes.byteOffset, frameBytes.byteLength); + // Create a copy to ensure state update triggers re-render + const copy = new Uint8Array(bytes.length); + copy.set(bytes); + setDpHistogramData(copy); + }, [frameBytes]); // Band-pass filter range [innerCutoff, outerCutoff] in pixels - [0, 0] means disabled const [bandpass, setBandpass] = React.useState([0, 0]); @@ -570,18 +957,8 @@ function Show4DSTEM() { }); }, []); - // Fix VS Code Jupyter white background (traverse up and fix parent) + // Root element ref (theme-aware styling handled via CSS variables) const rootRef = React.useRef(null); - React.useEffect(() => { - if (!rootRef.current) return; - let el: HTMLElement | null = rootRef.current; - while (el) { - if (el.classList.contains("cell-output-ipywidget-background")) { - el.style.setProperty("background-color", "#1a1a1a", "important"); - } - el = el.parentElement; - } - }, []); // Zoom state const [dpZoom, setDpZoom] = React.useState(1); @@ -603,6 +980,14 @@ function Show4DSTEM() { if (!isDraggingVI) { setLocalPosX(posX); setLocalPosY(posY); } }, [posX, posY, isDraggingVI]); + // Sync VI ROI local state + React.useEffect(() => { + if (!isDraggingViRoi && !isDraggingViRoiResize) { + setLocalViRoiCenterX(viRoiCenterX || shapeX / 2); + setLocalViRoiCenterY(viRoiCenterY || shapeY / 2); + } + }, [viRoiCenterX, viRoiCenterY, isDraggingViRoi, isDraggingViRoiResize, shapeX, shapeY]); + // Canvas refs const dpCanvasRef = React.useRef(null); const dpOverlayRef = React.useRef(null); @@ -619,8 +1004,7 @@ function Show4DSTEM() { const fftOffscreenRef = React.useRef(null); const fftImageDataRef = React.useRef(null); - // Display size for high-DPI UI overlays - const UI_SIZE = 400; + // Device pixel ratio for high-DPI UI overlays const DPR = typeof window !== 'undefined' ? window.devicePixelRatio || 1 : 1; // ───────────────────────────────────────────────────────────────────────── @@ -635,7 +1019,8 @@ function Show4DSTEM() { return () => overlays.forEach(el => el?.removeEventListener("wheel", preventDefault)); }, []); - // Store raw virtual image data for filtering + // Store raw data for histogram visualization + const dpBytesRef = React.useRef(null); const rawVirtualImageRef = React.useRef(null); const viWorkRealRef = React.useRef(null); const viWorkImagRef = React.useRef(null); @@ -655,17 +1040,31 @@ function Show4DSTEM() { for (let i = 0; i < bytes.length; i++) { floatData[i] = bytes[i]; } + // Update histogram state (triggers re-render) + setViHistogramData(floatData); }, [virtualImageBytes]); - // Render DP with zoom + // Render DP with zoom (use summed DP when VI ROI is active) React.useEffect(() => { - if (!frameBytes || !dpCanvasRef.current) return; + if (!dpCanvasRef.current) return; + + // Determine which bytes to display: summed DP (if VI ROI active) or single frame + const usesSummedDp = viRoiMode && viRoiMode !== "off" && summedDpBytes && summedDpBytes.byteLength > 0; + const sourceBytes = usesSummedDp ? summedDpBytes : frameBytes; + if (!sourceBytes) return; + const canvas = dpCanvasRef.current; const ctx = canvas.getContext("2d"); if (!ctx) return; - const bytes = new Uint8Array(frameBytes.buffer, frameBytes.byteOffset, frameBytes.byteLength); - const lut = COLORMAPS[colormap] || COLORMAPS.inferno; + const bytes = new Uint8Array(sourceBytes.buffer, sourceBytes.byteOffset, sourceBytes.byteLength); + dpBytesRef.current = bytes; // Store for histogram + const lut = COLORMAPS[dpColormap] || COLORMAPS.inferno; + + // Apply vmin/vmax percentile clipping + const vmin = Math.floor(255 * dpVminPct / 100); + const vmax = Math.ceil(255 * dpVmaxPct / 100); + const range = vmax > vmin ? vmax - vmin : 1; let offscreen = dpOffscreenRef.current; if (!offscreen) { @@ -689,7 +1088,9 @@ function Show4DSTEM() { const rgba = imgData.data; for (let i = 0; i < bytes.length; i++) { - const v = bytes[i]; + // Apply vmin/vmax clipping and rescaling + const clamped = Math.max(vmin, Math.min(vmax, bytes[i])); + const v = Math.round(((clamped - vmin) / range) * 255); const j = i * 4; const lutIdx = v * 3; rgba[j] = lut[lutIdx]; @@ -706,7 +1107,7 @@ function Show4DSTEM() { ctx.scale(dpZoom, dpZoom); ctx.drawImage(offscreen, 0, 0); ctx.restore(); - }, [frameBytes, detX, detY, colormap, dpZoom, dpPanX, dpPanY]); + }, [frameBytes, summedDpBytes, viRoiMode, detX, detY, dpColormap, dpVminPct, dpVmaxPct, dpZoom, dpPanX, dpPanY]); // Render DP overlay - just clear (ROI shapes now drawn on high-DPI UI canvas) React.useEffect(() => { @@ -730,13 +1131,34 @@ function Show4DSTEM() { const renderData = (filtered: Float32Array) => { // Normalize and render - let min = Infinity, max = -Infinity; - for (let i = 0; i < filtered.length; i++) { - if (filtered[i] < min) min = filtered[i]; - if (filtered[i] > max) max = filtered[i]; + // Apply scale transformation first + let scaled = filtered; + if (viScaleMode === "log") { + scaled = new Float32Array(filtered.length); + for (let i = 0; i < filtered.length; i++) { + scaled[i] = Math.log1p(Math.max(0, filtered[i])); + } + } else if (viScaleMode === "power") { + scaled = new Float32Array(filtered.length); + for (let i = 0; i < filtered.length; i++) { + scaled[i] = Math.pow(Math.max(0, filtered[i]), viPowerExp); + } } - const lut = COLORMAPS[colormap] || COLORMAPS.inferno; + // Compute actual min/max of scaled data + let dataMin = Infinity, dataMax = -Infinity; + for (let i = 0; i < scaled.length; i++) { + if (scaled[i] < dataMin) dataMin = scaled[i]; + if (scaled[i] > dataMax) dataMax = scaled[i]; + } + + // Apply vmin/vmax percentile clipping + const dataRange = dataMax - dataMin; + const vmin = dataMin + dataRange * viVminPct / 100; + const vmax = dataMin + dataRange * viVmaxPct / 100; + const range = vmax > vmin ? vmax - vmin : 1; + + const lut = COLORMAPS[viColormap] || COLORMAPS.inferno; let offscreen = viOffscreenRef.current; if (!offscreen) { offscreen = document.createElement("canvas"); @@ -756,8 +1178,10 @@ function Show4DSTEM() { imageData = offCtx.createImageData(width, height); viImageDataRef.current = imageData; } - for (let i = 0; i < filtered.length; i++) { - const val = Math.floor(((filtered[i] - min) / (max - min || 1)) * 255); + for (let i = 0; i < scaled.length; i++) { + // Clamp to vmin/vmax and rescale to 0-255 + const clamped = Math.max(vmin, Math.min(vmax, scaled[i])); + const val = Math.floor(((clamped - vmin) / range) * 255); imageData.data[i * 4] = lut[val * 3]; imageData.data[i * 4 + 1] = lut[val * 3 + 1]; imageData.data[i * 4 + 2] = lut[val * 3 + 2]; @@ -837,7 +1261,7 @@ function Show4DSTEM() { if (!rawVirtualImageRef.current) return; renderData(rawVirtualImageRef.current); } - }, [virtualImageBytes, shapeX, shapeY, colormap, viZoom, viPanX, viPanY, bpInner, bpOuter, gpuReady]); + }, [virtualImageBytes, shapeX, shapeY, viColormap, viVminPct, viVmaxPct, viScaleMode, viPowerExp, viZoom, viPanX, viPanY, bpInner, bpOuter, gpuReady]); // Render virtual image overlay (just clear - crosshair drawn on high-DPI UI canvas) React.useEffect(() => { @@ -863,7 +1287,7 @@ function Show4DSTEM() { const width = shapeY; const height = shapeX; const sourceData = rawVirtualImageRef.current; - const lut = COLORMAPS[colormap] || COLORMAPS.inferno; + const lut = COLORMAPS[viColormap] || COLORMAPS.inferno; // Helper to render magnitude to canvas const renderMagnitude = (real: Float32Array, imag: Float32Array) => { @@ -966,7 +1390,7 @@ function Show4DSTEM() { fftshift(imag, width, height); renderMagnitude(real, imag); } - }, [virtualImageBytes, shapeX, shapeY, colormap, fftZoom, fftPanX, fftPanY, gpuReady, showFft]); + }, [virtualImageBytes, shapeX, shapeY, viColormap, fftZoom, fftPanX, fftPanY, gpuReady, showFft]); // Render FFT overlay with high-pass filter circle React.useEffect(() => { @@ -1026,14 +1450,26 @@ function Show4DSTEM() { } }, [dpZoom, dpPanX, dpPanY, kPixelSize, kCalibrated, detX, detY, roiMode, roiRadius, roiRadiusInner, roiWidth, roiHeight, localKx, localKy, isDraggingDP, isDraggingResize, isDraggingResizeInner, isHoveringResize, isHoveringResizeInner]); - // VI scale bar + crosshair (high-DPI) + // VI scale bar + crosshair + ROI (high-DPI) React.useEffect(() => { if (!viUiRef.current) return; // Draw scale bar first (clears canvas) drawScaleBarHiDPI(viUiRef.current, DPR, viZoom, pixelSize || 1, "Å", shapeY, shapeX); - // Then draw crosshair on top - drawViCrosshairHiDPI(viUiRef.current, DPR, localPosX, localPosY, viZoom, viPanX, viPanY, shapeY, shapeX, isDraggingVI); - }, [viZoom, viPanX, viPanY, pixelSize, shapeX, shapeY, localPosX, localPosY, isDraggingVI]); + // Draw crosshair only when ROI is off (ROI replaces the crosshair) + if (!viRoiMode || viRoiMode === "off") { + drawViPositionMarker(viUiRef.current, DPR, localPosX, localPosY, viZoom, viPanX, viPanY, shapeY, shapeX, isDraggingVI); + } else { + // Draw VI ROI instead of crosshair + drawViRoiOverlayHiDPI( + viUiRef.current, DPR, viRoiMode, + localViRoiCenterX, localViRoiCenterY, viRoiRadius || 5, viRoiWidth || 10, viRoiHeight || 10, + viZoom, viPanX, viPanY, shapeY, shapeX, + isDraggingViRoi, isDraggingViRoiResize, isHoveringViRoiResize + ); + } + }, [viZoom, viPanX, viPanY, pixelSize, shapeX, shapeY, localPosX, localPosY, isDraggingVI, + viRoiMode, localViRoiCenterX, localViRoiCenterY, viRoiRadius, viRoiWidth, viRoiHeight, + isDraggingViRoi, isDraggingViRoiResize, isHoveringViRoiResize]); // Generic zoom handler const createZoomHandler = ( @@ -1088,6 +1524,33 @@ function Show4DSTEM() { return dist < RESIZE_HIT_AREA_PX / dpZoom; }; + // Helper: check if point is near VI ROI resize handle (same logic as DP) + // Hit area is capped to avoid overlap with center for small ROIs + const isNearViRoiResizeHandle = (imgX: number, imgY: number): boolean => { + if (!viRoiMode || viRoiMode === "off") return false; + if (viRoiMode === "rect") { + const halfH = (viRoiHeight || 10) / 2; + const halfW = (viRoiWidth || 10) / 2; + const handleX = localViRoiCenterX + halfH; + const handleY = localViRoiCenterY + halfW; + const dist = Math.sqrt((imgX - handleX) ** 2 + (imgY - handleY) ** 2); + const cornerDist = Math.sqrt(halfW ** 2 + halfH ** 2); + const hitArea = Math.min(RESIZE_HIT_AREA_PX / viZoom, cornerDist * 0.5); + return dist < hitArea; + } + if (viRoiMode === "circle" || viRoiMode === "square") { + const radius = viRoiRadius || 5; + const offset = viRoiMode === "square" ? radius : radius * CIRCLE_HANDLE_ANGLE; + const handleX = localViRoiCenterX + offset; + const handleY = localViRoiCenterY + offset; + const dist = Math.sqrt((imgX - handleX) ** 2 + (imgY - handleY) ** 2); + // Cap hit area to 50% of radius so center remains draggable + const hitArea = Math.min(RESIZE_HIT_AREA_PX / viZoom, radius * 0.5); + return dist < hitArea; + } + return false; + }; + // Mouse handlers const handleDpMouseDown = (e: React.MouseEvent) => { const canvas = dpOverlayRef.current; @@ -1175,6 +1638,25 @@ function Show4DSTEM() { const screenY = (e.clientY - rect.top) * (canvas.height / rect.height); const imgX = (screenY - viPanY) / viZoom; const imgY = (screenX - viPanX) / viZoom; + + // Check if VI ROI mode is active - same logic as DP + if (viRoiMode && viRoiMode !== "off") { + // Check if clicking on resize handle + if (isNearViRoiResizeHandle(imgX, imgY)) { + setIsDraggingViRoiResize(true); + return; + } + + // Otherwise, move ROI center to click position (same as DP) + setIsDraggingViRoi(true); + setLocalViRoiCenterX(imgX); + setLocalViRoiCenterY(imgY); + setViRoiCenterX(Math.round(Math.max(0, Math.min(shapeX - 1, imgX)))); + setViRoiCenterY(Math.round(Math.max(0, Math.min(shapeY - 1, imgY)))); + return; + } + + // Regular position selection (when ROI is off) setIsDraggingVI(true); setLocalPosX(imgX); setLocalPosY(imgY); setPosX(Math.round(Math.max(0, Math.min(shapeX - 1, imgX)))); @@ -1182,7 +1664,6 @@ function Show4DSTEM() { }; const handleViMouseMove = (e: React.MouseEvent) => { - if (!isDraggingVI) return; const canvas = virtualOverlayRef.current; if (!canvas) return; const rect = canvas.getBoundingClientRect(); @@ -1190,12 +1671,58 @@ function Show4DSTEM() { const screenY = (e.clientY - rect.top) * (canvas.height / rect.height); const imgX = (screenY - viPanY) / viZoom; const imgY = (screenX - viPanX) / viZoom; + + // Handle VI ROI resize dragging (same pattern as DP) + if (isDraggingViRoiResize) { + const dx = Math.abs(imgX - localViRoiCenterX); + const dy = Math.abs(imgY - localViRoiCenterY); + if (viRoiMode === "rect") { + setViRoiWidth(Math.max(2, Math.round(dy * 2))); + setViRoiHeight(Math.max(2, Math.round(dx * 2))); + } else if (viRoiMode === "square") { + const newHalfSize = Math.max(dx, dy); + setViRoiRadius(Math.max(1, Math.round(newHalfSize))); + } else { + // circle + const newRadius = Math.sqrt(dx ** 2 + dy ** 2); + setViRoiRadius(Math.max(1, Math.round(newRadius))); + } + return; + } + + // Check hover state for resize handles (same as DP) + if (!isDraggingViRoi) { + setIsHoveringViRoiResize(isNearViRoiResizeHandle(imgX, imgY)); + if (viRoiMode && viRoiMode !== "off") return; // Don't update position when ROI active + } + + // Handle VI ROI center dragging (same as DP) + if (isDraggingViRoi) { + setLocalViRoiCenterX(imgX); + setLocalViRoiCenterY(imgY); + setViRoiCenterX(Math.round(Math.max(0, Math.min(shapeX - 1, imgX)))); + setViRoiCenterY(Math.round(Math.max(0, Math.min(shapeY - 1, imgY)))); + return; + } + + // Handle regular position dragging (when ROI is off) + if (!isDraggingVI) return; setLocalPosX(imgX); setLocalPosY(imgY); setPosX(Math.round(Math.max(0, Math.min(shapeX - 1, imgX)))); setPosY(Math.round(Math.max(0, Math.min(shapeY - 1, imgY)))); }; - const handleViMouseUp = () => setIsDraggingVI(false); + const handleViMouseUp = () => { + setIsDraggingVI(false); + setIsDraggingViRoi(false); + setIsDraggingViRoiResize(false); + }; + const handleViMouseLeave = () => { + setIsDraggingVI(false); + setIsDraggingViRoi(false); + setIsDraggingViRoiResize(false); + setIsHoveringViRoiResize(false); + }; const handleViDoubleClick = () => { setViZoom(1); setViPanX(0); setViPanY(0); }; const handleFftDoubleClick = () => { setFftZoom(1); setFftPanX(0); setFftPanY(0); }; @@ -1224,53 +1751,92 @@ function Show4DSTEM() { // ───────────────────────────────────────────────────────────────────────── // Render // ───────────────────────────────────────────────────────────────────────── + + // Export DP handler + const handleExportDP = async () => { + const timestamp = new Date().toISOString().replace(/[:.]/g, '-').slice(0, 19); + const zip = new JSZip(); + const metadata = { + exported_at: new Date().toISOString(), + type: "diffraction_pattern", + scan_position: { x: posX, y: posY }, + scan_shape: { x: shapeX, y: shapeY }, + detector_shape: { x: detX, y: detY }, + roi: { mode: roiMode, center_x: roiCenterX, center_y: roiCenterY, radius_outer: roiRadius, radius_inner: roiRadiusInner }, + display: { colormap: dpColormap, vmin_pct: dpVminPct, vmax_pct: dpVmaxPct, scale_mode: dpScaleMode }, + calibration: { bf_radius: bfRadius, center_x: centerX, center_y: centerY, k_pixel_size: kPixelSize, k_calibrated: kCalibrated }, + }; + zip.file("metadata.json", JSON.stringify(metadata, null, 2)); + const canvasToBlob = (canvas: HTMLCanvasElement): Promise => new Promise((resolve) => canvas.toBlob((blob) => resolve(blob!), 'image/png')); + if (dpCanvasRef.current) zip.file("diffraction_pattern.png", await canvasToBlob(dpCanvasRef.current)); + const zipBlob = await zip.generateAsync({ type: "blob" }); + const link = document.createElement('a'); + link.download = `dp_export_${timestamp}.zip`; + link.href = URL.createObjectURL(zipBlob); + link.click(); + URL.revokeObjectURL(link.href); + }; + + // Export VI handler + const handleExportVI = async () => { + const timestamp = new Date().toISOString().replace(/[:.]/g, '-').slice(0, 19); + const zip = new JSZip(); + const metadata = { + exported_at: new Date().toISOString(), + scan_position: { x: posX, y: posY }, + scan_shape: { x: shapeX, y: shapeY }, + detector_shape: { x: detX, y: detY }, + roi: { mode: roiMode, center_x: roiCenterX, center_y: roiCenterY, radius_outer: roiRadius, radius_inner: roiRadiusInner }, + display: { dp_colormap: dpColormap, vi_colormap: viColormap, dp_scale_mode: dpScaleMode, vi_scale_mode: viScaleMode }, + calibration: { bf_radius: bfRadius, center_x: centerX, center_y: centerY, pixel_size: pixelSize, k_pixel_size: kPixelSize }, + }; + zip.file("metadata.json", JSON.stringify(metadata, null, 2)); + const canvasToBlob = (canvas: HTMLCanvasElement): Promise => new Promise((resolve) => canvas.toBlob((blob) => resolve(blob!), 'image/png')); + if (virtualCanvasRef.current) zip.file("virtual_image.png", await canvasToBlob(virtualCanvasRef.current)); + if (dpCanvasRef.current) zip.file("diffraction_pattern.png", await canvasToBlob(dpCanvasRef.current)); + if (fftCanvasRef.current) zip.file("fft.png", await canvasToBlob(fftCanvasRef.current)); + const zipBlob = await zip.generateAsync({ type: "blob" }); + const link = document.createElement('a'); + link.download = `4dstem_export_${timestamp}.zip`; + link.href = URL.createObjectURL(zipBlob); + link.click(); + URL.revokeObjectURL(link.href); + }; + + // Common styles for panel control groups (fills parent width = canvas width) + const panelControlStyle = { + display: "flex", + alignItems: "center", + gap: `${SPACING.SM}px`, + width: "100%", + boxSizing: "border-box", + }; + return ( - - {/* Wrapper to ensure header and content have same width */} - - {/* Header */} - - - 4D-STEM Explorer - - - - {shapeX}×{shapeY} scan | {detX}×{detY} det - - { - // Reset position to center - setPosX(Math.floor(shapeX / 2)); - setPosY(Math.floor(shapeY / 2)); - // Reset ROI to detector center, point mode - setRoiCenterX(centerX); - setRoiCenterY(centerY); - setRoiRadius(bfRadius * 0.5); - setRoiMode("point"); - // Reset zoom/pan - setDpZoom(1); setDpPanX(0); setDpPanY(0); - setViZoom(1); setViPanX(0); setViPanY(0); - setFftZoom(1); setFftPanX(0); setFftPanY(0); - // Reset colormap and bandpass - setColormap("inferno"); - setBandpass([0, 0]); - }} - sx={{ ...controlPanel.button }} - > - Reset + + {/* HEADER */} + + 4D-STEM Explorer + + + {/* MAIN CONTENT: Two columns */} + + {/* LEFT COLUMN: DP Panel */} + + {/* DP Header */} + + + DP at ({Math.round(localPosX)}, {Math.round(localPosY)}) + k: ({Math.round(localKx)}, {Math.round(localKy)}) + + + + - - - {/* LEFT: DP */} - - - DP at ({Math.round(localPosX)}, {Math.round(localPosY)}) - k: ({Math.round(localKx)}, {Math.round(localKy)}) - - + {/* DP Canvas */} + - {/* High-DPI UI overlay for crisp scale bar */} - + - - {/* RIGHT: Virtual Image + FFT */} - - - - - Virtual Image - - { - const timestamp = new Date().toISOString().replace(/[:.]/g, '-').slice(0, 19); - const zip = new JSZip(); - - // Add metadata JSON - const metadata = { - exported_at: new Date().toISOString(), - scan_position: { x: posX, y: posY }, - scan_shape: { x: shapeX, y: shapeY }, - detector_shape: { x: detX, y: detY }, - roi: { - mode: roiMode, - center_x: roiCenterX, - center_y: roiCenterY, - radius_outer: roiRadius, - radius_inner: roiRadiusInner, - }, - display: { - colormap: colormap, - log_scale: logScale, - }, - calibration: { - bf_radius: bfRadius, - center_x: centerX, - center_y: centerY, - pixel_size: pixelSize, - k_pixel_size: kPixelSize, - }, - }; - zip.file("metadata.json", JSON.stringify(metadata, null, 2)); - - // Helper to convert canvas to blob - const canvasToBlob = (canvas: HTMLCanvasElement): Promise => { - return new Promise((resolve) => { - canvas.toBlob((blob) => resolve(blob!), 'image/png'); - }); - }; - - // Add images - const viCanvas = virtualCanvasRef.current; - if (viCanvas) { - const blob = await canvasToBlob(viCanvas); - zip.file("virtual_image.png", blob); - } - const dpCanvas = dpCanvasRef.current; - if (dpCanvas) { - const blob = await canvasToBlob(dpCanvas); - zip.file("diffraction_pattern.png", blob); - } - const fftCanvas = fftCanvasRef.current; - if (fftCanvas) { - const blob = await canvasToBlob(fftCanvas); - zip.file("fft.png", blob); - } - - // Generate and download ZIP - const zipBlob = await zip.generateAsync({ type: "blob" }); - const link = document.createElement('a'); - link.download = `4dstem_export_${timestamp}.zip`; - link.href = URL.createObjectURL(zipBlob); - link.click(); - URL.revokeObjectURL(link.href); - }} - sx={{ ...controlPanel.button }} - > - Export - - - - - - {/* High-DPI UI overlay for crisp scale bar */} - + + {/* DP Stats Bar */} + {dpStats && dpStats.length === 4 && ( + + Mean {formatStat(dpStats[0])} + Min {formatStat(dpStats[1])} + Max {formatStat(dpStats[2])} + Std {formatStat(dpStats[3])} + + )} + + {/* DP Controls - two rows with histogram on right */} + + {/* Left: two rows of controls */} + + {/* Row 1: Detector + slider */} + + Detector: + + {(roiMode === "circle" || roiMode === "square" || roiMode === "annular") && ( + <> + { + if (roiMode === "annular") { + const [inner, outer] = v as number[]; + setRoiRadiusInner(Math.min(inner, outer - 1)); + setRoiRadius(Math.max(outer, inner + 1)); + } else { + setRoiRadius(v as number); + } + }} + min={1} + max={Math.min(detX, detY) / 2} + size="small" + sx={{ + width: roiMode === "annular" ? 100 : 70, + mx: 1, + "& .MuiSlider-thumb": { width: 14, height: 14 } + }} + /> + + {roiMode === "annular" ? `${Math.round(roiRadiusInner)}-${Math.round(roiRadius)}px` : `${Math.round(roiRadius)}px`} + + + )} + + {/* Row 2: Presets + Color + Scale */} + + { setRoiMode("circle"); setRoiRadius(bfRadius || 10); setRoiCenterX(centerX); setRoiCenterY(centerY); }} sx={{ color: "#4f4", fontSize: 11, fontWeight: "bold", cursor: "pointer", "&:hover": { textDecoration: "underline" } }}>BF + { setRoiMode("annular"); setRoiRadiusInner((bfRadius || 10) * 0.5); setRoiRadius(bfRadius || 10); setRoiCenterX(centerX); setRoiCenterY(centerY); }} sx={{ color: "#4af", fontSize: 11, fontWeight: "bold", cursor: "pointer", "&:hover": { textDecoration: "underline" } }}>ABF + { setRoiMode("annular"); setRoiRadiusInner(bfRadius || 10); setRoiRadius(Math.min((bfRadius || 10) * 3, Math.min(detX, detY) / 2 - 2)); setRoiCenterX(centerX); setRoiCenterY(centerY); }} sx={{ color: "#fa4", fontSize: 11, fontWeight: "bold", cursor: "pointer", "&:hover": { textDecoration: "underline" } }}>ADF + Color: + + Scale: + + + + {/* Right: Histogram spanning both rows */} + + { setDpVminPct(min); setDpVmaxPct(max); }} width={100} height={50} /> + - - - - FFT + {/* RIGHT COLUMN: VI Panel + FFT (when shown) */} + + {/* VI Header */} + + Virtual Image + + + {shapeX}×{shapeY} | {detX}×{detY} - setShowFft(e.target.checked)} - size="small" - sx={switchStyles.medium} - /> + FFT: + setShowFft(e.target.checked)} size="small" sx={switchStyles.small} /> + + - - - - + + + {/* VI Canvas */} + + + + - - - - - {/* Controls - Organized in 3 rows */} - - - {/* Row 1: Presets + Detector */} - - {/* Detector Presets - only show if bf_radius is calibrated */} - {bfRadius > 0 && ( - - Presets: - - - - - + + {/* VI Stats Bar */} + {viStats && viStats.length === 4 && ( + + Mean {formatStat(viStats[0])} + Min {formatStat(viStats[1])} + Max {formatStat(viStats[2])} + Std {formatStat(viStats[3])} + )} - {/* Virtual Detector Mode */} - - Detector: - - {(roiMode === "circle" || roiMode === "square") && ( - <> - {roiMode === "circle" ? "r:" : "½:"} - setRoiRadius(v as number)} - min={1} - max={Math.min(detX, detY) / 2} - size="small" - sx={{ width: 80 }} - /> - - {Math.round(roiRadius || 10)}px - - - )} - {roiMode === "rect" && ( - <> - W: - setRoiWidth(v as number)} - min={2} - max={detY} - size="small" - sx={{ width: 60 }} - /> - H: - setRoiHeight(v as number)} - min={2} - max={detX} - size="small" - sx={{ width: 60 }} - /> - - {Math.round(roiWidth || 20)}×{Math.round(roiHeight || 10)} - - - )} - {roiMode === "annular" && ( - <> - { - const [inner, outer] = v as number[]; - setRoiRadiusInner(inner); - setRoiRadius(outer); - }} - min={0} - max={Math.min(detX, detY) / 2} - size="small" - sx={{ width: 120 }} - valueLabelDisplay="auto" - /> - - {Math.round(roiRadiusInner || 5)}-{Math.round(roiRadius || 10)}px - - - )} - + {/* VI Controls - Two rows with histogram on right */} + + {/* Left: Two rows of controls */} + + {/* Row 1: ROI selector */} + + ROI: + + {viRoiMode && viRoiMode !== "off" && ( + <> + {(viRoiMode === "circle" || viRoiMode === "square") && ( + <> + setViRoiRadius(v as number)} + min={1} + max={Math.min(shapeX, shapeY) / 2} + size="small" + sx={{ width: 80, mx: 1 }} + /> + + {Math.round(viRoiRadius || 5)}px + + + )} + {summedDpCount > 0 && ( + + {summedDpCount} pos + + )} + + )} + + {/* Row 2: Color + Scale */} + + Color: + + Scale: + + + + {/* Right: Histogram spanning both rows */} + + { setViVminPct(min); setViVmaxPct(max); }} width={100} height={50} /> + + - {/* Path Animation Controls - only show if path is defined */} - {pathLength > 0 && ( - - Path: - { setPathPlaying(false); setPathIndex(0); }} - sx={{ color: "#888", fontSize: 14, cursor: "pointer", "&:hover": { color: "#fff" }, px: 0.5 }} - title="Stop" - >⏹ - setPathPlaying(!pathPlaying)} - sx={{ color: pathPlaying ? "#0f0" : "#888", fontSize: 14, cursor: "pointer", "&:hover": { color: "#fff" }, px: 0.5 }} - title={pathPlaying ? "Pause" : "Play"} - >{pathPlaying ? "⏸" : "▶"} - - {pathIndex + 1}/{pathLength} - - { setPathPlaying(false); setPathIndex(v as number); }} - min={0} - max={Math.max(0, pathLength - 1)} - size="small" - sx={{ width: 100 }} - /> - + {/* FFT Panel (conditionally shown) */} + {showFft && ( + + + FFT + + + + + + + )} - - - {/* Row 2: Colormap + Log + Contrast */} - - - Colormap: - - + + - - Log: - setLogScale(e.target.checked)} - size="small" - sx={switchStyles.medium} - /> - + {/* BOTTOM CONTROLS - Path only (FFT toggle moved to VI panel) */} + {pathLength > 0 && ( + + + Path: + { setPathPlaying(false); setPathIndex(0); }} sx={{ color: "#888", fontSize: 14, cursor: "pointer", "&:hover": { color: "#fff" }, px: 0.5 }} title="Stop">⏹ + setPathPlaying(!pathPlaying)} sx={{ color: pathPlaying ? "#0f0" : "#888", fontSize: 14, cursor: "pointer", "&:hover": { color: "#fff" }, px: 0.5 }} title={pathPlaying ? "Pause" : "Play"}>{pathPlaying ? "⏸" : "▶"} + {pathIndex + 1}/{pathLength} + { setPathPlaying(false); setPathIndex(v as number); }} min={0} max={Math.max(0, pathLength - 1)} size="small" sx={{ width: 100 }} /> + - + )} ); } diff --git a/widget/src/quantem/widget/show4dstem.py b/widget/src/quantem/widget/show4dstem.py index 8c1acb2b..579986e7 100644 --- a/widget/src/quantem/widget/show4dstem.py +++ b/widget/src/quantem/widget/show4dstem.py @@ -85,7 +85,11 @@ class Show4DSTEM(anywidget.AnyWidget): # Pre-normalized uint8 frame as bytes (no base64!) frame_bytes = traitlets.Bytes(b"").tag(sync=True) - # Log scale toggle + # Scale mode for DP: "linear", "log", or "power" + dp_scale_mode = traitlets.Unicode("linear").tag(sync=True) + dp_power_exp = traitlets.Float(0.5).tag(sync=True) # Power exponent (e.g., 0.5 = sqrt) + + # Legacy log_scale (kept for backward compatibility, syncs with dp_scale_mode) log_scale = traitlets.Bool(False).tag(sync=True) # ========================================================================= @@ -117,6 +121,18 @@ class Show4DSTEM(anywidget.AnyWidget): # ========================================================================= virtual_image_bytes = traitlets.Bytes(b"").tag(sync=True) + # ========================================================================= + # VI ROI (real-space region selection for summed DP) + # ========================================================================= + vi_roi_mode = traitlets.Unicode("off").tag(sync=True) # "off", "circle", "rect" + vi_roi_center_x = traitlets.Float(0.0).tag(sync=True) + vi_roi_center_y = traitlets.Float(0.0).tag(sync=True) + vi_roi_radius = traitlets.Float(5.0).tag(sync=True) + vi_roi_width = traitlets.Float(10.0).tag(sync=True) + vi_roi_height = traitlets.Float(10.0).tag(sync=True) + summed_dp_bytes = traitlets.Bytes(b"").tag(sync=True) # Summed DP from VI ROI + summed_dp_count = traitlets.Int(0).tag(sync=True) # Number of positions summed + # ========================================================================= # Scale Bar # ========================================================================= @@ -133,6 +149,18 @@ class Show4DSTEM(anywidget.AnyWidget): path_interval_ms = traitlets.Int(100).tag(sync=True) # ms between frames path_loop = traitlets.Bool(True).tag(sync=True) # loop when reaching end + # ========================================================================= + # Auto-detection trigger (frontend sets to True, backend resets to False) + # ========================================================================= + auto_detect_trigger = traitlets.Bool(False).tag(sync=True) + + # ========================================================================= + # Statistics for display (mean, min, max, std) + # ========================================================================= + dp_stats = traitlets.List(traitlets.Float(), default_value=[0.0, 0.0, 0.0, 0.0]).tag(sync=True) + vi_stats = traitlets.List(traitlets.Float(), default_value=[0.0, 0.0, 0.0, 0.0]).tag(sync=True) + mask_dc = traitlets.Bool(True).tag(sync=True) # Mask center pixel for DP stats + def __init__( self, data: "Dataset4dstem | np.ndarray", @@ -202,26 +230,44 @@ def __init__( self.pos_y = self.shape_y // 2 # Precompute global range for consistent scaling self._compute_global_range() - # Setup center and BF/ADF radii based on detector size + # Setup center and BF radius + # If user provides explicit values, use them + # Otherwise, auto-detect from the data for accurate presets det_size = min(self.det_x, self.det_y) - if center is not None: + if center is not None and bf_radius is not None: + # User provided both - use explicit values + self.center_x = float(center[0]) + self.center_y = float(center[1]) + self.bf_radius = float(bf_radius) + elif center is not None: + # User provided center only - use it with default bf_radius self.center_x = float(center[0]) self.center_y = float(center[1]) + self.bf_radius = det_size * DEFAULT_BF_RATIO + elif bf_radius is not None: + # User provided bf_radius only - use detector center + self.center_x = float(self.det_y / 2) + self.center_y = float(self.det_x / 2) + self.bf_radius = float(bf_radius) else: + # Neither provided - auto-detect from data + # Set defaults first (will be overwritten by auto-detect) self.center_x = float(self.det_y / 2) self.center_y = float(self.det_x / 2) - self.bf_radius = float(bf_radius) if bf_radius is not None else det_size * DEFAULT_BF_RATIO + self.bf_radius = det_size * DEFAULT_BF_RATIO + # Auto-detect center and bf_radius from the data + self._auto_detect_center_silent() - # Pre-compute and cache common virtual images (BF, ABF, LAADF, HAADF) + # Pre-compute and cache common virtual images (BF, ABF, ADF) + # Each cache stores (bytes, stats) tuple self._cached_bf_virtual = None self._cached_abf_virtual = None - self._cached_laadf_virtual = None - self._cached_haadf_virtual = None + self._cached_adf_virtual = None if precompute_virtual_images: self._precompute_common_virtual_images() # Update frame when position or settings change - self.observe(self._update_frame, names=["pos_x", "pos_y", "log_scale"]) + self.observe(self._update_frame, names=["pos_x", "pos_y", "dp_scale_mode", "dp_power_exp"]) self.observe(self._on_roi_change, names=[ "roi_center_x", "roi_center_y", "roi_radius", "roi_radius_inner", "roi_active", "roi_mode", "roi_width", "roi_height" @@ -240,6 +286,23 @@ def __init__( # Path animation: observe index changes from frontend self.observe(self._on_path_index_change, names=["path_index"]) + # Auto-detect trigger: observe changes from frontend + self.observe(self._on_auto_detect_trigger, names=["auto_detect_trigger"]) + + # VI ROI: observe changes for summed DP computation + # Initialize VI ROI center to scan center with reasonable default sizes + self.vi_roi_center_x = float(self.shape_x / 2) + self.vi_roi_center_y = float(self.shape_y / 2) + # Set initial ROI size to ~15% of minimum scan dimension + default_roi_size = max(3, min(self.shape_x, self.shape_y) * 0.15) + self.vi_roi_radius = float(default_roi_size) + self.vi_roi_width = float(default_roi_size * 2) + self.vi_roi_height = float(default_roi_size) + self.observe(self._on_vi_roi_change, names=[ + "vi_roi_mode", "vi_roi_center_x", "vi_roi_center_y", + "vi_roi_radius", "vi_roi_width", "vi_roi_height" + ]) + def __repr__(self) -> str: k_unit = "mrad" if self.k_calibrated else "px" return ( @@ -348,6 +411,13 @@ def _on_path_index_change(self, change): self.pos_x = max(0, min(self.shape_x - 1, x)) self.pos_y = max(0, min(self.shape_y - 1, y)) + def _on_auto_detect_trigger(self, change): + """Called when auto_detect_trigger is set to True from frontend.""" + if change["new"]: + self.auto_detect_center() + # Reset trigger to allow re-triggering + self.auto_detect_trigger = False + # ========================================================================= # Path Animation Patterns # ========================================================================= @@ -526,6 +596,96 @@ def roi_rect( self.roi_height = float(height) return self + def auto_detect_center(self) -> "Show4DSTEM": + """ + Automatically detect BF disk center and radius using centroid. + + This method analyzes the summed diffraction pattern to find the + bright field disk center and estimate its radius. The detected + values are applied to the widget's calibration (center_x, center_y, + bf_radius). + + Returns + ------- + Show4DSTEM + Self for method chaining. + + Examples + -------- + >>> widget = Show4DSTEM(data) + >>> widget.auto_detect_center() # Auto-detect and apply + """ + # Sum all diffraction patterns to get average + if self._data.ndim == 4: + summed_dp = self._data.sum(axis=(0, 1)).astype(np.float64) + else: + summed_dp = self._data.sum(axis=0).astype(np.float64) + + # Threshold at mean + std to isolate BF disk + threshold = summed_dp.mean() + summed_dp.std() + mask = summed_dp > threshold + + # Avoid division by zero + total = mask.sum() + if total == 0: + return self + + # Calculate centroid using meshgrid + y_coords, x_coords = np.meshgrid( + np.arange(self.det_x), np.arange(self.det_y), indexing='ij' + ) + cx = float((x_coords * mask).sum() / total) + cy = float((y_coords * mask).sum() / total) + + # Estimate radius from mask area (A = pi*r^2) + radius = float(np.sqrt(total / np.pi)) + + # Apply detected values + self.center_x = cx + self.center_y = cy + self.bf_radius = radius + + # Also update ROI to center + self.roi_center_x = cx + self.roi_center_y = cy + + # Recompute cached virtual images with new calibration + self._precompute_common_virtual_images() + + return self + + def _auto_detect_center_silent(self): + """Auto-detect center without updating ROI (used during __init__).""" + # Sum all diffraction patterns to get average + if self._data.ndim == 4: + summed_dp = self._data.sum(axis=(0, 1)).astype(np.float64) + else: + summed_dp = self._data.sum(axis=0).astype(np.float64) + + # Threshold at mean + std to isolate BF disk + threshold = summed_dp.mean() + summed_dp.std() + mask = summed_dp > threshold + + # Avoid division by zero + total = mask.sum() + if total == 0: + return + + # Calculate centroid using meshgrid + y_coords, x_coords = np.meshgrid( + np.arange(self.det_x), np.arange(self.det_y), indexing='ij' + ) + cx = float((x_coords * mask).sum() / total) + cy = float((y_coords * mask).sum() / total) + + # Estimate radius from mask area (A = pi*r^2) + radius = float(np.sqrt(total / np.pi)) + + # Apply detected values (don't update ROI - that happens later in __init__) + self.center_x = cx + self.center_y = cy + self.bf_radius = radius + def _compute_global_range(self): """Compute global min/max from sampled frames for consistent scaling.""" @@ -564,19 +724,40 @@ def _get_frame(self, x: int, y: int): def _update_frame(self, change=None): """Send pre-normalized uint8 frame to frontend.""" frame = self._get_frame(self.pos_x, self.pos_y) - - # Determine value range - if self.log_scale: - vmin, vmax = self._log_min, self._log_max + frame = frame.astype(np.float32) + + # Compute stats from raw frame (optionally mask DC component) + if self.mask_dc: + # Mask center 3x3 region for stats + cx, cy = self.det_x // 2, self.det_y // 2 + stats_frame = frame.copy() + stats_frame[max(0, cx-1):cx+2, max(0, cy-1):cy+2] = np.nan + self.dp_stats = [ + float(np.nanmean(stats_frame)), + float(np.nanmin(stats_frame)), + float(np.nanmax(stats_frame)), + float(np.nanstd(stats_frame)), + ] else: + self.dp_stats = [ + float(frame.mean()), + float(frame.min()), + float(frame.max()), + float(frame.std()), + ] + + # Apply scale transformation + if self.dp_scale_mode == "log": + frame = np.log1p(frame) + vmin, vmax = self._log_min, self._log_max + elif self.dp_scale_mode == "power": + # Power scaling (e.g., sqrt when exp=0.5) + frame = np.power(np.maximum(frame, 0), self.dp_power_exp) + vmin = np.power(max(self._global_min, 0), self.dp_power_exp) + vmax = np.power(max(self._global_max, 0), self.dp_power_exp) + else: # linear vmin, vmax = self._global_min, self._global_max - # Apply log scale if enabled - if self.log_scale: - frame = np.log1p(frame.astype(np.float32)) - else: - frame = frame.astype(np.float32) - # Normalize to 0-255 if vmax > vmin: normalized = np.clip((frame - vmin) / (vmax - vmin) * 255, 0, 255) @@ -593,6 +774,61 @@ def _on_roi_change(self, change=None): return self._compute_virtual_image_from_roi() + def _on_vi_roi_change(self, change=None): + """Compute summed DP when VI ROI changes.""" + if self.vi_roi_mode == "off": + self.summed_dp_bytes = b"" + self.summed_dp_count = 0 + return + self._compute_summed_dp_from_vi_roi() + + def _compute_summed_dp_from_vi_roi(self): + """Sum diffraction patterns from positions inside VI ROI.""" + # Create mask in scan space + # y (rows) corresponds to vi_roi_center_x, x (cols) corresponds to vi_roi_center_y + rows, cols = np.ogrid[:self.shape_x, :self.shape_y] + + if self.vi_roi_mode == "circle": + mask = (rows - self.vi_roi_center_x) ** 2 + (cols - self.vi_roi_center_y) ** 2 <= self.vi_roi_radius ** 2 + elif self.vi_roi_mode == "square": + # Square uses vi_roi_radius as half-size + half_size = self.vi_roi_radius + mask = (np.abs(rows - self.vi_roi_center_x) <= half_size) & (np.abs(cols - self.vi_roi_center_y) <= half_size) + elif self.vi_roi_mode == "rect": + half_w = self.vi_roi_width / 2 + half_h = self.vi_roi_height / 2 + mask = (np.abs(rows - self.vi_roi_center_x) <= half_h) & (np.abs(cols - self.vi_roi_center_y) <= half_w) + else: + return + + # Get positions inside mask + positions = np.argwhere(mask) + if len(positions) == 0: + self.summed_dp_bytes = b"" + self.summed_dp_count = 0 + return + + # Average DPs from all positions (average is more useful than sum) + # positions from argwhere are (row, col) = (x_idx, y_idx) in our naming + summed_dp = np.zeros((self.det_x, self.det_y), dtype=np.float64) + for row_idx, col_idx in positions: + summed_dp += self._get_frame(row_idx, col_idx).astype(np.float64) + + self.summed_dp_count = len(positions) + + # Convert to average + avg_dp = summed_dp / len(positions) + + # Normalize to 0-255 for display + vmin, vmax = avg_dp.min(), avg_dp.max() + if vmax > vmin: + normalized = np.clip((avg_dp - vmin) / (vmax - vmin) * 255, 0, 255) + normalized = normalized.astype(np.uint8) + else: + normalized = np.zeros(avg_dp.shape, dtype=np.uint8) + + self.summed_dp_bytes = normalized.tobytes() + def _create_circular_mask(self, cx: float, cy: float, radius: float): """Create circular mask (boolean).""" y, x = np.ogrid[:self.det_x, :self.det_y] @@ -621,47 +857,50 @@ def _create_rect_mask(self, cx: float, cy: float, half_width: float, half_height return mask def _precompute_common_virtual_images(self): - """Pre-compute BF/ABF/LAADF/HAADF virtual images for instant preset switching.""" + """Pre-compute BF/ABF/ADF virtual images for instant preset switching.""" cx, cy, bf = self.center_x, self.center_y, self.bf_radius - self._cached_bf_virtual = self._normalize_to_bytes( - self._fast_masked_sum(self._create_circular_mask(cx, cy, bf))) - self._cached_abf_virtual = self._normalize_to_bytes( - self._fast_masked_sum(self._create_annular_mask(cx, cy, bf * 0.5, bf))) - self._cached_laadf_virtual = self._normalize_to_bytes( - self._fast_masked_sum(self._create_annular_mask(cx, cy, bf, bf * 2.0))) - self._cached_haadf_virtual = self._normalize_to_bytes( - self._fast_masked_sum(self._create_annular_mask(cx, cy, bf * 2.0, bf * 4.0))) - - def _get_cached_preset(self) -> bytes | None: - """Check if current ROI matches a cached preset and return it.""" + # Cache both bytes and stats for each preset + bf_arr = self._fast_masked_sum(self._create_circular_mask(cx, cy, bf)) + abf_arr = self._fast_masked_sum(self._create_annular_mask(cx, cy, bf * 0.5, bf)) + adf_arr = self._fast_masked_sum(self._create_annular_mask(cx, cy, bf, bf * 4.0)) + + self._cached_bf_virtual = ( + self._normalize_to_bytes(bf_arr, update_vi_stats=False), + [float(bf_arr.mean()), float(bf_arr.min()), float(bf_arr.max()), float(bf_arr.std())] + ) + self._cached_abf_virtual = ( + self._normalize_to_bytes(abf_arr, update_vi_stats=False), + [float(abf_arr.mean()), float(abf_arr.min()), float(abf_arr.max()), float(abf_arr.std())] + ) + self._cached_adf_virtual = ( + self._normalize_to_bytes(adf_arr, update_vi_stats=False), + [float(adf_arr.mean()), float(adf_arr.min()), float(adf_arr.max()), float(adf_arr.std())] + ) + + def _get_cached_preset(self) -> tuple[bytes, list[float]] | None: + """Check if current ROI matches a cached preset and return (bytes, stats) tuple.""" # Must be centered on detector center if abs(self.roi_center_x - self.center_x) >= 1 or abs(self.roi_center_y - self.center_y) >= 1: return None - + bf = self.bf_radius - + # BF: circle at bf_radius if (self.roi_mode == "circle" and abs(self.roi_radius - bf) < 1): return self._cached_bf_virtual - + # ABF: annular at 0.5*bf to bf - if (self.roi_mode == "annular" and - abs(self.roi_radius_inner - bf * 0.5) < 1 and + if (self.roi_mode == "annular" and + abs(self.roi_radius_inner - bf * 0.5) < 1 and abs(self.roi_radius - bf) < 1): return self._cached_abf_virtual - - # LAADF: annular at bf to 2*bf - if (self.roi_mode == "annular" and - abs(self.roi_radius_inner - bf) < 1 and - abs(self.roi_radius - bf * 2.0) < 1): - return self._cached_laadf_virtual - - # HAADF: annular at 2*bf to 4*bf - if (self.roi_mode == "annular" and - abs(self.roi_radius_inner - bf * 2.0) < 1 and + + # ADF: annular at bf to 4*bf (combines LAADF + HAADF) + if (self.roi_mode == "annular" and + abs(self.roi_radius_inner - bf) < 1 and abs(self.roi_radius - bf * 4.0) < 1): - return self._cached_haadf_virtual - + return self._cached_adf_virtual + return None def _fast_masked_sum(self, mask) -> "np.ndarray": @@ -670,8 +909,17 @@ def _fast_masked_sum(self, mask) -> "np.ndarray": return (self._data.astype(np.float32) * mask).sum(axis=(2, 3)) return (self._data.astype(np.float32) * mask).sum(axis=(1, 2)).reshape(self._scan_shape) - def _normalize_to_bytes(self, arr: "np.ndarray") -> bytes: + def _normalize_to_bytes(self, arr: "np.ndarray", update_vi_stats: bool = True) -> bytes: """Normalize array to uint8 bytes.""" + # Compute VI stats from raw data + if update_vi_stats: + self.vi_stats = [ + float(arr.mean()), + float(arr.min()), + float(arr.max()), + float(arr.std()), + ] + vmin, vmax = float(arr.min()), float(arr.max()) if vmax > vmin: norm = np.clip((arr - vmin) / (vmax - vmin) * 255, 0, 255).astype(np.uint8) @@ -683,7 +931,10 @@ def _compute_virtual_image_from_roi(self): """Compute virtual image based on ROI mode.""" cached = self._get_cached_preset() if cached is not None: - self.virtual_image_bytes = cached + # Cached preset returns (bytes, stats) tuple + vi_bytes, vi_stats = cached + self.virtual_image_bytes = vi_bytes + self.vi_stats = vi_stats return cx, cy = self.roi_center_x, self.roi_center_y From 6c2fa20ca068efd5562de4f10e4c669e6d0946d4 Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Mon, 19 Jan 2026 20:33:43 -0800 Subject: [PATCH 09/27] update config for theme inheritance, remove stale tests --- widget/js/CONFIG.ts | 6 +- widget/tests/test_widget.py | 110 +++++++++++++++++++++--------------- 2 files changed, 70 insertions(+), 46 deletions(-) diff --git a/widget/js/CONFIG.ts b/widget/js/CONFIG.ts index 9443471a..74ba8914 100644 --- a/widget/js/CONFIG.ts +++ b/widget/js/CONFIG.ts @@ -73,8 +73,10 @@ export const CONTROL_PANEL = { export const CONTAINER = { ROOT: { p: 2, - bgcolor: COLORS.BG, - color: COLORS.TEXT_PRIMARY, + // Use transparent background to inherit from parent (light/dark mode aware) + bgcolor: "transparent", + // Inherit text color from parent for theme awareness + color: "inherit", fontFamily: "monospace", borderRadius: 1, // CRITICAL: Allow dropdowns to overflow diff --git a/widget/tests/test_widget.py b/widget/tests/test_widget.py index 9a880b76..45524bc2 100644 --- a/widget/tests/test_widget.py +++ b/widget/tests/test_widget.py @@ -31,50 +31,6 @@ def test_show4dstem_flattened_scan_shape_mapping(): assert np.array_equal(frame, np.full((2, 2), 5, dtype=np.float32)) -def test_roi_circle_integrated_value(): - data = np.zeros((1, 1, 5, 5), dtype=np.float32) - rows = np.arange(5, dtype=np.float32)[:, None] - cols = np.arange(5, dtype=np.float32)[None, :] - data[0, 0] = rows * 10 + cols - widget = Show4DSTEM(data, center=(2, 2), bf_radius=1, log_scale=False) - widget.roi_mode = "circle" - widget.roi_center_x = 2 - widget.roi_center_y = 2 - widget.roi_radius = 1 - widget.roi_active = True - widget._on_roi_change() - assert np.isclose(widget.roi_integrated_value, 110.0) - - -def test_scan_image_bf_mode(): - base = np.array( - [ - [1, 2, 3], - [4, 5, 6], - [7, 8, 9], - ], - dtype=np.float32, - ) - data = np.zeros((2, 2, 3, 3), dtype=np.float32) - for x in range(2): - for y in range(2): - data[x, y] = base + x * 100 + y * 10 - - widget = Show4DSTEM(data, center=(1, 1), bf_radius=1, log_scale=False) - widget.scan_mode = "bf" - widget.show_scan_view = True - widget._compute_scan_image() - - actual = np.frombuffer(widget.scan_image_bytes, dtype=np.uint8).reshape(2, 2) - scan_image = np.array([[25, 75], [525, 575]], dtype=np.float32) - expected = np.clip( - (scan_image - scan_image.min()) / (scan_image.max() - scan_image.min()) * 255, - 0, - 255, - ).astype(np.uint8) - assert np.array_equal(actual, expected) - - def test_log_scale_changes_frame_bytes(): data = np.array([[[[0, 1], [3, 7]]]], dtype=np.float32) widget = Show4DSTEM(data, log_scale=True) @@ -85,3 +41,69 @@ def test_log_scale_changes_frame_bytes(): linear_bytes = bytes(widget.frame_bytes) assert log_bytes != linear_bytes + + +def test_auto_detect_center(): + """Test automatic center spot detection using centroid.""" + # Create data with a bright spot at (3, 3) in a 7x7 detector + data = np.zeros((2, 2, 7, 7), dtype=np.float32) + # Add a bright circular spot centered at (3, 3) + for i in range(7): + for j in range(7): + dist = np.sqrt((i - 3) ** 2 + (j - 3) ** 2) + if dist <= 1.5: + data[:, :, i, j] = 100.0 + + widget = Show4DSTEM(data, precompute_virtual_images=False) + # Initial center should be at detector center (3.5, 3.5) + assert widget.center_x == 3.5 + assert widget.center_y == 3.5 + + # Run auto-detection + widget.auto_detect_center() + + # Center should be detected near (3, 3) + assert abs(widget.center_x - 3.0) < 0.5 + assert abs(widget.center_y - 3.0) < 0.5 + # BF radius should be approximately sqrt(pi*r^2 / pi) = r ~ 1.5 + assert widget.bf_radius > 0 + + +def test_adf_preset_cache(): + """Test that ADF preset uses combined bf to 4*bf range.""" + data = np.random.rand(4, 4, 16, 16).astype(np.float32) + widget = Show4DSTEM(data, center=(8, 8), bf_radius=2) + + # Check that ADF cache exists (replaced LAADF/HAADF) + assert widget._cached_adf_virtual is not None + assert not hasattr(widget, "_cached_laadf_virtual") + assert not hasattr(widget, "_cached_haadf_virtual") + + # Set ROI to match ADF range + widget.roi_mode = "annular" + widget.roi_center_x = 8 + widget.roi_center_y = 8 + widget.roi_radius_inner = 2 # bf + widget.roi_radius = 8 # 4*bf + + # Should return cached value + cached = widget._get_cached_preset() + assert cached == widget._cached_adf_virtual + + +def test_rectangular_scan_shape(): + """Test that rectangular (non-square) scans work correctly.""" + # Non-square scan: 4 rows x 8 columns + data = np.random.rand(4, 8, 16, 16).astype(np.float32) + widget = Show4DSTEM(data) + + assert widget.shape_x == 4 + assert widget.shape_y == 8 + assert widget.det_x == 16 + assert widget.det_y == 16 + + # Verify frame retrieval works at corners + frame_00 = widget._get_frame(0, 0) + frame_37 = widget._get_frame(3, 7) + assert frame_00.shape == (16, 16) + assert frame_37.shape == (16, 16) From ac8fe471f48095eca30ae9c1cecb4b1bfc8685b1 Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Mon, 19 Jan 2026 20:35:44 -0800 Subject: [PATCH 10/27] ignore Playwright testing artifacts --- .gitignore | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.gitignore b/.gitignore index d87d85c8..e51818a0 100644 --- a/.gitignore +++ b/.gitignore @@ -191,3 +191,9 @@ ipynb-playground/ # widget (JS build artifacts) node_modules/ widget/src/quantem/widget/static/ + +# Playwright testing +playwright-report/ +playwright/ +test-results/ +playwright*.config.ts From 673bcf102d2ec32c665f66b2745dd4bfe9d3162d Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Mon, 19 Jan 2026 20:39:04 -0800 Subject: [PATCH 11/27] make Reset/Export buttons more compact --- widget/js/show4dstem.tsx | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/widget/js/show4dstem.tsx b/widget/js/show4dstem.tsx index edee1836..1a69ff89 100644 --- a/widget/js/show4dstem.tsx +++ b/widget/js/show4dstem.tsx @@ -34,6 +34,14 @@ const LINE_WIDTH_FRACTION = 0.015; const LINE_WIDTH_MIN_PX = 1.5; const LINE_WIDTH_MAX_PX = 3; +// Compact button style for Reset/Export +const compactButton = { + fontSize: 10, + py: 0.25, + px: 1, + minWidth: 0, +}; + /** Round to a nice value (1, 2, 5, 10, 20, 50, etc.) */ function roundToNiceValue(value: number): number { if (value <= 0) return 1; @@ -1830,8 +1838,8 @@ function Show4DSTEM() { k: ({Math.round(localKx)}, {Math.round(localKy)}) - - + + @@ -1941,8 +1949,8 @@ function Show4DSTEM() { FFT: setShowFft(e.target.checked)} size="small" sx={switchStyles.small} /> - - + + @@ -2038,7 +2046,7 @@ function Show4DSTEM() { FFT - + From 3e25bd1b96b347ab17b164c74dc4df4fbdd1c8ae Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Mon, 19 Jan 2026 20:40:33 -0800 Subject: [PATCH 12/27] vertically center control rows in containers --- widget/js/show4dstem.tsx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/widget/js/show4dstem.tsx b/widget/js/show4dstem.tsx index 1a69ff89..2f87a0db 100644 --- a/widget/js/show4dstem.tsx +++ b/widget/js/show4dstem.tsx @@ -1870,7 +1870,7 @@ function Show4DSTEM() { {/* DP Controls - two rows with histogram on right */} {/* Left: two rows of controls */} - + {/* Row 1: Detector + slider */} Detector: @@ -1981,7 +1981,7 @@ function Show4DSTEM() { {/* VI Controls - Two rows with histogram on right */} {/* Left: Two rows of controls */} - + {/* Row 1: ROI selector */} ROI: From 6f8905d6de26276f3592dfc8cdd3dd1e908ffb2a Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Mon, 19 Jan 2026 20:41:45 -0800 Subject: [PATCH 13/27] add bordered style to each control row --- widget/js/show4dstem.tsx | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/widget/js/show4dstem.tsx b/widget/js/show4dstem.tsx index 2f87a0db..fb4ffde0 100644 --- a/widget/js/show4dstem.tsx +++ b/widget/js/show4dstem.tsx @@ -42,6 +42,18 @@ const compactButton = { minWidth: 0, }; +// Control row style - bordered container for each row +const controlRow = { + display: "flex", + alignItems: "center", + gap: `${SPACING.SM}px`, + border: "1px solid #3a3a3a", + borderRadius: "2px", + px: 1, + py: 0.5, + bgcolor: "#252525", +}; + /** Round to a nice value (1, 2, 5, 10, 20, 50, etc.) */ function roundToNiceValue(value: number): number { if (value <= 0) return 1; @@ -1868,11 +1880,11 @@ function Show4DSTEM() { )} {/* DP Controls - two rows with histogram on right */} - + {/* Left: two rows of controls */} {/* Row 1: Detector + slider */} - + Detector: setViRoiMode(e.target.value)} size="small" sx={{ ...controlPanel.select, minWidth: 60, fontSize: 10 }} MenuProps={upwardMenuProps}> Off @@ -2017,7 +2029,7 @@ function Show4DSTEM() { )} {/* Row 2: Color + Scale */} - + Color: setRoiMode(e.target.value)} size="small" sx={{ ...controlPanel.select, minWidth: 65, fontSize: 10 }} MenuProps={upwardMenuProps}> + setDpColormap(String(e.target.value))} size="small" sx={{ ...controlPanel.select, minWidth: 65, fontSize: 10 }} MenuProps={upwardMenuProps}> + Scale: - setDpScaleMode(e.target.value as "linear" | "log" | "power")} size="small" sx={{ ...themedSelect, minWidth: 50, fontSize: 10 }} MenuProps={upwardMenuProps}> Lin Log Pow @@ -2012,7 +2155,7 @@ function Show4DSTEM() { {/* Right: Histogram spanning both rows */} - { setDpVminPct(min); setDpVmaxPct(max); }} width={110} height={58} /> + { setDpVminPct(min); setDpVmaxPct(max); }} width={110} height={58} theme={themeInfo.theme} /> @@ -2023,7 +2166,7 @@ function Show4DSTEM() { Virtual Image - + {shapeX}×{shapeY} | {detX}×{detY} FFT: @@ -2049,11 +2192,11 @@ function Show4DSTEM() { {/* VI Stats Bar */} {viStats && viStats.length === 4 && ( - - Mean {formatStat(viStats[0])} - Min {formatStat(viStats[1])} - Max {formatStat(viStats[2])} - Std {formatStat(viStats[3])} + + Mean {formatStat(viStats[0])} + Min {formatStat(viStats[1])} + Max {formatStat(viStats[2])} + Std {formatStat(viStats[3])} )} @@ -2062,9 +2205,9 @@ function Show4DSTEM() { {/* Left: Two rows of controls */} {/* Row 1: ROI selector */} - + ROI: - setViRoiMode(e.target.value)} size="small" sx={{ ...themedSelect, minWidth: 60, fontSize: 10 }} MenuProps={upwardMenuProps}> Off Circle Square @@ -2096,9 +2239,9 @@ function Show4DSTEM() { )} {/* Row 2: Color + Scale */} - + Color: - setViColormap(String(e.target.value))} size="small" sx={{ ...themedSelect, minWidth: 65, fontSize: 10 }} MenuProps={upwardMenuProps}> Inferno Viridis Plasma @@ -2107,7 +2250,7 @@ function Show4DSTEM() { Gray Scale: - setViScaleMode(e.target.value as "linear" | "log" | "power")} size="small" sx={{ ...themedSelect, minWidth: 50, fontSize: 10 }} MenuProps={upwardMenuProps}> Lin Log Pow @@ -2116,7 +2259,7 @@ function Show4DSTEM() { {/* Right: Histogram spanning both rows */} - { setViVminPct(min); setViVmaxPct(max); }} width={110} height={58} /> + { setViVminPct(min); setViVmaxPct(max); }} width={110} height={58} theme={themeInfo.theme} /> @@ -2148,7 +2291,7 @@ function Show4DSTEM() { Path: - { setPathPlaying(false); setPathIndex(0); }} sx={{ color: "#888", fontSize: 14, cursor: "pointer", "&:hover": { color: "#fff" }, px: 0.5 }} title="Stop">⏹ + { setPathPlaying(false); setPathIndex(0); }} sx={{ color: themeColors.textMuted, fontSize: 14, cursor: "pointer", "&:hover": { color: "#fff" }, px: 0.5 }} title="Stop">⏹ setPathPlaying(!pathPlaying)} sx={{ color: pathPlaying ? "#0f0" : "#888", fontSize: 14, cursor: "pointer", "&:hover": { color: "#fff" }, px: 0.5 }} title={pathPlaying ? "Pause" : "Play"}>{pathPlaying ? "⏸" : "▶"} {pathIndex + 1}/{pathLength} { setPathPlaying(false); setPathIndex(v as number); }} min={0} max={Math.max(0, pathLength - 1)} size="small" sx={{ width: 100 }} /> From 60631bfff9b04363668d638fb976fc3b78719a12 Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Mon, 19 Jan 2026 22:22:47 -0800 Subject: [PATCH 18/27] inline all dependencies into show4dstem.tsx for self-contained widget --- widget/js/show4dstem.tsx | 269 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 265 insertions(+), 4 deletions(-) diff --git a/widget/js/show4dstem.tsx b/widget/js/show4dstem.tsx index b635ea86..f25b9a0c 100644 --- a/widget/js/show4dstem.tsx +++ b/widget/js/show4dstem.tsx @@ -1,3 +1,4 @@ +/// import * as React from "react"; import { createRender, useModelState } from "@anywidget/react"; import Box from "@mui/material/Box"; @@ -9,10 +10,6 @@ import Slider from "@mui/material/Slider"; import Button from "@mui/material/Button"; import Switch from "@mui/material/Switch"; import JSZip from "jszip"; -import { getWebGPUFFT, WebGPUFFT } from "./webgpu-fft"; -import { COLORMAPS, fft2d, fftshift, applyBandPassFilter, MIN_ZOOM, MAX_ZOOM } from "./shared"; -import { typography, controlPanel, container } from "./CONFIG"; -import { upwardMenuProps, switchStyles } from "./components"; import "./show4dstem.css"; // ============================================================================ @@ -94,6 +91,270 @@ function isColorDark(color: string): boolean { return luminance < 0.5; } +// ============================================================================ +// Colormaps - pre-computed LUTs for image display +// ============================================================================ +const COLORMAP_POINTS: Record = { + inferno: [[0,0,4],[40,11,84],[101,21,110],[159,42,99],[212,72,66],[245,125,21],[252,193,57],[252,255,164]], + viridis: [[68,1,84],[72,36,117],[65,68,135],[53,95,141],[42,120,142],[33,145,140],[34,168,132],[68,191,112],[122,209,81],[189,223,38],[253,231,37]], + plasma: [[13,8,135],[75,3,161],[126,3,168],[168,34,150],[203,70,121],[229,107,93],[248,148,65],[253,195,40],[240,249,33]], + magma: [[0,0,4],[28,16,68],[79,18,123],[129,37,129],[181,54,122],[229,80,100],[251,135,97],[254,194,135],[252,253,191]], + hot: [[0,0,0],[87,0,0],[173,0,0],[255,0,0],[255,87,0],[255,173,0],[255,255,0],[255,255,128],[255,255,255]], + gray: [[0,0,0],[255,255,255]], +}; + +function createColormapLUT(points: number[][]): Uint8Array { + const lut = new Uint8Array(256 * 3); + for (let i = 0; i < 256; i++) { + const t = (i / 255) * (points.length - 1); + const idx = Math.floor(t); + const frac = t - idx; + const p0 = points[Math.min(idx, points.length - 1)]; + const p1 = points[Math.min(idx + 1, points.length - 1)]; + lut[i * 3] = Math.round(p0[0] + frac * (p1[0] - p0[0])); + lut[i * 3 + 1] = Math.round(p0[1] + frac * (p1[1] - p0[1])); + lut[i * 3 + 2] = Math.round(p0[2] + frac * (p1[2] - p0[2])); + } + return lut; +} + +const COLORMAPS: Record = Object.fromEntries( + Object.entries(COLORMAP_POINTS).map(([name, points]) => [name, createColormapLUT(points)]) +); + +// ============================================================================ +// FFT Utilities - CPU implementation with WebGPU acceleration +// ============================================================================ +const MIN_ZOOM = 0.5; +const MAX_ZOOM = 10; + +function nextPow2(n: number): number { return Math.pow(2, Math.ceil(Math.log2(n))); } +function isPow2(n: number): boolean { return n > 0 && (n & (n - 1)) === 0; } + +function fft1dPow2(real: Float32Array, imag: Float32Array, inverse: boolean = false) { + const n = real.length; + if (n <= 1) return; + let j = 0; + for (let i = 0; i < n - 1; i++) { + if (i < j) { [real[i], real[j]] = [real[j], real[i]]; [imag[i], imag[j]] = [imag[j], imag[i]]; } + let k = n >> 1; + while (k <= j) { j -= k; k >>= 1; } + j += k; + } + const sign = inverse ? 1 : -1; + for (let len = 2; len <= n; len <<= 1) { + const halfLen = len >> 1; + const angle = (sign * 2 * Math.PI) / len; + const wReal = Math.cos(angle), wImag = Math.sin(angle); + for (let i = 0; i < n; i += len) { + let curReal = 1, curImag = 0; + for (let k = 0; k < halfLen; k++) { + const evenIdx = i + k, oddIdx = i + k + halfLen; + const tReal = curReal * real[oddIdx] - curImag * imag[oddIdx]; + const tImag = curReal * imag[oddIdx] + curImag * real[oddIdx]; + real[oddIdx] = real[evenIdx] - tReal; imag[oddIdx] = imag[evenIdx] - tImag; + real[evenIdx] += tReal; imag[evenIdx] += tImag; + const newReal = curReal * wReal - curImag * wImag; + curImag = curReal * wImag + curImag * wReal; curReal = newReal; + } + } + } + if (inverse) { for (let i = 0; i < n; i++) { real[i] /= n; imag[i] /= n; } } +} + +function fft2d(real: Float32Array, imag: Float32Array, width: number, height: number, inverse: boolean = false) { + const paddedW = nextPow2(width), paddedH = nextPow2(height); + const needsPadding = paddedW !== width || paddedH !== height; + let workReal: Float32Array, workImag: Float32Array; + if (needsPadding) { + workReal = new Float32Array(paddedW * paddedH); workImag = new Float32Array(paddedW * paddedH); + for (let y = 0; y < height; y++) for (let x = 0; x < width; x++) { + workReal[y * paddedW + x] = real[y * width + x]; workImag[y * paddedW + x] = imag[y * width + x]; + } + } else { workReal = real; workImag = imag; } + const rowReal = new Float32Array(paddedW), rowImag = new Float32Array(paddedW); + for (let y = 0; y < paddedH; y++) { + const offset = y * paddedW; + for (let x = 0; x < paddedW; x++) { rowReal[x] = workReal[offset + x]; rowImag[x] = workImag[offset + x]; } + fft1dPow2(rowReal, rowImag, inverse); + for (let x = 0; x < paddedW; x++) { workReal[offset + x] = rowReal[x]; workImag[offset + x] = rowImag[x]; } + } + const colReal = new Float32Array(paddedH), colImag = new Float32Array(paddedH); + for (let x = 0; x < paddedW; x++) { + for (let y = 0; y < paddedH; y++) { colReal[y] = workReal[y * paddedW + x]; colImag[y] = workImag[y * paddedW + x]; } + fft1dPow2(colReal, colImag, inverse); + for (let y = 0; y < paddedH; y++) { workReal[y * paddedW + x] = colReal[y]; workImag[y * paddedW + x] = colImag[y]; } + } + if (needsPadding) { + for (let y = 0; y < height; y++) for (let x = 0; x < width; x++) { + real[y * width + x] = workReal[y * paddedW + x]; imag[y * width + x] = workImag[y * paddedW + x]; + } + } +} + +function fftshift(data: Float32Array, width: number, height: number): void { + const halfW = width >> 1, halfH = height >> 1; + const temp = new Float32Array(width * height); + for (let y = 0; y < height; y++) for (let x = 0; x < width; x++) { + temp[((y + halfH) % height) * width + ((x + halfW) % width)] = data[y * width + x]; + } + data.set(temp); +} + +function applyBandPassFilter(real: Float32Array, imag: Float32Array, width: number, height: number, innerRadius: number, outerRadius: number) { + const centerX = width >> 1, centerY = height >> 1; + const innerSq = innerRadius * innerRadius, outerSq = outerRadius * outerRadius; + for (let y = 0; y < height; y++) for (let x = 0; x < width; x++) { + const distSq = (x - centerX) ** 2 + (y - centerY) ** 2; + if (distSq < innerSq || (outerRadius > 0 && distSq > outerSq)) { real[y * width + x] = 0; imag[y * width + x] = 0; } + } +} + +// ============================================================================ +// WebGPU FFT - GPU-accelerated FFT when available +// ============================================================================ +const FFT_SHADER = `fn cmul(a: vec2, b: vec2) -> vec2 { return vec2(a.x*b.x-a.y*b.y, a.x*b.y+a.y*b.x); } +fn twiddle(k: u32, N: u32, inverse: f32) -> vec2 { let angle = inverse * 2.0 * 3.14159265359 * f32(k) / f32(N); return vec2(cos(angle), sin(angle)); } +fn bitReverse(x: u32, log2N: u32) -> u32 { var result: u32 = 0u; var val = x; for (var i: u32 = 0u; i < log2N; i = i + 1u) { result = (result << 1u) | (val & 1u); val = val >> 1u; } return result; } +struct FFTParams { N: u32, log2N: u32, stage: u32, inverse: f32, } +@group(0) @binding(0) var params: FFTParams; +@group(0) @binding(1) var data: array>; +@compute @workgroup_size(256) fn bitReversePermute(@builtin(global_invocation_id) gid: vec3) { let idx = gid.x; if (idx >= params.N) { return; } let rev = bitReverse(idx, params.log2N); if (idx < rev) { let temp = data[idx]; data[idx] = data[rev]; data[rev] = temp; } } +@compute @workgroup_size(256) fn butterflyStage(@builtin(global_invocation_id) gid: vec3) { let idx = gid.x; if (idx >= params.N / 2u) { return; } let stage = params.stage; let halfSize = 1u << stage; let fullSize = halfSize << 1u; let group = idx / halfSize; let pos = idx % halfSize; let i = group * fullSize + pos; let j = i + halfSize; let w = twiddle(pos, fullSize, params.inverse); let u = data[i]; let t = cmul(w, data[j]); data[i] = u + t; data[j] = u - t; } +@compute @workgroup_size(256) fn normalize(@builtin(global_invocation_id) gid: vec3) { let idx = gid.x; if (idx >= params.N) { return; } let scale = 1.0 / f32(params.N); data[idx] = data[idx] * scale; }`; + +const FFT_2D_SHADER = `fn cmul(a: vec2, b: vec2) -> vec2 { return vec2(a.x*b.x-a.y*b.y, a.x*b.y+a.y*b.x); } +fn twiddle(k: u32, N: u32, inverse: f32) -> vec2 { let angle = inverse * 2.0 * 3.14159265359 * f32(k) / f32(N); return vec2(cos(angle), sin(angle)); } +fn bitReverse(x: u32, log2N: u32) -> u32 { var result: u32 = 0u; var val = x; for (var i: u32 = 0u; i < log2N; i = i + 1u) { result = (result << 1u) | (val & 1u); val = val >> 1u; } return result; } +struct FFT2DParams { width: u32, height: u32, log2Size: u32, stage: u32, inverse: f32, isRowWise: u32, } +@group(0) @binding(0) var params: FFT2DParams; +@group(0) @binding(1) var data: array>; +fn getIndex(row: u32, col: u32) -> u32 { return row * params.width + col; } +@compute @workgroup_size(16, 16) fn bitReverseRows(@builtin(global_invocation_id) gid: vec3) { let row = gid.y; let col = gid.x; if (row >= params.height || col >= params.width) { return; } let rev = bitReverse(col, params.log2Size); if (col < rev) { let idx1 = getIndex(row, col); let idx2 = getIndex(row, rev); let temp = data[idx1]; data[idx1] = data[idx2]; data[idx2] = temp; } } +@compute @workgroup_size(16, 16) fn bitReverseCols(@builtin(global_invocation_id) gid: vec3) { let row = gid.y; let col = gid.x; if (row >= params.height || col >= params.width) { return; } let rev = bitReverse(row, params.log2Size); if (row < rev) { let idx1 = getIndex(row, col); let idx2 = getIndex(rev, col); let temp = data[idx1]; data[idx1] = data[idx2]; data[idx2] = temp; } } +@compute @workgroup_size(16, 16) fn butterflyRows(@builtin(global_invocation_id) gid: vec3) { let row = gid.y; let idx = gid.x; if (row >= params.height || idx >= params.width / 2u) { return; } let stage = params.stage; let halfSize = 1u << stage; let fullSize = halfSize << 1u; let group = idx / halfSize; let pos = idx % halfSize; let col_i = group * fullSize + pos; let col_j = col_i + halfSize; if (col_j >= params.width) { return; } let w = twiddle(pos, fullSize, params.inverse); let i = getIndex(row, col_i); let j = getIndex(row, col_j); let u = data[i]; let t = cmul(w, data[j]); data[i] = u + t; data[j] = u - t; } +@compute @workgroup_size(16, 16) fn butterflyCols(@builtin(global_invocation_id) gid: vec3) { let col = gid.x; let idx = gid.y; if (col >= params.width || idx >= params.height / 2u) { return; } let stage = params.stage; let halfSize = 1u << stage; let fullSize = halfSize << 1u; let group = idx / halfSize; let pos = idx % halfSize; let row_i = group * fullSize + pos; let row_j = row_i + halfSize; if (row_j >= params.height) { return; } let w = twiddle(pos, fullSize, params.inverse); let i = getIndex(row_i, col); let j = getIndex(row_j, col); let u = data[i]; let t = cmul(w, data[j]); data[i] = u + t; data[j] = u - t; } +@compute @workgroup_size(16, 16) fn normalize2D(@builtin(global_invocation_id) gid: vec3) { let row = gid.y; let col = gid.x; if (row >= params.height || col >= params.width) { return; } let idx = getIndex(row, col); let scale = 1.0 / f32(params.width * params.height); data[idx] = data[idx] * scale; }`; + +class WebGPUFFT { + private device: GPUDevice; + private pipelines1D: { bitReverse: GPUComputePipeline; butterfly: GPUComputePipeline; normalize: GPUComputePipeline } | null = null; + private pipelines2D: { bitReverseRows: GPUComputePipeline; bitReverseCols: GPUComputePipeline; butterflyRows: GPUComputePipeline; butterflyCols: GPUComputePipeline; normalize: GPUComputePipeline } | null = null; + private initialized = false; + constructor(device: GPUDevice) { this.device = device; } + async init(): Promise { + if (this.initialized) return; + const module1D = this.device.createShaderModule({ code: FFT_SHADER }); + this.pipelines1D = { + bitReverse: this.device.createComputePipeline({ layout: 'auto', compute: { module: module1D, entryPoint: 'bitReversePermute' } }), + butterfly: this.device.createComputePipeline({ layout: 'auto', compute: { module: module1D, entryPoint: 'butterflyStage' } }), + normalize: this.device.createComputePipeline({ layout: 'auto', compute: { module: module1D, entryPoint: 'normalize' } }) + }; + const module2D = this.device.createShaderModule({ code: FFT_2D_SHADER }); + this.pipelines2D = { + bitReverseRows: this.device.createComputePipeline({ layout: 'auto', compute: { module: module2D, entryPoint: 'bitReverseRows' } }), + bitReverseCols: this.device.createComputePipeline({ layout: 'auto', compute: { module: module2D, entryPoint: 'bitReverseCols' } }), + butterflyRows: this.device.createComputePipeline({ layout: 'auto', compute: { module: module2D, entryPoint: 'butterflyRows' } }), + butterflyCols: this.device.createComputePipeline({ layout: 'auto', compute: { module: module2D, entryPoint: 'butterflyCols' } }), + normalize: this.device.createComputePipeline({ layout: 'auto', compute: { module: module2D, entryPoint: 'normalize2D' } }) + }; + this.initialized = true; + } + async fft2D(realData: Float32Array, imagData: Float32Array, width: number, height: number, inverse: boolean = false): Promise<{ real: Float32Array, imag: Float32Array }> { + await this.init(); + const paddedWidth = nextPow2(width), paddedHeight = nextPow2(height); + const needsPadding = paddedWidth !== width || paddedHeight !== height; + const log2Width = Math.log2(paddedWidth), log2Height = Math.log2(paddedHeight); + const paddedSize = paddedWidth * paddedHeight, originalSize = width * height; + let workReal: Float32Array, workImag: Float32Array; + if (needsPadding) { + workReal = new Float32Array(paddedSize); workImag = new Float32Array(paddedSize); + for (let y = 0; y < height; y++) for (let x = 0; x < width; x++) { workReal[y * paddedWidth + x] = realData[y * width + x]; workImag[y * paddedWidth + x] = imagData[y * width + x]; } + } else { workReal = realData; workImag = imagData; } + const complexData = new Float32Array(paddedSize * 2); + for (let i = 0; i < paddedSize; i++) { complexData[i * 2] = workReal[i]; complexData[i * 2 + 1] = workImag[i]; } + const dataBuffer = this.device.createBuffer({ size: complexData.byteLength, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST }); + this.device.queue.writeBuffer(dataBuffer, 0, complexData); + const paramsBuffer = this.device.createBuffer({ size: 24, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST }); + const readBuffer = this.device.createBuffer({ size: complexData.byteLength, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST }); + const inverseVal = inverse ? 1.0 : -1.0; + const workgroupsX = Math.ceil(paddedWidth / 16), workgroupsY = Math.ceil(paddedHeight / 16); + const runPass = (pipeline: GPUComputePipeline) => { + const bindGroup = this.device.createBindGroup({ layout: pipeline.getBindGroupLayout(0), entries: [{ binding: 0, resource: { buffer: paramsBuffer } }, { binding: 1, resource: { buffer: dataBuffer } }] }); + const encoder = this.device.createCommandEncoder(); const pass = encoder.beginComputePass(); + pass.setPipeline(pipeline); pass.setBindGroup(0, bindGroup); pass.dispatchWorkgroups(workgroupsX, workgroupsY); pass.end(); + this.device.queue.submit([encoder.finish()]); + }; + const params = new ArrayBuffer(24); const paramsU32 = new Uint32Array(params); const paramsF32 = new Float32Array(params); + paramsU32[0] = paddedWidth; paramsU32[1] = paddedHeight; paramsU32[2] = log2Width; paramsU32[3] = 0; paramsF32[4] = inverseVal; paramsU32[5] = 1; + this.device.queue.writeBuffer(paramsBuffer, 0, params); runPass(this.pipelines2D!.bitReverseRows); + for (let stage = 0; stage < log2Width; stage++) { paramsU32[3] = stage; this.device.queue.writeBuffer(paramsBuffer, 0, params); runPass(this.pipelines2D!.butterflyRows); } + paramsU32[2] = log2Height; paramsU32[3] = 0; paramsU32[5] = 0; + this.device.queue.writeBuffer(paramsBuffer, 0, params); runPass(this.pipelines2D!.bitReverseCols); + for (let stage = 0; stage < log2Height; stage++) { paramsU32[3] = stage; this.device.queue.writeBuffer(paramsBuffer, 0, params); runPass(this.pipelines2D!.butterflyCols); } + if (inverse) runPass(this.pipelines2D!.normalize); + const encoder = this.device.createCommandEncoder(); encoder.copyBufferToBuffer(dataBuffer, 0, readBuffer, 0, complexData.byteLength); + this.device.queue.submit([encoder.finish()]); await readBuffer.mapAsync(GPUMapMode.READ); + const result = new Float32Array(readBuffer.getMappedRange().slice(0)); readBuffer.unmap(); + dataBuffer.destroy(); paramsBuffer.destroy(); readBuffer.destroy(); + if (needsPadding) { + const realResult = new Float32Array(originalSize), imagResult = new Float32Array(originalSize); + for (let y = 0; y < height; y++) for (let x = 0; x < width; x++) { realResult[y * width + x] = result[(y * paddedWidth + x) * 2]; imagResult[y * width + x] = result[(y * paddedWidth + x) * 2 + 1]; } + return { real: realResult, imag: imagResult }; + } + const realResult = new Float32Array(paddedSize), imagResult = new Float32Array(paddedSize); + for (let i = 0; i < paddedSize; i++) { realResult[i] = result[i * 2]; imagResult[i] = result[i * 2 + 1]; } + return { real: realResult, imag: imagResult }; + } + destroy(): void { this.initialized = false; } +} + +let gpuFFT: WebGPUFFT | null = null; +async function getWebGPUFFT(): Promise { + if (gpuFFT) return gpuFFT; + if (!navigator.gpu) { console.warn('WebGPU not supported, using CPU FFT'); return null; } + try { + const adapter = await navigator.gpu.requestAdapter(); + if (!adapter) return null; + const device = await adapter.requestDevice(); + gpuFFT = new WebGPUFFT(device); await gpuFFT.init(); + console.log('🚀 WebGPU FFT ready'); + return gpuFFT; + } catch (e) { console.warn('WebGPU init failed:', e); return null; } +} + +// ============================================================================ +// UI Styles - component styling helpers +// ============================================================================ +const typography = { + label: { color: "#aaa", fontSize: 11 }, + labelSmall: { color: "#888", fontSize: 10 }, + value: { color: "#888", fontSize: 10, fontFamily: "monospace" }, + title: { color: "#0af", fontWeight: "bold" as const }, +}; + +const controlPanel = { + group: { bgcolor: "#222", px: 1.5, py: 0.5, borderRadius: 1, border: "1px solid #444", height: 32 }, + button: { color: "#888", fontSize: 10, cursor: "pointer", "&:hover": { color: "#fff" }, bgcolor: "#222", px: 1, py: 0.25, borderRadius: 0.5, border: "1px solid #444" }, + select: { minWidth: 90, bgcolor: "#333", color: "#fff", fontSize: 11, "& .MuiSelect-select": { py: 0.5 } }, +}; + +const container = { + root: { p: 2, bgcolor: "transparent", color: "inherit", fontFamily: "monospace", borderRadius: 1, overflow: "visible" }, + imageBox: { bgcolor: "#000", border: "1px solid #444", overflow: "hidden", position: "relative" as const }, +}; + +const upwardMenuProps = { + anchorOrigin: { vertical: "top" as const, horizontal: "left" as const }, + transformOrigin: { vertical: "bottom" as const, horizontal: "left" as const }, + sx: { zIndex: 9999 }, +}; + +const switchStyles = { + small: { '& .MuiSwitch-thumb': { width: 12, height: 12 }, '& .MuiSwitch-switchBase': { padding: '4px' } }, + medium: { '& .MuiSwitch-thumb': { width: 14, height: 14 }, '& .MuiSwitch-switchBase': { padding: '4px' } }, +}; + // ============================================================================ // Layout Constants - consistent spacing throughout // ============================================================================ From 75a29d45e700cc7b7e3858942a4cd71c43a7210a Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Mon, 19 Jan 2026 22:27:23 -0800 Subject: [PATCH 19/27] remove unused shared code after inlining into show4dstem --- widget/js/CONFIG.ts | 191 ----------- widget/js/components.tsx | 220 ------------- widget/js/core/canvas-utils.ts | 269 ---------------- widget/js/core/canvas.ts | 276 ---------------- widget/js/core/colormaps.ts | 100 ------ widget/js/core/colors.ts | 71 ----- widget/js/core/export.ts | 135 -------- widget/js/core/fft-utils.ts | 161 ---------- widget/js/core/format.ts | 58 ---- widget/js/core/hooks.ts | 211 ------------- widget/js/core/index.ts | 86 ----- widget/js/core/styles.ts | 295 ----------------- widget/js/core/webgpu-hook.ts | 37 --- widget/js/shared.ts | 221 ------------- widget/js/webgpu-fft.ts | 558 --------------------------------- 15 files changed, 2889 deletions(-) delete mode 100644 widget/js/CONFIG.ts delete mode 100644 widget/js/components.tsx delete mode 100644 widget/js/core/canvas-utils.ts delete mode 100644 widget/js/core/canvas.ts delete mode 100644 widget/js/core/colormaps.ts delete mode 100644 widget/js/core/colors.ts delete mode 100644 widget/js/core/export.ts delete mode 100644 widget/js/core/fft-utils.ts delete mode 100644 widget/js/core/format.ts delete mode 100644 widget/js/core/hooks.ts delete mode 100644 widget/js/core/index.ts delete mode 100644 widget/js/core/styles.ts delete mode 100644 widget/js/core/webgpu-hook.ts delete mode 100644 widget/js/shared.ts delete mode 100644 widget/js/webgpu-fft.ts diff --git a/widget/js/CONFIG.ts b/widget/js/CONFIG.ts deleted file mode 100644 index 74ba8914..00000000 --- a/widget/js/CONFIG.ts +++ /dev/null @@ -1,191 +0,0 @@ -/** - * Global configuration for bobleesj.widget. - * Layout constants and styling presets for all widgets. - */ - -// Import colors from single source of truth -import { COLORS, colors } from "./core/colors"; -export { COLORS, colors }; - -// ============================================================================ -// TYPOGRAPHY -// ============================================================================ -export const TYPOGRAPHY = { - LABEL: { - color: COLORS.TEXT_SECONDARY, - fontSize: 11, - }, - LABEL_SMALL: { - color: COLORS.TEXT_MUTED, - fontSize: 10, - }, - VALUE: { - color: COLORS.TEXT_MUTED, - fontSize: 10, - fontFamily: "monospace", - }, - TITLE: { - color: COLORS.ACCENT, - fontWeight: "bold" as const, - }, -}; - -// ============================================================================ -// CONTROL PANEL STYLES -// ============================================================================ -export const CONTROL_PANEL = { - // Standard control group (height: 32px) - GROUP: { - bgcolor: COLORS.BG_PANEL, - px: 1.5, - py: 0.5, - borderRadius: 1, - border: `1px solid ${COLORS.BORDER}`, - height: 32, - }, - // Compact button - BUTTON: { - color: COLORS.TEXT_MUTED, - fontSize: 10, - cursor: "pointer", - "&:hover": { color: COLORS.TEXT_PRIMARY }, - bgcolor: COLORS.BG_PANEL, - px: 1, - py: 0.25, - borderRadius: 0.5, - border: `1px solid ${COLORS.BORDER}`, - }, - // Select dropdown - SELECT: { - minWidth: 90, - bgcolor: COLORS.BG_INPUT, - color: COLORS.TEXT_PRIMARY, - fontSize: 11, - "& .MuiSelect-select": { - py: 0.5, - }, - }, -}; - -// ============================================================================ -// CONTAINER STYLES -// ============================================================================ -export const CONTAINER = { - ROOT: { - p: 2, - // Use transparent background to inherit from parent (light/dark mode aware) - bgcolor: "transparent", - // Inherit text color from parent for theme awareness - color: "inherit", - fontFamily: "monospace", - borderRadius: 1, - // CRITICAL: Allow dropdowns to overflow - overflow: "visible", - }, - IMAGE_BOX: { - bgcolor: "#000", - border: `1px solid ${COLORS.BORDER}`, - overflow: "hidden", - position: "relative" as const, - }, -}; - -// ============================================================================ -// SLIDER SIZES -// ============================================================================ -export const SLIDER = { - // Width presets - WIDTH: { - TINY: 60, // Very compact (e.g., ms/frame slider) - SMALL: 80, // Standard small slider - MEDIUM: 100, // Medium slider - LARGE: 120, // Larger slider - }, - // Container min-widths (for label + slider + value combos) - CONTAINER: { - COMPACT: 120, // Minimal container - STANDARD: 150, // Standard container (e.g., delay slider) - WIDE: 180, // Wider container - }, -}; - -// ============================================================================ -// PANEL SIZES (for canvases and image boxes) -// ============================================================================ -export const PANEL = { - // Main image canvas sizes - MAIN: { - DEFAULT: 300, // Default main canvas size - MIN: 150, // Minimum resizable size - MAX: 600, // Maximum resizable size - }, - // Side panels (FFT, histogram, etc.) - SIDE: { - DEFAULT: 150, // Default side panel size - MIN: 80, // Minimum resizable size - MAX: 250, // Maximum resizable size - }, - // Show4DSTEM specific - DP: { - DEFAULT: 400, // Diffraction pattern panel - }, - VIRTUAL: { - DEFAULT: 300, // Virtual image panel - }, - FFT: { - DEFAULT: 300, // FFT panel - }, - // Gallery mode - GALLERY: { - IMAGE_SIZE: 200, // Target size for gallery images - MIN_COLS: 2, // Minimum columns - MAX_COLS: 4, // Maximum columns - }, -}; - -// ============================================================================ -// ZOOM/PAN LIMITS -// ============================================================================ -export const ZOOM = { - MIN: 0.5, - MAX: 10, - WHEEL_FACTOR: { - IN: 1.1, - OUT: 0.9, - }, -}; - -// ============================================================================ -// ANIMATION/PLAYBACK -// ============================================================================ -export const PLAYBACK = { - MS_PER_FRAME: { - DEFAULT: 1000, // Default: 1 fps - MIN: 200, // Fastest: 5 fps - MAX: 3000, // Slowest: ~0.33 fps - STEP: 100, // Step size for slider - }, -}; - -// ============================================================================ -// LEGACY ALIASES (for backward compatibility during migration) -// These use camelCase keys to match existing widget code -// Note: `colors` is imported from core/colors.ts and re-exported at the top -// ============================================================================ -export const typography = { - label: TYPOGRAPHY.LABEL, - labelSmall: TYPOGRAPHY.LABEL_SMALL, - value: TYPOGRAPHY.VALUE, - title: TYPOGRAPHY.TITLE, -}; - -export const controlPanel = { - group: CONTROL_PANEL.GROUP, - button: CONTROL_PANEL.BUTTON, - select: CONTROL_PANEL.SELECT, -}; - -export const container = { - root: CONTAINER.ROOT, - imageBox: CONTAINER.IMAGE_BOX, -}; diff --git a/widget/js/components.tsx b/widget/js/components.tsx deleted file mode 100644 index c275c8cb..00000000 --- a/widget/js/components.tsx +++ /dev/null @@ -1,220 +0,0 @@ -/** - * Shared styling constants and simple UI components for bobleesj.widget. - * - * ARCHITECTURE NOTE: Only styling should be shared here. - * Widget-specific logic (resize handlers, zoom handlers) should be inlined per-widget. - */ - -import * as React from "react"; -import Switch from "@mui/material/Switch"; -import Select from "@mui/material/Select"; -import MenuItem from "@mui/material/MenuItem"; -import Stack from "@mui/material/Stack"; -import Typography from "@mui/material/Typography"; -import { colors, controlPanel, typography } from "./CONFIG"; - -// ============================================================================ -// Switch Style Constants -// ============================================================================ -export const switchStyles = { - small: { - '& .MuiSwitch-thumb': { width: 12, height: 12 }, - '& .MuiSwitch-switchBase': { padding: '4px' }, - }, - medium: { - '& .MuiSwitch-thumb': { width: 14, height: 14 }, - '& .MuiSwitch-switchBase': { padding: '4px' }, - }, -}; - -// ============================================================================ -// Select MenuProps for upward dropdown (all widgets use this) -// ============================================================================ -export const upwardMenuProps = { - anchorOrigin: { vertical: "top" as const, horizontal: "left" as const }, - transformOrigin: { vertical: "bottom" as const, horizontal: "left" as const }, - sx: { zIndex: 9999 }, -}; - -// ============================================================================ -// LabeledSwitch - Label + Switch combo (optional, use if needed) -// ============================================================================ -interface LabeledSwitchProps { - label: string; - checked: boolean; - onChange: (checked: boolean) => void; - size?: "small" | "medium"; -} - -export function LabeledSwitch({ label, checked, onChange, size = "small" }: LabeledSwitchProps) { - return ( - - {label}: - onChange(e.target.checked)} - size="small" - sx={switchStyles[size]} - /> - - ); -} - -// ============================================================================ -// LabeledSelect - Label + Select dropdown combo (optional, use if needed) -// ============================================================================ -interface LabeledSelectProps { - label: string; - value: T; - options: readonly T[] | T[]; - onChange: (value: T) => void; - formatLabel?: (value: T) => string; -} - -export function LabeledSelect({ - label, - value, - options, - onChange, - formatLabel, -}: LabeledSelectProps) { - return ( - - {label}: - - - ); -} - -// ============================================================================ -// ScaleBar - Overlay component for canvas scale bars -// ============================================================================ -interface ScaleBarProps { - zoom: number; - size: number; - label?: string; -} - -export function ScaleBar({ zoom, size, label = "px" }: ScaleBarProps) { - const scaleBarPx = 50; - const realPixels = Math.round(scaleBarPx / zoom); - - return ( -
- - {realPixels} {label} - -
-
- ); -} - -// ============================================================================ -// ZoomIndicator - Overlay component for zoom level display -// ============================================================================ -interface ZoomIndicatorProps { - zoom: number; -} - -export function ZoomIndicator({ zoom }: ZoomIndicatorProps) { - return ( - - {zoom.toFixed(1)}× - - ); -} - -// ============================================================================ -// ResetButton - Compact reset button -// ============================================================================ -interface ResetButtonProps { - onClick: () => void; - label?: string; -} - -export function ResetButton({ onClick, label = "Reset" }: ResetButtonProps) { - return ( - - {label} - - ); -} - -// ============================================================================ -// ControlGroup - Wrapper for control panel groups -// ============================================================================ -interface ControlGroupProps { - children: React.ReactNode; -} - -export function ControlGroup({ children }: ControlGroupProps) { - return ( - - {children} - - ); -} - -// ============================================================================ -// ColormapSelect - Colormap dropdown with standard options -// ============================================================================ -const COLORMAP_OPTIONS = ["inferno", "viridis", "plasma", "magma", "hot", "gray"] as const; - -interface ColormapSelectProps { - value: string; - onChange: (value: string) => void; -} - -export function ColormapSelect({ value, onChange }: ColormapSelectProps) { - return ( - v.charAt(0).toUpperCase() + v.slice(1)} - /> - ); -} diff --git a/widget/js/core/canvas-utils.ts b/widget/js/core/canvas-utils.ts deleted file mode 100644 index 1679fa2f..00000000 --- a/widget/js/core/canvas-utils.ts +++ /dev/null @@ -1,269 +0,0 @@ -/** - * Shared canvas rendering utilities. - * Used by Show2D, Show3D, Show4DSTEM, and Reconstruct. - */ - -import { COLORMAPS } from "./colormaps"; -import { colors } from "./colors"; - -// ============================================================================ -// Colormap LUT Application -// ============================================================================ - -/** - * Render uint8 data to canvas with colormap LUT. - */ -export function renderWithColormap( - ctx: CanvasRenderingContext2D, - data: Uint8Array, - width: number, - height: number, - cmapName: string = "inferno" -): void { - const lut = COLORMAPS[cmapName] || COLORMAPS.inferno; - const imgData = ctx.createImageData(width, height); - const rgba = imgData.data; - - for (let i = 0; i < data.length; i++) { - const v = data[i]; - const j = i * 4; - const lutIdx = v * 3; - rgba[j] = lut[lutIdx]; - rgba[j + 1] = lut[lutIdx + 1]; - rgba[j + 2] = lut[lutIdx + 2]; - rgba[j + 3] = 255; - } - ctx.putImageData(imgData, 0, 0); -} - -/** - * Render float32 data to canvas with colormap. - */ -export function renderFloat32WithColormap( - ctx: CanvasRenderingContext2D, - data: Float32Array, - width: number, - height: number, - cmapName: string = "inferno" -): void { - const lut = COLORMAPS[cmapName] || COLORMAPS.inferno; - - // Calculate min/max - let min = Infinity, max = -Infinity; - for (let i = 0; i < data.length; i++) { - if (data[i] < min) min = data[i]; - if (data[i] > max) max = data[i]; - } - const range = max - min || 1; - const scale = 255 / range; - - const imgData = ctx.createImageData(width, height); - const rgba = imgData.data; - - for (let i = 0; i < data.length; i++) { - const v = Math.round((data[i] - min) * scale); - const lutIdx = Math.max(0, Math.min(255, v)) * 3; - const j = i * 4; - rgba[j] = lut[lutIdx]; - rgba[j + 1] = lut[lutIdx + 1]; - rgba[j + 2] = lut[lutIdx + 2]; - rgba[j + 3] = 255; - } - ctx.putImageData(imgData, 0, 0); -} - -/** - * Draw image data to canvas with zoom and pan. - */ -export function drawWithZoomPan( - ctx: CanvasRenderingContext2D, - source: HTMLCanvasElement | ImageData, - canvasWidth: number, - canvasHeight: number, - zoom: number, - panX: number, - panY: number -): void { - ctx.imageSmoothingEnabled = false; - ctx.clearRect(0, 0, canvasWidth, canvasHeight); - ctx.save(); - ctx.translate(panX, panY); - ctx.scale(zoom, zoom); - if (source instanceof ImageData) { - ctx.putImageData(source, 0, 0); - } else { - ctx.drawImage(source, 0, 0); - } - ctx.restore(); -} - -// ============================================================================ -// Scale Bar Rendering -// ============================================================================ - -/** Round to a nice value (1, 2, 5, 10, 20, 50, etc.) */ -export function roundToNiceValue(value: number): number { - if (value <= 0) return 1; - const magnitude = Math.pow(10, Math.floor(Math.log10(value))); - const normalized = value / magnitude; - if (normalized < 1.5) return magnitude; - if (normalized < 3.5) return 2 * magnitude; - if (normalized < 7.5) return 5 * magnitude; - return 10 * magnitude; -} - -/** Format scale bar label with appropriate unit */ -export function formatScaleLabel(value: number, unit: string): string { - const nice = roundToNiceValue(value); - - if (unit === "Å") { - if (nice >= 10) return `${Math.round(nice / 10)} nm`; - if (nice >= 1) return `${Math.round(nice)} Å`; - return `${nice.toFixed(2)} Å`; - } else if (unit === "nm") { - if (nice >= 1000) return `${Math.round(nice / 1000)} µm`; - if (nice >= 1) return `${Math.round(nice)} nm`; - return `${nice.toFixed(2)} nm`; - } else if (unit === "mrad") { - if (nice >= 1000) return `${Math.round(nice / 1000)} rad`; - if (nice >= 1) return `${Math.round(nice)} mrad`; - return `${nice.toFixed(2)} mrad`; - } else if (unit === "1/µm") { - if (nice >= 1000) return `${Math.round(nice / 1000)} 1/nm`; - if (nice >= 1) return `${Math.round(nice)} 1/µm`; - return `${nice.toFixed(2)} 1/µm`; - } else if (unit === "px") { - return `${Math.round(nice)} px`; - } - return `${Math.round(nice)} ${unit}`; -} - -/** - * Draw scale bar on high-DPI canvas. - */ -export function drawScaleBarHiDPI( - canvas: HTMLCanvasElement, - dpr: number, - zoom: number, - pixelSize: number, - unit: string = "nm", - imageWidth: number, - imageHeight: number -): void { - const ctx = canvas.getContext("2d"); - if (!ctx) return; - - ctx.clearRect(0, 0, canvas.width, canvas.height); - ctx.save(); - ctx.scale(dpr, dpr); - - const cssWidth = canvas.width / dpr; - const cssHeight = canvas.height / dpr; - const displayScale = Math.min(cssWidth / imageWidth, cssHeight / imageHeight); - const effectiveZoom = zoom * displayScale; - - // Fixed UI sizes in CSS pixels - const targetBarPx = 60; - const barThickness = 5; - const fontSize = 16; - const margin = 12; - - const targetPhysical = (targetBarPx / effectiveZoom) * pixelSize; - const nicePhysical = roundToNiceValue(targetPhysical); - const barPx = (nicePhysical / pixelSize) * effectiveZoom; - - const barY = cssHeight - margin; - const barX = cssWidth - barPx - margin; - - // Draw with shadow for visibility - ctx.shadowColor = "rgba(0, 0, 0, 0.5)"; - ctx.shadowBlur = 2; - ctx.shadowOffsetX = 1; - ctx.shadowOffsetY = 1; - - ctx.fillStyle = "white"; - ctx.fillRect(barX, barY, barPx, barThickness); - - // Draw label - const label = formatScaleLabel(nicePhysical, unit); - ctx.font = `${fontSize}px -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif`; - ctx.textAlign = "center"; - ctx.textBaseline = "bottom"; - ctx.fillText(label, barX + barPx / 2, barY - 4); - - // Draw zoom indicator (bottom left) - ctx.textAlign = "left"; - ctx.textBaseline = "bottom"; - ctx.fillText(`${zoom.toFixed(1)}×`, margin, cssHeight - margin + barThickness); - - ctx.restore(); -} - -/** - * Draw crosshair on high-DPI canvas. - */ -export function drawCrosshairHiDPI( - canvas: HTMLCanvasElement, - dpr: number, - posX: number, - posY: number, - zoom: number, - panX: number, - panY: number, - imageWidth: number, - imageHeight: number, - isDragging: boolean, - color: string = "rgba(0, 255, 0, 0.9)", - dragColor: string = "rgba(255, 255, 0, 0.9)" -): void { - const ctx = canvas.getContext("2d"); - if (!ctx) return; - - ctx.save(); - ctx.scale(dpr, dpr); - - const cssWidth = canvas.width / dpr; - const cssHeight = canvas.height / dpr; - const displayScale = Math.min(cssWidth / imageWidth, cssHeight / imageHeight); - - const screenX = posX * zoom * displayScale + panX * displayScale; - const screenY = posY * zoom * displayScale + panY * displayScale; - - const crosshairSize = 18; - const lineWidth = 3; - const dotRadius = 6; - - ctx.shadowColor = "rgba(0, 0, 0, 0.5)"; - ctx.shadowBlur = 2; - ctx.shadowOffsetX = 1; - ctx.shadowOffsetY = 1; - - ctx.strokeStyle = isDragging ? dragColor : color; - ctx.lineWidth = lineWidth; - - ctx.beginPath(); - ctx.moveTo(screenX - crosshairSize, screenY); - ctx.lineTo(screenX + crosshairSize, screenY); - ctx.moveTo(screenX, screenY - crosshairSize); - ctx.lineTo(screenX, screenY + crosshairSize); - ctx.stroke(); - - ctx.beginPath(); - ctx.arc(screenX, screenY, dotRadius, 0, 2 * Math.PI); - ctx.stroke(); - - ctx.restore(); -} - -// ============================================================================ -// Export to Blob/ZIP Helpers -// ============================================================================ - -/** - * Convert canvas to PNG blob. - */ -export function canvasToBlob(canvas: HTMLCanvasElement): Promise { - return new Promise((resolve) => { - canvas.toBlob((blob) => resolve(blob!), "image/png"); - }); -} diff --git a/widget/js/core/canvas.ts b/widget/js/core/canvas.ts deleted file mode 100644 index 336cee64..00000000 --- a/widget/js/core/canvas.ts +++ /dev/null @@ -1,276 +0,0 @@ -/** - * Canvas rendering utilities for image widgets. - * Scale bar, overlays, ROI drawing, etc. - */ - -import { colors } from "./colors"; - -/** Nice values for scale bar lengths */ -const NICE_VALUES = [1, 2, 5, 10, 20, 50, 100, 200, 500, 1000]; - -/** - * Calculate a "nice" scale bar length. - * @param imageWidthNm - Total image width in nm - * @param targetFraction - Target fraction of image width (default 0.2) - * @returns Scale bar length in nm - */ -export function calculateNiceScaleBar( - imageWidthNm: number, - targetFraction: number = 0.2 -): number { - const targetNm = imageWidthNm * targetFraction; - const magnitude = Math.pow(10, Math.floor(Math.log10(targetNm))); - - let barNm = magnitude; - for (const v of NICE_VALUES) { - if (v * magnitude <= targetNm * 1.5) { - barNm = v * magnitude; - } - } - return barNm; -} - -/** - * Round to a "nice" scale bar value (1, 2, 5, 10, 20, 50, 100, etc.) - * This ensures scale bars always show clean integer values. - * @param value - Raw value to round - * @returns Nice rounded value - */ -function roundToNiceValue(value: number): number { - if (value <= 0) return 1; - const magnitude = Math.pow(10, Math.floor(Math.log10(value))); - const normalized = value / magnitude; - // Round to 1, 2, 5, or 10 - if (normalized < 1.5) return magnitude; - if (normalized < 3.5) return 2 * magnitude; - if (normalized < 7.5) return 5 * magnitude; - return 10 * magnitude; -} - -/** - * Format scale bar label with appropriate unit. - * Always displays integer values (no decimals). - * @param angstroms - Length in Angstroms - * @returns Formatted string (e.g., "5 Å", "20 nm", "1 µm") - */ -export function formatScaleBarLabel(angstroms: number): string { - // Round to nice value first - const nice = roundToNiceValue(angstroms); - - if (nice >= 10000) { // >= 1 µm - return `${Math.round(nice / 10000)} µm`; - } - if (nice >= 100) { // >= 10 nm, show in nm - return `${Math.round(nice / 10)} nm`; - } - return `${Math.round(nice)} Å`; -} - -/** - * Draw scale bar on canvas overlay with nice integer labels. - * The bar length is dynamically calculated to accurately represent the physical distance. - * @param ctx - Canvas 2D context - * @param canvasWidth - Canvas width in pixels - * @param canvasHeight - Canvas height in pixels - * @param imageWidth - Image width in data pixels - * @param pixelSizeAngstrom - Pixel size in Angstroms - * @param displayScale - Canvas scale factor (includes zoom) - * @param targetBarLength - Target length of the scale bar in pixels (default 50) - * @param barThickness - Thickness of the scale bar (default 4) - * @param fontSize - Font size for the label (default 16) - */ -export function drawScaleBar( - ctx: CanvasRenderingContext2D, - canvasWidth: number, - canvasHeight: number, - _imageWidth: number, - pixelSizeAngstrom: number, - displayScale: number = 1, - targetBarLength: number = 50, - barThickness: number = 4, - fontSize: number = 16 -): void { - // Fallback if pixelSize is missing or invalid: show bar in pixels - if (pixelSizeAngstrom <= 0) { - const x = canvasWidth - targetBarLength - 10; - const y = canvasHeight - 20; - - ctx.fillStyle = colors.textPrimary; - ctx.fillRect(x, y, targetBarLength, barThickness); - - ctx.font = "11px sans-serif"; - ctx.fillStyle = colors.textPrimary; - ctx.textAlign = "right"; - ctx.fillText(`${targetBarLength} px`, x + targetBarLength, y - 5); - return; - } - - // Calculate what the target bar length represents in Angstroms at current zoom - const targetAngstroms = targetBarLength * pixelSizeAngstrom / displayScale; - - // Round to a nice value - const niceAngstroms = roundToNiceValue(targetAngstroms); - - // Calculate the actual bar length for the nice value - const barLength = (niceAngstroms / pixelSizeAngstrom) * displayScale; - - const x = canvasWidth - barLength - 10; - const y = canvasHeight - 20; - - // Draw bar (length matches the nice value) - ctx.fillStyle = colors.textPrimary; - ctx.fillRect(x, y, barLength, barThickness); - - // Draw label - const label = formatScaleBarLabel(niceAngstroms); - ctx.font = `${fontSize}px sans-serif`; - ctx.fillStyle = colors.textPrimary; - ctx.textAlign = "right"; - ctx.fillText(label, x + barLength, y - 5); -} - -/** - * Draw ROI on canvas overlay with different shapes. - * @param ctx - Canvas 2D context - * @param x - Center X in canvas pixels - * @param y - Center Y in canvas pixels - * @param shape - ROI shape: "circle", "square", or "rectangle" - * @param radius - Radius for circle, or half-size for square - * @param width - Width for rectangle - * @param height - Height for rectangle - * @param active - Whether ROI is being dragged - */ -export function drawROI( - ctx: CanvasRenderingContext2D, - x: number, - y: number, - shape: "circle" | "square" | "rectangle", - radius: number, - width: number, - height: number, - active: boolean = false -): void { - const strokeColor = active ? colors.accentYellow : colors.accentGreen; - ctx.strokeStyle = strokeColor; - ctx.lineWidth = 2; - - if (shape === "circle") { - ctx.beginPath(); - ctx.arc(x, y, radius, 0, Math.PI * 2); - ctx.stroke(); - } else if (shape === "square") { - const size = radius * 2; - ctx.strokeRect(x - radius, y - radius, size, size); - } else if (shape === "rectangle") { - const halfW = width / 2; - const halfH = height / 2; - ctx.strokeRect(x - halfW, y - halfH, width, height); - } - - // Center crosshair - only show while dragging - if (active) { - ctx.beginPath(); - ctx.moveTo(x - 5, y); - ctx.lineTo(x + 5, y); - ctx.moveTo(x, y - 5); - ctx.lineTo(x, y + 5); - ctx.stroke(); - } -} - -/** - * Draw ROI circle on canvas overlay. - * @param ctx - Canvas 2D context - * @param x - Center X in canvas pixels - * @param y - Center Y in canvas pixels - * @param radius - Radius in canvas pixels - * @param active - Whether ROI is being dragged - */ -export function drawROICircle( - ctx: CanvasRenderingContext2D, - x: number, - y: number, - radius: number, - active: boolean = false -): void { - const strokeColor = active ? colors.accentYellow : colors.accentGreen; - - // Circle - ctx.strokeStyle = strokeColor; - ctx.lineWidth = 2; - ctx.beginPath(); - ctx.arc(x, y, radius, 0, Math.PI * 2); - ctx.stroke(); - - // Center crosshair - ctx.beginPath(); - ctx.moveTo(x - 5, y); - ctx.lineTo(x + 5, y); - ctx.moveTo(x, y - 5); - ctx.lineTo(x, y + 5); - ctx.stroke(); -} - -/** - * Draw crosshair on canvas. - * @param ctx - Canvas 2D context - * @param x - Center X - * @param y - Center Y - * @param size - Half-length of crosshair arms - * @param color - Stroke color - */ -export function drawCrosshair( - ctx: CanvasRenderingContext2D, - x: number, - y: number, - size: number = 10, - color: string = colors.accentGreen -): void { - ctx.strokeStyle = color; - ctx.lineWidth = 2; - ctx.beginPath(); - ctx.moveTo(x - size, y); - ctx.lineTo(x + size, y); - ctx.moveTo(x, y - size); - ctx.lineTo(x, y + size); - ctx.stroke(); -} - -/** - * Calculate canvas scale factor for display. - * Aims for approximately targetSize pixels on screen. - * @param width - Image width - * @param height - Image height - * @param targetSize - Target display size in pixels (default 400) - * @returns Integer scale factor >= 1 - */ -export function calculateDisplayScale( - width: number, - height: number, - targetSize: number = 400 -): number { - return Math.max(1, Math.floor(targetSize / Math.max(width, height))); -} - -/** - * Extract bytes from DataView (handles anywidget's byte transfer). - * @param dataView - DataView from anywidget - * @returns Uint8Array of bytes - */ -export function extractBytes(dataView: DataView | ArrayBuffer | Uint8Array): Uint8Array { - if (dataView instanceof Uint8Array) { - return dataView; - } - if (dataView instanceof ArrayBuffer) { - return new Uint8Array(dataView); - } - // DataView from anywidget - if (dataView && "buffer" in dataView) { - return new Uint8Array( - dataView.buffer, - dataView.byteOffset, - dataView.byteLength - ); - } - return new Uint8Array(0); -} diff --git a/widget/js/core/colormaps.ts b/widget/js/core/colormaps.ts deleted file mode 100644 index 7047d1b4..00000000 --- a/widget/js/core/colormaps.ts +++ /dev/null @@ -1,100 +0,0 @@ -/** - * Colormap definitions and utilities for image display. - * Shared across Show2D, Show3D, Show4D widgets. - */ - -// Control points for interpolation -export const COLORMAP_POINTS: Record = { - inferno: [ - [0, 0, 4], [40, 11, 84], [101, 21, 110], [159, 42, 99], - [212, 72, 66], [245, 125, 21], [252, 193, 57], [252, 255, 164], - ], - viridis: [ - [68, 1, 84], [72, 36, 117], [65, 68, 135], [53, 95, 141], - [42, 120, 142], [33, 145, 140], [34, 168, 132], [68, 191, 112], - [122, 209, 81], [189, 223, 38], [253, 231, 37], - ], - plasma: [ - [13, 8, 135], [75, 3, 161], [126, 3, 168], [168, 34, 150], - [203, 70, 121], [229, 107, 93], [248, 148, 65], [253, 195, 40], [240, 249, 33], - ], - magma: [ - [0, 0, 4], [28, 16, 68], [79, 18, 123], [129, 37, 129], - [181, 54, 122], [229, 80, 100], [251, 135, 97], [254, 194, 135], [252, 253, 191], - ], - hot: [ - [0, 0, 0], [87, 0, 0], [173, 0, 0], [255, 0, 0], - [255, 87, 0], [255, 173, 0], [255, 255, 0], [255, 255, 128], [255, 255, 255], - ], - gray: [[0, 0, 0], [255, 255, 255]], -}; - -/** Available colormap names */ -export const COLORMAP_NAMES = Object.keys(COLORMAP_POINTS); - -/** Create 256-entry LUT from control points */ -export function createColormapLUT(points: number[][]): Uint8Array { - const lut = new Uint8Array(256 * 3); - for (let i = 0; i < 256; i++) { - const t = (i / 255) * (points.length - 1); - const idx = Math.floor(t); - const frac = t - idx; - const p0 = points[Math.min(idx, points.length - 1)]; - const p1 = points[Math.min(idx + 1, points.length - 1)]; - lut[i * 3] = Math.round(p0[0] + frac * (p1[0] - p0[0])); - lut[i * 3 + 1] = Math.round(p0[1] + frac * (p1[1] - p0[1])); - lut[i * 3 + 2] = Math.round(p0[2] + frac * (p1[2] - p0[2])); - } - return lut; -} - -/** Pre-computed LUTs for all colormaps (flat Uint8Array, 256*3 bytes each) */ -export const COLORMAPS: Record = Object.fromEntries( - Object.entries(COLORMAP_POINTS).map(([name, points]) => [name, createColormapLUT(points)]) -); - -/** Apply colormap to a single normalized value [0,1] */ -export function applyColormapValue( - value: number, - cmap: number[][] -): [number, number, number] { - const n = cmap.length - 1; - const t = Math.max(0, Math.min(1, value)) * n; - const i = Math.min(Math.floor(t), n - 1); - const f = t - i; - return [ - Math.round(cmap[i][0] * (1 - f) + cmap[i + 1][0] * f), - Math.round(cmap[i][1] * (1 - f) + cmap[i + 1][1] * f), - Math.round(cmap[i][2] * (1 - f) + cmap[i + 1][2] * f), - ]; -} - -/** - * Apply colormap to uint8 grayscale data, returning RGBA ImageData. - * @param data - Uint8Array of grayscale values (0-255) - * @param width - Image width - * @param height - Image height - * @param cmapName - Name of colormap to use - * @returns Uint8ClampedArray of RGBA values - */ -export function applyColormapToImage( - data: Uint8Array, - width: number, - height: number, - cmapName: string -): Uint8ClampedArray { - const lut = COLORMAPS[cmapName] || COLORMAPS.inferno; - const rgba = new Uint8ClampedArray(width * height * 4); - - for (let i = 0; i < data.length; i++) { - const v = Math.max(0, Math.min(255, data[i])); - const j = i * 4; - const lutIdx = v * 3; - rgba[j] = lut[lutIdx]; - rgba[j + 1] = lut[lutIdx + 1]; - rgba[j + 2] = lut[lutIdx + 2]; - rgba[j + 3] = 255; - } - - return rgba; -} diff --git a/widget/js/core/colors.ts b/widget/js/core/colors.ts deleted file mode 100644 index 996ee68d..00000000 --- a/widget/js/core/colors.ts +++ /dev/null @@ -1,71 +0,0 @@ -/** - * Shared color palette for all bobleesj.widget components. - * Single source of truth for theming across Show2D, Show3D, Show4D, and Reconstruct. - */ - -// Primary color definitions (SCREAMING_SNAKE_CASE for constants) -export const COLORS = { - // Backgrounds - BG: "#1a1a1a", - BG_PANEL: "#222", - BG_INPUT: "#333", - BG_CANVAS: "#000", - - // Borders - BORDER: "#444", - BORDER_LIGHT: "#555", - - // Text - TEXT_PRIMARY: "#fff", - TEXT_SECONDARY: "#aaa", - TEXT_MUTED: "#888", - TEXT_DIM: "#666", - - // Accent colors - ACCENT: "#0af", - ACCENT_GREEN: "#0f0", - ACCENT_RED: "#f00", - ACCENT_ORANGE: "#fa0", - ACCENT_CYAN: "#0cf", - ACCENT_YELLOW: "#ff0", -} as const; - -// Convenience alias with camelCase keys (for existing widget code) -export const colors = { - bg: COLORS.BG, - bgPanel: COLORS.BG_PANEL, - bgInput: COLORS.BG_INPUT, - bgCanvas: COLORS.BG_CANVAS, - border: COLORS.BORDER, - borderLight: COLORS.BORDER_LIGHT, - textPrimary: COLORS.TEXT_PRIMARY, - textSecondary: COLORS.TEXT_SECONDARY, - textMuted: COLORS.TEXT_MUTED, - textDim: COLORS.TEXT_DIM, - accent: COLORS.ACCENT, - accentGreen: COLORS.ACCENT_GREEN, - accentRed: COLORS.ACCENT_RED, - accentOrange: COLORS.ACCENT_ORANGE, - accentCyan: COLORS.ACCENT_CYAN, - accentYellow: COLORS.ACCENT_YELLOW, -} as const; - -// CSS variable export for vanilla JS widgets -export const cssVars = ` - --bg: ${COLORS.BG}; - --bg-panel: ${COLORS.BG_PANEL}; - --bg-input: ${COLORS.BG_INPUT}; - --bg-canvas: ${COLORS.BG_CANVAS}; - --border: ${COLORS.BORDER}; - --border-light: ${COLORS.BORDER_LIGHT}; - --text-primary: ${COLORS.TEXT_PRIMARY}; - --text-secondary: ${COLORS.TEXT_SECONDARY}; - --text-muted: ${COLORS.TEXT_MUTED}; - --text-dim: ${COLORS.TEXT_DIM}; - --accent: ${COLORS.ACCENT}; - --accent-green: ${COLORS.ACCENT_GREEN}; - --accent-red: ${COLORS.ACCENT_RED}; - --accent-orange: ${COLORS.ACCENT_ORANGE}; - --accent-cyan: ${COLORS.ACCENT_CYAN}; - --accent-yellow: ${COLORS.ACCENT_YELLOW}; -`; diff --git a/widget/js/core/export.ts b/widget/js/core/export.ts deleted file mode 100644 index 6a0324e6..00000000 --- a/widget/js/core/export.ts +++ /dev/null @@ -1,135 +0,0 @@ -/** - * Export utilities for downloading widget canvases as images. - * Composites multiple canvas layers and burns in overlays (scale bars, etc). - */ - -/** - * Generate a timestamped filename. - * @param prefix - Filename prefix (e.g., "show2d", "show3d") - * @param extension - File extension (default: "png") - */ -export function generateFilename(prefix: string, extension: string = "png"): string { - const now = new Date(); - const timestamp = now.toISOString() - .replace(/[:.]/g, "-") - .slice(0, 19); - return `${prefix}_${timestamp}.${extension}`; -} - -/** - * Composite multiple canvases into a single canvas. - * Layers are drawn in order (first = bottom, last = top). - * @param layers - Array of canvases to composite - * @param width - Output width - * @param height - Output height - */ -export function compositeCanvases( - layers: (HTMLCanvasElement | null)[], - width: number, - height: number -): HTMLCanvasElement { - const output = document.createElement("canvas"); - output.width = width; - output.height = height; - const ctx = output.getContext("2d"); - - if (ctx) { - // Fill with black background - ctx.fillStyle = "#000"; - ctx.fillRect(0, 0, width, height); - - // Draw each layer - for (const layer of layers) { - if (layer) { - ctx.drawImage(layer, 0, 0, width, height); - } - } - } - - return output; -} - -/** - * Download a canvas as a PNG file. - * @param canvas - The canvas to download - * @param filename - Output filename - */ -export function downloadCanvas(canvas: HTMLCanvasElement, filename: string): void { - canvas.toBlob((blob) => { - if (!blob) return; - - const url = URL.createObjectURL(blob); - const link = document.createElement("a"); - link.href = url; - link.download = filename; - link.click(); - - // Cleanup - URL.revokeObjectURL(url); - }, "image/png"); -} - -/** - * Export a widget's canvas with overlays burned in. - * @param imageCanvas - Main image canvas - * @param overlayCanvas - Overlay canvas (scale bar, etc) - * @param prefix - Filename prefix - * @param label - Optional label to append to filename - */ -export function exportWithOverlay( - imageCanvas: HTMLCanvasElement | null, - overlayCanvas: HTMLCanvasElement | null, - prefix: string, - label?: string -): void { - if (!imageCanvas) return; - - const width = imageCanvas.width; - const height = imageCanvas.height; - - const output = compositeCanvases([imageCanvas, overlayCanvas], width, height); - - // Generate filename with optional label - const cleanLabel = label ? `_${label.replace(/[^a-zA-Z0-9]/g, "_")}` : ""; - const filename = generateFilename(`${prefix}${cleanLabel}`); - - downloadCanvas(output, filename); -} - -/** - * Export multiple canvases as a ZIP file (for galleries). - * Requires JSZip to be available. - */ -export async function exportGalleryAsZip( - canvases: { image: HTMLCanvasElement | null; overlay: HTMLCanvasElement | null; label: string }[], - prefix: string -): Promise { - // Dynamic import to avoid bundling JSZip if not needed - const JSZip = (await import("jszip")).default; - const zip = new JSZip(); - - const timestamp = new Date().toISOString().replace(/[:.]/g, "-").slice(0, 19); - - for (let i = 0; i < canvases.length; i++) { - const { image, overlay, label } = canvases[i]; - if (!image) continue; - - const output = compositeCanvases([image, overlay], image.width, image.height); - const cleanLabel = label.replace(/[^a-zA-Z0-9]/g, "_"); - const filename = `${String(i + 1).padStart(3, "0")}_${cleanLabel}.png`; - - const blob = await new Promise((resolve) => { - output.toBlob((b) => resolve(b!), "image/png"); - }); - - zip.file(filename, blob); - } - - const zipBlob = await zip.generateAsync({ type: "blob" }); - const url = URL.createObjectURL(zipBlob); - const link = document.createElement("a"); - link.href = url; - link.download = `${prefix}_gallery_${timestamp}.zip`; - link.click(); - URL.revokeObjectURL(url); -} diff --git a/widget/js/core/fft-utils.ts b/widget/js/core/fft-utils.ts deleted file mode 100644 index 8d6a5626..00000000 --- a/widget/js/core/fft-utils.ts +++ /dev/null @@ -1,161 +0,0 @@ -/** - * FFT and histogram rendering utilities. - * Shared across Show2D and Show3D widgets. - */ - -import { COLORMAPS } from "./colormaps"; -import { colors } from "./colors"; - -// ============================================================================ -// FFT Rendering -// ============================================================================ - -/** - * Render FFT magnitude to canvas with log scale and colormap. - * @param ctx - Canvas 2D context - * @param fftMag - FFT magnitude data (Float32Array) - * @param width - Image width - * @param height - Image height - * @param panelSize - Canvas panel size - * @param zoom - Zoom level (default 3 for center detail) - * @param panX - Pan X offset - * @param panY - Pan Y offset - * @param cmapName - Colormap name (default "inferno") - */ -export function renderFFT( - ctx: CanvasRenderingContext2D, - fftMag: Float32Array, - width: number, - height: number, - panelSize: number, - zoom: number = 3, - panX: number = 0, - panY: number = 0, - cmapName: string = "inferno" -): void { - // Log scale and normalize - let min = Infinity; - let max = -Infinity; - const logData = new Float32Array(fftMag.length); - - for (let i = 0; i < fftMag.length; i++) { - logData[i] = Math.log(1 + fftMag[i]); - if (logData[i] < min) min = logData[i]; - if (logData[i] > max) max = logData[i]; - } - - const lut = COLORMAPS[cmapName] || COLORMAPS.inferno; - - // Create offscreen canvas at native resolution - const offscreen = document.createElement("canvas"); - offscreen.width = width; - offscreen.height = height; - const offCtx = offscreen.getContext("2d"); - if (!offCtx) return; - - const imgData = offCtx.createImageData(width, height); - const range = max - min || 1; - - for (let i = 0; i < logData.length; i++) { - const v = Math.floor(((logData[i] - min) / range) * 255); - const j = i * 4; - imgData.data[j] = lut[v * 3]; - imgData.data[j + 1] = lut[v * 3 + 1]; - imgData.data[j + 2] = lut[v * 3 + 2]; - imgData.data[j + 3] = 255; - } - offCtx.putImageData(imgData, 0, 0); - - // Draw with zoom/pan - center the zoomed view - const scale = panelSize / Math.max(width, height); - ctx.imageSmoothingEnabled = false; - ctx.clearRect(0, 0, panelSize, panelSize); - ctx.save(); - - const centerOffsetX = (panelSize - width * scale * zoom) / 2 + panX; - const centerOffsetY = (panelSize - height * scale * zoom) / 2 + panY; - - ctx.translate(centerOffsetX, centerOffsetY); - ctx.scale(zoom, zoom); - ctx.drawImage(offscreen, 0, 0, width * scale, height * scale); - ctx.restore(); -} - -// ============================================================================ -// Histogram Rendering -// ============================================================================ - -/** - * Render histogram to canvas. - * @param ctx - Canvas 2D context - * @param counts - Histogram bin counts - * @param panelSize - Canvas panel size - * @param accentColor - Bar color (default: accent blue) - * @param bgColor - Background color (default: panel background) - */ -export function renderHistogram( - ctx: CanvasRenderingContext2D, - counts: number[], - panelSize: number, - accentColor: string = colors.accent, - bgColor: string = colors.bgPanel -): void { - const w = panelSize; - const h = panelSize; - - // Clear and fill background - ctx.fillStyle = bgColor; - ctx.fillRect(0, 0, w, h); - - // Only draw bars if we have data - if (!counts || counts.length === 0) return; - - const maxCount = Math.max(...counts); - if (maxCount === 0) return; - - // Add padding for centering - const padding = 8; - const drawWidth = w - 2 * padding; - const drawHeight = h - padding - 5; // 5px bottom margin - const barWidth = drawWidth / counts.length; - - ctx.fillStyle = accentColor; - for (let i = 0; i < counts.length; i++) { - const barHeight = (counts[i] / maxCount) * drawHeight; - ctx.fillRect(padding + i * barWidth, h - padding - barHeight, barWidth - 1, barHeight); - } -} - -// ============================================================================ -// FFT Shift (move DC component to center) -// ============================================================================ - -/** - * Shift FFT data to center the DC component. - * Modifies data in place. - */ -export function fftshift(data: Float32Array, width: number, height: number): void { - const halfW = width >> 1; - const halfH = height >> 1; - const temp = new Float32Array(width * height); - - for (let y = 0; y < height; y++) { - for (let x = 0; x < width; x++) { - const newY = (y + halfH) % height; - const newX = (x + halfW) % width; - temp[newY * width + newX] = data[y * width + x]; - } - } - data.set(temp); -} - -/** - * Compute FFT magnitude from real and imaginary parts. - */ -export function computeMagnitude(real: Float32Array, imag: Float32Array): Float32Array { - const mag = new Float32Array(real.length); - for (let i = 0; i < real.length; i++) { - mag[i] = Math.sqrt(real[i] ** 2 + imag[i] ** 2); - } - return mag; -} diff --git a/widget/js/core/format.ts b/widget/js/core/format.ts deleted file mode 100644 index a168c68c..00000000 --- a/widget/js/core/format.ts +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Number and text formatting utilities. - */ - -/** - * Format a number for display with appropriate precision. - * Uses exponential notation for very large or small values. - * @param val - Value to format - * @param decimals - Number of decimal places (default 2) - * @returns Formatted string - */ -export function formatNumber(val: number, decimals: number = 2): string { - if (val === 0) return "0"; - if (Math.abs(val) >= 1000 || Math.abs(val) < 0.01) { - return val.toExponential(decimals); - } - return val.toFixed(decimals); -} - -/** - * Format bytes as human-readable size. - * @param bytes - Number of bytes - * @returns Formatted string (e.g., "1.5 MB") - */ -export function formatBytes(bytes: number): string { - if (bytes === 0) return "0 B"; - const k = 1024; - const sizes = ["B", "KB", "MB", "GB"]; - const i = Math.floor(Math.log(bytes) / Math.log(k)); - return `${(bytes / Math.pow(k, i)).toFixed(1)} ${sizes[i]}`; -} - -/** - * Format time duration. - * @param seconds - Duration in seconds - * @returns Formatted string (e.g., "1.5 s" or "150 ms") - */ -export function formatDuration(seconds: number): string { - if (seconds < 0.001) { - return `${(seconds * 1e6).toFixed(0)} µs`; - } - if (seconds < 1) { - return `${(seconds * 1000).toFixed(1)} ms`; - } - if (seconds < 60) { - return `${seconds.toFixed(2)} s`; - } - const mins = Math.floor(seconds / 60); - const secs = seconds % 60; - return `${mins}m ${secs.toFixed(0)}s`; -} - -/** - * Clamp a value between min and max. - */ -export function clamp(val: number, min: number, max: number): number { - return Math.max(min, Math.min(max, val)); -} diff --git a/widget/js/core/hooks.ts b/widget/js/core/hooks.ts deleted file mode 100644 index c5d59a39..00000000 --- a/widget/js/core/hooks.ts +++ /dev/null @@ -1,211 +0,0 @@ -/** - * Shared React hooks for widget functionality. - * Provides reusable zoom/pan and resize logic. - */ - -import * as React from "react"; - -// ============================================================================ -// Constants -// ============================================================================ -export const ZOOM_LIMITS = { - MIN: 0.5, - MAX: 10, - WHEEL_IN: 1.1, - WHEEL_OUT: 0.9, -} as const; - -// ============================================================================ -// Types -// ============================================================================ -export interface ZoomPanState { - zoom: number; - panX: number; - panY: number; -} - -export const DEFAULT_ZOOM_PAN: ZoomPanState = { - zoom: 1, - panX: 0, - panY: 0, -}; - -// ============================================================================ -// useZoomPan Hook -// ============================================================================ -export interface UseZoomPanOptions { - canvasRef: React.RefObject; - canvasWidth: number; - canvasHeight: number; - initialState?: ZoomPanState; -} - -export interface UseZoomPanResult { - state: ZoomPanState; - setState: React.Dispatch>; - reset: () => void; - handleWheel: (e: React.WheelEvent) => void; - handleMouseDown: (e: React.MouseEvent) => void; - handleMouseMove: (e: React.MouseEvent) => void; - handleMouseUp: () => void; - handleDoubleClick: () => void; - isDragging: boolean; -} - -export function useZoomPan(options: UseZoomPanOptions): UseZoomPanResult { - const { canvasRef, canvasWidth, canvasHeight, initialState = DEFAULT_ZOOM_PAN } = options; - - const [state, setState] = React.useState(initialState); - const [isDragging, setIsDragging] = React.useState(false); - const [dragStart, setDragStart] = React.useState<{ x: number; y: number; panX: number; panY: number } | null>(null); - - const reset = React.useCallback(() => { - setState(DEFAULT_ZOOM_PAN); - }, []); - - const handleWheel = React.useCallback((e: React.WheelEvent) => { - const canvas = canvasRef.current; - if (!canvas) return; - - const rect = canvas.getBoundingClientRect(); - const scaleX = canvas.width / rect.width; - const scaleY = canvas.height / rect.height; - - // Mouse position in canvas coordinates - const mouseX = (e.clientX - rect.left) * scaleX; - const mouseY = (e.clientY - rect.top) * scaleY; - - // Canvas center - const cx = canvasWidth / 2; - const cy = canvasHeight / 2; - - setState(prev => { - // Calculate position in image space - const imageX = (mouseX - cx - prev.panX) / prev.zoom + cx; - const imageY = (mouseY - cy - prev.panY) / prev.zoom + cy; - - // Apply zoom factor - const zoomFactor = e.deltaY > 0 ? ZOOM_LIMITS.WHEEL_OUT : ZOOM_LIMITS.WHEEL_IN; - const newZoom = Math.max(ZOOM_LIMITS.MIN, Math.min(ZOOM_LIMITS.MAX, prev.zoom * zoomFactor)); - - // Calculate new pan to keep mouse position fixed - const newPanX = mouseX - (imageX - cx) * newZoom - cx; - const newPanY = mouseY - (imageY - cy) * newZoom - cy; - - return { zoom: newZoom, panX: newPanX, panY: newPanY }; - }); - }, [canvasRef, canvasWidth, canvasHeight]); - - const handleMouseDown = React.useCallback((e: React.MouseEvent) => { - setIsDragging(true); - setDragStart({ x: e.clientX, y: e.clientY, panX: state.panX, panY: state.panY }); - }, [state.panX, state.panY]); - - const handleMouseMove = React.useCallback((e: React.MouseEvent) => { - if (!isDragging || !dragStart) return; - - const canvas = canvasRef.current; - if (!canvas) return; - - const rect = canvas.getBoundingClientRect(); - const scaleX = canvas.width / rect.width; - const scaleY = canvas.height / rect.height; - - const dx = (e.clientX - dragStart.x) * scaleX; - const dy = (e.clientY - dragStart.y) * scaleY; - - setState(prev => ({ ...prev, panX: dragStart.panX + dx, panY: dragStart.panY + dy })); - }, [isDragging, dragStart, canvasRef]); - - const handleMouseUp = React.useCallback(() => { - setIsDragging(false); - setDragStart(null); - }, []); - - const handleDoubleClick = React.useCallback(() => { - reset(); - }, [reset]); - - return { - state, - setState, - reset, - handleWheel, - handleMouseDown, - handleMouseMove, - handleMouseUp, - handleDoubleClick, - isDragging, - }; -} - -// ============================================================================ -// useResize Hook -// ============================================================================ -export interface UseResizeOptions { - initialSize: number; - minSize?: number; - maxSize?: number; -} - -export interface UseResizeResult { - size: number; - setSize: React.Dispatch>; - isResizing: boolean; - handleResizeStart: (e: React.MouseEvent) => void; -} - -export function useResize(options: UseResizeOptions): UseResizeResult { - const { initialSize, minSize = 80, maxSize = 600 } = options; - - const [size, setSize] = React.useState(initialSize); - const [isResizing, setIsResizing] = React.useState(false); - const [resizeStart, setResizeStart] = React.useState<{ x: number; y: number; size: number } | null>(null); - - const handleResizeStart = React.useCallback((e: React.MouseEvent) => { - e.stopPropagation(); - e.preventDefault(); - setIsResizing(true); - setResizeStart({ x: e.clientX, y: e.clientY, size }); - }, [size]); - - React.useEffect(() => { - if (!isResizing || !resizeStart) return; - - const handleMouseMove = (e: MouseEvent) => { - const delta = Math.max(e.clientX - resizeStart.x, e.clientY - resizeStart.y); - const newSize = Math.max(minSize, Math.min(maxSize, resizeStart.size + delta)); - setSize(newSize); - }; - - const handleMouseUp = () => { - setIsResizing(false); - setResizeStart(null); - }; - - document.addEventListener("mousemove", handleMouseMove); - document.addEventListener("mouseup", handleMouseUp); - return () => { - document.removeEventListener("mousemove", handleMouseMove); - document.removeEventListener("mouseup", handleMouseUp); - }; - }, [isResizing, resizeStart, minSize, maxSize]); - - return { size, setSize, isResizing, handleResizeStart }; -} - -// ============================================================================ -// usePreventScroll Hook -// ============================================================================ -export function usePreventScroll(refs: React.RefObject[]): void { - React.useEffect(() => { - const preventDefault = (e: WheelEvent) => e.preventDefault(); - const elements = refs.map(ref => ref.current).filter(Boolean) as HTMLElement[]; - - elements.forEach(el => el.addEventListener("wheel", preventDefault, { passive: false })); - - return () => { - elements.forEach(el => el.removeEventListener("wheel", preventDefault)); - }; - }, [refs]); -} diff --git a/widget/js/core/index.ts b/widget/js/core/index.ts deleted file mode 100644 index 45769f87..00000000 --- a/widget/js/core/index.ts +++ /dev/null @@ -1,86 +0,0 @@ -/** - * Core utilities for bobleesj.widget components. - * Re-exports all shared modules. - */ - -// Colors and theming -export { COLORS, colors, cssVars } from "./colors"; - -// Colormaps -export { - COLORMAP_NAMES, - COLORMAP_POINTS, - COLORMAPS, - applyColormapToImage, - applyColormapValue, - createColormapLUT, -} from "./colormaps"; - -// Canvas utilities -export { - calculateDisplayScale, - calculateNiceScaleBar, - drawCrosshair, - drawROI, - drawROICircle, - drawScaleBar, - extractBytes, - formatScaleBarLabel, -} from "./canvas"; - -// Formatting -export { - clamp, - formatBytes, - formatDuration, - formatNumber, -} from "./format"; - -// Base CSS -export { baseCSS } from "./styles"; - -// FFT and histogram utilities -export { - computeMagnitude, - fftshift, - renderFFT, - renderHistogram, -} from "./fft-utils"; - -// React hooks -export { - DEFAULT_ZOOM_PAN, - ZOOM_LIMITS, - usePreventScroll, - useResize, - useZoomPan, - type UseResizeOptions, - type UseResizeResult, - type UseZoomPanOptions, - type UseZoomPanResult, - type ZoomPanState, -} from "./hooks"; - -// WebGPU hook -export { useWebGPU, type UseWebGPUResult } from "./webgpu-hook"; - -// Advanced canvas utilities (high-DPI, colormap rendering) -export { - canvasToBlob, - drawCrosshairHiDPI, - drawScaleBarHiDPI, - drawWithZoomPan, - formatScaleLabel, - renderFloat32WithColormap, - renderWithColormap, - roundToNiceValue, -} from "./canvas-utils"; - -// Export utilities -export { - compositeCanvases, - downloadCanvas, - exportGalleryAsZip, - exportWithOverlay, - generateFilename, -} from "./export"; diff --git a/widget/js/core/styles.ts b/widget/js/core/styles.ts deleted file mode 100644 index fbc3c8cf..00000000 --- a/widget/js/core/styles.ts +++ /dev/null @@ -1,295 +0,0 @@ -/** - * Shared CSS for widget components. - * Base styles used by Show2D, Show3D, and other vanilla JS widgets. - */ - -export const baseCSS = ` -/* ============================================================================ - Base Styles - Shared across Show2D, Show3D - ============================================================================ */ - -/* Root container */ -.widget-root { - font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif; - background-color: var(--bg, #1a1a1a); - color: var(--text-primary, #fff); - padding: 12px; - border-radius: 6px; - display: inline-block; - min-width: 320px; - - --bg: #1a1a1a; - --bg-panel: #222; - --bg-input: #333; - --bg-canvas: #000; - --border: #444; - --border-light: #555; - --text-primary: #fff; - --text-secondary: #aaa; - --text-muted: #888; - --text-dim: #666; - --accent: #0af; - --accent-green: #0f0; - --accent-red: #f00; -} - -.widget-root:focus { - outline: 2px solid var(--accent); - outline-offset: 2px; -} - -/* Title bar */ -.widget-title-bar { - margin-bottom: 8px; -} - -.widget-title { - color: var(--accent); - font-weight: bold; - font-size: 13px; -} - -/* Canvas container */ -.widget-canvas-container { - position: relative; - background-color: var(--bg-canvas); - border: 1px solid var(--border); - border-radius: 4px; - overflow: hidden; -} - -.widget-canvas { - display: block; - image-rendering: pixelated; - image-rendering: crisp-edges; -} - -.widget-overlay { - position: absolute; - top: 0; - left: 0; - pointer-events: none; -} - -/* Panels */ -.widget-panel { - background-color: var(--bg-panel); - border: 1px solid var(--border); - border-radius: 4px; - padding: 6px; -} - -.widget-panel-title { - font-size: 10px; - color: var(--text-muted); - text-transform: uppercase; - margin-bottom: 4px; -} - -/* Control group */ -.widget-control-group { - display: flex; - align-items: center; - gap: 6px; - background-color: var(--bg-panel); - padding: 4px 8px; - border-radius: 4px; - border: 1px solid var(--border); -} - -/* Buttons */ -.widget-btn { - background-color: var(--bg-input); - border: 1px solid var(--border-light); - color: var(--text-secondary); - min-width: 32px; - height: 28px; - border-radius: 4px; - cursor: pointer; - font-size: 12px; - display: flex; - align-items: center; - justify-content: center; - transition: all 0.15s; - padding: 0 8px; -} - -.widget-btn:hover { - background-color: var(--border); - color: var(--text-primary); -} - -.widget-btn:active, -.widget-btn-active { - background-color: var(--accent); - color: #000; - border-color: var(--accent); -} - -.widget-btn-primary { - background-color: var(--accent); - color: #000; - border-color: var(--accent); -} - -.widget-btn-primary:hover { - background-color: #0cf; -} - -/* Slider */ -.widget-slider { - flex: 1; - height: 6px; - -webkit-appearance: none; - appearance: none; - background: var(--border); - border-radius: 3px; - cursor: pointer; -} - -.widget-slider::-webkit-slider-thumb { - -webkit-appearance: none; - width: 14px; - height: 14px; - background: var(--accent); - border-radius: 50%; - cursor: pointer; - border: 2px solid var(--bg); - box-shadow: 0 1px 3px rgba(0, 0, 0, 0.4); -} - -.widget-slider::-moz-range-thumb { - width: 14px; - height: 14px; - background: var(--accent); - border-radius: 50%; - cursor: pointer; - border: 2px solid var(--bg); -} - -.widget-slider:focus { - outline: none; -} - -/* Inputs */ -.widget-input { - background-color: var(--bg-input); - border: 1px solid var(--border-light); - color: var(--text-secondary); - border-radius: 3px; - padding: 4px 6px; - font-size: 11px; - font-family: monospace; -} - -.widget-input:focus { - outline: none; - border-color: var(--accent); - color: var(--text-primary); -} - -.widget-input-small { - width: 45px; - height: 24px; - text-align: center; -} - -/* Toggles / Checkboxes */ -.widget-toggle { - display: flex; - align-items: center; - gap: 4px; - font-size: 11px; - color: var(--text-muted); - cursor: pointer; - user-select: none; -} - -.widget-toggle:hover { - color: var(--text-primary); -} - -.widget-toggle input[type="checkbox"] { - width: 14px; - height: 14px; - accent-color: var(--accent); - cursor: pointer; -} - -/* Select */ -.widget-select { - background-color: var(--bg-input); - border: 1px solid var(--border-light); - color: var(--text-secondary); - border-radius: 3px; - padding: 4px 8px; - font-size: 11px; - cursor: pointer; -} - -.widget-select:focus { - outline: none; - border-color: var(--accent); -} - -/* Stats bar */ -.widget-stats-bar { - display: flex; - flex-wrap: wrap; - gap: 16px; - background-color: var(--bg-panel); - padding: 6px 12px; - border-radius: 4px; - border: 1px solid var(--border); -} - -.widget-stat-item { - display: flex; - gap: 6px; - align-items: baseline; -} - -.widget-stat-label { - font-size: 10px; - color: var(--text-dim); -} - -.widget-stat-value { - font-size: 11px; - font-family: monospace; - color: var(--accent); -} - -/* Labels */ -.widget-label { - color: var(--text-secondary); - font-size: 11px; -} - -.widget-label-small { - color: var(--text-dim); - font-size: 10px; -} - -/* Layout helpers */ -.widget-row { - display: flex; - align-items: center; - gap: 8px; -} - -.widget-col { - display: flex; - flex-direction: column; - gap: 8px; -} - -.widget-flex { - flex: 1; -} - -/* Monospace text */ -.widget-mono { - font-family: monospace; -} -`; diff --git a/widget/js/core/webgpu-hook.ts b/widget/js/core/webgpu-hook.ts deleted file mode 100644 index 3c72f846..00000000 --- a/widget/js/core/webgpu-hook.ts +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Shared WebGPU FFT hook for all widgets. - * Provides consistent GPU acceleration across Show4DSTEM and Reconstruct. - */ - -import * as React from "react"; -import { getWebGPUFFT, WebGPUFFT } from "../webgpu-fft"; - -export interface UseWebGPUResult { - gpuFFT: WebGPUFFT | null; - gpuReady: boolean; -} - -/** - * Hook to initialize WebGPU FFT on mount. - * Returns null if WebGPU is not available (falls back to CPU). - */ -export function useWebGPU(): UseWebGPUResult { - const gpuFFTRef = React.useRef(null); - const [gpuReady, setGpuReady] = React.useState(false); - - React.useEffect(() => { - let cancelled = false; - - getWebGPUFFT().then(fft => { - if (cancelled) return; - if (fft) { - gpuFFTRef.current = fft; - setGpuReady(true); - } - }); - - return () => { cancelled = true; }; - }, []); - - return { gpuFFT: gpuFFTRef.current, gpuReady }; -} diff --git a/widget/js/shared.ts b/widget/js/shared.ts deleted file mode 100644 index cb14fb33..00000000 --- a/widget/js/shared.ts +++ /dev/null @@ -1,221 +0,0 @@ -/** - * Shared utilities for widget components. - * Contains CPU FFT fallback and band-pass filtering. - * Re-exports commonly used utilities from core. - */ - -// Re-export colormaps from core -export { COLORMAP_NAMES, COLORMAP_POINTS, COLORMAPS, createColormapLUT, applyColormapValue, applyColormapToImage } from "./core/colormaps"; - -// Re-export fftshift from core (also available here for backward compatibility) -export { fftshift, computeMagnitude, renderFFT, renderHistogram } from "./core/fft-utils"; - -// Re-export zoom constants from core hooks -export { ZOOM_LIMITS } from "./core/hooks"; -export const MIN_ZOOM = 0.5; // Legacy alias -export const MAX_ZOOM = 10; // Legacy alias - -// ============================================================================ -// CPU FFT Implementation (Cooley-Tukey radix-2) - Fallback when WebGPU unavailable -// Supports ANY size via automatic zero-padding to next power of 2 -// ============================================================================ - -/** Get next power of 2 >= n */ -function nextPow2(n: number): number { - return Math.pow(2, Math.ceil(Math.log2(n))); -} - -/** Check if n is a power of 2 */ -function isPow2(n: number): boolean { - return n > 0 && (n & (n - 1)) === 0; -} - -/** Internal 1D FFT - requires power-of-2 size */ -function fft1dPow2(real: Float32Array, imag: Float32Array, inverse: boolean = false) { - const n = real.length; - if (n <= 1) return; - - // Bit-reversal permutation - let j = 0; - for (let i = 0; i < n - 1; i++) { - if (i < j) { - [real[i], real[j]] = [real[j], real[i]]; - [imag[i], imag[j]] = [imag[j], imag[i]]; - } - let k = n >> 1; - while (k <= j) { j -= k; k >>= 1; } - j += k; - } - - // Cooley-Tukey FFT - const sign = inverse ? 1 : -1; - for (let len = 2; len <= n; len <<= 1) { - const halfLen = len >> 1; - const angle = (sign * 2 * Math.PI) / len; - const wReal = Math.cos(angle); - const wImag = Math.sin(angle); - - for (let i = 0; i < n; i += len) { - let curReal = 1, curImag = 0; - for (let k = 0; k < halfLen; k++) { - const evenIdx = i + k; - const oddIdx = i + k + halfLen; - - const tReal = curReal * real[oddIdx] - curImag * imag[oddIdx]; - const tImag = curReal * imag[oddIdx] + curImag * real[oddIdx]; - - real[oddIdx] = real[evenIdx] - tReal; - imag[oddIdx] = imag[evenIdx] - tImag; - real[evenIdx] += tReal; - imag[evenIdx] += tImag; - - const newReal = curReal * wReal - curImag * wImag; - curImag = curReal * wImag + curImag * wReal; - curReal = newReal; - } - } - } - - if (inverse) { - for (let i = 0; i < n; i++) { - real[i] /= n; - imag[i] /= n; - } - } -} - -/** - * 1D FFT - supports ANY size via zero-padding - * Modifies arrays in-place - */ -export function fft1d(real: Float32Array, imag: Float32Array, inverse: boolean = false) { - const n = real.length; - if (isPow2(n)) { - fft1dPow2(real, imag, inverse); - return; - } - - // Pad to next power of 2 - const paddedN = nextPow2(n); - const paddedReal = new Float32Array(paddedN); - const paddedImag = new Float32Array(paddedN); - paddedReal.set(real); - paddedImag.set(imag); - - fft1dPow2(paddedReal, paddedImag, inverse); - - // Copy back (truncate to original size) - for (let i = 0; i < n; i++) { - real[i] = paddedReal[i]; - imag[i] = paddedImag[i]; - } -} - -/** - * 2D FFT - supports ANY size via zero-padding - * Modifies arrays in-place - */ -export function fft2d(real: Float32Array, imag: Float32Array, width: number, height: number, inverse: boolean = false) { - const paddedW = nextPow2(width); - const paddedH = nextPow2(height); - const needsPadding = paddedW !== width || paddedH !== height; - - // Work arrays (padded if needed) - let workReal: Float32Array; - let workImag: Float32Array; - - if (needsPadding) { - workReal = new Float32Array(paddedW * paddedH); - workImag = new Float32Array(paddedW * paddedH); - // Copy original data into top-left corner - for (let y = 0; y < height; y++) { - for (let x = 0; x < width; x++) { - const srcIdx = y * width + x; - const dstIdx = y * paddedW + x; - workReal[dstIdx] = real[srcIdx]; - workImag[dstIdx] = imag[srcIdx]; - } - } - } else { - workReal = real; - workImag = imag; - } - - // FFT on rows (padded width) - const rowReal = new Float32Array(paddedW); - const rowImag = new Float32Array(paddedW); - for (let y = 0; y < paddedH; y++) { - const offset = y * paddedW; - for (let x = 0; x < paddedW; x++) { - rowReal[x] = workReal[offset + x]; - rowImag[x] = workImag[offset + x]; - } - fft1dPow2(rowReal, rowImag, inverse); - for (let x = 0; x < paddedW; x++) { - workReal[offset + x] = rowReal[x]; - workImag[offset + x] = rowImag[x]; - } - } - - // FFT on columns (padded height) - const colReal = new Float32Array(paddedH); - const colImag = new Float32Array(paddedH); - for (let x = 0; x < paddedW; x++) { - for (let y = 0; y < paddedH; y++) { - colReal[y] = workReal[y * paddedW + x]; - colImag[y] = workImag[y * paddedW + x]; - } - fft1dPow2(colReal, colImag, inverse); - for (let y = 0; y < paddedH; y++) { - workReal[y * paddedW + x] = colReal[y]; - workImag[y * paddedW + x] = colImag[y]; - } - } - - // Copy back to original arrays if padded - if (needsPadding) { - for (let y = 0; y < height; y++) { - for (let x = 0; x < width; x++) { - const srcIdx = y * paddedW + x; - const dstIdx = y * width + x; - real[dstIdx] = workReal[srcIdx]; - imag[dstIdx] = workImag[srcIdx]; - } - } - } -} - -// ============================================================================ -// Band-pass Filter -// ============================================================================ - -/** Apply band-pass filter in frequency domain (keeps frequencies between inner and outer radius) */ -export function applyBandPassFilter( - real: Float32Array, - imag: Float32Array, - width: number, - height: number, - innerRadius: number, // High-pass: remove frequencies below this - outerRadius: number // Low-pass: remove frequencies above this -) { - const centerX = width >> 1; - const centerY = height >> 1; - const innerSq = innerRadius * innerRadius; - const outerSq = outerRadius * outerRadius; - - for (let y = 0; y < height; y++) { - for (let x = 0; x < width; x++) { - const dx = x - centerX; - const dy = y - centerY; - const distSq = dx * dx + dy * dy; - const idx = y * width + x; - - // Zero out frequencies outside the band - if (distSq < innerSq || (outerRadius > 0 && distSq > outerSq)) { - real[idx] = 0; - imag[idx] = 0; - } - } - } -} - diff --git a/widget/js/webgpu-fft.ts b/widget/js/webgpu-fft.ts deleted file mode 100644 index 4cbdbe7e..00000000 --- a/widget/js/webgpu-fft.ts +++ /dev/null @@ -1,558 +0,0 @@ -/// - -/** - * WebGPU FFT Implementation - * - * Implements Cooley-Tukey radix-2 FFT using WebGPU compute shaders. - * Supports 1D and 2D FFT with forward and inverse transforms. - */ - -// WGSL Shader for FFT butterfly operations -const FFT_SHADER = /* wgsl */` -// Complex number operations -fn cmul(a: vec2, b: vec2) -> vec2 { - return vec2( - a.x * b.x - a.y * b.y, - a.x * b.y + a.y * b.x - ); -} - -// Twiddle factor: e^(-2πi * k / N) for forward, e^(2πi * k / N) for inverse -fn twiddle(k: u32, N: u32, inverse: f32) -> vec2 { - let angle = inverse * 2.0 * 3.14159265359 * f32(k) / f32(N); - return vec2(cos(angle), sin(angle)); -} - -// Bit reversal for index -fn bitReverse(x: u32, log2N: u32) -> u32 { - var result: u32 = 0u; - var val = x; - for (var i: u32 = 0u; i < log2N; i = i + 1u) { - result = (result << 1u) | (val & 1u); - val = val >> 1u; - } - return result; -} - -struct FFTParams { - N: u32, // FFT size - log2N: u32, // log2(N) - stage: u32, // Current butterfly stage - inverse: f32, // -1.0 for forward, 1.0 for inverse -} - -@group(0) @binding(0) var params: FFTParams; -@group(0) @binding(1) var data: array>; - -// Bit-reversal permutation kernel -@compute @workgroup_size(256) -fn bitReversePermute(@builtin(global_invocation_id) gid: vec3) { - let idx = gid.x; - if (idx >= params.N) { return; } - - let rev = bitReverse(idx, params.log2N); - if (idx < rev) { - let temp = data[idx]; - data[idx] = data[rev]; - data[rev] = temp; - } -} - -// Butterfly operation kernel for one stage -@compute @workgroup_size(256) -fn butterflyStage(@builtin(global_invocation_id) gid: vec3) { - let idx = gid.x; - if (idx >= params.N / 2u) { return; } - - let stage = params.stage; - let halfSize = 1u << stage; // 2^stage - let fullSize = halfSize << 1u; // 2^(stage+1) - - let group = idx / halfSize; - let pos = idx % halfSize; - - let i = group * fullSize + pos; - let j = i + halfSize; - - let w = twiddle(pos, fullSize, params.inverse); - - let u = data[i]; - let t = cmul(w, data[j]); - - data[i] = u + t; - data[j] = u - t; -} - -// Normalization for inverse FFT -@compute @workgroup_size(256) -fn normalize(@builtin(global_invocation_id) gid: vec3) { - let idx = gid.x; - if (idx >= params.N) { return; } - - let scale = 1.0 / f32(params.N); - data[idx] = data[idx] * scale; -} -`; - -// 2D FFT Shader (row-wise and column-wise transforms) -const FFT_2D_SHADER = /* wgsl */` -fn cmul(a: vec2, b: vec2) -> vec2 { - return vec2( - a.x * b.x - a.y * b.y, - a.x * b.y + a.y * b.x - ); -} - -fn twiddle(k: u32, N: u32, inverse: f32) -> vec2 { - let angle = inverse * 2.0 * 3.14159265359 * f32(k) / f32(N); - return vec2(cos(angle), sin(angle)); -} - -fn bitReverse(x: u32, log2N: u32) -> u32 { - var result: u32 = 0u; - var val = x; - for (var i: u32 = 0u; i < log2N; i = i + 1u) { - result = (result << 1u) | (val & 1u); - val = val >> 1u; - } - return result; -} - -struct FFT2DParams { - width: u32, - height: u32, - log2Size: u32, - stage: u32, - inverse: f32, - isRowWise: u32, // 1 for row-wise, 0 for column-wise -} - -@group(0) @binding(0) var params: FFT2DParams; -@group(0) @binding(1) var data: array>; - -// Get linear index for 2D data -fn getIndex(row: u32, col: u32) -> u32 { - return row * params.width + col; -} - -// Bit-reversal for rows -@compute @workgroup_size(16, 16) -fn bitReverseRows(@builtin(global_invocation_id) gid: vec3) { - let row = gid.y; - let col = gid.x; - if (row >= params.height || col >= params.width) { return; } - - let rev = bitReverse(col, params.log2Size); - if (col < rev) { - let idx1 = getIndex(row, col); - let idx2 = getIndex(row, rev); - let temp = data[idx1]; - data[idx1] = data[idx2]; - data[idx2] = temp; - } -} - -// Bit-reversal for columns -@compute @workgroup_size(16, 16) -fn bitReverseCols(@builtin(global_invocation_id) gid: vec3) { - let row = gid.y; - let col = gid.x; - if (row >= params.height || col >= params.width) { return; } - - let rev = bitReverse(row, params.log2Size); - if (row < rev) { - let idx1 = getIndex(row, col); - let idx2 = getIndex(rev, col); - let temp = data[idx1]; - data[idx1] = data[idx2]; - data[idx2] = temp; - } -} - -// Butterfly for rows -@compute @workgroup_size(16, 16) -fn butterflyRows(@builtin(global_invocation_id) gid: vec3) { - let row = gid.y; - let idx = gid.x; - if (row >= params.height || idx >= params.width / 2u) { return; } - - let stage = params.stage; - let halfSize = 1u << stage; - let fullSize = halfSize << 1u; - - let group = idx / halfSize; - let pos = idx % halfSize; - - let col_i = group * fullSize + pos; - let col_j = col_i + halfSize; - - if (col_j >= params.width) { return; } - - let w = twiddle(pos, fullSize, params.inverse); - - let i = getIndex(row, col_i); - let j = getIndex(row, col_j); - - let u = data[i]; - let t = cmul(w, data[j]); - - data[i] = u + t; - data[j] = u - t; -} - -// Butterfly for columns -@compute @workgroup_size(16, 16) -fn butterflyCols(@builtin(global_invocation_id) gid: vec3) { - let col = gid.x; - let idx = gid.y; - if (col >= params.width || idx >= params.height / 2u) { return; } - - let stage = params.stage; - let halfSize = 1u << stage; - let fullSize = halfSize << 1u; - - let group = idx / halfSize; - let pos = idx % halfSize; - - let row_i = group * fullSize + pos; - let row_j = row_i + halfSize; - - if (row_j >= params.height) { return; } - - let w = twiddle(pos, fullSize, params.inverse); - - let i = getIndex(row_i, col); - let j = getIndex(row_j, col); - - let u = data[i]; - let t = cmul(w, data[j]); - - data[i] = u + t; - data[j] = u - t; -} - -// Normalization for inverse 2D FFT -@compute @workgroup_size(16, 16) -fn normalize2D(@builtin(global_invocation_id) gid: vec3) { - let row = gid.y; - let col = gid.x; - if (row >= params.height || col >= params.width) { return; } - - let idx = getIndex(row, col); - let scale = 1.0 / f32(params.width * params.height); - data[idx] = data[idx] * scale; -} -`; - -/** - * Get next power of 2 >= n - */ -function nextPow2(n: number): number { - return Math.pow(2, Math.ceil(Math.log2(n))); -} - -/** - * WebGPU FFT class for 1D and 2D transforms - */ -export class WebGPUFFT { - private device: GPUDevice; - private pipelines1D: { - bitReverse: GPUComputePipeline; - butterfly: GPUComputePipeline; - normalize: GPUComputePipeline; - } | null = null; - private pipelines2D: { - bitReverseRows: GPUComputePipeline; - bitReverseCols: GPUComputePipeline; - butterflyRows: GPUComputePipeline; - butterflyCols: GPUComputePipeline; - normalize: GPUComputePipeline; - } | null = null; - private initialized = false; - - constructor(device: GPUDevice) { - this.device = device; - } - - async init(): Promise { - if (this.initialized) return; - - // Create 1D FFT pipelines - const module1D = this.device.createShaderModule({ code: FFT_SHADER }); - - this.pipelines1D = { - bitReverse: this.device.createComputePipeline({ - layout: 'auto', - compute: { module: module1D, entryPoint: 'bitReversePermute' } - }), - butterfly: this.device.createComputePipeline({ - layout: 'auto', - compute: { module: module1D, entryPoint: 'butterflyStage' } - }), - normalize: this.device.createComputePipeline({ - layout: 'auto', - compute: { module: module1D, entryPoint: 'normalize' } - }) - }; - - // Create 2D FFT pipelines - const module2D = this.device.createShaderModule({ code: FFT_2D_SHADER }); - - this.pipelines2D = { - bitReverseRows: this.device.createComputePipeline({ - layout: 'auto', - compute: { module: module2D, entryPoint: 'bitReverseRows' } - }), - bitReverseCols: this.device.createComputePipeline({ - layout: 'auto', - compute: { module: module2D, entryPoint: 'bitReverseCols' } - }), - butterflyRows: this.device.createComputePipeline({ - layout: 'auto', - compute: { module: module2D, entryPoint: 'butterflyRows' } - }), - butterflyCols: this.device.createComputePipeline({ - layout: 'auto', - compute: { module: module2D, entryPoint: 'butterflyCols' } - }), - normalize: this.device.createComputePipeline({ - layout: 'auto', - compute: { module: module2D, entryPoint: 'normalize2D' } - }) - }; - - this.initialized = true; - console.log('WebGPU FFT initialized'); - } - - /** - * Perform 2D FFT - supports ANY size via automatic zero-padding - */ - async fft2D( - realData: Float32Array, - imagData: Float32Array, - width: number, - height: number, - inverse: boolean = false - ): Promise<{ real: Float32Array, imag: Float32Array }> { - await this.init(); - - // Compute padded power-of-2 dimensions - const paddedWidth = nextPow2(width); - const paddedHeight = nextPow2(height); - const needsPadding = paddedWidth !== width || paddedHeight !== height; - - const log2Width = Math.log2(paddedWidth); - const log2Height = Math.log2(paddedHeight); - - const paddedSize = paddedWidth * paddedHeight; - const originalSize = width * height; - - // Zero-pad input if needed - let workReal: Float32Array; - let workImag: Float32Array; - - if (needsPadding) { - workReal = new Float32Array(paddedSize); - workImag = new Float32Array(paddedSize); - // Copy original data into top-left corner - for (let y = 0; y < height; y++) { - for (let x = 0; x < width; x++) { - const srcIdx = y * width + x; - const dstIdx = y * paddedWidth + x; - workReal[dstIdx] = realData[srcIdx]; - workImag[dstIdx] = imagData[srcIdx]; - } - } - } else { - workReal = realData; - workImag = imagData; - } - - const size = paddedSize; - - // Interleave real and imaginary (use padded work arrays) - const complexData = new Float32Array(size * 2); - for (let i = 0; i < size; i++) { - complexData[i * 2] = workReal[i]; - complexData[i * 2 + 1] = workImag[i]; - } - - // Create buffers - const dataBuffer = this.device.createBuffer({ - size: complexData.byteLength, - usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, - }); - this.device.queue.writeBuffer(dataBuffer, 0, complexData); - - const paramsBuffer = this.device.createBuffer({ - size: 24, // 6 x u32/f32 - usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, - }); - - const readBuffer = this.device.createBuffer({ - size: complexData.byteLength, - usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, - }); - - const inverseVal = inverse ? 1.0 : -1.0; - const workgroupsX = Math.ceil(paddedWidth / 16); - const workgroupsY = Math.ceil(paddedHeight / 16); - - // Helper to run a pass - const runPass = (pipeline: GPUComputePipeline) => { - const bindGroup = this.device.createBindGroup({ - layout: pipeline.getBindGroupLayout(0), - entries: [ - { binding: 0, resource: { buffer: paramsBuffer } }, - { binding: 1, resource: { buffer: dataBuffer } }, - ] - }); - - const encoder = this.device.createCommandEncoder(); - const pass = encoder.beginComputePass(); - pass.setPipeline(pipeline); - pass.setBindGroup(0, bindGroup); - pass.dispatchWorkgroups(workgroupsX, workgroupsY); - pass.end(); - this.device.queue.submit([encoder.finish()]); - }; - - // Row-wise FFT (use padded dimensions) - const params = new ArrayBuffer(24); - const paramsU32 = new Uint32Array(params); - const paramsF32 = new Float32Array(params); - paramsU32[0] = paddedWidth; - paramsU32[1] = paddedHeight; - paramsU32[2] = log2Width; - paramsU32[3] = 0; - paramsF32[4] = inverseVal; - paramsU32[5] = 1; - this.device.queue.writeBuffer(paramsBuffer, 0, params); - runPass(this.pipelines2D!.bitReverseRows); - - for (let stage = 0; stage < log2Width; stage++) { - paramsU32[3] = stage; - this.device.queue.writeBuffer(paramsBuffer, 0, params); - runPass(this.pipelines2D!.butterflyRows); - } - - // Column-wise FFT - paramsU32[2] = log2Height; - paramsU32[3] = 0; - paramsU32[5] = 0; - this.device.queue.writeBuffer(paramsBuffer, 0, params); - runPass(this.pipelines2D!.bitReverseCols); - - for (let stage = 0; stage < log2Height; stage++) { - paramsU32[3] = stage; - this.device.queue.writeBuffer(paramsBuffer, 0, params); - runPass(this.pipelines2D!.butterflyCols); - } - - if (inverse) { - runPass(this.pipelines2D!.normalize); - } - - // Read back results - const encoder = this.device.createCommandEncoder(); - encoder.copyBufferToBuffer(dataBuffer, 0, readBuffer, 0, complexData.byteLength); - this.device.queue.submit([encoder.finish()]); - - await readBuffer.mapAsync(GPUMapMode.READ); - const result = new Float32Array(readBuffer.getMappedRange().slice(0)); - readBuffer.unmap(); - - // Cleanup GPU buffers - dataBuffer.destroy(); - paramsBuffer.destroy(); - readBuffer.destroy(); - - // Deinterleave and crop back to original size if needed - if (needsPadding) { - const realResult = new Float32Array(originalSize); - const imagResult = new Float32Array(originalSize); - for (let y = 0; y < height; y++) { - for (let x = 0; x < width; x++) { - const srcIdx = y * paddedWidth + x; - const dstIdx = y * width + x; - realResult[dstIdx] = result[srcIdx * 2]; - imagResult[dstIdx] = result[srcIdx * 2 + 1]; - } - } - return { real: realResult, imag: imagResult }; - } else { - const realResult = new Float32Array(size); - const imagResult = new Float32Array(size); - for (let i = 0; i < size; i++) { - realResult[i] = result[i * 2]; - imagResult[i] = result[i * 2 + 1]; - } - return { real: realResult, imag: imagResult }; - } - } - - destroy(): void { - this.initialized = false; - } -} - -// Singleton instance -let gpuFFT: WebGPUFFT | null = null; -let gpuDevice: GPUDevice | null = null; -let gpuInfo = "GPU"; - -/** - * Initialize WebGPU and get FFT instance - */ -export async function getWebGPUFFT(): Promise { - if (gpuFFT) return gpuFFT; - - if (!navigator.gpu) { - console.warn('WebGPU not supported, falling back to CPU FFT'); - return null; - } - - try { - const adapter = await navigator.gpu.requestAdapter(); - if (!adapter) { - console.warn('No GPU adapter found'); - return null; - } - - // Attempt to get GPU info - try { - // In modern browsers, we can request adapter info - // @ts-ignore - requestAdapterInfo is not yet in all type definitions - const info = await adapter.requestAdapterInfo?.(); - if (info) { - // Prioritize 'description' which usually has the full name (e.g. "NVIDIA GeForce RTX 4090") - // Fallback to vendor/device if description is missing - gpuInfo = info.description || - `${info.vendor} ${info.architecture || ""} ${info.device || ""}`.trim() || - "Generic WebGPU Adapter"; - } - } catch (e) { - console.log("Could not get detailed adapter info", e); - } - - gpuDevice = await adapter.requestDevice(); - gpuFFT = new WebGPUFFT(gpuDevice); - await gpuFFT.init(); - - console.log(`🚀 WebGPU FFT ready on ${gpuInfo}!`); - return gpuFFT; - } catch (e) { - console.warn('WebGPU init failed:', e); - return null; - } -} - -/** - * Get current GPU info string - */ -export function getGPUInfo(): string { - return gpuInfo; -} - -export default WebGPUFFT; From 46f3718efa4b9063fdced720cf908484221f5cc2 Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Mon, 19 Jan 2026 22:58:28 -0800 Subject: [PATCH 20/27] fix rectangular scan aspect ratio in VI and FFT panels --- widget/js/show4dstem.tsx | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/widget/js/show4dstem.tsx b/widget/js/show4dstem.tsx index f25b9a0c..18771108 100644 --- a/widget/js/show4dstem.tsx +++ b/widget/js/show4dstem.tsx @@ -1279,6 +1279,11 @@ function Show4DSTEM() { accent: "#0066cc", }; + // Compute VI canvas dimensions to respect aspect ratio of rectangular scans + // The longer dimension gets CANVAS_SIZE, the shorter scales proportionally + const viCanvasWidth = shapeX > shapeY ? Math.round(CANVAS_SIZE * (shapeY / shapeX)) : CANVAS_SIZE; + const viCanvasHeight = shapeY > shapeX ? Math.round(CANVAS_SIZE * (shapeX / shapeY)) : CANVAS_SIZE; + // Histogram data - use state to ensure re-renders (both are Float32Array now) const [dpHistogramData, setDpHistogramData] = React.useState(null); const [viHistogramData, setViHistogramData] = React.useState(null); @@ -2422,7 +2427,7 @@ function Show4DSTEM() { {/* RIGHT COLUMN: VI Panel + FFT (when shown) */} - + {/* VI Header */} Virtual Image @@ -2438,7 +2443,7 @@ function Show4DSTEM() { {/* VI Canvas */} - + - + {/* VI Stats Bar */} @@ -2531,7 +2536,7 @@ function Show4DSTEM() { FFT - + Date: Mon, 19 Jan 2026 23:01:14 -0800 Subject: [PATCH 21/27] restructure js/ into per-widget folders --- widget/js/show2d/index.tsx | 17 +++++++++++++++++ widget/js/show3d/index.tsx | 17 +++++++++++++++++ .../js/{show4dstem.tsx => show4dstem/index.tsx} | 2 +- .../{show4dstem.css => show4dstem/styles.css} | 0 widget/package.json | 2 +- 5 files changed, 36 insertions(+), 2 deletions(-) create mode 100644 widget/js/show2d/index.tsx create mode 100644 widget/js/show3d/index.tsx rename widget/js/{show4dstem.tsx => show4dstem/index.tsx} (99%) rename widget/js/{show4dstem.css => show4dstem/styles.css} (100%) diff --git a/widget/js/show2d/index.tsx b/widget/js/show2d/index.tsx new file mode 100644 index 00000000..6a42d9ff --- /dev/null +++ b/widget/js/show2d/index.tsx @@ -0,0 +1,17 @@ +// Placeholder for Show2D widget +// TODO: Implement 2D image viewer widget + +import * as React from "react"; +import { createRender } from "@anywidget/react"; +import Box from "@mui/material/Box"; +import Typography from "@mui/material/Typography"; + +function Show2D() { + return ( + + Show2D - Coming Soon + + ); +} + +export const render = createRender(Show2D); diff --git a/widget/js/show3d/index.tsx b/widget/js/show3d/index.tsx new file mode 100644 index 00000000..62d95a35 --- /dev/null +++ b/widget/js/show3d/index.tsx @@ -0,0 +1,17 @@ +// Placeholder for Show3D widget +// TODO: Implement 3D volume viewer widget + +import * as React from "react"; +import { createRender } from "@anywidget/react"; +import Box from "@mui/material/Box"; +import Typography from "@mui/material/Typography"; + +function Show3D() { + return ( + + Show3D - Coming Soon + + ); +} + +export const render = createRender(Show3D); diff --git a/widget/js/show4dstem.tsx b/widget/js/show4dstem/index.tsx similarity index 99% rename from widget/js/show4dstem.tsx rename to widget/js/show4dstem/index.tsx index 18771108..ec8e328d 100644 --- a/widget/js/show4dstem.tsx +++ b/widget/js/show4dstem/index.tsx @@ -10,7 +10,7 @@ import Slider from "@mui/material/Slider"; import Button from "@mui/material/Button"; import Switch from "@mui/material/Switch"; import JSZip from "jszip"; -import "./show4dstem.css"; +import "./styles.css"; // ============================================================================ // Theme Detection - detect environment and light/dark mode diff --git a/widget/js/show4dstem.css b/widget/js/show4dstem/styles.css similarity index 100% rename from widget/js/show4dstem.css rename to widget/js/show4dstem/styles.css diff --git a/widget/package.json b/widget/package.json index d78da578..9144270a 100644 --- a/widget/package.json +++ b/widget/package.json @@ -1,7 +1,7 @@ { "scripts": { "dev": "npm run build -- --sourcemap=inline --watch", - "build": "esbuild js/*.tsx --minify --format=esm --bundle --outdir=src/quantem/widget/static", + "build": "esbuild js/show4dstem/index.tsx js/show2d/index.tsx js/show3d/index.tsx --minify --format=esm --bundle --outdir=src/quantem/widget/static --entry-names=[dir]", "typecheck": "tsc --noEmit" }, "dependencies": { From edf4c97c844eeef48b2b5af42a9066569b1aae6f Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Thu, 22 Jan 2026 06:50:32 -0800 Subject: [PATCH 22/27] george feedback - 3rd FFT column, hot pixel, loading speed, reset button unactivated, fft scale --- widget/js/show4dstem/index.tsx | 252 ++++++++++++++++++++---- widget/package.json | 3 +- widget/src/quantem/widget/show4dstem.py | 109 +++++++--- 3 files changed, 294 insertions(+), 70 deletions(-) diff --git a/widget/js/show4dstem/index.tsx b/widget/js/show4dstem/index.tsx index ec8e328d..6cd8a8d3 100644 --- a/widget/js/show4dstem/index.tsx +++ b/widget/js/show4dstem/index.tsx @@ -9,6 +9,7 @@ import MenuItem from "@mui/material/MenuItem"; import Slider from "@mui/material/Slider"; import Button from "@mui/material/Button"; import Switch from "@mui/material/Switch"; +import Tooltip from "@mui/material/Tooltip"; import JSZip from "jszip"; import "./styles.css"; @@ -380,6 +381,10 @@ const compactButton = { py: 0.25, px: 1, minWidth: 0, + "&.Mui-disabled": { + color: "#666", + borderColor: "#444", + }, }; // Control row style - bordered container for each row @@ -1025,6 +1030,52 @@ interface HistogramProps { width?: number; height?: number; theme?: "light" | "dark"; + dataMin?: number; + dataMax?: number; +} + +/** + * Info tooltip component - small ⓘ icon with hover tooltip + */ +function InfoTooltip({ text, theme = "dark" }: { text: string; theme?: "light" | "dark" }) { + const isDark = theme === "dark"; + return ( + {text}} + arrow + placement="bottom" + componentsProps={{ + tooltip: { + sx: { + bgcolor: isDark ? "#333" : "#fff", + color: isDark ? "#ddd" : "#333", + border: `1px solid ${isDark ? "#555" : "#ccc"}`, + maxWidth: 280, + p: 1, + }, + }, + arrow: { + sx: { + color: isDark ? "#333" : "#fff", + "&::before": { border: `1px solid ${isDark ? "#555" : "#ccc"}` }, + }, + }, + }} + > + + ⓘ + + + ); } /** @@ -1039,7 +1090,9 @@ function Histogram({ onRangeChange, width = 120, height = 40, - theme = "dark" + theme = "dark", + dataMin = 0, + dataMax = 1, }: HistogramProps) { const canvasRef = React.useRef(null); const bins = React.useMemo(() => computeHistogramFromBytes(data), [data]); @@ -1122,12 +1175,18 @@ function Histogram({ min={0} max={100} size="small" + valueLabelDisplay="auto" + valueLabelFormat={(pct) => { + const val = dataMin + (pct / 100) * (dataMax - dataMin); + return val >= 1000 ? val.toExponential(1) : val.toFixed(1); + }} sx={{ width, py: 0, "& .MuiSlider-thumb": { width: 8, height: 8 }, "& .MuiSlider-rail": { height: 2 }, "& .MuiSlider-track": { height: 2 }, + "& .MuiSlider-valueLabel": { fontSize: 10, padding: "2px 4px" }, }} /> @@ -1236,6 +1295,7 @@ function Show4DSTEM() { const [summedDpCount] = useModelState("summed_dp_count"); const [dpStats] = useModelState("dp_stats"); // [mean, min, max, std] const [viStats] = useModelState("vi_stats"); // [mean, min, max, std] + const [hotPixelFilter, setHotPixelFilter] = useModelState("hot_pixel_filter"); const [showFft, setShowFft] = React.useState(false); // Hidden by default per feedback // Theme detection - detect environment and light/dark mode @@ -1407,6 +1467,15 @@ function Show4DSTEM() { const [fftZoom, setFftZoom] = React.useState(1); const [fftPanX, setFftPanX] = React.useState(0); const [fftPanY, setFftPanY] = React.useState(0); + const [fftScaleMode, setFftScaleMode] = React.useState<"linear" | "log" | "power">("linear"); + const [fftColormap, setFftColormap] = React.useState("inferno"); + const [fftAuto, setFftAuto] = React.useState(true); // Auto: mask DC + 99.9% clipping + const [fftVminPct, setFftVminPct] = React.useState(0); + const [fftVmaxPct, setFftVmaxPct] = React.useState(100); + const [fftStats, setFftStats] = React.useState(null); // [mean, min, max, std] + const [fftHistogramData, setFftHistogramData] = React.useState(null); + const [fftDataMin, setFftDataMin] = React.useState(0); + const [fftDataMax, setFftDataMax] = React.useState(1); // Sync local state React.useEffect(() => { @@ -1773,26 +1842,74 @@ function Show4DSTEM() { const width = shapeY; const height = shapeX; const sourceData = rawVirtualImageRef.current; - const lut = COLORMAPS[viColormap] || COLORMAPS.inferno; + const lut = COLORMAPS[fftColormap] || COLORMAPS.inferno; // Helper to render magnitude to canvas const renderMagnitude = (real: Float32Array, imag: Float32Array) => { - // Compute log magnitude + // Compute magnitude (log or linear) let magnitude = fftMagnitudeRef.current; if (!magnitude || magnitude.length !== real.length) { magnitude = new Float32Array(real.length); fftMagnitudeRef.current = magnitude; } for (let i = 0; i < real.length; i++) { - magnitude[i] = Math.log1p(Math.sqrt(real[i] * real[i] + imag[i] * imag[i])); + const mag = Math.sqrt(real[i] * real[i] + imag[i] * imag[i]); + if (fftScaleMode === "log") { + magnitude[i] = Math.log1p(mag); + } else if (fftScaleMode === "power") { + magnitude[i] = Math.pow(mag, 0.5); // gamma = 0.5 + } else { + magnitude[i] = mag; + } } - // Normalize - let min = Infinity, max = -Infinity; + // Auto mode: mask DC component + 99.9% percentile clipping + let displayMin: number, displayMax: number; + if (fftAuto) { + // Mask DC (center pixel) by replacing with neighbor average + const centerIdx = Math.floor(height / 2) * width + Math.floor(width / 2); + const neighbors = [ + magnitude[centerIdx - 1], + magnitude[centerIdx + 1], + magnitude[centerIdx - width], + magnitude[centerIdx + width] + ]; + magnitude[centerIdx] = neighbors.reduce((a, b) => a + b, 0) / 4; + + // Apply 99.9% percentile clipping for display range + const sorted = magnitude.slice().sort((a, b) => a - b); + displayMin = sorted[0]; + displayMax = sorted[Math.floor(sorted.length * 0.999)]; + } else { + // No auto: use actual min/max + displayMin = Infinity; + displayMax = -Infinity; + for (let i = 0; i < magnitude.length; i++) { + if (magnitude[i] < displayMin) displayMin = magnitude[i]; + if (magnitude[i] > displayMax) displayMax = magnitude[i]; + } + } + setFftDataMin(displayMin); + setFftDataMax(displayMax); + + // Stats use same values + const actualMin = displayMin; + const actualMax = displayMax; + let sum = 0; for (let i = 0; i < magnitude.length; i++) { - if (magnitude[i] < min) min = magnitude[i]; - if (magnitude[i] > max) max = magnitude[i]; + sum += magnitude[i]; } + const mean = sum / magnitude.length; + let sumSq = 0; + for (let i = 0; i < magnitude.length; i++) { + const diff = magnitude[i] - mean; + sumSq += diff * diff; + } + const std = Math.sqrt(sumSq / magnitude.length); + setFftStats([mean, actualMin, actualMax, std]); + + // Store histogram data (copy of magnitude for histogram component) + setFftHistogramData(magnitude.slice()); let offscreen = fftOffscreenRef.current; if (!offscreen) { @@ -1814,10 +1931,15 @@ function Show4DSTEM() { fftImageDataRef.current = imgData; } const rgba = imgData.data; - const range = max > min ? max - min : 1; + + // Apply histogram slider range on top of percentile clipping + const dataRange = displayMax - displayMin; + const vmin = displayMin + (fftVminPct / 100) * dataRange; + const vmax = displayMin + (fftVmaxPct / 100) * dataRange; + const range = vmax > vmin ? vmax - vmin : 1; for (let i = 0; i < magnitude.length; i++) { - const v = Math.round(((magnitude[i] - min) / range) * 255); + const v = Math.round(((magnitude[i] - vmin) / range) * 255); const j = i * 4; const lutIdx = Math.max(0, Math.min(255, v)) * 3; rgba[j] = lut[lutIdx]; @@ -1876,7 +1998,7 @@ function Show4DSTEM() { fftshift(imag, width, height); renderMagnitude(real, imag); } - }, [virtualImageBytes, shapeX, shapeY, viColormap, fftZoom, fftPanX, fftPanY, gpuReady, showFft]); + }, [virtualImageBytes, shapeX, shapeY, fftColormap, fftZoom, fftPanX, fftPanY, gpuReady, showFft, fftScaleMode, fftAuto, fftVminPct, fftVmaxPct]); // Render FFT overlay with high-pass filter circle React.useEffect(() => { @@ -2315,7 +2437,7 @@ function Show4DSTEM() { 4D-STEM Explorer - {/* MAIN CONTENT: Two columns */} + {/* MAIN CONTENT: DP | VI | FFT (three columns when FFT shown) */} {/* LEFT COLUMN: DP Panel */} @@ -2324,9 +2446,10 @@ function Show4DSTEM() { DP at ({Math.round(localPosX)}, {Math.round(localPosY)}) k: ({Math.round(localKx)}, {Math.round(localKy)}) + - + @@ -2347,11 +2470,15 @@ function Show4DSTEM() { {/* DP Stats Bar */} {dpStats && dpStats.length === 4 && ( - + Mean {formatStat(dpStats[0])} Min {formatStat(dpStats[1])} Max {formatStat(dpStats[2])} Std {formatStat(dpStats[3])} + + Show hot px: + setHotPixelFilter(!e.target.checked)} size="small" sx={switchStyles.small} /> + )} @@ -2421,23 +2548,23 @@ function Show4DSTEM() { {/* Right: Histogram spanning both rows */} - { setDpVminPct(min); setDpVmaxPct(max); }} width={110} height={58} theme={themeInfo.theme} /> + { setDpVminPct(min); setDpVmaxPct(max); }} width={110} height={58} theme={themeInfo.theme} dataMin={dpGlobalMin} dataMax={dpGlobalMax} /> - {/* RIGHT COLUMN: VI Panel + FFT (when shown) */} + {/* SECOND COLUMN: VI Panel */} {/* VI Header */} - Virtual Image + Virtual Image {shapeX}×{shapeY} | {detX}×{detY} FFT: setShowFft(e.target.checked)} size="small" sx={switchStyles.small} /> - + @@ -2525,31 +2652,82 @@ function Show4DSTEM() { {/* Right: Histogram spanning both rows */} - { setViVminPct(min); setViVmaxPct(max); }} width={110} height={58} theme={themeInfo.theme} /> + { setViVminPct(min); setViVmaxPct(max); }} width={110} height={58} theme={themeInfo.theme} dataMin={viDataMin} dataMax={viDataMax} /> + - {/* FFT Panel (conditionally shown) */} - {showFft && ( - - - FFT - + {/* THIRD COLUMN: FFT Panel (conditionally shown) */} + {showFft && ( + + {/* FFT Header */} + + FFT + + - - - + + + {/* FFT Canvas */} + + + + + + {/* FFT Stats Bar */} + {fftStats && fftStats.length === 4 && ( + + Mean {formatStat(fftStats[0])} + Min {formatStat(fftStats[1])} + Max {formatStat(fftStats[2])} + Std {formatStat(fftStats[3])} + + )} + + {/* FFT Controls - Two rows with histogram on right */} + + {/* Left: Two rows of controls */} + + {/* Row 1: Scale + Clip */} + + Scale: + + Auto: + setFftAuto(e.target.checked)} size="small" sx={switchStyles.small} /> + + {/* Row 2: Color */} + + Color: + + + + {/* Right: Histogram spanning both rows */} + + {fftHistogramData && ( + { setFftVminPct(min); setFftVmaxPct(max); }} width={110} height={58} theme={themeInfo.theme} dataMin={fftDataMin} dataMax={fftDataMax} /> + )} - )} - + + )} {/* BOTTOM CONTROLS - Path only (FFT toggle moved to VI panel) */} diff --git a/widget/package.json b/widget/package.json index 9144270a..4e63363a 100644 --- a/widget/package.json +++ b/widget/package.json @@ -11,8 +11,7 @@ "@mui/material": "^7.3.6", "jszip": "^3.10.1", "react": "^19.1.0", - "react-dom": "^19.1.0", - "webfft": "^1.0.3" + "react-dom": "^19.1.0" }, "devDependencies": { "@types/react": "^19.1.3", diff --git a/widget/src/quantem/widget/show4dstem.py b/widget/src/quantem/widget/show4dstem.py index 58ca82ea..751574c2 100644 --- a/widget/src/quantem/widget/show4dstem.py +++ b/widget/src/quantem/widget/show4dstem.py @@ -11,8 +11,10 @@ import anywidget import numpy as np +import torch import traitlets +from quantem.core.config import validate_device from quantem.widget.array_utils import to_numpy @@ -159,6 +161,7 @@ class Show4DSTEM(anywidget.AnyWidget): dp_stats = traitlets.List(traitlets.Float(), default_value=[0.0, 0.0, 0.0, 0.0]).tag(sync=True) vi_stats = traitlets.List(traitlets.Float(), default_value=[0.0, 0.0, 0.0, 0.0]).tag(sync=True) mask_dc = traitlets.Bool(True).tag(sync=True) # Mask center pixel for DP stats + hot_pixel_filter = traitlets.Bool(True).tag(sync=True) # Filter hot pixels (default on) def __init__( self, @@ -168,7 +171,7 @@ def __init__( k_pixel_size: float | None = None, center: tuple[float, float] | None = None, bf_radius: float | None = None, - precompute_virtual_images: bool = True, + precompute_virtual_images: bool = False, log_scale: bool = False, **kwargs, ): @@ -197,8 +200,11 @@ def __init__( self.k_calibrated = k_calibrated or (k_pixel_size is not None) # Path animation (configured via set_path() or raster()) self._path_points: list[tuple[int, int]] = [] - # Convert to NumPy - self._data = to_numpy(data) + # Convert to NumPy then PyTorch tensor using quantem device config + data_np = to_numpy(data) + device_str, _ = validate_device(None) # Get device from quantem config + self._device = torch.device(device_str) + self._data = torch.from_numpy(data_np.astype(np.float32)).to(self._device) # Handle flattened data if data.ndim == 3: if scan_shape is not None: @@ -266,7 +272,8 @@ def __init__( self._precompute_common_virtual_images() # Update frame when position changes (scale/colormap handled in JS) - self.observe(self._update_frame, names=["pos_x", "pos_y"]) + self.observe(self._update_frame, names=["pos_x", "pos_y", "hot_pixel_filter"]) + self.observe(self._compute_global_range, names=["hot_pixel_filter"]) self.observe(self._on_roi_change, names=[ "roi_center_x", "roi_center_y", "roi_radius", "roi_radius_inner", "roi_active", "roi_mode", "roi_width", "roi_height" @@ -655,11 +662,11 @@ def auto_detect_center(self) -> "Show4DSTEM": def _auto_detect_center_silent(self): """Auto-detect center without updating ROI (used during __init__).""" - # Sum all diffraction patterns to get average + # Sum all frames (fast on GPU) if self._data.ndim == 4: - summed_dp = self._data.sum(axis=(0, 1)).astype(np.float64) + summed_dp = self._data.sum(dim=(0, 1)) else: - summed_dp = self._data.sum(axis=0).astype(np.float64) + summed_dp = self._data.sum(dim=0) # Threshold at mean + std to isolate BF disk threshold = summed_dp.mean() + summed_dp.std() @@ -670,23 +677,25 @@ def _auto_detect_center_silent(self): if total == 0: return - # Calculate centroid using meshgrid - y_coords, x_coords = np.meshgrid( - np.arange(self.det_x), np.arange(self.det_y), indexing='ij' + # Calculate centroid + y_coords, x_coords = torch.meshgrid( + torch.arange(self.det_x, device=self._device), + torch.arange(self.det_y, device=self._device), + indexing='ij' ) cx = float((x_coords * mask).sum() / total) cy = float((y_coords * mask).sum() / total) # Estimate radius from mask area (A = pi*r^2) - radius = float(np.sqrt(total / np.pi)) + radius = float(torch.sqrt(total / torch.pi)) # Apply detected values (don't update ROI - that happens later in __init__) self.center_x = cx self.center_y = cy self.bf_radius = radius - def _compute_global_range(self): - """Compute global min/max from sampled frames for consistent scaling.""" + def _compute_global_range(self, change=None): + """Compute min/max for histogram range. Uses percentiles when hot_pixel_filter is ON.""" # Sample corners and center samples = [ @@ -697,32 +706,47 @@ def _compute_global_range(self): (self.shape_x // 2, self.shape_y // 2), ] - all_min, all_max = float("inf"), float("-inf") + # Collect all pixel values from sampled frames + all_values = [] for x, y in samples: frame = self._get_frame(x, y) - fmin = float(frame.min()) - fmax = float(frame.max()) - all_min = min(all_min, fmin) - all_max = max(all_max, fmax) + all_values.append(frame.ravel()) - # Set traits for JS-side normalization - self.dp_global_min = max(all_min, 1e-10) - self.dp_global_max = all_max + all_values = np.concatenate(all_values) + + if self.hot_pixel_filter: + # Use 99.99 percentile to exclude hot pixels + p_high = np.percentile(all_values, 99.99) + self.dp_global_min = max(float(all_values.min()), 1e-10) + self.dp_global_max = float(p_high) + else: + # Use actual min/max (show hot pixels) + self.dp_global_min = max(float(all_values.min()), 1e-10) + self.dp_global_max = float(all_values.max()) - def _get_frame(self, x: int, y: int): - """Get single diffraction frame at position (x, y).""" + def _get_frame(self, x: int, y: int) -> np.ndarray: + """Get single diffraction frame at position (x, y) as numpy array.""" if self._data.ndim == 3: idx = x * self.shape_y + y - return self._data[idx] + return self._data[idx].cpu().numpy() else: - return self._data[x, y] + return self._data[x, y].cpu().numpy() + + def _filter_hot_pixels(self, frame: np.ndarray) -> np.ndarray: + """Filter hot pixels by clipping to 99.99th percentile.""" + threshold = np.percentile(frame, 99.99) + return np.clip(frame, None, threshold) def _update_frame(self, change=None): """Send raw float32 frame to frontend (JS handles scale/colormap).""" frame = self._get_frame(self.pos_x, self.pos_y) frame = frame.astype(np.float32) - # Compute stats from raw frame (optionally mask DC component) + # Apply hot pixel filtering if enabled + if self.hot_pixel_filter: + frame = self._filter_hot_pixels(frame) + + # Compute stats from frame (optionally mask DC component) if self.mask_dc: # Mask center 3x3 region for stats cx, cy = self.det_x // 2, self.det_y // 2 @@ -884,13 +908,36 @@ def _get_cached_preset(self) -> tuple[bytes, list[float], float, float] | None: return None def _fast_masked_sum(self, mask) -> "np.ndarray": - """Masked sum over detector dimensions.""" - if self._data.ndim == 4: - return (self._data.astype(np.float32) * mask).sum(axis=(2, 3)) - return (self._data.astype(np.float32) * mask).sum(axis=(1, 2)).reshape(self._scan_shape) + """Compute masked sum using PyTorch. + + Uses sparse indexing for small masks (<20% coverage) which is faster + because it only processes non-zero pixels: + - r=10 (1%): ~0.8ms (sparse) vs ~13ms (full) + - r=30 (8%): ~4ms (sparse) vs ~13ms (full) + + For large masks (≥20%), uses full tensordot which has constant ~13ms. + """ + mask_tensor = torch.from_numpy(mask.astype(np.float32)).to(self._device) + n_det = self._det_shape[0] * self._det_shape[1] + n_nonzero = int(mask_tensor.sum()) + coverage = n_nonzero / n_det + + if coverage < 0.2: + # Sparse: faster for small masks + indices = torch.nonzero(mask_tensor.flatten(), as_tuple=True)[0] + n_scan = self._scan_shape[0] * self._scan_shape[1] + data_flat = self._data.reshape(n_scan, n_det) + result = data_flat[:, indices].sum(dim=1).reshape(self._scan_shape) + else: + # Tensordot: faster for large masks + result = torch.tensordot(self._data, mask_tensor, dims=([2, 3], [0, 1])) + + return result.cpu().numpy() - def _to_float32_bytes(self, arr: "np.ndarray", update_vi_stats: bool = True) -> bytes: + def _to_float32_bytes(self, arr, update_vi_stats: bool = True) -> bytes: """Convert array to float32 bytes (JS handles scale/colormap).""" + if isinstance(arr, torch.Tensor): + arr = arr.cpu().numpy() arr = arr.astype(np.float32) # Set min/max for JS-side normalization From fa41d2712dac86207110776dda6c298da35a1ed4 Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Thu, 22 Jan 2026 07:46:23 -0800 Subject: [PATCH 23/27] remove deadcode --- widget/js/show4dstem/index.tsx | 115 ++------------------------------- 1 file changed, 6 insertions(+), 109 deletions(-) diff --git a/widget/js/show4dstem/index.tsx b/widget/js/show4dstem/index.tsx index 6cd8a8d3..630a0138 100644 --- a/widget/js/show4dstem/index.tsx +++ b/widget/js/show4dstem/index.tsx @@ -202,15 +202,6 @@ function fftshift(data: Float32Array, width: number, height: number): void { data.set(temp); } -function applyBandPassFilter(real: Float32Array, imag: Float32Array, width: number, height: number, innerRadius: number, outerRadius: number) { - const centerX = width >> 1, centerY = height >> 1; - const innerSq = innerRadius * innerRadius, outerSq = outerRadius * outerRadius; - for (let y = 0; y < height; y++) for (let x = 0; x < width; x++) { - const distSq = (x - centerX) ** 2 + (y - centerY) ** 2; - if (distSq < innerSq || (outerRadius > 0 && distSq > outerSq)) { real[y * width + x] = 0; imag[y * width + x] = 0; } - } -} - // ============================================================================ // WebGPU FFT - GPU-accelerated FFT when available // ============================================================================ @@ -1277,9 +1268,9 @@ function Show4DSTEM() { const [viVmaxPct, setViVmaxPct] = React.useState(100); // Scale mode: "linear" | "log" | "power" const [dpScaleMode, setDpScaleMode] = React.useState<"linear" | "log" | "power">("linear"); - const [dpPowerExp, setDpPowerExp] = React.useState(0.5); + const dpPowerExp = 0.5; const [viScaleMode, setViScaleMode] = React.useState<"linear" | "log" | "power">("linear"); - const [viPowerExp, setViPowerExp] = React.useState(0.5); + const viPowerExp = 0.5; // VI ROI state (real-space region selection for summed DP) - synced with Python const [viRoiMode, setViRoiMode] = useModelState("vi_roi_mode"); @@ -1369,11 +1360,6 @@ function Show4DSTEM() { setDpHistogramData(scaledData); }, [frameBytes, dpScaleMode, dpPowerExp]); - // Band-pass filter range [innerCutoff, outerCutoff] in pixels - [0, 0] means disabled - const [bandpass, setBandpass] = React.useState([0, 0]); - const bpInner = bandpass[0]; - const bpOuter = bandpass[1]; - // GPU FFT state const gpuFFTRef = React.useRef(null); const [gpuReady, setGpuReady] = React.useState(false); @@ -1528,8 +1514,6 @@ function Show4DSTEM() { // Store raw data for filtering/FFT const rawVirtualImageRef = React.useRef(null); - const viWorkRealRef = React.useRef(null); - const viWorkImagRef = React.useRef(null); const fftWorkRealRef = React.useRef(null); const fftWorkImagRef = React.useRef(null); const fftMagnitudeRef = React.useRef(null); @@ -1753,70 +1737,9 @@ function Show4DSTEM() { ctx.restore(); }; - if (bpInner > 0 || bpOuter > 0) { - if (gpuFFTRef.current && gpuReady) { - // GPU filtering (Async) - const real = rawVirtualImageRef.current.slice(); - const imag = new Float32Array(real.length); - - // We use a local flag to prevent state updates if the effect has already re-run - let isCancelled = false; - - const runGpuFilter = async () => { - // WebGPU version of: Forward -> Filter -> Inverse - // Note: The provided WebGPUFFT doesn't have shift/unshift built-in yet, - // but we can apply the filter in shifted coordinates or modify it. - // For now, let's keep it simple: Forward -> Filter -> Inverse. - const { real: fReal, imag: fImag } = await gpuFFTRef.current!.fft2D(real, imag, width, height, false); - - if (isCancelled) return; - - // Shift in CPU for now (future: do this in WGSL) - fftshift(fReal, width, height); - fftshift(fImag, width, height); - applyBandPassFilter(fReal, fImag, width, height, bpInner, bpOuter); - fftshift(fReal, width, height); - fftshift(fImag, width, height); - - const { real: invReal } = await gpuFFTRef.current!.fft2D(fReal, fImag, width, height, true); - - if (!isCancelled) renderData(invReal); - }; - - runGpuFilter(); - return () => { isCancelled = true; }; - } else { - // CPU Fallback (Sync) - const source = rawVirtualImageRef.current; - if (!source) return; - const len = source.length; - let real = viWorkRealRef.current; - if (!real || real.length !== len) { - real = new Float32Array(len); - viWorkRealRef.current = real; - } - real.set(source); - let imag = viWorkImagRef.current; - if (!imag || imag.length !== len) { - imag = new Float32Array(len); - viWorkImagRef.current = imag; - } else { - imag.fill(0); - } - fft2d(real, imag, width, height, false); - fftshift(real, width, height); - fftshift(imag, width, height); - applyBandPassFilter(real, imag, width, height, bpInner, bpOuter); - fftshift(real, width, height); - fftshift(imag, width, height); - fft2d(real, imag, width, height, true); - renderData(real); - } - } else { - if (!rawVirtualImageRef.current) return; - renderData(rawVirtualImageRef.current); - } - }, [virtualImageBytes, shapeX, shapeY, viColormap, viVminPct, viVmaxPct, viScaleMode, viPowerExp, viZoom, viPanX, viPanY, bpInner, bpOuter, gpuReady]); + if (!rawVirtualImageRef.current) return; + renderData(rawVirtualImageRef.current); + }, [virtualImageBytes, shapeX, shapeY, viColormap, viVminPct, viVmaxPct, viScaleMode, viPowerExp, viZoom, viPanX, viPanY]); // Render virtual image overlay (just clear - crosshair drawn on high-DPI UI canvas) React.useEffect(() => { @@ -2007,33 +1930,7 @@ function Show4DSTEM() { const ctx = canvas.getContext("2d"); if (!ctx) return; ctx.clearRect(0, 0, canvas.width, canvas.height); - if (!showFft) return; - - // Draw band-pass filter circles (inner = HP, outer = LP) - const centerX = (shapeY / 2) * fftZoom + fftPanX; - const centerY = (shapeX / 2) * fftZoom + fftPanY; - const minScanSize = Math.min(shapeX, shapeY); - const fftLineWidth = Math.max(LINE_WIDTH_MIN_PX, Math.min(LINE_WIDTH_MAX_PX, minScanSize * LINE_WIDTH_FRACTION)); - - if (bpInner > 0) { - ctx.strokeStyle = "rgba(255, 0, 0, 0.8)"; - ctx.lineWidth = fftLineWidth; - ctx.setLineDash([5, 5]); - ctx.beginPath(); - ctx.arc(centerX, centerY, bpInner * fftZoom, 0, 2 * Math.PI); - ctx.stroke(); - ctx.setLineDash([]); - } - if (bpOuter > 0) { - ctx.strokeStyle = "rgba(0, 150, 255, 0.8)"; - ctx.lineWidth = fftLineWidth; - ctx.setLineDash([5, 5]); - ctx.beginPath(); - ctx.arc(centerX, centerY, bpOuter * fftZoom, 0, 2 * Math.PI); - ctx.stroke(); - ctx.setLineDash([]); - } - }, [fftZoom, fftPanX, fftPanY, pixelSize, shapeX, shapeY, bpInner, bpOuter, showFft]); + }, [fftZoom, fftPanX, fftPanY, showFft]); // ───────────────────────────────────────────────────────────────────────── // High-DPI Scale Bar UI Overlays From 2039076d6eea4014d8196f15dd741a7f486aff94 Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Sun, 25 Jan 2026 13:44:15 -0800 Subject: [PATCH 24/27] cache detector coordinate for VI movement for higher FPS --- widget/js/show4dstem/index.tsx | 5 - widget/src/quantem/widget/show4dstem.py | 193 ++++++++++-------------- widget/tests/test_widget.py | 109 ------------- 3 files changed, 81 insertions(+), 226 deletions(-) delete mode 100644 widget/tests/test_widget.py diff --git a/widget/js/show4dstem/index.tsx b/widget/js/show4dstem/index.tsx index 630a0138..75a7183c 100644 --- a/widget/js/show4dstem/index.tsx +++ b/widget/js/show4dstem/index.tsx @@ -1286,7 +1286,6 @@ function Show4DSTEM() { const [summedDpCount] = useModelState("summed_dp_count"); const [dpStats] = useModelState("dp_stats"); // [mean, min, max, std] const [viStats] = useModelState("vi_stats"); // [mean, min, max, std] - const [hotPixelFilter, setHotPixelFilter] = useModelState("hot_pixel_filter"); const [showFft, setShowFft] = React.useState(false); // Hidden by default per feedback // Theme detection - detect environment and light/dark mode @@ -2372,10 +2371,6 @@ function Show4DSTEM() { Min {formatStat(dpStats[1])} Max {formatStat(dpStats[2])} Std {formatStat(dpStats[3])} - - Show hot px: - setHotPixelFilter(!e.target.checked)} size="small" sx={switchStyles.small} /> - )} diff --git a/widget/src/quantem/widget/show4dstem.py b/widget/src/quantem/widget/show4dstem.py index 751574c2..d44b41f9 100644 --- a/widget/src/quantem/widget/show4dstem.py +++ b/widget/src/quantem/widget/show4dstem.py @@ -161,7 +161,6 @@ class Show4DSTEM(anywidget.AnyWidget): dp_stats = traitlets.List(traitlets.Float(), default_value=[0.0, 0.0, 0.0, 0.0]).tag(sync=True) vi_stats = traitlets.List(traitlets.Float(), default_value=[0.0, 0.0, 0.0, 0.0]).tag(sync=True) mask_dc = traitlets.Bool(True).tag(sync=True) # Mask center pixel for DP stats - hot_pixel_filter = traitlets.Bool(True).tag(sync=True) # Filter hot pixels (default on) def __init__( self, @@ -205,6 +204,10 @@ def __init__( device_str, _ = validate_device(None) # Get device from quantem config self._device = torch.device(device_str) self._data = torch.from_numpy(data_np.astype(np.float32)).to(self._device) + # Remove saturated hot pixels (65535 for uint16, 255 for uint8) + saturated_value = 65535.0 if data_np.dtype == np.uint16 else 255.0 if data_np.dtype == np.uint8 else None + if saturated_value is not None: + self._data[self._data >= saturated_value] = 0 # Handle flattened data if data.ndim == 3: if scan_shape is not None: @@ -263,6 +266,12 @@ def __init__( # Auto-detect center and bf_radius from the data self._auto_detect_center_silent() + # Cache coordinate tensors for mask creation (avoid repeated torch.arange) + self._det_y_coords = torch.arange(self.det_x, device=self._device, dtype=torch.float32)[:, None] + self._det_x_coords = torch.arange(self.det_y, device=self._device, dtype=torch.float32)[None, :] + self._scan_row_coords = torch.arange(self.shape_x, device=self._device, dtype=torch.float32)[:, None] + self._scan_col_coords = torch.arange(self.shape_y, device=self._device, dtype=torch.float32)[None, :] + # Pre-compute and cache common virtual images (BF, ABF, ADF) # Each cache stores (bytes, stats) tuple self._cached_bf_virtual = None @@ -272,8 +281,7 @@ def __init__( self._precompute_common_virtual_images() # Update frame when position changes (scale/colormap handled in JS) - self.observe(self._update_frame, names=["pos_x", "pos_y", "hot_pixel_filter"]) - self.observe(self._compute_global_range, names=["hot_pixel_filter"]) + self.observe(self._update_frame, names=["pos_x", "pos_y"]) self.observe(self._on_roi_change, names=[ "roi_center_x", "roi_center_y", "roi_radius", "roi_radius_inner", "roi_active", "roi_mode", "roi_width", "roi_height" @@ -621,11 +629,11 @@ def auto_detect_center(self) -> "Show4DSTEM": >>> widget = Show4DSTEM(data) >>> widget.auto_detect_center() # Auto-detect and apply """ - # Sum all diffraction patterns to get average + # Sum all diffraction patterns to get average (PyTorch) if self._data.ndim == 4: - summed_dp = self._data.sum(axis=(0, 1)).astype(np.float64) + summed_dp = self._data.sum(dim=(0, 1)) else: - summed_dp = self._data.sum(axis=0).astype(np.float64) + summed_dp = self._data.sum(dim=0) # Threshold at mean + std to isolate BF disk threshold = summed_dp.mean() + summed_dp.std() @@ -636,15 +644,12 @@ def auto_detect_center(self) -> "Show4DSTEM": if total == 0: return self - # Calculate centroid using meshgrid - y_coords, x_coords = np.meshgrid( - np.arange(self.det_x), np.arange(self.det_y), indexing='ij' - ) - cx = float((x_coords * mask).sum() / total) - cy = float((y_coords * mask).sum() / total) + # Calculate centroid using cached coordinates + cx = float((self._det_x_coords * mask).sum() / total) + cy = float((self._det_y_coords * mask).sum() / total) # Estimate radius from mask area (A = pi*r^2) - radius = float(np.sqrt(total / np.pi)) + radius = float(torch.sqrt(total / torch.pi)) # Apply detected values self.center_x = cx @@ -695,34 +700,9 @@ def _auto_detect_center_silent(self): self.bf_radius = radius def _compute_global_range(self, change=None): - """Compute min/max for histogram range. Uses percentiles when hot_pixel_filter is ON.""" - - # Sample corners and center - samples = [ - (0, 0), - (0, self.shape_y - 1), - (self.shape_x - 1, 0), - (self.shape_x - 1, self.shape_y - 1), - (self.shape_x // 2, self.shape_y // 2), - ] - - # Collect all pixel values from sampled frames - all_values = [] - for x, y in samples: - frame = self._get_frame(x, y) - all_values.append(frame.ravel()) - - all_values = np.concatenate(all_values) - - if self.hot_pixel_filter: - # Use 99.99 percentile to exclude hot pixels - p_high = np.percentile(all_values, 99.99) - self.dp_global_min = max(float(all_values.min()), 1e-10) - self.dp_global_max = float(p_high) - else: - # Use actual min/max (show hot pixels) - self.dp_global_min = max(float(all_values.min()), 1e-10) - self.dp_global_max = float(all_values.max()) + """Compute min/max for histogram range. Hot pixels already removed at init.""" + self.dp_global_min = max(float(self._data.min()), 1e-10) + self.dp_global_max = float(self._data.max()) def _get_frame(self, x: int, y: int) -> np.ndarray: """Get single diffraction frame at position (x, y) as numpy array.""" @@ -732,31 +712,31 @@ def _get_frame(self, x: int, y: int) -> np.ndarray: else: return self._data[x, y].cpu().numpy() - def _filter_hot_pixels(self, frame: np.ndarray) -> np.ndarray: - """Filter hot pixels by clipping to 99.99th percentile.""" - threshold = np.percentile(frame, 99.99) - return np.clip(frame, None, threshold) - def _update_frame(self, change=None): """Send raw float32 frame to frontend (JS handles scale/colormap).""" - frame = self._get_frame(self.pos_x, self.pos_y) - frame = frame.astype(np.float32) + # Get frame as tensor (stays on device) + if self._data.ndim == 3: + idx = self.pos_x * self.shape_y + self.pos_y + frame = self._data[idx] + else: + frame = self._data[self.pos_x, self.pos_y] - # Apply hot pixel filtering if enabled - if self.hot_pixel_filter: - frame = self._filter_hot_pixels(frame) + # Apply log scale if enabled + if self.log_scale: + frame = torch.log1p(frame) # Compute stats from frame (optionally mask DC component) - if self.mask_dc: - # Mask center 3x3 region for stats + if self.mask_dc and self.det_x > 3 and self.det_y > 3: + # Mask center 3x3 region for stats (only for detectors > 3x3) cx, cy = self.det_x // 2, self.det_y // 2 - stats_frame = frame.copy() - stats_frame[max(0, cx-1):cx+2, max(0, cy-1):cy+2] = np.nan + mask = torch.ones_like(frame, dtype=torch.bool) + mask[max(0, cx-1):cx+2, max(0, cy-1):cy+2] = False + masked_vals = frame[mask] self.dp_stats = [ - float(np.nanmean(stats_frame)), - float(np.nanmin(stats_frame)), - float(np.nanmax(stats_frame)), - float(np.nanstd(stats_frame)), + float(masked_vals.mean()), + float(masked_vals.min()), + float(masked_vals.max()), + float(masked_vals.std()), ] else: self.dp_stats = [ @@ -766,8 +746,8 @@ def _update_frame(self, change=None): float(frame.std()), ] - # Send raw float32 bytes (JS handles scale/normalization/colormap) - self.frame_bytes = frame.tobytes() + # Convert to numpy only for sending bytes to frontend + self.frame_bytes = frame.cpu().numpy().astype(np.float32).tobytes() def _on_roi_change(self, change=None): """Recompute virtual image when ROI changes.""" @@ -784,77 +764,69 @@ def _on_vi_roi_change(self, change=None): self._compute_summed_dp_from_vi_roi() def _compute_summed_dp_from_vi_roi(self): - """Sum diffraction patterns from positions inside VI ROI.""" - # Create mask in scan space - # y (rows) corresponds to vi_roi_center_x, x (cols) corresponds to vi_roi_center_y - rows, cols = np.ogrid[:self.shape_x, :self.shape_y] - + """Sum diffraction patterns from positions inside VI ROI (PyTorch).""" + # Create mask in scan space using cached coordinates if self.vi_roi_mode == "circle": - mask = (rows - self.vi_roi_center_x) ** 2 + (cols - self.vi_roi_center_y) ** 2 <= self.vi_roi_radius ** 2 + mask = (self._scan_row_coords - self.vi_roi_center_x) ** 2 + (self._scan_col_coords - self.vi_roi_center_y) ** 2 <= self.vi_roi_radius ** 2 elif self.vi_roi_mode == "square": - # Square uses vi_roi_radius as half-size half_size = self.vi_roi_radius - mask = (np.abs(rows - self.vi_roi_center_x) <= half_size) & (np.abs(cols - self.vi_roi_center_y) <= half_size) + mask = (torch.abs(self._scan_row_coords - self.vi_roi_center_x) <= half_size) & (torch.abs(self._scan_col_coords - self.vi_roi_center_y) <= half_size) elif self.vi_roi_mode == "rect": half_w = self.vi_roi_width / 2 half_h = self.vi_roi_height / 2 - mask = (np.abs(rows - self.vi_roi_center_x) <= half_h) & (np.abs(cols - self.vi_roi_center_y) <= half_w) + mask = (torch.abs(self._scan_row_coords - self.vi_roi_center_x) <= half_h) & (torch.abs(self._scan_col_coords - self.vi_roi_center_y) <= half_w) else: return - # Get positions inside mask - positions = np.argwhere(mask) - if len(positions) == 0: + # Count positions in mask + n_positions = int(mask.sum()) + if n_positions == 0: self.summed_dp_bytes = b"" self.summed_dp_count = 0 return - # Average DPs from all positions (average is more useful than sum) - # positions from argwhere are (row, col) = (x_idx, y_idx) in our naming - summed_dp = np.zeros((self.det_x, self.det_y), dtype=np.float64) - for row_idx, col_idx in positions: - summed_dp += self._get_frame(row_idx, col_idx).astype(np.float64) + self.summed_dp_count = n_positions - self.summed_dp_count = len(positions) - - # Convert to average - avg_dp = summed_dp / len(positions) + # Compute average DP using masked sum (vectorized) + if self._data.ndim == 4: + # (scan_x, scan_y, det_x, det_y) - sum over masked scan positions + avg_dp = self._data[mask].mean(dim=0) + else: + # Flattened: (N, det_x, det_y) - need to convert mask indices + flat_indices = torch.nonzero(mask.flatten(), as_tuple=True)[0] + avg_dp = self._data[flat_indices].mean(dim=0) # Normalize to 0-255 for display - vmin, vmax = avg_dp.min(), avg_dp.max() + vmin, vmax = float(avg_dp.min()), float(avg_dp.max()) if vmax > vmin: - normalized = np.clip((avg_dp - vmin) / (vmax - vmin) * 255, 0, 255) - normalized = normalized.astype(np.uint8) + normalized = torch.clamp((avg_dp - vmin) / (vmax - vmin) * 255, 0, 255) + normalized = normalized.cpu().numpy().astype(np.uint8) else: - normalized = np.zeros(avg_dp.shape, dtype=np.uint8) + normalized = np.zeros((self.det_x, self.det_y), dtype=np.uint8) self.summed_dp_bytes = normalized.tobytes() def _create_circular_mask(self, cx: float, cy: float, radius: float): - """Create circular mask (boolean).""" - y, x = np.ogrid[:self.det_x, :self.det_y] - mask = (x - cx) ** 2 + (y - cy) ** 2 <= radius ** 2 + """Create circular mask (boolean tensor on device).""" + mask = (self._det_x_coords - cx) ** 2 + (self._det_y_coords - cy) ** 2 <= radius ** 2 return mask def _create_square_mask(self, cx: float, cy: float, half_size: float): - """Create square mask (boolean).""" - y, x = np.ogrid[:self.det_x, :self.det_y] - mask = (np.abs(x - cx) <= half_size) & (np.abs(y - cy) <= half_size) + """Create square mask (boolean tensor on device).""" + mask = (torch.abs(self._det_x_coords - cx) <= half_size) & (torch.abs(self._det_y_coords - cy) <= half_size) return mask def _create_annular_mask( self, cx: float, cy: float, inner: float, outer: float ): - """Create annular (donut) mask (boolean).""" - y, x = np.ogrid[:self.det_x, :self.det_y] - dist_sq = (x - cx) ** 2 + (y - cy) ** 2 + """Create annular (donut) mask (boolean tensor on device).""" + dist_sq = (self._det_x_coords - cx) ** 2 + (self._det_y_coords - cy) ** 2 mask = (dist_sq >= inner ** 2) & (dist_sq <= outer ** 2) return mask def _create_rect_mask(self, cx: float, cy: float, half_width: float, half_height: float): - """Create rectangular mask (boolean).""" - y, x = np.ogrid[:self.det_x, :self.det_y] - mask = (np.abs(x - cx) <= half_width) & (np.abs(y - cy) <= half_height) + """Create rectangular mask (boolean tensor on device).""" + mask = (torch.abs(self._det_x_coords - cx) <= half_width) & (torch.abs(self._det_y_coords - cy) <= half_height) return mask def _precompute_common_virtual_images(self): @@ -907,7 +879,7 @@ def _get_cached_preset(self) -> tuple[bytes, list[float], float, float] | None: return None - def _fast_masked_sum(self, mask) -> "np.ndarray": + def _fast_masked_sum(self, mask: torch.Tensor) -> torch.Tensor: """Compute masked sum using PyTorch. Uses sparse indexing for small masks (<20% coverage) which is faster @@ -917,30 +889,26 @@ def _fast_masked_sum(self, mask) -> "np.ndarray": For large masks (≥20%), uses full tensordot which has constant ~13ms. """ - mask_tensor = torch.from_numpy(mask.astype(np.float32)).to(self._device) + mask_float = mask.float() n_det = self._det_shape[0] * self._det_shape[1] - n_nonzero = int(mask_tensor.sum()) + n_nonzero = int(mask.sum()) coverage = n_nonzero / n_det if coverage < 0.2: # Sparse: faster for small masks - indices = torch.nonzero(mask_tensor.flatten(), as_tuple=True)[0] + indices = torch.nonzero(mask_float.flatten(), as_tuple=True)[0] n_scan = self._scan_shape[0] * self._scan_shape[1] data_flat = self._data.reshape(n_scan, n_det) result = data_flat[:, indices].sum(dim=1).reshape(self._scan_shape) else: # Tensordot: faster for large masks - result = torch.tensordot(self._data, mask_tensor, dims=([2, 3], [0, 1])) - - return result.cpu().numpy() + result = torch.tensordot(self._data, mask_float, dims=([2, 3], [0, 1])) - def _to_float32_bytes(self, arr, update_vi_stats: bool = True) -> bytes: - """Convert array to float32 bytes (JS handles scale/colormap).""" - if isinstance(arr, torch.Tensor): - arr = arr.cpu().numpy() - arr = arr.astype(np.float32) + return result - # Set min/max for JS-side normalization + def _to_float32_bytes(self, arr: torch.Tensor, update_vi_stats: bool = True) -> bytes: + """Convert tensor to float32 bytes (JS handles scale/colormap).""" + # Compute stats using PyTorch (on GPU) if update_vi_stats: self.vi_data_min = float(arr.min()) self.vi_data_max = float(arr.max()) @@ -951,7 +919,8 @@ def _to_float32_bytes(self, arr, update_vi_stats: bool = True) -> bytes: float(arr.std()), ] - return arr.tobytes() + # Convert to numpy only for sending bytes to frontend + return arr.cpu().numpy().astype(np.float32).tobytes() def _compute_virtual_image_from_roi(self): """Compute virtual image based on ROI mode.""" diff --git a/widget/tests/test_widget.py b/widget/tests/test_widget.py deleted file mode 100644 index 45524bc2..00000000 --- a/widget/tests/test_widget.py +++ /dev/null @@ -1,109 +0,0 @@ -import numpy as np - -import quantem.widget -from quantem.widget import Show4DSTEM - - -def test_version_exists(): - assert hasattr(quantem.widget, "__version__") - - -def test_version_is_string(): - assert isinstance(quantem.widget.__version__, str) - - -def test_show4dstem_loads(): - """Widget can be created from mock 4D data.""" - data = np.random.rand(8, 8, 16, 16).astype(np.float32) - widget = Show4DSTEM(data) - assert widget is not None - - -def test_show4dstem_flattened_scan_shape_mapping(): - data = np.zeros((6, 2, 2), dtype=np.float32) - for idx in range(data.shape[0]): - data[idx] = idx - - widget = Show4DSTEM(data, scan_shape=(2, 3)) - assert (widget.shape_x, widget.shape_y) == (2, 3) - assert (widget.det_x, widget.det_y) == (2, 2) - frame = widget._get_frame(1, 2) - assert np.array_equal(frame, np.full((2, 2), 5, dtype=np.float32)) - - -def test_log_scale_changes_frame_bytes(): - data = np.array([[[[0, 1], [3, 7]]]], dtype=np.float32) - widget = Show4DSTEM(data, log_scale=True) - log_bytes = bytes(widget.frame_bytes) - - widget.log_scale = False - widget._update_frame() - linear_bytes = bytes(widget.frame_bytes) - - assert log_bytes != linear_bytes - - -def test_auto_detect_center(): - """Test automatic center spot detection using centroid.""" - # Create data with a bright spot at (3, 3) in a 7x7 detector - data = np.zeros((2, 2, 7, 7), dtype=np.float32) - # Add a bright circular spot centered at (3, 3) - for i in range(7): - for j in range(7): - dist = np.sqrt((i - 3) ** 2 + (j - 3) ** 2) - if dist <= 1.5: - data[:, :, i, j] = 100.0 - - widget = Show4DSTEM(data, precompute_virtual_images=False) - # Initial center should be at detector center (3.5, 3.5) - assert widget.center_x == 3.5 - assert widget.center_y == 3.5 - - # Run auto-detection - widget.auto_detect_center() - - # Center should be detected near (3, 3) - assert abs(widget.center_x - 3.0) < 0.5 - assert abs(widget.center_y - 3.0) < 0.5 - # BF radius should be approximately sqrt(pi*r^2 / pi) = r ~ 1.5 - assert widget.bf_radius > 0 - - -def test_adf_preset_cache(): - """Test that ADF preset uses combined bf to 4*bf range.""" - data = np.random.rand(4, 4, 16, 16).astype(np.float32) - widget = Show4DSTEM(data, center=(8, 8), bf_radius=2) - - # Check that ADF cache exists (replaced LAADF/HAADF) - assert widget._cached_adf_virtual is not None - assert not hasattr(widget, "_cached_laadf_virtual") - assert not hasattr(widget, "_cached_haadf_virtual") - - # Set ROI to match ADF range - widget.roi_mode = "annular" - widget.roi_center_x = 8 - widget.roi_center_y = 8 - widget.roi_radius_inner = 2 # bf - widget.roi_radius = 8 # 4*bf - - # Should return cached value - cached = widget._get_cached_preset() - assert cached == widget._cached_adf_virtual - - -def test_rectangular_scan_shape(): - """Test that rectangular (non-square) scans work correctly.""" - # Non-square scan: 4 rows x 8 columns - data = np.random.rand(4, 8, 16, 16).astype(np.float32) - widget = Show4DSTEM(data) - - assert widget.shape_x == 4 - assert widget.shape_y == 8 - assert widget.det_x == 16 - assert widget.det_y == 16 - - # Verify frame retrieval works at corners - frame_00 = widget._get_frame(0, 0) - frame_37 = widget._get_frame(3, 7) - assert frame_00.shape == (16, 16) - assert frame_37.shape == (16, 16) From a55b6eca167cb56289252fbf9d0863846640042e Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Sun, 25 Jan 2026 14:35:26 -0800 Subject: [PATCH 25/27] faster FPS, do not render twice when ROI changes --- widget/js/show4dstem/index.tsx | 9 +- widget/src/quantem/widget/show4dstem.py | 128 +++++++++++++----------- 2 files changed, 75 insertions(+), 62 deletions(-) diff --git a/widget/js/show4dstem/index.tsx b/widget/js/show4dstem/index.tsx index 75a7183c..63894354 100644 --- a/widget/js/show4dstem/index.tsx +++ b/widget/js/show4dstem/index.tsx @@ -2130,8 +2130,13 @@ function Show4DSTEM() { setRoiCenterY(Math.round(Math.max(0, Math.min(detX - 1, imgY)))); }; - const handleDpMouseUp = () => { setIsDraggingDP(false); setIsDraggingResize(false); setIsDraggingResizeInner(false); }; - const handleDpMouseLeave = () => { setIsDraggingDP(false); setIsDraggingResize(false); setIsDraggingResizeInner(false); setIsHoveringResize(false); setIsHoveringResizeInner(false); }; + const handleDpMouseUp = () => { + setIsDraggingDP(false); setIsDraggingResize(false); setIsDraggingResizeInner(false); + }; + const handleDpMouseLeave = () => { + setIsDraggingDP(false); setIsDraggingResize(false); setIsDraggingResizeInner(false); + setIsHoveringResize(false); setIsHoveringResizeInner(false); + }; const handleDpDoubleClick = () => { setDpZoom(1); setDpPanX(0); setDpPanY(0); }; const handleViMouseDown = (e: React.MouseEvent) => { diff --git a/widget/src/quantem/widget/show4dstem.py b/widget/src/quantem/widget/show4dstem.py index d44b41f9..1de13560 100644 --- a/widget/src/quantem/widget/show4dstem.py +++ b/widget/src/quantem/widget/show4dstem.py @@ -18,8 +18,13 @@ from quantem.widget.array_utils import to_numpy -# Detector geometry constant -DEFAULT_BF_RATIO = 0.125 # 1/8 of detector size +# ============================================================================ +# Constants +# ============================================================================ +DEFAULT_BF_RATIO = 0.125 # BF disk radius as fraction of detector size (1/8) +SPARSE_MASK_THRESHOLD = 0.2 # Use sparse indexing below this mask coverage +MIN_LOG_VALUE = 1e-10 # Minimum value for log scale to avoid log(0) +DEFAULT_VI_ROI_RATIO = 0.15 # Default VI ROI size as fraction of scan dimension class Show4DSTEM(anywidget.AnyWidget): @@ -264,14 +269,18 @@ def __init__( self.center_y = float(self.det_x / 2) self.bf_radius = det_size * DEFAULT_BF_RATIO # Auto-detect center and bf_radius from the data - self._auto_detect_center_silent() + self.auto_detect_center(update_roi=False) # Cache coordinate tensors for mask creation (avoid repeated torch.arange) - self._det_y_coords = torch.arange(self.det_x, device=self._device, dtype=torch.float32)[:, None] - self._det_x_coords = torch.arange(self.det_y, device=self._device, dtype=torch.float32)[None, :] + # det_row_coords: (det_x, 1), det_col_coords: (1, det_y) + self._det_row_coords = torch.arange(self.det_x, device=self._device, dtype=torch.float32)[:, None] + self._det_col_coords = torch.arange(self.det_y, device=self._device, dtype=torch.float32)[None, :] self._scan_row_coords = torch.arange(self.shape_x, device=self._device, dtype=torch.float32)[:, None] self._scan_col_coords = torch.arange(self.shape_y, device=self._device, dtype=torch.float32)[None, :] + # Batching flag for ROI updates (prevents double computation when X and Y change together) + self._roi_update_pending = False + # Pre-compute and cache common virtual images (BF, ABF, ADF) # Each cache stores (bytes, stats) tuple self._cached_bf_virtual = None @@ -307,8 +316,8 @@ def __init__( # Initialize VI ROI center to scan center with reasonable default sizes self.vi_roi_center_x = float(self.shape_x / 2) self.vi_roi_center_y = float(self.shape_y / 2) - # Set initial ROI size to ~15% of minimum scan dimension - default_roi_size = max(3, min(self.shape_x, self.shape_y) * 0.15) + # Set initial ROI size based on scan dimension + default_roi_size = max(3, min(self.shape_x, self.shape_y) * DEFAULT_VI_ROI_RATIO) self.vi_roi_radius = float(default_roi_size) self.vi_roi_width = float(default_roi_size * 2) self.vi_roi_height = float(default_roi_size) @@ -610,7 +619,7 @@ def roi_rect( self.roi_height = float(height) return self - def auto_detect_center(self) -> "Show4DSTEM": + def auto_detect_center(self, update_roi: bool = True) -> "Show4DSTEM": """ Automatically detect BF disk center and radius using centroid. @@ -619,6 +628,12 @@ def auto_detect_center(self) -> "Show4DSTEM": values are applied to the widget's calibration (center_x, center_y, bf_radius). + Parameters + ---------- + update_roi : bool, default True + If True, also update ROI center and recompute cached virtual images. + Set to False during __init__ when ROI is not yet initialized. + Returns ------- Show4DSTEM @@ -644,9 +659,17 @@ def auto_detect_center(self) -> "Show4DSTEM": if total == 0: return self - # Calculate centroid using cached coordinates - cx = float((self._det_x_coords * mask).sum() / total) - cy = float((self._det_y_coords * mask).sum() / total) + # Calculate centroid using coordinate grids + # Note: During __init__, cached coords may not exist yet, so create them + if hasattr(self, '_det_col_coords'): + col_coords = self._det_col_coords + row_coords = self._det_row_coords + else: + row_coords = torch.arange(self.det_x, device=self._device, dtype=torch.float32)[:, None] + col_coords = torch.arange(self.det_y, device=self._device, dtype=torch.float32)[None, :] + + cx = float((col_coords * mask).sum() / total) + cy = float((row_coords * mask).sum() / total) # Estimate radius from mask area (A = pi*r^2) radius = float(torch.sqrt(total / torch.pi)) @@ -656,52 +679,18 @@ def auto_detect_center(self) -> "Show4DSTEM": self.center_y = cy self.bf_radius = radius - # Also update ROI to center - self.roi_center_x = cx - self.roi_center_y = cy - - # Recompute cached virtual images with new calibration - self._precompute_common_virtual_images() + if update_roi: + # Also update ROI to center + self.roi_center_x = cx + self.roi_center_y = cy + # Recompute cached virtual images with new calibration + self._precompute_common_virtual_images() return self - def _auto_detect_center_silent(self): - """Auto-detect center without updating ROI (used during __init__).""" - # Sum all frames (fast on GPU) - if self._data.ndim == 4: - summed_dp = self._data.sum(dim=(0, 1)) - else: - summed_dp = self._data.sum(dim=0) - - # Threshold at mean + std to isolate BF disk - threshold = summed_dp.mean() + summed_dp.std() - mask = summed_dp > threshold - - # Avoid division by zero - total = mask.sum() - if total == 0: - return - - # Calculate centroid - y_coords, x_coords = torch.meshgrid( - torch.arange(self.det_x, device=self._device), - torch.arange(self.det_y, device=self._device), - indexing='ij' - ) - cx = float((x_coords * mask).sum() / total) - cy = float((y_coords * mask).sum() / total) - - # Estimate radius from mask area (A = pi*r^2) - radius = float(torch.sqrt(total / torch.pi)) - - # Apply detected values (don't update ROI - that happens later in __init__) - self.center_x = cx - self.center_y = cy - self.bf_radius = radius - def _compute_global_range(self, change=None): """Compute min/max for histogram range. Hot pixels already removed at init.""" - self.dp_global_min = max(float(self._data.min()), 1e-10) + self.dp_global_min = max(float(self._data.min()), MIN_LOG_VALUE) self.dp_global_max = float(self._data.max()) def _get_frame(self, x: int, y: int) -> np.ndarray: @@ -750,9 +739,28 @@ def _update_frame(self, change=None): self.frame_bytes = frame.cpu().numpy().astype(np.float32).tobytes() def _on_roi_change(self, change=None): - """Recompute virtual image when ROI changes.""" + """Recompute virtual image when ROI changes. + + Uses batching to prevent double computation when X and Y change together. + Multiple rapid changes within the same event loop tick are combined. + """ if not self.roi_active: return + if self._roi_update_pending: + return # Already scheduled, will pick up new values + self._roi_update_pending = True + # Schedule for next event loop tick to batch X and Y changes + try: + import asyncio + loop = asyncio.get_running_loop() + loop.call_soon(self._do_roi_update) + except RuntimeError: + # No running event loop (e.g., during init or testing) + self._do_roi_update() + + def _do_roi_update(self): + """Execute the batched ROI update.""" + self._roi_update_pending = False self._compute_virtual_image_from_roi() def _on_vi_roi_change(self, change=None): @@ -808,25 +816,25 @@ def _compute_summed_dp_from_vi_roi(self): def _create_circular_mask(self, cx: float, cy: float, radius: float): """Create circular mask (boolean tensor on device).""" - mask = (self._det_x_coords - cx) ** 2 + (self._det_y_coords - cy) ** 2 <= radius ** 2 + mask = (self._det_col_coords - cx) ** 2 + (self._det_row_coords - cy) ** 2 <= radius ** 2 return mask def _create_square_mask(self, cx: float, cy: float, half_size: float): """Create square mask (boolean tensor on device).""" - mask = (torch.abs(self._det_x_coords - cx) <= half_size) & (torch.abs(self._det_y_coords - cy) <= half_size) + mask = (torch.abs(self._det_col_coords - cx) <= half_size) & (torch.abs(self._det_row_coords - cy) <= half_size) return mask def _create_annular_mask( self, cx: float, cy: float, inner: float, outer: float ): """Create annular (donut) mask (boolean tensor on device).""" - dist_sq = (self._det_x_coords - cx) ** 2 + (self._det_y_coords - cy) ** 2 + dist_sq = (self._det_col_coords - cx) ** 2 + (self._det_row_coords - cy) ** 2 mask = (dist_sq >= inner ** 2) & (dist_sq <= outer ** 2) return mask def _create_rect_mask(self, cx: float, cy: float, half_width: float, half_height: float): """Create rectangular mask (boolean tensor on device).""" - mask = (torch.abs(self._det_x_coords - cx) <= half_width) & (torch.abs(self._det_y_coords - cy) <= half_height) + mask = (torch.abs(self._det_col_coords - cx) <= half_width) & (torch.abs(self._det_row_coords - cy) <= half_height) return mask def _precompute_common_virtual_images(self): @@ -894,7 +902,7 @@ def _fast_masked_sum(self, mask: torch.Tensor) -> torch.Tensor: n_nonzero = int(mask.sum()) coverage = n_nonzero / n_det - if coverage < 0.2: + if coverage < SPARSE_MASK_THRESHOLD: # Sparse: faster for small masks indices = torch.nonzero(mask_float.flatten(), as_tuple=True)[0] n_scan = self._scan_shape[0] * self._scan_shape[1] @@ -946,8 +954,8 @@ def _compute_virtual_image_from_roi(self): mask = self._create_rect_mask(cx, cy, self.roi_width / 2, self.roi_height / 2) else: # Point mode: single-pixel indexing - row = int(np.clip(round(cy), 0, self._det_shape[0] - 1)) - col = int(np.clip(round(cx), 0, self._det_shape[1] - 1)) + row = int(max(0, min(round(cy), self._det_shape[0] - 1))) + col = int(max(0, min(round(cx), self._det_shape[1] - 1))) if self._data.ndim == 4: virtual_image = self._data[:, :, row, col] else: From 9436c7493a1e303c7b2b3b5c8d47a06b8b24f6fd Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Sun, 25 Jan 2026 18:14:16 -0800 Subject: [PATCH 26/27] remove unused code, batch model updates for faster ROI dragging - Remove isPow2(), debug console.log, unused array_utils functions - Remove vite.config.js (using esbuild directly) - Remove Show4D alias from __init__.py - Batch pos_x/pos_y and roi_center updates into single save_changes() - Use Python's pre-computed viDataMin/viDataMax instead of JS array scan - Rename "Virtual Image" label to "Image" - Add tests/visual_check.py for Playwright screenshot testing --- widget/js/show4dstem/index.tsx | 81 +++++++--- widget/package.json | 2 +- widget/src/quantem/widget/__init__.py | 3 - widget/src/quantem/widget/array_utils.py | 41 ----- widget/src/quantem/widget/show4dstem.py | 105 ++++++------- widget/tests/visual_check.py | 187 +++++++++++++++++++++++ widget/vite.config.js | 27 ---- 7 files changed, 294 insertions(+), 152 deletions(-) create mode 100644 widget/tests/visual_check.py delete mode 100644 widget/vite.config.js diff --git a/widget/js/show4dstem/index.tsx b/widget/js/show4dstem/index.tsx index 63894354..4ac54e64 100644 --- a/widget/js/show4dstem/index.tsx +++ b/widget/js/show4dstem/index.tsx @@ -1,6 +1,6 @@ /// import * as React from "react"; -import { createRender, useModelState } from "@anywidget/react"; +import { createRender, useModelState, useModel } from "@anywidget/react"; import Box from "@mui/material/Box"; import Typography from "@mui/material/Typography"; import Stack from "@mui/material/Stack"; @@ -130,7 +130,6 @@ const MIN_ZOOM = 0.5; const MAX_ZOOM = 10; function nextPow2(n: number): number { return Math.pow(2, Math.ceil(Math.log2(n))); } -function isPow2(n: number): boolean { return n > 0 && (n & (n - 1)) === 0; } function fft1dPow2(real: Float32Array, imag: Float32Array, inverse: boolean = false) { const n = real.length; @@ -310,7 +309,6 @@ async function getWebGPUFFT(): Promise { if (!adapter) return null; const device = await adapter.requestDevice(); gpuFFT = new WebGPUFFT(device); await gpuFFT.init(); - console.log('🚀 WebGPU FFT ready'); return gpuFFT; } catch (e) { console.warn('WebGPU init failed:', e); return null; } } @@ -1188,6 +1186,9 @@ function Histogram({ // Main Component // ============================================================================ function Show4DSTEM() { + // Direct model access for batched updates + const model = useModel(); + // ───────────────────────────────────────────────────────────────────────── // Model State (synced with Python) // ───────────────────────────────────────────────────────────────────────── @@ -1520,7 +1521,7 @@ function Show4DSTEM() { // Parse virtual image bytes into Float32Array and apply scale for histogram React.useEffect(() => { if (!virtualImageBytes) return; - // Parse as Float32Array since Python now sends raw float32 + // Parse as Float32Array const numFloats = virtualImageBytes.byteLength / 4; const rawData = new Float32Array(virtualImageBytes.buffer, virtualImageBytes.byteOffset, numFloats); @@ -1545,7 +1546,6 @@ function Show4DSTEM() { } else { scaledData.set(rawData); } - // Update histogram state (triggers re-render) setViHistogramData(scaledData); }, [virtualImageBytes, viScaleMode, viPowerExp]); @@ -1683,11 +1683,28 @@ function Show4DSTEM() { } } - // Compute actual min/max of scaled data - let dataMin = Infinity, dataMax = -Infinity; - for (let i = 0; i < scaled.length; i++) { - if (scaled[i] < dataMin) dataMin = scaled[i]; - if (scaled[i] > dataMax) dataMax = scaled[i]; + // Use Python's pre-computed min/max when valid, fallback to computing from data + let dataMin: number, dataMax: number; + const hasValidMinMax = viDataMin !== undefined && viDataMax !== undefined && viDataMax > viDataMin; + if (hasValidMinMax) { + // Apply scale transform to Python's values + if (viScaleMode === "log") { + dataMin = Math.log1p(Math.max(0, viDataMin)); + dataMax = Math.log1p(Math.max(0, viDataMax)); + } else if (viScaleMode === "power") { + dataMin = Math.pow(Math.max(0, viDataMin), viPowerExp); + dataMax = Math.pow(Math.max(0, viDataMax), viPowerExp); + } else { + dataMin = viDataMin; + dataMax = viDataMax; + } + } else { + // Fallback: compute from scaled data + dataMin = Infinity; dataMax = -Infinity; + for (let i = 0; i < scaled.length; i++) { + if (scaled[i] < dataMin) dataMin = scaled[i]; + if (scaled[i] > dataMax) dataMax = scaled[i]; + } } // Apply vmin/vmax percentile clipping @@ -1738,6 +1755,8 @@ function Show4DSTEM() { if (!rawVirtualImageRef.current) return; renderData(rawVirtualImageRef.current); + // Note: viDataMin/viDataMax intentionally not in deps - they arrive with virtualImageBytes + // and we have a fallback if they're stale }, [virtualImageBytes, shapeX, shapeY, viColormap, viVminPct, viVmaxPct, viScaleMode, viPowerExp, viZoom, viPanX, viPanY]); // Render virtual image overlay (just clear - crosshair drawn on high-DPI UI canvas) @@ -2077,9 +2096,12 @@ function Show4DSTEM() { setIsDraggingDP(true); setLocalKx(imgX); setLocalKy(imgY); - setRoiActive(true); - setRoiCenterX(Math.round(Math.max(0, Math.min(detY - 1, imgX)))); - setRoiCenterY(Math.round(Math.max(0, Math.min(detX - 1, imgY)))); + // Use compound roi_center trait - single observer fires in Python + const newX = Math.round(Math.max(0, Math.min(detY - 1, imgX))); + const newY = Math.round(Math.max(0, Math.min(detX - 1, imgY))); + model.set("roi_active", true); + model.set("roi_center", [newX, newY]); + model.save_changes(); }; const handleDpMouseMove = (e: React.MouseEvent) => { @@ -2126,8 +2148,11 @@ function Show4DSTEM() { } setLocalKx(imgX); setLocalKy(imgY); - setRoiCenterX(Math.round(Math.max(0, Math.min(detY - 1, imgX)))); - setRoiCenterY(Math.round(Math.max(0, Math.min(detX - 1, imgY)))); + // Use compound roi_center trait - single observer fires in Python + const newX = Math.round(Math.max(0, Math.min(detY - 1, imgX))); + const newY = Math.round(Math.max(0, Math.min(detX - 1, imgY))); + model.set("roi_center", [newX, newY]); + model.save_changes(); }; const handleDpMouseUp = () => { @@ -2168,8 +2193,12 @@ function Show4DSTEM() { // Regular position selection (when ROI is off) setIsDraggingVI(true); setLocalPosX(imgX); setLocalPosY(imgY); - setPosX(Math.round(Math.max(0, Math.min(shapeX - 1, imgX)))); - setPosY(Math.round(Math.max(0, Math.min(shapeY - 1, imgY)))); + // Batch X and Y updates into a single sync + const newX = Math.round(Math.max(0, Math.min(shapeX - 1, imgX))); + const newY = Math.round(Math.max(0, Math.min(shapeY - 1, imgY))); + model.set("pos_x", newX); + model.set("pos_y", newY); + model.save_changes(); }; const handleViMouseMove = (e: React.MouseEvent) => { @@ -2209,16 +2238,24 @@ function Show4DSTEM() { if (isDraggingViRoi) { setLocalViRoiCenterX(imgX); setLocalViRoiCenterY(imgY); - setViRoiCenterX(Math.round(Math.max(0, Math.min(shapeX - 1, imgX)))); - setViRoiCenterY(Math.round(Math.max(0, Math.min(shapeY - 1, imgY)))); + // Batch VI ROI center updates + const newViX = Math.round(Math.max(0, Math.min(shapeX - 1, imgX))); + const newViY = Math.round(Math.max(0, Math.min(shapeY - 1, imgY))); + model.set("vi_roi_center_x", newViX); + model.set("vi_roi_center_y", newViY); + model.save_changes(); return; } // Handle regular position dragging (when ROI is off) if (!isDraggingVI) return; setLocalPosX(imgX); setLocalPosY(imgY); - setPosX(Math.round(Math.max(0, Math.min(shapeX - 1, imgX)))); - setPosY(Math.round(Math.max(0, Math.min(shapeY - 1, imgY)))); + // Batch position updates into a single sync + const newX = Math.round(Math.max(0, Math.min(shapeX - 1, imgX))); + const newY = Math.round(Math.max(0, Math.min(shapeY - 1, imgY))); + model.set("pos_x", newX); + model.set("pos_y", newY); + model.save_changes(); }; const handleViMouseUp = () => { @@ -2454,7 +2491,7 @@ function Show4DSTEM() { {/* VI Header */} - Virtual Image + Image {shapeX}×{shapeY} | {detX}×{detY} diff --git a/widget/package.json b/widget/package.json index 4e63363a..dbe065f2 100644 --- a/widget/package.json +++ b/widget/package.json @@ -1,7 +1,7 @@ { "scripts": { "dev": "npm run build -- --sourcemap=inline --watch", - "build": "esbuild js/show4dstem/index.tsx js/show2d/index.tsx js/show3d/index.tsx --minify --format=esm --bundle --outdir=src/quantem/widget/static --entry-names=[dir]", + "build": "esbuild js/show4dstem/index.tsx --minify --format=esm --bundle --outdir=src/quantem/widget/static --outbase=js --entry-names=[dir]", "typecheck": "tsc --noEmit" }, "dependencies": { diff --git a/widget/src/quantem/widget/__init__.py b/widget/src/quantem/widget/__init__.py index d4c6c040..4d75aa24 100644 --- a/widget/src/quantem/widget/__init__.py +++ b/widget/src/quantem/widget/__init__.py @@ -11,7 +11,4 @@ from quantem.widget.show4dstem import Show4DSTEM -# Alias for convenience -Show4D = Show4DSTEM - __all__ = ["Show4DSTEM"] diff --git a/widget/src/quantem/widget/array_utils.py b/widget/src/quantem/widget/array_utils.py index 717c5bcc..5287ca0a 100644 --- a/widget/src/quantem/widget/array_utils.py +++ b/widget/src/quantem/widget/array_utils.py @@ -107,44 +107,3 @@ def to_numpy(data: Any, dtype: np.dtype | None = None) -> np.ndarray: result = np.asarray(result, dtype=dtype) return result - - -def to_numpy_list(data_list: list[Any], dtype: np.dtype | None = None) -> list[np.ndarray]: - """ - Convert a list of arrays to NumPy arrays. - - Parameters - ---------- - data_list : list of array-like - List of arrays from any supported backend. - dtype : np.dtype, optional - Target dtype for all output arrays. - - Returns - ------- - list of np.ndarray - List of NumPy arrays. - """ - return [to_numpy(arr, dtype=dtype) for arr in data_list] - - -def get_gpu_module(data: Any): - """ - Get the GPU module (cupy) if the data is on GPU, else return numpy. - - Parameters - ---------- - data : array-like - Input array. - - Returns - ------- - module - Either cupy or numpy module. - """ - backend = get_array_backend(data) - if backend == "cupy": - import cupy as cp - - return cp - return np diff --git a/widget/src/quantem/widget/show4dstem.py b/widget/src/quantem/widget/show4dstem.py index 1de13560..0c0c843f 100644 --- a/widget/src/quantem/widget/show4dstem.py +++ b/widget/src/quantem/widget/show4dstem.py @@ -115,6 +115,8 @@ class Show4DSTEM(anywidget.AnyWidget): roi_mode = traitlets.Unicode("point").tag(sync=True) roi_center_x = traitlets.Float(0.0).tag(sync=True) roi_center_y = traitlets.Float(0.0).tag(sync=True) + # Compound trait for batched X+Y updates (JS sends both at once, 1 observer fires) + roi_center = traitlets.List(traitlets.Float(), default_value=[0.0, 0.0]).tag(sync=True) roi_radius = traitlets.Float(10.0).tag(sync=True) roi_radius_inner = traitlets.Float(5.0).tag(sync=True) roi_width = traitlets.Float(20.0).tag(sync=True) @@ -241,8 +243,14 @@ def __init__( # Initial position at center self.pos_x = self.shape_x // 2 self.pos_y = self.shape_y // 2 - # Precompute global range for consistent scaling - self._compute_global_range() + # Precompute global range for consistent scaling (hot pixels already removed) + self.dp_global_min = max(float(self._data.min()), MIN_LOG_VALUE) + self.dp_global_max = float(self._data.max()) + # Cache coordinate tensors for mask creation (avoid repeated torch.arange) + self._det_row_coords = torch.arange(self.det_x, device=self._device, dtype=torch.float32)[:, None] + self._det_col_coords = torch.arange(self.det_y, device=self._device, dtype=torch.float32)[None, :] + self._scan_row_coords = torch.arange(self.shape_x, device=self._device, dtype=torch.float32)[:, None] + self._scan_col_coords = torch.arange(self.shape_y, device=self._device, dtype=torch.float32)[None, :] # Setup center and BF radius # If user provides explicit values, use them # Otherwise, auto-detect from the data for accurate presets @@ -271,16 +279,6 @@ def __init__( # Auto-detect center and bf_radius from the data self.auto_detect_center(update_roi=False) - # Cache coordinate tensors for mask creation (avoid repeated torch.arange) - # det_row_coords: (det_x, 1), det_col_coords: (1, det_y) - self._det_row_coords = torch.arange(self.det_x, device=self._device, dtype=torch.float32)[:, None] - self._det_col_coords = torch.arange(self.det_y, device=self._device, dtype=torch.float32)[None, :] - self._scan_row_coords = torch.arange(self.shape_x, device=self._device, dtype=torch.float32)[:, None] - self._scan_col_coords = torch.arange(self.shape_y, device=self._device, dtype=torch.float32)[None, :] - - # Batching flag for ROI updates (prevents double computation when X and Y change together) - self._roi_update_pending = False - # Pre-compute and cache common virtual images (BF, ABF, ADF) # Each cache stores (bytes, stats) tuple self._cached_bf_virtual = None @@ -291,14 +289,18 @@ def __init__( # Update frame when position changes (scale/colormap handled in JS) self.observe(self._update_frame, names=["pos_x", "pos_y"]) + # Observe individual ROI params (for backward compatibility) self.observe(self._on_roi_change, names=[ "roi_center_x", "roi_center_y", "roi_radius", "roi_radius_inner", "roi_active", "roi_mode", "roi_width", "roi_height" ]) - + # Observe compound roi_center for batched updates from JS + self.observe(self._on_roi_center_change, names=["roi_center"]) + # Initialize default ROI at BF center self.roi_center_x = self.center_x self.roi_center_y = self.center_y + self.roi_center = [self.center_x, self.center_y] self.roi_radius = self.bf_radius * 0.5 # Start with half BF radius self.roi_active = True @@ -659,17 +661,9 @@ def auto_detect_center(self, update_roi: bool = True) -> "Show4DSTEM": if total == 0: return self - # Calculate centroid using coordinate grids - # Note: During __init__, cached coords may not exist yet, so create them - if hasattr(self, '_det_col_coords'): - col_coords = self._det_col_coords - row_coords = self._det_row_coords - else: - row_coords = torch.arange(self.det_x, device=self._device, dtype=torch.float32)[:, None] - col_coords = torch.arange(self.det_y, device=self._device, dtype=torch.float32)[None, :] - - cx = float((col_coords * mask).sum() / total) - cy = float((row_coords * mask).sum() / total) + # Calculate centroid using cached coordinate grids + cx = float((self._det_col_coords * mask).sum() / total) + cy = float((self._det_row_coords * mask).sum() / total) # Estimate radius from mask area (A = pi*r^2) radius = float(torch.sqrt(total / torch.pi)) @@ -688,11 +682,6 @@ def auto_detect_center(self, update_roi: bool = True) -> "Show4DSTEM": return self - def _compute_global_range(self, change=None): - """Compute min/max for histogram range. Hot pixels already removed at init.""" - self.dp_global_min = max(float(self._data.min()), MIN_LOG_VALUE) - self.dp_global_max = float(self._data.max()) - def _get_frame(self, x: int, y: int) -> np.ndarray: """Get single diffraction frame at position (x, y) as numpy array.""" if self._data.ndim == 3: @@ -739,28 +728,30 @@ def _update_frame(self, change=None): self.frame_bytes = frame.cpu().numpy().astype(np.float32).tobytes() def _on_roi_change(self, change=None): - """Recompute virtual image when ROI changes. + """Recompute virtual image when individual ROI params change. + + This handles legacy setters (setRoiCenterX/Y) from button handlers. + High-frequency updates use the compound roi_center trait instead. + """ + if not self.roi_active: + return + self._compute_virtual_image_from_roi() - Uses batching to prevent double computation when X and Y change together. - Multiple rapid changes within the same event loop tick are combined. + def _on_roi_center_change(self, change=None): + """Handle batched roi_center updates from JS (single observer for X+Y). + + This is the fast path for drag operations. JS sends [x, y] as a single + compound trait, so only one observer fires per mouse move. """ if not self.roi_active: return - if self._roi_update_pending: - return # Already scheduled, will pick up new values - self._roi_update_pending = True - # Schedule for next event loop tick to batch X and Y changes - try: - import asyncio - loop = asyncio.get_running_loop() - loop.call_soon(self._do_roi_update) - except RuntimeError: - # No running event loop (e.g., during init or testing) - self._do_roi_update() - - def _do_roi_update(self): - """Execute the batched ROI update.""" - self._roi_update_pending = False + if change and "new" in change: + x, y = change["new"] + # Sync to individual traits (without triggering _on_roi_change observers) + self.unobserve(self._on_roi_change, names=["roi_center_x", "roi_center_y"]) + self.roi_center_x = x + self.roi_center_y = y + self.observe(self._on_roi_change, names=["roi_center_x", "roi_center_y"]) self._compute_virtual_image_from_roi() def _on_vi_roi_change(self, change=None): @@ -915,19 +906,17 @@ def _fast_masked_sum(self, mask: torch.Tensor) -> torch.Tensor: return result def _to_float32_bytes(self, arr: torch.Tensor, update_vi_stats: bool = True) -> bytes: - """Convert tensor to float32 bytes (JS handles scale/colormap).""" - # Compute stats using PyTorch (on GPU) + """Convert tensor to float32 bytes.""" + # Compute min/max (fast on GPU) + vmin = float(arr.min()) + vmax = float(arr.max()) + self.vi_data_min = vmin + self.vi_data_max = vmax + + # Compute full stats if requested if update_vi_stats: - self.vi_data_min = float(arr.min()) - self.vi_data_max = float(arr.max()) - self.vi_stats = [ - float(arr.mean()), - float(arr.min()), - float(arr.max()), - float(arr.std()), - ] + self.vi_stats = [float(arr.mean()), vmin, vmax, float(arr.std())] - # Convert to numpy only for sending bytes to frontend return arr.cpu().numpy().astype(np.float32).tobytes() def _compute_virtual_image_from_roi(self): diff --git a/widget/tests/visual_check.py b/widget/tests/visual_check.py new file mode 100644 index 00000000..6d07b489 --- /dev/null +++ b/widget/tests/visual_check.py @@ -0,0 +1,187 @@ +""" +Visual regression test - screenshots the actual Show4DSTEM widget in Jupyter. + +This is a manual test (not run in CI) because it: +- Requires real data files +- Takes ~1-2 minutes +- Needs JupyterLab + Chromium + +Usage: + playwright install chromium # one-time setup + python widget/tests/visual_check.py +""" + +import subprocess +import time +import json +import tempfile +import shutil +from pathlib import Path +from playwright.sync_api import sync_playwright, Page + +OUTPUT_DIR = Path(__file__).parent / "screenshots" +OUTPUT_DIR.mkdir(exist_ok=True) + +# GPU/memory cleanup to prepend to each test +CLEANUP = """ +import os +os.environ['CUDA_VISIBLE_DEVICES'] = '1' +import torch, gc +if torch.cuda.is_available(): + torch.cuda.empty_cache() +gc.collect() +""" + +# Real data test cases +REAL_TESTS = [ + ("arina_george", CLEANUP + """ +from quantem.core.io.file_readers import read_4dstem +from quantem.widget import Show4DSTEM +dataset = read_4dstem('/home/bobleesj/data/geroge/gold-10/gold_10_master.h5', file_type='arina') +widget = Show4DSTEM(dataset) +widget +"""), + ("arina_ncem", CLEANUP + """ +from quantem.core.io.file_readers import read_4dstem +from quantem.widget import Show4DSTEM +dataset = read_4dstem('/home/bobleesj/data/251115_ncem_arina_steph/lamella_2_002_master.h5', file_type='arina') +widget = Show4DSTEM(dataset) +widget +"""), + ("rect_scan", CLEANUP + """ +import h5py +from quantem.widget import Show4DSTEM +with h5py.File('/home/bobleesj/data/ptycho_MoS2_bin2.h5', 'r') as f: + data = f['4DSTEM/datacube/data'][:] +print(f'Data shape: {data.shape}') +widget = Show4DSTEM(data) +widget +"""), + ("legacy_h5", CLEANUP + """ +from quantem.core.io.file_readers import read_emdfile_to_4dstem +from quantem.widget import Show4DSTEM +dataset = read_emdfile_to_4dstem('/home/bobleesj/data/ptycho_gold_data_2024.h5') +widget = Show4DSTEM(dataset) +widget +"""), +] + + +def create_notebook(path: Path, code: str): + """Create a Jupyter notebook with the given code.""" + notebook = { + "cells": [{"cell_type": "code", "execution_count": None, "metadata": {}, "outputs": [], "source": code}], + "metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}}, + "nbformat": 4, + "nbformat_minor": 5 + } + with open(path, 'w') as f: + json.dump(notebook, f) + + +def run_cell(page: Page, notebook_name: str): + """Execute notebook cell and wait for completion.""" + # Click tab to ensure focus + try: + page.click(f'.lm-TabBar-tab[data-id*="{notebook_name}"]', timeout=3000) + page.wait_for_timeout(500) + except Exception: + pass + + # Click cell editor and run + try: + page.click('.jp-NotebookPanel:not(.lm-mod-hidden) .jp-Cell .cm-content', timeout=3000) + except Exception: + page.click('.jp-NotebookPanel:not(.lm-mod-hidden) .jp-Cell', timeout=3000) + page.wait_for_timeout(300) + page.keyboard.press('Shift+Enter') + print(" Executing cell...") + + # Wait for idle + try: + page.wait_for_selector('.jp-NotebookPanel:not(.lm-mod-hidden) .jp-Notebook-ExecutionIndicator[data-status="idle"]', timeout=30000) + except Exception: + page.wait_for_timeout(10000) + + +def wait_for_widget(page: Page, timeout: int = 60000) -> bool: + """Wait for widget canvas to render.""" + try: + page.wait_for_selector('.jp-NotebookPanel:not(.lm-mod-hidden) .jp-OutputArea-output canvas', timeout=timeout) + print(" Widget rendered") + page.wait_for_timeout(2000) + return True + except Exception: + print(" Warning: Widget not detected") + return False + + +def run_tests(tests: list, port: int = 18888): + """Run screenshot tests for given test cases.""" + temp_dir = Path(tempfile.mkdtemp()) + jupyter_proc = None + + try: + print("=" * 60) + print(f"Show4DSTEM Screenshot Tests ({len(tests)} tests)") + print(f"Output: {OUTPUT_DIR}") + print("=" * 60) + + print("\nStarting JupyterLab...") + jupyter_proc = subprocess.Popen( + ["jupyter", "lab", "--no-browser", f"--port={port}", + f"--notebook-dir={temp_dir}", "--ServerApp.token=", "--ServerApp.password="], + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + ) + time.sleep(8) + + with sync_playwright() as p: + browser = p.chromium.launch(headless=True) + + for name, code in tests: + print(f"\n[TEST] {name}") + create_notebook(temp_dir / f"{name}.ipynb", code.strip()) + + context = browser.new_context(viewport={"width": 1600, "height": 1200}) + page = context.new_page() + page.goto(f"http://localhost:{port}/lab/tree/{name}.ipynb") + + # Wait for JupyterLab ready + try: + page.wait_for_selector('.jp-NotebookPanel:not(.lm-mod-hidden)', timeout=30000) + page.wait_for_selector('.jp-NotebookPanel:not(.lm-mod-hidden) .jp-Notebook-ExecutionIndicator[data-status="idle"]', timeout=30000) + except Exception: + page.wait_for_timeout(8000) + + run_cell(page, name) + widget_found = wait_for_widget(page) + + if not widget_found: + page.screenshot(path=str(OUTPUT_DIR / f"{name}_debug.png"), full_page=True) + + # Scroll to widget and capture + page.evaluate("document.querySelector('.jp-OutputArea-output canvas')?.scrollIntoView({block: 'center'})") + page.wait_for_timeout(1000) + page.screenshot(path=str(OUTPUT_DIR / f"{name}.png"), full_page=True) + print(f" Saved: {name}.png") + + context.close() + + browser.close() + + print("\n" + "=" * 60) + print(f"Done! Screenshots in: {OUTPUT_DIR}") + print("=" * 60) + + finally: + if jupyter_proc: + jupyter_proc.terminate() + try: + jupyter_proc.wait(timeout=5) + except Exception: + jupyter_proc.kill() + shutil.rmtree(temp_dir, ignore_errors=True) + + +if __name__ == "__main__": + run_tests(REAL_TESTS) diff --git a/widget/vite.config.js b/widget/vite.config.js deleted file mode 100644 index 948e74d3..00000000 --- a/widget/vite.config.js +++ /dev/null @@ -1,27 +0,0 @@ -import { defineConfig } from "vite"; -import anywidget from "@anywidget/vite"; -import react from "@vitejs/plugin-react"; - -export default defineConfig({ - plugins: [anywidget(), react()], - define: { - "process.env.NODE_ENV": JSON.stringify("production"), - }, - build: { - outDir: "src/quantem/widget/static", - lib: { - entry: { - show4dstem: "js/show4dstem.tsx", - }, - formats: ["es"], - }, - rollupOptions: { - output: { - // Each entry gets its own file - entryFileNames: "[name].js", - // CSS is handled separately by anywidget - assetFileNames: "[name][extname]", - }, - }, - }, -}); From 25a59b414dd37451d593b05966221fd8458ccd2e Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Sun, 25 Jan 2026 18:16:37 -0800 Subject: [PATCH 27/27] restore test_widget_show4dstem.py, remove visual_check.py --- widget/tests/test_widget_show4dstem.py | 140 ++++++++++++++++++ widget/tests/visual_check.py | 187 ------------------------- 2 files changed, 140 insertions(+), 187 deletions(-) create mode 100644 widget/tests/test_widget_show4dstem.py delete mode 100644 widget/tests/visual_check.py diff --git a/widget/tests/test_widget_show4dstem.py b/widget/tests/test_widget_show4dstem.py new file mode 100644 index 00000000..413d6470 --- /dev/null +++ b/widget/tests/test_widget_show4dstem.py @@ -0,0 +1,140 @@ +import numpy as np +import quantem.widget +from quantem.widget import Show4DSTEM + + +def test_version_exists(): + assert hasattr(quantem.widget, "__version__") + + +def test_version_is_string(): + assert isinstance(quantem.widget.__version__, str) + + +def test_show4dstem_loads(): + """Widget can be created from mock 4D data.""" + data = np.random.rand(8, 8, 16, 16).astype(np.float32) + widget = Show4DSTEM(data) + assert widget is not None + + +def test_show4dstem_flattened_scan_shape_mapping(): + """Test flattened 3D data with explicit scan shape.""" + data = np.zeros((6, 2, 2), dtype=np.float32) + for idx in range(data.shape[0]): + data[idx] = idx + widget = Show4DSTEM(data, scan_shape=(2, 3)) + assert (widget.shape_x, widget.shape_y) == (2, 3) + assert (widget.det_x, widget.det_y) == (2, 2) + frame = widget._get_frame(1, 2) + assert np.array_equal(frame, np.full((2, 2), 5, dtype=np.float32)) + + +def test_show4dstem_log_scale(): + """Test that log scale changes frame bytes.""" + data = np.random.rand(2, 2, 8, 8).astype(np.float32) * 100 + 1 + widget = Show4DSTEM(data, log_scale=True) + log_bytes = bytes(widget.frame_bytes) + widget.log_scale = False + widget._update_frame() + linear_bytes = bytes(widget.frame_bytes) + assert log_bytes != linear_bytes + + +def test_show4dstem_auto_detect_center(): + """Test automatic center spot detection using centroid.""" + data = np.zeros((2, 2, 7, 7), dtype=np.float32) + for i in range(7): + for j in range(7): + dist = np.sqrt((i - 3) ** 2 + (j - 3) ** 2) + if dist <= 1.5: + data[:, :, i, j] = 100.0 + widget = Show4DSTEM(data, precompute_virtual_images=False) + widget.auto_detect_center() + assert abs(widget.center_x - 3.0) < 0.5 + assert abs(widget.center_y - 3.0) < 0.5 + assert widget.bf_radius > 0 + + +def test_show4dstem_adf_preset_cache(): + """Test that ADF preset cache works when precompute is enabled.""" + data = np.random.rand(4, 4, 16, 16).astype(np.float32) + widget = Show4DSTEM(data, center=(8, 8), bf_radius=2, precompute_virtual_images=True) + assert widget._cached_adf_virtual is not None + widget.roi_mode = "annular" + widget.roi_center_x = 8 + widget.roi_center_y = 8 + widget.roi_radius_inner = 2 + widget.roi_radius = 8 + cached = widget._get_cached_preset() + assert cached == widget._cached_adf_virtual + + +def test_show4dstem_rectangular_scan_shape(): + """Test that rectangular (non-square) scans work correctly.""" + data = np.random.rand(4, 8, 16, 16).astype(np.float32) + widget = Show4DSTEM(data) + assert widget.shape_x == 4 + assert widget.shape_y == 8 + assert widget.det_x == 16 + assert widget.det_y == 16 + frame_00 = widget._get_frame(0, 0) + frame_37 = widget._get_frame(3, 7) + assert frame_00.shape == (16, 16) + assert frame_37.shape == (16, 16) + + +def test_show4dstem_hot_pixel_removal_uint16(): + """Test that saturated uint16 hot pixels are removed at init.""" + data = np.zeros((4, 4, 8, 8), dtype=np.uint16) + data[:, :, :, :] = 100 + data[:, :, 3, 5] = 65535 + data[:, :, 1, 2] = 65535 + widget = Show4DSTEM(data) + assert widget.dp_global_max < 65535 + assert widget.dp_global_max == 100.0 + frame = widget._get_frame(0, 0) + assert frame[3, 5] == 0 + assert frame[1, 2] == 0 + assert frame[0, 0] == 100 + + +def test_show4dstem_hot_pixel_removal_uint8(): + """Test that saturated uint8 hot pixels are removed at init.""" + data = np.zeros((4, 4, 8, 8), dtype=np.uint8) + data[:, :, :, :] = 50 + data[:, :, 2, 3] = 255 + widget = Show4DSTEM(data) + assert widget.dp_global_max == 50.0 + frame = widget._get_frame(0, 0) + assert frame[2, 3] == 0 + + +def test_show4dstem_no_hot_pixel_removal_float32(): + """Test that float32 data is not modified (no saturated value).""" + data = np.ones((4, 4, 8, 8), dtype=np.float32) * 1000 + widget = Show4DSTEM(data) + assert widget.dp_global_max == 1000.0 + + +def test_show4dstem_roi_modes(): + """Test all ROI modes compute virtual images correctly.""" + data = np.random.rand(8, 8, 16, 16).astype(np.float32) + widget = Show4DSTEM(data, center=(8, 8), bf_radius=3) + for mode in ["point", "circle", "square", "annular", "rect"]: + widget.roi_mode = mode + widget.roi_active = True + assert len(widget.vi_stats) == 4 + assert widget.vi_stats[2] >= widget.vi_stats[1] + + +def test_show4dstem_virtual_image_excludes_hot_pixels(): + """Test that virtual images don't include hot pixel contributions.""" + data = np.ones((4, 4, 8, 8), dtype=np.uint16) * 10 + data[:, :, 4, 4] = 65535 + widget = Show4DSTEM(data, center=(4, 4), bf_radius=2) + widget.roi_mode = "circle" + widget.roi_center_x = 4 + widget.roi_center_y = 4 + widget.roi_radius = 3 + assert widget.vi_stats[2] < 1000 diff --git a/widget/tests/visual_check.py b/widget/tests/visual_check.py deleted file mode 100644 index 6d07b489..00000000 --- a/widget/tests/visual_check.py +++ /dev/null @@ -1,187 +0,0 @@ -""" -Visual regression test - screenshots the actual Show4DSTEM widget in Jupyter. - -This is a manual test (not run in CI) because it: -- Requires real data files -- Takes ~1-2 minutes -- Needs JupyterLab + Chromium - -Usage: - playwright install chromium # one-time setup - python widget/tests/visual_check.py -""" - -import subprocess -import time -import json -import tempfile -import shutil -from pathlib import Path -from playwright.sync_api import sync_playwright, Page - -OUTPUT_DIR = Path(__file__).parent / "screenshots" -OUTPUT_DIR.mkdir(exist_ok=True) - -# GPU/memory cleanup to prepend to each test -CLEANUP = """ -import os -os.environ['CUDA_VISIBLE_DEVICES'] = '1' -import torch, gc -if torch.cuda.is_available(): - torch.cuda.empty_cache() -gc.collect() -""" - -# Real data test cases -REAL_TESTS = [ - ("arina_george", CLEANUP + """ -from quantem.core.io.file_readers import read_4dstem -from quantem.widget import Show4DSTEM -dataset = read_4dstem('/home/bobleesj/data/geroge/gold-10/gold_10_master.h5', file_type='arina') -widget = Show4DSTEM(dataset) -widget -"""), - ("arina_ncem", CLEANUP + """ -from quantem.core.io.file_readers import read_4dstem -from quantem.widget import Show4DSTEM -dataset = read_4dstem('/home/bobleesj/data/251115_ncem_arina_steph/lamella_2_002_master.h5', file_type='arina') -widget = Show4DSTEM(dataset) -widget -"""), - ("rect_scan", CLEANUP + """ -import h5py -from quantem.widget import Show4DSTEM -with h5py.File('/home/bobleesj/data/ptycho_MoS2_bin2.h5', 'r') as f: - data = f['4DSTEM/datacube/data'][:] -print(f'Data shape: {data.shape}') -widget = Show4DSTEM(data) -widget -"""), - ("legacy_h5", CLEANUP + """ -from quantem.core.io.file_readers import read_emdfile_to_4dstem -from quantem.widget import Show4DSTEM -dataset = read_emdfile_to_4dstem('/home/bobleesj/data/ptycho_gold_data_2024.h5') -widget = Show4DSTEM(dataset) -widget -"""), -] - - -def create_notebook(path: Path, code: str): - """Create a Jupyter notebook with the given code.""" - notebook = { - "cells": [{"cell_type": "code", "execution_count": None, "metadata": {}, "outputs": [], "source": code}], - "metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}}, - "nbformat": 4, - "nbformat_minor": 5 - } - with open(path, 'w') as f: - json.dump(notebook, f) - - -def run_cell(page: Page, notebook_name: str): - """Execute notebook cell and wait for completion.""" - # Click tab to ensure focus - try: - page.click(f'.lm-TabBar-tab[data-id*="{notebook_name}"]', timeout=3000) - page.wait_for_timeout(500) - except Exception: - pass - - # Click cell editor and run - try: - page.click('.jp-NotebookPanel:not(.lm-mod-hidden) .jp-Cell .cm-content', timeout=3000) - except Exception: - page.click('.jp-NotebookPanel:not(.lm-mod-hidden) .jp-Cell', timeout=3000) - page.wait_for_timeout(300) - page.keyboard.press('Shift+Enter') - print(" Executing cell...") - - # Wait for idle - try: - page.wait_for_selector('.jp-NotebookPanel:not(.lm-mod-hidden) .jp-Notebook-ExecutionIndicator[data-status="idle"]', timeout=30000) - except Exception: - page.wait_for_timeout(10000) - - -def wait_for_widget(page: Page, timeout: int = 60000) -> bool: - """Wait for widget canvas to render.""" - try: - page.wait_for_selector('.jp-NotebookPanel:not(.lm-mod-hidden) .jp-OutputArea-output canvas', timeout=timeout) - print(" Widget rendered") - page.wait_for_timeout(2000) - return True - except Exception: - print(" Warning: Widget not detected") - return False - - -def run_tests(tests: list, port: int = 18888): - """Run screenshot tests for given test cases.""" - temp_dir = Path(tempfile.mkdtemp()) - jupyter_proc = None - - try: - print("=" * 60) - print(f"Show4DSTEM Screenshot Tests ({len(tests)} tests)") - print(f"Output: {OUTPUT_DIR}") - print("=" * 60) - - print("\nStarting JupyterLab...") - jupyter_proc = subprocess.Popen( - ["jupyter", "lab", "--no-browser", f"--port={port}", - f"--notebook-dir={temp_dir}", "--ServerApp.token=", "--ServerApp.password="], - stdout=subprocess.PIPE, stderr=subprocess.PIPE, - ) - time.sleep(8) - - with sync_playwright() as p: - browser = p.chromium.launch(headless=True) - - for name, code in tests: - print(f"\n[TEST] {name}") - create_notebook(temp_dir / f"{name}.ipynb", code.strip()) - - context = browser.new_context(viewport={"width": 1600, "height": 1200}) - page = context.new_page() - page.goto(f"http://localhost:{port}/lab/tree/{name}.ipynb") - - # Wait for JupyterLab ready - try: - page.wait_for_selector('.jp-NotebookPanel:not(.lm-mod-hidden)', timeout=30000) - page.wait_for_selector('.jp-NotebookPanel:not(.lm-mod-hidden) .jp-Notebook-ExecutionIndicator[data-status="idle"]', timeout=30000) - except Exception: - page.wait_for_timeout(8000) - - run_cell(page, name) - widget_found = wait_for_widget(page) - - if not widget_found: - page.screenshot(path=str(OUTPUT_DIR / f"{name}_debug.png"), full_page=True) - - # Scroll to widget and capture - page.evaluate("document.querySelector('.jp-OutputArea-output canvas')?.scrollIntoView({block: 'center'})") - page.wait_for_timeout(1000) - page.screenshot(path=str(OUTPUT_DIR / f"{name}.png"), full_page=True) - print(f" Saved: {name}.png") - - context.close() - - browser.close() - - print("\n" + "=" * 60) - print(f"Done! Screenshots in: {OUTPUT_DIR}") - print("=" * 60) - - finally: - if jupyter_proc: - jupyter_proc.terminate() - try: - jupyter_proc.wait(timeout=5) - except Exception: - jupyter_proc.kill() - shutil.rmtree(temp_dir, ignore_errors=True) - - -if __name__ == "__main__": - run_tests(REAL_TESTS)