A Deep Reinforcement Learning Framework for Stock Market Trading

Overview

DQN-Trading

This is a framework based on deep reinforcement learning for stock market trading. This project is the implementation code for the two papers:

The deep reinforcement learning algorithm used here is Deep Q-Learning.

Acknowledgement

Requirements

Install pytorch using the following commands. This is for CUDA 11.1 and python 3.8:

pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
  • python = 3.8
  • pandas = 1.3.2
  • numpy = 1.21.2
  • matplotlib = 3.4.3
  • cython = 0.29.24
  • scikit-learn = 0.24.2

TODO List

  • Right now this project does not have a code for getting user hyper-parameters from terminal and running the code. We preferred writing a jupyter notebook (Main.ipynb) in which you can set the input data, the model, along with setting the hyper-parameters.

  • The project also does not have a code to do Hyper-parameter search (its easy to implement).

  • You can also set the seed for running the experiments in the original code for training the models.

Developers' Guidelines

In this section, I briefly explain different parts of the project and how to change each. The data for the project downloaded from Yahoo Finance where you can search for a specific market there and download your data under the Historical Data section. Then you create a directory with the name of the stock under the data directory and put the .csv file there.

The DataLoader directory contains files to process the data and interact with the RL agent. The DataLoader.py loads the data given the folder name under Data folder and also the name of the .csv file. For this, you should use the YahooFinanceDataLoader class for using data downloaded from Yahoo Finance.

The Data.py file is the environment that interacts with the RL agent. This file contains all the functionalities used in a standard RL environment. For each agent, I developed a class inherited from the Data class that only differs in the state space (consider that states for LSTM and convolutional models are time-series, while for other models are simple OHLCs). In DataForPatternBasedAgent.py the states are patterns extracted using rule-based methods in technical analysis. In DataAutoPatternExtractionAgent.py states are Open, High, Low, and Close prices (plus some other information about the candle-stick like trend, upper shadow, lower shadow, etc). In DataSequential.py as it is obvious from the name, the state space is time-series which is used in both LSTM and Convolutional models. DataSequencePrediction.py was an idea for feeding states that have been predicted using an LSTM model to the RL agent. This idea is raw and needs to be developed.

Where ever we used encoder-decoder architecture, the decoder is the DQN agent whose neural network is the same across all the models.

The DeepRLAgent directory contains the DQN model without encoder part (VanillaInput) whose data loader corresponds to DataAutoPatternExtractionAgent.py and DataForPatternBasedAgent.py; an encoder-decoder model where the encoder is a 1d convolutional layer added to the decoder which is DQN agent under SimpleCNNEncoder directory; an encoder-decoder model where encoder is a simple MLP model and the decoder is DQN agent under MLPEncoder directory.

Under the EncoderDecoderAgent there exist all the time-series models, including CNN (two-layered 1d CNN as encoder), CNN2D (a single-layered 2d CNN as encoder), CNN-GRU (the encoder is a 1d CNN over input and then a GRU on the output of CNN. The purpose of this model is that CNN extracts features from each candlestick, thenGRU extracts temporal dependency among those extracted features.), CNNAttn (A two-layered 1d CNN with attention layer for putting higher emphasis on specific parts of the features extracted from the time-series data), and a GRU encoder which extracts temporal relations among candles. All of these models use DataSequential.py file as environment.

For running each agent, please refer to the Main.py file for instructions on how to run each agent and feed data. The Main.py file also has code for plotting results.

The Objects directory contains the saved models from our experiments for each agent.

The PatternDetectionCandleStick directory contains Evaluation.py file which has all the evaluation metrics used in the paper. This file receives the actions from the agents and evaluate the result of the strategy offered by each agent. The LabelPatterns.py uses rule-based methods to generate buy or sell signals. Also Extract.py is another file used for detecting wellknown candlestick patterns.

RLAgent directory is the implementation of the traditional RL algorithm SARSA-λ using cython. In order to run that in the Main.ipynb you should first build the cython file. In order to do that, run the following script inside it's directory in terminal:

python setup.py build_ext --inplace

This works for both linux and windows.

For more information on the algorithms and models, please refer to the original paper. You can find them in the references.

If you had any questions regarding the paper, code, or you wanted to contribute, please send me an email: [email protected]

References

@article{taghian2020learning,
  title={Learning financial asset-specific trading rules via deep reinforcement learning},
  author={Taghian, Mehran and Asadi, Ahmad and Safabakhsh, Reza},
  journal={arXiv preprint arXiv:2010.14194},
  year={2020}
}

@article{taghian2021reinforcement,
  title={A Reinforcement Learning Based Encoder-Decoder Framework for Learning Stock Trading Rules},
  author={Taghian, Mehran and Asadi, Ahmad and Safabakhsh, Reza},
  journal={arXiv preprint arXiv:2101.03867},
  year={2021}
}
Automatic number plate recognition using tech: Yolo, OCR, Scene text detection, scene text recognation, flask, torch

Automatic Number Plate Recognition Automatic Number Plate Recognition (ANPR) is the process of reading the characters on the plate with various optica

Meftun AKARSU 52 Dec 22, 2022
PyTorch implementaton of our CVPR 2021 paper "Bridging the Visual Gap: Wide-Range Image Blending"

Bridging the Visual Gap: Wide-Range Image Blending PyTorch implementaton of our CVPR 2021 paper "Bridging the Visual Gap: Wide-Range Image Blending".

Chia-Ni Lu 69 Dec 20, 2022
A lightweight deep network for fast and accurate optical flow estimation.

FastFlowNet: A Lightweight Network for Fast Optical Flow Estimation The official PyTorch implementation of FastFlowNet (ICRA 2021). Authors: Lingtong

Tone 161 Jan 03, 2023
Custom IMDB Dataset is extracted between 2020-2021 and custom distilBERT model is trained for movie success probability prediction

IMDB Success Predictor Project involves Web Scraping custom IMDB data between 2020 and 2021 of 10000 movies and shows sorted by number of votes ,fine

Gautam Diwan 1 Jan 18, 2022
Blender scripts for computing geodesic distance

GeoDoodle Geodesic distance computation for Blender meshes Table of Contents Overivew Usage Implementation Overview This addon provides an operator fo

20 Jun 08, 2022
Repository for "Space-Time Correspondence as a Contrastive Random Walk" (NeurIPS 2020)

Space-Time Correspondence as a Contrastive Random Walk This is the repository for Space-Time Correspondence as a Contrastive Random Walk, published at

A. Jabri 239 Dec 27, 2022
On Generating Extended Summaries of Long Documents

ExtendedSumm This repository contains the implementation details and datasets used in On Generating Extended Summaries of Long Documents paper at the

Georgetown Information Retrieval Lab 76 Sep 05, 2022
Implementation of Learning Gradient Fields for Molecular Conformation Generation (ICML 2021).

[PDF] | [Slides] The official implementation of Learning Gradient Fields for Molecular Conformation Generation (ICML 2021 Long talk) Installation Inst

MilaGraph 117 Dec 09, 2022
The toolkit to generate auto labeled datasets

Ozeu Ozeu is the toolkit to autolabal dataset for instance segmentation. You can generate datasets labaled with segmentation mask and bounding box fro

Xiong Jie 28 Mar 28, 2022
Project Aquarium is a SUSE-sponsored open source project aiming at becoming an easy to use, rock solid storage appliance based on Ceph.

Project Aquarium Project Aquarium is a SUSE-sponsored open source project aiming at becoming an easy to use, rock solid storage appliance based on Cep

Aquarist Labs 73 Jul 21, 2022
EMNLP 2021 paper Models and Datasets for Cross-Lingual Summarisation.

This repository contains data and code for our EMNLP 2021 paper Models and Datasets for Cross-Lingual Summarisation. Please contact me at

9 Oct 28, 2022
Generating Band-Limited Adversarial Surfaces Using Neural Networks

Generating Band-Limited Adversarial Surfaces Using Neural Networks This is the official repository of the technical report that was published on arXiv

3 Jul 26, 2022
[ICLR 2021] "Neural Architecture Search on ImageNet in Four GPU Hours: A Theoretically Inspired Perspective" by Wuyang Chen, Xinyu Gong, Zhangyang Wang

Neural Architecture Search on ImageNet in Four GPU Hours: A Theoretically Inspired Perspective [PDF] Wuyang Chen, Xinyu Gong, Zhangyang Wang In ICLR 2

VITA 156 Nov 28, 2022
Re-implement CycleGAN in Tensorlayer

CycleGAN_Tensorlayer Re-implement CycleGAN in TensorLayer Original CycleGAN Improved CycleGAN with resize-convolution Prerequisites: TensorLayer Tenso

89 Aug 15, 2022
Deep Learning Visuals contains 215 unique images divided in 23 categories

Deep Learning Visuals contains 215 unique images divided in 23 categories (some images may appear in more than one category). All the images were originally published in my book "Deep Learning with P

Daniel Voigt Godoy 1.3k Dec 28, 2022
Head and Neck Tumour Segmentation and Prediction of Patient Survival Project

Head-and-Neck-Tumour-Segmentation-and-Prediction-of-Patient-Survival Welcome to the Head and Neck Tumour Segmentation and Prediction of Patient Surviv

5 Oct 20, 2022
TransMVSNet: Global Context-aware Multi-view Stereo Network with Transformers.

TransMVSNet This repository contains the official implementation of the paper: "TransMVSNet: Global Context-aware Multi-view Stereo Network with Trans

旷视研究院 3D 组 155 Dec 29, 2022
Generative Modelling of BRDF Textures from Flash Images [SIGGRAPH Asia, 2021]

Neural Material Official code repository for the paper: Generative Modelling of BRDF Textures from Flash Images [SIGGRAPH Asia, 2021] Henzler, Deschai

Philipp Henzler 80 Dec 20, 2022
ppo_pytorch_cpp - an implementation of the proximal policy optimization algorithm for the C++ API of Pytorch

PPO Pytorch C++ This is an implementation of the proximal policy optimization algorithm for the C++ API of Pytorch. It uses a simple TestEnvironment t

Martin Huber 59 Dec 09, 2022
Implementation of Memory-Efficient Neural Networks with Multi-Level Generation, ICCV 2021

Memory-Efficient Multi-Level In-Situ Generation (MLG) By Jiaqi Gu, Hanqing Zhu, Chenghao Feng, Mingjie Liu, Zixuan Jiang, Ray T. Chen and David Z. Pan

Jiaqi Gu 2 Jan 04, 2022