package nject

import (
	"fmt"
	"testing"

	"github.com/stretchr/testify/assert"
)

func TestSaveToInvalid(t *testing.T) {
	f := func() {}
	cases := []struct {
		want  string
		thing interface{}
	}{
		{
			want: "not a valid pointer",
		},
		{
			want:  "is nil",
			thing: (*string)(nil),
		},
		{
			want:  "is not a pointer",
			thing: 7,
		},
		{
			want:  "may not be a pointer to a function",
			thing: &f,
		},
	}
	for _, tc := range cases {
		t.Log(tc.want)
		_, err := SaveTo(tc.thing)
		if assert.Error(t, err, tc.want) {
			assert.Contains(t, err.Error(), tc.want)
		}
	}
}

func TestSaveToValid(t *testing.T) {
	type foo struct {
		i int
	}
	var fooDst foo
	var fooPtrDst *foo
	cases := []struct {
		inject interface{}
		ptr    interface{}
		check  func()
	}{
		{
			inject: foo{i: 7},
			ptr:    &fooDst,
			check: func() {
				assert.Equal(t, foo{i: 7}, fooDst)
			},
		},
		{
			inject: &foo{i: 7},
			ptr:    &fooPtrDst,
			check: func() {
				assert.Equal(t, foo{i: 7}, *fooPtrDst)
			},
		},
	}
	for i, tc := range cases {
		err := Run("x",
			tc.inject,
			MustSaveTo(tc.ptr),
		)
		if assert.NoErrorf(t, err, "%d", i) {
			assert.NotPanics(t, tc.check, "check")
		}
	}
}

func TestCurry(t *testing.T) {
	t.Parallel()
	seq := Sequence("available",
		func() string { return "foo" },
		func() int { return 3 },
		func() uint { return 7 },
	)
	var c1 func(string) string
	var c2 func(bool, bool, string, string) string
	cases := []struct {
		name     string
		fail     string
		curry    interface{}
		check    func(t *testing.T)
		original interface{}
	}{
		{
			curry: &c1,
			original: func(x int, s string) string {
				return fmt.Sprintf("%s-%d", s, x)
			},
			check: func(t *testing.T) {
				assert.Equal(t, "bar-3", c1("bar"))
			},
		},
		{
			curry: &c1,
			original: func(x int, s string, u uint) string {
				return fmt.Sprintf("%s-%d/%d", s, x, u)
			},
			check: func(t *testing.T) {
				assert.Equal(t, "bar-3/7", c1("bar"))
			},
		},
		{
			curry: &c2,
			original: func(b1 bool, x int, b2 bool, s1 string, s2 string, u uint) string {
				return fmt.Sprintf("%v-%d-%v %s %s-%d", b1, x, b2, s1, s2, u)
			},
			check: func(t *testing.T) {
				assert.Equal(t, "true-3-false bee boot-7", c2(true, false, "bee", "boot"))
			},
		},
		{
			curry:    &c2,
			original: func(b1 bool, s1 string, s2 string) string { return "" },
			fail:     "curried function must take fewer arguments",
		},
		{
			curry:    &c2,
			original: func(b1 bool, b2 bool, b3 bool, s1 string, s2 string) string { return "" },
			fail:     "original function takes more arguments of type bool",
		},
		{
			name:  "no original",
			curry: &c2,
			fail:  "original function is not a valid value",
		},
		{
			name:     "no curry",
			original: func(b1 bool, b2 bool, b3 bool, s1 string, s2 string) string { return "" },
			fail:     "curried function is not a valid value",
		},
		{
			name:     "non-pointer",
			curry:    7,
			original: func(b1 bool, b2 bool, b3 bool, s1 string, s2 string) string { return "" },
			fail:     "pointer (to a function)",
		},
		{
			name:     "non-func",
			curry:    seq,
			original: func(b1 bool, b2 bool, b3 bool, s1 string, s2 string) string { return "" },
			fail:     "pointer to a function",
		},
		{
			curry:    &c2,
			original: "original non-func",
			fail:     "first argument to Curry must be a function",
		},
		{
			name:     "nil",
			curry:    (*func())(nil),
			original: func(string) {},
			fail:     "pointer to curried function cannot be nil",
		},
		{
			curry:    &c1,
			original: func(string) {},
			fail:     "same number of outputs",
		},
		{
			curry: &c2,
			original: func(b1 bool, x int, b2 bool, s1 string, s2 string, u uint) int {
				return 22
			},
			fail: "return value #1 has a different type",
		},
		{
			curry: &c1,
			original: func(i1 int, i2 int, s string) string {
				return "foo"
			},
			fail: "cannot curry the same type (int) more than once",
		},
		{
			curry: &c1,
			original: func(uint, int) string {
				return "foo"
			},
			fail: "not all of the string inputs to the curried function were used",
		},
		{
			curry: &c1,
			original: func(s string, inner func(), i int) string {
				return fmt.Sprintf("%s-%d", s, i)
			},
			fail: "may not be a function",
		},
	}

	for _, tc := range cases {
		name := tc.name
		if name == "" {
			name = fmt.Sprintf("%T", tc.original)
		}
		t.Run(name, func(t *testing.T) {
			var called bool
			p, err := Curry(tc.original, tc.curry)
			if tc.fail != "" && err != nil {
				assert.Contains(t, err.Error(), tc.fail, "curry")
				assert.Panics(t, func() {
					_ = MustCurry(tc.original, tc.curry)
				}, "curry")
				return
			} else {
				//nolint:testifylint
				if !assert.NoError(t, err, "curry") {
					return
				}
			}
			err = Run(name, seq, p, func() { called = true })
			if tc.fail != "" {
				if assert.Error(t, err, "run") {
					assert.Contains(t, err.Error(), tc.fail, "run")
				}
				return
			}
			assert.True(t, called, "called")
			tc.check(t)
		})
	}
}