package sqlingo

import (
	"context"
	"database/sql"
	"errors"
	"fmt"
	"reflect"
)

type insertStatus struct {
	method                          string
	scope                           scope
	fields                          []Field
	values                          []interface{}
	models                          []interface{}
	onDuplicateKeyUpdateAssignments []assignment
	ctx                             context.Context
}

type insertWithTable interface {
	Fields(fields ...Field) insertWithValues
	Values(values ...interface{}) insertWithValues
	Models(models ...interface{}) insertWithModels
}

type insertWithValues interface {
	toInsertWithContext
	toInsertFinal
	Values(values ...interface{}) insertWithValues
	OnDuplicateKeyIgnore() toInsertFinal
	OnDuplicateKeyUpdate() insertWithOnDuplicateKeyUpdateBegin
}

type insertWithModels interface {
	toInsertWithContext
	toInsertFinal
	Models(models ...interface{}) insertWithModels
	OnDuplicateKeyIgnore() toInsertFinal
	OnDuplicateKeyUpdate() insertWithOnDuplicateKeyUpdateBegin
}

type insertWithOnDuplicateKeyUpdateBegin interface {
	Set(Field Field, value interface{}) insertWithOnDuplicateKeyUpdate
	SetIf(condition bool, Field Field, value interface{}) insertWithOnDuplicateKeyUpdate
}

type insertWithOnDuplicateKeyUpdate interface {
	toInsertWithContext
	toInsertFinal
	Set(Field Field, value interface{}) insertWithOnDuplicateKeyUpdate
	SetIf(condition bool, Field Field, value interface{}) insertWithOnDuplicateKeyUpdate
}

type toInsertWithContext interface {
	WithContext(ctx context.Context) toInsertFinal
}

type toInsertFinal interface {
	GetSQL() (string, error)
	Execute() (result sql.Result, err error)
}

func (d *database) InsertInto(table Table) insertWithTable {
	return insertStatus{method: "INSERT", scope: scope{Database: d, Tables: []Table{table}}}
}

func (d *database) ReplaceInto(table Table) insertWithTable {
	return insertStatus{method: "REPLACE", scope: scope{Database: d, Tables: []Table{table}}}
}

func (s insertStatus) Fields(fields ...Field) insertWithValues {
	s.fields = fields
	return s
}

func (s insertStatus) Values(values ...interface{}) insertWithValues {
	s.values = append([]interface{}{}, s.values...)
	s.values = append(s.values, values)
	return s
}

func addModel(models *[]Model, model interface{}) error {
	if model, ok := model.(Model); ok {
		*models = append(*models, model)
		return nil
	}

	value := reflect.ValueOf(model)
	switch value.Kind() {
	case reflect.Ptr:
		value = reflect.Indirect(value)
		return addModel(models, value.Interface())
	case reflect.Slice, reflect.Array:
		for i := 0; i < value.Len(); i++ {
			elem := value.Index(i)
			addr := elem.Addr()
			inter := addr.Interface()
			if err := addModel(models, inter); err != nil {
				return err
			}
		}
		return nil
	default:
		return fmt.Errorf("unknown model type (kind = %d)", value.Kind())
	}
}

func (s insertStatus) Models(models ...interface{}) insertWithModels {
	s.models = models
	return s
}

func (s insertStatus) OnDuplicateKeyUpdate() insertWithOnDuplicateKeyUpdateBegin {
	return s
}

func (s insertStatus) SetIf(condition bool, field Field, value interface{}) insertWithOnDuplicateKeyUpdate {
	if condition {
		return s.Set(field, value)
	}
	return s
}

func (s insertStatus) Set(field Field, value interface{}) insertWithOnDuplicateKeyUpdate {
	s.onDuplicateKeyUpdateAssignments = append([]assignment{}, s.onDuplicateKeyUpdateAssignments...)
	s.onDuplicateKeyUpdateAssignments = append(s.onDuplicateKeyUpdateAssignments, assignment{
		field: field,
		value: value,
	})
	return s
}

func (s insertStatus) OnDuplicateKeyIgnore() toInsertFinal {
	firstField := s.scope.Tables[0].GetFields()[0]
	return s.OnDuplicateKeyUpdate().Set(firstField, firstField)
}

func (s insertStatus) GetSQL() (string, error) {
	var fields []Field
	var values []interface{}
	if len(s.models) > 0 {
		models := make([]Model, 0, len(s.models))
		for _, model := range s.models {
			if err := addModel(&models, model); err != nil {
				return "", err
			}
		}

		if len(models) > 0 {
			fields = models[0].GetTable().GetFields()
			for _, model := range models {
				if model.GetTable().GetName() != s.scope.Tables[0].GetName() {
					return "", errors.New("invalid table from model")
				}
				values = append(values, model.GetValues())
			}
		}
	} else {
		if len(s.fields) == 0 {
			fields = s.scope.Tables[0].GetFields()
		} else {
			fields = s.fields
		}
		values = s.values
	}

	if len(values) == 0 {
		return "/* INSERT without VALUES */ DO 0", nil
	}

	tableSql := s.scope.Tables[0].GetSQL(s.scope)
	fieldsSql, err := commaFields(s.scope, fields)
	if err != nil {
		return "", err
	}
	valuesSql, err := commaValues(s.scope, values)
	if err != nil {
		return "", err
	}

	sqlString := s.method + " INTO " + tableSql + " (" + fieldsSql + ") VALUES " + valuesSql
	if len(s.onDuplicateKeyUpdateAssignments) > 0 {
		assignmentsSql, err := commaAssignments(s.scope, s.onDuplicateKeyUpdateAssignments)
		if err != nil {
			return "", err
		}
		sqlString += " ON DUPLICATE KEY UPDATE " + assignmentsSql
	}

	return sqlString, nil
}

func (s insertStatus) WithContext(ctx context.Context) toInsertFinal {
	s.ctx = ctx
	return s
}

func (s insertStatus) Execute() (result sql.Result, err error) {
	sqlString, err := s.GetSQL()
	if err != nil {
		return nil, err
	}
	return s.scope.Database.ExecuteContext(s.ctx, sqlString)
}