package client

import (
	"context"
	"crypto/rand"
	"fmt"
	"io"
	"math"
	"math/big"
	"reflect"
	"sync"
	"time"

	"github.com/golang/protobuf/proto"
	"github.com/google/uuid"
	connection "gitlab.wtotem.net/webtotem/backend-connection-library"
	pb "gitlab.wtotem.net/webtotem/backend-transport-client-quic/pb"
	"gitlab.wtotem.net/webtotem/logger"
)

const (
	GET_CLIENTS_TIMEOUT = 10 * time.Second
	INIT_TIMEOUT        = 3 * time.Second
	CALLBACK_TIMEOUT    = time.Minute

	SUBSCRIBE   = "Subscribe"
	GET_CLIENTS = "GetClients"
	INIT        = "Init"
)

const (
	ERR_CODE_TIMEOUT           = -1
	ERR_CODE_UNEXPECTED        = -11
	ERR_CODE_HANDLER_NOT_FOUND = -12
	ERR_CODE_MALFORMED_MSG     = -13
	ERR_CODE_TARGET_FUNC_ERROR = -14
	ERR_CODE_ALL_WORKERS_BUSY  = -15
)

const (
	ERR_MSG_TIMEOUT           = "Timeout exceeded"
	ERR_MSG_UNEXPECTED        = "Unexpected error"
	ERR_MSG_HANDLER_NOT_FOUND = "Such handler is not registered"
	ERR_MSG_MALFORMED_MSG     = "Malformed message"
	ERR_MSG_ALL_WORKERS_BUSY  = "Max limit for parallel jobs reached"
)

type ITransport interface {
	RegisterHandler(handler interface{}) error
	RegisterEvent(event interface{}, filter ClientInfo) error
	SendRequest(req interface{}, addr int64) (interface{}, error)
	SendRequestWithTimeout(req interface{}, addr int64, timeout time.Duration) (interface{}, error)
	EmitEvent(interface{}) error
	GetClients(info *ClientInfo) ([]int64, error)
	Run() error
	Stop()
	NewError(code int8, desc string) error
	GetSelfId() int64
}

type TransportClient struct {
	handlers        map[string]reflect.Value
	events          map[string]subscription
	subscribers     map[int64]*pb.ClientInfo
	callbacks       map[int64]callback
	info            *pb.ClientInfo
	ctx             context.Context
	cancel          context.CancelFunc
	mu              *sync.RWMutex
	requestsChannel chan func()
	// Address id of client set by transport
	id               int64
	concurrencyLimit uint32
	sess             connection.Session
	errChan          chan error
}

type ClientInfo struct {
	// Available types are: [UNKNOWN, SERVER, CLIENT, STANDALONE]
	Type    string
	Class   int32
	Name    string
	Version int32
}

type Response struct {
	header  *pb.Response
	payload []byte
}

func (t *TransportClient) OnNewSession(sess connection.Session) (connection.SessionEvents, error) {
	t.sess = sess
	return t, nil
}

func (t *TransportClient) OnStreamAccept(header *connection.FrameHeader, stream connection.FrameReadStream) {
	log.Infof("New stream!")

	body, err := io.ReadAll(stream)
	if err != nil {
		t.errChan <- fmt.Errorf("Error! Read error: %v", err)
		return
	}

	log.Infof("Body was read")

	switch header.Type {
	case connection.SystemFrameType:
		return
	case connection.RequestFrameType:
		request := &pb.Request{}
		if err := proto.Unmarshal(header.Header, request); err != nil {
			t.errChan <- fmt.Errorf("failed to unmarshal request frame header. Error: %v", err)
			return
		}

		t.requestHandler(request, body)
	case connection.ResponseFrameType:
		response := &pb.Response{}
		if err := proto.Unmarshal(header.Header, response); err != nil {
			t.errChan <- fmt.Errorf("failed to unmarshal response frame header. Error: %v", err)
			return
		}

		t.responseHandler(response, body)
	default:
		event := &pb.Event{}
		if err := proto.Unmarshal(header.Header, event); err != nil {
			t.errChan <- fmt.Errorf("failed to unmarshal event frame header. Error: %v", err)
			return
		}

		t.eventHandler(event, body)
	}

	// t.errChan <- nil
}

func (t *TransportClient) OnClose(err error) {
	log.Errorf("[OnClose]: Close error: %v", err)
	t.errChan <- err
}

/* DEPRECATED */

func (info *ClientInfo) getPbClientInfoType() pb.ClientInfo_Type {
	switch info.Type {
	case "UNKNOWN":
		return pb.ClientInfo_UNKNOWN
	case "SERVER":
		return pb.ClientInfo_SERVER
	case "CLIENT":
		return pb.ClientInfo_CLIENT
	case "STANDALONE":
		return pb.ClientInfo_STANDALONE
	default:
		return pb.ClientInfo_UNKNOWN
	}
}

type callback struct {
	resp    chan responseCallback
	timeout time.Duration
	atomic  int64
}

type responseCallback struct {
	Response proto.Message
	Error    error
}

type packet struct {
	Name    string
	Payload []byte
}

type subscription struct {
	eventFunc reflect.Value
	target    ClientInfo
}

var log = logger.NewLogger(&logger.LoggerInfo{
	CallerType: logger.LibCaller,
	Layer:      logger.TransportLayer,
	Name:       "clientLibrary",
	ModuleType: logger.OtherType,
})

// NewTransportClient initializes new TransportClient struct instance
func NewTransportClient(info *ClientInfo, host string, port uint16, timeout time.Duration, concurrencyLimit uint32) (*TransportClient, error) {
	clientInfo := &pb.ClientInfo{
		Type:    info.getPbClientInfoType(),
		Class:   info.Class,
		Name:    info.Name,
		Version: info.Version,
	}
	ctx, cancel := context.WithCancel(context.Background())

	tc := &TransportClient{
		handlers:         make(map[string]reflect.Value),
		events:           make(map[string]subscription),
		subscribers:      make(map[int64]*pb.ClientInfo),
		callbacks:        make(map[int64]callback),
		info:             clientInfo,
		requestsChannel:  make(chan func()),
		ctx:              ctx,
		cancel:           cancel,
		mu:               &sync.RWMutex{},
		concurrencyLimit: concurrencyLimit,
		errChan:          make(chan error),
	}

	clientConnLayer := connection.NewConnectionLayer(tc)

	_, err := clientConnLayer.OpenSession(fmt.Sprintf("%s:%d", host, port))
	if err != nil {
		log.Errorf("Error while opening session. Error: %v", err)
	}
	return tc, err
}

// RegisterHandler registers function, executed by other clients.
// Example: RegisterHandler(CheckSite) - registers CheckSite function as a handler for passed request types
func (t *TransportClient) RegisterHandler(handler interface{}) error {
	return t.registerHandler(handler)
}

// RegisterEvent registers events, sent by clients
// Example: RegisterEvent(SiteChecked) - registers SiteChecked event function as a handler for passed event types
func (t *TransportClient) RegisterEvent(event interface{}, filter ClientInfo) error {
	return t.registerEvent(event, filter)
}

// SendRequest sends command to remote client(-s) connected to client
func (t *TransportClient) SendRequest(req interface{}, addr int64) (interface{}, error) {
	return t.sendRequest(req, addr, CALLBACK_TIMEOUT)
}

// SendRequestWithTimeout adds option of custom response timeout setting
func (t *TransportClient) SendRequestWithTimeout(req interface{}, addr int64, timeout time.Duration) (interface{}, error) {
	return t.sendRequest(req, addr, timeout)
}

func (t *TransportClient) sendRequest(req interface{}, addr int64, wait time.Duration) (interface{}, error) {
	requestId := generateRandNumber()
	traceId := generateTraceId()

	packet, err := extractPacket(req)
	if err != nil {
		return nil, err
	}

	cb := callback{
		resp:    make(chan responseCallback),
		timeout: wait,
		atomic:  0,
	}
	// Cleanup callbacks map no matter if error or normal result returns
	defer func() {
		t.mu.Lock()
		delete(t.callbacks, requestId)
		t.mu.Unlock()
	}()

	t.mu.Lock()
	t.callbacks[requestId] = cb
	t.mu.Unlock()

	request := &pb.Request{
		Id:         requestId,
		Address:    addr,
		RequestRpc: packet.Name,
		TraceId:    traceId,
	}

	ctx, cancelFunc := context.WithTimeout(t.ctx, wait)
	defer cancelFunc()

	data, _ := proto.Marshal(request)
	fh := &connection.FrameHeader{
		Type:       connection.RequestFrameType,
		Header:     data,
		BodyLength: uint64(len(packet.Payload)),
	}
	stream, err := t.sess.OpenStream(ctx, fh)
	if err != nil {
		return nil, fmt.Errorf("Failed to open stream: %v", err)
	}

	select {
	case <-ctx.Done():
		return nil, fmt.Errorf("context for request timed out\n")
	default:
		if err != nil {
			return nil, fmt.Errorf("couldn't establish request stream: %s", err)
		}

		_, err = stream.Write(packet.Payload)
		if err != nil {
			return nil, fmt.Errorf("couldn't write request payload into body: %s", err)
		}
		stream.Close()

		time.AfterFunc(cb.timeout, func() {
			select {
			case cb.resp <- responseCallback{nil, fmt.Errorf("%d: %s", ERR_CODE_TIMEOUT, ERR_MSG_TIMEOUT)}:
				log.Errorf("Timeout exceeded")
			default:
			}
		})

		respCb := <-cb.resp
		return respCb.Response, respCb.Error
	}

}

// EmitEvent pushes event to transport in order to notify clients, subscribed to passed event type
func (t *TransportClient) EmitEvent(event interface{}) error {
	traceId := generateTraceId()
	eventId := generateRandNumber()

	packet, err := extractPacket(event)
	if err != nil {
		return err
	}

	pbEvent := &pb.Event{
		Id:       eventId,
		EventRpc: packet.Name,
		TraceId:  traceId,
	}
	emitContext, cancelFunc := context.WithCancel(t.ctx)
	defer cancelFunc()
	payload, err := proto.Marshal(pbEvent)
	if err != nil {
		log.Errorf("Failed to marshal pb event")
		return err
	}

	fh := &connection.FrameHeader{
		Type:       connection.EventFrameType,
		Header:     payload,
		BodyLength: uint64(len(packet.Payload)),
	}

	stream, err := t.sess.OpenStream(emitContext, fh)
	if err != nil {
		return fmt.Errorf("Failed to open stream: %v", err)
	}
	defer stream.Close()

	_, err = stream.Write(packet.Payload)
	if err != nil {
		return fmt.Errorf("Failed to write event payload to stream: %v", err)
	}

	return nil
}

// GetClients returns addresses of clients by passed pattern in params
func (t *TransportClient) GetClients(info *ClientInfo) ([]int64, error) {
	clientInfo := &pb.ClientInfo{
		Type:    info.getPbClientInfoType(),
		Class:   info.Class,
		Name:    info.Name,
		Version: info.Version,
	}
	clientsPayload, err := proto.Marshal(clientInfo)
	if err != nil {
		return nil, err
	}

	systemGetClients := &pb.System{
		Name: GET_CLIENTS,
	}

	getClientsCtx, cancel := context.WithTimeout(t.ctx, GET_CLIENTS_TIMEOUT)
	defer cancel()

	systemPayload, err := proto.Marshal(systemGetClients)
	if err != nil {
		return nil, fmt.Errorf("failed to unmarshal system proto: %s", err)
	}

	fh := &connection.FrameHeader{
		Type:       connection.SystemFrameType,
		Header:     systemPayload,
		BodyLength: uint64(len(clientsPayload)),
	}

	clientsHeader, clientsBody, err := t.sess.SendSyncRequest(getClientsCtx, fh, clientsPayload)
	if err != nil {
		return nil, fmt.Errorf("failed to send sync request: %v", err)
	}
	headerError := &pb.SyncError{}

	if clientsHeader != nil {
		if err := proto.Unmarshal(clientsHeader.Header, headerError); err != nil {
			return nil, fmt.Errorf("failed to unmarshal frame header. Error: %v", err)
		}
	}

	if headerError.ErrCode != 0 {
		return nil, fmt.Errorf("GetClients call error: %s", headerError.ErrDesc)
	}

	clients := &pb.GetClientsResponse{}
	if err := proto.Unmarshal(clientsBody, clients); err != nil {
		return nil, fmt.Errorf("failed to unmarshal to clients proto. Error: %v", err)
	}

	return clients.Addresses, nil
}

// Run starts the transport gateway
func (t *TransportClient) Run() error {
	var err error

	if t.id, err = t.initHandler(); err != nil {
		return fmt.Errorf("Failed to get init info: %s", err)
	}

	for e, subscriber := range t.events {
		protoClient := &pb.ClientInfo{
			Type:    subscriber.target.getPbClientInfoType(),
			Class:   subscriber.target.Class,
			Name:    subscriber.target.Name,
			Version: subscriber.target.Version,
		}

		subInfo := &pb.SubInfo{
			ClientInfo: protoClient,
			EventType:  e,
		}

		subscribeCtx, cancelFunc := context.WithCancel(t.ctx)
		defer cancelFunc()

		payload, err := proto.Marshal(subInfo)
		if err != nil {
			return err
		}

		sys := &pb.System{
			Name: SUBSCRIBE,
		}

		systemPayload, err := proto.Marshal(sys)
		if err != nil {
			return fmt.Errorf("failed to unmarshal system proto: %s", err)
		}

		fh := &connection.FrameHeader{
			Type:       connection.SystemFrameType,
			Header:     systemPayload,
			BodyLength: uint64(len(payload)),
		}

		subscribeHeader, subscribeBody, err := t.sess.SendSyncRequest(subscribeCtx, fh, payload)
		headerError := &pb.SyncError{}

		if subscribeHeader != nil {
			if err := proto.Unmarshal(subscribeHeader.Header, headerError); err != nil {
				return fmt.Errorf("failed to unmarshal frame header. Error: %v", err)
			}
		}

		if headerError.ErrCode != 0 {
			return fmt.Errorf("Subscribe call error: %s", headerError.ErrDesc)
		}

		subResult := &pb.SubResult{}
		if err := proto.Unmarshal(subscribeBody, subResult); err != nil {
			return fmt.Errorf("failed to unmarshal subscribe frame header. Error: %v", err)
		}

		t.mu.Lock()
		t.subscribers[subResult.Id] = protoClient
		t.mu.Unlock()

	}
	for i := 0; i < int(t.concurrencyLimit); i++ {
		go worker(i, t.requestsChannel)
	}
	return <-t.errChan
}

func (t *TransportClient) Stop() {
	close(t.requestsChannel)
	t.cancel()
}

func (t *TransportClient) NewError(code int8, desc string) error {
	return fmt.Errorf("code: %d | description: %s", code, desc)
}

func (t *TransportClient) GetSelfId() int64 {
	return t.id
}

func (t *TransportClient) registerEvent(event interface{}, filter ClientInfo) error {
	log.Infof("Registering new event %v with filter %v\n", event, filter)

	eventValue := reflect.ValueOf(event)
	fn := eventValue.Type()

	if fn.Kind() != reflect.Func {
		return fmt.Errorf("event handler must be function, received: %q\n", fn.Kind())
	}

	eventType := fn.In(0).Elem().String()
	t.mu.Lock()
	t.events[eventType] = subscription{
		eventFunc: eventValue,
		target:    filter,
	}
	t.mu.Unlock()

	return nil
}

func (t *TransportClient) registerHandler(handler interface{}) error {
	log.Infof("Registering new handler: %v\n", handler)
	val := reflect.ValueOf(handler)
	fn := val.Type()

	if fn.Kind() != reflect.Func {
		return fmt.Errorf("handler must be function type, received %q\n", fn.Kind())
	}

	argName := fn.In(0).Elem().String()
	log.Infof("Adding argument to handlers: %s", argName)
	t.mu.Lock()
	t.handlers[argName] = val
	t.mu.Unlock()
	return nil
}

func (t *TransportClient) requestHandler(request *pb.Request, payload []byte) {

	// TODO: Define in obsolete function before the call of func
	response := &pb.Response{Address: request.Address, Id: request.Id, TraceId: request.TraceId}

	log.Infof("Received RPC: %s", request.RequestRpc)
	log.Infof("Available handlers: %v", t.handlers)

	requestJob := func() {
		resp := &Response{
			header: response,
		}
		defer t.sendResponse(resp)

		t.mu.RLock()
		handler, found := t.handlers[request.RequestRpc]
		t.mu.RUnlock()

		if found {
			value := initMessageType(request.RequestRpc)
			if err := proto.Unmarshal(payload, value.Interface().(proto.Message)); err != nil {
				response.ErrCode = ERR_CODE_MALFORMED_MSG
				response.ErrMsg = ERR_MSG_MALFORMED_MSG
			}

			// TODO: call in goroutine
			remoteResult := handler.Call([]reflect.Value{value})
			if !remoteResult[1].IsNil() {
				err := remoteResult[1].Interface().(error)
				response.ErrCode = ERR_CODE_TARGET_FUNC_ERROR
				response.ErrMsg = err.Error()
			} else {
				var (
					err error
				)
				if resp.payload, err = proto.Marshal(remoteResult[0].Interface().(proto.Message)); err != nil {
					response.ErrCode = ERR_CODE_UNEXPECTED
					response.ErrMsg = ERR_MSG_UNEXPECTED
				}

				response.ResponseRpc = remoteResult[0].Elem().Type().String()
				response.ErrCode = 0
			}
		} else {
			response.ErrCode = ERR_CODE_HANDLER_NOT_FOUND
			response.ErrMsg = ERR_MSG_HANDLER_NOT_FOUND
		}
	}
	select {
	case t.requestsChannel <- requestJob:
	default:
		response.ErrMsg = ERR_MSG_ALL_WORKERS_BUSY
		response.ErrCode = ERR_CODE_ALL_WORKERS_BUSY
		t.sendResponse(&Response{
			header: response,
		})
	}
}

func (t *TransportClient) responseHandler(response *pb.Response, payload []byte) {
	var respCallback responseCallback

	log.Infof("Received response RPC: %s", response.ResponseRpc)

	t.mu.RLock()
	cb, found := t.callbacks[response.Id]
	t.mu.RUnlock()
	if found {
		if response.ErrCode != 0 {
			respCallback.Error = fmt.Errorf("Error while processing request. Error code: %v. Error description: %v", response.ErrCode, response.ErrMsg)
		} else {
			value := initMessageType(response.ResponseRpc)
			// Payload and RPC is in response, no errors
			if value.IsValid() {
				if err := proto.Unmarshal(payload, value.Interface().(proto.Message)); err != nil {
					response.ErrMsg = ERR_MSG_MALFORMED_MSG
					response.ErrCode = ERR_CODE_MALFORMED_MSG
				}

				respCallback.Response = value.Interface().(proto.Message)
				// Response contains error
			} else {
				respCallback.Response = nil
				respCallback.Error = fmt.Errorf(response.ErrMsg)
			}
		}
		select {
		case cb.resp <- respCallback:
		default:
			log.Errorf("Response receiver not found")
		}
	} else {
		log.Errorf("No callback found for given id %d", response.Id)
	}
}

func (t *TransportClient) eventHandler(event *pb.Event, payload []byte) {
	t.mu.RLock()
	subscriber, found := t.events[event.EventRpc]
	t.mu.RUnlock()

	if found {
		value := initMessageType(event.EventRpc)
		if err := proto.Unmarshal(payload, value.Interface().(proto.Message)); err != nil {
			log.Errorf("Can't unmarshal %s event: %s", event.EventRpc, err)
		}

		go subscriber.eventFunc.Call([]reflect.Value{value})
	} else {
		log.Errorf("No event found for given name %s", event.EventRpc)
	}
}

func (t *TransportClient) sendResponse(resp *Response) {
	sendResponseCtx, cancelFunc := context.WithCancel(t.ctx)
	defer cancelFunc()

	data, _ := proto.Marshal(resp.header)
	fh := &connection.FrameHeader{Type: connection.ResponseFrameType, Header: data, BodyLength: uint64(len(resp.payload))}
	stream, err := t.sess.OpenStream(sendResponseCtx, fh)
	if err != nil {
		log.Errorf("Failed to send response: %s", err)
	}
	defer stream.Close()

	_, err = stream.Write(resp.payload)
	if err != nil {
		log.Errorf("Failed to send response: %s", err)
	}
}

func worker(workerId int, jobsChan <-chan func()) {
	for j := range jobsChan {
		log.Infof("Worker %d started", workerId)
		j()
		log.Infof("Worker %d finished", workerId)
	}
}

func extractPacket(any interface{}) (*packet, error) {
	var value reflect.Value
	var msg proto.Message
	var msgData []byte

	value = reflect.ValueOf(any)
	msg = value.Interface().(proto.Message)
	msgData, err := proto.Marshal(msg)
	if err != nil {
		return nil, fmt.Errorf("error during proto message marshaling: %s\n", err)
	}

	return &packet{Payload: msgData, Name: proto.MessageName(msg)}, nil
}

func initMessageType(messageName string) reflect.Value {
	var valueType reflect.Type
	var value reflect.Value

	if valueType = proto.MessageType(messageName); valueType != nil {
		valueType = valueType.Elem()
	}

	if valueType != nil {
		value = reflect.New(valueType)
	}
	return value
}

var randNumMax = big.NewInt(math.MaxInt64)

func generateRandNumber() int64 {
	n, _ := rand.Int(rand.Reader, randNumMax)
	return n.Int64()
}

func generateTraceId() string {
	traceId, err := uuid.NewRandom()
	if err != nil {
		log.Errorf("Couldn't generate uuid: %s", err)
		return ""
	}
	return traceId.String()
}

func handlerErrorWrapper(handlerName string, err error) error {
	return fmt.Errorf("[%s] error: %s", handlerName, err)
}

func (t *TransportClient) initHandler() (int64, error) {
	initPayload, err := proto.Marshal(t.info)
	if err != nil {
		return 0, err
	}

	sys := &pb.System{
		Name: "Init",
	}

	initCtx, cancel := context.WithTimeout(t.ctx, INIT_TIMEOUT)
	defer cancel()

	systemPayload, err := proto.Marshal(sys)
	if err != nil {
		return 0, fmt.Errorf("failed to unmarshal system proto: %s", err)
	}

	fh := &connection.FrameHeader{
		Type:       connection.SystemFrameType,
		Header:     systemPayload,
		BodyLength: uint64(len(initPayload)),
	}

	initHeader, initBody, err := t.sess.SendSyncRequest(initCtx, fh, initPayload)
	headerError := &pb.SyncError{}

	if initHeader != nil {
		if err := proto.Unmarshal(initHeader.Header, headerError); err != nil {
			return 0, fmt.Errorf("failed to unmarshal init frame header. Error: %v", err)
		}
	}

	if headerError.ErrCode != 0 {
		return 0, fmt.Errorf("Init call error: %s", headerError.ErrDesc)
	}

	subResult := &pb.SubResult{}
	if err := proto.Unmarshal(initBody, subResult); err != nil {
		return 0, fmt.Errorf("failed to unmarshal init proto. Error: %v", err)
	}

	initResponse := &pb.InitResponse{}
	if err := proto.Unmarshal(initBody, initResponse); err != nil {
		return 0, fmt.Errorf("failed to unmarshal init frame header. Error: %v", err)
	}

	return initResponse.Addr, nil
}
