#!/usr/bin/python
# This program is free software; you can redistribute it and/or modify
# it under the terms of the (LGPL) GNU Lesser General Public License as
# published by the Free Software Foundation; either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Library Lesser General Public License for more details at
# ( http://www.gnu.org/licenses/lgpl.html ).
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.

"""
Blend is a tool used to sort and combine .sql files as defined by the arguments
and the contents of the .deps files.  Files with an extension of (.sql|.pkb|.pks)
are dependancy sorted as defined by the (optional) .deps files contained within each
subdirectory.  The initial ordering is the natural alphabetical (directory) ordering.
The secondary ordering is based on dependancies.

DEPENDENCY FILES (.deps)
-------------------------------------------------------------------------------

Terms:
  namespace - A root directory.
              May be: (.|class|tables|views|procs|packages)
              where (.) represents the current namesapce.

  qname     - A qualified name.
              May be: [namespace/]basename[.ext].  When the namespace
              is not specified, the name is unqualified and resolved
              using the "path".  Using the extension is only necessary
              when more then (1) file exists in a given namespace with
              the same name but with different extensions.

Format:

  path = {namespace}([\s,]+{namespace})*

  {qname} :: {qname}([\s,]+{qname})*

Examples:

  path = . tables packages"

  * :: sequence_nextval

  rhnSharedChannelView :: rhnChannel procs/lookup_functions \
                          rhn_channel.pks

  Assuming the current namespace (or directory) is /views.
  The table dependency "rhnChannel" is unqualifed and would be searched
  for in the following order:
    1st. views/rhnChannel
    2nd. tables/rhnChannel
    3rd. packages/rhnChannel
  However, lookup_functions dependency is qualified by namespace so the
  path will not be used.
  The package header dependency "rhn_channel.pks" is unqualifed and but has
  an extension and would be searched for in the following order:
    1st. views/rhn_channel.pks
    2nd. tables/rhn_channel.pks
    3rd. packages/rhn_channel.pks

  The star notation prepends its list to all following entries in the file.

Lines may be joined by (\).

MAIN.SQL

The main.sql file is generated and references includes
the individual .sql, .pks, .pkb files within the directory
tree.  This file is generated by blend and ordered
based on .deps files.  This takes the place of the universal .sql
file.

SPECIAL FILES

clean.sql - If found in the current directory, the contents of this file is
            inserted at the beginning of the output.  This file is expected
            to contain commands to clean (clear) the database before schema
            installation.

start.sql - If found in the current directory, the contents of this file is
            inserted at the beginning of the output (after clean.sql).  This
            file is expected to contain commands used to condition the database
            or database session to schema installation.

end.sql -   If found in the current directory, the contents of this file is
            appended to the end of the output.  This file contains commands
            to commit or post process the installed schema such as commit inserted
            data or to exit an application used for the installation.
"""

import sys, os, re
from getopt import getopt, GetoptError

verbose = False
aggregated = False

#
# CLASSES
#

class Name:

    @classmethod
    def qualify(cls, nss, fn):
        names = []
        if cls.qualified(fn):
            for bn in cls.basenames(fn):
                names.append(bn)
            return tuple(names)
        if not isinstance(nss, (tuple,list)):
            nss = (nss,)
        for ns in nss:
            for bn in cls.basenames(fn):
                name = '%s/%s' % (ns, bn)
                names.append(name)
        return tuple(names)

    @classmethod
    def basenames(cls, fn):
        names = [fn]
        try:
            basename, ext = fn.rsplit('.', 1)
            names.append(basename)
        except:
            pass
        return names

    @classmethod
    def qualified(cls, name):
        return ( '/' in name )

    @classmethod
    def unqualified(cls, name):
        return ( not cls.qualified(name) )


class DepTab:

    class Options(object):
        def __setitem__(self, k, v):
            setattr(self, k, v)

    class Entry:
        def __init__(self, tab, info, subject):
            self.tab = tab
            self.info = info
            self.subject = subject
            self.deps = ()
            self.hits = 0

        def __iter__(self):
            expanded = self.expanded([])
            return iter(expanded)

        def expanded(self, history):
            deps = []
            for keyset in self.deps:
                keyset, expanded = self.tab.expand(keyset, history)
                if len(keyset):
                    deps.append(tuple(keyset))
                deps += expanded
            return deps


        def __repr__(self):
            s = []
            s.append(self.subject)
            s.append('::')
            s.append(str(self.deps))
            s.append('@')
            s.append(self.info)
            s.append('/%d' % self.hits)
            return ''.join(s)

    pattern = re.compile('[,\s]+')

    def __init__(self):
        self.content = []
        self.index = {}
        self.options = self.Options()
        self.aliases = None

    def read(self, ns, path):
        fp = open(path)
        self.options.path = (ns,)
        star_entry = None
        for line in self.lines(fp):
            if self.setoption(ns, line[1]):
                continue
            parts = line[1].split('::', 1)
            if len(parts) < 2:
                continue
            info = '%s:%d' % (path, line[0])
            subject = parts[0].strip()
            entry = self.Entry(self, info, subject)
            parts[1] = parts[1].replace('\\', '')
            deps = []
            for dep in self.values(parts[1]):
                dep = dep.strip()
                if not len(dep):
                    continue
                qnames = []
                for qn in Name.qualify(self.options.path, dep):
                    qnames.append(qn)
                deps.append(tuple(qnames))
            entry.deps = tuple(deps)
            if star_entry is not None:
                entry.deps = star_entry.deps + entry.deps
            if subject == '*':
                star_entry = entry
                continue
            self.content.append(entry)
            for key in Name.qualify(ns, subject):
                self.index[key] = entry
        fp.close()

    def setoption(self, ns, line):
        parts = line.split('=', 1)
        if len(parts) == 2:
            tag = parts[0].strip()
            if tag == 'path':
                values = []
                for v in self.values(parts[1]):
                    if v == '.':
                        v = ns
                    values.append(v)
                self.options[tag] = values
                return True
        return False

    def values(self, line):
        return self.pattern.split(line.strip())

    def lines(self, fp):
        ln = 0
        lines = []
        append = False
        for line in fp.readlines():
            ln += 1
            if not len(line):
                continue
            if line[0] == '#':
                continue
            if len(lines) and lines[-1][1].endswith('\\\n'):
                last = lines.pop()[1][:-2]
                joined = ' '.join((last, line))
                lines.append((ln, joined))
            else:
                lines.append((ln, line))
        return lines

    def find(self, ns, fn):
        result = ()
        for key in Name.qualify(ns, fn):
            entry = self.index.get(key)
            if entry is None:
                continue
            entry.hits += 1
            result = entry
            break
        return result

    def expand(self, keyset, history):
        result = ([],[])
        for key in keyset:
            if key in history:
                print 'circular (alias) reference:%s' % history
                continue
            alias = self.aliases.get(key)
            if alias is None:
                result[0].append(key)
                continue
            history.append(key)
            alias.hits += 1
            for deps in alias.expanded(history):
                result[1].append(deps)
        if len(result[1]):
            result = ([], result[1])
        return result

    def findaliases(self):
        self.aliases = {}
        for s,e in self.index.items():
            if e.hits == 0:
                self.aliases[s] = e
        return self

class Reader:

    EXT = ['sql', 'pks', 'pkb']

    def __init__(self):
        self.deptab = DepTab()
        self.deplist = DepList()
        self.overrides = []

    def read(self, directories):
        self.opendeptabs(directories)
        self.getfiles(directories)
        self.deptab.findaliases()
        return self

    def sort(self):
        sorted = self.deplist.sort()
        return [x[2] for x in sorted]


    def opendeptabs(self, directories):
        for d in directories:
            for fn, path in self.files(d, ('deps',)):
                self.deptab.read(d, path)
        return self

    def getfiles(self, directories):
        for d in directories:
            for fn, path in self.files(d, self.EXT):
                entry = self.deptab.find(d, fn)
                keys = Name.qualify(d, fn)
                found = self.deplist.find(keys[0])
                if found is None:
                    self.deplist.add(keys, entry, path)
                    continue
                pk = found[2]
                self.overrides.append((path, pk))
        return self

    def path(self, d, fn):
        return '%s/%s' % (d, fn)

    def files(self, d, extensions):
        files = []
        for fn, dir in self.flattened(d):
            parts = fn.rsplit('.', 1)
            if len(parts) != 2:
                continue
            if parts[1] in extensions:
                path = self.path(dir, fn)
                files.append((fn, path))
        return files

    def flattened(self, d):
        files = []
        for n in os.walk(d):
            for f in n[2]:
                files.append((f, n[0]))
        files = sorted(files, self.fsort)
        return files

    def fsort(self, a, b):
        try:
            A = a[0].rsplit('.', 1)
            B = b[0].rsplit('.', 1)
            BC = cmp(A[0], B[0])
            if BC:
                return BC
            IA = self.EXT.index(A[1])
            IB = self.EXT.index(B[1])
            return cmp(IA, IB)
        except:
            return 0

class Writer:

    comment = """
-- Source: %s
    """

    def __init__(self, aggregated=True):
        self.aggregated = aggregated

    def write(self, sorted, output):
        self.copy('clean.sql', output, True)
        self.copy('start.sql', output, True)
        for p in sorted:
            self.copy(p, output)
        self.copy('end.sql', output, True)

    def copy(self, path, output, optional=False):
        if optional:
            if not os.path.exists(path):
                return
        f = open(path)
        output.write(self.comment % path)
        for line in f.readlines():
            if line.startswith('--'):
                continue
            if line.startswith('@'):
                self.copy(line[1:].strip(), output)
                continue
            if line.startswith('\i'):
                self.copy(line[2:].strip(), output)
                continue
            output.write(line)
        f.close()


class Oracle(Writer):

    def write(self, sorted, output):
        if self.aggregated:
            Writer.write(self, sorted, output)
            return
        self.start(output)
        for p in sorted:
            entry = '@%s' % p
            output.write('\n')
            output.write(entry)


class Postgres(Writer):

    def write(self, sorted, output):
        if self.aggregated:
            Writer.write(self, sorted, output)
            return
        self.start(output)
        for p in sorted:
            entry = '\\i %s' % p
            output.write('\n')
            output.write(entry)



class DepList:

    def __init__(self):
        """ """
        self.unsorted = []
        self.index = {}
        self.reset()

    def reset(self):
        self.stack = []
        self.pushed = set()
        self.sorted = []
        self.unfound = []

    def add(self, keys, deps, *payload):
        item = [keys, deps]
        item += payload
        item = tuple(item)
        self.unsorted.append(item)
        for key in keys:
            self.index[key] = item
        return self

    def sort(self):
        self.reset()
        for item in self.unsorted:
            popped = []
            self.push(item)
            while len(self.stack):
                try:
                    top = self.top()
                    ref = top[1].next()
                    refd = self.find(ref)
                    if refd is None:
                        info = top[0][1].info
                        self.unfound.append((ref, info))
                        continue
                    self.push(refd)
                except StopIteration:
                    popped.append(self.pop())
                    continue
            for p in popped:
                self.sorted.append(p)
        self.unsorted = self.sorted
        return self.sorted

    def find(self, key):
        if not isinstance(key, (tuple,list)):
            key = (key,)
        for k in key:
            v = self.index.get(k)
            if v is not None:
                return v
        return None

    def top(self):
        return self.stack[-1]

    def push(self, item):
        if item in self.pushed:
            return
        frame = (item, iter(item[1]))
        self.stack.append(frame)
        self.pushed.add(item)

    def pop(self):
        try:
            frame = self.stack.pop()
            return frame[0]
        except:
            pass

#
# FUNCTIONS
#

def generate(path, reader, writer, output, directories):
    os.chdir(path)
    out = open(output, 'w')
    ds = []
    for fn in directories:
        if os.path.isdir(fn):
            ds.append(fn)
    reader.read(ds)
    sorted = reader.sort()
    overrides = reader.overrides
    print '\nOVERRIDES (%d):' % len(overrides)
    for entry in overrides:
        print '\t"%s" overridden by "%s"' % entry
    unused = \
        [x for x in reader.deptab.content if x.hits == 0]
    print '\nUNUSED RULES (%d):' % len(unused)
    for entry in unused:
        if entry.hits == 0:
            print '\t%s @%s' % (entry.subject,entry.info)
    unfound = reader.deplist.unfound
    print '\nUNFOUND (rule) REFERENCES (%d):' % len(unfound)
    for unf in unfound:
        print '\t"%s" @%s' % unf
    if verbose:
        ln = 1
        print '\n\nFILES (%d):' % len(sorted)
        for p in sorted:
            print '\t%d %s' % (ln, p)
            ln += 1
    writer.write(sorted, out)
    out.close()
    errnum = 0
    if (len(overrides)+len(unused)+len(unfound) > 0): errnum = 1
    return errnum

def usage():
    s = []
    s.append('Usage blend: [OPTIONS] directory ...')
    s.append(' Options:')
    s.append('  -h, --help')
    s.append('      Show usage information.')
    s.append('  -v, --verbose')
    s.append('      Show extra processing information.')
    s.append('  -a, --aggregated')
    s.append('      Aggregate the contents vs. reference by includes')
    s.append('  -s, --style')
    s.append('      The output style (oracle|postres ).')
    s.append('  -o, --output')
    s.append('      The output file.  Default: main.sql')
    s.append('  -d, --directory')
    s.append('      The working directory to process.')
    print '\n'.join(s)

def main(argv):
    path = '.'
    global verbose, aggregated
    reader = Reader()
    writer = Oracle()
    output = 'main.sql'
    flags = 'vhad:o:s:'
    keywords = [
        '--help',
        '--verbose',
        '--directory',
        '--output',
        '--style'
        '--aggregated',
    ]
    try:
        opts, args = getopt(argv, flags, keywords)
        for opt, arg in opts:
            if opt in ('-h', '--help'):
                usage()
                sys.exit(0)
            if opt in ('-v', '--verbose'):
                verbose = True
                continue
            if opt in ('-a', '--aggregated'):
                aggregated = True
                continue
            if opt in ('-o', '--output'):
                output = arg
                continue
            if opt in ('-d', '--directory'):
                path = arg
                continue
            if opt in ('-s', '--style'):
                if arg == 'oracle':
                    writer = Oracle()
                    continue
                if arg == 'postgres':
                    writer = Postgres()
                    continue
                raise Exception('style "%s" not valid' % arg)
        writer.aggregated = aggregated
        errnum = generate(path, reader, writer, output, args)
        sys.exit(errnum)
    except GetoptError, e:
        print e
        usage()
        sys.exit(2)

if __name__ == '__main__':
    main(sys.argv[1:])

