// -*- C++ -*-
// 
// SimplIC3: a "simple" implementation of IC3 (and other SAT-based algorithms)
// for finite-state functional transition systems
//
// K-induction engine
//
// Author: Alberto Griggio <griggio@fbk.eu>
// See LICENSE.txt for copyright/licensing information
// See CREDITS.txt for other credits
//

#include "kinduction.h"
#include <sstream>
#include <iomanip>

namespace simplic3 {

KInduction::KInduction(Options opts, Model *model, size_t prop_idx):
    opts_(opts),
    model_(model),
    prop_idx_(prop_idx),
    un_(model_),
    cb_(NULL)
{
    factory_ = NULL;
    verb_ = opts_.verbosity;
    max_k_ = opts_.bmc_maxdepth;

    reached_depth_ = 0;
    num_simple_path_constraints_ = 0;
    total_time_ = 0.0;
    clausify_time_ = 0.0;
    base_time_ = 0.0;
    step_time_ = 0.0;
}


KInduction::~KInduction()
{
}


void KInduction::set_safe_bound_callback(const SafeBoundCallback *cb)
{
    cb_ = cb;
}


Prover::Result KInduction::prove()
{
    TimeKeeper t(total_time_);

    factory_ = SatSolverFactory::get(opts_.solver);
    SatSolver *solver = factory_->new_satsolver();
    CnfConv cnf(model_, solver, opts_.cnf_simple, true);

    cex_trace_.clear();

    Result done = UNKNOWN;

    SatLitList cls;
    SatLitList assumptions;

    SatLit init_lbl = SatLit(solver->new_var());
    solver->set_frozen(var(init_lbl));
    
    // send initial states
    {
        TimeKeeper tk(clausify_time_);
        
        for (AigList::const_iterator it = model_->statevars().begin(),
                 end = model_->statevars().end(); it != end; ++it) {
            Aig v = *it;
            Aig a = un_.unroll(model_->init_formula(v), 0);
            SatLit l = cnf.clausify(a);
            solver->set_frozen(var(l));

            cls.clear();
            cls.push_back(~init_lbl);
            cls.push_back(l);
            solver->add_clause(cls);
        }
    }

    reached_depth_ = 0;
    
    for (int k = 0;(max_k_ > 0 ? (k <= max_k_) : true) && done == UNKNOWN; ++k){
        reached_depth_ = k;
        
        if (verb_ > 0) {
            std::cout << "checking bound " << k << std::endl;
        }
        SatLit proplbl = satLit_Undef;
        {
            TimeKeeper tk(clausify_time_);
            Aig a = un_.unroll(model_->properties()[prop_idx_], k);
            proplbl = cnf.clausify(a);
        }

        assumptions.clear();
        assumptions.push_back(init_lbl);
        assumptions.push_back(~proplbl);
        bool sat = false;
        {
            TimeKeeper tk(base_time_);
            sat = solver->solve(assumptions);
        }

        if (sat) {
            done = FALSIFIED;
        } else {
            if (cb_) {
                (*cb_)(k);
            }
            
            assumptions[0] = ~init_lbl;
            while (true) {
                TimeKeeper tk(step_time_);
                sat = solver->solve(assumptions);
                if (!sat) {
                    break;
                }

                TimeKeeper tk2(clausify_time_);

                bool added = false;
                for (int i = 0; i < k; ++i) {
                    for (int j = i+1; j <= k; ++j) {
                        bool addcur = true;
                        Aig uc = model_->aig_manager()->aig_true();
                    
                        for (AigList::const_iterator it =
                                 model_->statevars().begin(),
                                 end = model_->statevars().end(); it != end;
                             ++it) {
                            Aig a = *it;
                            Aig ai = un_.unroll(a, i);
                            Aig ak = un_.unroll(a, j);
                            Aig eq = model_->aig_manager()->aig_iff(ai, ak);
                            uc = model_->aig_manager()->aig_and(uc, eq);

                            if (addcur) {
                                SatLit li = cnf.lookup(ai);
                                SatLit lk = cnf.lookup(ak);
                            
                                if (li != satLit_Undef && lk != satLit_Undef) {
                                    SatValue vi = solver->get_model_value(li);
                                    SatValue vk = solver->get_model_value(lk);

                                    if (vi != sat_Undef && vk != sat_Undef &&
                                        vi != vk) {
                                        addcur = false;
                                    }
                                }
                            }
                        }

                        if (addcur) {
                            SatLit l = cnf.clausify(uc);
                            solver->set_frozen(var(l));
                            cls.clear();
                            cls.push_back(~l);
                            solver->add_clause(cls);
                            added = true;
                            ++num_simple_path_constraints_;
                        }
                    }
                }

                if (!added) {
                    break;
                }
            }

            if (!sat) {
                done = VERIFIED;
            } else {
                cls.clear();
                cls.push_back(proplbl);
                solver->add_clause(cls);
            }
        }
    }
            
    if (done == FALSIFIED) {
        for (size_t i = 0; i <= reached_depth_; ++i) {
            cex_trace_.push_back(AigList());
            AigList &c = cex_trace_.back();

            for (AigList::const_iterator it = model_->statevars().begin(),
                     end = model_->statevars().end(); it != end; ++it) {
                Aig a = *it;
                Aig ua = un_.unroll(a, i);
                SatLit l = cnf.lookup(ua);
                if (l != satLit_Undef) {
                    SatValue val = solver->get_model_value(l);
                    if (val != sat_Undef) {
                        c.push_back(aig_lit(a, val == sat_False));
                    }
                }
            }

            for (AigList::const_iterator it = model_->inputs().begin(),
                     end = model_->inputs().end(); it != end; ++it) {
                Aig a = *it;
                Aig ua = un_.unroll(a, i);
                SatLit l = cnf.lookup(ua);
                if (l != satLit_Undef) {
                    SatValue val = solver->get_model_value(l);
                    if (val != sat_Undef) {
                        c.push_back(aig_lit(a, val == sat_False));
                    }
                }
            }

            sort(c, aig_lt);
        }
    }

    delete solver;

    return done;
}


bool KInduction::get_counterexample_trace(std::vector<AigList> *out)
{
    if (cex_trace_.empty()) {
        return false;
    }

    *out = cex_trace_;

    return true;
}


Stats KInduction::get_stats()
{
    Stats ret;
    std::ostringstream tmp;
    
#define ADDSTAT(name, val) do { tmp.str(""); \
        tmp << std::setprecision(3) << std::fixed << val;       \
        ret.push_back(std::make_pair(name, tmp.str()));         \
    } while (0)

    ADDSTAT("reached_depth", reached_depth_);
    ADDSTAT("num_simple_path_constraints", num_simple_path_constraints_);
    ADDSTAT("clausify_time", clausify_time_);
    ADDSTAT("base_time", base_time_);
    ADDSTAT("step_time", step_time_);
    ADDSTAT("prove_time", total_time_);
    size_t mem_used = get_mem_used_bytes() / (1024 * 1024);
    ADDSTAT("memory_used_mb", mem_used);

#undef ADDSTAT
    
    return ret;

}


} // namespace simplic3
