feq/diff.lua

110 lines
2.5 KiB
Lua

--Bespoke symbolic differentiation.
--Little prototypes for building a syntax tree containing only
--multiplication, addition, and composition,
--evaluating its value, and expanding it according to the chain rule.
local op;
local function d( x ) return x:diff() end
local function D( x ) return setmetatable( { ["n"] = x }, op.atom ) end
local computedDerivatives = {}
local __call = D
local __add = function( a, b ) return setmetatable( { a, b }, op.add ) end
local __mul = function( a, b ) return setmetatable( { a, b }, op.mul ) end
local __pow = function( a, b ) return setmetatable( { a, b }, op.compose ) end
op = {
atom = {
__add = __add,
__mul = __mul,
__pow = __pow,
__call = __call,
__tostring = function( self ) return tostring( self.n ) end,
__index = {
eval = function( self )
return computedDerivatives[ self.n ]
end,
diff = function( self )
return D( self.n + 1 )
end,
}},
compose = {
__add = __add,
__mul = __mul,
__pow = __pow,
__call = __call,
__tostring = function( self ) return "("..tostring( self[1] ).." o "..tostring( self[2] )..")" end,
__index = {
eval = function( self )
--All compositions are of the form f^(n) o f,
--whose value at a fixed point p is f^(n)(p)
return self[1]:eval()
end,
diff = function( self )
--All compositions are of the form f^(n) o f,
--whose derivatives are always ( f^(n+1) o f )
return d( self[1] ) ^ D(0) * D(1)
end,
}},
add = {
__add = __add,
__mul = __mul,
__pow = __pow,
__call = __call,
__tostring = function( self ) return tostring( self[1] ).." + "..tostring( self[2] ) end,
__index = {
name = "add",
eval = function( self )
return self[1]:eval() + self[2]:eval()
end,
diff = function( self )
return d(self[1]) + d(self[2])
end,
}},
mul = {
__add = __add,
__mul = __mul,
__pow = __pow,
__call = __call,
__tostring = function( self ) return tostring( self[1] ).." x "..tostring( self[2] ) end,
__index = {
eval = function( self )
return self[1]:eval() * self[2]:eval()
end,
diff = function( self )
return d(self[1]) * self[2] + self[1] * d(self[2])
end,
}},
}
op.New = function( nodeType, a, b )
return setmetatable( { a, b }, op[nodeType] or error( "Node Type Not Recognized" ) )
end
local a = D(1) ^ D(0)
for i = 1, 5 do
print( a )
a = a:diff()
end
return op