pylcm models involve two independent layers of transitions that compose during backward induction:
State transitions — how a state variable arrived at its current value (attached to grids)
Regime transitions — which regime the agent enters next period (attached to the regime itself)
The two layers have opposite orientations:
Regime transitions are forward-looking. The
transitionfield on aRegimeanswers “where does the agent go next?” It lives on the source regime and points toward the future.State transitions are backward-looking. The
transitionparameter on a grid answers “how did this state variable reach its current value?” It lives on the grid that receives the value. In multi-regime models, per-boundary mappings are placed on the target regime’s grid.
This notebook explains how each layer works, how per-boundary transitions resolve cross-regime state mismatches, and how everything composes in the Bellman equation.
import jax
import jax.numpy as jnp
from lcm import (
DiscreteGrid,
DiscreteMarkovGrid,
LinSpacedGrid,
Regime,
RegimeTransition,
categorical,
)
from lcm.typing import (
ContinuousState,
DiscreteState,
FloatND,
ScalarInt,
)State Transition Mechanics¶
State transitions are attached directly to grid objects via the transition
parameter. There are four cases:
| Grid configuration | Behavior |
|---|---|
transition=some_func | Deterministic: |
transition=None | Fixed: (identity auto-generated) |
DiscreteMarkovGrid(transition=func) | Stochastic: probability-weighted expectation |
Shock grids (lcm.shocks.*) | Intrinsic transitions with interpolated weights |
Deterministic state transitions¶
A state transition function defines how the current state value was determined from last period’s states, actions, and parameters. The function’s argument names are resolved from the regime’s namespace.
def next_wealth(
wealth: ContinuousState,
consumption: ContinuousState,
interest_rate: float,
) -> ContinuousState:
return (1 + interest_rate) * (wealth - consumption)
# Attached to the grid:
wealth_grid = LinSpacedGrid(start=0, stop=100, n_points=50, transition=next_wealth)Fixed states and identity transitions¶
When transition=None, pylcm auto-generates an _IdentityTransition internally.
This shows up in regime.get_all_functions() under the key "next_<state_name>".
@categorical
class EducationLevel:
low: int
high: int
@categorical
class RegimeId:
working: int
retired: int
def utility(wealth: ContinuousState) -> FloatND:
return jnp.log(wealth + 1.0)
def next_regime() -> ScalarInt:
return RegimeId.retired
regime = Regime(
transition=RegimeTransition(next_regime),
states={
"education": DiscreteGrid(EducationLevel, transition=None),
"wealth": LinSpacedGrid(start=0, stop=50, n_points=10, transition=None),
},
functions={"utility": utility},
)
all_funcs = regime.get_all_functions()
print("Function keys:", list(all_funcs.keys()))
print("next_education type:", type(all_funcs["next_education"]).__name__)
print("next_wealth type: ", type(all_funcs["next_wealth"]).__name__)Function keys: ['utility', 'H', 'next_education', 'next_wealth', 'next_regime']
next_education type: _IdentityTransition
next_wealth type: _IdentityTransition
Both fixed states produce _IdentityTransition objects. These are marked with
_is_auto_identity = True so that validation can distinguish them from
user-provided transitions.
Stochastic state transitions (DiscreteMarkovGrid)¶
For DiscreteMarkovGrid, the transition function returns a probability array over
the categories. During the solve step, pylcm computes a probability-weighted
expectation over next-period states:
@categorical
class Health:
bad: int
good: int
def health_transition(health: DiscreteState) -> FloatND:
return jnp.where(
health == Health.good,
jnp.array([0.1, 0.9]), # good → 90% stay good
jnp.array([0.6, 0.4]), # bad → 40% recover
)
health_grid = DiscreteMarkovGrid(Health, transition=health_transition)
# Inspect the transition probabilities
for state_name, code in [("bad", Health.bad), ("good", Health.good)]:
probs = health_transition(jnp.array(code))
print(f"P(next | {state_name}) = {probs}")P(next | bad) = [0.6 0.4]
P(next | good) = [0.1 0.9]
Shock grids¶
Shock grids (from lcm.shocks.iid and lcm.shocks.ar1) have intrinsic
transitions computed from the distribution. For IID shocks, the transition
probabilities are the same regardless of the current value. For AR(1) shocks,
probabilities depend on the current state.
Shock grids do not accept a transition parameter — their transitions are
built-in.
import lcm.shocks.iid
shock = lcm.shocks.iid.Normal(
n_points=5, gauss_hermite=False, mu=0.0, sigma=1.0, n_std=2.5
)
print("Grid points:", shock.to_jax())Grid points: [-2.5 -1.25 0. 1.25 2.5 ]
Per-Boundary State Transitions¶
When a discrete state has different categories across regimes, a simple callable transition is not enough — you need to map from one category set to another at the regime boundary.
The solution: a mapping transition keyed by (source_regime, target_regime)
pairs, placed on the target regime’s grid.
Example: different health categories¶
Suppose working life has three health states (disabled, bad, good) but retirement only has two (bad, good). The transition from working to retired needs an explicit mapping.
@categorical
class HealthWorking:
disabled: int
bad: int
good: int
@categorical
class HealthRetired:
bad: int
good: int
def map_working_to_retired(health: DiscreteState) -> DiscreteState:
"""Map 3-category working health to 2-category retired health."""
return jnp.where(
health == HealthWorking.good,
HealthRetired.good,
HealthRetired.bad,
)
# Verify the mapping
for name, code in [("disabled", 0), ("bad", 1), ("good", 2)]:
result = map_working_to_retired(jnp.array(code))
print(f"working {name} ({code}) → retired code {int(result)}")working disabled (0) → retired code 0
working bad (1) → retired code 0
working good (2) → retired code 1
The mapping is placed on the target regime’s grid:
health_retired_grid = DiscreteGrid(
HealthRetired,
transition={
("working", "retired"): map_working_to_retired,
},
)Resolution priority¶
When resolving which transition function to use at a regime boundary (source, target), pylcm checks (in order):
Target grid mapping for
(source, target)Source grid mapping for
(source, target)Source grid’s callable transition
Target grid’s callable transition
Auto-generated identity (if categories match)
If the categories differ across regimes and no explicit mapping is found,
ModelInitializationError is raised.
Parameterized per-boundary transitions¶
Per-boundary mapping functions can take parameters beyond the state variable itself. A common use case is a continuous state whose transition law differs across regime boundaries — for example, wealth that grows at a rate specific to the target regime.
When pylcm resolves a per-boundary transition from the target grid’s mapping
(priority 1 above), any parameters in that function are looked up in the
target regime’s parameter template. This means the user specifies the
parameter value under the target regime in the params dict, and pylcm
automatically routes it to the transition function at the boundary.
The rule is simple: whoever owns the mapping owns the parameters. Since per-boundary mappings live on the target regime’s grid, their parameters come from the target regime.
Example: regime-specific growth rate¶
Consider a two-regime model (phase 1 → phase 2) where wealth grows at a rate
that is specific to phase 2. The transition function on phase 2’s wealth grid
takes a growth_rate parameter:
def next_wealth_at_boundary(
wealth: ContinuousState,
growth_rate: float,
) -> ContinuousState:
"""Wealth transition at the phase1 → phase2 boundary.
The growth_rate parameter is resolved from phase2's params template,
because the mapping lives on phase2's grid.
"""
return (1 + growth_rate) * wealth
# Phase 2's wealth grid declares the per-boundary mapping:
phase2_wealth_grid = LinSpacedGrid(
start=0,
stop=100,
n_points=20,
transition={
("phase1", "phase2"): next_wealth_at_boundary,
},
)Because the mapping {("phase1", "phase2"): next_wealth_at_boundary} lives on
phase 2’s grid, the growth_rate parameter appears in phase 2’s parameter
template. The user supplies it under "phase2" in the params dict:
params = {
"phase1": {...},
"phase2": {
"next_wealth": {"growth_rate": 0.05},
...
},
}Internally, pylcm detects that the transition was resolved from the target
grid’s mapping and renames the parameter to a cross-boundary qualified name
(e.g., phase2__next_wealth__growth_rate). At solve and simulation time, the
value is looked up from internal_params["phase2"] — not from "phase1" —
even though the transition is evaluated as part of phase 1’s backward induction
step.
Internals: How Cross-Boundary Parameters Are Processed¶
The sections above describe the user-facing rule: whoever owns the mapping owns the parameters. This section explains the internal machinery that makes it work.
The problem: flat function signatures¶
All compiled Q functions have flat parameter signatures. The dags library
composes utility, constraints, and transitions into a single function whose
parameters are flat qualified names (qnames) like utility__risk_aversion or
next_wealth__interest_rate. These are passed as **kwargs at the call site:
# In solve_brute.py:
V_arr = max_Q_over_a(
**state_action_space.states,
**state_action_space.actions,
next_V_arr=next_V_arr,
**internal_params[regime_name], # flat kwargs
)When a per-boundary transition is target-originated (e.g., phase2’s grid owns
the (phase1, phase2) mapping), its transition function is compiled into
phase1’s Q function — but the parameter values live in
internal_params["phase2"]. The flat signature means we can’t just pass a nested
dict; the value must appear as a named kwarg in phase1’s params.
Step 1: Transition resolution and origin tracking¶
_resolve_state_transition (in regime_processing.py) resolves which transition
function to use at each regime boundary. When the winning function comes from the
target grid’s mapping dict (priority 1), it is recorded in a
target_originated_transitions set for that source regime.
For the phase1 → phase2 example, phase1’s resolved transitions include the
function next_wealth_at_boundary under the flat name "phase2__next_wealth".
Because it came from phase2’s grid, it is marked as target-originated.
Step 2: Parameter renaming (_rename_target_originated_transition)¶
Normal (source-originated) transitions rename growth_rate →
next_wealth__growth_rate — a standard qname within the source regime’s
namespace.
Target-originated transitions use a cross-boundary qname that embeds the
target regime name. _rename_target_originated_transition does the following:
Extracts the target regime name from the flat function name (
"phase2__next_wealth"→"phase2").Looks up the target regime’s parameter template to find which parameters the function uses (e.g.,
growth_rate).Renames each parameter to a cross-boundary qname:
growth_rate→phase2__next_wealth__growth_rate.Records the mapping in a
cross_boundary_paramsdict:cross_boundary_params["phase2__next_wealth__growth_rate"] = ( "phase2", # target regime "next_wealth__growth_rate", # qname in target's namespace )
This renaming avoids collisions with phase1’s own parameters (phase1 might have
its own next_wealth__growth_rate for within-regime transitions).
Step 3: flat_param_names and function compilation¶
After all transitions are renamed, process_regimes computes flat_param_names
— the model-level set of all regime-prefixed flat parameter names:
flat_param_names = frozenset(flatten_to_qnames(params_template))This flattens the nested params_template (keyed by regime → function → arg) into
qualified names like "phase1__utility__risk_aversion" or
"phase2__next_wealth__growth_rate". Cross-boundary qnames are included because
_rename_target_originated_transition already inserted them into the compiled
function signatures. The frozenset is stored on InternalRegime.flat_param_names
and passed to the function compilation pipeline (build_Q_and_F_functions,
build_next_state_simulation_functions).
The compilation pipeline uses flat_param_names for one critical purpose:
determining which arguments to vmap over. In _get_vmap_params:
non_vmap = {"period", "age"} | flat_param_names
return tuple(arg for arg in all_args if arg not in non_vmap)Parameters (including cross-boundary ones) are scalar — they must not be vmapped. Everything else (states, actions) gets vmapped.
Step 4: Runtime resolution (merge_cross_boundary_params)¶
At the top of solve() and simulate(), before any backward induction or
forward simulation begins, merge_cross_boundary_params is called once:
internal_params = merge_cross_boundary_params(internal_params, internal_regimes)This copies values from the target regime’s params into the source regime’s params dict under the cross-boundary qname:
# For phase1's cross_boundary_params:
# "phase2__next_wealth__growth_rate" → ("phase2", "next_wealth__growth_rate")
#
# merge_cross_boundary_params does:
internal_params["phase1"]["phase2__next_wealth__growth_rate"] = (
internal_params["phase2"]["next_wealth__growth_rate"]
)After merging, every call site just uses **internal_params[regime_name] — no
per-call-site cross-boundary logic needed.
Summary of the internal pipeline¶
User defines: phase2 grid transition={("phase1","phase2"): f(wealth, growth_rate)}
│
▼
Resolution: _resolve_state_transition marks it as target-originated
│
▼
Renaming: _rename_target_originated_transition renames
growth_rate → phase2__next_wealth__growth_rate
and records cross_boundary_params mapping
│
▼
Compilation: flat_param_names = own_params ∪ cross_boundary_params.keys()
Q function compiled with phase2__next_wealth__growth_rate
in its flat signature
│
▼
Runtime: merge_cross_boundary_params copies
internal_params["phase2"]["next_wealth__growth_rate"]
→ internal_params["phase1"]["phase2__next_wealth__growth_rate"]
│
▼
Call site: max_Q_over_a(**internal_params["phase1"])
— the cross-boundary value is now a regular flat kwargKey data structures:
| Structure | Location | Purpose |
|---|---|---|
cross_boundary_params | InternalRegime | Maps cross-boundary qnames to (target_regime, target_qname) |
flat_param_names | InternalRegime | Union of own params + cross-boundary qnames; used for vmap exclusion |
merge_cross_boundary_params() | interfaces.py | Resolves values from target into source at runtime entry |
Why not just index the target regime directly? The compiled Q function has a
flat signature determined at model init time. It expects named kwargs — it cannot
accept a nested dict or index into internal_params at runtime. The
cross-boundary machinery bridges the gap between “params are owned by the target
regime” and “compiled functions have flat signatures.”
Regime Transition Mechanics¶
Regime transitions determine which regime the agent enters next period. Internally, both deterministic and stochastic transitions are converted to a uniform probability array format.
Deterministic transitions → one-hot encoding¶
A RegimeTransition wraps a function that returns an integer regime ID.
Internally, _wrap_deterministic_regime_transition converts this to a one-hot
probability array using jax.nn.one_hot:
@categorical
class RegimeIdExample:
working: int
retired: int
dead: int
# Deterministic: retire at age 65
def next_regime_det(age: float, retirement_age: float) -> ScalarInt:
return jnp.where(
age >= retirement_age, RegimeIdExample.retired, RegimeIdExample.working
)
# What pylcm does internally:
regime_idx = next_regime_det(age=50.0, retirement_age=65.0)
one_hot = jax.nn.one_hot(regime_idx, num_classes=3)
print(f"Regime index: {int(regime_idx)}")
print(f"One-hot: {one_hot} (= [P(working), P(retired), P(dead)])")Regime index: 0
One-hot: [1. 0. 0.] (= [P(working), P(retired), P(dead)])
Stochastic transitions → probability array¶
A MarkovRegimeTransition wraps a function that directly returns a probability
array. No conversion is needed — the array is used as-is.
def next_regime_stoch(survival_prob: float) -> FloatND:
"""Alive → [P(working), P(retired), P(dead)]."""
return jnp.array([survival_prob, 0.0, 1 - survival_prob])
probs = next_regime_stoch(survival_prob=0.98)
print(f"Probabilities: {probs} (= [P(working), P(retired), P(dead)])")Probabilities: [0.98 0. 0.02] (= [P(working), P(retired), P(dead)])
After wrapping, the probability array is further converted to a dictionary keyed by
regime name (via _wrap_regime_transition_probs), giving a uniform internal
representation regardless of whether the original transition was deterministic or
stochastic.
How Transitions Compose in the Bellman Equation¶
The value function computation depends on the regime type:
Non-terminal with deterministic regime transition¶
The continuation value comes from a single next-period regime:
where is the deterministically chosen next regime and is the next-period state.
Non-terminal with stochastic regime transition¶
The continuation value is an expectation over possible next regimes:
where is the probability of transitioning to regime .
Adding stochastic state transitions¶
When a state has a Markov transition (DiscreteMarkovGrid) or shock grid, an
additional layer of expectation is added inside the max:
The inner sum handles the stochastic state transition; the outer sum handles the stochastic regime transition. When either is deterministic, its corresponding sum collapses to a single term.
Summary¶
| Component | Deterministic | Stochastic |
|---|---|---|
| State transition | ||
| Regime transition | One-hot single | |
| Internal format | Both converted to probability arrays | — |
The uniform probability format means the backward induction algorithm treats all transitions the same way — deterministic transitions are just the special case where one probability is 1 and the rest are 0.