From e533a67303a31054870567b78178b33716f091d8 Mon Sep 17 00:00:00 2001 From: Matthew Rich Date: Mon, 20 May 2024 12:24:24 -0700 Subject: [PATCH] add support for multiple source states in transitions --- machine.go | 8 ++++++-- machine_test.go | 4 ++-- transition.go | 18 ++++++++++-------- transition_test.go | 6 +++--- 4 files changed, 21 insertions(+), 15 deletions(-) diff --git a/machine.go b/machine.go index 771ef8f..286f531 100644 --- a/machine.go +++ b/machine.go @@ -11,7 +11,7 @@ type State string type Stater interface { AddStates(name ...State) GetState(name string) State - AddTransition(trigger string, source State, dest State) + AddTransition(trigger string, source []State, dest State) AddSubscription(transition string, subscription Subscriber) error AddModel(m Modeler) Trigger(transition string) error @@ -28,11 +28,15 @@ func New(initial State) Stater { return &Definition{model: NewModel(initial), triggers: make(map[string]Transitioner)} } +func States(name ...State) []State { + return name +} + func (d *Definition) AddStates(name ...State) { d.states = append(d.states, name...) } -func (d *Definition) AddTransition(trigger string, source State, dest State) { +func (d *Definition) AddTransition(trigger string, source []State, dest State) { d.triggers[trigger] = NewTransition(trigger, source, dest) } diff --git a/machine_test.go b/machine_test.go index fd5873b..3b811cb 100644 --- a/machine_test.go +++ b/machine_test.go @@ -39,7 +39,7 @@ func TestMachineCurrentState(t *testing.T) { func TestMachineAddTransition(t *testing.T) { s := setupStater("disconnected") s.AddStates("disconnected", "start_connection", "connected") - s.AddTransition("connect", "disconnected", "start_connection") + s.AddTransition("connect", States("disconnected"), "start_connection") s.AddModel(setupModel("disconnected")) // worker gets a trigger message assert.Nil(t, s.Trigger("connect")) @@ -54,7 +54,7 @@ func TestMachineAddSubscription(t *testing.T) { x := setupSubscriber() s := setupStater("disconnected") s.AddStates("disconnected", "start_connection", "connected") - s.AddTransition("connect", "disconnected", "start_connection") + s.AddTransition("connect", States("disconnected"), "start_connection") assert.Nil(t, s.AddSubscription("connect", x)) s.AddModel(setupModel("disconnected")) assert.Nil(t, s.Trigger("connect")) diff --git a/transition.go b/transition.go index 9235a7d..1ee4460 100644 --- a/transition.go +++ b/transition.go @@ -15,12 +15,12 @@ type Transitioner interface { type Transition struct { trigger string - source State + source []State dest State subscriptions []Subscriber } -func NewTransition(trigger string, source State, dest State) Transitioner { +func NewTransition(trigger string, source []State, dest State) Transitioner { return &Transition{trigger: trigger, source: source, dest: dest} } @@ -30,12 +30,14 @@ func (r *Transition) Run(m Modeler) error { if currentState == r.dest { return nil } - if currentState == r.source || r.source == "*" { - res := m.ChangeState(r.dest) - if res == currentState { - r.Notify(EXITSTATEEVENT, currentState, r.dest) - r.Notify(ENTERSTATEEVENT, currentState, r.dest) - return nil + for _, transitionSource := range r.source { + if currentState == transitionSource || transitionSource == "*" { + res := m.ChangeState(r.dest) + if res == currentState { + r.Notify(EXITSTATEEVENT, currentState, r.dest) + r.Notify(ENTERSTATEEVENT, currentState, r.dest) + return nil + } } } return fmt.Errorf("Transition from %s to %s failed on model state %s", r.source, r.dest, currentState) diff --git a/transition_test.go b/transition_test.go index 50da08c..5a780ff 100644 --- a/transition_test.go +++ b/transition_test.go @@ -9,7 +9,7 @@ import ( ) func setupTransition() Transitioner { - t := NewTransition("open", "closed", "open") + t := NewTransition("open", States("closed"), "open") if t == nil { log.Fatal("Failed creating new transition") } @@ -22,7 +22,7 @@ func setupSubscriber() Subscriber { } func TestNewTransition(t *testing.T) { - s := NewTransition("connect", "disconnected", "connected") + s := NewTransition("connect", States("disconnected"), "connected") if s == nil { t.Errorf("Failed creating new transition") } @@ -39,7 +39,7 @@ func TestTransitionExecution(t *testing.T) { } func TestTransitionWildcard(t *testing.T) { - tr := NewTransition("open", "*", "opened") + tr := NewTransition("open", States("*"), "opened") assert.NotNil(t, tr) m := setupModel("closed") assert.Nil(t, tr.Run(m))