Coverage for /home/runner/work/AutoDiff/AutoDiff/autodiff/ad.py: 100%
Generated by Amelia Li for AutoDiff. (GitHub Profile)
68 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-07 04:22 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-07 04:22 +0000
1# File : ad.py
2# Description: Parent class AD that stores the function passed in by the user
3# to perform automatic differentiation on
4import inspect
5import numpy as np
7class AD:
8 """Automatic differentiation base class."""
10 _supported_types = (int, float, np.ndarray, list)
11 _supported_scalars = (int, float)
12 _supported_vectors = (np.ndarray, list)
14 def __init__(self, f, inputs=[]):
15 """
16 Initialize the function of which the derivative will be calculated based on input 'f'.
18 Parameters
19 ----------
20 f : array-like
21 Input with one or multiple functions.
23 inputs : array-like
24 List of input variables.
25 """
26 self.f = f
27 self.inputs = inputs
28 self.jacobian = False
30 # check if user passed in a list-type of functions, if True, set jacobian to true
31 if isinstance(self.f, self._supported_vectors):
32 # check every function in list type is a function
33 for f in self.f:
34 if not inspect.isfunction(f):
35 raise TypeError(f"Unsupported type '{type(f)}'")
36 # set self.jacobian to true
37 self.jacobian = True
38 # if the user passed in one function, check that it is a function
39 elif not inspect.isfunction(self.f):
40 raise TypeError(f"Unsupported type '{type(self.f)}'")
42 # check if user passed in a list-type of variable(s)
43 if not isinstance(self.inputs, self._supported_vectors):
44 raise TypeError(f"Unsupported type '{type(self.inputs)}'")
46 # convert inputs of supported type into a list
47 self.inputs = list(self.inputs)
49 # store the length of inputs
50 self.n = len(self.inputs)
52 # check that the input list is not empty
53 if self.n == 0:
54 raise ValueError("Input list is empty.")
56 # check if every element in self.inputs is a string
57 for i in self.inputs:
58 if type(i) != str:
59 raise TypeError(f"Unsupported type '{type(i)}' for input elements.")
61 # check if every argument in the function(s) are present in the input
62 if isinstance(self.f, self._supported_vectors):
63 for f in list(self.f):
64 function_args = inspect.getfullargspec(f)[0]
65 for arg in function_args:
66 if arg not in self.inputs:
67 raise ValueError(f"Argument '{arg}' is not in '{self.inputs}'.")
68 else:
69 function_args = inspect.getfullargspec(self.f)[0]
70 for arg in function_args:
71 if arg not in self.inputs:
72 raise ValueError(f"Argument '{arg}' is not in '{self.inputs}'.")
74 def get_function(self):
75 """
76 Get the function.
78 Returns
79 -------
80 f
81 The method returns the function 'f' .
83 """
84 return self.f
86 def get_f(self, x):
87 """
88 Returns the value(s) of the function(s) evaluated at input 'x' computed by get_results.
90 Parameters
91 ----------
92 x : Scalar, Vector.
93 The point at which the function(s) is evaluated.
95 Returns
96 -------
97 f(x)
98 The method returns the value(s) of the function(s) evaluated at 'x'.
100 Raises
101 ------
102 TypeError
103 This method raises a `TypeError` if the type of input 'x' is not supported.
105 ValueError
106 This method also raises a `ValueError` if the dimension of input 'x' is not matched with the function(s).
108 """
109 return self.get_results(x)[0]
111 def get_f_prime(self, x):
112 """
113 Returns the derivative(s) of the function(s) based on input 'x' computed by get_results.
115 Parameters
116 ----------
117 x : Scalar, Vector.
118 The point at which the derivative(s) of the function(s) is evaluated.
120 Returns
121 -------
122 f'(x)
123 The method returns the derivative(s) of the function(s) at 'x'.
125 Raises
126 ------
127 TypeError
128 This method raises a `TypeError` if the type of input 'x' is not supported.
130 ValueError
131 This method also raises a `ValueError` if the dimension of input 'x' is not matched with the function(s).
133 """
134 return self.get_results(x)[1]
136 ### Square Root Function ###
137 def sqrt(self):
138 """
139 Call the sqrt function in Dual or Node.
140 """
141 return self.__class__.sqrt(self)
143 ### Exponential Function ###
144 def exp(self):
145 """
146 Call the exp function in Dual or Node.
147 """
148 return self.__class__.exp(self)
150 ### Logarithmic Function ###
151 def log(self, base):
152 """
153 Call the log function in Dual or Node.
154 """
155 return self.__class__.log(self, base)
157 ### Logistic Function ###
158 def standard_logistic(self):
159 """
160 Call the standard_logistic function in Dual or Node.
161 """
162 return self.__class__.standard_logistic(self)
164 ### Trigonometric Functions ###
165 def sin(self):
166 """
167 Call the sin function in Dual or Node.
168 """
169 return self.__class__.sin(self)
171 def cos(self):
172 """
173 Call the cos function in Dual or Node.
174 """
175 return self.__class__.cos(self)
177 def tan(self):
178 """
179 Call the cos function in Dual or Node.
180 """
181 return self.__class__.tan(self)
183 ### Inverse Trigonometric Functions ###
184 def arcsin(self):
185 """
186 Call the arcsin function in Dual or Node.
187 """
188 return self.__class__.arcsin(self)
190 def arccos(self):
191 """
192 Call the arccos function in Dual or Node.
193 """
194 return self.__class__.arccos(self)
196 def arctan(self):
197 """
198 Call the arctan function in Dual or Node.
199 """
200 return self.__class__.arctan(self)
202 ### Hyperbolic Functions ###
203 def sinh(self):
204 """
205 Call the sinh function in Dual or Node.
206 """
207 return self.__class__.sinh(self)
209 def cosh(self):
210 """
211 Call the cosh function in Dual or Node.
212 """
213 return self.__class__.cosh(self)
215 def tanh(self):
216 """
217 Call the tanh function in Dual or Node.
218 """
219 return self.__class__.tanh(self)