主页 >> 程序猿的东西 >> 编译原理极简入门:表达式求值

编译原理极简入门:表达式求值

本算法基于编译原理实现,可作为以上三个 leetcode 的通解。同时,稍作改动即可实现:生成汇编指令、后缀表达式、AST 语法树等。

思路

表达啥求值属于编译器的一部分,我们就根据编译原理的方法实现。因为一个编译器的复杂度是单纯表达式求值的成百上千倍,所以这种方式实现的代码结构更清晰易懂容易维护。

解题方法

编译原理简介

编译原理是计算机科学里面最接近科学的部分,整个编译过程都是一个不断抽象的过程。词法分析是把一个个的字符组织成单词;语法分析是把一个个单词组织成语句;整个编译过程又是对语句的抽象。

这种逐层抽象的方法是编程思维里面最重要的部分,不懂抽象的程序员很容易遇到发展瓶颈。

词法分析器 tokenizer

词法分析就是把一个个字符组织成单词,众所周不知的正则表达式正是为了解决这一问题而发明的。词法分析产生的结果我们叫做 token。

在完整的编译器中词法分析非常复杂,不仅因为token 的类型很多,还有许多状态需要管理。比如在表达式中遇到数字就可以直接提取为 token,但是在一个完整的编译器里面,遇到数字可能是在字符串里面,你必须把整个字符串作为一个 token。或者是代码在注释里面,你必须把整个注释作为一个 token。

在完整的编译器里面会用二元组来表示token,一个是 token的类型,一个是 token 具体的值。比如<数字, “123”>、<字符串, “123”>、<RETURN, “return”>、<PLUS, “+”>……

但是在表达式求值中,我们可以遇到的token只有 3 种:数字,运算符,终止符。而运算符和终止符也很简单,都是单字符的,因此不需要复杂的状态管理和二元组表示,只要把数字和其他字符一个个分开即可

语法分析器 parser

语法分析就是实现token的“正则表达式”。在语法定义里面有一个更专业的名词叫巴科斯范式(英语:Backus Normal Form,缩写为 BNF)。现代编译器一定是先定义 BNF,这才是编译器设计里面最难的部分,需要严谨的数学逻辑能力。有了 BNF 之后,我们的代码实现会非常容易。

expression => term [ + term | - term ] ...
term => factor [ * factor | / factor ] ...
factor => NUM | ( expression ) | -factor

这里定义了三种BNF,expression(表达式),term(项),factor(因子)。我们在代码里面可以直接根据这三个范式来实现。

  • expression 被定义为 term 的连续加减法。
  • term 被定义为 factor 的连续乘除法。
  • factor 有三种:数字,括号表达式,负号一元运算符。

可以看到下面的代码中实现了这三个同名的函数,每个函数返回两个值i和a,i表示编译器当前的扫描位置,a表示当前函数的计算结果。

在本题的官方题解中,总是少不了提到两个栈,以及运算符的优先级处理。而在 BNF 中,已经隐含实现了栈和优先级,比如 expression 只有加减法,乘除法全部在 term 中实现,也就表明在计算加减法之前就要计算好 term,而栈也是用函数间的调用来隐含实现了。这是这种算法更清晰易读的原因。

复杂度

时间复杂度: O(N)

Code

class Solution:

    def tokenizer(self, expr: str):
        """
        词法分析
        """
        expr = expr.replace(" ", "")
        tokens = []
        for c in expr:
            if c.isdigit():
                if tokens and tokens[-1].isdigit():
                    tokens[-1] += c
                else:
                    tokens.append(c)
            else:
                tokens.append(c)
        return tokens 
    
    def parser(self, tokens):
        """
        语法分析: BNF表达式
        expression => term [ + term | - term ] ...
        term => factor [ * factor | / factor ] ...
        factor => NUM | ( expression ) | -factor
        """
        
        # 实现计算单元
        def cpu(a, op , b):
            if op == "+":
                return a+b
            elif op == "-":
                return a-b
            elif op == "*":
                return a*b
            elif op == "/":
                return a//b

        # 以下实现 BNF 表达式的语法分析
        def expression(i):
            i, a = term(i)
            while(tokens[i] in ["+","-"]):
                op = tokens[i]
                i += 1
                i, b = term(i)
                a = cpu(a, op , b)
            return i, a
        def term(i):
            i, a = factor(i)
            while(tokens[i] in ["*","/"]):
                op = tokens[i]
                i += 1
                i, b = factor(i)
                a = cpu(a, op , b)
            return i, a
        def factor(i):
            if tokens[i].isdigit():
                return i+1, int(tokens[i])
            elif tokens[i] == "-":
                i += 1
                i, a = factor(i)
                return i, -1*a
            elif tokens[i] == "(":
                i += 1
                i, a = expression(i)
                if tokens[i] == ")":
                    i += 1
                    return i, a
                else:
                    raise Exception("SyntaxError: near the '%s'"%(tokens[i]))
            else:
                raise Exception("SyntaxError: near the '%s'"%(tokens[i]))
        
        # 表达式解析开始
        i,res = expression(0)
        if tokens[i] != ";":
            raise Exception("SyntaxError: incorrect ending near '%s'"%(tokens[i]))
        return res

    def calculate(self, expr: str) -> int:
        tokens = self.tokenizer(expr+";")
        return self.parser(tokens)

s = Solution()
print(s.calculate("2+3*4"))
print(s.calculate("(2+3)/4"))
print(s.calculate("14/3*2"))
print(s.calculate("0"))
print(s.calculate("1-(     -2)"))
print(s.calculate("1 + 1"))
print(s.calculate(" 2-1 + 2 "))
print(s.calculate("1+2*5/3+6/4*2"))
print(s.calculate("3+2*2"))
print(s.calculate(" 3/2 "))
print(s.calculate(" 3+5 / 2 "))
print(s.calculate("1+1+1"))
print(s.calculate("2-3+4"))
print(s.calculate("1*2-3/4+5*6-7*8+9/10"))
# 测试语法错误
print(s.calculate("100+20)"))

加入知识星球,与我探讨编译原理和算法。