package reactive

import (
	"context"
	"sync/atomic"
)

func SwitchMap[T, R any](source Observable[T], project func(T) Observable[R]) Observable[R] {
	ctx, cancel := context.WithCancel(context.Background())

	return &switchMap[T, R]{ctx, cancel, source, project, false, atomic.Bool{}}
}

type switchMap[T, R any] struct {
	ctx            context.Context
	cancel         context.CancelFunc
	source         Observable[T]
	project        func(T) Observable[R]
	completed      bool
	innerCompleted atomic.Bool
}

func (s *switchMap[T, R]) reset() {
	s.cancel()
	s.ctx, s.cancel = context.WithCancel(context.Background())
}

func (s *switchMap[T, R]) Observe(ctx context.Context) <-chan Notification[R] {
	observing := s.source.Observe(ctx)
	res := make(chan Notification[R])

	go func() {
		for notification := range observing {
			switch notification.Kind() {
			case NextNotification:
				s.innerCompleted.Store(false)
				s.reset()
				inner := s.project(notification.Value())
				go func() {
					Copy(inner.Observe(s.ctx), res)
					s.innerCompleted.Store(true)
				}()
			case ErrorNotification:
				res <- MakeError[R](notification.Error())
			}
		}

		if s.innerCompleted.Load() {
			close(res)
		}
	}()

	return res
}