// -*- C++ -*-
// 
// SimplIC3: a "simple" implementation of IC3 (and other SAT-based algorithms)
// for finite-state functional transition systems
//
// IC3 with lazy abstraction
//
// Author: Alberto Griggio <griggio@fbk.eu>
// See LICENSE.txt for copyright/licensing information
// See CREDITS.txt for other credits
//

#include "aic3.h"
#include <assert.h>
#include <sstream>
#include <iomanip>
#include <math.h>

namespace simplic3 {

AbsIC3::AbsIC3(Options opts, Model *model, size_t prop_idx):
    Super(opts, model, prop_idx)
{
    bmc_ref_preprocess_ = opts_.sat_preprocess;
    opts_.sat_preprocess = false;
    opts_.max_trans_solvers = -1; // TODO

    num_refinements_ = 0;
    refinement_time_ = 0.0;
}


AbsIC3::~AbsIC3()
{
}


void AbsIC3::new_frame()
{
    Super::new_frame();

    if (absvars_.empty()) {
        absvars_.push_back(AigSet());
    } else {
        absvars_.push_back(absvars_.back());
    }
}


void AbsIC3::get_next(const Cube &c, Cube *out, int time)
{
    IC3Solver &s = get_trans_solver(time);

    for (size_t i = 0; i < c.size(); ++i) {
        Aig v = aig_var(c[i]);
        Aig n = model_->next(v);
        
        if (absvars_[time].find(v) != absvars_[time].end()) {
            s.cnf->clausify(n);
            
            Aig f = model_->next_value(v);
            Aig e = model_->aig_manager()->aig_iff(n, f);
            SatLit l = s.cnf->lookup(e);
            if (l == satLit_Undef) {
                l = s.cnf->clausify(e);
                s.solver->add_clause(l);
            }
        }

        SatLit l = s.cnf->lookup(n);
        if (l != satLit_Undef) {
            Aig nl = aig_lit(n, aig_sign(c[i]));
            out->push_back(nl);
        } else {
            out->push_back(model_->aig_manager()->aig_true());
        }
    }
}


bool AbsIC3::is_cex(ProofObligationQueue &q)
{
    ProofObligation *s = q.top();
    if (s->get_time() == 0) {
        if (s->length() > depth()+1) {
            // this is a "may" obligation, simply ignore it
            q.clear();
            return false;
        }
        // check whether we can refine the counterexample
        if (!refine_cex(s)) {
            return true;
        } else {
            q.clear();
        }
    }
    return false;
}


bool AbsIC3::refine_cex(ProofObligation *s)
{
    if (opts_.aic3_ref_bmc) {
        return refine_cex_bmc(s);
    } else {
        return refine_cex_ic3(s);
    }
}


bool AbsIC3::refine_cex_ic3(ProofObligation *s)
{
    TimeKeeper t(refinement_time_);
    ++num_refinements_;
    
    Options opts = opts_;
    opts.max_trans_solvers = 1;
    opts.verbosity = 0;
    IC3 conc_ic3(opts, model_, prop_idx_);
    conc_ic3.use_gc_frame_ = false;

    // initialize the frames
    conc_ic3.initialize();
    for (size_t i = 0; i <= depth(); ++i) {
        conc_ic3.new_frame();
    }
    assert(conc_ic3.depth() == depth());

    Cube c;
    for (size_t i = 1; i < frames_.size(); ++i) {
        Frame &f = frames_[i];
        for (size_t j = 0; j < f.size(); ++j) {
            FrameCube &fc = f[j];
            if (fc.is_active()) {
                c = fc.get_cube();
                conc_ic3.add_blocked_cube(c, i);
            }
        }
    }

    std::vector<size_t> sizes;
    for (size_t i = 0; i < conc_ic3.frames_.size(); ++i) {
        sizes.push_back(conc_ic3.frames_[i].size());
    }

    while (conc_ic3.get_bad_cube(&c)) {
        if (!conc_ic3.rec_block_cube(c, conc_ic3.depth())) {
            // the counterexample is real, extract the trace
            cex_trace_.clear();
            bool ok = conc_ic3.get_counterexample_trace(&cex_trace_);
            cex_trace_ok_ = true;
            assert(ok);
            return false;
        }
    }

    // in order to refine the abstraction, we need to check whether we have
    // added something to some of the frames
    IC3Solver rs;
    rs.init(model_, factory_, opts_.cnf_simple);

    assumptions_.clear();
    for (AigList::const_iterator it = model_->statevars().begin(),
             end = model_->statevars().end(); it != end; ++it) {
        Aig v = *it;
        SatLit l = rs.cnf->clausify(v);
        rs.solver->set_frozen(var(l));
        
        Aig n = model_->next(v);
        l = rs.cnf->clausify(n);
        rs.solver->set_frozen(var(l));

        Aig f = model_->next_value(v);
        Aig e = model_->aig_manager()->aig_iff(n, f);
        l = rs.cnf->clausify(e);
        assumptions_.push_back(l);
    }

    SatLitSet core;
    AigManager *mgr = model_->aig_manager();

    for (size_t j = conc_ic3.frames_.size(); j > 0; --j) {
        size_t i = j-1;
        Frame &f = conc_ic3.frames_[i];
        if (f.size() > sizes[i]) {
            // we have added clauses at time i-1, we need to refine the
            // corresponding trans
            assert(i > 0);

            Aig pre = get_frame_formula(conc_ic3.frames_, i-1, false);
            Aig cur = get_frame_formula(conc_ic3.frames_, i, true);

            SatLit l1 = rs.cnf->clausify(pre);
            SatLit l2 = rs.cnf->clausify(cur);

            assumptions_.push_back(l1);
            assumptions_.push_back(~l2);
            bool sat = rs.solver->solve(assumptions_);
            assert(!sat);

            rs.solver->get_unsat_core(core);

            assumptions_.pop_back();
            assumptions_.pop_back();
            
            // all the vars in the unsat core are used for refinement
            for (size_t n = 0; n < model_->statevars().size(); ++n) {
                SatLit l = assumptions_[n];
                if (core.find(~l) != core.end()) {
                    Aig v = model_->statevars()[n];
                    for (size_t k = i-1; k < absvars_.size(); ++k) {
                        absvars_[k].insert(v);
                    }
                }
            }

            // finally, add the new clauses
            for (size_t k = sizes[i]; k < f.size(); ++k) {
                if (f[k].is_active()) {
                    c = f[k].get_cube();
                    add_blocked_cube(c, i);
                }
            }
        }
    }

    rs.destroy();

    return true;
}


bool AbsIC3::refine_cex_bmc(ProofObligation *s)
{
    TimeKeeper t(refinement_time_);
    ++num_refinements_;

    SatSolver *bmc = factory_->new_satsolver();

    typedef HashMap<Aig, SatVar> BmcMap;
    BmcMap vmap;
    
    for (AigList::const_iterator it = model_->statevars().begin(),
             end = model_->statevars().end(); it != end; ++it) {
        Aig v = *it;
        SatVar sv = bmc->new_var();
        vmap[v] = sv;
    }
    for (AigList::const_iterator it = model_->inputs().begin(),
             end = model_->inputs().end(); it != end; ++it) {
        Aig v = *it;
        SatVar sv = bmc->new_var();
        vmap[v] = sv;
    }
    
    AigList trans;
    SatLitList trans_labels;
    for (AigList::const_iterator it = model_->statevars().begin(),
             end = model_->statevars().end(); it != end; ++it) {
        Aig v = *it;
        Aig n = model_->next(v);
        Aig f = model_->next_value(v);
        Aig e = model_->aig_manager()->aig_iff(n, f);
        trans.push_back(e);
    }

    assumptions_.clear();

    // create the BMC path and setup assumptions
    for (size_t i = 0, sz = s->length(); i < sz-1; ++i) {
        CnfConv cnf(model_, bmc, opts_.cnf_simple, true);

        for (BmcMap::const_iterator it = vmap.begin(), end = vmap.end();
             it != end; ++it) {
            cnf.set_label(it->first, it->second);
        }

        Aig f = get_frame_formula(frames_, i, false);
        SatLit l = cnf.clausify(f);
        bmc->add_clause(l);
    
        for (size_t i = 0; i < trans.size(); ++i) {
            SatLit l = cnf.clausify(trans[i]);
            bmc->set_frozen(var(l));
            trans_labels.push_back(l);
        }

        for (AigList::const_iterator it = model_->statevars().begin(),
                 end = model_->statevars().end(); it != end; ++it) {
            Aig v = *it;
            Aig n = model_->next(v);
            SatLit l = cnf.lookup(n);
            if (l == satLit_Undef) {
                l = SatLit(bmc->new_var());
            }
            assert(!sign(l));
            vmap[v] = var(l);
        }
        for (AigList::const_iterator it = model_->inputs().begin(),
                 end = model_->inputs().end(); it != end; ++it) {
            Aig v = *it;
            SatVar sv = bmc->new_var();
            vmap[v] = sv;
        }
    }

    {
        Aig prop = model_->properties()[prop_idx_];
        CnfConv cnf(model_, bmc, opts_.cnf_simple, true);

        for (BmcMap::const_iterator it = vmap.begin(), end = vmap.end();
             it != end; ++it) {
            cnf.set_label(it->first, it->second);
        }

        if (opts_.prop_unroll > 0) {
            prop = un_.unroll(prop, opts_.prop_unroll);
        }

        SatLit l = cnf.clausify(prop);
        bmc->add_clause(~l);
    }

    if (bmc_ref_preprocess_) {
        bmc->preprocess();
    }

    bool sat = true;
    size_t nvars = model_->statevars().size();
    for (size_t i = trans_labels.size(), j = 0; i > 0; --i, ++j) {
        if (j == nvars) {
            sat = bmc->solve(assumptions_);
            j = 0;
            if (!sat) {
                break;
            }
        }
        
        SatLit l = trans_labels[i-1];
        assumptions_.push_back(l);
    }
    if (sat) {
        sat = bmc->solve(assumptions_);
    }

    if (sat) {
        cex_trace_ok_ = false;
        return false;
    }

    // otherwise, extract the used variables at each step
    SatLitSet core;
    bmc->get_unsat_core(core);

    size_t idx = 0;
    for (size_t i = 0, sz = s->length(); i < sz-1; ++i) {
        for (size_t j = 0; j < trans.size(); ++j, ++idx) {
            SatLit l = trans_labels[idx];
            if (core.find(~l) != core.end()) {
                Aig v = model_->statevars()[j];
                for (size_t k = i; k < absvars_.size(); ++k) {
                    absvars_[k].insert(v);
                }
            }
        }
    }

    delete bmc;

    return true;
}


Aig AbsIC3::get_frame_formula(const FrameList &frames, int time, bool next)
{
    AigManager *mgr = model_->aig_manager();
    Aig ret = mgr->aig_true();
    
    if (time == 0) {
        assert(!next);
        
        for (AigList::const_iterator it = model_->statevars().begin(),
                 end = model_->statevars().end(); it != end; ++it) {
            Aig v = *it;
            Aig f = model_->init_formula(v);
            ret = mgr->aig_and(ret, f);
        }
    } else {
        for (size_t i = time; i < frames.size(); ++i) {
            const Frame &f = frames[i];
            
            for (size_t j = 0; j < f.size(); ++j) {
                if (f[j].is_active()) {
                    Aig cls = mgr->aig_false();
                    const Cube &c = f[j].get_cube();
                    for (size_t k = 0; k < c.size(); ++k) {
                        Aig v = aig_var(c[k]);
                        if (next) {
                            v = model_->next(v);
                        }
                        cls = mgr->aig_or(cls, aig_lit(v, !aig_sign(c[k])));
                    }

                    ret = mgr->aig_and(ret, cls);
                }
            }
        }
    }

    return ret;
}


Stats AbsIC3::get_stats()
{
    Stats ret = Super::get_stats();
    std::pair<std::string, std::string> ms = ret.back();
    ret.pop_back();
    std::pair<std::string, std::string> pt = ret.back();
    ret.pop_back();
    
    std::ostringstream tmp;
    
#define ADDSTAT(name, val) do { tmp.str(""); \
        tmp << std::setprecision(3) << std::fixed << val;       \
        ret.push_back(std::make_pair(name, tmp.str()));         \
    } while (0)
    
    ADDSTAT("num_refinements", num_refinements_);
    ADDSTAT("refinement_time", refinement_time_);
    ADDSTAT("min_refinement_vars", absvars_[0].size());
    ADDSTAT("max_refinement_vars", absvars_.back().size());
    size_t totvars = 0;
    for (size_t i = 0; i < absvars_.size(); ++i) {
        totvars += absvars_[i].size();
    }
    ADDSTAT("avg_refinement_vars",
            size_t(round(double(totvars)/double(absvars_.size()))));

#undef ADDSTAT

    ret.push_back(pt);
    ret.push_back(ms);
    
    return ret;

}


} // namespace simplic3
