{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "List of keyboard shortcuts:\n", "1. \"Ctrl-Enter\" - Run a cell\n", "2. \"Shift-Enter\" - Run a cell and go to the next cell\n", "3. \"Enter\" - Start editing cell\n", "4. \"Escape\" - Stop editing cell\n", "5. \"x\" - delete current cell\n", "6. \"a\" - create a new cell above\n", "7. \"b\" - create a new cell below\n", "\n", "**Make sure you run every cell when you first load this file.**\n", "\n", "**If you change a value in a cell, you need to run it for the effect to take place.**\n", "\n", "**If the computation is taking too long, you can interupt the computation in the Kernel menu above.**" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from IPython.display import clear_output\n", "x,y = var('x y')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The cell below generates a random set of points in the range $-5\\leq x \\leq 5$. " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "#------Generate random Data-----------\n", "sigma = 2\n", "T = RealDistribution('gaussian', sigma)\n", "S = RealDistribution('uniform', [-5,5])\n", "A,B,C= 10,2,3\n", "data=[]\n", "for i in range(50):\n", " x= S.get_random_element()\n", " y= A + B* x+ C* x **2 + 5 * T.get_random_element()\n", " data.append([x,y])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Once you ran this code with the random input from above, go to http://math.jhu.edu/~vzakharevich/teaching/fall2020/data.php and insert the result in the cell below. The data will be in the range $-15 \\leq x \\leq 15 $. You will probably need to decrease the step size in order for it to work. Also, don't forget to change the $x$ bounds when you are graphing your result at the bottom of the page" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "#---------Insert data from http://math.jhu.edu/~vzakharevich/teaching/fall2020/data.php -------\n", "#data=[]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/lib/python3/dist-packages/sage/plot/graphics.py:2327: MatplotlibDeprecationWarning: \n", "The OldScalarFormatter class was deprecated in Matplotlib 3.3 and will be removed two minor releases later.\n", " x_formatter = OldScalarFormatter()\n", "/usr/lib/python3/dist-packages/sage/plot/graphics.py:2352: MatplotlibDeprecationWarning: \n", "The OldScalarFormatter class was deprecated in Matplotlib 3.3 and will be removed two minor releases later.\n", " y_formatter = OldScalarFormatter()\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "Graphics object consisting of 1 graphics primitive" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#--------Plot Data--------------\n", "scatter_plot(data)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "#--------- Define the functions used in the gradient descent--------\n", "a,b,c,x,y = var('a b c x y')\n", "SE(a,b,c,x,y)= (a+ b * x + c * x **2 -y)**2 #Square error function\n", "MSE(a,b,c)= sum([ SE(a,b,c,x,y) for x,y in data])/len(data) #Mean square error function\n", "diff(a,b,c,x,y)= (a+ b * x + c * x **2 -y) # Function used in the definition of the negative gradient of MSE\n", "def neg_grad(a,b,c): #The negative gradient of the mean square error function\n", " return vector((-2 * sum([ diff(a,b,c,x,y) for x,y in data])/len(data), -2 * sum([ diff(a,b,c,x,y)* x for x,y in data])/len(data), -2 * sum([ diff(a,b,c,x,y)* x ** 2 for x,y in data])/len(data)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**When you run the algorithm below, the value of the mean square difference should be decreasing. That is the function which we are trying to minimize. If it is not decreasing, you should decrease the step size of the algorithm**" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The value of the mean square difference at the current step 1000 is 67.159234: \n", "The predicted function is f(x)=10.9316 + 2.5943 x + 2.8174 x^2 \n" ] } ], "source": [ "#-------------Running the algorithm-----------\n", "delta_t = 0.005 # Change this value to adjust the step size\n", "starting_point= vector((0,0,0)) # Change this value to adjust the starting point\n", "number_of_iterations = 1000 #Change this value to adjust the number of iteration of the algorithm\n", "\n", "c=[starting_point]\n", "for i in range(number_of_iterations):\n", " new_point= c[i] + delta_t * neg_grad(*tuple(c[i]))\n", " c.append(new_point)\n", " clear_output(wait=True)\n", " print(\"The value of the mean square difference at the current step %d is %f: \"% (i+1, MSE(*c[i])))\n", "print (\"The predicted function is f(x)=%.4f + %.4f x + %.4f x^2 \" % (c[number_of_iterations-1][0], c[number_of_iterations-1][1], c[number_of_iterations-1][2]))\n", "prediction=c[number_of_iterations]\n", "predicted_function(x)= prediction[0]+ prediction[1]*x+ prediction[2]* x**2" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/lib/python3/dist-packages/sage/plot/graphics.py:2327: MatplotlibDeprecationWarning: \n", "The OldScalarFormatter class was deprecated in Matplotlib 3.3 and will be removed two minor releases later.\n", " x_formatter = OldScalarFormatter()\n", "/usr/lib/python3/dist-packages/sage/plot/graphics.py:2352: MatplotlibDeprecationWarning: \n", "The OldScalarFormatter class was deprecated in Matplotlib 3.3 and will be removed two minor releases later.\n", " y_formatter = OldScalarFormatter()\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "Graphics object consisting of 2 graphics primitives" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#-------------Plot the predicted function next to the data points--------\n", "x_bounds= (x,-5, 5)\n", "scatter_plot(data)+plot(predicted_function, x_bounds)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "SageMath 9.0", "language": "sage", "name": "sagemath" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }