open import 1Lab.Reflection
open import 1Lab.Prelude

open import Cat.Base

open import Data.List

import Cat.Reasoning as Cat

module Cat.Functor.Solver where


module NbE {o h o' h'} {π’ž : Precategory o h} {π’Ÿ : Precategory o' h'} (F : Functor π’ž π’Ÿ) where
  private
    module π’ž = Cat π’ž
    module π’Ÿ = Cat π’Ÿ
    open Functor F

    variable
      A B C : π’ž.Ob
      X Y Z : π’Ÿ.Ob

  data CExpr : π’ž.Ob β†’ π’ž.Ob β†’ Type (o βŠ” h) where
    _β€Άβˆ˜β€Ά_ : CExpr B C β†’ CExpr A B β†’ CExpr A C
    •id•  : CExpr A A
    _↑    : π’ž.Hom A B β†’ CExpr A B

  data DExpr : π’Ÿ.Ob β†’ π’Ÿ.Ob β†’ Type (o βŠ” h βŠ” o' βŠ” h') where
    •F₁•  : CExpr A B β†’ DExpr (Fβ‚€ A) (Fβ‚€ B)
    _β€Άβˆ˜β€Ά_ : DExpr Y Z β†’ DExpr X Y β†’ DExpr X Z
    •id•  : DExpr X X
    _↑    : π’Ÿ.Hom X Y β†’ DExpr X Y

  uncexpr : CExpr A B β†’ π’ž.Hom A B
  uncexpr (e1 β€Άβˆ˜β€Ά e2) = uncexpr e1 π’ž.∘ uncexpr e2
  uncexpr •id• = π’ž.id
  uncexpr (f ↑) = f

  undexpr : DExpr X Y β†’ π’Ÿ.Hom X Y
  undexpr (•F₁• e) = F₁ (uncexpr e)
  undexpr (e1 β€Άβˆ˜β€Ά e2) = undexpr e1 π’Ÿ.∘ undexpr e2
  undexpr •id• = π’Ÿ.id
  undexpr (f ↑) = f

  --------------------------------------------------------------------------------
  -- Values

  data CValue : π’ž.Ob β†’ π’ž.Ob β†’ Type (o βŠ” h) where
    vid : CValue A A
    vcomp : π’ž.Hom B C β†’ CValue A B β†’ CValue A C

  data Frame : π’Ÿ.Ob β†’ π’Ÿ.Ob β†’ Type (o βŠ” h βŠ” o' βŠ” h') where
    vhom : π’Ÿ.Hom X Y β†’ Frame X Y
    vfmap : π’ž.Hom A B β†’ Frame (Fβ‚€ A) (Fβ‚€ B)

  data DValue : π’Ÿ.Ob β†’ π’Ÿ.Ob β†’ Type (o βŠ” h βŠ” o' βŠ” h') where
    vid   : DValue X X
    vcomp : Frame Y Z β†’ DValue X Y β†’ DValue X Z

  uncvalue : CValue A B β†’ π’ž.Hom A B
  uncvalue vid = π’ž.id
  uncvalue (vcomp f v) = f π’ž.∘ uncvalue v

  unframe : Frame X Y β†’ π’Ÿ.Hom X Y
  unframe (vhom f) = f
  unframe (vfmap f) = F₁ f

  undvalue : DValue X Y β†’ π’Ÿ.Hom X Y
  undvalue vid = π’Ÿ.id
  undvalue (vcomp f v) = unframe f π’Ÿ.∘ undvalue v

  --------------------------------------------------------------------------------
  -- Evaluation

  do-cvcomp : CValue B C β†’ CValue A B β†’ CValue A C
  do-cvcomp vid v2 = v2
  do-cvcomp (vcomp f v1) v2 = vcomp f (do-cvcomp v1 v2)

  ceval : CExpr A B β†’ CValue A B
  ceval (e1 β€Άβˆ˜β€Ά e2) = do-cvcomp (ceval e1) (ceval e2)
  ceval •id• = vid
  ceval (f ↑) = vcomp f vid

  do-dvcomp : DValue Y Z β†’ DValue X Y β†’ DValue X Z
  do-dvcomp vid v2 = v2
  do-dvcomp (vcomp f v1) v2 = vcomp f (do-dvcomp v1 v2)

  do-vfmap : CValue A B β†’ DValue (Fβ‚€ A) (Fβ‚€ B)
  do-vfmap vid = vid
  do-vfmap (vcomp f v) = vcomp (vfmap f) (do-vfmap v)

  deval : DExpr X Y β†’ DValue X Y
  deval (•F₁• e) = do-vfmap (ceval e)
  deval (e1 β€Άβˆ˜β€Ά e2) = do-dvcomp (deval e1) (deval e2)
  deval •id• = vid
  deval (f ↑) = vcomp (vhom f) vid

  --------------------------------------------------------------------------------
  -- Soundness

  do-cvcomp-sound : βˆ€ (v1 : CValue B C) β†’ (v2 : CValue A B) β†’ uncvalue (do-cvcomp v1 v2) ≑ uncvalue v1 π’ž.∘ uncvalue v2
  do-cvcomp-sound vid v2 = sym (π’ž.idl (uncvalue v2))
  do-cvcomp-sound (vcomp f v1) v2 = π’ž.pushr (do-cvcomp-sound v1 v2)

  ceval-sound : βˆ€ (e : CExpr A B) β†’ uncvalue (ceval e) ≑ uncexpr e
  ceval-sound (e1 β€Άβˆ˜β€Ά e2) =
    uncvalue (do-cvcomp (ceval e1) (ceval e2))    β‰‘βŸ¨ do-cvcomp-sound (ceval e1) (ceval e2) ⟩
    (uncvalue (ceval e1) π’ž.∘ uncvalue (ceval e2)) β‰‘βŸ¨ apβ‚‚ π’ž._∘_ (ceval-sound e1) (ceval-sound e2) ⟩
    uncexpr e1 π’ž.∘ uncexpr e2                     ∎
  ceval-sound •id• = refl
  ceval-sound (f ↑) = π’ž.idr f

  do-vfmap-sound : βˆ€ (v : CValue A B) β†’ undvalue (do-vfmap v) ≑ F₁ (uncvalue v)
  do-vfmap-sound vid = sym F-id
  do-vfmap-sound (vcomp f v) =
    F₁ f π’Ÿ.∘ undvalue (do-vfmap v) β‰‘βŸ¨ ap (F₁ f π’Ÿ.∘_) (do-vfmap-sound v) ⟩
    F₁ f π’Ÿ.∘ F₁ (uncvalue v)       β‰‘Λ˜βŸ¨ F-∘ f (uncvalue v) ⟩
    F₁ (f π’ž.∘ uncvalue v)          ∎

  do-dvcomp-sound : βˆ€ (v1 : DValue Y Z) β†’ (v2 : DValue X Y) β†’ undvalue (do-dvcomp v1 v2) ≑ undvalue v1 π’Ÿ.∘ undvalue v2
  do-dvcomp-sound vid v2 = sym (π’Ÿ.idl (undvalue v2))
  do-dvcomp-sound (vcomp f v1) v2 = π’Ÿ.pushr (do-dvcomp-sound v1 v2)

  deval-sound : βˆ€ (e : DExpr X Y) β†’ undvalue (deval e) ≑ undexpr e
  deval-sound (•F₁• e) =
    undvalue (do-vfmap (ceval e)) β‰‘βŸ¨ do-vfmap-sound (ceval e) ⟩
    F₁ (uncvalue (ceval e))       β‰‘βŸ¨ ap F₁ (ceval-sound e ) ⟩
    F₁ (uncexpr e)                ∎
  deval-sound (e1 β€Άβˆ˜β€Ά e2) =
    undvalue (do-dvcomp (deval e1) (deval e2))  β‰‘βŸ¨ do-dvcomp-sound (deval e1) (deval e2) ⟩
    undvalue (deval e1) π’Ÿ.∘ undvalue (deval e2) β‰‘βŸ¨ apβ‚‚ π’Ÿ._∘_ (deval-sound e1) (deval-sound e2) ⟩
    undexpr e1 π’Ÿ.∘ undexpr e2                   ∎
  deval-sound •id• = refl
  deval-sound (f ↑) = π’Ÿ.idr f

  abstract
    solve : (e1 e2 : DExpr X Y) β†’ undvalue (deval e1) ≑ undvalue (deval e2) β†’ undexpr e1 ≑ undexpr e2
    solve e1 e2 p  = sym (deval-sound e1) Β·Β· p Β·Β· (deval-sound e2)

module Reflection where

  pattern category-args xs = _ hm∷ _ hm∷ _ v∷ xs

  pattern functor-args functor xs =
    _ hm∷ _ hm∷ _ hm∷ _ hm∷ _ hm∷ _ hm∷ functor v∷ xs

  pattern β€œid” =
    def (quote Precategory.id) (category-args (_ h∷ []))

  pattern β€œβˆ˜β€ f g =
    def (quote Precategory._∘_) (category-args (_ h∷ _ h∷ _ h∷ f v∷ g v∷ []))

  pattern β€œF₁” functor f =
    def (quote Functor.F₁) (functor-args functor (_ h∷ _ h∷ f v∷ []))

  mk-functor-args : Term β†’ List (Arg Term) β†’ List (Arg Term)
  mk-functor-args functor args =
    unknown h∷ unknown h∷ unknown h∷ unknown h∷ unknown h∷ unknown h∷ functor v∷ args

  β€œsolve” : Term β†’ Term β†’ Term β†’ Term
  β€œsolve” functor lhs rhs =
    def (quote NbE.solve) (mk-functor-args functor $ infer-hidden 2 $ lhs v∷ rhs v∷ def (quote refl) [] v∷ [])

  build-cexpr : Term β†’ Term
  build-cexpr β€œid” = con (quote NbE.CExpr.•id•) []
  build-cexpr (β€œβˆ˜β€ f g) = con (quote NbE.CExpr._β€Άβˆ˜β€Ά_) (build-cexpr f v∷ build-cexpr g v∷ [])
  build-cexpr f = con (quote NbE.CExpr._↑) (f v∷ [])

  build-dexpr : Term β†’ Term β†’ TC Term
  build-dexpr functor β€œid” =
    pure $ con (quote NbE.DExpr.•id•) []
  build-dexpr functor (β€œβˆ˜β€ f g) = do
    f ← build-dexpr functor f
    g ← build-dexpr functor g
    pure $ con (quote NbE.DExpr._β€Άβˆ˜β€Ά_) (f v∷ g v∷ [])
  build-dexpr functor (β€œF₁” functor' f) = do
    unify functor functor'
    pure $ con (quote NbE.DExpr.•F₁•) (build-cexpr f v∷ [])
  build-dexpr functor f =
    pure $ con (quote NbE.DExpr._↑) (f v∷ [])

  dont-reduce : List Name
  dont-reduce = quote Precategory.id ∷ quote Precategory._∘_ ∷ quote Functor.F₁ ∷ []

  solve-macro : βˆ€ {o h o' h'} {π’ž : Precategory o h} {π’Ÿ : Precategory o' h'} β†’ Functor π’ž π’Ÿ β†’ Term β†’ TC ⊀
  solve-macro functor hole =
   withNormalisation false $
   withReduceDefs (false , dont-reduce) $ do
     functor-tm ← quoteTC functor
     goal ← infer-type hole >>= reduce
     just (lhs , rhs) ← get-boundary goal
       where nothing β†’ typeError $ strErr "Can't determine boundary: " ∷
                                   termErr goal ∷ []
     elhs ← build-dexpr functor-tm lhs
     erhs ← build-dexpr functor-tm rhs
     noConstraints $ unify hole (β€œsolve” functor-tm elhs erhs)

macro
  functor! : βˆ€ {o h o' h'} {π’ž : Precategory o h} {π’Ÿ : Precategory o' h'} β†’ Functor π’ž π’Ÿ β†’ Term β†’ TC ⊀
  functor! functor = Reflection.solve-macro functor

private module Test {o h o' h'} {π’ž : Precategory o h} {π’Ÿ : Precategory o' h'} (F : Functor π’ž π’Ÿ) where
  module π’ž = Cat π’ž
  module π’Ÿ = Cat π’Ÿ
  open Functor F

  variable
    A B : π’ž.Ob
    X Y : π’Ÿ.Ob
    a b c : π’ž.Hom A B
    x y z : π’Ÿ.Hom X Y

  test : (x π’Ÿ.∘ F₁ (π’ž.id π’ž.∘ π’ž.id)) π’Ÿ.∘ F₁ a π’Ÿ.∘ F₁ (π’ž.id π’ž.∘ b) ≑ π’Ÿ.id π’Ÿ.∘ x π’Ÿ.∘ F₁ (a π’ž.∘ b)
  test = functor! F