module ComputePipeline

using Preferences

const ENABLE_COMPUTE_CHECKS = @load_preference("ENABLE_COMPUTE_CHECKS", false)

enable_debugging!() = set_debug!(true)
disable_debugging!() = set_debug!(false)
function set_debug!(value::Bool)
    if value != ENABLE_COMPUTE_CHECKS
        @set_preferences!("ENABLE_COMPUTE_CHECKS" => value)
        @info "Changing the debug mode requires restarting Julia to take effect!"
    end
    return
end

using Observables

using Base: RefValue

deref(r::RefValue) = r[]
deref(x) = x

abstract type AbstractEdge end

"""
    struct Computed

A `Computed` represents a node in the `ComputeGraph`. It should not be created
directly but be generated by [`add_input!`](@ref) or [`register_computation!`](@ref).

A `Computed` can be accessed from the graph by `graph[:name]` where `:name` is
the name the node was created with. Its value is returned by `graph[:name][]`
which also runs any pending updates.
"""
mutable struct Computed
    name::Symbol
    # if a parent edge got resolved and updated this computed, dirty is temporarily true
    # so that the edges dependents can update their inputs accordingly
    dirty::Bool

    value::RefValue
    parent::AbstractEdge
    parent_idx::Int # index of parent.outputs this value refers to
    Computed(name) = new(name, false)
    function Computed(name, value::RefValue)
        validate_node_value(value)
        return new(name, false, value)
    end
    function Computed(name, value::RefValue, parent::AbstractEdge, idx::Integer)
        validate_node_value(value)
        return new(name, false, value, parent, idx)
    end
    function Computed(name, edge::AbstractEdge, idx::Integer)
        p = new(name, false)
        p.parent = edge
        p.parent_idx = idx
        return p
    end
end

hasparent(computed::Computed) = isdefined(computed, :parent)
getparent(computed::Computed) = hasparent(computed) ? computed.parent : nothing

struct ResolveException{E <: Exception} <: Exception
    start::Computed
    error::E
end

struct TypedEdge{InputTuple, OutputTuple, F}
    callback::F
    inputs::InputTuple
    inputs_dirty::Vector{Bool}
    outputs::OutputTuple
    output_nodes::Vector{Computed}
end

"""
    struct ComputeEdge

A `ComputeEdge` represents the computation connecting one set of nodes to another.
They are considered internal and should not be interacted with or created directly.
"""
struct ComputeEdge{T} <: AbstractEdge
    graph::T
    callback::Function

    inputs::Vector{Computed}
    inputs_dirty::Vector{Bool}

    outputs::Vector{Computed}
    got_resolved::RefValue{Bool}

    # edges, that rely on outputs from this edge
    # Mainly needed for mark_dirty!(edge) to propagate to all dependents
    dependents::Vector{ComputeEdge{T}}
    typed_edge::RefValue{TypedEdge}
end


function ComputeEdge(f, graph::T, input::Computed, output::Computed) where {T}
    return ComputeEdge{ComputeGraph}(
        graph, f, [input], [true], [output], RefValue(false),
        ComputeEdge[], RefValue{TypedEdge}()
    )
end

function _get_named_change(::NamedTuple{Names}, dirty) where {Names}
    values = ntuple(i -> dirty[i], length(Names))
    return NamedTuple{Names, NTuple{length(Names), Bool}}(values)
end

function TypedEdge(edge::ComputeEdge)
    N = length(edge.inputs)
    names = ntuple(i -> edge.inputs[i].name, N)
    values = ntuple(i -> edge.inputs[i].value, N)
    inputs = NamedTuple{names}(values)
    # force `callback` and `inputs` types to be inferred so the rest of the
    # constructor can be type stable
    return Base.invokelatest(TypedEdge, edge, edge.callback, inputs)
end

function TypedEdge(edge::ComputeEdge, f, inputs)
    dirty = _get_named_change(inputs, edge.inputs_dirty)

    result = f(map(getindex, inputs), dirty, nothing)

    if result isa Tuple

        if !all(is_node_value_valid, result)
            invalid_results = [output.name => value for (output, value) in zip(edge.outputs, result) if !is_node_value_valid(value)]
            strings = map(kv -> "$(kv[1]) = ::$(typeof(kv[2]))", invalid_results)
            str = join(strings, ", ")
            error("Edge callback returned invalid types for outputs: [$str]")
        end

        if length(result) != length(edge.outputs)
            m = first(methods(edge.callback))
            line = string(m.file, ":", m.line)
            error("Result needs to have same length. Found: $(result), for func $(line)")
        end

        outputs = ntuple(length(result)) do i
            v = result[i] isa RefValue ? result[i] : RefValue(result[i])
            edge.outputs[i].value = v # initialize to fully typed RefValue
            return v
        end
        foreach(node -> node.dirty = true, edge.outputs)

    elseif isnothing(result)

        outputs = ntuple(length(edge.outputs)) do i
            v = RefValue(nothing)
            edge.outputs[i].value = v # initialize to fully typed RefValue
            return v
        end
        foreach(node -> node.dirty = false, edge.outputs)

    else
        error("Wrong type as result $(typeof(result)). Needs to be Tuple with one element per output or nothing. Value: $result")
    end
    return TypedEdge(f, inputs, edge.inputs_dirty, outputs, edge.outputs)
end


"""
    struct Input

A `Input` represents an entry point to the `ComputeGraph`. Like [`Computed`](@ref)
it should not be created directly, but rely on [`add_input!`](@ref). It should
be updated by `update!(graph, input_name = new_value)` to correctly update the
state of the compute graph.
"""
mutable struct Input{T} <: AbstractEdge
    graph::T
    name::Symbol
    value::Any
    f::Function
    output::Computed
    dirty::Bool
    dependents::Vector{ComputeEdge{T}}
end

Base.setproperty!(::Input, ::Symbol, ::Observable) = error("Setting the value of an ::Input to an Observable is not allowed")
Base.setproperty!(::Input, ::Symbol, ::Computed) = error("Setting the value of an ::Input to a Computed is not allowed")

function Input(graph, name, value, f, output)
    validate_node_value(value)
    return Input{ComputeGraph}(graph, name, value, f, output, true, ComputeEdge[])
end


"""
    ComputeGraph()

Creates a new empty `ComputeGraph`.

Inputs can be added to the graph using [`add_input`](@ref). Computations and the
connected outputs can be added with [`register_computation!`](@ref).

To update an input, `update!(graph, input_name = new_value)` can be used. To get
an up-to-date value from an output use `graph[:output_name][]`.

## Example:

```
graph = ComputeGraph()

add_input!(graph, :first_node, 1)
register_computation!(graph, [:first_node], [:derived_node]) do inputs, changed, cached
    return (2 * inputs[1][], )
end

update!(graph, first_node = 2)
graph[:derived_node][]
```
"""
struct ComputeGraph
    inputs::Dict{Symbol, Input}
    outputs::Dict{Symbol, Computed}
    lock::ReentrantLock

    onchange::Observable{Set{Symbol}}
    observables::Dict{Symbol, Observable}
    should_deepcopy::Set{Symbol}
    observerfunctions::Vector{Observables.ObserverFunction}
    obs_to_update::Vector{Observable}
end

validate_node_value(x) = nothing
validate_node_value(x::RefValue) = isassigned(x) ? validate_node_value(x[]) : nothing
# shouldn't have those in input.value or computed.value[]
function validate_node_value(::Union{T, RefValue{T}}) where {T <: Union{Computed, Input, ComputeGraph, ComputeEdge}}
    error("The value of a compute node is not allowed to be of type ::$T.")
end

is_node_value_valid(x) = true
is_node_value_valid(x::RefValue) = isassigned(x) ? is_node_value_valid(x[]) : true
# shouldn't have those in input.value or computed.value[]
function is_node_value_valid(::Union{T, RefValue{T}}) where {T <: Union{Computed, Input, ComputeGraph, ComputeEdge}}
    return false
end

"""
    get_observable!(graph::ComputeGraph, key::Symbol[; use_deepcopy = true])

Returns an observable which contains the up to date value of the graph node with name `key`.

Note that in order for the Observable to be up to date, the respective node needs
to resolve immediately after being marked dirty. This may cause the graph to
update more frequently.

Within the compute graph, data can be updated in-place by using cached values.
This makes it impossible for the compute graph to know whether that data has
changed as the old data is not available anymore. In these cases updates are
always propagated and observables are always triggered. If the observable is
used to (eventually) update an input of the node it represents, this could lead
to infinite loops of equal updates. To prevent these, the data in the Observable
is a `deepcopy` of the data in the compute node and their data is compared
before updating.
Setting `use_deepcopy = false` will turn this safeguard off, removing both the
deepcopy and the equality check. In this case in-place updates of the node will
always trigger observable updates. Any other duplicate updates will be skipped.
"""
function get_observable!(attr::ComputeGraph, key::Symbol; use_deepcopy = true)
    # Because we allow output arrays to be reused it can be impossible to tell
    # if the data has updated. In this case the data is marked as dirty/changed
    # and added to the `onchange`. If this data is fed into an Observable which
    # updates the graph it can lead to infinite loops.
    # To prevent this we have to disambiguate the data and do == checks here.
    # This requires us to copy data every time we update and we can't use
    # `copy` because that is not always available (e.g. not for Rect)
    return get!(attr.observables, key) do
        if use_deepcopy
            push!(attr.should_deepcopy, key)
        end
        # resolve first so eltype can work
        val = attr.outputs[key]
        initial_value = use_deepcopy ? deepcopy(val[]) : val[]
        # The graph already does ignore_equal_values = true checks when data is
        # not updated in-place, so it's useless to add it here
        return Observable{eltype(val.value)}(initial_value)
    end
end

function get_observable!(c::Computed; use_deepcopy = true)
    if hasparent(c)
        p = getparent(c)
        return get_observable!(p.graph, c.name; use_deepcopy = use_deepcopy)
    else
        error("Cannot get observable for Computed without parent")
    end
end

function Observables.on(f, x::Computed; kwargs...)
    obs = get_observable!(x)
    return on(f, obs; kwargs...)
end

function Observables.onany(f, arg1::Computed, args::Union{Observable, Computed}...; kwargs...)
    obsies = map(x -> x isa Computed ? get_observable!(x) : x, (arg1, args...))
    @assert all(obs -> obs isa Observable, obsies) "Failed to create Observables for all entries"
    return onany(f, obsies...; kwargs...)
end
function Observables.map!(f, target::Observable, args::Computed...; kwargs...)
    obsies = map(x -> x isa Computed ? get_observable!(x) : x, args)
    return map!(f, target, obsies...; kwargs...)
end
function Observables.map(f, arg1::Computed, args...; kwargs...)
    obsies = map(x -> x isa Computed ? get_observable!(x) : x, (arg1, args...))
    return map(f, obsies...; kwargs...)
end


# ComputeEdge(f) = ComputeEdge(f, Computed[])
function ComputeEdge(f, graph::ComputeGraph, inputs::Vector{Computed})
    return ComputeEdge{ComputeGraph}(
        graph, f, inputs, fill(true, length(inputs)), Computed[], RefValue(false),
        ComputeEdge[], RefValue{TypedEdge}()
    )
end

function ComputeGraph()
    graph = ComputeGraph(
        Dict{Symbol, ComputeEdge}(), Dict{Symbol, Computed}(), Base.ReentrantLock(),
        Observable(Set{Symbol}()), Dict{Symbol, Observable}(), Set{Symbol}(),
        Observables.ObserverFunction[], Observable[]
    )

    on(graph.onchange) do changeset
        intersect!(changeset, keys(graph.observables))

        # update data
        for key in changeset
            val = graph.outputs[key][]
            obs = graph.observables[key]
            # Trust the graph to discard equal values. This doesn't work for
            # anything updated in-place
            if !(key in graph.should_deepcopy)
                obs.val = val
            elseif val != obs[] # treat in-place updates

                obs.val = deepcopy(val)
            else # same value (with deepcopy), skip update
                delete!(changeset, key)
            end
        end

        # trigger observables
        for key in changeset
            notify(graph.observables[key])
        end

        # clear changeset after processing observables
        empty!(changeset)
        return Consume(false)
    end

    return graph
end

_first_arg(args, changed, last) = (args[1],)

"""
    alias!(graph::ComputeGraph, input::Symbol, output::Symbol)

Creates `output` as an alias of `input`.
"""
function alias!(attr::ComputeGraph, key::Symbol, alias_key::Symbol)
    # TODO: more efficient implementation!
    register_computation!(_first_arg, attr, [key], [alias_key])
    return attr
end

function isdirty(computed::Computed)
    return hasparent(computed) && isdirty(computed.parent)
end


isdirty(edge::ComputeEdge) = !edge.got_resolved[]

# Note:
# GLMakie may mark an unresolved renderobject as resolved to avoid repeated
# errors from repeatedly pulling it. This requires us to not shortcut mark_dirty!()
# Without that, we should be able to skip mark_dirty for any child/dependent that
# is already dirty

"""
    mark_resolved!(computed)

Mark the parent edge of a compute node as resolved, so that the node will no
longer try to update. This will be undone the next time any (recursive) input
to the node is updated.
"""
function mark_resolved!(computed::Computed)
    hasparent(computed) && mark_resolved!(computed.parent)
    return
end
mark_resolved!(edge::ComputeEdge) = edge.got_resolved[] = true
mark_resolved!(edge::Input) = edge.is_dirty = true

function mark_dirty!(edge::ComputeEdge, obs_to_update::Vector{Observable})
    # Assumes this is the same graph as edge.outputs (for parent -> child graph edges)
    g = edge.graph
    for output in edge.outputs
        push!(g.onchange.val, output.name)
        g.onchange in obs_to_update || push!(obs_to_update, g.onchange)
    end

    edge.got_resolved[] = false
    for dep in edge.dependents
        mark_dirty!(dep, obs_to_update)
    end
    return
end

function mark_dirty!(computed::Computed)
    computed.dirty = true
    hasparent(computed) || return
    return mark_dirty!(computed.parent)
end

function resolve!(input::Input)
    input.dirty || return
    value = input.f(input.value)
    if isdefined(input.output, :value) && isassigned(input.output.value)
        input.output.value[] = deref(value)
    else
        input.output.value = value isa RefValue ? value : RefValue(value)
    end
    input.dirty = false
    input.output.dirty = true
    for edge in input.dependents
        mark_input_dirty!(input, edge)
    end
    input.output.dirty = false
    return input.output.value[]
end

function mark_dirty!(input::Input, obs_to_update::Vector{Observable})
    push!(input.graph.onchange.val, input.name)
    if !(input.graph.onchange in obs_to_update)
        push!(obs_to_update, input.graph.onchange)
    end

    input.dirty = true
    for edge in input.dependents
        mark_dirty!(edge, obs_to_update)
    end
    return
end

mark_dirty!(x) = mark_dirty!(x, x.graph.obs_to_update)

update_observables!(comp::Computed) = update_observables!(comp.parent)
update_observables!(edge::Input) = update_observables!(edge.graph)
update_observables!(edge::ComputeEdge) = update_observables!(edge.graph)
update_observables!(graph::ComputeGraph) = update_observables!(graph.obs_to_update)
function update_observables!(obs_to_update::Vector{Observable})
    foreach(notify, obs_to_update)
    empty!(obs_to_update)
    return
end

function Base.setindex!(computed::Computed, value)
    if computed.parent isa Input
        return setindex!(computed.parent, value)
    else
        computed.value[] = value
        mark_dirty!(computed)
        update_observables!(computed)
        return value
    end
end

function Base.setindex!(input::Input, value)
    if is_same(input.value, value)
        # Skip if the value is the same as before
        return value
    end
    input.value = value
    mark_dirty!(input)
    update_observables!(input)
    return value
end

function _setproperty!(attr::ComputeGraph, key::Symbol, value)
    input = attr.inputs[key]
    # Skip if the value is the same as before
    is_same(input.value, value) && return value
    # can't notify observables immediately here, because update may call this
    # multiple times for a synchronized update (would cause desync)
    mark_dirty!(input)
    input.value = value
    return value
end

function Base.setproperty!(attr::ComputeGraph, key::Symbol, value)
    return lock(attr.lock) do
        _setproperty!(attr, key, value)
        foreach(notify, attr.obs_to_update)
        return value
    end
end

"""
    update!(graph; kwargs...)
    update!(graph, pairs::Pair{Symbol, Any}...)

Updates any number of inputs in the graph based on the passed `key = value`
keyword arguments. The `key` refers to the name of the input and the `value` is
the new value.

## Example:

```
graph = ComputeGraph()
add_input!(graph, :first_node, 1)
update!(graph, first_node = 2)
update!(graph, :first_node => 2)
```
"""
update!(attr::ComputeGraph; kwargs...) = update!(attr, [Pair{Symbol, Any}(k, v) for (k, v) in kwargs])
update!(attr::ComputeGraph, dict::Dict{Symbol}) = _update!(attr, dict)
update!(attr::ComputeGraph, pairs::Pair{Symbol}...) = _update!(attr, [Pair{Symbol, Any}(k, v) for (k, v) in pairs])
update!(attr::ComputeGraph, pairs::AbstractVector{<:Pair{Symbol}}) = _update!(attr, pairs)

function _update!(attr::ComputeGraph, values)
    return lock(attr.lock) do
        for (key, value) in values
            if haskey(attr.inputs, key)
                _setproperty!(attr, key, value)
            else
                error("Attribute $key not found in ComputeGraph")
            end
        end
        update_observables!(attr)
        return attr
    end
end

Base.haskey(attr::ComputeGraph, key::Symbol) = haskey(attr.outputs, key)
Base.get(attr::ComputeGraph, key::Symbol, default) = get(attr.outputs, key, default)

function Base.getproperty(attr::ComputeGraph, key::Symbol)
    # more efficient to hardcode?
    key === :inputs && return getfield(attr, :inputs)
    key === :outputs && return getfield(attr, :outputs)
    key === :onchange && return getfield(attr, :onchange)
    key === :observables && return getfield(attr, :observables)
    key === :observerfunctions && return getfield(attr, :observerfunctions)
    key === :obs_to_update && return getfield(attr, :obs_to_update)
    key === :lock && return getfield(attr, :lock)
    key === :should_deepcopy && return getfield(attr, :should_deepcopy)
    return attr.outputs[key]
end

function Base.getindex(attr::ComputeGraph, key::Symbol)
    return attr.outputs[key]
end
isdirty(input::Input) = input.dirty

Base.getindex(computed::Computed) = resolve!(computed)

function mark_input_dirty!(parent::ComputeEdge, edge::ComputeEdge)
    @assert parent.got_resolved[] # parent should only call this after resolve!
    for i in eachindex(edge.inputs)
        edge.inputs_dirty[i] |= getfield(edge.inputs[i], :dirty)
    end
    return
end

function mark_input_dirty!(parent::Input, edge::ComputeEdge)
    @assert !parent.dirty # should got resolved
    for i in eachindex(edge.inputs)
        edge.inputs_dirty[i] |= getfield(edge.inputs[i], :dirty)
    end
    return
end

function set_result!(edge::TypedEdge, result, i, value)
    if isnothing(value) || is_same(edge.outputs[i][], value)
        edge.output_nodes[i].dirty = false
    else
        edge.output_nodes[i].dirty = true
        edge.outputs[i][] = deref(value)
    end
    if !isempty(result)
        next_val = first(result)
        rem = Base.tail(result)
        set_result!(edge, rem, i + 1, next_val)
    end
    return
end

function set_result!(edge::TypedEdge, result)
    next_val = first(result)
    rem = Base.tail(result)
    return set_result!(edge, rem, 1, next_val)
end

is_same(@nospecialize(a), @nospecialize(b)) = false
is_same(a::Symbol, b::Symbol) = a == b
function is_same(a::T, b::T) where {T}
    if isbitstype(T)
        # We can compare immutable isbits type per value with `===`
        return a === b
    else
        # For mutable types, we can only compare them if they're not pointing to the same  object
        # If they are the same, we have to give up since we can't test if they got mutated in-between
        # Otherwise we can compare by equivalence
        same_object = a === b
        return same_object ? false : a == b
    end
end

# do we want this type stable?
# This is how we could get a type stable callback body for resolve
function resolve!(edge::TypedEdge)
    if any(edge.inputs_dirty) # only call if inputs changed
        dirty = _get_named_change(edge.inputs, edge.inputs_dirty)
        vals = map(getindex, edge.outputs)
        names = ntuple(length(vals)) do i
            edge.output_nodes[i].name
        end
        last = NamedTuple{names}(vals)
        result = edge.callback(map(getindex, edge.inputs), dirty, last)
        if result isa Tuple
            if length(result) != length(edge.outputs)
                error("Did not return correct length: $(result), $(edge.callback)")
            end
            set_result!(edge, result)
        elseif isnothing(result)
            foreach(x -> x.dirty = false, edge.output_nodes)
        else
            error("Needs to return a Tuple with one element per output, or nothing")
        end
    end
    return
end

function resolve!(computed::Computed)
    try
        return _resolve!(computed)
    catch e
        rethrow(ResolveException(computed, e))
    end
end

function _resolve!(computed::Computed)
    if hasparent(computed)
        resolve!(computed.parent)
    end
    return computed.value[]
end

function resolve!(edge::ComputeEdge)
    isdirty(edge) || return false
    return lock(edge.graph.lock) do
        # Resolve inputs first
        foreach(_resolve!, edge.inputs)
        if !isassigned(edge.typed_edge)
            # constructor does first resolve to determine fully typed outputs
            edge.typed_edge[] = TypedEdge(edge)
        else
            resolve!(edge.typed_edge[])
        end
        edge.got_resolved[] = true
        fill!(edge.inputs_dirty, false)
        for dep in edge.dependents
            mark_input_dirty!(edge, dep)
        end
        foreach(comp -> comp.dirty = false, edge.outputs)
        return true
    end
end


"""
    add_input!([callback], compute_graph, name::Symbol, value)

Adds a new input to the given `compute_graph`. The input is referred to by the
given `name` and is initialized with the given `value`. If a `callback` is given
any new value will be passed to `callback(name, new_value)` before being stored
in the compute graph.

## Example:

```
graph = ComputeGraph()

add_input!(graph, :first_node, 1)
add_input!((k, v) -> Float32(v), graph, :second_node, 2)
```
"""
add_input!(attr::ComputeGraph, key::Symbol, value) = _add_input!(identity, attr, key, value)

# For cleaner printing and error tracking we do not use an anonymous function
#   value -> conversion_function(key, value)
# or
#   (value,), changed, cached -> conversion_function(key, value)
# but instead create an explicit wrapper here.
struct InputFunctionWrapper{FT} <: Function
    key::Symbol
    user_func::FT
end
(x::InputFunctionWrapper)(v) = x.user_func(x.key, v)
(x::InputFunctionWrapper)(inputs, changed, cached) = (x.user_func(x.key, inputs[1]),)

function add_input!(conversion_func, attr::ComputeGraph, key::Symbol, value)
    return _add_input!(InputFunctionWrapper(key, conversion_func), attr, key, value)
end

function _add_input!(func, attr::ComputeGraph, key::Symbol, value)
    @assert !(value isa Computed)
    if haskey(attr.inputs, key) || haskey(attr.outputs, key)
        error("Cannot attach input with name $key - already exists!")
    end

    output = Computed(key)
    input = Input(attr, key, value, func, output)
    output.parent = input
    output.parent_idx = 1
    # Needs to be Any, since input can change type
    attr.inputs[key] = input
    attr.outputs[key] = output
    return attr
end

function add_inputs!(conversion_func, attr::ComputeGraph; kw...)
    for (k, v) in pairs(kw)
        add_input!(conversion_func, attr, k, v)
    end
    return attr
end

compute_identity(inputs, changed, cached) = values(inputs)

"""
    add_input!([callback], compute_graph, name::Symbol, node::Computed)

Connects an output `node` of another compute graph to the given `compute_graph`.

This does not create a settable input, meaning you cannot use
`update!(graph, name = new_value)` to update it. It is solely updated by the
connected node.
"""
function add_input!(attr::ComputeGraph, key::Symbol, value::Computed)
    if haskey(attr.outputs, key)
        error("Cannot attach throughput with name $key - already exists!")
    end
    register_computation!(compute_identity, attr, [value], [key])
    return attr
end

# for recipe -> primitive (mostly)
function add_input!(conversion_func, attr::ComputeGraph, key::Symbol, value::Computed)
    if haskey(attr.outputs, key)
        error("Cannot attach throughput with name $key - already exists!")
    end
    register_computation!(InputFunctionWrapper(key, conversion_func), attr, [value], [key])
    return attr
end

"""
    add_input!([callback], compute_graph, name::Symbol, obs::Observable)

Connects an `Observable` as a input to the given `compute_graph`. The input will
be updated automatically whenever the observable updates. It can also be updated
through the usual `update!(graph, name = new_value)` call.

The Observable listener is set up with `priority = typemax(Int)-1` and
`Consume(false)`. This means it will take precedence over all but `typemax(Int)`
priority and not block later updates.
"""
function add_input!(attr::ComputeGraph, k::Symbol, obs::Observable)
    add_input!(attr, k, obs[])
    # typemax-1 so it doesn't get disturbed by other listeners but can still be
    # blocked by a typemax obs
    # Setting it to lower priority, axis and colorbar don't get the latest update
    # Still need to investigate why exactly
    of = on(obs; priority = typemax(Int) - 1) do new_val
        setproperty!(attr, k, new_val)
        return Consume(false)
    end
    push!(attr.observerfunctions, of)
    return attr
end

function add_input!(f, attr::ComputeGraph, k::Symbol, obs::Observable)
    add_input!(f, attr, k, obs[])
    of = on(obs, priority = typemax(Int) - 1) do new_val
        setproperty!(attr, k, new_val)
        return Consume(false)
    end
    push!(attr.observerfunctions, of)
    return attr
end

"""
    add_constant!(graph, name::Symbol, value)

Adds a constant to the Graph. A constant is not connected to an `Input` and thus
can't change through compute graph resolution.
"""
function add_constant!(attr::ComputeGraph, k::Symbol, value)
    haskey(attr, k) && return
    map!(() -> value, attr, Symbol[], k)
    return attr
end

function add_constants!(attr::ComputeGraph; kw...)
    for (k, v) in pairs(kw)
        add_constant!(attr, k, v)
    end
    return attr
end

get_callback(computed::Computed) = hasparent(computed) ? computed.parent.callback : nothing

"""
    register_computation!(callback, compute_graph, input_names, output_names)

Registers a new computation which transforms the given inputs to a new set of
outputs. Both the inputs and outputs are referred to by name. The inputs must
exist when the function is called. The outputs should usually created by this
function.

The callback function must accept 3 arguments:
- `inputs::NamedTuple` which contains the `input_names` and input values in the order given to `register_computation`
- `changed::NamedTuple` which a `Bool` per input name signifying whether the input has been updated since the last execution of `callback`
- `cached::Tuple` which contain the last outputs returned by function. If no previous outputs exist `cached = nothing`.

Note that `inputs` and `cached` always wrap input and outputs values in `Ref`,
so you need to always dereference them.

## Example:

```julia
graph = ComputeGraph()
add_input!(graph, :input1, 1)
add_input!(graph, :input2, 1)

register_computation!(graph, [:input1, :input2], [:output1, :output2]) do inputs, changed, cached
    input1, input2 = inputs
    has_input1_changed, has_input2_changed2 = changed
    cached_output1, cached_output2 = cached

    # compute new outputs

    return (new_output1, new_output2)
end
```

See also: [`add_input!`](@ref), [`map!`](@ref)
"""
function register_computation!(f, attr::ComputeGraph, inputs::Vector{Symbol}, outputs::Vector{Symbol})
    if !all(k -> haskey(attr.outputs, k), inputs)
        missing_keys = filter(k -> !haskey(attr.outputs, k), inputs)
        error("Could not register computation: Inputs $missing_keys not found.")
    end
    _inputs = Computed[attr.outputs[k] for k in inputs]
    register_computation!(f, attr, _inputs, outputs)
    return
end

# [computed, symbol] is an Any Vector so no eltype here
function register_computation!(f, attr::ComputeGraph, inputs::Vector, outputs::Vector{Symbol})
    _inputs = Computed[k isa Symbol ? attr.outputs[k] : k for k in inputs]
    register_computation!(f, attr, _inputs, outputs)
    return
end

function check_boxed_values(f)
    names = propertynames(f)
    name_values = map(x -> x => getfield(f, x), names)
    boxed = filter(p -> p[2] isa Core.Box, name_values)
    return if !isempty(boxed)
        boxed_str = map(boxed) do (k, v)
            box = isdefined(v, :contents) ? typeof(v.contents) : "#undef"
            return "$(k)::Core.Box($(box))"
        end
        error("Cannot register computation: Callback function cannot use boxed values: $(first(methods(f))), $(join(boxed_str, ",")). This might be caused by a variable of the same name existing inside and outside a `do ... end` block.")
    end
end

macro ifelse_enabled(enabled_expr, disabled_expr = :())
    if ENABLE_COMPUTE_CHECKS
        return esc(enabled_expr)
    else
        return esc(disabled_expr)
    end
end

function assert_same_computation(@nospecialize(f), attr::ComputeGraph, inputs, outputs)
    # Check this so we can type assert later
    if any(k -> haskey(attr.inputs, k), outputs)
        input_nodes = [k for k in outputs if haskey(attr.inputs, k)]
        error("One or multiple edge outputs already exist as graph inputs: $input_nodes")
    end

    e = attr.outputs[outputs[1]].parent::ComputeEdge

    # Check that the edge is shared
    if any(k -> attr.outputs[k].parent::ComputeEdge !== e, outputs)
        # bad_keys = join([k for k in outputs if attr.outputs[k].parent != e], ", ")
        error("Cannot register computation: $outputs already have multiple different parent compute edges.")
    end

    # Check that the requested inputs are the inputs of the new edge
    if length(e.inputs) != length(inputs) || e.inputs != inputs
        error(
            "Cannot register computation: Outputs already have a parent compute edge with different inputs.\n" *
                "   New: (" * join([n.name for n in inputs], ", ") * ") -> (" * join(outputs, ", ") * ")\n" *
                "   Old: (" * join([n.name for n in e.inputs], ", ") * ") -> (" * join([n.name for n in e.outputs], ", ") * ")"
        )
    end

    # Check that the same callback is used
    # TODO: === is much faster, but why?
    if @ifelse_enabled(e.callback != f, e.callback !== f)
        # We should only care about input arg types...
        func1, loc1 = edge_callback_to_string(f, (NamedTuple, NamedTuple, Nothing))
        func2, loc2 = edge_callback_to_string(e)
        error(
            "Cannot register computation: The outputs already have a parent compute edge using " *
                "a different callback function.\n  Given: $func1 $loc1\n  Found: $func2 $loc2\n  $(methods(f))"
        )
    end

    # edge already exists so we can
    return
end

function register_computation!(f, attr::ComputeGraph, inputs::Vector{Computed}, outputs::Vector{Symbol})
    @ifelse_enabled(check_boxed_values(f))

    if any(k -> haskey(attr.outputs, k), outputs)
        N_existing = count(k -> haskey(attr.outputs, k) && hasparent(attr.outputs[k]), outputs)
        if N_existing == 0
            # fine, we won't be overwriting an edge
        elseif N_existing != length(outputs)
            existing = [k for k in outputs if haskey(attr.outputs, k) && hasparent(attr.outputs[k])]
            combined = join(existing, ", ")
            error("Cannot register computation: Some outputs already have parent compute edges: $combined")
        else

            assert_same_computation(f, attr, inputs, outputs)

            # edge already exists so we can return
            return
        end
    end

    new_edge = ComputeEdge(f, attr, inputs)

    for input in inputs
        @assert hasparent(input) "Computed should be guaranteed to have a parent edge, but does not"
        # Edges can have multiple outputs so multiple inputs of this edge could
        # come from the same edge
        any(x -> x === new_edge, input.parent.dependents) && continue
        push!(input.parent.dependents, new_edge)
    end

    # use order of namedtuple, which should not change!
    for (i, symbol) in enumerate(outputs)
        # create an uninitialized Ref, which gets replaced by the correctly strictly typed Ref on first resolve
        value = get!(attr.outputs, symbol, Computed(symbol))
        value.parent = new_edge
        value.parent_idx = i
        value.dirty = true
        push!(new_edge.outputs, value)
    end

    return
end

struct MapFunctionWrapper{pack, FT} <: Function
    user_func::FT
    MapFunctionWrapper(f::FT, pack = true) where {FT} = new{pack, FT}(f)
end

function (x::MapFunctionWrapper{true})(inputs, @nospecialize(changed), @nospecialize(cached))
    result = x.user_func(values(inputs)...)
    return (result,)
end
function (x::MapFunctionWrapper{false})(inputs, @nospecialize(changed), @nospecialize(cached))
    result = x.user_func(values(inputs)...)
    return result
end

"""
    map!(f, compute_graph::ComputeGraph, inputs::Union{Symbol, Computed, Vector}, outputs::Union{Symbol, Vector{Symbol}})

Registers a new ComputeEdge in the `compute_graph` which connect one or multiple
`inputs` to one or multiple `outputs`.

Inputs can be Symbols referring to compute nodes in `compute_graph`, or compute
nodes from any graph. These can also be mixed. Outputs are always Symbols naming
the new nodes generated by this functon.

The callback function `f` will be called with the values of the inputs as arguments.
If the output is a `::Symbol`, the function is expected to return a value, otherwise
it is expected to return a tuple of values to be mapped to the outputs.

```julia
graph = ComputeGraph()
add_input!(graph, :input1, 2)
add_input!(graph, :input2, 1)

other_graph = ComputeGraph()
add_input!(other_graph, :input, 3)

map!(x -> 2x, graph, :input1, :output1)
map!((x, y) -> x+y, graph, [:input1, other_graph.input], :output2)
map!((x, y) -> (x+y, x-y), graph, [:input1, :input2], [:output3, :output4])
```

See also: [`add_input!`](@ref), [`register_computation!`](@ref)
"""
function Base.map!(f, attr::ComputeGraph, input::Union{Symbol, Computed}, output::Symbol)
    register_computation!(MapFunctionWrapper(f), attr, [input], [output])
    return attr
end

function Base.map!(f, attr::ComputeGraph, inputs::Vector, output::Symbol)
    register_computation!(MapFunctionWrapper(f), attr, inputs, [output])
    return attr
end

function Base.map!(f, attr::ComputeGraph, inputs::Vector, outputs::Vector{Symbol})
    register_computation!(MapFunctionWrapper(f, false), attr, inputs, outputs)
    return attr
end

function Base.map!(f, attr::ComputeGraph, inputs::Union{Symbol, Computed}, outputs::Vector{Symbol})
    register_computation!(MapFunctionWrapper(f, false), attr, [inputs], outputs)
    return attr
end

function Base.empty!(attr::ComputeGraph)
    # empty!(attr.inputs)
    # empty!(attr.outputs)
    for (name, obs) in attr.observables
        Observables.clear(obs)
    end
    empty!(attr.observables)
    for of in attr.observerfunctions
        Observables.off(of)
    end
    return empty!(attr.observerfunctions)
end

"""
    delete!(graph::ComputeGraph, key::Symbol[; force = false, recursive = false])

Deletes a node from the given graph based on its name.

If `recursive = true` all child nodes of the selected node are deleted. If
`force = true` all siblings (outputs from the same parent edge) are deleted.
If either exists without the respective option being true an error will be thrown.
"""
function Base.delete!(attr::ComputeGraph, key::Symbol; force::Bool = false, recursive::Bool = false)
    return lock(attr.lock) do
        haskey(attr.outputs, key) || throw(KeyError(key))
        _delete!(attr, attr.outputs[key], force, recursive)
        return attr
    end
end

function _delete!(attr::ComputeGraph, node::Computed, force::Bool, recursive::Bool)
    @assert hasparent(node)
    _delete!(attr, node.parent, force, recursive)
    return attr
end

function validate_deletion(edge::ComputeEdge, force::Bool, recursive::Bool)
    force && recursive && return
    if !(length(edge.outputs) == 1 || force)
        error("Cannot delete node because it or one of its dependents has siblings. Set `force = true` to also delete siblings.")
    end
    if !(recursive || isempty(edge.dependents))
        error("Cannot delete node because it has children. Set `recursive = true` to also delete its children.")
    end
    return foreach(e -> validate_deletion(e, force, recursive), edge.dependents)
end

function validate_deletion(edge::Input, force::Bool, recursive::Bool)
    force && recursive && return
    if !(recursive || isempty(edge.dependents))
        error("Cannot delete node because it has children. Set `recursive = true` to also delete its children.")
    end
    return foreach(e -> validate_deletion(e, force, recursive), edge.dependents)
end

function _delete!(attr::ComputeGraph, edge::AbstractEdge, force::Bool, recursive::Bool)
    validate_deletion(edge, force, recursive)
    return unsafe_delete!(attr, edge)
end

function unsafe_delete!(attr::ComputeGraph, edge::ComputeEdge)
    # all dependents become invalid as their parent computation no longer runs
    for dependent in edge.dependents
        unsafe_delete!(attr, dependent)
    end

    # deregister this edge as a dependency of its parents
    for computed in edge.inputs
        @assert hasparent(computed)
        parent_edge = computed.parent
        filter!(e -> e !== edge, parent_edge.dependents)
    end

    # Delete output nodes of this edge
    for computed in edge.outputs
        k = computed.name
        @assert haskey(attr.outputs, k) && attr.outputs[k] === computed
        delete!(attr.outputs, k)
    end

    return attr
end

function unsafe_delete!(attr::ComputeGraph, edge::Input)
    # all dependents become invalid as their parent computation no longer runs
    for dependent in edge.dependents
        unsafe_delete!(attr, dependent)
    end

    # Delete output node of this edge
    k = edge.name
    @assert haskey(attr.outputs, k) && attr.outputs[k] === edge.output
    delete!(attr.outputs, k)

    # Delete Input
    @assert haskey(attr.inputs, k) && attr.inputs[k] === edge
    delete!(attr.inputs, k)

    return attr
end

"""
    unsafe_disconnect_parents!(graph)

Removes every reference to this graph from every connected parent graph. This is
meant to prepare the given graph for garbage collection.

After calling this function, the graph will be in a broken state. Edges
connecting this graph to parent graphs still exist with references to parent
graphs, but they can no longer be triggered.
"""
function unsafe_disconnect_from_parents!(attr::ComputeGraph)
    for comp in values(attr.outputs)
        if hasparent(comp)
            unsafe_disconnect_parent_graph_nodes!(attr, comp.parent)
        end
    end
    return
end

unsafe_disconnect_parent_graph_nodes!(attr::ComputeGraph, edge::Input) = nothing
function unsafe_disconnect_parent_graph_nodes!(attr::ComputeGraph, edge::ComputeEdge)
    for input in edge.inputs
        if !haskey(attr.outputs, input.name) || !(input in values(attr.outputs))
            unsafe_atomic_delete!(edge)
        end
    end
    return
end

function unsafe_atomic_delete!(edge::ComputeEdge)
    # deregister this edge as a dependency of its parents
    for computed in edge.inputs
        if hasparent(computed)
            parent_edge = computed.parent
            filter!(e -> e !== edge, parent_edge.dependents)
        end
    end

    return
end

include("io.jl")

export Computed, ComputeEdge
export ComputeGraph
export register_computation!
export add_input!, add_inputs!, add_constant!, add_constants!
export update!

end
