Skip to content

Commit 22072ac

Browse files
Cast all component parameters to lists in Component.__init__
1 parent b71354a commit 22072ac

File tree

2 files changed

+39
-9
lines changed

2 files changed

+39
-9
lines changed

pymc_extras/statespace/models/structural/core.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -507,12 +507,14 @@ def __init__(
507507
self.k_posdef = k_posdef
508508
self.measurement_error = measurement_error
509509

510-
self.state_names = state_names if state_names is not None else []
511-
self.observed_state_names = observed_state_names if observed_state_names is not None else []
512-
self.data_names = data_names if data_names is not None else []
513-
self.shock_names = shock_names if shock_names is not None else []
514-
self.param_names = param_names if param_names is not None else []
515-
self.exog_names = exog_names if exog_names is not None else []
510+
self.state_names = list(state_names) if state_names is not None else []
511+
self.observed_state_names = (
512+
list(observed_state_names) if observed_state_names is not None else []
513+
)
514+
self.data_names = list(data_names) if data_names is not None else []
515+
self.shock_names = list(shock_names) if shock_names is not None else []
516+
self.param_names = list(param_names) if param_names is not None else []
517+
self.exog_names = list(exog_names) if exog_names is not None else []
516518

517519
self.needs_exog_data = len(self.exog_names) > 0
518520
self.coords = {}
@@ -741,13 +743,15 @@ def make_slice(name, x, o_x):
741743

742744
def _combine_property(self, other, name, allow_duplicates=True):
743745
self_prop = getattr(self, name)
746+
other_prop = getattr(other, name)
747+
744748
if isinstance(self_prop, list) and allow_duplicates:
745-
return self_prop + getattr(other, name)
749+
return self_prop + other_prop
746750
elif isinstance(self_prop, list) and not allow_duplicates:
747-
return self_prop + [x for x in getattr(other, name) if x not in self_prop]
751+
return self_prop + [x for x in other_prop if x not in self_prop]
748752
elif isinstance(self_prop, dict):
749753
new_prop = self_prop.copy()
750-
new_prop.update(getattr(other, name))
754+
new_prop.update(other_prop)
751755
return new_prop
752756

753757
def _combine_component_info(self, other):

tests/statespace/models/structural/test_core.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from functools import partial
2+
13
import numpy as np
24
import pandas as pd
35
import pymc as pm
@@ -162,3 +164,27 @@ def test_extract_multiple_observed(rng):
162164

163165
missing = set(comp_states) - set(expected_states)
164166
assert len(missing) == 0, missing
167+
168+
169+
@pytest.mark.parametrize(
170+
"arg_type", [tuple, list, set, np.array], ids=["tuple", "list", "set", "array"]
171+
)
172+
def test_sequence_type_component_arguments(arg_type):
173+
state_names = list("ABCDEFG")
174+
components = [
175+
st.LevelTrendComponent,
176+
partial(st.CycleComponent, cycle_length=12),
177+
st.AutoregressiveComponent,
178+
partial(st.FrequencySeasonality, season_length=12),
179+
partial(st.TimeSeasonality, season_length=12),
180+
st.MeasurementError,
181+
]
182+
183+
components = [
184+
components[i](observed_state_names=arg_type(state_names))
185+
for i in np.random.choice(len(components), size=3, replace=False)
186+
]
187+
ss_mod = sum(components[1:], start=components[0]).build(verbose=False)
188+
189+
assert ss_mod.k_endog == len(state_names)
190+
assert sorted(ss_mod.observed_states) == sorted(list(state_names))

0 commit comments

Comments
 (0)