using ImageSegmentation
using ImageFiltering
using ImageCore.FixedPointNumbers, ImageCore.ColorTypes
using Test

of_mean_type(p::Colorant) = ImageSegmentation.meantype(typeof(p))(p)

@testset "Seeded Region Growing" begin
    # 2-D image
    img = zeros(Gray{N0f8}, 10, 10)
    img[6:10,4:8] .= 0.5
    img[3:7,2:6] .= 0.8
    seeds = [CartesianIndex(3,9) => 1, CartesianIndex(5,2) => 2, CartesianIndex(9,7) => 3]

    expected = ones(Int, 10, 10)
    expected[6:10,4:8] .= 3
    expected[3:7,2:6] .= 2
    expected_labels = [1,2,3]
    expected_means = Dict(1 => of_mean_type(img[3,9]), 2 => of_mean_type(img[5,2]), 3 => of_mean_type(img[9,7]))
    expected_count = Dict(1 => 56, 2 => 25, 3 => 19)

    seg = seeded_region_growing(img, seeds)
    @test all(label->(label in expected_labels), seg.segment_labels)
    @test all(label->(label in seg.segment_labels), expected_labels)
    @test expected_count == seg.segment_pixel_count
    @test expected_means == seg.segment_means
    @test seg.image_indexmap == expected

    # Custom neighbourhood using a function
    seg = seeded_region_growing(img, seeds, c->[CartesianIndex(c[1]-1,c[2]), CartesianIndex(c[1]+1,c[2]), CartesianIndex(c[1],c[2]-1), CartesianIndex(c[1],c[2]+1)])
    @test all(label->(label in expected_labels), seg.segment_labels)
    @test all(label->(label in seg.segment_labels), expected_labels)
    @test expected_count == seg.segment_pixel_count
    @test expected_means == seg.segment_means
    @test seg.image_indexmap == expected

    # Offset image
    img = centered(img)
    seeds = [(CartesianIndex(-2,4),1), (CartesianIndex(0,-3),2), (CartesianIndex(4,2),3)]
    expected = centered(expected)
    seg = seeded_region_growing(img, seeds)
    @test all(label->(label in expected_labels), seg.segment_labels)
    @test all(label->(label in seg.segment_labels), expected_labels)
    @test expected_count == seg.segment_pixel_count
    @test expected_means == seg.segment_means
    @test seg.image_indexmap == expected

    # Custom neighbourhood using a [3,3] vs [5,5] kernel
    img = zeros(Gray{N0f8}, 5, 5)
    img[2:4,2:4] .= 1
    img[3,3] = 0
    seeds = [(CartesianIndex(3,3),1), (CartesianIndex(2,3),2)]

    expected = fill(2,(5,5))
    expected[3,3] = 1
    expected_labels = [1,2]
    expected_means = Dict(1=>Gray{Float64}(0.0), 2=>Gray{Float64}(1/3))
    expected_count = Dict(1=>1, 2=>24)

    seg = seeded_region_growing(img, seeds, (3,3))
    @test all(label->(label in expected_labels), seg.segment_labels)
    @test all(label->(label in seg.segment_labels), expected_labels)
    @test expected_count == seg.segment_pixel_count
    @test all(label->(expected_means[label] ≈ seg.segment_means[label]), seg.segment_labels)
    @test seg.image_indexmap == expected

    expected = ones(Int, 5, 5)
    expected[2:4,2:4] .= 2
    expected[3,3] = 1
    expected_labels = [1,2]
    expected_means = Dict(1=>Gray{N0f8}(0.0), 2=>Gray{N0f8}(1.0))
    expected_count = Dict(1=>17, 2=>8)

    seg = seeded_region_growing(img, seeds, (5,5))
    @test all(label->(label in expected_labels), seg.segment_labels)
    @test all(label->(label in seg.segment_labels), expected_labels)
    @test expected_count == seg.segment_pixel_count
    @test expected_means == seg.segment_means
    @test seg.image_indexmap == expected

    # element-type and InexactError
    img = zeros(Int, 4, 4)
    img[2:end, 2:end] .= 10
    img[end,end] = 11
    seeds = [(CartesianIndex(1,1), 1), (CartesianIndex(2,2), 2)]

    expected = ones(Int, size(img))
    expected[2:end,2:end] .= 2
    expected_labels = [1,2]
    expected_means = Dict(1=>0.0, 2=>91/9)
    expected_count = Dict(1=>7, 2=>9)

    seg = seeded_region_growing(img, seeds, (3,3), (v1,v2)->(δ = abs(v1-v2); return δ > 1 ? δ : zero(δ)))
    @test all(label->(label in expected_labels), seg.segment_labels)
    @test all(label->(label in seg.segment_labels), expected_labels)
    @test expected_count == seg.segment_pixel_count
    @test expected_means == seg.segment_means
    @test seg.image_indexmap == expected

    # changed mean and the same queues are in NHQ for several iterations
    # in this case the queue cannot just be inserted into pq several times
    img = Gray{N0f8}.([
     0.475  0.412  0.443  0.475  0.702  0.714  0.667  0.443  0.404  0.353
     0.525  0.431  0.365  0.424  0.682  0.698  0.635  0.365  0.329  0.318
     0.561  0.553  0.545  0.463  0.643  0.663  0.58   0.275  0.251  0.259
     0.561  0.588  0.655  0.639  0.694  0.62   0.537  0.263  0.184  0.184
     0.463  0.584  0.639  0.635  0.655  0.584  0.482  0.267  0.169  0.169
     0.263  0.263  0.541  0.557  0.58   0.522  0.431  0.243  0.145  0.145
     0.239  0.231  0.239  0.38   0.443  0.467  0.376  0.224  0.118  0.122
     0.263  0.212  0.153  0.325  0.404  0.471  0.349  0.22   0.106  0.106
     0.255  0.208  0.133  0.333  0.416  0.482  0.345  0.224  0.098  0.169
     0.235  0.18   0.102  0.294  0.369  0.447  0.325  0.22   0.122  0.184 ])

    expected = [
     1  1  1  1  1  1  1  1  1  2
     1  1  1  1  1  1  1  2  2  2
     1  1  1  1  1  1  1  2  2  2
     1  1  1  1  1  1  1  2  2  2
     1  1  1  1  1  1  1  2  2  2
     3  3  1  1  1  1  1  2  2  2
     3  3  3  3  1  1  2  2  2  2
     3  3  3  3  3  1  2  2  2  2
     3  3  3  3  3  1  2  2  2  2
     3  3  3  3  3  1  2  2  2  2]

    seeds = [ (CartesianIndex(1,5),1), (CartesianIndex(8,9),2), (CartesianIndex(9,3),3) ]
    seg = seeded_region_growing(img, seeds)
    @test seg.image_indexmap == expected

    # 3-d image
    img = zeros(RGB{N0f8},(9,9,9))
    img[3:7,3:7,3:7] .= RGB{N0f8}(0.5,0.5,0.5)
    img[2:5,5:9,4:6] .= RGB{N0f8}(0.8,0.8,0.8)
    seeds = [(CartesianIndex(1,1,1),1), (CartesianIndex(6,4,4),2), (CartesianIndex(3,6,5),3)]

    expected = ones(Int, (9,9,9))
    expected[3:7,3:7,3:7] .= 2
    expected[2:5,5:9,4:6] .= 3
    expected_labels = [1,2,3]
    expected_means = Dict([(i, of_mean_type(img[seeds[i][1]])) for i in 1:3])
    expected_count = Dict(1=>571, 2=>98, 3=>60)

    seg = seeded_region_growing(img, seeds)
    @test all(label->(label in expected_labels), seg.segment_labels)
    @test all(label->(label in seg.segment_labels), expected_labels)
    @test expected_count == seg.segment_pixel_count
    @test expected_means == seg.segment_means
    @test seg.image_indexmap == expected

    # custom diff_fn
    img = zeros(RGB{N0f8},(3,3))
    img[1:3,1] .= RGB{N0f8}(0.4,1,0)
    img[1:3,2] .= RGB{N0f8}(0.2,1,0)
    seeds = [(CartesianIndex(2,1),1), (CartesianIndex(2,3),2)]

    expected = ones(Int, (3,3))
    expected[1:3,3] .= 2
    expected_labels = [1,2]
    expected_means = Dict(1=>RGB{Float32}(0.3,1.0,0.0), 2=>RGB{Float32}(0.0,0.0,0.0))
    expected_count = Dict(1=>6, 2=>3)

    seg = seeded_region_growing(img, seeds)
    @test all(label->(label in expected_labels), seg.segment_labels)
    @test all(label->(label in seg.segment_labels), expected_labels)
    @test expected_count == seg.segment_pixel_count
    @test expected_means[1] ≈ seg.segment_means[1]
    @test expected_means[2] ≈ seg.segment_means[2]
    @test seg.image_indexmap == expected

    expected = ones(Int, (3,3))
    expected[1:3,2] .= 0
    expected[1:3,3] .= 2
    expected_labels = [0,1,2]
    expected_means = Dict(1=>RGB{Float32}(0.4,1.0,0.0), 2=>RGB{Float32}(0.0,0.0,0.0))
    expected_count = Dict(0=>3, 1=>3, 2=>3)

    seg = seeded_region_growing(img, seeds, (3,3), (c1,c2)->abs(of_mean_type(c1).r - of_mean_type(c2).r))
    @test all(label->(label in expected_labels), seg.segment_labels)
    @test all(label->(label in seg.segment_labels), expected_labels)
    @test expected_count == seg.segment_pixel_count
    @test expected_means == seg.segment_means
    @test seg.image_indexmap == expected
    @info "The deprecation warning below is expected"   # but can be deleted eventually!
    segd = seeded_region_growing(img, seeds, [3,3], (c1,c2)->abs(of_mean_type(c1).r - of_mean_type(c2).r))
    @test labels_map(segd) == labels_map(seg)
end

@testset "Unseeded Region Growing" begin
    # 2-D image
    img = zeros(Gray{N0f8}, 10, 10)
    img[6:10,4:8] .= 0.5
    img[3:7,2:6] .= 0.8

    expected = ones(Int, 10, 10)
    expected[6:10,4:8] .= 2
    expected[3:7,2:6] .= 3
    expected_labels = [1,2,3]
    expected_means = Dict(1 => of_mean_type(img[3,9]), 3 => of_mean_type(img[5,2]), 2 => of_mean_type(img[9,7]))
    expected_count = Dict(1 => 56, 3 => 25, 2 => 19)

    seg = unseeded_region_growing(img, 0.2)
    @test all(label->(label in expected_labels), seg.segment_labels)
    @test all(label->(label in seg.segment_labels), expected_labels)
    @test expected_count == seg.segment_pixel_count
    @test expected_means == seg.segment_means
    @test seg.image_indexmap == expected

    # Custom neighbourhood using a function
    seg = unseeded_region_growing(img, 0.2, c->[CartesianIndex(c[1]-1,c[2]), CartesianIndex(c[1]+1,c[2]), CartesianIndex(c[1],c[2]-1), CartesianIndex(c[1],c[2]+1)])
    @test all(label->(label in expected_labels), seg.segment_labels)
    @test all(label->(label in seg.segment_labels), expected_labels)
    @test expected_count == seg.segment_pixel_count
    @test expected_means == seg.segment_means
    @test seg.image_indexmap == expected

    # Offset image
    img = centered(img)
    expected = centered(expected)
    seg = unseeded_region_growing(img, 0.2)
    @test all(label->(label in expected_labels), seg.segment_labels)
    @test all(label->(label in seg.segment_labels), expected_labels)
    @test expected_count == seg.segment_pixel_count
    @test expected_means == seg.segment_means
    @test seg.image_indexmap == expected

    # Custom neighbourhood using a [5,5] kernel and varying threshold
    img = zeros(Gray{N0f8}, 5, 5)
    img[2:4,2:4] .= 1
    img[3,3] = 0.8

    expected = fill(1,(5,5))
    expected[2:4,2:4] .= 2
    expected_labels = [1,2]
    expected_means = Dict(1=>Gray{Float64}(0.0), 2=>Gray{Float64}(8.8/9))
    expected_count = Dict(2=>9, 1=>16)

    seg = unseeded_region_growing(img, 0.2, (5,5))
    @test all(label->(label in expected_labels), seg.segment_labels)
    @test all(label->(label in seg.segment_labels), expected_labels)
    @test expected_count == seg.segment_pixel_count
    @test all(label->(expected_means[label] ≈ seg.segment_means[label]), seg.segment_labels)
    @test seg.image_indexmap == expected

    expected = ones(Int, 5, 5)
    expected[2:4,2:4] .= 3
    expected[3,3] = 2
    expected_labels = [1,2,3]
    expected_means = Dict(1=>Gray{N0f8}(0.0), 3=>Gray{N0f8}(1.0), 2=>Gray{N0f8}(0.8))
    expected_count = Dict(1=>16, 2=>1, 3=>8)

    seg = unseeded_region_growing(img, 0.1, (5,5))
    @test all(label->(label in expected_labels), seg.segment_labels)
    @test all(label->(label in seg.segment_labels), expected_labels)
    @test expected_count == seg.segment_pixel_count
    @test expected_means == seg.segment_means
    @test seg.image_indexmap == expected

    expected = ones(Int, 5, 5)
    expected[2:4,2:4] .= 2
    expected[3,3] = 3
    expected_labels = [1,2,3]
    expected_means = Dict(1=>Gray{N0f8}(0.0), 2=>Gray{N0f8}(1.0), 3=>Gray{N0f8}(0.8))
    expected_count = Dict(1=>16, 3=>1, 2=>8)

    seg = unseeded_region_growing(img, 0.1, (3,3))
    @test all(label->(label in expected_labels), seg.segment_labels)
    @test all(label->(label in seg.segment_labels), expected_labels)
    @test expected_count == seg.segment_pixel_count
    @test expected_means == seg.segment_means
    @test seg.image_indexmap == expected

    @test_throws ArgumentError unseeded_region_growing(img, NaN, (3,3))

    # 3-d image
    img = zeros(RGB{N0f8},(9,9,9))
    img[3:7,3:7,3:7] .= RGB{N0f8}(0.5,0.5,0.5)
    img[2:5,5:9,4:6] .= RGB{N0f8}(0.8,0.8,0.8)

    expected = ones(Int, (9,9,9))
    expected[3:7,3:7,3:7] .= 2
    expected[2:5,5:9,4:6] .= 3
    expected_labels = [1,2,3]
    expected_means = Dict(1=>of_mean_type(img[1,1,1]), 2=>of_mean_type(img[3,3,3]), 3=>of_mean_type(img[2,5,4]))
    expected_count = Dict(1=>571, 2=>98, 3=>60)

    seg = unseeded_region_growing(img, 0.2)
    @test all(label->(label in expected_labels), seg.segment_labels)
    @test all(label->(label in seg.segment_labels), expected_labels)
    @test expected_count == seg.segment_pixel_count
    @test expected_means == seg.segment_means
    @test seg.image_indexmap == expected

    # custom diff_fn
    img = zeros(RGB{N0f8},(3,3))
    img[1:3,1] .= RGB{N0f8}(0.4,1,0)
    img[1:3,2] .= RGB{N0f8}(0.2,1,0)

    expected = ones(Int, (3,3))
    expected[1:3,2] .= 2
    expected[1:3,3] .= 3
    expected_labels = [1,2,3]
    expected_means = Dict(1=>of_mean_type(img[1,1]), 3=>of_mean_type(img[1,3]), 2=>of_mean_type(img[1,2]))
    expected_count = Dict(1=>3, 2=>3, 3=>3)

    seg = unseeded_region_growing(img, 0.2, (3,3), (c1,c2)->abs(of_mean_type(c1).r - of_mean_type(c2).r))
    @test all(label->(label in expected_labels), seg.segment_labels)
    @test all(label->(label in seg.segment_labels), expected_labels)
    @test expected_count == seg.segment_pixel_count
    @test expected_means == seg.segment_means
    @test seg.image_indexmap == expected
    @info "The deprecation warning below is expected"   # but can be deleted eventually!
    segd = unseeded_region_growing(img, 0.2, [3,3], (c1,c2)->abs(of_mean_type(c1).r - of_mean_type(c2).r))
    @test labels_map(segd) == labels_map(seg)
end
