// -*- C++ -*-
// 
// SimplIC3: a "simple" implementation of IC3 (and other SAT-based algorithms)
// for finite-state functional transition systems
//
// Model simplification via temporal decomposition
//
// Author: Alberto Griggio <griggio@fbk.eu>
// See LICENSE.txt for copyright/licensing information
// See CREDITS.txt for other credits
//

#include "tdecomp.h"
#include "modelsim.h"
#include "miscutils.h"
#include "bmc.h"
#include "unroll.h"
#include <iostream>
#include <sstream>

namespace simplic3 {

namespace {

class UnionFind {
    typedef HashMap<Aig, Aig> UfMap;
    typedef HashMap<Aig, AigList> ClassMap;
    
public:
    void merge(Aig a, Aig b);
    Aig repr(Aig a) const;

    typedef ClassMap::const_iterator iterator;
    iterator begin() const;
    iterator end() const;

    void clear();
    void swap(UnionFind &other);

    size_t size(Aig a) const;
    
private:
    mutable UfMap uf_;
    mutable ClassMap cc_;
};


void UnionFind::merge(Aig a, Aig b)
{
    Aig ra = repr(a);
    Aig rb = repr(b);

    if (ra != rb) {
        if (aig_lt(rb, ra)) {
            std::swap(ra, rb);
        }
    }
    uf_[rb] = ra;
    AigList &c = cc_[ra];
    AigList &c2 = cc_[rb];
    c.insert(c.end(), c2.begin(), c2.end());
    cc_.erase(rb);
}


Aig UnionFind::repr(Aig a) const
{
    Aig ret = a;
    while (true) {
        UfMap::iterator it = uf_.find(ret);
        if (it == uf_.end() || it->second == ret) {
            break;
        }
        ret = it->second;
    }
    uf_[a] = ret;
    if (a == ret) {
        AigList &c = cc_[a];
        if (c.empty()) {
            c.push_back(a);
        }
    }
    return ret;
}


UnionFind::iterator UnionFind::begin() const
{
    return cc_.begin();
}


UnionFind::iterator UnionFind::end() const
{
    return cc_.end();
}


void UnionFind::clear()
{
    uf_.clear();
    cc_.clear();
}


void UnionFind::swap(UnionFind &other)
{
    std::swap(uf_, other.uf_);
    std::swap(cc_, other.cc_);
}


size_t UnionFind::size(Aig a) const
{
    Aig r = repr(a);
    ClassMap::const_iterator it = cc_.find(r);
    assert(it != cc_.end());
    const AigList &c = it->second;
    return c.size();
}


void intersect_uf(const UnionFind &a, const UnionFind &b, UnionFind &out)
{
    out.clear();
    
    for (UnionFind::iterator i = a.begin(), end = a.end(); i != end; ++i) {
        const AigList &c = i->second;
        for (size_t j = 0, k = 1; k < c.size(); ++k) {
            if (b.repr(c[j]) == b.repr(c[k])) {
                out.merge(c[j], c[k]);
            } else {
                j = k;
            }
        }
    }
}


typedef std::vector<ModelSim::value> State;

void init_uf(Model *model, const State &state, UnionFind &out)
{
    out.clear();

    Aig t = model->aig_manager()->aig_true();
    Aig f = model->aig_manager()->aig_false();

    const AigList &sv = model->statevars();
    
    for (size_t i = 0; i < sv.size(); ++i) {
        if (state[i] == ModelSim::TRUE) {
            out.merge(t, sv[i]);
        } else if (state[i] == ModelSim::FALSE) {
            out.merge(f, sv[i]);
        }
    }
}


size_t uf_size(Model *model, const UnionFind &uf)
{
    size_t ts = uf.size(model->aig_manager()->aig_true());
    size_t fs = uf.size(model->aig_manager()->aig_false());
    return ts + fs - 2;
}


struct StateIdx_hash_eq {
    StateIdx_hash_eq(const std::vector<State> &trace):
        trace_(trace) {}

    size_t operator()(size_t idx) const
    {
        const State &s = trace_[idx];
        size_t ret = 0;

        for (size_t i = 0; i < s.size(); ++i) {
            ret = 5 * ret + size_t(s[i]);
        }

        return ret;
    }

    bool operator()(size_t i1, size_t i2) const
    {
        const State &s1 = trace_[i1];
        const State &s2 = trace_[i2];

        assert(s1.size() == s2.size());

        for (size_t i = 0; i < s1.size(); ++i) {
            if (s1[i] != s2[i]) {
                return false;
            }
        }
        return true;
    }

    const std::vector<State> &trace_;
};


std::string state2str(const State &s)
{
    std::ostringstream out;
    for (size_t i = 0; i < s.size(); ++i) {
        switch (s[i]) {
        case ModelSim::TRUE: out << '1'; break;
        case ModelSim::FALSE: out << '0'; break;
        default: out << 'x';
        }
    }
    return out.str();
}


class TDUnroll: public ModelUnroll {
    typedef ModelUnroll Super;
public:
    TDUnroll(Model *model, int depth):
        Super(model),
        depth_(depth)
    {
        const AigList &sv = model->statevars();
        for (size_t i = 0; i < sv.size(); ++i) {
            Aig v = sv[i];
            cache_[Key(v, 0)] = model->init_value(v);
        }
    }

    Aig unroll(Aig a) { return Super::unroll(a, depth_); }

private:
    int depth_;
};


} // namespace


//-----------------------------------------------------------------------------
// TemporalDecomp
//-----------------------------------------------------------------------------

TemporalDecomp::TemporalDecomp(Options opts, Model *model):
    opts_(opts),
    model_(model),
    max_td_depth_(opts_.simpl_max_td_depth),
    max_sim_depth_(opts_.simpl_max_sim_depth),
    verb_(opts_.verbosity)
{
    unroll_depth_ = 0;
    simplmodel_ = NULL;
}


TemporalDecomp::~TemporalDecomp()
{
    if (simplmodel_) {
        delete simplmodel_;
    }
}


Model *TemporalDecomp::simplify()
{
    if (simplmodel_) {
        delete simplmodel_;
    }
    simplmodel_ = new Model();
    
    unroll_depth_ = 0;
    int min_depth = opts_.simpl_min_td_depth;
    
    // first, detect equivalences
    std::vector<ModelSim::value> inp(model_->inputs().size(), ModelSim::UNDEF);

    std::vector<State> trace;
    trace.push_back(State());
    State *s = &(trace.back());
    Aig tt = model_->aig_manager()->aig_true();
    Aig ff = model_->aig_manager()->aig_false();

    for (AigList::const_iterator it = model_->statevars().begin(),
             end = model_->statevars().end(); it != end; ++it) {
        Aig a = *it;
        Aig i = model_->init_value(a);
        if (i == tt) {
            s->push_back(ModelSim::TRUE);
        } else if (i == ff) {
            s->push_back(ModelSim::FALSE);
        } else {
            s->push_back(ModelSim::UNDEF);
        }
    }

    StateIdx_hash_eq he(trace);
    typedef HashMap<size_t,size_t, StateIdx_hash_eq, StateIdx_hash_eq> StateMap;
    StateMap smap(0, he, he);
    smap[0] = 0;

    ModelSim sim(model_);
    sim.init(*s);

    if (verb_ > 1) {
        std::cout << "simpl: reached state " << state2str(*s) << std::endl;
    }
    
    size_t c_start = 0;
    size_t c_end = 0;
    bool c_found = false;
    
    for (int n = 1; n < max_sim_depth_ && !c_found; ++n) {
        trace.push_back(State());
        sim.step(inp);

        s = &(trace.back());
        bool good_state = false;
        for (AigList::const_iterator it = model_->statevars().begin(),
                 end = model_->statevars().end(); it != end; ++it) {
            Aig a = *it;
            ModelSim::value v = sim.get(a);
            s->push_back(v);
            if (v != ModelSim::UNDEF) {
                good_state = true;
            }
        }

        if (!good_state) {
            if (verb_) {
                std::cout << "simpl: reached all X state, no simplification "
                          << "possible" << std::endl;
            }
            break;
        } else if (verb_ > 1) {
            std::cout << "simpl: reached state " << state2str(*s) << std::endl;
        }

        size_t idx = trace.size()-1;
        StateMap::iterator it = smap.find(idx);
        if (it != smap.end()) {
            c_start = it->second;
            c_end = n;
            c_found = true;
            break;
        } else {
            smap[idx] = idx;
        }
    }

    if (!c_found) {
        if (!min_depth) {
            return NULL; // no simplification possible
        }
    } else if (verb_) {
        std::cout << "simpl: found cycle from " << c_start << " to "
                  << c_end << std::endl;
    }

    // find all the equivalences holding in the recurring part of the model
    UnionFind equivs, cur, intersection;

    if (c_found) {
        init_uf(model_, trace[c_start], equivs);
    
        for (size_t i = c_start+1; i <= c_end; ++i) {
            init_uf(model_, trace[i], cur);
            intersect_uf(equivs, cur, intersection);
            equivs.swap(intersection);
        }

        // make sure we do not unroll too many cycles
        size_t limit = max_td_depth_;
        while (c_end > limit) {
            --c_end;
            init_uf(model_, trace[c_end], cur);
            intersect_uf(equivs, cur, intersection);
            equivs.swap(intersection);
        }
        if (c_end < c_start) {
            c_start = c_end;
        }

        // compute the substitution map
        size_t num_subst = uf_size(model_, equivs);
        if (!num_subst) {
            std::cout << "simpl: no equivalences found" << std::endl;
            if (!min_depth) {
                return NULL; // nothing to simplify
            }
        } else if (verb_) {
            std::cout << "simpl: found " << num_subst << " equivalences"
                      << std::endl;
        }
    
        // check whether the equivalences hold at earlier steps
        size_t c_cand = c_start;
        while (c_cand > 0) {
            --c_cand;
            bool good = true;
            init_uf(model_, trace[c_cand], cur);

            for (UnionFind::iterator it = equivs.begin(), end = equivs.end();
                 it != end; ++it) {
                Aig r = cur.repr(it->first);
                const AigList &c = it->second;
                for (size_t j = 0; j < c.size(); ++j) {
                    if (r != cur.repr(c[j])) {
                        good = false;
                        break;
                    }
                }
            }
        
            if (!good) {
                ++c_cand;
                break;
            }
        }

        if (c_cand < c_start) {
            if (verb_) {
                std::cout << "simpl: equivalences hold also until " << c_cand
                          << std::endl;
            }
            c_start = c_cand;
        }

        unroll_depth_ = c_start;

        if (min_depth && unroll_depth_ < min_depth) {
            unroll_depth_ = min_depth;
        }
    } else {
        assert(min_depth > 0);
        unroll_depth_ = min_depth;
    }
        
    // the new initial state is the unrolling up to c_start
    // the new transition function is the old one simplified with subst
    Model tmpmodel;
    SubstMap subst;
    AigManager *mgr = tmpmodel.aig_manager();

    subst[model_->aig_manager()->aig_true()] = mgr->aig_true();
    subst[model_->aig_manager()->aig_false()] = mgr->aig_false();
    for (UnionFind::iterator it = equivs.begin(), end = equivs.end();
         it != end; ++it) {
        Aig a = it->first;
        if (subst.find(a) == subst.end()) {
            Aig vv = mgr->aig_var(AigManager::aig_get_var(a));
            subst[a] = vv;
        }
    }

    // create inputs
    for (AigList::const_iterator it = model_->inputs().begin(),
             end = model_->inputs().end(); it != end; ++it) {
        Aig v = *it;
        Aig vv = mgr->aig_var(AigManager::aig_get_var(v));
        tmpmodel.add_aiger_input(vv);
        subst[v] = vv;
    }
    
    // compute latches
    const AigList &sv = model_->statevars();
    for (size_t i = 0; i < sv.size(); ++i) {
        Aig a = sv[i];
        Aig r = equivs.repr(a);
        if (a != r) {
            assert(subst.find(r) != subst.end());
            subst[a] = subst[r];
        } else {
            Aig vv = mgr->aig_var(AigManager::aig_get_var(a));
            subst[a] = vv;
        }
    }
    tmpmodel.set_highest_var(model_->highest_var());
    
    AigList init;
    AigList next;
    TDUnroll un(model_, c_start);
    SubstMap cache;
    cache[model_->aig_manager()->aig_true()] = mgr->aig_true();
    
    tt = mgr->aig_true();
    ff = mgr->aig_false();
    
    for (size_t i = 0; i < sv.size(); ++i) {
        Aig a = sv[i];
        Aig aa = subst[a];
        Aig ua = un.unroll(a);
        // check for new inputs
        for (int v = tmpmodel.highest_var(); v <= model_->highest_var();
             ++v) {
            Aig iv = model_->aig_manager()->aig_var(v);
            Aig siv = mgr->aig_var(v);
            subst[iv] = siv;
            tmpmodel.add_aiger_input(siv);
        }
        tmpmodel.set_highest_var(model_->highest_var());
        
        ua = apply_subst(mgr, ua, cache);
        init.push_back(ua);
        Aig f = apply_subst(mgr, model_->next_value(a), subst);
        next.push_back(f);
    }

    for (size_t i = 0; i < sv.size(); ++i) {
        Aig a = sv[i];
        Aig aa = mgr->aig_var(AigManager::aig_get_var(a));
        tmpmodel.add_aiger_latch(aa, init[i], next[i]);
    }

    for (AigList::const_iterator it = model_->properties().begin(),
             end = model_->properties().end(); it != end; ++it) {
        Aig p = *it;
        p = apply_subst(mgr, p, subst);
        tmpmodel.add_aiger_output(p);
    }

    // apply cone-of-influence reduction again
    apply_coi(tmpmodel, *simplmodel_);

    if (verb_) {
        std::cout << "simpl: done ("
                  << model_->statevars().size() << ", "
                  << model_->inputs().size() << ") -> ("
                  << simplmodel_->statevars().size() << ", "
                  << simplmodel_->inputs().size() << ")"
                  << std::endl;
    }

    return simplmodel_;
}


size_t TemporalDecomp::get_unroll_depth() const
{
    return unroll_depth_;
}


Aig TemporalDecomp::apply_subst(AigManager *mgr, Aig a, SubstMap &cache)
{
    AigList to_process;
    to_process.push_back(aig_var(a));

    while (!to_process.empty()) {
        Aig cur = to_process.back();
        if (cache.find(cur) != cache.end()) {
            to_process.pop_back();
            continue;
        }

        if (AigManager::aig_is_and(cur)) {
            Aig l = AigManager::aig_get_left(cur);
            Aig r = AigManager::aig_get_right(cur);

            Aig vl = aig_var(l);
            Aig vr = aig_var(r);

            bool children_done = true;
            if (cache.find(vr) == cache.end()) {
                children_done = false;
                to_process.push_back(vr);
            }

            if (cache.find(vl) == cache.end()) {
                children_done = false;
                to_process.push_back(vl);
            }

            if (children_done) {
                to_process.pop_back();

                vl = cache[vl];
                vr = cache[vr];

                Aig res = mgr->aig_and(aig_lit(vl, aig_sign(l)),
                                       aig_lit(vr, aig_sign(r)));
                cache[cur] = res;
            }
        } else {
            to_process.pop_back();
            Aig v = mgr->aig_var(AigManager::aig_get_var(cur));
            cache[cur] = v;
        }
    }

    assert(cache.find(aig_var(a)) != cache.end());
    
    Aig ret = cache[aig_var(a)];
    ret = aig_lit(ret, aig_sign(a));

    return ret;
}


void TemporalDecomp::apply_coi(Model &src, Model &dst)
{
    AigList to_process = src.properties();
    HashSet<Aig> coi;

    while (!to_process.empty()) {
        Aig cur = AigManager::aig_strip(to_process.back());
        to_process.pop_back();

        if (!coi.insert(cur).second) {
            continue;
        }

        if (src.is_statevar(cur)) {
            to_process.push_back(src.next_value(cur));
        } else if (AigManager::aig_is_and(cur)) {
            to_process.push_back(AigManager::aig_get_left(cur));
            to_process.push_back(AigManager::aig_get_right(cur));
        }
    }

    SubstMap cache;
    cache[src.aig_manager()->aig_true()] = dst.aig_manager()->aig_true();

    dst.set_highest_var(src.highest_var());
    for (AigList::const_iterator it = src.statevars().begin(),
             end = src.statevars().end(); it != end; ++it) {
        Aig v = *it;
        if (coi.find(v) != coi.end()) {
            Aig vv = dst.aig_manager()->aig_var(AigManager::aig_get_var(v));
            Aig i = apply_subst(dst.aig_manager(), src.init_value(v), cache);
            Aig n = apply_subst(dst.aig_manager(), src.next_value(v), cache);
            dst.add_aiger_latch(vv, i, n);
        } else {
            Aig vv = dst.aig_manager()->aig_var(AigManager::aig_get_var(v));
            Aig f = dst.aig_manager()->aig_false();
            dst.add_aiger_latch(vv, f, f);
        }
    }

    for (AigList::const_iterator it = src.properties().begin(),
             end = src.properties().end(); it != end; ++it) {
        Aig o = *it;
        Aig oo = apply_subst(dst.aig_manager(), o, cache);
        dst.add_aiger_output(oo);
    }

    for (AigList::const_iterator it = src.inputs().begin(),
             end = src.inputs().end(); it != end; ++it) {
        Aig v = *it;
        Aig vv = dst.aig_manager()->aig_var(AigManager::aig_get_var(v));
        dst.add_aiger_input(vv);
    }
}


void TemporalDecomp::get_trace(const std::vector<AigList> &trace,
                               std::vector<AigList> &out)
{
    assert(!trace.empty());

    out.clear();
    // first, add the initial steps
    size_t idx = 0;
    if (unroll_depth_ > 0) {
        Options opts;
        opts.solver = opts_.solver;
        opts.verbosity = 0;
        opts.bmc_mindepth = 0;
        opts.bmc_maxdepth = unroll_depth_;

        BMC bmc(opts, model_, 0);
        AigList target;
        for (size_t i = 0; i < trace[0].size(); ++i) {
            Aig a = trace[0][i];
            Aig v = AigManager::aig_strip(a);
            Aig vv = model_->aig_manager()->aig_var(AigManager::aig_get_var(v));
            if (model_->is_statevar(vv) || model_->is_inputvar(vv)) {
                target.push_back(aig_lit(vv, aig_sign(a)));
            }
        }
        bool ok = bmc.check_reachable(target);
        assert(ok);
        ok = bmc.get_counterexample_trace(&out);
        assert(ok);
        idx = 1;
    }

    for (size_t i = idx; i < trace.size(); ++i) {
        out.push_back(AigList());
        AigList &l = out.back();
        for (size_t j = 0; j < trace[i].size(); ++j) {
            Aig a = trace[i][j];
            Aig v = model_->aig_manager()->aig_var(AigManager::aig_get_var(a));
            if (model_->is_statevar(v) || model_->is_inputvar(v)) {
                l.push_back(aig_lit(v, aig_sign(a)));
            }
        }
    }
}


} // namespace simplic3
