Coverage for /home/runner/work/AutoDiff/AutoDiff/autodiff/ad.py: 100%


Generated by Amelia Li for AutoDiff. (GitHub Profile)

68 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-07 04:22 +0000

1# File : ad.py 

2# Description: Parent class AD that stores the function passed in by the user 

3# to perform automatic differentiation on 

4import inspect 

5import numpy as np 

6 

7class AD: 

8 """Automatic differentiation base class.""" 

9 

10 _supported_types = (int, float, np.ndarray, list) 

11 _supported_scalars = (int, float) 

12 _supported_vectors = (np.ndarray, list) 

13 

14 def __init__(self, f, inputs=[]): 

15 """ 

16 Initialize the function of which the derivative will be calculated based on input 'f'. 

17 

18 Parameters 

19 ---------- 

20 f : array-like 

21 Input with one or multiple functions. 

22 

23 inputs : array-like 

24 List of input variables. 

25 """ 

26 self.f = f 

27 self.inputs = inputs 

28 self.jacobian = False 

29 

30 # check if user passed in a list-type of functions, if True, set jacobian to true 

31 if isinstance(self.f, self._supported_vectors): 

32 # check every function in list type is a function 

33 for f in self.f: 

34 if not inspect.isfunction(f): 

35 raise TypeError(f"Unsupported type '{type(f)}'") 

36 # set self.jacobian to true 

37 self.jacobian = True 

38 # if the user passed in one function, check that it is a function 

39 elif not inspect.isfunction(self.f): 

40 raise TypeError(f"Unsupported type '{type(self.f)}'") 

41 

42 # check if user passed in a list-type of variable(s) 

43 if not isinstance(self.inputs, self._supported_vectors): 

44 raise TypeError(f"Unsupported type '{type(self.inputs)}'") 

45 

46 # convert inputs of supported type into a list 

47 self.inputs = list(self.inputs) 

48 

49 # store the length of inputs 

50 self.n = len(self.inputs) 

51 

52 # check that the input list is not empty 

53 if self.n == 0: 

54 raise ValueError("Input list is empty.") 

55 

56 # check if every element in self.inputs is a string 

57 for i in self.inputs: 

58 if type(i) != str: 

59 raise TypeError(f"Unsupported type '{type(i)}' for input elements.") 

60 

61 # check if every argument in the function(s) are present in the input 

62 if isinstance(self.f, self._supported_vectors): 

63 for f in list(self.f): 

64 function_args = inspect.getfullargspec(f)[0] 

65 for arg in function_args: 

66 if arg not in self.inputs: 

67 raise ValueError(f"Argument '{arg}' is not in '{self.inputs}'.") 

68 else: 

69 function_args = inspect.getfullargspec(self.f)[0] 

70 for arg in function_args: 

71 if arg not in self.inputs: 

72 raise ValueError(f"Argument '{arg}' is not in '{self.inputs}'.") 

73 

74 def get_function(self): 

75 """ 

76 Get the function. 

77 

78 Returns 

79 ------- 

80 f 

81 The method returns the function 'f' . 

82 

83 """ 

84 return self.f 

85 

86 def get_f(self, x): 

87 """ 

88 Returns the value(s) of the function(s) evaluated at input 'x' computed by get_results. 

89 

90 Parameters 

91 ---------- 

92 x : Scalar, Vector.  

93 The point at which the function(s) is evaluated.  

94  

95 Returns 

96 ------- 

97 f(x) 

98 The method returns the value(s) of the function(s) evaluated at 'x'. 

99  

100 Raises 

101 ------ 

102 TypeError 

103 This method raises a `TypeError` if the type of input 'x' is not supported. 

104  

105 ValueError 

106 This method also raises a `ValueError` if the dimension of input 'x' is not matched with the function(s). 

107 

108 """ 

109 return self.get_results(x)[0] 

110 

111 def get_f_prime(self, x): 

112 """ 

113 Returns the derivative(s) of the function(s) based on input 'x' computed by get_results. 

114 

115 Parameters 

116 ---------- 

117 x : Scalar, Vector.  

118 The point at which the derivative(s) of the function(s) is evaluated.  

119 

120 Returns 

121 ------- 

122 f'(x) 

123 The method returns the derivative(s) of the function(s) at 'x'. 

124 

125 Raises 

126 ------ 

127 TypeError 

128 This method raises a `TypeError` if the type of input 'x' is not supported. 

129  

130 ValueError 

131 This method also raises a `ValueError` if the dimension of input 'x' is not matched with the function(s). 

132  

133 """ 

134 return self.get_results(x)[1] 

135 

136 ### Square Root Function ### 

137 def sqrt(self): 

138 """ 

139 Call the sqrt function in Dual or Node. 

140 """ 

141 return self.__class__.sqrt(self) 

142 

143 ### Exponential Function ### 

144 def exp(self): 

145 """ 

146 Call the exp function in Dual or Node. 

147 """ 

148 return self.__class__.exp(self) 

149 

150 ### Logarithmic Function ### 

151 def log(self, base): 

152 """ 

153 Call the log function in Dual or Node. 

154 """ 

155 return self.__class__.log(self, base) 

156 

157 ### Logistic Function ### 

158 def standard_logistic(self): 

159 """ 

160 Call the standard_logistic function in Dual or Node. 

161 """ 

162 return self.__class__.standard_logistic(self) 

163 

164 ### Trigonometric Functions ###  

165 def sin(self): 

166 """ 

167 Call the sin function in Dual or Node. 

168 """ 

169 return self.__class__.sin(self) 

170 

171 def cos(self): 

172 """ 

173 Call the cos function in Dual or Node. 

174 """ 

175 return self.__class__.cos(self) 

176 

177 def tan(self): 

178 """ 

179 Call the cos function in Dual or Node. 

180 """ 

181 return self.__class__.tan(self) 

182 

183 ### Inverse Trigonometric Functions ### 

184 def arcsin(self): 

185 """ 

186 Call the arcsin function in Dual or Node. 

187 """ 

188 return self.__class__.arcsin(self) 

189 

190 def arccos(self): 

191 """ 

192 Call the arccos function in Dual or Node. 

193 """ 

194 return self.__class__.arccos(self) 

195 

196 def arctan(self): 

197 """ 

198 Call the arctan function in Dual or Node. 

199 """ 

200 return self.__class__.arctan(self) 

201 

202 ### Hyperbolic Functions ### 

203 def sinh(self): 

204 """ 

205 Call the sinh function in Dual or Node. 

206 """ 

207 return self.__class__.sinh(self) 

208 

209 def cosh(self): 

210 """ 

211 Call the cosh function in Dual or Node. 

212 """ 

213 return self.__class__.cosh(self) 

214 

215 def tanh(self): 

216 """ 

217 Call the tanh function in Dual or Node. 

218 """ 

219 return self.__class__.tanh(self)