view pyrect/translator/c_translator.py @ 99:e327e93aeb3a

remove callgraph and use Transition.
author Ryoma SHINYA <shinya@firefly.cr.ie.u-ryukyu.ac.jp>
date Sun, 12 Dec 2010 23:09:19 +0900
parents 020ba001c58a
children
line wrap: on
line source

#!/Usr/bin/env python

from pyrect.regexp import Regexp, DFA
from pyrect.regexp.ast import *
from translator import Translator

class CTranslator(Translator):
    """CTranslator
    This Class can translate from DFA or NFA into C source code.
    DFA: A simple state transition as tail call (also can implement with CbC).
    NFA: using stack, deepening depth-first search.
    >>> string = '(A|B)*C'
    >>> reg = Regexp(string)
    >>> CTranslator(reg).translate()
    >>> CTranslator(reg).translate()
    """
    def __init__(self, regexp):
        Translator.__init__(self, regexp)
        self.fa = regexp.dfa
        self.debug = False
        self.eols = (Character('\0'), Character('\n'), Character('\r'))
        self.special_rule = (Range, BegLine, MBCharacter)
        self.trans_stmt = self._trans_stmt(self.emit)

    def state_name(self, name):
        if name == "accept" or name == "reject":
            return name
        else:
            return "state_"+str(name)

    def emit_accept_state(self):
        self.emiti("int accept(unsigned char* s) {")
        self.emit(   "return 1;")
        self.demit("}", 2)

    def emit_reject_state(self):
        self.emiti("int reject(unsigned char* s) {")
        self.emit(   "return 0;")
        self.demit("}", 2)

    def emit_skip(self):
        self.emiti("const char skip_tbl[256] = {")
        self.emit("1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,")
        self.emit("1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,")
        self.emit("1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,")
        self.emit("1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,")
        self.emit("1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,")
        self.emit("1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,")
        self.emit("2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,")
        self.emit("3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,4,4,4,4,4,4,4,4,5,5,5,5,6,6,1,1,")
        self.emit("};")
        self.demit("#define SKIP(s) ((s) + skip_tbl[*(unsigned char *)s])", 2)

    def emit_driver(self):
        self.emiti("int main(int argc, unsigned char* argv[]) {")
        self.emit(   'buf = argv[1];')
        self.emit(   'puts("regexp: %s");' % self.regexp.regexp)
        self.emit(   'puts("number of state: %d");' % len(self.fa.states))
        self.emit(  r'printf("string: %s\n", argv[1]);')
        self.emit0(  "if (%s(argv[1]))" % self.state_name(self.fa.start))
        self.emit(  r'printf("accept: regexp matches string. \n\n");')
        self.emit0(  "else ")
        self.emit( r'printf("reject: regexp not matches string. \n\n");')
        self.emit(   "return 0;")
        self.demit("}", 2)

    def emit_strcmp1(self, string, next):
        cmp_stmt = list()
        offset = 0
        if len(string) >= 4:
            for n in range(len(string)/4):
                type_ = "unsigned int *"
                ptr = "intp" + str(n)
                self.emit('static %s%s = (%s)\"%s\";'
                           % (type_, ptr, type_, string[:4]))
                cmp_stmt.append((type_, offset, "*"+ptr))
                string = string[4:]
                offset += 4
        if len(string) >= 2:
            type_ = "unsigned short int *"
            ptr   = "shortp"
            self.emit('static %s%s = (%s)\"%s\";'
                      % (type_, ptr, type_, string[:2]))
            cmp_stmt.append((type_, offset, "*"+ptr))
            offset += 2
            string = string[2:]
        if len(string) == 1:
            ptr = "'%s'" % string[0]
            cmp_stmt.append(("unsigned char *", offset, ptr))
            offset += 1

        self.emit()
        self.emit0("if (")
        for stmt in cmp_stmt:
            self.emit0("*(%s)((unsigned char *)s+%d) == %s" % stmt)
            if stmt != cmp_stmt[-1]:
                self.emit(" && ")
        self.emiti(")")
        self.emit ("return %s(s+%d);" % (self.state_name(next), offset))
        self.demit()

    def emit_strcmp2(self, string, next):
        self.iemit ( "ls = s;")
        # emit -> if (cmp_stmt && cmp_stmt && ...)
        self.emit0("if (")
        self.emit ("*ls++ == '%c'" % string[0])
        for char in string[1:]:
            self.emit0(" && *ls++ == '%c'" % char)
        self.emiti(")")
        self.emit ("return %s(ls);" % self.state_name(next))
        self.demit()

    def emit_strcmp3(self, string, next):
        self.emit('static unsigned char* string = \"%s\";' % string)
        self.emiti("if (memcmp(string, s, %d) == 0)" % len(string))
        self.emit("return %s(s+%d);" % (self.state_name(next), len(string)))
        self.demit()

    def emit_switch(self, case, default=None):
        if not case:
            if default:
                self.emit("return %s(s);" % default)
            return
        self.emiti("switch(*s++) {")
        for case, next_ in case.iteritems():
            self.trans_stmt.emit(case, self.state_name(next_))
        if default:
            self.emit("default: return %s(s);" % default)
        self.demit("}")

    def emit_state(self, cur_state, transition):
        self.emiti("int %s(unsigned char* s) {" % self.state_name(cur_state))

        if self.debug:
            self.emit(r'printf("state: %s, input: %%s\n", s);' % cur_state)
        if type(self.fa) == DFA:
            default = None
            if '' in transition:
                epsilon_transition = transition.pop('')
                for n in epsilon_transition:
                    self.emit("return %s(s);\n" % self.state_name(n))
        else:
            default = "reject"

        for input_ in transition.keys():
            if type(input_) in self.special_rule:
                self.trans_stmt.emit(input_, self.state_name(transition.pop(input_)))
            elif type(input_) is AnyChar:
                default = self.state_name(transition.pop(input_))

        if cur_state in self.fa.accepts:
            for eol in self.eols:
                transition[eol] = "accept"
        elif default != "reject":
            for eol in self.eols:
                transition[eol] = "reject"

        self.emit_switch(transition, default)

        self.demit("}", 2)

    def emit_initialization(self):
        self.emit("#include <stdio.h>")
        for state in self.fa.transition.iterkeys():
            self.emit("int %s(unsigned char* s);" % self.state_name(state))
        self.emit('int accept(unsigned char* s);')
        self.emit('int reject(unsigned char* s);')
        self.emit('unsigned char* buf;')
        self.emit_skip()

    def emit_from_callgraph(self):
        # self.emit C-source code
        self.emit_initialization()
        self.emit_driver()

        for cur_state, transition in self.fa.transition.iteritems():
            self.emit_state(cur_state, transition)

        self.emit_accept_state()
        self.emit_reject_state()

    class _trans_stmt(ASTWalker):
        def __init__(self, emit):
            self._emit = emit

        def emit(self, input_node, next_):
            self.next = next_
            input_node.accept(self)

        def visit(self, input_node):
            self._emit("/* UNKNOW RULE */")
            self._emit("/* %s */" % input_node.__repr__())

        def visit_Character(self, char):
            self._emit("case %d: /* match %s */" % (char.char, char))
            self._emit("  return %s(s);" % self.next)

        def visit_EndLine(self, endline):
            self._emit("/* end of line  */")
            self._emit(r"case '\0':")
            self._emit("  return %s(s);" % self.next)

        # Special Rule

        def visit_MBCharacter(self, mbchar):
            self._emit("/* match %s  */" % mbchar)
            bytes = mbchar.bytes
            self._emit("  if(%s)" % \
                       " && ".join(["*(s+%d) == 0x%x" % (d, x) for d, x in enumerate(bytes)]))
            self._emit("    return %s(s+%d);" % (self.next, len(bytes)), 2)

        def visit_BegLine(self, begline):
            self._emit("/* begin of line  */")
            self._emit("if (s == buf)")
            self._emit("  return %s(s);" % self.next, 2)

        def visit_Range(self, range):
            if isinstance(range.lower, MBCharacter) and not \
               isinstance(range.upper, MBCharacter) or  \
               isinstance(range.upper, MBCharacter) and not \
               isinstance(range.lower, MBCharacter):
                return

            if isinstance(range.lower, MBCharacter):
                self.visit(range)
            else:
                self._emit("if ('%s' <= *s && *s <= '%s')" % (range.lower.char, range.upper.char))
                self._emit("  return %s(s+1);" % self.next, 2)

def test():
    import doctest
    doctest.testmod()

if __name__ == '__main__': test()