add support for multiple source states in transitions
Some checks failed
Machine Tests / test (push) Waiting to run
Lint / golangci-lint (push) Has been cancelled

This commit is contained in:
Matthew Rich 2024-05-20 12:24:24 -07:00
parent fbffd2b947
commit e533a67303
4 changed files with 21 additions and 15 deletions

View File

@ -11,7 +11,7 @@ type State string
type Stater interface { type Stater interface {
AddStates(name ...State) AddStates(name ...State)
GetState(name string) State GetState(name string) State
AddTransition(trigger string, source State, dest State) AddTransition(trigger string, source []State, dest State)
AddSubscription(transition string, subscription Subscriber) error AddSubscription(transition string, subscription Subscriber) error
AddModel(m Modeler) AddModel(m Modeler)
Trigger(transition string) error Trigger(transition string) error
@ -28,11 +28,15 @@ func New(initial State) Stater {
return &Definition{model: NewModel(initial), triggers: make(map[string]Transitioner)} return &Definition{model: NewModel(initial), triggers: make(map[string]Transitioner)}
} }
func States(name ...State) []State {
return name
}
func (d *Definition) AddStates(name ...State) { func (d *Definition) AddStates(name ...State) {
d.states = append(d.states, name...) d.states = append(d.states, name...)
} }
func (d *Definition) AddTransition(trigger string, source State, dest State) { func (d *Definition) AddTransition(trigger string, source []State, dest State) {
d.triggers[trigger] = NewTransition(trigger, source, dest) d.triggers[trigger] = NewTransition(trigger, source, dest)
} }

View File

@ -39,7 +39,7 @@ func TestMachineCurrentState(t *testing.T) {
func TestMachineAddTransition(t *testing.T) { func TestMachineAddTransition(t *testing.T) {
s := setupStater("disconnected") s := setupStater("disconnected")
s.AddStates("disconnected", "start_connection", "connected") s.AddStates("disconnected", "start_connection", "connected")
s.AddTransition("connect", "disconnected", "start_connection") s.AddTransition("connect", States("disconnected"), "start_connection")
s.AddModel(setupModel("disconnected")) s.AddModel(setupModel("disconnected"))
// worker gets a trigger message // worker gets a trigger message
assert.Nil(t, s.Trigger("connect")) assert.Nil(t, s.Trigger("connect"))
@ -54,7 +54,7 @@ func TestMachineAddSubscription(t *testing.T) {
x := setupSubscriber() x := setupSubscriber()
s := setupStater("disconnected") s := setupStater("disconnected")
s.AddStates("disconnected", "start_connection", "connected") s.AddStates("disconnected", "start_connection", "connected")
s.AddTransition("connect", "disconnected", "start_connection") s.AddTransition("connect", States("disconnected"), "start_connection")
assert.Nil(t, s.AddSubscription("connect", x)) assert.Nil(t, s.AddSubscription("connect", x))
s.AddModel(setupModel("disconnected")) s.AddModel(setupModel("disconnected"))
assert.Nil(t, s.Trigger("connect")) assert.Nil(t, s.Trigger("connect"))

View File

@ -15,12 +15,12 @@ type Transitioner interface {
type Transition struct { type Transition struct {
trigger string trigger string
source State source []State
dest State dest State
subscriptions []Subscriber subscriptions []Subscriber
} }
func NewTransition(trigger string, source State, dest State) Transitioner { func NewTransition(trigger string, source []State, dest State) Transitioner {
return &Transition{trigger: trigger, source: source, dest: dest} return &Transition{trigger: trigger, source: source, dest: dest}
} }
@ -30,7 +30,8 @@ func (r *Transition) Run(m Modeler) error {
if currentState == r.dest { if currentState == r.dest {
return nil return nil
} }
if currentState == r.source || r.source == "*" { for _, transitionSource := range r.source {
if currentState == transitionSource || transitionSource == "*" {
res := m.ChangeState(r.dest) res := m.ChangeState(r.dest)
if res == currentState { if res == currentState {
r.Notify(EXITSTATEEVENT, currentState, r.dest) r.Notify(EXITSTATEEVENT, currentState, r.dest)
@ -38,6 +39,7 @@ func (r *Transition) Run(m Modeler) error {
return nil return nil
} }
} }
}
return fmt.Errorf("Transition from %s to %s failed on model state %s", r.source, r.dest, currentState) return fmt.Errorf("Transition from %s to %s failed on model state %s", r.source, r.dest, currentState)
} }

View File

@ -9,7 +9,7 @@ import (
) )
func setupTransition() Transitioner { func setupTransition() Transitioner {
t := NewTransition("open", "closed", "open") t := NewTransition("open", States("closed"), "open")
if t == nil { if t == nil {
log.Fatal("Failed creating new transition") log.Fatal("Failed creating new transition")
} }
@ -22,7 +22,7 @@ func setupSubscriber() Subscriber {
} }
func TestNewTransition(t *testing.T) { func TestNewTransition(t *testing.T) {
s := NewTransition("connect", "disconnected", "connected") s := NewTransition("connect", States("disconnected"), "connected")
if s == nil { if s == nil {
t.Errorf("Failed creating new transition") t.Errorf("Failed creating new transition")
} }
@ -39,7 +39,7 @@ func TestTransitionExecution(t *testing.T) {
} }
func TestTransitionWildcard(t *testing.T) { func TestTransitionWildcard(t *testing.T) {
tr := NewTransition("open", "*", "opened") tr := NewTransition("open", States("*"), "opened")
assert.NotNil(t, tr) assert.NotNil(t, tr)
m := setupModel("closed") m := setupModel("closed")
assert.Nil(t, tr.Run(m)) assert.Nil(t, tr.Run(m))