Skip to content
Extraits de code Groupes Projets
Non vérifiée Valider aa10200e rédigé par Eugen Rochko's avatar Eugen Rochko Validation de GitHub
Parcourir les fichiers

Fix streaming API allowing connections to persist after access token invalidation (#15111)

Fix #14816
parent 8532429a
Aucune branche associée trouvée
Aucune étiquette associée trouvée
Aucune requête de fusion associée trouvée
# frozen_string_literal: true
module AccessTokenExtension
extend ActiveSupport::Concern
included do
after_commit :push_to_streaming_api
end
def revoke(clock = Time)
update(revoked_at: clock.now.utc)
end
def push_to_streaming_api
Redis.current.publish("timeline:access_token:#{id}", Oj.dump(event: :kill)) if revoked? || destroyed?
end
end
...@@ -70,12 +70,16 @@ class SessionActivation < ApplicationRecord ...@@ -70,12 +70,16 @@ class SessionActivation < ApplicationRecord
end end
def assign_access_token def assign_access_token
superapp = Doorkeeper::Application.find_by(superapp: true) self.access_token = Doorkeeper::AccessToken.create!(access_token_attributes)
end
self.access_token = Doorkeeper::AccessToken.create!(application_id: superapp&.id, def access_token_attributes
resource_owner_id: user_id, {
scopes: 'read write follow', application_id: Doorkeeper::Application.find_by(superapp: true)&.id,
expires_in: Doorkeeper.configuration.access_token_expires_in, resource_owner_id: user_id,
use_refresh_token: Doorkeeper.configuration.refresh_token_enabled?) scopes: 'read write follow',
expires_in: Doorkeeper.configuration.access_token_expires_in,
use_refresh_token: Doorkeeper.configuration.refresh_token_enabled?,
}
end end
end end
...@@ -140,6 +140,7 @@ module Mastodon ...@@ -140,6 +140,7 @@ module Mastodon
Doorkeeper::AuthorizationsController.layout 'modal' Doorkeeper::AuthorizationsController.layout 'modal'
Doorkeeper::AuthorizedApplicationsController.layout 'admin' Doorkeeper::AuthorizedApplicationsController.layout 'admin'
Doorkeeper::Application.send :include, ApplicationExtension Doorkeeper::Application.send :include, ApplicationExtension
Doorkeeper::AccessToken.send :include, AccessTokenExtension
Devise::FailureApp.send :include, AbstractController::Callbacks Devise::FailureApp.send :include, AbstractController::Callbacks
Devise::FailureApp.send :include, HttpAcceptLanguage::EasyAccess Devise::FailureApp.send :include, HttpAcceptLanguage::EasyAccess
Devise::FailureApp.send :include, Localized Devise::FailureApp.send :include, Localized
......
...@@ -294,7 +294,7 @@ const startWorker = (workerId) => { ...@@ -294,7 +294,7 @@ const startWorker = (workerId) => {
return; return;
} }
client.query('SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes, devices.device_id FROM oauth_access_tokens INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id LEFT OUTER JOIN devices ON oauth_access_tokens.id = devices.access_token_id WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL LIMIT 1', [token], (err, result) => { client.query('SELECT oauth_access_tokens.id, oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes, devices.device_id FROM oauth_access_tokens INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id LEFT OUTER JOIN devices ON oauth_access_tokens.id = devices.access_token_id WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL LIMIT 1', [token], (err, result) => {
done(); done();
if (err) { if (err) {
...@@ -310,6 +310,7 @@ const startWorker = (workerId) => { ...@@ -310,6 +310,7 @@ const startWorker = (workerId) => {
return; return;
} }
req.accessTokenId = result.rows[0].id;
req.scopes = result.rows[0].scopes.split(' '); req.scopes = result.rows[0].scopes.split(' ');
req.accountId = result.rows[0].account_id; req.accountId = result.rows[0].account_id;
req.chosenLanguages = result.rows[0].chosen_languages; req.chosenLanguages = result.rows[0].chosen_languages;
...@@ -450,6 +451,55 @@ const startWorker = (workerId) => { ...@@ -450,6 +451,55 @@ const startWorker = (workerId) => {
}); });
}; };
/**
* @typedef SystemMessageHandlers
* @property {function(): void} onKill
*/
/**
* @param {any} req
* @param {SystemMessageHandlers} eventHandlers
* @return {function(string): void}
*/
const createSystemMessageListener = (req, eventHandlers) => {
return message => {
const json = parseJSON(message);
if (!json) return;
const { event } = json;
log.silly(req.requestId, `System message for ${req.accountId}: ${event}`);
if (event === 'kill') {
log.verbose(req.requestId, `Closing connection for ${req.accountId} due to expired access token`);
eventHandlers.onKill();
}
}
};
/**
* @param {any} req
* @param {any} res
*/
const subscribeHttpToSystemChannel = (req, res) => {
const systemChannelId = `timeline:access_token:${req.accessTokenId}`;
const listener = createSystemMessageListener(req, {
onKill () {
res.end();
},
});
res.on('close', () => {
unsubscribe(`${redisPrefix}${systemChannelId}`, listener);
});
subscribe(`${redisPrefix}${systemChannelId}`, listener);
};
/** /**
* @param {any} req * @param {any} req
* @param {any} res * @param {any} res
...@@ -462,6 +512,8 @@ const startWorker = (workerId) => { ...@@ -462,6 +512,8 @@ const startWorker = (workerId) => {
} }
accountFromRequest(req, alwaysRequireAuth).then(() => checkScopes(req, channelNameFromPath(req))).then(() => { accountFromRequest(req, alwaysRequireAuth).then(() => checkScopes(req, channelNameFromPath(req))).then(() => {
subscribeHttpToSystemChannel(req, res);
}).then(() => {
next(); next();
}).catch(err => { }).catch(err => {
next(err); next(err);
...@@ -536,7 +588,9 @@ const startWorker = (workerId) => { ...@@ -536,7 +588,9 @@ const startWorker = (workerId) => {
const listener = message => { const listener = message => {
const json = parseJSON(message); const json = parseJSON(message);
if (!json) return; if (!json) return;
const { event, payload, queued_at } = json; const { event, payload, queued_at } = json;
const transmit = () => { const transmit = () => {
...@@ -902,6 +956,28 @@ const startWorker = (workerId) => { ...@@ -902,6 +956,28 @@ const startWorker = (workerId) => {
socket.send(JSON.stringify({ error: err.toString() })); socket.send(JSON.stringify({ error: err.toString() }));
}); });
/**
* @param {WebSocketSession} session
*/
const subscribeWebsocketToSystemChannel = ({ socket, request, subscriptions }) => {
const systemChannelId = `timeline:access_token:${request.accessTokenId}`;
const listener = createSystemMessageListener(request, {
onKill () {
socket.close();
},
});
subscribe(`${redisPrefix}${systemChannelId}`, listener);
subscriptions[systemChannelId] = {
listener,
stopHeartbeat: () => {},
};
};
/** /**
* @param {string|string[]} arrayOrString * @param {string|string[]} arrayOrString
* @return {string} * @return {string}
...@@ -948,7 +1024,9 @@ const startWorker = (workerId) => { ...@@ -948,7 +1024,9 @@ const startWorker = (workerId) => {
ws.on('message', data => { ws.on('message', data => {
const json = parseJSON(data); const json = parseJSON(data);
if (!json) return; if (!json) return;
const { type, stream, ...params } = json; const { type, stream, ...params } = json;
if (type === 'subscribe') { if (type === 'subscribe') {
...@@ -960,6 +1038,8 @@ const startWorker = (workerId) => { ...@@ -960,6 +1038,8 @@ const startWorker = (workerId) => {
} }
}); });
subscribeWebsocketToSystemChannel(session);
if (location.query.stream) { if (location.query.stream) {
subscribeWebsocketToChannel(session, firstParam(location.query.stream), location.query); subscribeWebsocketToChannel(session, firstParam(location.query.stream), location.query);
} }
......
0% Chargement en cours ou .
You are about to add 0 people to the discussion. Proceed with caution.
Terminez d'abord l'édition de ce message.
Veuillez vous inscrire ou vous pour commenter