Source code for utils

from __future__ import annotations

"""
utils

Small, dependency-light utilities used across 

Goals:
- deterministic, CLI-friendly behavior
- safe JSON serialization helpers
- consistent error-context wrapping
- predictable logging without duplicate handlers
- light-weight dict override utilities (for CLI --set style overrides)
"""

import functools
import json
import logging
import os
import time
from dataclasses import asdict, is_dataclass
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, Mapping, MutableMapping, Optional, TypeVar

T = TypeVar("T")


# -----------------------------------------------------------------------------
# Filesystem helpers
# -----------------------------------------------------------------------------

[docs] def ensure_dir(path: str | Path) -> None: """ Ensure a directory exists. If `path` looks like a file path (has a suffix), create its parent directory. Otherwise, create the directory itself. """ p = Path(path) target = p.parent if p.suffix else p if str(target) and not target.exists(): target.mkdir(parents=True, exist_ok=True)
# ----------------------------------------------------------------------------- # Logging helpers # ----------------------------------------------------------------------------- _LOGGER_INIT: set[str] = set() def _parse_env_log_level(env_value: str, default: int) -> int: s = str(env_value).strip() if not s: return default # Allow numeric levels ("10") or names ("DEBUG") try: return int(s) except Exception: return int(getattr(logging, s.upper(), default))
[docs] def setup_logger(name: str = "tdpy", level: int = logging.INFO) -> logging.Logger: """ Deterministic logger setup. - Avoids duplicate handlers even if called repeatedly - Default INFO; can be overridden via env TDPY_LOG_LEVEL (e.g. DEBUG, INFO, WARNING, 10, 20) """ env_level = os.environ.get("TDPY_LOG_LEVEL") if env_level is not None: level = _parse_env_log_level(env_level, level) logger = logging.getLogger(name) logger.setLevel(level) # Avoid duplicate handlers. Also tolerate external logging configuration. if name in _LOGGER_INIT or logger.handlers: return logger h = logging.StreamHandler() fmt = logging.Formatter("[%(levelname)s] %(name)s: %(message)s") h.setFormatter(fmt) logger.addHandler(h) logger.propagate = False _LOGGER_INIT.add(name) return logger
# ----------------------------------------------------------------------------- # JSON helpers # -----------------------------------------------------------------------------
[docs] def json_default(obj: Any) -> Any: """ Default handler for json.dumps. Supports: - dataclasses - pathlib.Path - simple iterables (set/tuple) -> list """ if is_dataclass(obj): return asdict(obj) if isinstance(obj, Path): return str(obj) if isinstance(obj, (set, tuple)): return list(obj) raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
[docs] def json_dumps(data: Any, *, indent: int = 2) -> str: """Convenience wrapper for consistent JSON output.""" return json.dumps(data, indent=indent, default=json_default)
# ----------------------------------------------------------------------------- # Timing / error-context decorators # -----------------------------------------------------------------------------
[docs] def timed(fn: Callable[..., T]) -> Callable[..., T]: """Decorator to time function calls and log elapsed seconds (DEBUG level).""" logger = logging.getLogger("tdpy") # don't rebuild handlers on every call @functools.wraps(fn) def wrapper(*args: Any, **kwargs: Any) -> T: t0 = time.perf_counter() out = fn(*args, **kwargs) dt = time.perf_counter() - t0 logger.debug("%s took %.6fs", fn.__name__, dt) return out return wrapper
[docs] def with_error_context(context: str) -> Callable[[Callable[..., T]], Callable[..., T]]: """ Decorator factory: wraps exceptions with extra context. Keeps the original exception as __cause__ so tracebacks stay useful. """ def deco(fn: Callable[..., T]) -> Callable[..., T]: @functools.wraps(fn) def wrapper(*args: Any, **kwargs: Any) -> T: try: return fn(*args, **kwargs) except Exception as e: et = type(e).__name__ msg = str(e) raise RuntimeError(f"{context}: {et}: {msg}") from e return wrapper return deco
# ----------------------------------------------------------------------------- # Env var helper # -----------------------------------------------------------------------------
[docs] def env_path(*parts: str, default: str | None = None) -> str | None: """ Read env var like TDPY_FOO_BAR for parts ("foo","bar"). Returns string or default. """ key = "_".join(["TDPY", *parts]).upper() return os.environ.get(key, default)
# ----------------------------------------------------------------------------- # Dict / override helpers (used for CLI --set and future GUI editing) # -----------------------------------------------------------------------------
[docs] def deep_update(dst: MutableMapping[str, Any], src: Mapping[str, Any]) -> MutableMapping[str, Any]: """ Recursively merge src into dst (in-place) and return dst. Dicts merge deeply, other types overwrite. """ for k, v in src.items(): if isinstance(v, Mapping) and isinstance(dst.get(k), Mapping): deep_update(dst[k], v) # type: ignore[arg-type] else: dst[k] = v return dst
[docs] def dotted_get(mapping: Mapping[str, Any], path: str, default: Any = None) -> Any: """Get mapping["a"]["b"]["c"] using path "a.b.c".""" cur: Any = mapping for part in str(path).split("."): if not isinstance(cur, Mapping) or part not in cur: return default cur = cur[part] return cur
[docs] def dotted_set(mapping: MutableMapping[str, Any], path: str, value: Any) -> None: """ Set mapping["a"]["b"]["c"] = value using path "a.b.c". Creates intermediate dicts as needed. """ parts = str(path).split(".") if not parts or any(p.strip() == "" for p in parts): raise ValueError(f"Invalid dotted path: {path!r}") cur: MutableMapping[str, Any] = mapping for p in parts[:-1]: nxt = cur.get(p) if not isinstance(nxt, MutableMapping): nxt = {} cur[p] = nxt cur = nxt # type: ignore[assignment] cur[parts[-1]] = value
[docs] def coerce_scalar(text: str) -> Any: """ Coerce a string into bool/int/float if it looks like one; otherwise return the original string. """ s = str(text).strip() if s == "": return "" sl = s.lower() if sl in ("true", "false"): return sl == "true" # int (avoid octal-ish surprises) try: if s.startswith("0") and s not in ("0", "0.0") and not s.startswith("0."): raise ValueError return int(s) except Exception: pass # float try: return float(s) except Exception: return s
[docs] def parse_overrides(pairs: Iterable[str]) -> Dict[str, Any]: """ Parse CLI-style overrides: ["a.b=3", "foo=true", "bar.baz=1e-6"] Returns a nested dict suitable for deep_update(). """ out: Dict[str, Any] = {} for raw in pairs: item = str(raw).strip() if not item: continue if "=" not in item: raise ValueError(f"Override must be key=value; got {raw!r}") k, v = item.split("=", 1) dotted_set(out, k.strip(), coerce_scalar(v)) return out
# ----------------------------------------------------------------------------- # Optional dependency hints (nice UX for CLI) # -----------------------------------------------------------------------------
[docs] def require_package(import_name: str, pip_name: Optional[str] = None) -> None: """Raise a clear ImportError for optional dependencies.""" try: __import__(import_name) except Exception as e: pkg = pip_name or import_name raise ImportError( f"Missing optional dependency {import_name!r}. Install with: pip install {pkg}" ) from e
# ----------------------------------------------------------------------------- # Small convenience helpers # -----------------------------------------------------------------------------
[docs] def now_iso() -> str: """UTC-ish timestamp string; good enough for meta fields.""" return time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
[docs] def clamp(x: float, lo: float | None = None, hi: float | None = None) -> float: if lo is not None and x < lo: return lo if hi is not None and x > hi: return hi return x