Comments (5)
Hi, you didn't describe your problem clearly. The two scenarios are quite different.
-
If you input one image, and then use TRIQ and then do segmentation. I am not sure the segmentation will be done on what? TRIQ does not change the input image, it only calculates the quality value. Please clarify your question.
-
If you want to input two images. You need to modify the create_triq_model method. You need to define the other operations (e.g., backbone) on the two input images, and how to combine them ...
`def create_triq_model(n_quality_levels,
input_shape=(None, None, 3),
backbone='resnet50',
transformer_params=(2, 32, 8, 64),
maximum_position_encoding=193,
vis=False):
"""
Creates the hybrid TRIQ model
:param n_quality_levels: number of quality levels, use 5 to predict quality distribution
:param input_shape: input shape
:param backbone: bakbone nets, supports ResNet50 and VGG16 now
:param transformer_params: Transformer parameters
:param maximum_position_encoding: the maximal number of positional embeddings
:param vis: flag to visualize attention weight maps
:return: TRIQ model
"""
input_1 = Input(shape=input_shape)
#input_2 = your another image
if backbone == 'resnet50':
backbone_model = ResNet50(inputs,
return_feature_maps=False, return_last_map=True)
elif backbone == 'vgg16':
backbone_model = VGG16(inputs, return_last_map=True)
else:
raise NotImplementedError
C5 = backbone_model.output
dropout_rate = 0.1
transformer = TriQImageQualityTransformer(
num_layers=transformer_params[0],
d_model=transformer_params[1],
num_heads=transformer_params[2],
mlp_dim=transformer_params[3],
dropout=dropout_rate,
n_quality_levels=n_quality_levels,
maximum_position_encoding=maximum_position_encoding,
vis=vis
)
outputs = transformer(C5)
model = Model(inputs=[input_1, input_2], outputs=outputs)
model.summary()
return model`
from triq.
`def create_triq_model(n_quality_levels,
input_shape=(None, None, 3),
backbone='resnet50',
transformer_params=(2, 32, 8, 64),
maximum_position_encoding=193,
vis=False):
"""
Creates the hybrid TRIQ model
:param n_quality_levels: number of quality levels, use 5 to predict quality distribution
:param input_shape: input shape
:param backbone: bakbone nets, supports ResNet50 and VGG16 now
:param transformer_params: Transformer parameters
:param maximum_position_encoding: the maximal number of positional embeddings
:param vis: flag to visualize attention weight maps
:return: TRIQ model
"""
inputs = Input(shape=input_shape)
m = inputs
x = tf.keras.layers.Lambda(tf.split, arguments={'axis': 2, 'num_or_size_splits': 2})(m)
input1 = x[0]
input2 = x[1]
print(input1.shape)
inputs1 = Input(shape=(input1.shape[1], input1.shape[2], input1.shape[3]))
inputs2 = Input(shape=(input2.shape[1], input2.shape[2], input2.shape[3]))
if backbone == 'resnet50':
backbone_model1 = ResNet50(inputs1, return_feature_maps=False, return_last_map=True)
backbone_model2 = ResNet50(inputs2, return_feature_maps=False, return_last_map=True)
elif backbone == 'vgg16':
backbone_model1 = VGG16(inputs1, return_last_map=True)
backbone_model2 = VGG16(inputs2, return_last_map=True)
else:
raise NotImplementedError
C51 = backbone_model1.output
C52 = backbone_model2.output
dropout_rate = 0.1
transformer = TriQImageQualityTransformer(
num_layers=transformer_params[0],
d_model=transformer_params[1],
num_heads=transformer_params[2],
mlp_dim=transformer_params[3],
dropout=dropout_rate,
n_quality_levels=n_quality_levels,
maximum_position_encoding=maximum_position_encoding,
vis=vis
)
outputs1 = transformer(C51)
outputs2 = transformer(C52)
outputs = tf.concat([outputs1, outputs2], 0)
model = Model(inputs=inputs, outputs=outputs)
#model.summary() # 输出模型参数
return model
`
oh, i'm sorry. it is just like this. The input data is [768, 2048, 3], and then divided into two sets of data, the size is [768, 1024, 3], and then put into the transformer respectively to get the output.
from triq.
If you are using the fixed input size, you do not need to specify the input size as (None, None, 3). I don't see any problems in your script.
from triq.
Thank you very much for your patient reply and guidance, I think the problem may be in the model call. It probably works now.
from triq.
Thank you very much for your patient reply and guidance, I think the problem may be in the model call. It probably works now.
Great to hear it works.
from triq.
Related Issues (20)
- Training HOT 10
- AttributeError: 'MyCSVLogger' object has no attribute 'file_flags' HOT 1
- Combined database normalisation HOT 1
- Accuracy and loss function visualisation HOT 1
- Could you please provide me a copy of the CSIQ dataset? HOT 1
- Does the sequence of datasets need to shuffle? In the code, shuffle is set False HOT 8
- OOM HOT 4
- 不能运行 image_quality_prediction.py HOT 1
- The test set HOT 2
- Issue with Training - Generator error HOT 5
- Same output for every input image HOT 6
- training HOT 2
- plcc HOT 4
- TRIQ failure on images of particular size range HOT 1
- request for trained model
- Save model config data
- save model architecture HOT 1
- dataset HOT 1
- About dataset HOT 5
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from triq.