// -*- C++ -*-
// 
// SimplIC3: a "simple" implementation of IC3 (and other SAT-based algorithms)
// for finite-state functional transition systems
//
// C API
//
// Author: Alberto Griggio <griggio@fbk.eu>
// See LICENSE.txt for copyright/licensing information
// See CREDITS.txt for other credits
//

#include "simplic3.h"
#include "prover.h"
#include "tdecomp.h"
#include <stdlib.h>
#include <string.h>
#include <sstream>

namespace {

using namespace simplic3;

typedef void(*c_safe_cb)(int, void *);

class CSafeBoundCallback: public Prover::SafeBoundCallback {
public:
    CSafeBoundCallback(c_safe_cb func, void *data):
        func_(func),
        data_(data)
    {}

    void operator()(int k) const
    {
        func_(k, data_);
    }

private:
    c_safe_cb func_;
    void *data_;
};


struct ProverState {
    Model model;
    Options opts;
    Prover *prover;
    int status;
    bool initialized;
    std::string tmpstr;
    CSafeBoundCallback *cb;
    TemporalDecomp *simpl;

    enum { SAFE = 0, UNSAFE = 1 };

    ProverState():
        model(),
        opts(),
        prover(NULL),
        initialized(false),
        status(-1),
        cb(NULL),
        simpl(NULL)
    {
    }
    
    ~ProverState()
    {
        if (cb) {
            delete cb;
        }
        if (prover) {
            delete prover;
        }
        if (simpl) {
            delete simpl;
        }
    }
    
    void prove(unsigned int prop_idx)
    {
        if (prover) {
            delete prover;
        }
        if (simpl) {
            delete simpl;
        }
        status = -1;
        Model *m = &model;
        Prover::Result res = Prover::UNKNOWN;
        if (opts.simpl_max_td_depth > 0 || opts.simpl_min_td_depth > 0) {
            simpl = new TemporalDecomp(opts, &model);
            m = simpl->simplify();
            if (m == NULL) {
                m = &model;
                delete simpl;
                simpl = NULL;
            } else {
                // check with bmc that the property holds until the unrolled
                // depth
                Options bopts = opts;
                bopts.bmc_mindepth = 0;
                bopts.bmc_maxdepth = simpl->get_unroll_depth();
                prover = Prover::get("bmc", bopts, &model, prop_idx);
                res = prover->prove();
                if (res != Prover::FALSIFIED) {
                    delete prover;
                    prover = NULL;
                    res = Prover::UNKNOWN;
                } else {
                    m = &model;
                    delete simpl;
                    simpl = NULL;
                }
            }
        }
        if (res == Prover::UNKNOWN) {
            prover = Prover::get(opts.algorithm, opts, m/*&model*/, prop_idx);
            if (cb) {
                prover->set_safe_bound_callback(cb);
            }
            res = prover->prove();
        }
        if (res == Prover::VERIFIED) {
            status = SAFE;
        } else if (res == Prover::FALSIFIED) {
            status = UNSAFE;
        }
    }
};

inline ProverState *S(SimplIC3 s) { return static_cast<ProverState *>(s.repr); }


} // namespace

extern "C" {

void simplic3_free(void *mem)
{
    free(mem);
}


SimplIC3 simplic3_new(void)
{
    SimplIC3 ret = { NULL };
    ProverState *state = new ProverState();
    ret.repr = state;
    return ret;
}


int simplic3_init(SimplIC3 s, aiger *aigmgr)
{
    ProverState *state = S(s);
    if (!state || state->initialized ||
        !state->model.init_from_aiger(aigmgr)) {
        return -1;
    } else {
        state->initialized = true;
        return 0;
    }
}


void simplic3_delete(SimplIC3 s)
{
    ProverState *state = S(s);
    if (state) {
        delete state;
    }
}


int simplic3_setopt(SimplIC3 s, const char *opt, const char *val)
{
    ProverState *state = S(s);
    if (!state || !opt || !val) {
        return -1;
    }
    try {
        if (!state->opts.set(opt, val)) {
            return -1;
        } else {
            return 0;
        }
    } catch (Options::error &exc) {
        return -1;
    }
}


const char *simplic3_getopt(SimplIC3 s, const char *opt)
{
    ProverState *state = S(s);
    if (!state || !opt) {
        return NULL;
    }
    try {
        state->tmpstr = state->opts.get(opt);
        return state->tmpstr.c_str();
    } catch (Options::error &exc) {
        return NULL;
    }
}


char *simplic3_opthelp(SimplIC3 s)
{
    ProverState *state = S(s);
    if (!state) {
        return NULL;
    }
    try {
        std::ostringstream tmp;
        state->opts.print_help(tmp);
        std::string tmps = tmp.str();
        char *ret = static_cast<char *>(malloc(tmps.size()+1));
        strcpy(ret, tmps.c_str());
        return ret;
    } catch (Options::error &exc) {
        return NULL;
    }
}


int simplic3_set_safe_bound_callback(SimplIC3 s, void(*cb)(int, void *),
                                     void *user_data)
{
    ProverState *state = S(s);
    if (!state) {
        return -1;
    }
    if (state->cb) {
        delete state->cb;
    }
    state->cb = new CSafeBoundCallback(cb, user_data);
    return 0;
}


int simplic3_prove(SimplIC3 s, unsigned int prop_idx)
{
    ProverState *state = S(s);
    if (!state || !state->initialized) {
        return -1;
    }
    if (prop_idx >= state->model.properties().size()) {
        return -1;
    }
    try {
        state->prove(prop_idx);
        return state->status;
    } catch (std::exception &exc) {
        return -1;
    }
}


unsigned int *simplic3_witness(SimplIC3 s)
{
    ProverState *state = S(s);
    if (!state) {
        return NULL;
    }
    unsigned int *ret = NULL;
    std::vector<AigList> v;
    bool ok = false;
    
    switch (state->status) {
    case ProverState::UNSAFE:
        ok = state->prover->get_counterexample_trace(&v);
        if (ok && state->simpl) {
            std::vector<AigList> vv;
            state->simpl->get_trace(v, vv);
            std::swap(v, vv);
        }
        break;
    case ProverState::SAFE:
        ok = state->prover->get_final_invariant(&v);
        break;
    default:
        break;
    }
    if (ok) {
        size_t sz = v.size();
        for (size_t i = 0; i < v.size(); ++i) {
            sz += v[i].size();
        }
        ++sz;
        ret = static_cast<unsigned int *>(malloc(sizeof(unsigned int) * sz));
        if (ret) {
            size_t j = 0;
            for (size_t i = 0; i < v.size(); ++i) {
                AigList &l = v[i];
                for (size_t n = 0; n < l.size(); ++n) {
                    Aig a = l[n];
                    unsigned lit = aiger_var2lit(AigManager::aig_get_var(a));
                    if (AigManager::aig_is_negated(a)) {
                        lit = aiger_not(lit);
                    }
                    ret[j++] = lit;
                }
                ret[j++] = 0;
            }
            ret[j++] = 0;
        }
    }
    return ret;
}


char **simplic3_stats(SimplIC3 s)
{
    ProverState *state = S(s);
    if (!state || !state->prover) {
        return NULL;
    }

    Stats st = state->prover->get_stats();
    size_t sz = (st.size() * 2) + 1;

    char **ret = static_cast<char **>(malloc(sizeof(char *) * sz));
    size_t j = 0;
    for (size_t i = 0; i < st.size(); ++i) {
        char *s = static_cast<char *>(malloc(st[i].first.size() + 1));
        strcpy(s, st[i].first.c_str());
        ret[j++] = s;
        s = static_cast<char *>(malloc(st[i].second.size() + 1));
        strcpy(s, st[i].second.c_str());
        ret[j++] = s;
    }
    ret[j] = NULL;

    return ret;
}

} // extern "C"
