import Node from "./Node";
import Edge from "./Edge";
import {EdgeSerialized} from "../interfaces/EdgeSerialized";
import {NodeSerialized} from "../interfaces/NodeSerialized";
import _ from "lodash";
import ResultNode from "./ResultNode";
import {
    ExecutionState,
    InputValue,
    NodeOptions, Optional,
    OutputValueFunction,
    ResultValueFunction,
    RiskInputOption
} from "../types";
import NodeGroupSerialized from "../interfaces/NodeGroupSerialized";
import NodeGroup from "./NodeGroup";
import {DEFAULT_RESULT_NODE_LABEL, RESULT_NODE_ID} from "../constants";
import hash from "object-hash";
import {Key} from "react";
import objectPath from "object-path";

interface NetworkSerialized {
    edges?: EdgeSerialized[];
    nodes?: NodeSerialized[];
    nodeGroups?: NodeGroupSerialized[];
    resultNodeLabel: string;
    possibleValues: RiskInputOption[];
}


export default class Network {
    private readonly _nodes: Record<string, Node> = {};
    private readonly _edges: Edge[] = [];
    private readonly _nodeGroups: NodeGroup[] = [];
    private readonly _resultNode: ResultNode;
    private readonly _possibleValues: RiskInputOption[];
    private _executionState: ExecutionState = {};
    private _hash!: string;
    private _comparisonValues?: Record<Key, Record<Key, number>>;

    constructor(
        serializedNetwork: NetworkSerialized,
    ) {
        const {
            nodes = [],
            edges = [],
            nodeGroups = [],
            resultNodeLabel = DEFAULT_RESULT_NODE_LABEL,
            possibleValues
        } = serializedNetwork || {};
        nodes.forEach(this.addNode);
        edges.forEach(serializedEdge => {
            const { source, target, probability } = serializedEdge;
            const sourceNode = this._nodes[source];
            const targetNode = this._nodes[target];
            this.addEdge(sourceNode, targetNode, probability)
        });
        nodeGroups.forEach(nodeGroup => {
            const { label, id } = nodeGroup;
            const nodes = nodeGroup.nodes.map(id => this._nodes[id])
            this.addNodeGroup(label, nodes, id);
        });
        this._resultNode = new ResultNode(resultNodeLabel);
        this._possibleValues = possibleValues;
        this._updateHash();
    }

    public addNode(
        options: NodeOptions
    ): Node {
        const node = new Node(options);
        this._nodes[node.id] = node;
        this._updateHash();
        return node;
    }

    public addEdge(
        sourceNode: Node,
        targetNode: Node|ResultNode,
        weight: number
    ): Edge {
        const edge = new Edge(sourceNode, targetNode, weight);
        this._edges.push(edge);
        this._updateHash();
        return edge;
    }

    public addNodeGroup(
        label: string,
        nodes: Node[],
        id?: string,
    ): NodeGroup {
        const nodeGroup = new NodeGroup(nodes, label, id);
        this._nodeGroups.push(nodeGroup);
        this._updateHash();
        return nodeGroup;
    }

    public serialize(): NetworkSerialized {
        return {
            edges: this.edges.map(edge => edge.serialize()),
            nodes: this.nodes.map(node => node.serialize()),
            nodeGroups: this._nodeGroups.map(nodeGroup => nodeGroup.serialize()),
            resultNodeLabel: this._resultNode.label,
            possibleValues: this._possibleValues
        }
    }

    public getResult() {
        return this._executionState[this.resultNode.id].outputValue;
    }

    public setComparisonValues(
        comparisonValues: Record<Key, Record<Key, number>>
    ): void {
        this._comparisonValues = comparisonValues;
    }

    public get comparisonValues(): Record<Key, Record<Key, number>>|undefined {
        const keys = _.keys(this._comparisonValues).filter(key => this._validateComparisonValues(key))
        return _.pickBy(this._comparisonValues, (value, key) => keys.includes(key));
    }
    public get possibleValues(): RiskInputOption[] {
        return this._possibleValues;
    }

    public get edges(): Edge[] {
        return this._edges;
    }

    public get nodes(): Node[] {
        return _.values(this._nodes);
    }

    public get nodeGroups(): NodeGroup[]|undefined {
        return this._nodeGroups;
    }

    public get resultNode(): ResultNode {
        return this._resultNode;
    }

    public get hash(): string {
        return this._hash;
    }

    public execute(
        outputValueFunction?: OutputValueFunction,
        resultValueFunction?: ResultValueFunction
    ): number|undefined {
        this._reInitializeExecutionState();
        let currentNodes = this.nodes
            .filter(node => !this.hasIncomingEdges(node) && this.hasOutgoingEdges(node));

        while (currentNodes.length > 0) {
            currentNodes = this._calculateOutputsForNodes(currentNodes, outputValueFunction, resultValueFunction);
        }
        return this.getResult();
    }

    public updateNodeValues(
        nodeValues: Record<string, number>
    ): void {
        _.keys(nodeValues).forEach(key => {
            this._nodes[key].value = nodeValues[key];
        });
        this._updateHash();
    }

    public hasIncomingEdges(
        node: Node|ResultNode|string
    ): boolean {
        if (_.isString(node)) {
            node = this.getNode(node);
        }
        return this._getIncomingEdges(node).length > 0;
    }

    public hasOutgoingEdges(
        node: Node|string
    ): boolean {
        if (_.isString(node)) {
            node = this.getNode(node);
        }
        return this._getOutgoingEdges(node).length > 0;
    }

    public getNode(
        nodeId: string
    ): Node {
        return this._nodes[nodeId];
    }

    private _validateComparisonValues(
        key: Key
    ) {
        return this._nodeGroups?.map(nodeGroup => nodeGroup.id)
            .every(id => _.isNumber(objectPath.get(this._comparisonValues!, [key, id])));
    }

    private _calculateOutputsForNodes(
        currentNodes: Node[],
        outputValueFunction?: OutputValueFunction,
        resultValueFunction?: ResultValueFunction
    ): Node[] {
        const nextNodes: Node[] = [];
        currentNodes.forEach(node => this._isResultNode(node)
            ? this._calculateOutputsForResultNode(node, resultValueFunction)
            : this._calculateOutputsForStandardNode(node, nextNodes, outputValueFunction));
        return nextNodes.filter(node => this._nodeHasAllInputValues(node));
    }


    private _calculateOutputsForStandardNode(
        node: Node,
        nextNodes: (Node|ResultNode)[],
        outputValueFunction?: OutputValueFunction
    ) {
        const inputValues = this._getIncomingValues(node);
        const outputValue = _.isFunction(outputValueFunction)
            ? outputValueFunction(inputValues, node, this._executionState)
            : this._defaultOutputValueFunction(inputValues, node);
        this._getOutgoingEdges(node).forEach(edge => {
            this._addOutputAsIncomingValueForTargetNode(edge, outputValue);
            if (!nextNodes.includes(edge.targetNode)) {
                nextNodes.push(edge.targetNode);
            }
        });
    }

    private _defaultOutputValueFunction(
        inputValues: InputValue[],
        node: Node
    ): number {
        const inverseProbability = (1 - node.probability);
        const linearCombinationOfInputs = _.sumBy(inputValues, d => d.value * d.edgeProbability);
        return node.value * node.probability + inverseProbability * linearCombinationOfInputs
    }

    private _calculateOutputsForResultNode(
        node: Node,
        resultValueFunction?: ResultValueFunction
    ) {
        const inputValues = this._getIncomingValues(node);
        this._executionState[node.id].outputValue = _.isFunction(resultValueFunction)
            ? resultValueFunction(inputValues, this._executionState)
            : this._defaultResultValueFunction(inputValues);
    }

    private _defaultResultValueFunction(
        inputValues: InputValue[]
    ): number {
        return _.sumBy(inputValues, d => d.value * d.edgeProbability);
    }

    private _addOutputAsIncomingValueForTargetNode(
        edge: Edge,
        outputValue: number
    ) {
        const { targetId, targetNode } = edge;
        if (_.isUndefined(this._executionState[targetId])) {
            this._executionState[targetId] = {
                expectedIncomingValues: this._getIncomingValues(targetNode).length,
                incomingValues: []
            }
        }
        this._executionState[targetId].incomingValues.push({
            edgeProbability: edge.weight,
            value: outputValue
        });
    }

    private _nodeHasAllInputValues(node: Node) {
        const {
            incomingValues,
            expectedIncomingValues
        } = this._executionState[node.id];
        return expectedIncomingValues === incomingValues.length;
    }

    private _reInitializeExecutionState(): void {
        this._executionState = {};
        this.nodes.forEach(node => this._initializeNodeInExecutionState(node));
        this._initializeNodeInExecutionState(this.resultNode);
    }

    private _initializeNodeInExecutionState(node: Node|ResultNode) {
        this._executionState[node.id] = {
            expectedIncomingValues: this._getIncomingEdges(node).length,
            incomingValues: [],
        };
    }

    private _getIncomingEdges(
        node: Node|ResultNode
    ): Edge[] {
        return this.edges.filter(edge => edge.targetId === node.id);
    }

    private _getOutgoingEdges(
        node: Node
    ): Edge[] {
        return this.edges.filter(edge => edge.sourceId === node.id);
    }

    private _getIncomingValues(
        node: Node|ResultNode
    ): InputValue[] {
        return this._executionState[node.id].incomingValues
    }

    private _isResultNode(
        node: Node|ResultNode
    ): boolean {
        return node.id === RESULT_NODE_ID;
    }

    private _updateHash(): void {
        this._hash = hash(this.serialize());
    }

}

