Source code for interpreter.numeric_eval

from __future__ import annotations

"""Safe numeric evaluation helpers for the TDPy interpreter.

This module evaluates numeric constants and numeric constant expressions in a
restricted way. It is used by the interpreter layer before the full nonlinear
equation solver is invoked.

Goals
-----
The evaluator is designed to:

* Evaluate simple numeric constants and constant expressions safely.
* Support unit-aware quantity parsing when the optional units layer is enabled.
* Provide useful error messages that identify the failing expression.
* Stay consistent with equation preprocessing, including ``^`` to ``**``
  conversion when the shared preprocessing helper is available.

Security policy
---------------
The evaluator uses an AST whitelist. It does not allow attribute access,
subscripts, comprehensions, lambdas, imports, comparisons, boolean operations,
conditional expressions, starred arguments, or calls to non-whitelisted
functions.
"""

import ast
import math
import re
from dataclasses import dataclass
from typing import Any, Dict, Mapping, Optional


# ------------------------------ preprocessing ------------------------------

# Reuse solver-safe preprocessing when available.
try:
    from equations.safe_eval import preprocess_expr  # type: ignore
except Exception:  # pragma: no cover

    def preprocess_expr(s: str) -> str:
        """Minimal fallback preprocessing used when equations.safe_eval is unavailable."""
        return s.replace("^", "**")


# ------------------------------ whitelist configuration ------------------------------

_ALLOWED_BINOPS = (ast.Add, ast.Sub, ast.Mult, ast.Div, ast.Pow, ast.Mod)
_ALLOWED_UNARYOPS = (ast.UAdd, ast.USub)


[docs] def default_numeric_funcs() -> Dict[str, Any]: """Return the default numeric-only function allowlist. The returned functions perform numeric work only. They do not perform file I/O, imports, attribute access, or backend calls. """ funcs: Dict[str, Any] = { # core "abs": abs, "min": min, "max": max, "pow": pow, "clamp": lambda x, lo, hi: max(lo, min(hi, x)), # exponentials and logs "sqrt": math.sqrt, "exp": math.exp, "log": math.log, "ln": math.log, "log10": math.log10, "log2": getattr(math, "log2", None), # trigonometry "sin": math.sin, "cos": math.cos, "tan": math.tan, "asin": math.asin, "acos": math.acos, "atan": math.atan, "atan2": math.atan2, # hyperbolic functions "sinh": math.sinh, "cosh": math.cosh, "tanh": math.tanh, # rounding "floor": math.floor, "ceil": math.ceil, # misc "hypot": math.hypot, "radians": math.radians, "degrees": math.degrees, } return {k: v for k, v in funcs.items() if v is not None}
# ------------------------------ errors ------------------------------
[docs] @dataclass(frozen=True) class NumericEvalContext: """Context attached to a numeric evaluation error.""" expr: str where: str = "numeric expression"
[docs] class NumericEvalError(ValueError): """Error raised when a numeric expression cannot be safely evaluated.""" def __init__(self, message: str, *, ctx: Optional[NumericEvalContext] = None): if ctx is not None: super().__init__(f"{ctx.where}: {message} | expr={ctx.expr!r}") else: super().__init__(message)
# ------------------------------ safe numeric eval ------------------------------
[docs] def safe_eval_numeric( expr: str, *, names: Mapping[str, float], funcs: Optional[Mapping[str, Any]] = None, ) -> float: """Evaluate a numeric expression with an AST whitelist. Allowed syntax includes numeric literals, resolved names supplied through ``names``, the operators ``+``, ``-``, ``*``, ``/``, ``**``, ``%``, unary signs, parentheses, and calls to whitelisted numeric functions. Disallowed syntax includes attribute access, indexing, comprehensions, lambdas, assignments, comparisons, boolean operations, conditional expressions, calls to non-whitelisted functions, and starred arguments. Parameters ---------- expr: Expression text to evaluate. names: Mapping of previously resolved numeric constants. funcs: Optional extra numeric-safe functions. Returns ------- float Evaluated numeric value. """ s = preprocess_expr(str(expr)).strip() ctx = NumericEvalContext(expr=s) if not s: raise NumericEvalError("Empty expression", ctx=ctx) fns = dict(default_numeric_funcs()) if funcs: fns.update(dict(funcs)) consts: Dict[str, float] = dict(names) consts.setdefault("pi", float(math.pi)) consts.setdefault("e", float(math.e)) try: node = ast.parse(s, mode="eval") except SyntaxError as e: raise NumericEvalError("Invalid syntax", ctx=ctx) from e class V(ast.NodeVisitor): """AST visitor implementing the numeric whitelist.""" def visit_Expression(self, n: ast.Expression) -> float: return float(self.visit(n.body)) def visit_Constant(self, n: ast.Constant) -> float: if isinstance(n.value, (int, float)): return float(n.value) raise NumericEvalError(f"Non-numeric literal not allowed: {n.value!r}", ctx=ctx) def visit_Num(self, n: ast.Num) -> float: # pragma: no cover """Support Python versions that still expose ast.Num.""" return float(n.n) def visit_Name(self, n: ast.Name) -> float: nm = n.id if nm.startswith("__"): raise NumericEvalError("Dunder names are not allowed", ctx=ctx) if nm in consts: return float(consts[nm]) raise NumericEvalError(f"Unknown name: {nm!r}", ctx=ctx) def visit_UnaryOp(self, n: ast.UnaryOp) -> float: if not isinstance(n.op, _ALLOWED_UNARYOPS): raise NumericEvalError(f"Unary operator not allowed: {type(n.op).__name__}", ctx=ctx) v = float(self.visit(n.operand)) return +v if isinstance(n.op, ast.UAdd) else -v def visit_BinOp(self, n: ast.BinOp) -> float: if not isinstance(n.op, _ALLOWED_BINOPS): raise NumericEvalError(f"Operator not allowed: {type(n.op).__name__}", ctx=ctx) a = float(self.visit(n.left)) b = float(self.visit(n.right)) if isinstance(n.op, ast.Add): return a + b if isinstance(n.op, ast.Sub): return a - b if isinstance(n.op, ast.Mult): return a * b if isinstance(n.op, ast.Div): return a / b if isinstance(n.op, ast.Pow): return a**b if isinstance(n.op, ast.Mod): return a % b raise NumericEvalError("Unhandled operator", ctx=ctx) def visit_Call(self, n: ast.Call) -> float: if not isinstance(n.func, ast.Name): raise NumericEvalError("Only direct calls f(x) are allowed", ctx=ctx) fn = n.func.id if fn not in fns: raise NumericEvalError(f"Function not allowed: {fn!r}", ctx=ctx) for a in n.args: if isinstance(a, ast.Starred): raise NumericEvalError("Star-args are not allowed", ctx=ctx) for kw in n.keywords: if kw.arg is None: raise NumericEvalError("Star-kwargs are not allowed", ctx=ctx) args = [float(self.visit(a)) for a in n.args] kwargs = {str(kw.arg): float(self.visit(kw.value)) for kw in n.keywords} try: return float(fns[fn](*args, **kwargs)) except Exception as e: raise NumericEvalError(f"Call failed: {fn}(...) -> {e}", ctx=ctx) from e def visit_Attribute(self, n: ast.Attribute) -> float: raise NumericEvalError("Attribute access is not allowed", ctx=ctx) def visit_Subscript(self, n: ast.Subscript) -> float: raise NumericEvalError("Indexing is not allowed", ctx=ctx) def visit_Lambda(self, n: ast.Lambda) -> float: # pragma: no cover raise NumericEvalError("Lambda is not allowed", ctx=ctx) def visit_Compare(self, n: ast.Compare) -> float: raise NumericEvalError("Comparisons are not allowed in numeric constants", ctx=ctx) def visit_BoolOp(self, n: ast.BoolOp) -> float: raise NumericEvalError("Boolean operations are not allowed in numeric constants", ctx=ctx) def visit_IfExp(self, n: ast.IfExp) -> float: raise NumericEvalError("Conditional expressions are not allowed", ctx=ctx) def generic_visit(self, n: ast.AST) -> float: raise NumericEvalError(f"Unsafe or unsupported syntax: {type(n).__name__}", ctx=ctx) return float(V().visit(node))
# ------------------------------ units parsing ------------------------------ _FLOAT_RE = re.compile( r"""^\s* [+-]? (?: (?:\d+(?:\.\d*)?)|(?:\.\d+) ) (?:[eE][+-]?\d+)? \s*$""", re.VERBOSE, ) def _looks_like_plain_float(s: str) -> bool: return bool(_FLOAT_RE.match(s))
[docs] def try_parse_float_or_quantity(s: str, *, enable_units: bool = True) -> Optional[float]: """Try to parse text as a float or a unit-aware quantity. The parser first accepts plain finite floats and scientific notation. When ``enable_units`` is true, it then tries the optional TDPy units layer. Returns ------- float | None Parsed finite value, or ``None`` when the text cannot be interpreted as a numeric value. """ ss = preprocess_expr(str(s)).strip() if not ss: return None if _looks_like_plain_float(ss): try: v = float(ss) if not math.isfinite(v): return None return v except Exception: pass if not enable_units: return None try: from units import DEFAULT_REGISTRY, parse_quantity # type: ignore except Exception: return None try: q = parse_quantity(ss, DEFAULT_REGISTRY) v = float(q.base_value()) if not math.isfinite(v): return None return v except Exception: return None