import os
from flask import Flask, request
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import pandas as pd
import gevent.pywsgi
from flask_cors import CORS
from datetime import datetime
from tensorflow.python.keras.backend import set_session
from keras import backend as K
import tensorflow.compat.v1 as tf
import keras_preprocessing
from keras.applications.xception import preprocess_input as xception_preprocess_input
from keras_preprocessing import image
import json
from PIL import Image
from hashlib import sha256
###########################
from tensorflow.keras.layers import Input, Dense, Conv2D, Activation, MaxPool2D
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.optimizers import Adam, SGD
import glob
import argparse
from keras import __version__
from keras.applications.xception import preprocess_input as xception_preprocess_input
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import SGD
from keras import optimizers
from keras import callbacks
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
from keras import backend as keras
from keras.regularizers import l2,l1
import pandas as pd
input_shape = (224,224,3)
nbr_of_classes=38
tf.disable_v2_behavior()
graph = tf.get_default_graph()
app = Flask(__name__)
CORS(app)
sess = tf.Session()
set_session(sess)
#ResTS ARCHITECTURE#RESTEACHER
base_model1 = tf.keras.applications.Xception(include_top=False, weights='imagenet',input_shape = input_shape)
x1_0 = base_model1.output
x1_0 = Flatten(name='Flatten1')(x1_0)
dense1 = Dense(256, name='fc1',activation='relu')(x1_0)
x = classif_out_encoder1 = Dense(38, name='out1', activation = 'softmax')(dense1) # Latent Representation / Bottleneck
#Get Xception's tensors for skip connection.
...
#DECODER
dense2 = Dense(256, activation='relu')(x)
x = Add(name='first_merge')([dense1, dense2])
x = Dense(7*7*2048)(x)
reshape1 = Reshape((7, 7, 2048))(x)
#BLOCK 1
...
#BLOCK 2
...
#BLOCK 3-10
...
#BLOCK 11
...
#BLOCK 12
...
#BLOCK 13
...
#BLOCK 14
...
x = Conv2D(2, 3, activation = 'relu', padding = 'same',)(x)
mask = x = Conv2D(3, 1, activation = 'sigmoid',name='Mask')(x)
#RESSTUDENT
base_model2 = tf.keras.applications.Xception(include_top=False, weights='imagenet',input_shape = (224,224,3))
x2_0 = base_model2(mask)
x2_0 = Flatten(name='Flatten2')(x2_0)
x2_1 = Dense(256, name='fc2',activation='relu')(x2_0)
classif_out_encoder2 = Dense(nbr_of_classes, name='out2',activation='softmax')(x2_1)
#Create ResTS Model and Load Pre-trained weights
ResTS = Model(base_model1.input, [classif_out_encoder1, classif_out_encoder2])
ResTS.load_weights('tf/RTS')
#For visualization impetuslayer_name ='Mask'
NewInput = ResTS.get_layer(layer_name).output
visualization = K.function([ResTS.input], [NewInput])
def reduce_channels_sequare(heatmap):
channel1 = heatmap[:,:,0]
channel2 = heatmap[:,:,1]
channel3 = heatmap[:,:,2]
new_heatmap = np.sqrt(((channel1-0.149)*(channel1-0.149))+((channel2-0.1529)*(channel2-0.1529))+((channel3-0.3412)*(channel3-0.3412)))
return new_heatmap
def postprocess_vis(heatmap1,threshould = 0.9):
heatmap = heatmap1.copy()
heatmap = (heatmap - heatmap.min())/(heatmap.max() - heatmap.min())
heatmap = reduce_channels_sequare(heatmap)
heatmap = (heatmap - heatmap.min())/(heatmap.max() - heatmap.min())
heatmap[heatmap>threshould] = 1
heatmap = heatmap*255
return heatmap
def visualize_image(img_name):
image_size = (224,224)
original_image = image.load_img(img_name, target_size=image_size)
img = image.img_to_array(original_image)
img = np.expand_dims(img, axis=0)
img = xception_preprocess_input(img)
global sess
global graph
with graph.as_default():
set_session(sess)
vis = visualization([img])[0][0]
disease = ResTS.predict(img)[0]
probab = max(disease[0])
disease = np.argmax(disease)
heatmap = postprocess_vis(vis)
img = plt.imshow(heatmap, cmap='Reds')
plt.axis('off')
plt.savefig(img_name, bbox_inches='tight')
return disease, probab
#cv2.imwrite('vis.jpg',vis)
Figure 6. Heatmap with OpenCV (without cmap = ‘Red’)

Creating a route ‘/ detect’

Figure 7. Flow of the “/ detect” route
@app.route('/detect', methods=['POST'])
def change():
image_size = (224,224)
img_data = request.get_json()['image']
img_name = str(int(datetime.timestamp(datetime.now()))) + str(np.random.randint(1000000000))
img_name = sha256(img_name.encode()).hexdigest()[0:12]
img_data = np.array(list(img_data.values())).reshape([224, 224, 3])
im = Image.fromarray((img_data).astype(np.uint8))
im.save(img_name+'.jpg')
disease, probab = visualize_image(img_name+'.jpg')
img = cv2.imread(img_name+'.jpg')
img = cv2.resize(img, image_size) / 255.0
img = img.tolist()
os.remove(img_name+'.jpg')
return json.dumps({"image": img, "disease":int(disease), "probab":str(probab)})
import logo from './logo.svg';
import './App.css';
import React from 'react';
import * as tf from '@tensorflow/tfjs';
import cat from './cat.jpg';
import {CLASSES} from './imagenet_classes';
const axios = require('axios');
const IMAGE_SIZE = 224;
let mobilenet;
let demoStatusElement;
let status;
let mobilenet2;
constructor(props){
super(props);
this.state = {
load:false,
status: "F1 score of the model is: 0.9908 ",
probab: ""
};
this.mobilenetDemo = this.mobilenetDemo.bind(this);
this.predict = this.predict.bind(this);
this.showResults = this.showResults.bind(this);
this.filechangehandler = this.filechangehandler.bind(this);
}
async mobilenetDemo(){
const catElement = document.getElementById('cat');
if (catElement.complete && catElement.naturalHeight !== 0) {
this.predict(catElement);
catElement.style.display = '';
} else {
catElement.onload = () => {
this.predict(catElement);
catElement.style.display = '';
}
}
};
async predict(imgElement) {
let img = tf.browser.fromPixels(imgElement).toFloat().reshape([1, 224, 224, 3]);
//img = tf.reverse(img, -1);
this.setState({
load:true
});
const image = await axios.post('http://localhost:5000/detect', {'image': img.dataSync()});
this.setState({
load:false
});
// // Show the classes in the DOM.
this.showResults(imgElement, image.data['disease'], image.data['probab'], tf.tensor3d([image.data['image']].flat(), [224, 224, 3]));
}
async showResults(imgElement, diseaseClass, probab, tensor) {
const predictionContainer = document.createElement('div');
predictionContainer.className = 'pred-container';
const imgContainer = document.createElement('div');
imgContainer.appendChild(imgElement);
predictionContainer.appendChild(imgContainer);
const probsContainer = document.createElement('div');
const predictedCanvas = document.createElement('canvas');
probsContainer.appendChild(predictedCanvas);
predictedCanvas.width = tensor.shape[0];
predictedCanvas.height = tensor.shape[1];
tensor = tf.reverse(tensor, -1);
await tf.browser.toPixels(tensor, predictedCanvas);
console.log(probab);
this.setState({
probab: "The last prediction was " + parseFloat(probab)*100 + " % accurate!"
});
const predictedDisease = document.createElement('p');
predictedDisease.innerHTML = 'Disease: ';
const i = document.createElement('i');
i.innerHTML = CLASSES[diseaseClass];
predictedDisease.appendChild(i);

//probsContainer.appendChild(predictedCanvas);
//probsContainer.appendChild(predictedDisease);

predictionContainer.appendChild(probsContainer);
predictionContainer.appendChild(predictedDisease);
const predictionsElement = document.getElementById('predictions');
predictionsElement.insertBefore(
predictionContainer, predictionsElement.firstChild);
}
filechangehandler(evt){
let files = evt.target.files;
for (let i = 0, f; f = files[i]; i++) {
// Only process image files (skip non image files)
if (!f.type.match('image.*')) {
continue;
}
let reader = new FileReader();
reader.onload = e => {
// Fill the image & call predict.
let img = document.createElement('img');
img.src = e.target.result;
img.width = IMAGE_SIZE;
img.height = IMAGE_SIZE;
img.onload = () => this.predict(img);
};
// Read in the image file as a data URL.
reader.readAsDataURL(f);
}
}
componentDidMount(){
this.mobilenetDemo();
}
render(){
return (
<div className="tfjs-example-container">
<section className='title-area'>
<h1>ResTS for Plant Disease Diagnosis</h1>
</section>
<section>
<p className='section-head'>Description</p>
<p>
This WebApp uses the ResTS model which will be made available soon for public use.
It is not trained to recognize images that DO NOT have BLACK BACKGROUNDS. For best performance, upload images of leaf/Plant with black background. You can see the disease categories it has been trained to recognize in <a
href="https://github.com/spMohanty/PlantVillage-Dataset/tree/master/raw/segmented">this folder</a>.
</p>
</section>
<section>
<p className='section-head'>Status</p>
{this.state.load?<div id="status">{this.state.status}</div>:<div id="status">{this.state.status}<br></br>{this.state.probab}</div>}
</section>
<section>
<p className='section-head'>Model Output</p>
<div id="file-container">
Upload an image: <input type="file" id="files" name="files[]" onChange={this.filechangehandler} multiple />
</div>
{this.state.load?<div className="lds-roller"><div></div><div></div><div></div><div></div><div></div><div></div><div></div><div></div></div>:''}
<div id="predictions"></div><img id="cat" src={cat}/>
</section>
</div>
);
}
Gif 1. Respond to the application

Contact me on LinkedIn here.

LEAVE A REPLY

Please enter your comment!
Please enter your name here