Linear Regression with a Graph Database
How to Run a Linear Regression using Pokémon Data from a Graph in TigerGraph
(Note: This is a bonus blog based on a series. To obtain the data used, refer to the past blogs.)
Hello! Today, we’re going to learn how to run a linear regression using TigerGraph. In this example, we’ll be running a linear regression with the Pokémon’s height and weight. Let’s get to it!
Part I: Setup
To start, we’re going to need to import the necessary libraries and create a connection to our graph.
The libraries we’ll be using are pyTigerGraph, Plotly Express, and Pandas. You can import them using the following:
import pyTigerGraph as tgimport plotly.express as pximport pandas as pd
Note that if you are using a Colab notebook, you’ll need to install pyTigerGraph first:
!pip install pyTigerGraph
Now that you have your libraries installed, start your solution on TigerGraph Cloud and create a connection in Python.
conn = tg.TigerGraphConnection(host="https://SUBDOMAIN.i.tgcloud.io", password="tigergraph", gsqlVersion="3.0.5", useCert=True)conn.graphname = "PokemonGraph"conn.apiToken = conn.getToken(conn.createSecret())
Step II: Write a Query
Next, write a query to grab all the Pokémon. Here, we’ll just create a seed with all the Pokemon vertices and use that. We’ll assign the results of the query to a variable called result
.
result = conn.runInterpretedQuery('''INTERPRET QUERY () FOR GRAPH PokemonGraph {catchThemAll = {Pokemon.*}; # Pokemon.* means that it'll grab all the PokemonPRINT catchThemAll; # print displays the results}''')
Step III: Create a Dataframe
To easily plot the data into Plotly, we’re going to need to format our results into a dataframe. Here’s how we’d do that:
result = result[0]["catchThemAll"] # Grabbing just the catchThemAll result# The lists of what we need
height = []
weight = []for i in result: # For each pokémon
height.append(i["attributes"]["height"]) # Add the height
weight.append(i["attributes"]["weight"]) # Add the weightd = {'height': height, 'weight': weight} # Put it in a dictionarydf = pd.DataFrame(data=d) # Put it in a dataframe
Step IV: Plot it on a Scatter Plot!
Finally, plotting everything is as simple as two lines:
fig = px.scatter(df, x="height", y="weight", trendline="ols")fig.show()
And with that, you should be done!
Step V: Explore and Experiment!
That’s it! You’ve now created a linear regression using data from a graph database! Try to query for other variables and run a linear regression using those variables or enhance the graph, adding a title, changing the colours, etc. Good luck!
You can also reference the Colab here: https://colab.research.google.com/drive/1U1mG7gzce3lCLhrsi0KCUZn3z0v-iLca