from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Mapping
try:
from .core.clutch import Brake, Clutch, RotatingMember, Sprag
from .core.planetary import PlanetaryGearSet
from .core.solver import TransmissionSolver
from .utils import (
TransmissionAppError,
coerce_int,
dedupe_keep_order,
ensure_dict,
ensure_list,
ensure_str,
normalize_state_name,
)
except ImportError:
from core.clutch import Brake, Clutch, RotatingMember, Sprag
from core.planetary import PlanetaryGearSet
from core.solver import TransmissionSolver
from utils import (
TransmissionAppError,
coerce_int,
dedupe_keep_order,
ensure_dict,
ensure_list,
ensure_str,
normalize_state_name,
)
[docs]
@dataclass(frozen=True)
class GearsetSpec:
name: str
Ns: int
Nr: int
sun: str
ring: str
carrier: str
[docs]
@dataclass(frozen=True)
class ClutchSpec:
name: str
a: str
b: str
[docs]
@dataclass(frozen=True)
class BrakeSpec:
name: str
member: str
[docs]
@dataclass(frozen=True)
class SpragSpec:
name: str
member: str
hold_direction: str = "counter_clockwise"
locked_when_engaged: bool = True
[docs]
@dataclass(frozen=True)
class ShiftStateSpec:
name: str
active_constraints: tuple[str, ...]
display_elements: tuple[str, ...]
notes: str = ""
manual_neutral: bool = False
[docs]
@dataclass
class TransmissionSpec:
name: str
input_member: str
output_member: str
gearsets: list[GearsetSpec]
clutches: list[ClutchSpec]
brakes: list[BrakeSpec]
sprags: list[SpragSpec] = field(default_factory=list)
permanent_ties: list[tuple[str, str]] = field(default_factory=list)
members: list[str] = field(default_factory=list)
display_order: list[str] = field(default_factory=list)
state_aliases: dict[str, str] = field(default_factory=dict)
speed_display_order: list[str] = field(default_factory=list)
speed_display_labels: dict[str, str] = field(default_factory=dict)
strict_geometry: bool = False
notes: str = ""
presets: dict[str, Any] = field(default_factory=dict)
meta: dict[str, Any] = field(default_factory=dict)
@staticmethod
def _raw_clutch_items(spec_dict: Mapping[str, Any]) -> tuple[list[Any], str]:
"""Return clutch-like items with backward-compatible schema support.
Preferred modern key:
clutches_brakes_flywheels
Legacy key still supported:
clutches
"""
if spec_dict.get("clutches_brakes_flywheels") is not None:
return (
ensure_list(
spec_dict.get("clutches_brakes_flywheels"),
context="spec.clutches_brakes_flywheels",
),
"spec.clutches_brakes_flywheels",
)
return (
ensure_list(spec_dict.get("clutches"), context="spec.clutches"),
"spec.clutches",
)
[docs]
@staticmethod
def from_dict(data: Mapping[str, Any]) -> "TransmissionSpec":
d = ensure_dict(data, context="transmission spec")
name = str(d.get("name", "Generic Transmission")).strip() or "Generic Transmission"
input_member = ensure_str(d.get("input_member"), context="spec.input_member")
output_member = ensure_str(d.get("output_member"), context="spec.output_member")
strict_geometry = bool(d.get("strict_geometry", False))
notes = str(d.get("notes", ""))
members = [
ensure_str(x, context="spec.members[]")
for x in ensure_list(d.get("members"), context="spec.members")
]
display_order = [
ensure_str(x, context="spec.display_order[]")
for x in ensure_list(d.get("display_order"), context="spec.display_order")
]
speed_display_order = [
ensure_str(x, context="spec.speed_display_order[]")
for x in ensure_list(d.get("speed_display_order"), context="spec.speed_display_order")
]
raw_aliases = ensure_dict(d.get("state_aliases"), context="spec.state_aliases")
state_aliases = {
ensure_str(k, context="spec.state_aliases key"): ensure_str(v, context="spec.state_aliases value")
for k, v in raw_aliases.items()
}
raw_speed_labels = ensure_dict(d.get("speed_display_labels"), context="spec.speed_display_labels")
speed_display_labels = {
ensure_str(k, context="spec.speed_display_labels key"): ensure_str(v, context="spec.speed_display_labels value")
for k, v in raw_speed_labels.items()
}
presets = ensure_dict(d.get("presets"), context="spec.presets")
meta = ensure_dict(d.get("meta"), context="spec.meta")
gearsets: list[GearsetSpec] = []
for idx, item in enumerate(ensure_list(d.get("gearsets"), context="spec.gearsets")):
g = ensure_dict(item, context=f"spec.gearsets[{idx}]")
gearsets.append(
GearsetSpec(
name=ensure_str(g.get("name"), context=f"spec.gearsets[{idx}].name"),
Ns=coerce_int(g.get("Ns"), context=f"spec.gearsets[{idx}].Ns"),
Nr=coerce_int(g.get("Nr"), context=f"spec.gearsets[{idx}].Nr"),
sun=ensure_str(g.get("sun"), context=f"spec.gearsets[{idx}].sun"),
ring=ensure_str(g.get("ring"), context=f"spec.gearsets[{idx}].ring"),
carrier=ensure_str(g.get("carrier"), context=f"spec.gearsets[{idx}].carrier"),
)
)
if not gearsets:
raise TransmissionAppError("spec.gearsets must contain at least one planetary gearset.")
raw_clutch_items, clutch_context = TransmissionSpec._raw_clutch_items(d)
clutches: list[ClutchSpec] = []
for idx, item in enumerate(raw_clutch_items):
c = ensure_dict(item, context=f"{clutch_context}[{idx}]")
clutches.append(
ClutchSpec(
name=ensure_str(c.get("name"), context=f"{clutch_context}[{idx}].name"),
a=ensure_str(c.get("a"), context=f"{clutch_context}[{idx}].a"),
b=ensure_str(c.get("b"), context=f"{clutch_context}[{idx}].b"),
)
)
brakes: list[BrakeSpec] = []
for idx, item in enumerate(ensure_list(d.get("brakes"), context="spec.brakes")):
b = ensure_dict(item, context=f"spec.brakes[{idx}]")
brakes.append(
BrakeSpec(
name=ensure_str(b.get("name"), context=f"spec.brakes[{idx}].name"),
member=ensure_str(b.get("member"), context=f"spec.brakes[{idx}].member"),
)
)
sprags: list[SpragSpec] = []
for idx, item in enumerate(ensure_list(d.get("sprags"), context="spec.sprags")):
s = ensure_dict(item, context=f"spec.sprags[{idx}]")
sprags.append(
SpragSpec(
name=ensure_str(s.get("name"), context=f"spec.sprags[{idx}].name"),
member=ensure_str(s.get("member"), context=f"spec.sprags[{idx}].member"),
hold_direction=str(s.get("hold_direction", "counter_clockwise")).strip() or "counter_clockwise",
locked_when_engaged=bool(s.get("locked_when_engaged", True)),
)
)
permanent_ties: list[tuple[str, str]] = []
for idx, item in enumerate(ensure_list(d.get("permanent_ties"), context="spec.permanent_ties")):
if not isinstance(item, list) or len(item) != 2:
raise TransmissionAppError(f"spec.permanent_ties[{idx}] must be a 2-item array.")
a = ensure_str(item[0], context=f"spec.permanent_ties[{idx}][0]")
b = ensure_str(item[1], context=f"spec.permanent_ties[{idx}][1]")
permanent_ties.append((a, b))
return TransmissionSpec(
name=name,
input_member=input_member,
output_member=output_member,
gearsets=gearsets,
clutches=clutches,
brakes=brakes,
sprags=sprags,
permanent_ties=permanent_ties,
members=members,
display_order=display_order,
state_aliases=state_aliases,
speed_display_order=speed_display_order,
speed_display_labels=speed_display_labels,
strict_geometry=strict_geometry,
notes=notes,
presets=presets,
meta=meta,
)
[docs]
def all_member_names(self) -> list[str]:
out: list[str] = []
out.extend(self.members)
out.append(self.input_member)
out.append(self.output_member)
for g in self.gearsets:
out.extend([g.sun, g.ring, g.carrier])
for c in self.clutches:
out.extend([c.a, c.b])
for b in self.brakes:
out.append(b.member)
for s in self.sprags:
out.append(s.member)
for a, b in self.permanent_ties:
out.extend([a, b])
return dedupe_keep_order(out)
[docs]
@dataclass
class ShiftSchedule:
states: dict[str, ShiftStateSpec]
notes: str = ""
display_order: list[str] = field(default_factory=list)
[docs]
@staticmethod
def from_dict(data: Mapping[str, Any], *, aliases: Mapping[str, str] | None = None) -> "ShiftSchedule":
d = ensure_dict(data, context="shift schedule")
raw_states = ensure_dict(d.get("states"), context="schedule.states")
states: dict[str, ShiftStateSpec] = {}
for raw_state, raw_spec in raw_states.items():
state_name = normalize_state_name(str(raw_state), aliases)
if isinstance(raw_spec, list):
elems = ensure_list(raw_spec, context=f"schedule.states.{raw_state}")
active = tuple(
ensure_str(x, context=f"schedule.states.{raw_state}[]")
for x in elems
)
states[state_name] = ShiftStateSpec(
name=state_name,
active_constraints=active,
display_elements=active,
notes="",
manual_neutral=False,
)
continue
spec_obj = ensure_dict(raw_spec, context=f"schedule.states.{raw_state}")
active_raw = ensure_list(
spec_obj.get("active_constraints", spec_obj.get("engaged")),
context=f"schedule.states.{raw_state}.active_constraints",
)
active = tuple(
ensure_str(x, context=f"schedule.states.{raw_state}.active_constraints[]")
for x in active_raw
)
display_raw = ensure_list(
spec_obj.get("display_elements", list(active)),
context=f"schedule.states.{raw_state}.display_elements",
)
display = tuple(
ensure_str(x, context=f"schedule.states.{raw_state}.display_elements[]")
for x in display_raw
)
notes = str(spec_obj.get("notes", ""))
manual_neutral = bool(spec_obj.get("manual_neutral", False))
states[state_name] = ShiftStateSpec(
name=state_name,
active_constraints=active,
display_elements=display,
notes=notes,
manual_neutral=manual_neutral,
)
display_order = [
ensure_str(x, context="schedule.display_order[]")
for x in ensure_list(d.get("display_order"), context="schedule.display_order")
]
notes = str(d.get("notes", ""))
return ShiftSchedule(states=states, notes=notes, display_order=display_order)
[docs]
@dataclass(frozen=True)
class GenericSolveResult:
state: str
engaged: tuple[str, ...]
ok: bool
ratio: float | None
speeds: dict[str, float]
notes: str = ""
solver_path: str = "core_generic_json_builder"
status: str = "ok"
message: str = ""
[docs]
class GenericTransmission:
def __init__(self, *, spec: TransmissionSpec, schedule: ShiftSchedule) -> None:
self.spec = spec
self.schedule = schedule
self._validate_schedule_elements()
def _validate_schedule_elements(self) -> None:
valid = (
{c.name for c in self.spec.clutches}
| {b.name for b in self.spec.brakes}
| {s.name for s in self.spec.sprags}
)
for state, state_spec in self.schedule.states.items():
for elem in state_spec.active_constraints:
if elem not in valid:
valid_txt = ", ".join(sorted(valid))
raise TransmissionAppError(
f"Schedule state '{state}' references unknown shift element '{elem}'. "
f"Valid elements: {valid_txt}"
)
def _member_map(self) -> dict[str, RotatingMember]:
return {name: RotatingMember(name) for name in self.spec.all_member_names()}
[docs]
def build_solver(
self,
) -> tuple[
TransmissionSolver,
dict[str, RotatingMember],
dict[str, Clutch],
dict[str, Brake],
dict[str, Sprag],
]:
members = self._member_map()
solver = TransmissionSolver()
for gear in self.spec.gearsets:
gearset = PlanetaryGearSet(
Ns=gear.Ns,
Nr=gear.Nr,
name=gear.name,
sun=members[gear.sun],
ring=members[gear.ring],
carrier=members[gear.carrier],
geometry_mode="strict" if self.spec.strict_geometry else "relaxed",
)
solver.add_gearset(gearset)
clutch_map: dict[str, Clutch] = {}
for c in self.spec.clutches:
obj = Clutch(members[c.a], members[c.b], name=c.name)
solver.add_clutch(obj)
clutch_map[c.name] = obj
brake_map: dict[str, Brake] = {}
for b in self.spec.brakes:
obj = Brake(members[b.member], name=b.name)
solver.add_brake(obj)
brake_map[b.name] = obj
sprag_map: dict[str, Sprag] = {}
for s in self.spec.sprags:
obj = Sprag(
members[s.member],
hold_direction=s.hold_direction,
locked_when_engaged=s.locked_when_engaged,
name=s.name,
)
solver.add_brake(obj) # type: ignore[arg-type]
sprag_map[s.name] = obj
for a, b in self.spec.permanent_ties:
solver.add_permanent_tie(a, b)
return solver, members, clutch_map, brake_map, sprag_map
[docs]
def topology_summary(self) -> dict[str, Any]:
return {
"name": self.spec.name,
"input_member": self.spec.input_member,
"output_member": self.spec.output_member,
"strict_geometry": self.spec.strict_geometry,
"members": self.spec.all_member_names(),
"gearsets": [
{
"name": g.name,
"Ns": g.Ns,
"Nr": g.Nr,
"sun": g.sun,
"ring": g.ring,
"carrier": g.carrier,
}
for g in self.spec.gearsets
],
"clutches_brakes_flywheels": [{"name": c.name, "a": c.a, "b": c.b} for c in self.spec.clutches],
"clutches": [{"name": c.name, "a": c.a, "b": c.b} for c in self.spec.clutches],
"brakes": [{"name": b.name, "member": b.member} for b in self.spec.brakes],
"sprags": [
{
"name": s.name,
"member": s.member,
"hold_direction": s.hold_direction,
"locked_when_engaged": s.locked_when_engaged,
}
for s in self.spec.sprags
],
"permanent_ties": [list(x) for x in self.spec.permanent_ties],
"schedule_states": list(self.schedule.states.keys()),
"notes": self.spec.notes,
"meta": self.spec.meta,
}
[docs]
def normalize_state_name(self, state: str) -> str:
return normalize_state_name(state, self.spec.state_aliases)
[docs]
def available_states(self) -> list[str]:
if self.schedule.display_order:
ordered = [s for s in self.schedule.display_order if s in self.schedule.states]
tail = [s for s in self.schedule.states if s not in ordered]
return ordered + tail
if self.spec.display_order:
ordered = [s for s in self.spec.display_order if s in self.schedule.states]
tail = [s for s in self.schedule.states if s not in ordered]
return ordered + tail
return list(self.schedule.states.keys())
def _manual_neutral_speeds(self, *, input_speed: float) -> dict[str, float]:
speeds: dict[str, float] = {}
for name in self.spec.all_member_names():
speeds[name] = 0.0
speeds[self.spec.input_member] = float(input_speed)
speeds[self.spec.output_member] = 0.0
return speeds
[docs]
def solve_state(self, state: str, *, input_speed: float = 1.0) -> GenericSolveResult:
resolved = self.normalize_state_name(state)
if resolved.lower() == "all":
raise TransmissionAppError("solve_state() expects one state, not 'all'.")
if resolved not in self.schedule.states:
valid = ", ".join(self.available_states())
raise TransmissionAppError(f"Unknown state '{state}'. Valid states: {valid}")
state_spec = self.schedule.states[resolved]
if state_spec.manual_neutral:
return GenericSolveResult(
state=resolved,
engaged=tuple(state_spec.display_elements),
ok=True,
ratio=0.0,
speeds=self._manual_neutral_speeds(input_speed=float(input_speed)),
notes=state_spec.notes,
solver_path="core_generic_json_builder",
status="manual_neutral",
message="State reported through manual-neutral convention.",
)
solver, _members, clutch_map, brake_map, sprag_map = self.build_solver()
solver.release_all()
for elem in state_spec.active_constraints:
if elem in clutch_map:
clutch_map[elem].engage()
elif elem in brake_map:
brake_map[elem].engage()
elif elem in sprag_map:
sprag_map[elem].engage()
else:
raise TransmissionAppError(f"Internal error: unknown shift element '{elem}'.")
report = solver.solve_report(self.spec.input_member, input_speed=float(input_speed))
speeds = dict(report.member_speeds)
out_speed = speeds.get(self.spec.output_member)
ratio: float | None = None
status = report.classification.status
message = report.classification.message
if out_speed is not None and abs(out_speed) > 1.0e-12:
ratio = float(input_speed) / float(out_speed)
if report.ok:
status = "output_determined"
elif out_speed is not None and abs(out_speed) <= 1.0e-12:
ratio = None
if report.ok:
status = "output_zero"
return GenericSolveResult(
state=resolved,
engaged=tuple(state_spec.display_elements),
ok=bool(report.ok),
ratio=ratio,
speeds=speeds,
notes=state_spec.notes,
solver_path="core_generic_json_builder",
status=status,
message=message,
)
[docs]
def solve(self, *, state: str, input_speed: float = 1.0) -> dict[str, GenericSolveResult]:
resolved = self.normalize_state_name(state)
if resolved.lower() != "all":
res = self.solve_state(resolved, input_speed=input_speed)
return {res.state: res}
out: dict[str, GenericSolveResult] = {}
for s in self.available_states():
out[s] = self.solve_state(s, input_speed=input_speed)
return out